添加实体抽取接口

main
fanpt 1 year ago
parent 385c41486f
commit bfcdae9699

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="virtual_patient_qa" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="interrorobot" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">

@ -1,4 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="virtual_patient_qa" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="interrorobot" project-jdk-type="Python SDK" />
</project>

@ -1,3 +1,44 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from paddlespeech.cli.tts.infer import TTSExecutor
tts = TTSExecutor()
tts(text="今天天气十分不错。", output="output.wav")
app = FastAPI()
tts_executor = TTSExecutor()
class TextRequest(BaseModel):
text: str
def warm_up_tts():
default_text = "初始化文本" # 你可以使用任何你喜欢的文本
output_file = "warm_up_output.wav" # 暖身合成结果的输出文件名
# 调用合成功能
try:
tts_executor(text=default_text, output=output_file)
except Exception as e:
print(f"Error during warm-up: {e}")
# 在应用启动时进行暖身合成
warm_up_tts()
@app.post("/synthesize/")
async def synthesize_text(text_request: TextRequest):
text = text_request.text
output_file = "output.wav" # You can customize the output file name if needed
try:
tts_executor(text=text, output=output_file)
return {"message": "语音合成成功", "output_file": output_file}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

@ -11,7 +11,6 @@ import shutil
import yaml
import logging
app = FastAPI()
import sys
@ -22,7 +21,7 @@ logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('log/app.log'),
logging.StreamHandler(sys.stdout) # 这里添加控制台处理程序
logging.StreamHandler(sys.stdout) # 添加控制台处理程序
]
)
logger = logging.getLogger(__name__)
@ -31,18 +30,24 @@ class QuestionRequest(BaseModel):
question: str
scoreThreshold: float
class QuestionResponse(BaseModel):
code: int
msg: str
data: list
class QuestionItem(BaseModel):
questionId: str
questionList: list[str]
class InputText(BaseModel):
inputText: str
class ExtractedInfo(BaseModel):
name: str
cardNumber: str
idNnumber: str
# 读取配置文件
with open('config/config.yaml', 'r') as config_file:
config_data = yaml.safe_load(config_file)
@ -51,7 +56,6 @@ api_url = config_data['api']['url']
path = config_data['output_file_path']
max_knowledge_bases = config_data['max_knowledge_bases']
def load_knowledge_bases():
"""加载知识库名称列表"""
if os.path.exists(knowledge_base_file):
@ -60,13 +64,11 @@ def load_knowledge_bases():
else:
return []
def save_knowledge_bases(names):
"""保存知识库名称列表到文件"""
with open(knowledge_base_file, "w") as file:
file.write("\n".join(names))
def update_kb(kb_name, qa_service, path, max_knowledge_bases):
"""更新知识库"""
store_data(qa_service, path)
@ -82,7 +84,13 @@ def update_kb(kb_name, qa_service, path, max_knowledge_bases):
logger.info(f"Knowledge base updated: {kb_name}\n"
f"Please wait while the database is being updated···")
recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)
def text_to_number(text_id):
chinese_nums = {'': '0', '': '1', '': '2', '': '3', '': '4', '': '5', '': '6', '': '7', '': '8', '': '9'}
for chinese_num, arabic_num in chinese_nums.items():
text_id = text_id.replace(chinese_num, arabic_num)
return text_id
@app.post("/updateDatabase")
async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks):
@ -139,8 +147,38 @@ def match_question(request: QuestionRequest):
logger.error(f"Error matching question: {e}")
return QuestionResponse(code=500, msg="success", data=[])
recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)
@app.post("/extractInformation/")
async def extract_information(input_data: InputText):
"""提取信息的端点"""
try:
inputText = input_data.inputText
from paddlenlp import Taskflow
corrector = Taskflow("text_correction")
data = corrector(inputText)
target_value = data[0]['target']
converted_id = text_to_number(target_value)
schema = ["姓名", "卡号", "身份证号"]
ie = Taskflow('information_extraction', schema=schema, model='uie-base')
extracted_info = ie(converted_id)
result = {}
for item in extracted_info:
for key, value in item.items():
result[key.lower()] = value[0]['text']
extracted_result = ExtractedInfo(name=result.get('姓名', ''),
cardNumber=result.get('卡号', ''),
idNnumber=result.get('身份证号', ''))
return extracted_result
except Exception as e:
logger.error(f"Error extracting information: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")
if __name__ == "__main__":
import uvicorn

Loading…
Cancel
Save