diff --git a/src/main/java/com/supervision/police/mybatis/RowMapperStatementBuilder.java b/src/main/java/com/supervision/police/mybatis/RowMapperStatementBuilder.java index a2bae54..bbc17fd 100644 --- a/src/main/java/com/supervision/police/mybatis/RowMapperStatementBuilder.java +++ b/src/main/java/com/supervision/police/mybatis/RowMapperStatementBuilder.java @@ -115,6 +115,11 @@ public class RowMapperStatementBuilder { return msId; } + public SqlSource getSqlSource(String sql, Class parameterType) { + + return languageDriver.createSqlSource(configuration, sql, parameterType); + } + public String select(String sql, Class resultType) { String msId = generateMappedStatementId(resultType + sql, SqlCommandType.SELECT); if (hasMappedStatement(msId)) { diff --git a/src/main/java/com/supervision/police/mybatis/RowSqlMapper.java b/src/main/java/com/supervision/police/mybatis/RowSqlMapper.java index 66bff8c..5abd95b 100644 --- a/src/main/java/com/supervision/police/mybatis/RowSqlMapper.java +++ b/src/main/java/com/supervision/police/mybatis/RowSqlMapper.java @@ -3,6 +3,7 @@ package com.supervision.police.mybatis; import cn.hutool.json.JSONUtil; import lombok.extern.slf4j.Slf4j; import org.apache.ibatis.exceptions.TooManyResultsException; +import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.session.SqlSession; import org.springframework.stereotype.Component; @@ -160,6 +161,10 @@ public class RowSqlMapper { return sqlSession.selectList(msId, value); } + public BoundSql getBoundSql(String sql, Object parameter) { + return mapperStatementBuilder.getSqlSource(sql, parameter.getClass()).getBoundSql(parameter); + } + /** * 插入数据 * diff --git a/src/main/java/com/supervision/police/service/impl/ModelServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ModelServiceImpl.java index e0cc70e..de5c85b 100644 --- a/src/main/java/com/supervision/police/service/impl/ModelServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/ModelServiceImpl.java @@ -5,6 +5,8 @@ import cn.hutool.core.lang.Assert; import cn.hutool.core.util.NumberUtil; import cn.hutool.core.util.ObjectUtil; import cn.hutool.json.JSONUtil; +import com.alibaba.druid.sql.ast.SQLStatement; +import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.supervision.common.domain.R; import com.supervision.common.utils.StringUtils; @@ -18,11 +20,13 @@ import com.supervision.police.dto.caseScore.CaseScoreDetailBuilder; import com.supervision.police.mapper.*; import com.supervision.police.mybatis.RowSqlMapper; import com.supervision.police.service.ModelService; +import com.supervision.utils.SqlParserUtil; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.neo4j.driver.Driver; import org.neo4j.driver.Result; import org.neo4j.driver.Session; +import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import java.util.*; @@ -48,6 +52,9 @@ public class ModelServiceImpl implements ModelService { private final RowSqlMapper rowSqlMapper; + @Value("${case.evidence.table}") + private List allowedTables; + @Override public R analyseCase(AnalyseCaseDTO analyseCaseDTO) { ModelCase modelCase = modelCaseMapper.selectById(analyseCaseDTO.getCaseId()); @@ -292,20 +299,49 @@ public class ModelServiceImpl implements ModelService { params.put("provider", null); params.put("party_a", analyseCaseDTO.getLawActorName()); params.put("party_b", analyseCaseDTO.getLawParty()); + boolean success = false; + if (checkSql(sql,allowedTables)){ + success = parseResult(rowSqlMapper.selectList(sql, params, Map.class)); + } + result.setAtomicResult(success ? "1" : "0"); - /* - todo:添加语法解析功能,提前验证sql是否合法,并且限制sql智能是select语句和限制sql语句中出现的表 - MappedStatement mappedStatement = rowSqlMapper.selectMappedStatement(sql, Map.class); - BoundSql boundSql = mappedStatement.getBoundSql(params); - String sql1 = boundSql.getSql(); - */ - boolean b = parseResult(rowSqlMapper.selectList(sql, params, Map.class)); - result.setAtomicResult(b ? "1" : "0"); + } + private boolean checkSql(String sql,List allowedTables) { + if (StringUtils.isEmpty(sql)) { + return false; + } + if (CollUtil.isEmpty(allowedTables)){ + log.info("checkSql:未配置允许的表"); + return false; + } + MySqlStatementParser parser = new MySqlStatementParser(sql); + SQLStatement sqlStatement = SqlParserUtil.parseStatement(parser); + if (Objects.isNull(sqlStatement)) { + log.warn("checkSql sql:{}语句解析失败", sql); + return false; + } + String sqlType = SqlParserUtil.detectSQLType(sqlStatement); + if (!"SELECT".equals(sqlType)) { + log.warn("checkSql:只支持查询类型语句"); + return false; + } + List tableList = SqlParserUtil.extractTableNames(sqlStatement); + if (CollUtil.isEmpty(tableList)){ + log.warn("checkSql:未检测到表"); + return false; + } + long count = tableList.stream().filter(table -> !allowedTables.contains(table)).count(); + if (count > 0){ + log.warn("checkSql:表{}不在允许的表列表中",tableList); + return false; + } + return true; } + /** * 执行结果分析: * 1. 如果查询出的结果只有一行,判断列数是否大于1,如果大于1,返回真,如果=1,继续判断值是否大于0,如果大于0,返回真,如果=0,返回假 diff --git a/src/main/java/com/supervision/utils/SqlParserUtil.java b/src/main/java/com/supervision/utils/SqlParserUtil.java index 4aad264..63d7dc5 100644 --- a/src/main/java/com/supervision/utils/SqlParserUtil.java +++ b/src/main/java/com/supervision/utils/SqlParserUtil.java @@ -5,6 +5,8 @@ import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser; import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor; import com.alibaba.druid.sql.parser.SQLStatementParser; import com.alibaba.druid.stat.TableStat; +import com.supervision.common.utils.StringUtils; + import java.util.List; import java.util.Set; @@ -59,11 +61,11 @@ public class SqlParserUtil { Set tableStatNames = visitor.getTables().keySet(); - return tableStatNames.stream().map(TableStat.Name::getName).toList(); + return tableStatNames.stream().map(tableStat -> StringUtils.lowerCase(tableStat.getName())).toList(); } public static void main(String[] args) { - String sql = "select u.name, o.id FROM users u left JOIN orders o ON u.id = o.user_id left join dept d on u.id = d.id WHERE u.id = 1 and u.id in ( select id from people where name = 'dd' )"; + String sql = "select u.name, o.id FROM users u left JOIN orders o ON u.id = o.user_id left join dept d on u.id = d.id WHERE u.id = 1 and u.id in ( select id from people where name = ? )"; MySqlStatementParser parser = new MySqlStatementParser(sql); SQLStatement sqlStatement = parseStatement(parser); diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index f2acfa1..2708ad0 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -18,3 +18,6 @@ server: port: 8097 servlet: context-path: /fu-hsi-server +case: + evidence: + table: case_evidence diff --git a/src/test/java/com/supervision/springaidemo/ModelIndexTest.java b/src/test/java/com/supervision/springaidemo/ModelIndexTest.java index 0db2ee4..308daca 100644 --- a/src/test/java/com/supervision/springaidemo/ModelIndexTest.java +++ b/src/test/java/com/supervision/springaidemo/ModelIndexTest.java @@ -11,13 +11,16 @@ import com.supervision.police.dto.AtomicData; import com.supervision.police.dto.JudgeLogic; import com.supervision.police.mapper.ModelAtomicIndexMapper; import com.supervision.police.mapper.ModelIndexMapper; +import com.supervision.police.mybatis.RowSqlMapper; import lombok.Data; import lombok.extern.slf4j.Slf4j; +import org.apache.ibatis.mapping.BoundSql; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -31,6 +34,9 @@ public class ModelIndexTest { @Autowired private ModelAtomicIndexMapper modelAtomicIndexMapper; + + @Autowired + private RowSqlMapper rowSqlMapper; //@Test public void modelIndexGenerate() { @@ -44,6 +50,23 @@ public class ModelIndexTest { } } + @Test + public void modelAtomicIndexGenerate() { + + String sql = "select * from model_index_result where index_id = #{id};"; + HashMap params = new HashMap<>(){ + { + put("id", "1815615911075151874"); + } + }; + + BoundSql boundSql = rowSqlMapper.getBoundSql(sql, params);; + String sql1 = boundSql.getSql(); + System.out.println(sql1); + Map map = rowSqlMapper.selectOne(sql, params); + System.out.println(map); + } + public List readModelIndex() { String path = "F:\\supervision\\doc\\宁夏公安\\宁夏大模型项目计划0722.xlsx";