# -*- coding:utf-8 -*- """ @author: yq @time: 2022/8/29 @desc: """ import argparse import json import time import traceback import uuid import uvicorn from fastapi import FastAPI, Body, Request from starlette.responses import RedirectResponse, StreamingResponse from commom.logger import get_logger from commom.traceId_util import request_id_context from entitys.response import BaseResponse from services.model_route import model_classify logger = get_logger() app = FastAPI() def predict( question: str = Body(..., embed=True, description="question", example="人保最近半年收入多少?") ): try: logger.info(f"{question=}") classify = model_classify.predict(question) logger.info(f"{classify=}") return model_classify.post_process(classify) except Exception as msg: logger.error(traceback.format_exc()) return BaseResponse.ofFailure(str(msg)) def chat( question: str = Body(..., embed=True, description="question", example="为什么会下雨?") ): try: logger.info(f"{question=}") answer = model_classify.chat(question) logger.info(f"{answer=}") return BaseResponse.ofSuccess(answer) except Exception as msg: logger.error(traceback.format_exc()) return BaseResponse.ofFailure(str(msg)) async def streaming_test( question: str = Body(..., embed=True, description="知识库问答", example="金融机构、非银行支付机构应当履行下列防范非法集资的义务的具体条款是?"), requestId: str = Body(..., embed=True, description="", example=""), ): return StreamingResponse(generate_data(), media_type="application/json") @app.middleware("http") async def add_request_id_header(request: Request, call_next): request_id = request.headers.get("X-REQUEST-ID") if request_id is None or len(request_id) == 0: request_id = str(uuid.uuid4()) request_id_context.set(request_id) start_time = time.time() response = await call_next(request) process_time = time.time() - start_time response.headers["X-REQUEST-ID"] = request_id_context.get() response.headers["PROCESS-TIME"] = f"{process_time:.2f}" return response async def document(): return RedirectResponse(url="/docs") def api_start(host, port): app.get("/", response_model=BaseResponse)(document) app.post("/classify/predict", response_model=BaseResponse)(predict) app.post("/chat", response_model=BaseResponse)(chat) app.post("/streaming_test", response_model=BaseResponse)(streaming_test) uvicorn.run(app, host=host, port=port) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=18080) args = parser.parse_args() api_start(args.host, args.port)