yq 5 kuukautta sitten
vanhempi
sitoutus
38cc2a7d9a

+ 2 - 1
.gitignore

@@ -57,4 +57,5 @@ docs/_build/
 
 # PyBuilder
 target/
-
+/.idea/
+/logs

+ 3 - 1
README.md

@@ -1,3 +1,5 @@
 # easy-ml
 
-模型自动化及监控自动化
+模型自动化及监控自动化
+
+环境 python3.6

+ 6 - 0
commom/__init__.py

@@ -0,0 +1,6 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2021/11/9
+@desc: 
+"""

+ 65 - 0
commom/logger.py

@@ -0,0 +1,65 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2022/8/29
+@desc:
+"""
+
+import datetime
+import logging
+import logging.handlers
+import os
+import threading
+import time
+from os.path import dirname, realpath
+
+import pytz
+
+from commom.traceId_util import TraceIdFilter
+
+
+def my_time(*args):
+    return time.strptime(datetime.datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S"),
+                         "%Y-%m-%d %H:%M:%S")
+
+
+_instance_lock = threading.Lock()
+logger_map = {}
+
+
+def get_logger(logger_name: str = None) -> logging.Logger:
+    if logger_name is None:
+        logger_name = "app"
+    if logger_name in logger_map.keys():
+        return logger_map.get(logger_name)
+    with _instance_lock:
+        if logger_name in logger_map.keys():
+            return logger_map.get(logger_name)
+
+        _logger = logging.Logger(logger_name)
+        _logger.setLevel(logging.INFO)
+        _logger.addFilter(TraceIdFilter())
+
+        formatter = logging.Formatter(
+            '[%(asctime)s] [requestId-%(requestId)s] [%(levelname)s] [%(threadName)s] [%(filename)s] [func:%(funcName)s line:%(lineno)d]\n %(message)s')
+        formatter.converter = my_time
+
+        log_path = os.path.join(dirname(dirname(realpath(__file__))), "logs")
+        filename = os.path.join(log_path, f"{logger_name}.log")
+
+        if not os.path.exists(dirname(filename)):
+            os.makedirs(dirname(filename))
+        print(f"日志路径:{filename}")
+
+        handler = logging.handlers.TimedRotatingFileHandler(filename, when="MIDNIGHT", interval=7, backupCount=4,
+                                                            encoding="utf8", atTime=datetime.time(0, 0, 0, 0))
+        handler.setFormatter(formatter)
+        _logger.addHandler(handler)
+
+        console_handler = logging.StreamHandler()
+        console_handler.setFormatter(formatter)
+        _logger.addHandler(console_handler)
+
+        logger_map[logger_name] = _logger
+
+        return _logger

+ 18 - 0
commom/traceId_util.py

@@ -0,0 +1,18 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2023/5/16
+@desc: 
+"""
+
+import logging
+from contextvars import ContextVar
+
+request_id_context = ContextVar('request_id')
+
+
+class TraceIdFilter(logging.Filter):
+
+    def filter(self, record):
+        record.requestId = request_id_context.get()
+        return True

+ 6 - 0
commom/utils.py

@@ -0,0 +1,6 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2023/12/28
+@desc: 
+"""

+ 8 - 0
config/__init__.py

@@ -0,0 +1,8 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2022/10/24
+@desc: 
+"""
+
+

+ 10 - 0
config/base_config.py

@@ -0,0 +1,10 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2022/10/24
+@desc: 
+"""
+
+
+class BaseConfig:
+    pass

+ 10 - 0
dao/__init__.py

@@ -0,0 +1,10 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2022/10/24
+@desc: 数据库相关
+"""
+
+
+if __name__ == "__main__":
+    pass

+ 0 - 0
entitys/__init__.py


+ 33 - 0
entitys/response.py

@@ -0,0 +1,33 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2023/12/29
+@desc: 
+"""
+import pydantic
+from pydantic import BaseModel
+
+
+class BaseResponse(BaseModel):
+    data: object = pydantic.Field(None, description="request id")
+    code: int = pydantic.Field(200, description="HTTP status code")
+    msg: str = pydantic.Field("success", description="HTTP status message")
+    success: bool = pydantic.Field(True, description="success status")
+
+    class Config:
+        schema_extra = {
+            "example": {
+                "data": None,
+                "code": 200,
+                "msg": "success",
+                "success": True
+            }
+        }
+
+    @staticmethod
+    def ofSuccess(data: object):
+        return BaseResponse(data=data, code=200, msg="success", success=True)
+
+    @staticmethod
+    def ofFailure(msg: str = "error"):
+        return BaseResponse(data=None, code=500, msg=msg, success=False)

+ 0 - 0
enums/__init__.py


+ 93 - 0
main.py

@@ -0,0 +1,93 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2022/8/29
+@desc:
+"""
+
+import argparse
+import json
+import time
+import traceback
+import uuid
+
+import uvicorn
+from fastapi import FastAPI, Body, Request
+from starlette.responses import RedirectResponse, StreamingResponse
+
+from commom.logger import get_logger
+from commom.traceId_util import request_id_context
+from entitys.response import BaseResponse
+from services.model_route import model_classify
+
+logger = get_logger()
+app = FastAPI()
+
+
+def predict(
+        question: str = Body(..., embed=True, description="question", example="人保最近半年收入多少?")
+):
+    try:
+        logger.info(f"{question=}")
+        classify = model_classify.predict(question)
+        logger.info(f"{classify=}")
+        return model_classify.post_process(classify)
+
+    except Exception as msg:
+        logger.error(traceback.format_exc())
+        return BaseResponse.ofFailure(str(msg))
+
+
+def chat(
+        question: str = Body(..., embed=True, description="question", example="为什么会下雨?")
+):
+    try:
+        logger.info(f"{question=}")
+        answer = model_classify.chat(question)
+        logger.info(f"{answer=}")
+        return BaseResponse.ofSuccess(answer)
+
+    except Exception as msg:
+        logger.error(traceback.format_exc())
+        return BaseResponse.ofFailure(str(msg))
+
+async def streaming_test(
+        question: str = Body(..., embed=True, description="知识库问答",
+                             example="金融机构、非银行支付机构应当履行下列防范非法集资的义务的具体条款是?"),
+        requestId: str = Body(..., embed=True, description="", example=""),
+):
+    return StreamingResponse(generate_data(), media_type="application/json")
+
+@app.middleware("http")
+async def add_request_id_header(request: Request, call_next):
+    request_id = request.headers.get("X-REQUEST-ID")
+    if request_id is None or len(request_id) == 0:
+        request_id = str(uuid.uuid4())
+    request_id_context.set(request_id)
+    start_time = time.time()
+    response = await call_next(request)
+    process_time = time.time() - start_time
+    response.headers["X-REQUEST-ID"] = request_id_context.get()
+    response.headers["PROCESS-TIME"] = f"{process_time:.2f}"
+
+    return response
+
+
+async def document():
+    return RedirectResponse(url="/docs")
+
+
+def api_start(host, port):
+    app.get("/", response_model=BaseResponse)(document)
+    app.post("/classify/predict", response_model=BaseResponse)(predict)
+    app.post("/chat", response_model=BaseResponse)(chat)
+    app.post("/streaming_test", response_model=BaseResponse)(streaming_test)
+    uvicorn.run(app, host=host, port=port)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--host", type=str, default="0.0.0.0")
+    parser.add_argument("--port", type=int, default=18080)
+    args = parser.parse_args()
+    api_start(args.host, args.port)

+ 9 - 0
metric/__init__.py

@@ -0,0 +1,9 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2022/10/24
+@desc: 指标计算相关
+"""
+
+if __name__ == "__main__":
+    pass

+ 6 - 0
model/__init__.py

@@ -0,0 +1,6 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2023/12/28
+@desc: 模型相关
+"""

+ 18 - 0
model/model_base.py

@@ -0,0 +1,18 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2024/1/2
+@desc: 
+"""
+import pandas as pd
+import abc
+
+class ModelBase(metaclass=abc.ABCMeta):
+
+    @abc.abstractmethod
+    def train(self, data: pd.DataFrame):
+        pass
+
+    @abc.abstractmethod
+    def predict(self, x: pd.DataFrame):
+        pass

+ 15 - 0
model/model_route.py

@@ -0,0 +1,15 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2024/1/2
+@desc: 
+"""
+from config.base_config import BaseConfig
+
+model_route_map = {"lr": ModelQwen}
+
+model_clazz = model_route_map.get(BaseConfig.classify_model_name, None)
+
+assert model_clazz is not None, f"模型【{BaseConfig.classify_model_name}】不存在"
+
+model_classify = model_clazz()

+ 9 - 0
monitor/__init__.py

@@ -0,0 +1,9 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2022/10/24
+@desc: 指标监控
+"""
+
+if __name__ == "__main__":
+    pass

+ 2 - 0
requirements.txt

@@ -0,0 +1,2 @@
+psutil==5.9.5
+

+ 31 - 0
start.sh

@@ -0,0 +1,31 @@
+#!/bin/bash
+
+source activate chatglm2
+
+PATH_APP=$(pwd)
+
+function get_pid() {
+  APP_PID=$(ps -ef | grep "python $PATH_APP/main.py" | grep -v grep | awk '{print $2}')
+}
+
+function kill_app() {
+  if [ -n $APP_PID ]; then
+    for v in $APP_PID; do
+      echo $(date +%F%n%T) "开始杀死已有进程: $v"
+      kill -9 $v
+    done
+  fi
+}
+
+function start_app() {
+  echo $(date +%F%n%T) "开始启动model-api-classify..."
+  PYTHONIOENCODING=utf-8 nohup python $PATH_APP/main.py > $PATH_APP/nohup.out 2>&1 &
+  sleep 3
+  echo $(tail -50 $PATH_APP/nohup.out)
+  echo "启动完成..."
+  echo "日志请查看 $PATH_APP/nohup.out"
+}
+
+get_pid
+kill_app
+start_app