Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/datajoint/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,14 @@ def analyze_columns(schema: Schema) -> dict:
for (table_name,) in tables:
# Get all columns for this table
columns_query = """
SELECT COLUMN_NAME, COLUMN_TYPE, DATA_TYPE, COLUMN_COMMENT
SELECT COLUMN_NAME, COLUMN_TYPE, DATA_TYPE, COLUMN_COMMENT, IS_NULLABLE
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = %s
AND TABLE_NAME = %s
"""
columns = connection.query(columns_query, args=(schema.database, table_name)).fetchall()

for column_name, column_type, data_type, comment in columns:
for column_name, column_type, data_type, comment, is_nullable in columns:
comment = comment or ""

# Check if column already has a type label (starts with :type:)
Expand All @@ -167,6 +167,7 @@ def analyze_columns(schema: Schema) -> dict:
"column": column_name,
"native_type": column_type,
"comment": comment,
"is_nullable": is_nullable == "YES",
}

if is_external:
Expand Down Expand Up @@ -270,9 +271,10 @@ def migrate_columns(
new_comment_escaped = new_comment.replace("\\", "\\\\").replace("'", "\\'")

# Generate ALTER TABLE statement
not_null = "" if col["is_nullable"] else " NOT NULL"
sql = (
f"ALTER TABLE `{db_name}`.`{table_name}` "
f"MODIFY COLUMN `{col['column']}` {col['native_type']} "
f"MODIFY COLUMN `{col['column']}` {col['native_type']}{not_null} "
f"COMMENT '{new_comment_escaped}'"
)
result["sql_statements"].append(sql)
Expand Down Expand Up @@ -365,7 +367,7 @@ def analyze_blob_columns(schema: Schema) -> list[dict]:
for (table_name,) in tables:
# Get column information for each table
columns_query = """
SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_COMMENT
SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_COMMENT, IS_NULLABLE
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = %s
AND TABLE_NAME = %s
Expand All @@ -374,7 +376,7 @@ def analyze_blob_columns(schema: Schema) -> list[dict]:

columns = connection.query(columns_query, args=(schema.database, table_name)).fetchall()

for column_name, column_type, comment in columns:
for column_name, column_type, comment, is_nullable in columns:
# Check if comment already has a codec type (starts with :type:)
has_codec = comment and comment.startswith(":")

Expand All @@ -385,6 +387,7 @@ def analyze_blob_columns(schema: Schema) -> list[dict]:
"column_type": column_type,
"current_comment": comment or "",
"needs_migration": not has_codec,
"is_nullable": is_nullable == "YES",
}
)

Expand Down Expand Up @@ -447,9 +450,10 @@ def generate_migration_sql(
db_name, table_name = col["table_name"].split(".")

# Generate ALTER TABLE statement
not_null = "" if col.get("is_nullable", True) else " NOT NULL"
sql = (
f"ALTER TABLE `{db_name}`.`{table_name}` "
f"MODIFY COLUMN `{col['column_name']}` {col['column_type']} "
f"MODIFY COLUMN `{col['column_name']}` {col['column_type']}{not_null} "
f"COMMENT '{new_comment_escaped}'"
)
sql_statements.append(sql)
Expand Down
Loading