Browse Source

del: 删除无用代码

yq 3 months ago
parent
commit
2a92c9ca74
3 changed files with 6 additions and 118 deletions
  1. 6 1
      __init__.py
  2. 0 93
      main.py
  3. 0 24
      strategy_test1.py

+ 6 - 1
__init__.py

@@ -7,6 +7,10 @@
 import sys
 from os.path import dirname, realpath
 
+from feature import FilterStrategyFactory
+from model import ModelFactory
+from trainer import TrainPipeline
+
 sys.path.append(dirname(realpath(__file__)))
 
 from data import DataLoaderMysql
@@ -14,4 +18,5 @@ from entitys import DbConfigEntity
 from monitor import MonitorMetric
 from metrics import MetricBase
 
-__all__ = ['MonitorMetric', 'DataLoaderMysql', 'DbConfigEntity', 'MetricBase']
+__all__ = ['MonitorMetric', 'DataLoaderMysql', 'DbConfigEntity', 'MetricBase', 'FilterStrategyFactory', 'ModelFactory',
+           'TrainPipeline']

+ 0 - 93
main.py

@@ -1,93 +0,0 @@
-# -*- coding:utf-8 -*-
-"""
-@author: yq
-@time: 2022/8/29
-@desc:
-"""
-
-import argparse
-import json
-import time
-import traceback
-import uuid
-
-import uvicorn
-from fastapi import FastAPI, Body, Request
-from starlette.responses import RedirectResponse, StreamingResponse
-
-from commom.logger import get_logger
-from commom.traceId_util import request_id_context
-from entitys.response import BaseResponse
-from services.model_route import model_classify
-
-logger = get_logger()
-app = FastAPI()
-
-
-def predict(
-        question: str = Body(..., embed=True, description="question", example="人保最近半年收入多少?")
-):
-    try:
-        logger.info(f"{question=}")
-        classify = model_classify.predict(question)
-        logger.info(f"{classify=}")
-        return model_classify.post_process(classify)
-
-    except Exception as msg:
-        logger.error(traceback.format_exc())
-        return BaseResponse.ofFailure(str(msg))
-
-
-def chat(
-        question: str = Body(..., embed=True, description="question", example="为什么会下雨?")
-):
-    try:
-        logger.info(f"{question=}")
-        answer = model_classify.chat(question)
-        logger.info(f"{answer=}")
-        return BaseResponse.ofSuccess(answer)
-
-    except Exception as msg:
-        logger.error(traceback.format_exc())
-        return BaseResponse.ofFailure(str(msg))
-
-async def streaming_test(
-        question: str = Body(..., embed=True, description="知识库问答",
-                             example="金融机构、非银行支付机构应当履行下列防范非法集资的义务的具体条款是?"),
-        requestId: str = Body(..., embed=True, description="", example=""),
-):
-    return StreamingResponse(generate_data(), media_type="application/json")
-
-@app.middleware("http")
-async def add_request_id_header(request: Request, call_next):
-    request_id = request.headers.get("X-REQUEST-ID")
-    if request_id is None or len(request_id) == 0:
-        request_id = str(uuid.uuid4())
-    request_id_context.set(request_id)
-    start_time = time.time()
-    response = await call_next(request)
-    process_time = time.time() - start_time
-    response.headers["X-REQUEST-ID"] = request_id_context.get()
-    response.headers["PROCESS-TIME"] = f"{process_time:.2f}"
-
-    return response
-
-
-async def document():
-    return RedirectResponse(url="/docs")
-
-
-def api_start(host, port):
-    app.get("/", response_model=BaseResponse)(document)
-    app.post("/classify/predict", response_model=BaseResponse)(predict)
-    app.post("/chat", response_model=BaseResponse)(chat)
-    app.post("/streaming_test", response_model=BaseResponse)(streaming_test)
-    uvicorn.run(app, host=host, port=port)
-
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--host", type=str, default="0.0.0.0")
-    parser.add_argument("--port", type=int, default=18080)
-    args = parser.parse_args()
-    api_start(args.host, args.port)

+ 0 - 24
strategy_test1.py

@@ -1,24 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-@author: yq
-@time: 2024/11/1
-@desc: 
-"""
-import time
-
-from entitys import DataSplitEntity, DataProcessConfigEntity
-from feature import FilterStrategyFactory
-from feature.strategy_iv import StrategyIv
-
-if __name__ == "__main__":
-    time_now = time.time()
-    import scorecardpy as sc
-    dat = sc.germancredit()
-    dat["creditability"] = dat["creditability"].apply(lambda x: 1 if x == "bad" else 0)
-    data = DataSplitEntity(dat[:700], None, dat[700:])
-    filter_strategy_factory= FilterStrategyFactory(DataProcessConfigEntity.from_config('./config/data_process_config_template.json'))
-    strategy = filter_strategy_factory.get_strategy()
-    candidate_feature = strategy.filter(data)
-    candidate_feature = strategy.feature_generate(data, candidate_feature)
-
-    print(time.time() - time_now)