1. 添加sql解析校验

topo_dev
xueqingkun 11 months ago
parent 1e7b1bf8b6
commit 311670ce62

@ -115,6 +115,11 @@ public class RowMapperStatementBuilder {
return msId; return msId;
} }
public SqlSource getSqlSource(String sql, Class<?> parameterType) {
return languageDriver.createSqlSource(configuration, sql, parameterType);
}
public String select(String sql, Class<?> resultType) { public String select(String sql, Class<?> resultType) {
String msId = generateMappedStatementId(resultType + sql, SqlCommandType.SELECT); String msId = generateMappedStatementId(resultType + sql, SqlCommandType.SELECT);
if (hasMappedStatement(msId)) { if (hasMappedStatement(msId)) {

@ -3,6 +3,7 @@ package com.supervision.police.mybatis;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.exceptions.TooManyResultsException; import org.apache.ibatis.exceptions.TooManyResultsException;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.SqlSession; import org.apache.ibatis.session.SqlSession;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
@ -160,6 +161,10 @@ public class RowSqlMapper {
return sqlSession.selectList(msId, value); return sqlSession.selectList(msId, value);
} }
public BoundSql getBoundSql(String sql, Object parameter) {
return mapperStatementBuilder.getSqlSource(sql, parameter.getClass()).getBoundSql(parameter);
}
/** /**
* *
* *

@ -5,6 +5,8 @@ import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.NumberUtil; import cn.hutool.core.util.NumberUtil;
import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.ObjectUtil;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.supervision.common.domain.R; import com.supervision.common.domain.R;
import com.supervision.common.utils.StringUtils; import com.supervision.common.utils.StringUtils;
@ -18,11 +20,13 @@ import com.supervision.police.dto.caseScore.CaseScoreDetailBuilder;
import com.supervision.police.mapper.*; import com.supervision.police.mapper.*;
import com.supervision.police.mybatis.RowSqlMapper; import com.supervision.police.mybatis.RowSqlMapper;
import com.supervision.police.service.ModelService; import com.supervision.police.service.ModelService;
import com.supervision.utils.SqlParserUtil;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.neo4j.driver.Driver; import org.neo4j.driver.Driver;
import org.neo4j.driver.Result; import org.neo4j.driver.Result;
import org.neo4j.driver.Session; import org.neo4j.driver.Session;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.*; import java.util.*;
@ -48,6 +52,9 @@ public class ModelServiceImpl implements ModelService {
private final RowSqlMapper rowSqlMapper; private final RowSqlMapper rowSqlMapper;
@Value("${case.evidence.table}")
private List<String> allowedTables;
@Override @Override
public R<?> analyseCase(AnalyseCaseDTO analyseCaseDTO) { public R<?> analyseCase(AnalyseCaseDTO analyseCaseDTO) {
ModelCase modelCase = modelCaseMapper.selectById(analyseCaseDTO.getCaseId()); ModelCase modelCase = modelCaseMapper.selectById(analyseCaseDTO.getCaseId());
@ -292,20 +299,49 @@ public class ModelServiceImpl implements ModelService {
params.put("provider", null); params.put("provider", null);
params.put("party_a", analyseCaseDTO.getLawActorName()); params.put("party_a", analyseCaseDTO.getLawActorName());
params.put("party_b", analyseCaseDTO.getLawParty()); params.put("party_b", analyseCaseDTO.getLawParty());
boolean success = false;
if (checkSql(sql,allowedTables)){
success = parseResult(rowSqlMapper.selectList(sql, params, Map.class));
}
result.setAtomicResult(success ? "1" : "0");
/*
todo:sqlsqlselectsql
MappedStatement mappedStatement = rowSqlMapper.selectMappedStatement(sql, Map.class);
BoundSql boundSql = mappedStatement.getBoundSql(params);
String sql1 = boundSql.getSql();
*/
boolean b = parseResult(rowSqlMapper.selectList(sql, params, Map.class)); }
result.setAtomicResult(b ? "1" : "0");
private boolean checkSql(String sql,List<String> allowedTables) {
if (StringUtils.isEmpty(sql)) {
return false;
}
if (CollUtil.isEmpty(allowedTables)){
log.info("checkSql:未配置允许的表");
return false;
}
MySqlStatementParser parser = new MySqlStatementParser(sql);
SQLStatement sqlStatement = SqlParserUtil.parseStatement(parser);
if (Objects.isNull(sqlStatement)) {
log.warn("checkSql sql:{}语句解析失败", sql);
return false;
}
String sqlType = SqlParserUtil.detectSQLType(sqlStatement);
if (!"SELECT".equals(sqlType)) {
log.warn("checkSql:只支持查询类型语句");
return false;
}
List<String> tableList = SqlParserUtil.extractTableNames(sqlStatement);
if (CollUtil.isEmpty(tableList)){
log.warn("checkSql:未检测到表");
return false;
}
long count = tableList.stream().filter(table -> !allowedTables.contains(table)).count();
if (count > 0){
log.warn("checkSql:表{}不在允许的表列表中",tableList);
return false;
}
return true;
} }
/** /**
* *
* 1. 11=100=0 * 1. 11=100=0

@ -5,6 +5,8 @@ import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor; import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
import com.alibaba.druid.sql.parser.SQLStatementParser; import com.alibaba.druid.sql.parser.SQLStatementParser;
import com.alibaba.druid.stat.TableStat; import com.alibaba.druid.stat.TableStat;
import com.supervision.common.utils.StringUtils;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
@ -59,11 +61,11 @@ public class SqlParserUtil {
Set<TableStat.Name> tableStatNames = visitor.getTables().keySet(); Set<TableStat.Name> tableStatNames = visitor.getTables().keySet();
return tableStatNames.stream().map(TableStat.Name::getName).toList(); return tableStatNames.stream().map(tableStat -> StringUtils.lowerCase(tableStat.getName())).toList();
} }
public static void main(String[] args) { public static void main(String[] args) {
String sql = "select u.name, o.id FROM users u left JOIN orders o ON u.id = o.user_id left join dept d on u.id = d.id WHERE u.id = 1 and u.id in ( select id from people where name = 'dd' )"; String sql = "select u.name, o.id FROM users u left JOIN orders o ON u.id = o.user_id left join dept d on u.id = d.id WHERE u.id = 1 and u.id in ( select id from people where name = ? )";
MySqlStatementParser parser = new MySqlStatementParser(sql); MySqlStatementParser parser = new MySqlStatementParser(sql);
SQLStatement sqlStatement = parseStatement(parser); SQLStatement sqlStatement = parseStatement(parser);

@ -18,3 +18,6 @@ server:
port: 8097 port: 8097
servlet: servlet:
context-path: /fu-hsi-server context-path: /fu-hsi-server
case:
evidence:
table: case_evidence

@ -11,13 +11,16 @@ import com.supervision.police.dto.AtomicData;
import com.supervision.police.dto.JudgeLogic; import com.supervision.police.dto.JudgeLogic;
import com.supervision.police.mapper.ModelAtomicIndexMapper; import com.supervision.police.mapper.ModelAtomicIndexMapper;
import com.supervision.police.mapper.ModelIndexMapper; import com.supervision.police.mapper.ModelIndexMapper;
import com.supervision.police.mybatis.RowSqlMapper;
import lombok.Data; import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.mapping.BoundSql;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.context.SpringBootTest;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -31,6 +34,9 @@ public class ModelIndexTest {
@Autowired @Autowired
private ModelAtomicIndexMapper modelAtomicIndexMapper; private ModelAtomicIndexMapper modelAtomicIndexMapper;
@Autowired
private RowSqlMapper rowSqlMapper;
//@Test //@Test
public void modelIndexGenerate() { public void modelIndexGenerate() {
@ -44,6 +50,23 @@ public class ModelIndexTest {
} }
} }
@Test
public void modelAtomicIndexGenerate() {
String sql = "select * from model_index_result where index_id = #{id};";
HashMap<String, Object> params = new HashMap<>(){
{
put("id", "1815615911075151874");
}
};
BoundSql boundSql = rowSqlMapper.getBoundSql(sql, params);;
String sql1 = boundSql.getSql();
System.out.println(sql1);
Map map = rowSqlMapper.selectOne(sql, params);
System.out.println(map);
}
public List<ModelIndex> readModelIndex() { public List<ModelIndex> readModelIndex() {
String path = "F:\\supervision\\doc\\宁夏公安\\宁夏大模型项目计划0722.xlsx"; String path = "F:\\supervision\\doc\\宁夏公安\\宁夏大模型项目计划0722.xlsx";

Loading…
Cancel
Save