diff --git a/docs/cli-reference.rst b/docs/cli-reference.rst index 6231dbf2..ff4800fb 100644 --- a/docs/cli-reference.rst +++ b/docs/cli-reference.rst @@ -474,6 +474,8 @@ See :ref:`cli_transform_table`. Add a foreign key constraint from a column to another table with another column --drop-foreign-key TEXT Drop foreign key constraint for this column + --update-incoming-fks Update foreign keys in other tables that + reference renamed columns --sql Output SQL without executing it --load-extension TEXT Path to SQLite extension, with optional :entrypoint diff --git a/docs/cli.rst b/docs/cli.rst index a6081609..b93868c6 100644 --- a/docs/cli.rst +++ b/docs/cli.rst @@ -2113,6 +2113,9 @@ Every option for this table (with the exception of ``--pk-none``) can be specifi ``--add-foreign-key column other_table other_column`` Add a foreign key constraint to ``column`` pointing to ``other_table.other_column``. +``--update-incoming-fks`` + When renaming columns, automatically update foreign key constraints in other tables that reference the renamed columns. For example, if ``books.author_id`` references ``authors.id`` and you rename ``authors.id`` to ``authors.author_pk``, this flag will also update the foreign key in ``books`` to reference the new column name. + If you want to see the SQL that will be executed to make the change without actually executing it, add the ``--sql`` flag. For example: .. code-block:: bash diff --git a/docs/python-api.rst b/docs/python-api.rst index 267591ac..2fe2aeb5 100644 --- a/docs/python-api.rst +++ b/docs/python-api.rst @@ -1631,6 +1631,31 @@ This example drops two foreign keys - the one from ``places.country`` to ``count drop_foreign_keys=("country", "continent") ) +.. _python_api_transform_update_incoming_fks: + +Updating foreign keys in other tables +------------------------------------- + +When renaming columns that are referenced by foreign keys in other tables, you can use the ``update_incoming_fks=True`` parameter to automatically update those foreign key constraints. + +For example, if you have a ``books`` table with a foreign key from ``books.author_id`` to ``authors.id``, and you want to rename ``authors.id`` to ``authors.author_pk``: + +.. code-block:: python + + db["authors"].transform( + rename={"id": "author_pk"}, + update_incoming_fks=True, + ) + +This will rename the column in the ``authors`` table and also update the foreign key constraint in the ``books`` table to reference ``authors.author_pk`` instead of ``authors.id``. + +Without ``update_incoming_fks=True``, this operation would fail with a foreign key mismatch error (if foreign key enforcement is enabled) because the ``books`` table would still reference the old column name. + +This parameter also correctly handles: + +- Multiple tables referencing the renamed column +- Self-referential foreign keys (e.g., an ``employees.manager_id`` column referencing ``employees.id``) + .. _python_api_transform_sql: Custom transformations with .transform_sql() diff --git a/pyproject.toml b/pyproject.toml index a50a3a8e..55ccda9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ [dependency-groups] dev = [ - "black>=24.1.1", + "black>=26.1.0", "cogapp", "hypothesis", "pytest", diff --git a/sqlite_utils/cli.py b/sqlite_utils/cli.py index 9b9ee20e..38d5c7f7 100644 --- a/sqlite_utils/cli.py +++ b/sqlite_utils/cli.py @@ -42,7 +42,6 @@ TypeTracker, ) - CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"]) @@ -2545,6 +2544,11 @@ def schema( multiple=True, help="Drop foreign key constraint for this column", ) +@click.option( + "--update-incoming-fks", + is_flag=True, + help="Update foreign keys in other tables that reference renamed columns", +) @click.option("--sql", is_flag=True, help="Output SQL without executing it") @load_extension_option def transform( @@ -2562,6 +2566,7 @@ def transform( default_none, add_foreign_keys, drop_foreign_keys, + update_incoming_fks, sql, load_extension, ): @@ -2615,6 +2620,8 @@ def transform( kwargs["drop_foreign_keys"] = drop_foreign_keys if add_foreign_keys: kwargs["add_foreign_keys"] = add_foreign_keys + if update_incoming_fks: + kwargs["update_incoming_fks"] = True if sql: for line in db.table(table).transform_sql(**kwargs): @@ -2911,8 +2918,7 @@ def _analyze(db, tables, columns, save, common_limit=10, no_most=False, no_least ) details = ( ( - textwrap.dedent( - """ + textwrap.dedent(""" {table}.{column}: ({i}/{total}) Total rows: {total_rows} @@ -2920,8 +2926,7 @@ def _analyze(db, tables, columns, save, common_limit=10, no_most=False, no_least Blank rows: {num_blank} Distinct values: {num_distinct}{most_common_rendered}{least_common_rendered} - """ - ) + """) .strip() .format( i=i + 1, @@ -2968,8 +2973,7 @@ def uninstall(packages, yes): def _generate_convert_help(): - help = textwrap.dedent( - """ + help = textwrap.dedent(""" Convert columns using Python code you supply. For example: \b @@ -2982,8 +2986,7 @@ def _generate_convert_help(): Use "-" for CODE to read Python code from standard input. The following common operations are available as recipe functions: - """ - ).strip() + """).strip() recipe_names = [ n for n in dir(recipes) @@ -2997,15 +3000,13 @@ def _generate_convert_help(): name, str(inspect.signature(fn)), textwrap.dedent(fn.__doc__.rstrip()) ) help += "\n\n" - help += textwrap.dedent( - """ + help += textwrap.dedent(""" You can use these recipes like so: \b sqlite-utils convert my.db mytable mycolumn \\ 'r.jsonsplit(value, delimiter=":")' - """ - ).strip() + """).strip() return help diff --git a/sqlite_utils/db.py b/sqlite_utils/db.py index aacdc893..77a6f809 100644 --- a/sqlite_utils/db.py +++ b/sqlite_utils/db.py @@ -974,15 +974,17 @@ def sort_key(p): column_items.insert(0, (hash_id, str)) pk = hash_id # Soundness check foreign_keys point to existing tables - for fk in foreign_keys: - if fk.other_table == name and columns.get(fk.other_column): - continue - if fk.other_column != "rowid" and not any( - c for c in self[fk.other_table].columns if c.name == fk.other_column - ): - raise AlterError( - "No such column: {}.{}".format(fk.other_table, fk.other_column) - ) + # (can be skipped for internal operations like update_incoming_fks) + if not getattr(self, "_skip_fk_validation", False): + for fk in foreign_keys: + if fk.other_table == name and columns.get(fk.other_column): + continue + if fk.other_column != "rowid" and not any( + c for c in self[fk.other_table].columns if c.name == fk.other_column + ): + raise AlterError( + "No such column: {}.{}".format(fk.other_table, fk.other_column) + ) column_defs = [] # ensure pk is a tuple @@ -1850,6 +1852,40 @@ def duplicate(self, new_name: str) -> "Table": self.db.execute(sql) return self.db.table(new_name) + def _get_incoming_fks_needing_update(self, rename: dict) -> list: + """ + Find all tables with FK constraints pointing to columns being renamed. + + Returns a list of (table_name, new_fks) tuples where new_fks is the + updated list of foreign keys for that table. + + :param rename: Dictionary mapping old column names to new column names + """ + tables_needing_update = [] + + for other_table_name in self.db.table_names(): + if other_table_name == self.name: + continue + + other_table = self.db.table(other_table_name) + other_fks = other_table.foreign_keys + + # Check if any FK references a column being renamed + needs_update = False + new_fks = [] + for fk in other_fks: + if fk.other_table == self.name and fk.other_column in rename: + # This FK needs updating + needs_update = True + new_fks.append((fk.column, fk.other_table, rename[fk.other_column])) + else: + new_fks.append((fk.column, fk.other_table, fk.other_column)) + + if needs_update: + tables_needing_update.append((other_table_name, new_fks)) + + return tables_needing_update + def transform( self, *, @@ -1864,6 +1900,7 @@ def transform( foreign_keys: Optional[ForeignKeysType] = None, column_order: Optional[List[str]] = None, keep_table: Optional[str] = None, + update_incoming_fks: bool = False, ) -> "Table": """ Apply an advanced alter table, including operations that are not supported by @@ -1884,21 +1921,50 @@ def transform( to use when creating the table :param keep_table: If specified, the existing table will be renamed to this and will not be dropped + :param update_incoming_fks: If True, automatically update foreign key constraints in other + tables that reference columns being renamed in this table """ assert self.exists(), "Cannot transform a table that doesn't exist yet" - sqls = self.transform_sql( - types=types, - rename=rename, - drop=drop, - pk=pk, - not_null=not_null, - defaults=defaults, - drop_foreign_keys=drop_foreign_keys, - add_foreign_keys=add_foreign_keys, - foreign_keys=foreign_keys, - column_order=column_order, - keep_table=keep_table, - ) + + # Collect SQL for updating incoming FKs if needed + incoming_fk_sqls: List[str] = [] + if update_incoming_fks and rename: + tables_needing_update = self._get_incoming_fks_needing_update(rename) + for other_table_name, new_fks in tables_needing_update: + other_table = self.db.table(other_table_name) + # Generate transform SQL for the other table with updated FKs + # Skip FK validation since the new column doesn't exist yet + try: + setattr(self.db, "_skip_fk_validation", True) + incoming_fk_sqls.extend( + other_table.transform_sql(foreign_keys=new_fks) + ) + finally: + setattr(self.db, "_skip_fk_validation", False) + + # Skip FK validation for main transform if update_incoming_fks is True + # because self-referential FKs will reference the new column name + # that only exists after the transform completes + if update_incoming_fks and rename: + setattr(self.db, "_skip_fk_validation", True) + try: + sqls = self.transform_sql( + types=types, + rename=rename, + drop=drop, + pk=pk, + not_null=not_null, + defaults=defaults, + drop_foreign_keys=drop_foreign_keys, + add_foreign_keys=add_foreign_keys, + foreign_keys=foreign_keys, + column_order=column_order, + keep_table=keep_table, + ) + finally: + if update_incoming_fks and rename: + setattr(self.db, "_skip_fk_validation", False) + pragma_foreign_keys_was_on = self.db.execute("PRAGMA foreign_keys").fetchone()[ 0 ] @@ -1906,8 +1972,12 @@ def transform( if pragma_foreign_keys_was_on: self.db.execute("PRAGMA foreign_keys=0;") with self.db.conn: + # First: transform the main table (so renamed columns exist) for sql in sqls: self.db.execute(sql) + # Then: update incoming FKs in other tables + for sql in incoming_fk_sqls: + self.db.execute(sql) # Run the foreign_key_check before we commit if pragma_foreign_keys_was_on: self.db.execute("PRAGMA foreign_key_check;") @@ -1972,6 +2042,9 @@ def transform_sql( for table, column, other_table, other_column in self.foreign_keys: # Copy over old foreign keys, unless we are dropping them if (drop_foreign_keys is None) or (column not in drop_foreign_keys): + # For self-referential FKs, also update the referenced column if renamed + if other_table == self.name: + other_column = rename.get(other_column) or other_column create_table_foreign_keys.append( ForeignKey( table, @@ -2275,12 +2348,10 @@ def create_index( "{}_{}".format(index_name, suffix) if suffix else index_name ) sql = ( - textwrap.dedent( - """ + textwrap.dedent(""" CREATE {unique}INDEX {if_not_exists}{index_name} ON {table_name} ({columns}); - """ - ) + """) .strip() .format( index_name=quote_identifier(created_index_name), @@ -2475,8 +2546,7 @@ def enable_counts(self) -> None: See :ref:`python_api_cached_table_counts` for details. """ sql = ( - textwrap.dedent( - """ + textwrap.dedent(""" {create_counts_table} CREATE TRIGGER IF NOT EXISTS {trigger_insert} AFTER INSERT ON {table} BEGIN @@ -2501,8 +2571,7 @@ def enable_counts(self) -> None: ); END; INSERT OR REPLACE INTO _counts VALUES ({table_quoted}, (select count(*) from {table})); - """ - ) + """) .strip() .format( create_counts_table=_COUNTS_TABLE_CREATE_SQL.format( @@ -2554,14 +2623,12 @@ def enable_fts( :param replace: Should any existing FTS index for this table be replaced by the new one? """ create_fts_sql = ( - textwrap.dedent( - """ + textwrap.dedent(""" CREATE VIRTUAL TABLE {table_fts} USING {fts_version} ( {columns},{tokenize} content={table} ) - """ - ) + """) .strip() .format( table=quote_identifier(self.name), @@ -2599,8 +2666,7 @@ def enable_fts( table = quote_identifier(self.name) table_fts = quote_identifier(self.name + "_fts") triggers = ( - textwrap.dedent( - """ + textwrap.dedent(""" CREATE TRIGGER {table_ai} AFTER INSERT ON {table} BEGIN INSERT INTO {table_fts} (rowid, {columns}) VALUES (new.rowid, {new_cols}); END; @@ -2611,8 +2677,7 @@ def enable_fts( INSERT INTO {table_fts} ({table_fts}, rowid, {columns}) VALUES('delete', old.rowid, {old_cols}); INSERT INTO {table_fts} (rowid, {columns}) VALUES (new.rowid, {new_cols}); END; - """ - ) + """) .strip() .format( table=table, @@ -2637,12 +2702,10 @@ def populate_fts(self, columns: Iterable[str]) -> "Table": """ columns_quoted = ", ".join(quote_identifier(c) for c in columns) sql = ( - textwrap.dedent( - """ + textwrap.dedent(""" INSERT INTO {table_fts} (rowid, {columns}) SELECT rowid, {columns} FROM {table}; - """ - ) + """) .strip() .format( table=quote_identifier(self.name), @@ -2659,17 +2722,11 @@ def disable_fts(self) -> "Table": if fts_table: self.db[fts_table].drop() # Now delete the triggers that related to that table - sql = ( - textwrap.dedent( - """ + sql = textwrap.dedent(""" SELECT name FROM sqlite_master WHERE type = 'trigger' AND (sql LIKE '% INSERT INTO [{}]%' OR sql LIKE '% INSERT INTO "{}"%') - """ - ) - .strip() - .format(fts_table, fts_table) - ) + """).strip().format(fts_table, fts_table) trigger_names = [] for row in self.db.execute(sql).fetchall(): trigger_names.append(row[0]) @@ -2695,8 +2752,7 @@ def rebuild_fts(self) -> "Table": def detect_fts(self) -> Optional[str]: "Detect if table has a corresponding FTS virtual table and return it" - sql = textwrap.dedent( - """ + sql = textwrap.dedent(""" SELECT name FROM sqlite_master WHERE rootpage = 0 AND ( @@ -2707,8 +2763,7 @@ def detect_fts(self) -> Optional[str]: AND sql LIKE '%VIRTUAL TABLE%USING FTS%' ) ) - """ - ).strip() + """).strip() args = { "like": '%VIRTUAL TABLE%USING FTS%content="{}"%'.format(self.name), "like2": '%VIRTUAL TABLE%USING FTS%content="{}"%'.format(self.name), @@ -2724,13 +2779,9 @@ def optimize(self) -> "Table": "Run the ``optimize`` operation against the associated full-text search index table." fts_table = self.detect_fts() if fts_table is not None: - self.db.execute( - """ + self.db.execute(""" INSERT INTO {table} ({table}) VALUES ("optimize"); - """.strip().format( - table=quote_identifier(fts_table) - ) - ) + """.strip().format(table=quote_identifier(fts_table))) return self def search_sql( @@ -2768,8 +2819,7 @@ def search_sql( ) fts_table_quoted = quote_identifier(fts_table) virtual_table_using = self.db.table(fts_table).virtual_table_using - sql = textwrap.dedent( - """ + sql = textwrap.dedent(""" with {original} as ( select rowid, @@ -2786,8 +2836,7 @@ def search_sql( order by {order_by} {limit_offset} - """ - ).strip() + """).strip() if virtual_table_using == "FTS5": rank_implementation = "{}.rank".format(fts_table_quoted) else: diff --git a/tests/conftest.py b/tests/conftest.py index 3990d76e..6fd35b85 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,14 +42,12 @@ def fresh_db(): @pytest.fixture def existing_db(): database = Database(memory=True) - database.executescript( - """ + database.executescript(""" CREATE TABLE foo (text TEXT); INSERT INTO foo (text) values ("one"); INSERT INTO foo (text) values ("two"); INSERT INTO foo (text) values ("three"); - """ - ) + """) return database diff --git a/tests/test_analyze_tables.py b/tests/test_analyze_tables.py index 4618eff1..a2ce585d 100644 --- a/tests/test_analyze_tables.py +++ b/tests/test_analyze_tables.py @@ -143,10 +143,7 @@ def db_to_analyze_path(db_to_analyze, tmpdir): def test_analyze_table(db_to_analyze_path): result = CliRunner().invoke(cli.cli, ["analyze-tables", db_to_analyze_path]) - assert ( - result.output.strip() - == ( - """ + assert result.output.strip() == (""" stuff.id: (1/3) Total rows: 8 @@ -179,9 +176,7 @@ def test_analyze_table(db_to_analyze_path): Most common: 5: 5 - 3: 4""" - ).strip() - ) + 3: 4""").strip() def test_analyze_table_save(db_to_analyze_path): diff --git a/tests/test_cli.py b/tests/test_cli.py index 40c3595c..e11508ad 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -967,12 +967,9 @@ def test_query_json_with_json_cols(db_path): result = CliRunner().invoke( cli.cli, [db_path, "select id, name, friends from dogs"] ) - assert ( - r""" + assert r""" [{"id": 1, "name": "Cleo", "friends": "[{\"name\": \"Pancakes\"}, {\"name\": \"Bailey\"}]"}] - """.strip() - == result.output.strip() - ) + """.strip() == result.output.strip() # With --json-cols: result = CliRunner().invoke( cli.cli, [db_path, "select id, name, friends from dogs", "--json-cols"] @@ -1810,6 +1807,46 @@ def test_transform_add_or_drop_foreign_key(db_path, extra_args, expected_schema) assert schema == expected_schema +def test_transform_update_incoming_fks_cli(db_path): + """Test --update-incoming-fks flag updates foreign keys in other tables""" + db = Database(db_path) + with db.conn: + db["authors"].insert({"id": 1, "name": "Alice"}, pk="id") + db["books"].insert( + {"id": 1, "title": "Book A", "author_id": 1}, + pk="id", + foreign_keys=[("author_id", "authors", "id")], + ) + + # Rename authors.id to authors.author_pk with --update-incoming-fks + result = CliRunner().invoke( + cli.cli, + [ + "transform", + db_path, + "authors", + "--rename", + "id", + "author_pk", + "--update-incoming-fks", + ], + ) + assert result.exit_code == 0, result.output + + # Verify authors column was renamed + assert "author_pk" in db["authors"].columns_dict + assert "id" not in db["authors"].columns_dict + + # Verify books FK was updated + assert db["books"].schema == ( + 'CREATE TABLE "books" (\n' + ' "id" INTEGER PRIMARY KEY,\n' + ' "title" TEXT,\n' + ' "author_id" INTEGER REFERENCES "authors"("author_pk")\n' + ")" + ) + + _common_other_schema = ( 'CREATE TABLE "species" (\n "id" INTEGER PRIMARY KEY,\n "species" TEXT\n)' ) @@ -1998,12 +2035,10 @@ def test_search_quote(tmpdir): def test_indexes(tmpdir): db_path = str(tmpdir / "test.db") db = Database(db_path) - db.conn.executescript( - """ + db.conn.executescript(""" create table Gosh (c1 text, c2 text, c3 text); create index Gosh_idx on Gosh(c2, c3 desc); - """ - ) + """) result = CliRunner().invoke( cli.cli, ["indexes", str(db_path)], @@ -2094,16 +2129,12 @@ def test_triggers(tmpdir, extra_args, expected): pk="id", ) db["counter"].insert({"count": 1}) - db.conn.execute( - textwrap.dedent( - """ + db.conn.execute(textwrap.dedent(""" CREATE TRIGGER blah AFTER INSERT ON articles BEGIN UPDATE counter SET count = count + 1; END - """ - ) - ) + """)) args = ["triggers", db_path] if extra_args: args.extend(extra_args) diff --git a/tests/test_cli_convert.py b/tests/test_cli_convert.py index 387b3181..443e72c6 100644 --- a/tests/test_cli_convert.py +++ b/tests/test_cli_convert.py @@ -371,16 +371,14 @@ def test_convert_multi_complex_column_types(fresh_db_and_path): ], pk="id", ) - code = textwrap.dedent( - """ + code = textwrap.dedent(""" if value == 1: return {"is_str": "", "is_float": 1.2, "is_int": None} elif value == 2: return {"is_float": 1, "is_int": 12} elif value == 3: return {"is_bytes": b"blah"} - """ - ) + """) result = CliRunner().invoke( cli.cli, [ diff --git a/tests/test_create.py b/tests/test_create.py index ea09c459..b1a6ad1f 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -20,7 +20,6 @@ import pytest import uuid - try: import pandas as pd # type: ignore except ImportError: diff --git a/tests/test_default_value.py b/tests/test_default_value.py index 9ffdb144..c594c9f5 100644 --- a/tests/test_default_value.py +++ b/tests/test_default_value.py @@ -1,6 +1,5 @@ import pytest - EXAMPLES = [ ("TEXT DEFAULT 'foo'", "'foo'", "'foo'"), ("TEXT DEFAULT 'foo)'", "'foo)'", "'foo)'"), diff --git a/tests/test_duplicate.py b/tests/test_duplicate.py index 552f697c..28961d2e 100644 --- a/tests/test_duplicate.py +++ b/tests/test_duplicate.py @@ -5,14 +5,12 @@ def test_duplicate(fresh_db): # Create table using native Sqlite statement: - fresh_db.execute( - """CREATE TABLE "table1" ( + fresh_db.execute("""CREATE TABLE "table1" ( "text_col" TEXT, "real_col" REAL, "int_col" INTEGER, "bool_col" INTEGER, - "datetime_col" TEXT)""" - ) + "datetime_col" TEXT)""") # Insert one row of mock data: dt = datetime.datetime.now() data = { diff --git a/tests/test_extract.py b/tests/test_extract.py index 435c1aea..d24c5974 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -126,9 +126,7 @@ def test_extract_rowid_table(fresh_db): ' "common_name_latin_name_id" INTEGER REFERENCES "common_name_latin_name"("id")\n' ")" ) - assert ( - fresh_db.execute( - """ + assert fresh_db.execute(""" select tree.name, common_name_latin_name.common_name, @@ -136,10 +134,7 @@ def test_extract_rowid_table(fresh_db): from tree join common_name_latin_name on tree.common_name_latin_name_id = common_name_latin_name.id - """ - ).fetchall() - == [("Tree 1", "Palm", "Arecaceae")] - ) + """).fetchall() == [("Tree 1", "Palm", "Arecaceae")] def test_reuse_lookup_table(fresh_db): diff --git a/tests/test_introspect.py b/tests/test_introspect.py index ab61c158..4ff3f772 100644 --- a/tests/test_introspect.py +++ b/tests/test_introspect.py @@ -109,13 +109,11 @@ def test_table_repr(fresh_db): def test_indexes(fresh_db): - fresh_db.executescript( - """ + fresh_db.executescript(""" create table Gosh (c1 text, c2 text, c3 text); create index Gosh_c1 on Gosh(c1); create index Gosh_c2c3 on Gosh(c2, c3); - """ - ) + """) assert [ Index( seq=0, @@ -130,13 +128,11 @@ def test_indexes(fresh_db): def test_xindexes(fresh_db): - fresh_db.executescript( - """ + fresh_db.executescript(""" create table Gosh (c1 text, c2 text, c3 text); create index Gosh_c1 on Gosh(c1); create index Gosh_c2c3 on Gosh(c2, c3 desc); - """ - ) + """) assert fresh_db["Gosh"].xindexes == [ XIndex( name="Gosh_c2c3", diff --git a/tests/test_transform.py b/tests/test_transform.py index e763a6cd..e7bb4143 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -638,15 +638,13 @@ def test_transform_with_indexes_errors(fresh_db, transform_params): def test_transform_with_unique_constraint_implicit_index(fresh_db): dogs = fresh_db["dogs"] # Create a table with a UNIQUE constraint on 'name', which creates an implicit index - fresh_db.execute( - """ + fresh_db.execute(""" CREATE TABLE dogs ( id INTEGER PRIMARY KEY, name TEXT UNIQUE, age INTEGER ); - """ - ) + """) dogs.insert({"id": 1, "name": "Cleo", "age": 5}) # Attempt to transform the table without modifying 'name' @@ -661,3 +659,188 @@ def test_transform_with_unique_constraint_implicit_index(fresh_db): "You must manually drop this index prior to running this transformation and manually recreate the new index after running this transformation." in str(excinfo.value) ) + + +def test_transform_update_incoming_fks_on_column_rename(fresh_db): + """ + Test that update_incoming_fks=True updates FK constraints in other tables + when a referenced column is renamed. + """ + fresh_db.execute("PRAGMA foreign_keys=ON") + + # Create authors table with id as PK + fresh_db["authors"].insert({"id": 1, "name": "Alice"}, pk="id") + + # Create books table with FK to authors.id + fresh_db["books"].insert( + {"id": 1, "title": "Book A", "author_id": 1}, + pk="id", + foreign_keys=[("author_id", "authors", "id")], + ) + + # Verify initial FK + assert fresh_db["books"].foreign_keys == [ + ForeignKey( + table="books", column="author_id", other_table="authors", other_column="id" + ) + ] + + # Rename authors.id to authors.author_pk with update_incoming_fks=True + fresh_db["authors"].transform( + rename={"id": "author_pk"}, + update_incoming_fks=True, + ) + + # Verify authors column was renamed + assert "author_pk" in fresh_db["authors"].columns_dict + assert "id" not in fresh_db["authors"].columns_dict + + # Verify books FK was updated to point to new column name + assert fresh_db["books"].foreign_keys == [ + ForeignKey( + table="books", + column="author_id", + other_table="authors", + other_column="author_pk", + ) + ] + + # Verify data integrity + assert list(fresh_db["authors"].rows) == [{"author_pk": 1, "name": "Alice"}] + assert list(fresh_db["books"].rows) == [ + {"id": 1, "title": "Book A", "author_id": 1} + ] + + # Verify FK enforcement still works + assert fresh_db.execute("PRAGMA foreign_keys").fetchone()[0] == 1 + violations = list(fresh_db.execute("PRAGMA foreign_key_check").fetchall()) + assert violations == [] + + +def test_transform_update_incoming_fks_multiple_tables(fresh_db): + """ + Test that update_incoming_fks=True updates FK constraints in multiple tables + when a referenced column is renamed. + """ + fresh_db.execute("PRAGMA foreign_keys=ON") + + # Create authors table with id as PK + fresh_db["authors"].insert({"id": 1, "name": "Alice"}, pk="id") + + # Create multiple tables with FKs to authors.id + fresh_db["books"].insert( + {"id": 1, "title": "Book A", "author_id": 1}, + pk="id", + foreign_keys=[("author_id", "authors", "id")], + ) + fresh_db["articles"].insert( + {"id": 1, "headline": "Article A", "writer_id": 1}, + pk="id", + foreign_keys=[("writer_id", "authors", "id")], + ) + fresh_db["quotes"].insert( + {"id": 1, "text": "Quote A", "speaker_id": 1}, + pk="id", + foreign_keys=[("speaker_id", "authors", "id")], + ) + + # Rename authors.id to authors.author_pk with update_incoming_fks=True + fresh_db["authors"].transform( + rename={"id": "author_pk"}, + update_incoming_fks=True, + ) + + # Verify authors column was renamed + assert "author_pk" in fresh_db["authors"].columns_dict + assert "id" not in fresh_db["authors"].columns_dict + + # Verify all FKs were updated + assert fresh_db["books"].foreign_keys == [ + ForeignKey( + table="books", + column="author_id", + other_table="authors", + other_column="author_pk", + ) + ] + assert fresh_db["articles"].foreign_keys == [ + ForeignKey( + table="articles", + column="writer_id", + other_table="authors", + other_column="author_pk", + ) + ] + assert fresh_db["quotes"].foreign_keys == [ + ForeignKey( + table="quotes", + column="speaker_id", + other_table="authors", + other_column="author_pk", + ) + ] + + # Verify FK enforcement still works + violations = list(fresh_db.execute("PRAGMA foreign_key_check").fetchall()) + assert violations == [] + + +def test_transform_update_incoming_fks_self_referential(fresh_db): + """ + Test that update_incoming_fks=True handles self-referential FK constraints. + """ + fresh_db.execute("PRAGMA foreign_keys=ON") + + # Create employees table with self-referential FK (manager_id -> id) + fresh_db.execute(""" + CREATE TABLE employees ( + id INTEGER PRIMARY KEY, + name TEXT, + manager_id INTEGER REFERENCES employees(id) + ) + """) + fresh_db["employees"].insert_all( + [ + {"id": 1, "name": "CEO", "manager_id": None}, + {"id": 2, "name": "VP", "manager_id": 1}, + {"id": 3, "name": "Dev", "manager_id": 2}, + ] + ) + + # Verify initial FK + assert fresh_db["employees"].foreign_keys == [ + ForeignKey( + table="employees", + column="manager_id", + other_table="employees", + other_column="id", + ) + ] + + # Rename employees.id to employees.emp_id with update_incoming_fks=True + fresh_db["employees"].transform( + rename={"id": "emp_id"}, + update_incoming_fks=True, + ) + + # Verify column was renamed + assert "emp_id" in fresh_db["employees"].columns_dict + assert "id" not in fresh_db["employees"].columns_dict + + # Verify self-referential FK was updated + assert fresh_db["employees"].foreign_keys == [ + ForeignKey( + table="employees", + column="manager_id", + other_table="employees", + other_column="emp_id", + ) + ] + + # Verify data integrity + rows = list(fresh_db.execute("SELECT * FROM employees ORDER BY emp_id").fetchall()) + assert rows == [(1, "CEO", None), (2, "VP", 1), (3, "Dev", 2)] + + # Verify FK enforcement still works + violations = list(fresh_db.execute("PRAGMA foreign_key_check").fetchall()) + assert violations == []