添加实体抽取接口

main
fanpt 1 year ago
parent 385c41486f
commit bfcdae9699

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4"> <module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="virtual_patient_qa" jdkType="Python SDK" /> <orderEntry type="jdk" jdkName="interrorobot" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
<component name="PyDocumentationSettings"> <component name="PyDocumentationSettings">

@ -1,4 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="virtual_patient_qa" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="interrorobot" project-jdk-type="Python SDK" />
</project> </project>

@ -1,3 +1,44 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from paddlespeech.cli.tts.infer import TTSExecutor from paddlespeech.cli.tts.infer import TTSExecutor
tts = TTSExecutor()
tts(text="今天天气十分不错。", output="output.wav") app = FastAPI()
tts_executor = TTSExecutor()
class TextRequest(BaseModel):
text: str
def warm_up_tts():
default_text = "初始化文本" # 你可以使用任何你喜欢的文本
output_file = "warm_up_output.wav" # 暖身合成结果的输出文件名
# 调用合成功能
try:
tts_executor(text=default_text, output=output_file)
except Exception as e:
print(f"Error during warm-up: {e}")
# 在应用启动时进行暖身合成
warm_up_tts()
@app.post("/synthesize/")
async def synthesize_text(text_request: TextRequest):
text = text_request.text
output_file = "output.wav" # You can customize the output file name if needed
try:
tts_executor(text=text, output=output_file)
return {"message": "语音合成成功", "output_file": output_file}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

@ -11,7 +11,6 @@ import shutil
import yaml import yaml
import logging import logging
app = FastAPI() app = FastAPI()
import sys import sys
@ -22,7 +21,7 @@ logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[ handlers=[
logging.FileHandler('log/app.log'), logging.FileHandler('log/app.log'),
logging.StreamHandler(sys.stdout) # 这里添加控制台处理程序 logging.StreamHandler(sys.stdout) # 添加控制台处理程序
] ]
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,18 +30,24 @@ class QuestionRequest(BaseModel):
question: str question: str
scoreThreshold: float scoreThreshold: float
class QuestionResponse(BaseModel): class QuestionResponse(BaseModel):
code: int code: int
msg: str msg: str
data: list data: list
class QuestionItem(BaseModel): class QuestionItem(BaseModel):
questionId: str questionId: str
questionList: list[str] questionList: list[str]
class InputText(BaseModel):
inputText: str
class ExtractedInfo(BaseModel):
name: str
cardNumber: str
idNnumber: str
# 读取配置文件
with open('config/config.yaml', 'r') as config_file: with open('config/config.yaml', 'r') as config_file:
config_data = yaml.safe_load(config_file) config_data = yaml.safe_load(config_file)
@ -51,7 +56,6 @@ api_url = config_data['api']['url']
path = config_data['output_file_path'] path = config_data['output_file_path']
max_knowledge_bases = config_data['max_knowledge_bases'] max_knowledge_bases = config_data['max_knowledge_bases']
def load_knowledge_bases(): def load_knowledge_bases():
"""加载知识库名称列表""" """加载知识库名称列表"""
if os.path.exists(knowledge_base_file): if os.path.exists(knowledge_base_file):
@ -60,13 +64,11 @@ def load_knowledge_bases():
else: else:
return [] return []
def save_knowledge_bases(names): def save_knowledge_bases(names):
"""保存知识库名称列表到文件""" """保存知识库名称列表到文件"""
with open(knowledge_base_file, "w") as file: with open(knowledge_base_file, "w") as file:
file.write("\n".join(names)) file.write("\n".join(names))
def update_kb(kb_name, qa_service, path, max_knowledge_bases): def update_kb(kb_name, qa_service, path, max_knowledge_bases):
"""更新知识库""" """更新知识库"""
store_data(qa_service, path) store_data(qa_service, path)
@ -82,7 +84,13 @@ def update_kb(kb_name, qa_service, path, max_knowledge_bases):
logger.info(f"Knowledge base updated: {kb_name}\n" logger.info(f"Knowledge base updated: {kb_name}\n"
f"Please wait while the database is being updated···") f"Please wait while the database is being updated···")
recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)
def text_to_number(text_id):
chinese_nums = {'': '0', '': '1', '': '2', '': '3', '': '4', '': '5', '': '6', '': '7', '': '8', '': '9'}
for chinese_num, arabic_num in chinese_nums.items():
text_id = text_id.replace(chinese_num, arabic_num)
return text_id
@app.post("/updateDatabase") @app.post("/updateDatabase")
async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks): async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks):
@ -139,8 +147,38 @@ def match_question(request: QuestionRequest):
logger.error(f"Error matching question: {e}") logger.error(f"Error matching question: {e}")
return QuestionResponse(code=500, msg="success", data=[]) return QuestionResponse(code=500, msg="success", data=[])
recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases) @app.post("/extractInformation/")
async def extract_information(input_data: InputText):
"""提取信息的端点"""
try:
inputText = input_data.inputText
from paddlenlp import Taskflow
corrector = Taskflow("text_correction")
data = corrector(inputText)
target_value = data[0]['target']
converted_id = text_to_number(target_value)
schema = ["姓名", "卡号", "身份证号"]
ie = Taskflow('information_extraction', schema=schema, model='uie-base')
extracted_info = ie(converted_id)
result = {}
for item in extracted_info:
for key, value in item.items():
result[key.lower()] = value[0]['text']
extracted_result = ExtractedInfo(name=result.get('姓名', ''),
cardNumber=result.get('卡号', ''),
idNnumber=result.get('身份证号', ''))
return extracted_result
except Exception as e:
logger.error(f"Error extracting information: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn

Loading…
Cancel
Save