utils.py 440 B

12345678910111213141516171819202122
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2025/2/27
  5. @desc:
  6. """
  7. import torch.nn as nn
  8. class LR(nn.Module):
  9. def __init__(self, weight: nn.Parameter):
  10. super(LR, self).__init__()
  11. self.linear = nn.Linear(weight.shape[0], 1, bias=False)
  12. self.linear.weight = weight
  13. self.sigmoid = nn.Sigmoid()
  14. def forward(self, x):
  15. return self.sigmoid(self.linear(x))
  16. if __name__ == "__main__":
  17. pass