main.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # -*- coding:utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2022/8/29
  5. @desc:
  6. """
  7. import argparse
  8. import json
  9. import time
  10. import traceback
  11. import uuid
  12. import uvicorn
  13. from fastapi import FastAPI, Body, Request
  14. from starlette.responses import RedirectResponse, StreamingResponse
  15. from commom.logger import get_logger
  16. from commom.traceId_util import request_id_context
  17. from entitys.response import BaseResponse
  18. from services.model_route import model_classify
  19. logger = get_logger()
  20. app = FastAPI()
  21. def predict(
  22. question: str = Body(..., embed=True, description="question", example="人保最近半年收入多少?")
  23. ):
  24. try:
  25. logger.info(f"{question=}")
  26. classify = model_classify.predict(question)
  27. logger.info(f"{classify=}")
  28. return model_classify.post_process(classify)
  29. except Exception as msg:
  30. logger.error(traceback.format_exc())
  31. return BaseResponse.ofFailure(str(msg))
  32. def chat(
  33. question: str = Body(..., embed=True, description="question", example="为什么会下雨?")
  34. ):
  35. try:
  36. logger.info(f"{question=}")
  37. answer = model_classify.chat(question)
  38. logger.info(f"{answer=}")
  39. return BaseResponse.ofSuccess(answer)
  40. except Exception as msg:
  41. logger.error(traceback.format_exc())
  42. return BaseResponse.ofFailure(str(msg))
  43. async def streaming_test(
  44. question: str = Body(..., embed=True, description="知识库问答",
  45. example="金融机构、非银行支付机构应当履行下列防范非法集资的义务的具体条款是?"),
  46. requestId: str = Body(..., embed=True, description="", example=""),
  47. ):
  48. return StreamingResponse(generate_data(), media_type="application/json")
  49. @app.middleware("http")
  50. async def add_request_id_header(request: Request, call_next):
  51. request_id = request.headers.get("X-REQUEST-ID")
  52. if request_id is None or len(request_id) == 0:
  53. request_id = str(uuid.uuid4())
  54. request_id_context.set(request_id)
  55. start_time = time.time()
  56. response = await call_next(request)
  57. process_time = time.time() - start_time
  58. response.headers["X-REQUEST-ID"] = request_id_context.get()
  59. response.headers["PROCESS-TIME"] = f"{process_time:.2f}"
  60. return response
  61. async def document():
  62. return RedirectResponse(url="/docs")
  63. def api_start(host, port):
  64. app.get("/", response_model=BaseResponse)(document)
  65. app.post("/classify/predict", response_model=BaseResponse)(predict)
  66. app.post("/chat", response_model=BaseResponse)(chat)
  67. app.post("/streaming_test", response_model=BaseResponse)(streaming_test)
  68. uvicorn.run(app, host=host, port=port)
  69. if __name__ == "__main__":
  70. parser = argparse.ArgumentParser()
  71. parser.add_argument("--host", type=str, default="0.0.0.0")
  72. parser.add_argument("--port", type=int, default=18080)
  73. args = parser.parse_args()
  74. api_start(args.host, args.port)