diff --git a/pythonProject/main.py b/pythonProject/main.py index ce55fcf..ec0f157 100644 --- a/pythonProject/main.py +++ b/pythonProject/main.py @@ -10,6 +10,7 @@ def main(): sql_tables = SQLParser.parse_sql_file("schema.sql") inserts = [] for table in sql_tables: + print(table.foreign_keys) for i in range(0, 5): inserts.append(SQLGenerator.generate_random_insert(table)) diff --git a/pythonProject/schema.sql b/pythonProject/schema.sql index 9aa67fb..6199590 100644 --- a/pythonProject/schema.sql +++ b/pythonProject/schema.sql @@ -8,3 +8,20 @@ CREATE TABLE `test` ( `ending_date_time` datetime(6) DEFAULT NULL, PRIMARY KEY (`id`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci + +CREATE TABLE `tasks` ( + `taskid` bigint(20) NOT NULL AUTO_INCREMENT, + `deadline` date DEFAULT NULL, + `eta` int(11) NOT NULL, + `start_date` date DEFAULT NULL, + `task_name` varchar(255) DEFAULT NULL, + `work_time` int(11) NOT NULL, + `parent` bigint(20) DEFAULT NULL, + `taskgroup_id` bigint(20) DEFAULT NULL, + PRIMARY KEY (`taskid`), + KEY `FKamiednrmm0puy94sf1o3q84bp` (`parent`), + KEY `FKk2et8snwh68sf7p6n6ltmk638` (`taskgroup_id`), + CONSTRAINT `FKamiednrmm0puy94sf1o3q84bp` FOREIGN KEY (`parent`) REFERENCES `tasks` (`taskid`), + CONSTRAINT `FKk2et8snwh68sf7p6n6ltmk638` FOREIGN KEY (`taskgroup_id`) REFERENCES `taskgroups` (`taskgroupid`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci + diff --git a/pythonProject/sqlparse/SQLParser.py b/pythonProject/sqlparse/SQLParser.py index 84321b2..0bcc962 100644 --- a/pythonProject/sqlparse/SQLParser.py +++ b/pythonProject/sqlparse/SQLParser.py @@ -40,13 +40,21 @@ def process_table_content(table_content: str, table_name: str): splitted_table_content = table_content.split(",") primary_key = "" columns = [] + foreign_keys = [] for column_definition in splitted_table_content: if column_definition.startswith(" PRIMARY KEY"): primary_key = get_primary_key(column_definition) + elif column_definition.startswith(" KEY"): + pass + elif column_definition.startswith(" CONSTRAINT"): + foreign_key_definition = column_definition[column_definition.index("FOREIGN KEY"):len(column_definition):1] + match = re.search("FOREIGN KEY \(`(\w+)`\)", foreign_key_definition) + if match: + foreign_keys.append(match.group(1)) else: column = parse_column_definition(column_definition) columns.append(column) - table = SQLTable(table_name, columns, primary_key) + table = SQLTable(table_name, columns, primary_key, foreign_keys) return table @@ -78,6 +86,7 @@ def get_column_type(column_definition: str): return SQLColumn.SQLColumnType.DATETIME + def get_column_name(column_definition: str): match = re.search("`(\w+)`", column_definition) if match: diff --git a/pythonProject/sqlparse/SQLTable.py b/pythonProject/sqlparse/SQLTable.py index e6ea0d0..cec5c44 100644 --- a/pythonProject/sqlparse/SQLTable.py +++ b/pythonProject/sqlparse/SQLTable.py @@ -1,9 +1,12 @@ +from typing import List + from sqlparse.SQLColumn import SQLColumn class SQLTable: - def __init__(self, table_name, columns: [SQLColumn], primary_key: str): + def __init__(self, table_name, columns: List[SQLColumn], primary_key: str, foreign_keys: List[str]): self.table_name = table_name self.columns = columns self.primary_key = primary_key + self.foreign_keys = foreign_keys