package com.supervision.utils ;
import com.alibaba.druid.sql.ast.SQLStatement ;
import com.alibaba.druid.sql.ast.statement.* ;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser ;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor ;
import com.alibaba.druid.sql.parser.SQLStatementParser ;
import com.alibaba.druid.stat.TableStat ;
import com.supervision.common.utils.StringUtils ;
import java.util.List ;
import java.util.Set ;
public class SqlParserUtil {
/ * *
* 解 析 SQL
* @param parser parser
* @return null 表 示 解 析 失 败
* /
public static SQLStatement parseStatement ( SQLStatementParser parser ) {
try {
return parser . parseStatement ( ) ;
} catch ( Exception e ) {
return null ;
}
}
/ * *
* sql 类 型
* @param statement statement
* @return
* /
public static String detectSQLType ( SQLStatement statement ) {
try {
// 判断 SQL 类型
if ( statement instanceof SQLSelectStatement ) {
return "SELECT" ;
} else if ( statement instanceof SQLInsertStatement ) {
return "INSERT" ;
} else if ( statement instanceof SQLUpdateStatement ) {
return "UPDATE" ;
} else if ( statement instanceof SQLDeleteStatement ) {
return "DELETE" ;
} else {
return "UNKNOWN" ;
}
} catch ( Exception e ) {
return "ERROR" ;
}
}
/ * *
* 提 取 SQL 中 的 表 名
* @param statement statement
* @return 表 名
* /
public static List < String > extractTableNames ( SQLStatement statement ) {
MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor ( ) ;
statement . accept ( visitor ) ;
Set < TableStat . Name > tableStatNames = visitor . getTables ( ) . keySet ( ) ;
return tableStatNames . stream ( ) . map ( tableStat - > StringUtils . lowerCase ( tableStat . getName ( ) ) ) . toList ( ) ;
}
public static void main ( String [ ] args ) {
String sql = "select u.name, o.id FROM users u left JOIN orders o ON u.id = o.user_id left join dept d on u.id = d.id WHERE u.id = 1 and u.id in ( select id from people where name = ? )" ;
MySqlStatementParser parser = new MySqlStatementParser ( sql ) ;
SQLStatement sqlStatement = parseStatement ( parser ) ;
if ( sqlStatement = = null ) {
System . out . println ( "SQL 是否合法: " + false ) ;
return ;
}
String s = detectSQLType ( sqlStatement ) ;
System . out . println ( "SQL 类型: " + s ) ;
List < String > tableNames = extractTableNames ( sqlStatement ) ;
System . out . println ( "涉及到的表: " + tableNames ) ;
}
}