Parse Foreign Keys

This commit is contained in:
Sebastian Böckelmann 2024-04-17 18:01:54 +02:00
parent 90e08dad2e
commit ff967a4852
4 changed files with 32 additions and 2 deletions

View File

@ -10,6 +10,7 @@ def main():
sql_tables = SQLParser.parse_sql_file("schema.sql")
inserts = []
for table in sql_tables:
print(table.foreign_keys)
for i in range(0, 5):
inserts.append(SQLGenerator.generate_random_insert(table))

View File

@ -8,3 +8,20 @@ CREATE TABLE `test` (
`ending_date_time` datetime(6) DEFAULT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci
CREATE TABLE `tasks` (
`taskid` bigint(20) NOT NULL AUTO_INCREMENT,
`deadline` date DEFAULT NULL,
`eta` int(11) NOT NULL,
`start_date` date DEFAULT NULL,
`task_name` varchar(255) DEFAULT NULL,
`work_time` int(11) NOT NULL,
`parent` bigint(20) DEFAULT NULL,
`taskgroup_id` bigint(20) DEFAULT NULL,
PRIMARY KEY (`taskid`),
KEY `FKamiednrmm0puy94sf1o3q84bp` (`parent`),
KEY `FKk2et8snwh68sf7p6n6ltmk638` (`taskgroup_id`),
CONSTRAINT `FKamiednrmm0puy94sf1o3q84bp` FOREIGN KEY (`parent`) REFERENCES `tasks` (`taskid`),
CONSTRAINT `FKk2et8snwh68sf7p6n6ltmk638` FOREIGN KEY (`taskgroup_id`) REFERENCES `taskgroups` (`taskgroupid`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci

View File

@ -40,13 +40,21 @@ def process_table_content(table_content: str, table_name: str):
splitted_table_content = table_content.split(",")
primary_key = ""
columns = []
foreign_keys = []
for column_definition in splitted_table_content:
if column_definition.startswith(" PRIMARY KEY"):
primary_key = get_primary_key(column_definition)
elif column_definition.startswith(" KEY"):
pass
elif column_definition.startswith(" CONSTRAINT"):
foreign_key_definition = column_definition[column_definition.index("FOREIGN KEY"):len(column_definition):1]
match = re.search("FOREIGN KEY \(`(\w+)`\)", foreign_key_definition)
if match:
foreign_keys.append(match.group(1))
else:
column = parse_column_definition(column_definition)
columns.append(column)
table = SQLTable(table_name, columns, primary_key)
table = SQLTable(table_name, columns, primary_key, foreign_keys)
return table
@ -78,6 +86,7 @@ def get_column_type(column_definition: str):
return SQLColumn.SQLColumnType.DATETIME
def get_column_name(column_definition: str):
match = re.search("`(\w+)`", column_definition)
if match:

View File

@ -1,9 +1,12 @@
from typing import List
from sqlparse.SQLColumn import SQLColumn
class SQLTable:
def __init__(self, table_name, columns: [SQLColumn], primary_key: str):
def __init__(self, table_name, columns: List[SQLColumn], primary_key: str, foreign_keys: List[str]):
self.table_name = table_name
self.columns = columns
self.primary_key = primary_key
self.foreign_keys = foreign_keys