From bfcdae96997706a63481cbb76fd798630c97f5dd Mon Sep 17 00:00:00 2001
From: fanpt <320622572@qq.com>
Date: Wed, 27 Mar 2024 13:36:26 +0800
Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=AE=9E=E4=BD=93=E6=8A=BD?=
 =?UTF-8?q?=E5=8F=96=E6=8E=A5=E5=8F=A3?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 .idea/interro_robot_tool.iml |  2 +-
 .idea/misc.xml               |  2 +-
 PaddelSpeech.py              | 45 ++++++++++++++++++++++++++++--
 fast_api.py                  | 54 ++++++++++++++++++++++++++++++------
 4 files changed, 91 insertions(+), 12 deletions(-)

diff --git a/.idea/interro_robot_tool.iml b/.idea/interro_robot_tool.iml
index c0f1096..d62d042 100644
--- a/.idea/interro_robot_tool.iml
+++ b/.idea/interro_robot_tool.iml
@@ -2,7 +2,7 @@
 <module type="PYTHON_MODULE" version="4">
   <component name="NewModuleRootManager">
     <content url="file://$MODULE_DIR$" />
-    <orderEntry type="jdk" jdkName="virtual_patient_qa" jdkType="Python SDK" />
+    <orderEntry type="jdk" jdkName="interrorobot" jdkType="Python SDK" />
     <orderEntry type="sourceFolder" forTests="false" />
   </component>
   <component name="PyDocumentationSettings">
diff --git a/.idea/misc.xml b/.idea/misc.xml
index 51993a2..ac5d38c 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -1,4 +1,4 @@
 <?xml version="1.0" encoding="UTF-8"?>
 <project version="4">
-  <component name="ProjectRootManager" version="2" project-jdk-name="virtual_patient_qa" project-jdk-type="Python SDK" />
+  <component name="ProjectRootManager" version="2" project-jdk-name="interrorobot" project-jdk-type="Python SDK" />
 </project>
\ No newline at end of file
diff --git a/PaddelSpeech.py b/PaddelSpeech.py
index 2367f09..780ca80 100644
--- a/PaddelSpeech.py
+++ b/PaddelSpeech.py
@@ -1,3 +1,44 @@
+from fastapi import FastAPI, HTTPException
+from pydantic import BaseModel
 from paddlespeech.cli.tts.infer import TTSExecutor
-tts = TTSExecutor()
-tts(text="今天天气十分不错。", output="output.wav")
\ No newline at end of file
+
+app = FastAPI()
+tts_executor = TTSExecutor()
+
+
+class TextRequest(BaseModel):
+    text: str
+
+
+def warm_up_tts():
+    default_text = "初始化文本"  # 你可以使用任何你喜欢的文本
+    output_file = "warm_up_output.wav"  # 暖身合成结果的输出文件名
+
+    # 调用合成功能
+    try:
+        tts_executor(text=default_text, output=output_file)
+    except Exception as e:
+        print(f"Error during warm-up: {e}")
+
+
+# 在应用启动时进行暖身合成
+warm_up_tts()
+
+
+@app.post("/synthesize/")
+async def synthesize_text(text_request: TextRequest):
+    text = text_request.text
+    output_file = "output.wav"  # You can customize the output file name if needed
+
+    try:
+        tts_executor(text=text, output=output_file)
+        return {"message": "语音合成成功", "output_file": output_file}
+    except Exception as e:
+        raise HTTPException(status_code=500, detail=str(e))
+
+
+
+if __name__ == "__main__":
+    import uvicorn
+
+    uvicorn.run(app, host="0.0.0.0", port=8000)
diff --git a/fast_api.py b/fast_api.py
index c7f1308..52bd174 100644
--- a/fast_api.py
+++ b/fast_api.py
@@ -11,7 +11,6 @@ import shutil
 import yaml
 import logging
 
-
 app = FastAPI()
 
 import sys
@@ -22,7 +21,7 @@ logging.basicConfig(
     format='%(asctime)s - %(levelname)s - %(message)s',
     handlers=[
         logging.FileHandler('log/app.log'),
-        logging.StreamHandler(sys.stdout)  # 这里添加控制台处理程序
+        logging.StreamHandler(sys.stdout)  # 添加控制台处理程序
     ]
 )
 logger = logging.getLogger(__name__)
@@ -31,18 +30,24 @@ class QuestionRequest(BaseModel):
     question: str
     scoreThreshold: float
 
-
 class QuestionResponse(BaseModel):
     code: int
     msg: str
     data: list
 
-
 class QuestionItem(BaseModel):
     questionId: str
     questionList: list[str]
 
+class InputText(BaseModel):
+    inputText: str
 
+class ExtractedInfo(BaseModel):
+    name: str
+    cardNumber: str
+    idNnumber: str
+
+# 读取配置文件
 with open('config/config.yaml', 'r') as config_file:
     config_data = yaml.safe_load(config_file)
 
@@ -51,7 +56,6 @@ api_url = config_data['api']['url']
 path = config_data['output_file_path']
 max_knowledge_bases = config_data['max_knowledge_bases']
 
-
 def load_knowledge_bases():
     """加载知识库名称列表"""
     if os.path.exists(knowledge_base_file):
@@ -60,13 +64,11 @@ def load_knowledge_bases():
     else:
         return []
 
-
 def save_knowledge_bases(names):
     """保存知识库名称列表到文件"""
     with open(knowledge_base_file, "w") as file:
         file.write("\n".join(names))
 
-
 def update_kb(kb_name, qa_service, path, max_knowledge_bases):
     """更新知识库"""
     store_data(qa_service, path)
@@ -82,7 +84,13 @@ def update_kb(kb_name, qa_service, path, max_knowledge_bases):
     logger.info(f"Knowledge base updated: {kb_name}\n"
                 f"Please wait while the database is being updated···")
 
+recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)
 
+def text_to_number(text_id):
+    chinese_nums = {'零': '0', '一': '1', '二': '2', '三': '3', '四': '4', '五': '5', '六': '6', '七': '7', '八': '8', '九': '9'}
+    for chinese_num, arabic_num in chinese_nums.items():
+        text_id = text_id.replace(chinese_num, arabic_num)
+    return text_id
 
 @app.post("/updateDatabase")
 async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks):
@@ -139,8 +147,38 @@ def match_question(request: QuestionRequest):
         logger.error(f"Error matching question: {e}")
         return QuestionResponse(code=500, msg="success", data=[])
 
-recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)
+@app.post("/extractInformation/")
+async def extract_information(input_data: InputText):
+    """提取信息的端点"""
+    try:
+        inputText = input_data.inputText
+        from paddlenlp import Taskflow
+
+        corrector = Taskflow("text_correction")
+        data = corrector(inputText)
+
+        target_value = data[0]['target']
+
+        converted_id = text_to_number(target_value)
+
+        schema = ["姓名", "卡号", "身份证号"]
+        ie = Taskflow('information_extraction', schema=schema, model='uie-base')
+        extracted_info = ie(converted_id)
 
+        result = {}
+        for item in extracted_info:
+            for key, value in item.items():
+                result[key.lower()] = value[0]['text']
+
+        extracted_result = ExtractedInfo(name=result.get('姓名', ''),
+                                         cardNumber=result.get('卡号', ''),
+                                         idNnumber=result.get('身份证号', ''))
+
+        return extracted_result
+
+    except Exception as e:
+        logger.error(f"Error extracting information: {e}")
+        raise HTTPException(status_code=500, detail="Internal Server Error")
 
 if __name__ == "__main__":
     import uvicorn