|
@@ -14,11 +14,11 @@ class TrainPipeline():
|
|
|
|
|
|
def train(self, train_data: DataFeatureEntity, test_data: DataFeatureEntity):
|
|
|
self.model.train(train_data)
|
|
|
- self.model.predict_prob(test_data.data)
|
|
|
+ self.model.predict_prob(test_data.get_Xdata())
|
|
|
|
|
|
def generate_report(self):
|
|
|
pass
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- pass
|
|
|
+ pass
|