You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

209 lines
6.1 KiB
Python

# coding=gbk
import yaml
import sys
import os
import time
import uuid
import json
import shutil
import logging
from collections import deque
from pydantic import BaseModel
from fastapi import BackgroundTasks
from fastapi import FastAPI, HTTPException
from qa_Ask import QAService, match_query, store_data
app = FastAPI()
# <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>־<EFBFBD><D6BE>¼<EFBFBD><C2BC><EFBFBD>ļ<EFBFBD><C4BC><EFBFBD><EFBFBD>ն<EFBFBD>
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('log/app.log'),
logging.StreamHandler(sys.stdout) # <20><><EFBFBD>ӿ<EFBFBD><D3BF><EFBFBD>̨<EFBFBD><CCA8><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
]
)
logger = logging.getLogger(__name__)
class QuestionRequest(BaseModel):
question: str
11 months ago
scoreThreshold: float
class QuestionResponse(BaseModel):
code: int
msg: str
data: list
class QuestionItem(BaseModel):
11 months ago
questionId: str
questionList: list[str]
class InputText(BaseModel):
inputText: str
class ExtractedInfo(BaseModel):
name: str
cardNumber: str
idNumber: str
with open('config/config.yaml', 'r') as config_file:
config_data = yaml.safe_load(config_file)
knowledge_base_file = config_data['knowledge_base_file']
api_url = config_data['api']['url']
path = config_data['output_file_path']
max_knowledge_bases = config_data['max_knowledge_bases']
def load_knowledge_bases():
"""<EFBFBD><EFBFBD><EFBFBD><EFBFBD>֪ʶ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>б<EFBFBD>"""
if os.path.exists(knowledge_base_file):
with open(knowledge_base_file, "r") as file:
return file.read().splitlines()
else:
return []
def save_knowledge_bases(names):
"""<EFBFBD><EFBFBD><EFBFBD><EFBFBD>֪ʶ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>б<EFBFBD><EFBFBD><EFBFBD><EFBFBD>ļ<EFBFBD>"""
with open(knowledge_base_file, "w") as file:
file.write("\n".join(names))
def update_kb(kb_name, qa_service, path, max_knowledge_bases):
"""<EFBFBD><EFBFBD><EFBFBD><EFBFBD>֪ʶ<EFBFBD><EFBFBD>"""
store_data(qa_service, path)
if len(recent_knowledge_bases) == max_knowledge_bases:
folder_to_delete = recent_knowledge_bases.popleft()
shutil.rmtree(f"knowledge_base/{folder_to_delete}")
recent_knowledge_bases.append(kb_name)
save_knowledge_bases(recent_knowledge_bases)
os.remove(path)
logger.info(f"Knowledge base updated: {kb_name}\n"
f"Please wait while the database is being updated<65><64><EFBFBD><EFBFBD><EFBFBD><EFBFBD>")
recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)
def text_to_number(text_id):
chinese_nums = {'<EFBFBD><EFBFBD>': '0', 'һ': '1', '<EFBFBD><EFBFBD>': '2', '<EFBFBD><EFBFBD>': '3', '<EFBFBD><EFBFBD>': '4', '<EFBFBD><EFBFBD>': '5', '<EFBFBD><EFBFBD>': '6', '<EFBFBD><EFBFBD>': '7', '<EFBFBD><EFBFBD>': '8', '<EFBFBD><EFBFBD>': '9'}
translation_table = str.maketrans(chinese_nums)
return text_id.translate(translation_table)
@app.post("/updateDatabase")
async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks):
"""<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ݲ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ΪJSON<EFBFBD>ļ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>̨<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>"""
try:
json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2)
path = "output.json"
with open(path, "w", encoding="utf-8") as file:
file.write(json_data)
kb_name = str(uuid.uuid4())
device = None
qa_service = QAService(kb_name, device)
background_tasks.add_task(
update_kb, kb_name, qa_service, path, max_knowledge_bases
)
return {"status": "success", "message": "Please wait while the database is being updated<65><64><EFBFBD><EFBFBD><EFBFBD><EFBFBD>"}
except Exception as e:
logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}")
# raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
return {"status": "error", "message": "update task error<6F><72><EFBFBD><EFBFBD><EFBFBD><EFBFBD>"}
@app.post("/matchQuestion")
def match_question(request: QuestionRequest):
"""ƥ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ķ˵<EFBFBD>"""
try:
logger.info(f"match_question:Request: {request}")
start_time = time.time()
query = request.question
newest = recent_knowledge_bases[-1]
top_k = 3
device = None
qa_service = QAService(newest, device)
11 months ago
result = match_query(qa_service, query, top_k, request.scoreThreshold)
response = QuestionResponse(code=200, msg="success", data=result)
stop_time = time.time()
duration = stop_time - start_time
logger.info(f"match_question:Matched question in {duration} seconds. "
f"Response: {result}")
return response
except Exception as e:
logger.error(f"Error matching question: {e}")
return QuestionResponse(code=500, msg="success", data=[])
from paddlenlp import Taskflow
corrector = Taskflow("text_correction")
schema = ["<EFBFBD><EFBFBD><EFBFBD><EFBFBD>", '<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>', '<EFBFBD><EFBFBD><EFBFBD>Ա', "<EFBFBD><EFBFBD><EFBFBD><EFBFBD>֤<EFBFBD><EFBFBD>", "<EFBFBD><EFBFBD><EFBFBD><EFBFBD>֤<EFBFBD><EFBFBD><EFBFBD><EFBFBD>", "<EFBFBD><EFBFBD><EFBFBD><EFBFBD>", "<EFBFBD><EFBFBD><EFBFBD>׿<EFBFBD><EFBFBD><EFBFBD>", "<EFBFBD><EFBFBD><EFBFBD>п<EFBFBD><EFBFBD><EFBFBD>", ]
name = Taskflow('information_extraction', schema=schema[:2], model='uie-base')
identity = Taskflow('information_extraction', schema=schema[3:5], model='uie-base')
card = Taskflow('information_extraction', schema=schema[5:8], model='uie-base')
11 months ago
@app.post("/extractInformation")
async def extract_information(input_data: InputText):
"""<EFBFBD><EFBFBD>ȡ<EFBFBD><EFBFBD>Ϣ<EFBFBD>Ķ˵<EFBFBD>"""
try:
input_text = input_data.inputText
data = corrector(input_text)
target_value = data[0]['target']
converted_id = text_to_number(target_value + '<EFBFBD><EFBFBD>')
extracted_info = {}
for model_name, model in zip(["name", "identity", "card"], [name, identity, card]):
extracted_info[model_name] = model(converted_id)
result = {}
for model_name, info_list in extracted_info.items():
for item in info_list:
for key, value in item.items():
result[key.lower()] = value[0]['text']
extracted_result = ExtractedInfo(
name=result.get('<EFBFBD><EFBFBD><EFBFBD><EFBFBD>', '') or result.get('<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>', '') or result.get('<EFBFBD><EFBFBD><EFBFBD>Ա', ''),
cardNumber=result.get('<EFBFBD><EFBFBD><EFBFBD><EFBFBD>', '') or result.get('<EFBFBD><EFBFBD><EFBFBD>׿<EFBFBD><EFBFBD><EFBFBD>', '') or result.get('<EFBFBD><EFBFBD><EFBFBD>п<EFBFBD><EFBFBD><EFBFBD>', ''),
idNumber=result.get('<EFBFBD><EFBFBD><EFBFBD><EFBFBD>֤<EFBFBD><EFBFBD>', '') or result.get('<EFBFBD><EFBFBD><EFBFBD><EFBFBD>֤<EFBFBD><EFBFBD><EFBFBD><EFBFBD>', '') or result.get('<EFBFBD><EFBFBD><EFBFBD><EFBFBD>֤<EFBFBD><EFBFBD><EFBFBD><EFBFBD>', '')
)
return extracted_result
except Exception as e:
logger.error(f"Error extracting information: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)