# -*- coding: utf-8 -*- """ @author: yq @time: 2025/2/27 @desc: """ import torch.nn as nn class LR(nn.Module): def __init__(self, weight: nn.Parameter): super(LR, self).__init__() self.linear = nn.Linear(weight.shape[0], 1, bias=False) self.linear.weight = weight self.sigmoid = nn.Sigmoid() def forward(self, x): return self.sigmoid(self.linear(x)) if __name__ == "__main__": pass