Parse Foreign Keys
This commit is contained in:
parent
90e08dad2e
commit
ff967a4852
@ -10,6 +10,7 @@ def main():
|
||||
sql_tables = SQLParser.parse_sql_file("schema.sql")
|
||||
inserts = []
|
||||
for table in sql_tables:
|
||||
print(table.foreign_keys)
|
||||
for i in range(0, 5):
|
||||
inserts.append(SQLGenerator.generate_random_insert(table))
|
||||
|
||||
|
@ -8,3 +8,20 @@ CREATE TABLE `test` (
|
||||
`ending_date_time` datetime(6) DEFAULT NULL,
|
||||
PRIMARY KEY (`id`)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci
|
||||
|
||||
CREATE TABLE `tasks` (
|
||||
`taskid` bigint(20) NOT NULL AUTO_INCREMENT,
|
||||
`deadline` date DEFAULT NULL,
|
||||
`eta` int(11) NOT NULL,
|
||||
`start_date` date DEFAULT NULL,
|
||||
`task_name` varchar(255) DEFAULT NULL,
|
||||
`work_time` int(11) NOT NULL,
|
||||
`parent` bigint(20) DEFAULT NULL,
|
||||
`taskgroup_id` bigint(20) DEFAULT NULL,
|
||||
PRIMARY KEY (`taskid`),
|
||||
KEY `FKamiednrmm0puy94sf1o3q84bp` (`parent`),
|
||||
KEY `FKk2et8snwh68sf7p6n6ltmk638` (`taskgroup_id`),
|
||||
CONSTRAINT `FKamiednrmm0puy94sf1o3q84bp` FOREIGN KEY (`parent`) REFERENCES `tasks` (`taskid`),
|
||||
CONSTRAINT `FKk2et8snwh68sf7p6n6ltmk638` FOREIGN KEY (`taskgroup_id`) REFERENCES `taskgroups` (`taskgroupid`)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci
|
||||
|
||||
|
@ -40,13 +40,21 @@ def process_table_content(table_content: str, table_name: str):
|
||||
splitted_table_content = table_content.split(",")
|
||||
primary_key = ""
|
||||
columns = []
|
||||
foreign_keys = []
|
||||
for column_definition in splitted_table_content:
|
||||
if column_definition.startswith(" PRIMARY KEY"):
|
||||
primary_key = get_primary_key(column_definition)
|
||||
elif column_definition.startswith(" KEY"):
|
||||
pass
|
||||
elif column_definition.startswith(" CONSTRAINT"):
|
||||
foreign_key_definition = column_definition[column_definition.index("FOREIGN KEY"):len(column_definition):1]
|
||||
match = re.search("FOREIGN KEY \(`(\w+)`\)", foreign_key_definition)
|
||||
if match:
|
||||
foreign_keys.append(match.group(1))
|
||||
else:
|
||||
column = parse_column_definition(column_definition)
|
||||
columns.append(column)
|
||||
table = SQLTable(table_name, columns, primary_key)
|
||||
table = SQLTable(table_name, columns, primary_key, foreign_keys)
|
||||
return table
|
||||
|
||||
|
||||
@ -78,6 +86,7 @@ def get_column_type(column_definition: str):
|
||||
return SQLColumn.SQLColumnType.DATETIME
|
||||
|
||||
|
||||
|
||||
def get_column_name(column_definition: str):
|
||||
match = re.search("`(\w+)`", column_definition)
|
||||
if match:
|
||||
|
@ -1,9 +1,12 @@
|
||||
from typing import List
|
||||
|
||||
from sqlparse.SQLColumn import SQLColumn
|
||||
|
||||
|
||||
class SQLTable:
|
||||
def __init__(self, table_name, columns: [SQLColumn], primary_key: str):
|
||||
def __init__(self, table_name, columns: List[SQLColumn], primary_key: str, foreign_keys: List[str]):
|
||||
self.table_name = table_name
|
||||
self.columns = columns
|
||||
self.primary_key = primary_key
|
||||
self.foreign_keys = foreign_keys
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user