99 lines
3.0 KiB
Python
99 lines
3.0 KiB
Python
|
import re
|
||
|
|
||
|
from sqlparse import SQLColumn
|
||
|
from sqlparse.SQLTable import SQLTable
|
||
|
|
||
|
tableNamePattern = r"CREATE TABLE `(\w+)`"
|
||
|
|
||
|
|
||
|
def read_sql_file(sql_path: str):
|
||
|
with open(sql_path, 'r') as file:
|
||
|
return file.read()
|
||
|
|
||
|
|
||
|
def split_table_definition(sql_string: str):
|
||
|
splitted_sql_string = sql_string.split("CREATE TABLE")
|
||
|
splitted_sql_string.pop(0)
|
||
|
for index, sql_table_string in enumerate(splitted_sql_string):
|
||
|
splitted_sql_string[index] = "CREATE TABLE" + sql_table_string
|
||
|
return splitted_sql_string
|
||
|
|
||
|
|
||
|
def parse_sql_file(sql_path: str):
|
||
|
sql_string = read_sql_file(sql_path)
|
||
|
sql_tables = split_table_definition(sql_string)
|
||
|
processed_tables = [SQLTable]
|
||
|
for sql_table in sql_tables:
|
||
|
table = parse_table_sql(sql_table)
|
||
|
processed_tables.append(table)
|
||
|
return processed_tables
|
||
|
|
||
|
|
||
|
def parse_table_sql(table_sql: str):
|
||
|
table_name = get_table_name(table_sql)
|
||
|
table_content = get_table_content(table_sql)
|
||
|
table = process_table_content(table_content, table_name)
|
||
|
return table
|
||
|
|
||
|
|
||
|
def process_table_content(table_content: str, table_name: str):
|
||
|
splitted_table_content = table_content.split(",")
|
||
|
primary_key = ""
|
||
|
columns = []
|
||
|
for column_definition in splitted_table_content:
|
||
|
if column_definition.startswith(" PRIMARY KEY"):
|
||
|
primary_key = get_primary_key(column_definition)
|
||
|
else:
|
||
|
column = parse_column_definition(column_definition)
|
||
|
columns.append(column)
|
||
|
return SQLTable(table_name, columns, primary_key)
|
||
|
|
||
|
|
||
|
def parse_column_definition(column_definition: str):
|
||
|
column_name = get_column_name(column_definition)
|
||
|
column_type = get_column_type(column_definition)
|
||
|
nullable = get_nullable(column_definition)
|
||
|
return SQLColumn.SQLColumn(column_name, column_type, nullable)
|
||
|
|
||
|
|
||
|
def get_nullable(column_definition: str):
|
||
|
return "NOT NULL" in column_definition
|
||
|
|
||
|
|
||
|
def get_column_type(column_definition: str):
|
||
|
splitted_column_definition = column_definition.split(" ")
|
||
|
for column_definition_characteristic in splitted_column_definition:
|
||
|
if column_definition_characteristic == 'bigint(20)':
|
||
|
return SQLColumn.SQLColumnType.BIGINT
|
||
|
elif column_definition_characteristic == 'int(11)':
|
||
|
return SQLColumn.SQLColumnType.INT
|
||
|
|
||
|
|
||
|
def get_column_name(column_definition: str):
|
||
|
match = re.search("`(\w+)`", column_definition)
|
||
|
if match:
|
||
|
column_name = match.group(1)
|
||
|
return column_name
|
||
|
|
||
|
|
||
|
def get_primary_key(primary_key_sql: str):
|
||
|
match = re.search("PRIMARY KEY \(`(\w+)`\)", primary_key_sql)
|
||
|
if match:
|
||
|
return match.group(1)
|
||
|
|
||
|
|
||
|
def get_table_name(sql_string: str):
|
||
|
match = re.search(tableNamePattern, sql_string)
|
||
|
if match:
|
||
|
table_name = match.group(1)
|
||
|
return table_name
|
||
|
|
||
|
|
||
|
def get_table_content(sql_string: str):
|
||
|
last_ceiling_index = sql_string.rfind(")")
|
||
|
first_ceiling_index = sql_string.index("(") + 1
|
||
|
|
||
|
table_content = sql_string[first_ceiling_index:last_ceiling_index]
|
||
|
table_content = re.compile(r"\s+").sub(" ", table_content).strip()
|
||
|
return table_content
|