|
@@ -0,0 +1,93 @@
|
|
|
+# -*- coding:utf-8 -*-
|
|
|
+"""
|
|
|
+@author: yq
|
|
|
+@time: 2022/8/29
|
|
|
+@desc:
|
|
|
+"""
|
|
|
+
|
|
|
+import argparse
|
|
|
+import json
|
|
|
+import time
|
|
|
+import traceback
|
|
|
+import uuid
|
|
|
+
|
|
|
+import uvicorn
|
|
|
+from fastapi import FastAPI, Body, Request
|
|
|
+from starlette.responses import RedirectResponse, StreamingResponse
|
|
|
+
|
|
|
+from commom.logger import get_logger
|
|
|
+from commom.traceId_util import request_id_context
|
|
|
+from entitys.response import BaseResponse
|
|
|
+from services.model_route import model_classify
|
|
|
+
|
|
|
+logger = get_logger()
|
|
|
+app = FastAPI()
|
|
|
+
|
|
|
+
|
|
|
+def predict(
|
|
|
+ question: str = Body(..., embed=True, description="question", example="人保最近半年收入多少?")
|
|
|
+):
|
|
|
+ try:
|
|
|
+ logger.info(f"{question=}")
|
|
|
+ classify = model_classify.predict(question)
|
|
|
+ logger.info(f"{classify=}")
|
|
|
+ return model_classify.post_process(classify)
|
|
|
+
|
|
|
+ except Exception as msg:
|
|
|
+ logger.error(traceback.format_exc())
|
|
|
+ return BaseResponse.ofFailure(str(msg))
|
|
|
+
|
|
|
+
|
|
|
+def chat(
|
|
|
+ question: str = Body(..., embed=True, description="question", example="为什么会下雨?")
|
|
|
+):
|
|
|
+ try:
|
|
|
+ logger.info(f"{question=}")
|
|
|
+ answer = model_classify.chat(question)
|
|
|
+ logger.info(f"{answer=}")
|
|
|
+ return BaseResponse.ofSuccess(answer)
|
|
|
+
|
|
|
+ except Exception as msg:
|
|
|
+ logger.error(traceback.format_exc())
|
|
|
+ return BaseResponse.ofFailure(str(msg))
|
|
|
+
|
|
|
+async def streaming_test(
|
|
|
+ question: str = Body(..., embed=True, description="知识库问答",
|
|
|
+ example="金融机构、非银行支付机构应当履行下列防范非法集资的义务的具体条款是?"),
|
|
|
+ requestId: str = Body(..., embed=True, description="", example=""),
|
|
|
+):
|
|
|
+ return StreamingResponse(generate_data(), media_type="application/json")
|
|
|
+
|
|
|
+@app.middleware("http")
|
|
|
+async def add_request_id_header(request: Request, call_next):
|
|
|
+ request_id = request.headers.get("X-REQUEST-ID")
|
|
|
+ if request_id is None or len(request_id) == 0:
|
|
|
+ request_id = str(uuid.uuid4())
|
|
|
+ request_id_context.set(request_id)
|
|
|
+ start_time = time.time()
|
|
|
+ response = await call_next(request)
|
|
|
+ process_time = time.time() - start_time
|
|
|
+ response.headers["X-REQUEST-ID"] = request_id_context.get()
|
|
|
+ response.headers["PROCESS-TIME"] = f"{process_time:.2f}"
|
|
|
+
|
|
|
+ return response
|
|
|
+
|
|
|
+
|
|
|
+async def document():
|
|
|
+ return RedirectResponse(url="/docs")
|
|
|
+
|
|
|
+
|
|
|
+def api_start(host, port):
|
|
|
+ app.get("/", response_model=BaseResponse)(document)
|
|
|
+ app.post("/classify/predict", response_model=BaseResponse)(predict)
|
|
|
+ app.post("/chat", response_model=BaseResponse)(chat)
|
|
|
+ app.post("/streaming_test", response_model=BaseResponse)(streaming_test)
|
|
|
+ uvicorn.run(app, host=host, port=port)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ parser = argparse.ArgumentParser()
|
|
|
+ parser.add_argument("--host", type=str, default="0.0.0.0")
|
|
|
+ parser.add_argument("--port", type=int, default=18080)
|
|
|
+ args = parser.parse_args()
|
|
|
+ api_start(args.host, args.port)
|