1. 添加sql解析校验

topo_dev
xueqingkun 9 months ago
parent 1e7b1bf8b6
commit 311670ce62

@ -115,6 +115,11 @@ public class RowMapperStatementBuilder {
return msId;
}
public SqlSource getSqlSource(String sql, Class<?> parameterType) {
return languageDriver.createSqlSource(configuration, sql, parameterType);
}
public String select(String sql, Class<?> resultType) {
String msId = generateMappedStatementId(resultType + sql, SqlCommandType.SELECT);
if (hasMappedStatement(msId)) {

@ -3,6 +3,7 @@ package com.supervision.police.mybatis;
import cn.hutool.json.JSONUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.exceptions.TooManyResultsException;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.SqlSession;
import org.springframework.stereotype.Component;
@ -160,6 +161,10 @@ public class RowSqlMapper {
return sqlSession.selectList(msId, value);
}
public BoundSql getBoundSql(String sql, Object parameter) {
return mapperStatementBuilder.getSqlSource(sql, parameter.getClass()).getBoundSql(parameter);
}
/**
*
*

@ -5,6 +5,8 @@ import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.NumberUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.json.JSONUtil;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.supervision.common.domain.R;
import com.supervision.common.utils.StringUtils;
@ -18,11 +20,13 @@ import com.supervision.police.dto.caseScore.CaseScoreDetailBuilder;
import com.supervision.police.mapper.*;
import com.supervision.police.mybatis.RowSqlMapper;
import com.supervision.police.service.ModelService;
import com.supervision.utils.SqlParserUtil;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.neo4j.driver.Driver;
import org.neo4j.driver.Result;
import org.neo4j.driver.Session;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.util.*;
@ -48,6 +52,9 @@ public class ModelServiceImpl implements ModelService {
private final RowSqlMapper rowSqlMapper;
@Value("${case.evidence.table}")
private List<String> allowedTables;
@Override
public R<?> analyseCase(AnalyseCaseDTO analyseCaseDTO) {
ModelCase modelCase = modelCaseMapper.selectById(analyseCaseDTO.getCaseId());
@ -292,20 +299,49 @@ public class ModelServiceImpl implements ModelService {
params.put("provider", null);
params.put("party_a", analyseCaseDTO.getLawActorName());
params.put("party_b", analyseCaseDTO.getLawParty());
boolean success = false;
if (checkSql(sql,allowedTables)){
success = parseResult(rowSqlMapper.selectList(sql, params, Map.class));
}
result.setAtomicResult(success ? "1" : "0");
/*
todo:sqlsqlselectsql
MappedStatement mappedStatement = rowSqlMapper.selectMappedStatement(sql, Map.class);
BoundSql boundSql = mappedStatement.getBoundSql(params);
String sql1 = boundSql.getSql();
*/
boolean b = parseResult(rowSqlMapper.selectList(sql, params, Map.class));
result.setAtomicResult(b ? "1" : "0");
}
private boolean checkSql(String sql,List<String> allowedTables) {
if (StringUtils.isEmpty(sql)) {
return false;
}
if (CollUtil.isEmpty(allowedTables)){
log.info("checkSql:未配置允许的表");
return false;
}
MySqlStatementParser parser = new MySqlStatementParser(sql);
SQLStatement sqlStatement = SqlParserUtil.parseStatement(parser);
if (Objects.isNull(sqlStatement)) {
log.warn("checkSql sql:{}语句解析失败", sql);
return false;
}
String sqlType = SqlParserUtil.detectSQLType(sqlStatement);
if (!"SELECT".equals(sqlType)) {
log.warn("checkSql:只支持查询类型语句");
return false;
}
List<String> tableList = SqlParserUtil.extractTableNames(sqlStatement);
if (CollUtil.isEmpty(tableList)){
log.warn("checkSql:未检测到表");
return false;
}
long count = tableList.stream().filter(table -> !allowedTables.contains(table)).count();
if (count > 0){
log.warn("checkSql:表{}不在允许的表列表中",tableList);
return false;
}
return true;
}
/**
*
* 1. 11=100=0

@ -5,6 +5,8 @@ import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
import com.alibaba.druid.sql.parser.SQLStatementParser;
import com.alibaba.druid.stat.TableStat;
import com.supervision.common.utils.StringUtils;
import java.util.List;
import java.util.Set;
@ -59,11 +61,11 @@ public class SqlParserUtil {
Set<TableStat.Name> tableStatNames = visitor.getTables().keySet();
return tableStatNames.stream().map(TableStat.Name::getName).toList();
return tableStatNames.stream().map(tableStat -> StringUtils.lowerCase(tableStat.getName())).toList();
}
public static void main(String[] args) {
String sql = "select u.name, o.id FROM users u left JOIN orders o ON u.id = o.user_id left join dept d on u.id = d.id WHERE u.id = 1 and u.id in ( select id from people where name = 'dd' )";
String sql = "select u.name, o.id FROM users u left JOIN orders o ON u.id = o.user_id left join dept d on u.id = d.id WHERE u.id = 1 and u.id in ( select id from people where name = ? )";
MySqlStatementParser parser = new MySqlStatementParser(sql);
SQLStatement sqlStatement = parseStatement(parser);

@ -18,3 +18,6 @@ server:
port: 8097
servlet:
context-path: /fu-hsi-server
case:
evidence:
table: case_evidence

@ -11,13 +11,16 @@ import com.supervision.police.dto.AtomicData;
import com.supervision.police.dto.JudgeLogic;
import com.supervision.police.mapper.ModelAtomicIndexMapper;
import com.supervision.police.mapper.ModelIndexMapper;
import com.supervision.police.mybatis.RowSqlMapper;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.mapping.BoundSql;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@ -31,6 +34,9 @@ public class ModelIndexTest {
@Autowired
private ModelAtomicIndexMapper modelAtomicIndexMapper;
@Autowired
private RowSqlMapper rowSqlMapper;
//@Test
public void modelIndexGenerate() {
@ -44,6 +50,23 @@ public class ModelIndexTest {
}
}
@Test
public void modelAtomicIndexGenerate() {
String sql = "select * from model_index_result where index_id = #{id};";
HashMap<String, Object> params = new HashMap<>(){
{
put("id", "1815615911075151874");
}
};
BoundSql boundSql = rowSqlMapper.getBoundSql(sql, params);;
String sql1 = boundSql.getSql();
System.out.println(sql1);
Map map = rowSqlMapper.selectOne(sql, params);
System.out.println(map);
}
public List<ModelIndex> readModelIndex() {
String path = "F:\\supervision\\doc\\宁夏公安\\宁夏大模型项目计划0722.xlsx";

Loading…
Cancel
Save