# -*- coding:utf-8 -*-
"""
@author: yq
@time: 2022/8/29
@desc:
"""

import argparse
import json
import time
import traceback
import uuid

import uvicorn
from fastapi import FastAPI, Body, Request
from starlette.responses import RedirectResponse, StreamingResponse

from commom.logger import get_logger
from commom.traceId_util import request_id_context
from entitys.response import BaseResponse
from services.model_route import model_classify

logger = get_logger()
app = FastAPI()


def predict(
        question: str = Body(..., embed=True, description="question", example="人保最近半年收入多少?")
):
    try:
        logger.info(f"{question=}")
        classify = model_classify.predict(question)
        logger.info(f"{classify=}")
        return model_classify.post_process(classify)

    except Exception as msg:
        logger.error(traceback.format_exc())
        return BaseResponse.ofFailure(str(msg))


def chat(
        question: str = Body(..., embed=True, description="question", example="为什么会下雨?")
):
    try:
        logger.info(f"{question=}")
        answer = model_classify.chat(question)
        logger.info(f"{answer=}")
        return BaseResponse.ofSuccess(answer)

    except Exception as msg:
        logger.error(traceback.format_exc())
        return BaseResponse.ofFailure(str(msg))

async def streaming_test(
        question: str = Body(..., embed=True, description="知识库问答",
                             example="金融机构、非银行支付机构应当履行下列防范非法集资的义务的具体条款是?"),
        requestId: str = Body(..., embed=True, description="", example=""),
):
    return StreamingResponse(generate_data(), media_type="application/json")

@app.middleware("http")
async def add_request_id_header(request: Request, call_next):
    request_id = request.headers.get("X-REQUEST-ID")
    if request_id is None or len(request_id) == 0:
        request_id = str(uuid.uuid4())
    request_id_context.set(request_id)
    start_time = time.time()
    response = await call_next(request)
    process_time = time.time() - start_time
    response.headers["X-REQUEST-ID"] = request_id_context.get()
    response.headers["PROCESS-TIME"] = f"{process_time:.2f}"

    return response


async def document():
    return RedirectResponse(url="/docs")


def api_start(host, port):
    app.get("/", response_model=BaseResponse)(document)
    app.post("/classify/predict", response_model=BaseResponse)(predict)
    app.post("/chat", response_model=BaseResponse)(chat)
    app.post("/streaming_test", response_model=BaseResponse)(streaming_test)
    uvicorn.run(app, host=host, port=port)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="0.0.0.0")
    parser.add_argument("--port", type=int, default=18080)
    args = parser.parse_args()
    api_start(args.host, args.port)