12345678910111213141516171819202122 |
- # -*- coding: utf-8 -*-
- """
- @author: yq
- @time: 2025/2/27
- @desc:
- """
- import torch.nn as nn
- class LR(nn.Module):
- def __init__(self, weight: nn.Parameter):
- super(LR, self).__init__()
- self.linear = nn.Linear(weight.shape[0], 1, bias=False)
- self.linear.weight = weight
- self.sigmoid = nn.Sigmoid()
- def forward(self, x):
- return self.sigmoid(self.linear(x))
- if __name__ == "__main__":
- pass
|