From 1e7b1bf8b63fd842882c4d64ec781fb901480762 Mon Sep 17 00:00:00 2001 From: xueqingkun Date: Wed, 24 Jul 2024 11:01:25 +0800 Subject: [PATCH] =?UTF-8?q?1.=20=E6=B7=BB=E5=8A=A0sql=E8=A7=A3=E6=9E=90?= =?UTF-8?q?=E6=B5=8B=E8=AF=95demo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../com/supervision/utils/SqlParserUtil.java | 81 ++++++++++++++++++ .../springaidemo/DruidTableExtractor.java | 85 +++++++++++++++++++ .../springaidemo/SQLTableExtractor.java | 33 +++++++ 3 files changed, 199 insertions(+) create mode 100644 src/main/java/com/supervision/utils/SqlParserUtil.java create mode 100644 src/test/java/com/supervision/springaidemo/DruidTableExtractor.java create mode 100644 src/test/java/com/supervision/springaidemo/SQLTableExtractor.java diff --git a/src/main/java/com/supervision/utils/SqlParserUtil.java b/src/main/java/com/supervision/utils/SqlParserUtil.java new file mode 100644 index 0000000..4aad264 --- /dev/null +++ b/src/main/java/com/supervision/utils/SqlParserUtil.java @@ -0,0 +1,81 @@ +package com.supervision.utils; +import com.alibaba.druid.sql.ast.SQLStatement; +import com.alibaba.druid.sql.ast.statement.*; +import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser; +import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor; +import com.alibaba.druid.sql.parser.SQLStatementParser; +import com.alibaba.druid.stat.TableStat; +import java.util.List; +import java.util.Set; + +public class SqlParserUtil { + + + /** + * 解析 SQL + * @param parser parser + * @return null 表示解析失败 + */ + public static SQLStatement parseStatement(SQLStatementParser parser) { + try { + return parser.parseStatement(); + } catch (Exception e) { + return null; + } + } + + /** + *sql 类型 + * @param statement statement + * @return + */ + public static String detectSQLType(SQLStatement statement) { + try { + // 判断 SQL 类型 + if (statement instanceof SQLSelectStatement) { + return "SELECT"; + } else if (statement instanceof SQLInsertStatement) { + return "INSERT"; + } else if (statement instanceof SQLUpdateStatement) { + return "UPDATE"; + } else if (statement instanceof SQLDeleteStatement) { + return "DELETE"; + } else { + return "UNKNOWN"; + } + } catch (Exception e) { + return "ERROR"; + } + } + + /** + * 提取 SQL 中的表名 + * @param statement statement + * @return 表名 + */ + public static List extractTableNames(SQLStatement statement) { + MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor(); + statement.accept(visitor); + + Set tableStatNames = visitor.getTables().keySet(); + + return tableStatNames.stream().map(TableStat.Name::getName).toList(); + } + + public static void main(String[] args) { + String sql = "select u.name, o.id FROM users u left JOIN orders o ON u.id = o.user_id left join dept d on u.id = d.id WHERE u.id = 1 and u.id in ( select id from people where name = 'dd' )"; + + MySqlStatementParser parser = new MySqlStatementParser(sql); + SQLStatement sqlStatement = parseStatement(parser); + if (sqlStatement == null){ + System.out.println("SQL 是否合法: " + false); + return; + } + + String s = detectSQLType(sqlStatement); + System.out.println("SQL 类型: " + s); + + List tableNames = extractTableNames(sqlStatement); + System.out.println("涉及到的表: " + tableNames); + } +} diff --git a/src/test/java/com/supervision/springaidemo/DruidTableExtractor.java b/src/test/java/com/supervision/springaidemo/DruidTableExtractor.java new file mode 100644 index 0000000..a09f7fc --- /dev/null +++ b/src/test/java/com/supervision/springaidemo/DruidTableExtractor.java @@ -0,0 +1,85 @@ +package com.supervision.springaidemo; + +import com.alibaba.druid.sql.ast.SQLStatement; +import com.alibaba.druid.sql.ast.statement.*; +import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser; +import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor; +import com.alibaba.druid.stat.TableStat; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; + +public class DruidTableExtractor { + public static void main(String[] args) { + String sql = "drop u.name, o.id FROM users u left JOIN orders o ON u.id = o.user_id left join dept d on u.id = d.id WHERE u.id = 1 and u.id in ( select id from people where name = 'dd' )"; + + + boolean b = validateSQL(sql); + System.out.println("SQL 是否合法: " + b); + + + String s = detectSQLType(sql); + System.out.println("SQL 类型: " + s); + + List tableNames = extractTableNames(sql); + System.out.println("涉及到的表: " + tableNames); + } + + public static boolean validateSQL(String sql) { + try { + // 解析 SQL + MySqlStatementParser parser = new MySqlStatementParser(sql); + SQLStatement statement = parser.parseStatement(); + return true; // 解析成功,SQL 语句合法 + } catch (Exception e) { + // 解析异常,SQL 语句不合法 + System.out.println("SQL 解析错误: " + e.getMessage()); + return false; + } + } + + public static String detectSQLType(String sql) { + try { + // 解析 SQL + MySqlStatementParser parser = new MySqlStatementParser(sql); + SQLStatement statement = parser.parseStatement(); + + // 判断 SQL 类型 + if (statement instanceof SQLSelectStatement) { + return "SELECT"; + } else if (statement instanceof SQLInsertStatement) { + return "INSERT"; + } else if (statement instanceof SQLUpdateStatement) { + return "UPDATE"; + } else if (statement instanceof SQLDeleteStatement) { + return "DELETE"; + } else { + return "UNKNOWN"; + } + } catch (Exception e) { + System.out.println("SQL 解析错误: " + e.getMessage()); + return "ERROR"; + } + } + + public static List extractTableNames(String sql) { + // 解析 SQL + MySqlStatementParser parser = new MySqlStatementParser(sql); + SQLStatement statement = parser.parseStatement(); + + // 访问者模式提取表名 + MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor(); + statement.accept(visitor); + + Set tableStatNames = visitor.getTables().keySet(); + List tableNames = new ArrayList<>(); + + for (TableStat.Name name : tableStatNames) { + tableNames.add(name.getName()); + } + + return tableNames; + } +} + diff --git a/src/test/java/com/supervision/springaidemo/SQLTableExtractor.java b/src/test/java/com/supervision/springaidemo/SQLTableExtractor.java new file mode 100644 index 0000000..73e0377 --- /dev/null +++ b/src/test/java/com/supervision/springaidemo/SQLTableExtractor.java @@ -0,0 +1,33 @@ +package com.supervision.springaidemo; + +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.statement.Statement; +import net.sf.jsqlparser.statement.select.PlainSelect; +import net.sf.jsqlparser.statement.select.Select; +import net.sf.jsqlparser.statement.select.SelectBody; +import net.sf.jsqlparser.statement.select.SubSelect; +import net.sf.jsqlparser.schema.Table; +import net.sf.jsqlparser.util.TablesNamesFinder; +import net.sf.jsqlparser.JSQLParserException; + +import java.util.List; + +public class SQLTableExtractor { + public static void main(String[] args) { + String sql = "SELECT u.name, o.id FROM users u left JOIN orders o ON u.id = o.user_id WHERE u.id = 1 and u.id in ( select id from people where name = 'dd' )"; + + try { + Statement statement = CCJSqlParserUtil.parse(sql); + List tableNames = extractTableNames(statement); + System.out.println("涉及到的表: " + tableNames); + } catch (JSQLParserException e) { + System.out.println("SQL解析错误: " + e.getMessage()); + } + } + + private static List extractTableNames(Statement statement) { + TablesNamesFinder tablesNamesFinder = new TablesNamesFinder(); + return tablesNamesFinder.getTableList(statement); + } +} +