import re from sqlparse import SQLColumn from sqlparse.SQLTable import SQLTable tableNamePattern = r"CREATE TABLE `(\w+)`" def read_sql_file(sql_path: str): with open(sql_path, 'r') as file: return file.read() def split_table_definition(sql_string: str): splitted_sql_string = sql_string.split("CREATE TABLE") splitted_sql_string.pop(0) for index, sql_table_string in enumerate(splitted_sql_string): splitted_sql_string[index] = "CREATE TABLE" + sql_table_string return splitted_sql_string def parse_sql_file(sql_path: str): sql_string = read_sql_file(sql_path) sql_tables = split_table_definition(sql_string) processed_tables = [SQLTable] for sql_table in sql_tables: table = parse_table_sql(sql_table) processed_tables.append(table) return processed_tables def parse_table_sql(table_sql: str): table_name = get_table_name(table_sql) table_content = get_table_content(table_sql) table = process_table_content(table_content, table_name) return table def process_table_content(table_content: str, table_name: str): splitted_table_content = table_content.split(",") primary_key = "" columns = [] for column_definition in splitted_table_content: if column_definition.startswith(" PRIMARY KEY"): primary_key = get_primary_key(column_definition) else: column = parse_column_definition(column_definition) columns.append(column) return SQLTable(table_name, columns, primary_key) def parse_column_definition(column_definition: str): column_name = get_column_name(column_definition) column_type = get_column_type(column_definition) nullable = get_nullable(column_definition) return SQLColumn.SQLColumn(column_name, column_type, nullable) def get_nullable(column_definition: str): return "NOT NULL" in column_definition def get_column_type(column_definition: str): splitted_column_definition = column_definition.split(" ") for column_definition_characteristic in splitted_column_definition: if column_definition_characteristic == 'bigint(20)': return SQLColumn.SQLColumnType.BIGINT elif column_definition_characteristic == 'int(11)': return SQLColumn.SQLColumnType.INT def get_column_name(column_definition: str): match = re.search("`(\w+)`", column_definition) if match: column_name = match.group(1) return column_name def get_primary_key(primary_key_sql: str): match = re.search("PRIMARY KEY \(`(\w+)`\)", primary_key_sql) if match: return match.group(1) def get_table_name(sql_string: str): match = re.search(tableNamePattern, sql_string) if match: table_name = match.group(1) return table_name def get_table_content(sql_string: str): last_ceiling_index = sql_string.rfind(")") first_ceiling_index = sql_string.index("(") + 1 table_content = sql_string[first_ceiling_index:last_ceiling_index] table_content = re.compile(r"\s+").sub(" ", table_content).strip() return table_content