|
@@ -7,7 +7,7 @@
|
|
|
import pandas as pd
|
|
|
from sklearn.linear_model import LogisticRegression
|
|
|
|
|
|
-from entitys import DataFeatureEntity
|
|
|
+from entitys import DataFeatureEntity, MetricTrainEntity
|
|
|
from model_base import ModelBase
|
|
|
|
|
|
|
|
@@ -15,8 +15,9 @@ class ModelLr(ModelBase):
|
|
|
def __init__(self, ):
|
|
|
self.lr = LogisticRegression(penalty='l1', C=0.9, solver='saga', n_jobs=-1)
|
|
|
|
|
|
- def train(self, data: DataFeatureEntity, *args, **kwargs):
|
|
|
+ def train(self, data: DataFeatureEntity, *args, **kwargs) -> MetricTrainEntity:
|
|
|
self.lr.fit(data.get_Xdata(), data.get_Ydata())
|
|
|
+ return MetricTrainEntity(0.7, 0.4)
|
|
|
|
|
|
def predict_prob(self, x: pd.DataFrame, *args, **kwargs):
|
|
|
return self.lr.predict_proba(x)[:, 1]
|