优化代码并添加注释

main
fanpt 9 months ago
parent 016dd673c6
commit 934f1d97cf

1
.gitignore vendored

@ -160,3 +160,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
/.idea

8
.idea/.gitignore vendored

@ -1,8 +0,0 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

@ -1,56 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="fanpt@192.168.0.102:22 password">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="fanpt@192.168.0.102:22 password (2)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="fanpt@192.168.0.102:22 password (3)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="fanpt@192.168.0.102:22 password (4)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="fanpt@192.168.0.102:22 password (5)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="fanpt@192.168.0.102:22 password (6)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="fanpt@192.168.0.102:22 password (7)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
</component>
</project>

@ -1,28 +0,0 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="14">
<item index="0" class="java.lang.String" itemvalue="protobuf" />
<item index="1" class="java.lang.String" itemvalue="transformers" />
<item index="2" class="java.lang.String" itemvalue="tensorboard" />
<item index="3" class="java.lang.String" itemvalue="icetk" />
<item index="4" class="java.lang.String" itemvalue="cpm_kernels" />
<item index="5" class="java.lang.String" itemvalue="peft" />
<item index="6" class="java.lang.String" itemvalue="accelerate" />
<item index="7" class="java.lang.String" itemvalue="torch" />
<item index="8" class="java.lang.String" itemvalue="datasets" />
<item index="9" class="java.lang.String" itemvalue="bitsandbytes" />
<item index="10" class="java.lang.String" itemvalue="ConcurrentLogHandler" />
<item index="11" class="java.lang.String" itemvalue="uwsgi" />
<item index="12" class="java.lang.String" itemvalue="ultralytics" />
<item index="13" class="java.lang.String" itemvalue="tool_helpers" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

@ -1,6 +0,0 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

@ -1,12 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="interro" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

@ -1,4 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="interro" project-jdk-type="Python SDK" />
</project>

@ -1,8 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/interro_robot_tool.iml" filepath="$PROJECT_DIR$/.idea/interro_robot_tool.iml" />
</modules>
</component>
</project>

@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

@ -1,2 +1,2 @@
a0f9848b-2d65-4b37-85ca-6712061f01c0
38de6667-4f5d-4f0a-8165-992ab76c1424
f5361731-865c-4c36-90a5-70499c207562
2d5cdfb8-b1ec-4e29-9e0d-45bfd48afedf

@ -101,7 +101,7 @@ class KBFaissPool(_FaissPool):
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
# load the embedding model
embeddings = self.load_kb_embeddings(local_model_path=embed_local_model_path, embed_device=embed_device)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT", allow_dangerous_deserialization=True)
elif create:
# create an empty vector store

@ -1,53 +1,58 @@
from fastapi import FastAPI, HTTPException, BackgroundTasks
from qa_Ask import QAService, match_query, store_data
from pydantic import BaseModel
from collections import deque
import requests
# coding=gbk
import yaml
import sys
import os
import time
import uuid
import json
import shutil
import yaml
import logging
from collections import deque
from pydantic import BaseModel
from fastapi import BackgroundTasks
from fastapi import FastAPI, HTTPException
from qa_Ask import QAService, match_query, store_data
app = FastAPI()
import sys
# 配置日志记录到文件和终端
# 配置日志记录到文件和终端
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('log/app.log'),
logging.StreamHandler(sys.stdout) # 添加控制台处理程序
logging.StreamHandler(sys.stdout) # 添加控制台处理程序
]
)
logger = logging.getLogger(__name__)
class QuestionRequest(BaseModel):
question: str
scoreThreshold: float
class QuestionResponse(BaseModel):
code: int
msg: str
data: list
class QuestionItem(BaseModel):
questionId: str
questionList: list[str]
class InputText(BaseModel):
inputText: str
class ExtractedInfo(BaseModel):
name: str
cardNumber: str
idNnumber: str
idNumber: str
# 读取配置文件
with open('config/config.yaml', 'r') as config_file:
config_data = yaml.safe_load(config_file)
@ -56,21 +61,24 @@ api_url = config_data['api']['url']
path = config_data['output_file_path']
max_knowledge_bases = config_data['max_knowledge_bases']
def load_knowledge_bases():
"""加载知识库名称列表"""
"""加载知识库名称列表"""
if os.path.exists(knowledge_base_file):
with open(knowledge_base_file, "r") as file:
return file.read().splitlines()
else:
return []
def save_knowledge_bases(names):
"""保存知识库名称列表到文件"""
"""保存知识库名称列表到文件"""
with open(knowledge_base_file, "w") as file:
file.write("\n".join(names))
def update_kb(kb_name, qa_service, path, max_knowledge_bases):
"""更新知识库"""
"""更新知识库"""
store_data(qa_service, path)
if len(recent_knowledge_bases) == max_knowledge_bases:
@ -82,19 +90,21 @@ def update_kb(kb_name, qa_service, path, max_knowledge_bases):
os.remove(path)
logger.info(f"Knowledge base updated: {kb_name}\n"
f"Please wait while the database is being updated···")
f"Please wait while the database is being updated···")
recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)
def text_to_number(text_id):
chinese_nums = {'': '0', '': '1', '': '2', '': '3', '': '4', '': '5', '': '6', '': '7', '': '8', '': '9'}
for chinese_num, arabic_num in chinese_nums.items():
text_id = text_id.replace(chinese_num, arabic_num)
return text_id
chinese_nums = {'': '0', '': '1', '': '2', '': '3', '': '4', '': '5', '': '6', '': '7', '': '8', '': '9'}
translation_table = str.maketrans(chinese_nums)
return text_id.translate(translation_table)
@app.post("/updateDatabase")
async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks):
"""接收问题数据并异步保存为JSON文件触发后台更新任务"""
"""接收问题数据并异步保存为JSON文件触发后台更新任务"""
try:
json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2)
path = "output.json"
@ -111,16 +121,17 @@ async def save_to_json(question_items: list[QuestionItem], background_tasks: Bac
update_kb, kb_name, qa_service, path, max_knowledge_bases
)
return {"status": "success", "message": "Please wait while the database is being updated···"}
return {"status": "success", "message": "Please wait while the database is being updated···"}
except Exception as e:
logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}")
# raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
return {"status": "error", "message": "update task error···"}
return {"status": "error", "message": "update task error···"}
@app.post("/matchQuestion")
def match_question(request: QuestionRequest):
"""匹配问题的端点"""
"""匹配问题的端点"""
try:
logger.info(f"match_question:Request: {request}")
start_time = time.time()
@ -147,33 +158,40 @@ def match_question(request: QuestionRequest):
logger.error(f"Error matching question: {e}")
return QuestionResponse(code=500, msg="success", data=[])
from paddlenlp import Taskflow
corrector = Taskflow("text_correction")
schema = ["姓名", '嫌疑人', '涉案人员', "身份证号", "交易证件号", "卡号", "交易卡号", "银行卡号", ]
name = Taskflow('information_extraction', schema=schema[:2], model='uie-base')
identity = Taskflow('information_extraction', schema=schema[3:5], model='uie-base')
card = Taskflow('information_extraction', schema=schema[5:8], model='uie-base')
@app.post("/extractInformation")
async def extract_information(input_data: InputText):
"""提取信息的端点"""
"""提取信息的端点"""
try:
inputText = input_data.inputText
from paddlenlp import Taskflow
corrector = Taskflow("text_correction")
data = corrector(inputText)
input_text = input_data.inputText
data = corrector(input_text)
target_value = data[0]['target']
converted_id = text_to_number(target_value + '')
converted_id = text_to_number(target_value)
schema = ["姓名", '嫌疑人', '涉案人员', "身份证号", "交易证件号", "卡号", "交易卡号", "银行卡号", ]
ie = Taskflow('information_extraction', schema=schema, model='uie-base')
extracted_info = ie(converted_id)
extracted_info = {}
for model_name, model in zip(["name", "identity", "card"], [name, identity, card]):
extracted_info[model_name] = model(converted_id)
result = {}
for item in extracted_info:
for key, value in item.items():
result[key.lower()] = value[0]['text']
for model_name, info_list in extracted_info.items():
for item in info_list:
for key, value in item.items():
result[key.lower()] = value[0]['text']
extracted_result = ExtractedInfo(
name=result.get('姓名', '') or result.get('嫌疑人', '') or result.get('涉案人员', ''),
cardNumber=result.get('卡号', '') or result.get('交易卡号', '') or result.get('银行卡号', ''),
idNnumber=result.get('身份证号', '') or result.get('交易证件号', '') or result.get('交易证件号', '')
name=result.get('姓名', '') or result.get('嫌疑人', '') or result.get('涉案人员', ''),
cardNumber=result.get('卡号', '') or result.get('交易卡号', '') or result.get('银行卡号', ''),
idNumber=result.get('身份证号', '') or result.get('交易证件号', '') or result.get('交易证件号', '')
)
return extracted_result
@ -182,7 +200,9 @@ async def extract_information(input_data: InputText):
logger.error(f"Error extracting information: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=8001)

Loading…
Cancel
Save