优化代码并添加注释

main
fanpt 9 months ago
parent 016dd673c6
commit 934f1d97cf

1
.gitignore vendored

@ -160,3 +160,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
/.idea

8
.idea/.gitignore vendored

@ -1,8 +0,0 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

@ -1,56 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="fanpt@192.168.0.102:22 password">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="fanpt@192.168.0.102:22 password (2)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="fanpt@192.168.0.102:22 password (3)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="fanpt@192.168.0.102:22 password (4)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="fanpt@192.168.0.102:22 password (5)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="fanpt@192.168.0.102:22 password (6)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="fanpt@192.168.0.102:22 password (7)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
</component>
</project>

@ -1,28 +0,0 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="14">
<item index="0" class="java.lang.String" itemvalue="protobuf" />
<item index="1" class="java.lang.String" itemvalue="transformers" />
<item index="2" class="java.lang.String" itemvalue="tensorboard" />
<item index="3" class="java.lang.String" itemvalue="icetk" />
<item index="4" class="java.lang.String" itemvalue="cpm_kernels" />
<item index="5" class="java.lang.String" itemvalue="peft" />
<item index="6" class="java.lang.String" itemvalue="accelerate" />
<item index="7" class="java.lang.String" itemvalue="torch" />
<item index="8" class="java.lang.String" itemvalue="datasets" />
<item index="9" class="java.lang.String" itemvalue="bitsandbytes" />
<item index="10" class="java.lang.String" itemvalue="ConcurrentLogHandler" />
<item index="11" class="java.lang.String" itemvalue="uwsgi" />
<item index="12" class="java.lang.String" itemvalue="ultralytics" />
<item index="13" class="java.lang.String" itemvalue="tool_helpers" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

@ -1,6 +0,0 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

@ -1,12 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="interro" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

@ -1,4 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="interro" project-jdk-type="Python SDK" />
</project>

@ -1,8 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/interro_robot_tool.iml" filepath="$PROJECT_DIR$/.idea/interro_robot_tool.iml" />
</modules>
</component>
</project>

@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

@ -1,2 +1,2 @@
a0f9848b-2d65-4b37-85ca-6712061f01c0 f5361731-865c-4c36-90a5-70499c207562
38de6667-4f5d-4f0a-8165-992ab76c1424 2d5cdfb8-b1ec-4e29-9e0d-45bfd48afedf

@ -101,7 +101,7 @@ class KBFaissPool(_FaissPool):
if os.path.isfile(os.path.join(vs_path, "index.faiss")): if os.path.isfile(os.path.join(vs_path, "index.faiss")):
# load the embedding model # load the embedding model
embeddings = self.load_kb_embeddings(local_model_path=embed_local_model_path, embed_device=embed_device) embeddings = self.load_kb_embeddings(local_model_path=embed_local_model_path, embed_device=embed_device)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT") vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT", allow_dangerous_deserialization=True)
elif create: elif create:
# create an empty vector store # create an empty vector store

@ -1,53 +1,58 @@
from fastapi import FastAPI, HTTPException, BackgroundTasks # coding=gbk
from qa_Ask import QAService, match_query, store_data import yaml
from pydantic import BaseModel import sys
from collections import deque
import requests
import os import os
import time import time
import uuid import uuid
import json import json
import shutil import shutil
import yaml
import logging import logging
from collections import deque
from pydantic import BaseModel
from fastapi import BackgroundTasks
from fastapi import FastAPI, HTTPException
from qa_Ask import QAService, match_query, store_data
app = FastAPI() app = FastAPI()
import sys # 配置日志记录到文件和终端
# 配置日志记录到文件和终端
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s', format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[ handlers=[
logging.FileHandler('log/app.log'), logging.FileHandler('log/app.log'),
logging.StreamHandler(sys.stdout) # 添加控制台处理程序 logging.StreamHandler(sys.stdout) # 添加控制台处理程序
] ]
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class QuestionRequest(BaseModel): class QuestionRequest(BaseModel):
question: str question: str
scoreThreshold: float scoreThreshold: float
class QuestionResponse(BaseModel): class QuestionResponse(BaseModel):
code: int code: int
msg: str msg: str
data: list data: list
class QuestionItem(BaseModel): class QuestionItem(BaseModel):
questionId: str questionId: str
questionList: list[str] questionList: list[str]
class InputText(BaseModel): class InputText(BaseModel):
inputText: str inputText: str
class ExtractedInfo(BaseModel): class ExtractedInfo(BaseModel):
name: str name: str
cardNumber: str cardNumber: str
idNnumber: str idNumber: str
# 读取配置文件
with open('config/config.yaml', 'r') as config_file: with open('config/config.yaml', 'r') as config_file:
config_data = yaml.safe_load(config_file) config_data = yaml.safe_load(config_file)
@ -56,21 +61,24 @@ api_url = config_data['api']['url']
path = config_data['output_file_path'] path = config_data['output_file_path']
max_knowledge_bases = config_data['max_knowledge_bases'] max_knowledge_bases = config_data['max_knowledge_bases']
def load_knowledge_bases(): def load_knowledge_bases():
"""加载知识库名称列表""" """加载知识库名称列表"""
if os.path.exists(knowledge_base_file): if os.path.exists(knowledge_base_file):
with open(knowledge_base_file, "r") as file: with open(knowledge_base_file, "r") as file:
return file.read().splitlines() return file.read().splitlines()
else: else:
return [] return []
def save_knowledge_bases(names): def save_knowledge_bases(names):
"""保存知识库名称列表到文件""" """保存知识库名称列表到文件"""
with open(knowledge_base_file, "w") as file: with open(knowledge_base_file, "w") as file:
file.write("\n".join(names)) file.write("\n".join(names))
def update_kb(kb_name, qa_service, path, max_knowledge_bases): def update_kb(kb_name, qa_service, path, max_knowledge_bases):
"""更新知识库""" """更新知识库"""
store_data(qa_service, path) store_data(qa_service, path)
if len(recent_knowledge_bases) == max_knowledge_bases: if len(recent_knowledge_bases) == max_knowledge_bases:
@ -82,19 +90,21 @@ def update_kb(kb_name, qa_service, path, max_knowledge_bases):
os.remove(path) os.remove(path)
logger.info(f"Knowledge base updated: {kb_name}\n" logger.info(f"Knowledge base updated: {kb_name}\n"
f"Please wait while the database is being updated···") f"Please wait while the database is being updated···")
recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases) recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)
def text_to_number(text_id): def text_to_number(text_id):
chinese_nums = {'': '0', '': '1', '': '2', '': '3', '': '4', '': '5', '': '6', '': '7', '': '8', '': '9'} chinese_nums = {'': '0', '': '1', '': '2', '': '3', '': '4', '': '5', '': '6', '': '7', '': '8', '': '9'}
for chinese_num, arabic_num in chinese_nums.items(): translation_table = str.maketrans(chinese_nums)
text_id = text_id.replace(chinese_num, arabic_num) return text_id.translate(translation_table)
return text_id
@app.post("/updateDatabase") @app.post("/updateDatabase")
async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks): async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks):
"""接收问题数据并异步保存为JSON文件触发后台更新任务""" """接收问题数据并异步保存为JSON文件触发后台更新任务"""
try: try:
json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2) json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2)
path = "output.json" path = "output.json"
@ -111,16 +121,17 @@ async def save_to_json(question_items: list[QuestionItem], background_tasks: Bac
update_kb, kb_name, qa_service, path, max_knowledge_bases update_kb, kb_name, qa_service, path, max_knowledge_bases
) )
return {"status": "success", "message": "Please wait while the database is being updated···"} return {"status": "success", "message": "Please wait while the database is being updated···"}
except Exception as e: except Exception as e:
logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}") logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}")
# raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") # raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
return {"status": "error", "message": "update task error···"} return {"status": "error", "message": "update task error···"}
@app.post("/matchQuestion") @app.post("/matchQuestion")
def match_question(request: QuestionRequest): def match_question(request: QuestionRequest):
"""匹配问题的端点""" """匹配问题的端点"""
try: try:
logger.info(f"match_question:Request: {request}") logger.info(f"match_question:Request: {request}")
start_time = time.time() start_time = time.time()
@ -147,33 +158,40 @@ def match_question(request: QuestionRequest):
logger.error(f"Error matching question: {e}") logger.error(f"Error matching question: {e}")
return QuestionResponse(code=500, msg="success", data=[]) return QuestionResponse(code=500, msg="success", data=[])
from paddlenlp import Taskflow
corrector = Taskflow("text_correction")
schema = ["姓名", '嫌疑人', '涉案人员', "身份证号", "交易证件号", "卡号", "交易卡号", "银行卡号", ]
name = Taskflow('information_extraction', schema=schema[:2], model='uie-base')
identity = Taskflow('information_extraction', schema=schema[3:5], model='uie-base')
card = Taskflow('information_extraction', schema=schema[5:8], model='uie-base')
@app.post("/extractInformation") @app.post("/extractInformation")
async def extract_information(input_data: InputText): async def extract_information(input_data: InputText):
"""提取信息的端点""" """提取信息的端点"""
try: try:
inputText = input_data.inputText input_text = input_data.inputText
from paddlenlp import Taskflow
corrector = Taskflow("text_correction")
data = corrector(inputText)
data = corrector(input_text)
target_value = data[0]['target'] target_value = data[0]['target']
converted_id = text_to_number(target_value + '')
converted_id = text_to_number(target_value) extracted_info = {}
for model_name, model in zip(["name", "identity", "card"], [name, identity, card]):
schema = ["姓名", '嫌疑人', '涉案人员', "身份证号", "交易证件号", "卡号", "交易卡号", "银行卡号", ] extracted_info[model_name] = model(converted_id)
ie = Taskflow('information_extraction', schema=schema, model='uie-base')
extracted_info = ie(converted_id)
result = {} result = {}
for item in extracted_info: for model_name, info_list in extracted_info.items():
for key, value in item.items(): for item in info_list:
result[key.lower()] = value[0]['text'] for key, value in item.items():
result[key.lower()] = value[0]['text']
extracted_result = ExtractedInfo( extracted_result = ExtractedInfo(
name=result.get('姓名', '') or result.get('嫌疑人', '') or result.get('涉案人员', ''), name=result.get('姓名', '') or result.get('嫌疑人', '') or result.get('涉案人员', ''),
cardNumber=result.get('卡号', '') or result.get('交易卡号', '') or result.get('银行卡号', ''), cardNumber=result.get('卡号', '') or result.get('交易卡号', '') or result.get('银行卡号', ''),
idNnumber=result.get('身份证号', '') or result.get('交易证件号', '') or result.get('交易证件号', '') idNumber=result.get('身份证号', '') or result.get('交易证件号', '') or result.get('交易证件号', '')
) )
return extracted_result return extracted_result
@ -182,7 +200,9 @@ async def extract_information(input_data: InputText):
logger.error(f"Error extracting information: {e}") logger.error(f"Error extracting information: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error") raise HTTPException(status_code=500, detail="Internal Server Error")
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) uvicorn.run(app, host="0.0.0.0", port=8001)

Loading…
Cancel
Save