RandomDBGenerator/pythonProject/sqlparse/SQLParser.py
Sebastian Böckelmann ff967a4852 Parse Foreign Keys
2024-04-17 18:01:54 +02:00

117 lines
3.9 KiB
Python

import re
from sqlparse import SQLColumn
from sqlparse.SQLTable import SQLTable
tableNamePattern = r"CREATE TABLE `(\w+)`"
def read_sql_file(sql_path: str):
with open(sql_path, 'r') as file:
return file.read()
def split_table_definition(sql_string: str):
splitted_sql_string = sql_string.split("CREATE TABLE")
splitted_sql_string.pop(0)
for index, sql_table_string in enumerate(splitted_sql_string):
splitted_sql_string[index] = "CREATE TABLE" + sql_table_string
return splitted_sql_string
def parse_sql_file(sql_path: str):
sql_string = read_sql_file(sql_path)
sql_tables = split_table_definition(sql_string)
processed_tables = []
for sql_table in sql_tables:
table = parse_table_sql(sql_table)
processed_tables.append(table)
return processed_tables
def parse_table_sql(table_sql: str):
table_name = get_table_name(table_sql)
table_content = get_table_content(table_sql)
table = process_table_content(table_content, table_name)
return table
def process_table_content(table_content: str, table_name: str):
splitted_table_content = table_content.split(",")
primary_key = ""
columns = []
foreign_keys = []
for column_definition in splitted_table_content:
if column_definition.startswith(" PRIMARY KEY"):
primary_key = get_primary_key(column_definition)
elif column_definition.startswith(" KEY"):
pass
elif column_definition.startswith(" CONSTRAINT"):
foreign_key_definition = column_definition[column_definition.index("FOREIGN KEY"):len(column_definition):1]
match = re.search("FOREIGN KEY \(`(\w+)`\)", foreign_key_definition)
if match:
foreign_keys.append(match.group(1))
else:
column = parse_column_definition(column_definition)
columns.append(column)
table = SQLTable(table_name, columns, primary_key, foreign_keys)
return table
def parse_column_definition(column_definition: str):
column_name = get_column_name(column_definition)
column_type = get_column_type(column_definition)
nullable = get_nullable(column_definition)
return SQLColumn.SQLColumn(column_name, column_type, nullable)
def get_nullable(column_definition: str):
return "NOT NULL" in column_definition
def get_column_type(column_definition: str):
splitted_column_definition = column_definition.split(" ")
for column_definition_characteristic in splitted_column_definition:
if column_definition_characteristic == 'bigint(20)':
return SQLColumn.SQLColumnType.BIGINT
elif column_definition_characteristic == 'int(11)':
return SQLColumn.SQLColumnType.INT
elif column_definition_characteristic == 'varchar(255)':
return SQLColumn.SQLColumnType.VARCHAR
elif column_definition_characteristic == 'date':
return SQLColumn.SQLColumnType.DATE
elif column_definition_characteristic == 'time':
return SQLColumn.SQLColumnType.TIME
elif column_definition_characteristic == 'datetime(6)':
return SQLColumn.SQLColumnType.DATETIME
def get_column_name(column_definition: str):
match = re.search("`(\w+)`", column_definition)
if match:
column_name = match.group(1)
return column_name
def get_primary_key(primary_key_sql: str):
match = re.search("PRIMARY KEY \(`(\w+)`\)", primary_key_sql)
if match:
return match.group(1)
def get_table_name(sql_string: str):
match = re.search(tableNamePattern, sql_string)
if match:
table_name = match.group(1)
return table_name
def get_table_content(sql_string: str):
last_ceiling_index = sql_string.rfind(")")
first_ceiling_index = sql_string.index("(") + 1
table_content = sql_string[first_ceiling_index:last_ceiling_index]
table_content = re.compile(r"\s+").sub(" ", table_content).strip()
return table_content