|
27 | 27 | from pyiceberg.catalog.rest import RestCatalog |
28 | 28 | from pyiceberg.catalog.sql import SqlCatalog |
29 | 29 | from pyiceberg.exceptions import ( |
| 30 | + CommitFailedException, |
30 | 31 | CommitFailedException, |
31 | 32 | NamespaceAlreadyExistsError, |
32 | 33 | NamespaceNotEmptyError, |
|
36 | 37 | ) |
37 | 38 | from pyiceberg.io import WAREHOUSE |
38 | 39 | from pyiceberg.partitioning import PartitionField, PartitionSpec |
| 40 | +from pyiceberg.partitioning import PartitionField, PartitionSpec |
39 | 41 | from pyiceberg.schema import INITIAL_SCHEMA_ID, Schema |
| 42 | +from pyiceberg.transforms import BucketTransform |
40 | 43 | from pyiceberg.table.metadata import INITIAL_SPEC_ID |
41 | 44 | from pyiceberg.table.sorting import INITIAL_SORT_ORDER_ID, SortField, SortOrder |
42 | 45 | from pyiceberg.transforms import DayTransform, IdentityTransform |
@@ -90,7 +93,6 @@ def rest_test_catalog() -> Generator[Catalog, None, None]: |
90 | 93 | else: |
91 | 94 | pytest.skip("PYICEBERG_TEST_CATALOG environment variables not set") |
92 | 95 |
|
93 | | - |
94 | 96 | @pytest.fixture(scope="function") |
95 | 97 | def hive_catalog() -> Generator[Catalog, None, None]: |
96 | 98 | test_catalog = HiveCatalog( |
@@ -503,6 +505,69 @@ def test_update_namespace_properties(test_catalog: Catalog, database_name: str) |
503 | 505 | assert "updated test description" == test_catalog.load_namespace_properties(database_name)["comment"] |
504 | 506 |
|
505 | 507 |
|
| 508 | +@pytest.mark.integration |
| 509 | +@pytest.mark.parametrize("test_catalog", CATALOGS) |
| 510 | +def test_update_table_spec(test_catalog: Catalog, test_schema: Schema, table_name: str, database_name: str) -> None: |
| 511 | + identifier = (database_name, table_name) |
| 512 | + test_catalog.create_namespace(database_name) |
| 513 | + table = test_catalog.create_table(identifier, test_schema) |
| 514 | + |
| 515 | + with table.update_spec() as update: |
| 516 | + update.add_field(source_column_name="VendorID", transform=BucketTransform(16), partition_field_name="shard") |
| 517 | + |
| 518 | + loaded = test_catalog.load_table(identifier) |
| 519 | + expected_spec = PartitionSpec( |
| 520 | + PartitionField(source_id=1, field_id=1000, transform=BucketTransform(16), name="shard"), spec_id=1 |
| 521 | + ) |
| 522 | + # The spec ID may not match, so check equality of the fields |
| 523 | + assert loaded.spec() == expected_spec |
| 524 | + |
| 525 | + |
| 526 | +@pytest.mark.integration |
| 527 | +@pytest.mark.parametrize("test_catalog", CATALOGS) |
| 528 | +def test_update_table_spec_conflict(test_catalog: Catalog, test_schema: Schema, table_name: str, database_name: str) -> None: |
| 529 | + identifier = (database_name, table_name) |
| 530 | + test_catalog.create_namespace(database_name) |
| 531 | + spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=BucketTransform(16), name="id_bucket")) |
| 532 | + table = test_catalog.create_table(identifier, test_schema, partition_spec=spec) |
| 533 | + |
| 534 | + update = table.update_spec() |
| 535 | + update.add_field(source_column_name="tpep_pickup_datetime", transform=BucketTransform(16), partition_field_name="shard") |
| 536 | + |
| 537 | + # update with conflict |
| 538 | + conflict_table = test_catalog.load_table(identifier) |
| 539 | + with conflict_table.update_spec() as conflict_update: |
| 540 | + conflict_update.remove_field("id_bucket") |
| 541 | + |
| 542 | + with pytest.raises( |
| 543 | + CommitFailedException, match="Requirement failed: default spec id has changed|default partition spec changed" |
| 544 | + ): |
| 545 | + update.commit() |
| 546 | + |
| 547 | + loaded = test_catalog.load_table(identifier) |
| 548 | + assert loaded.spec() == PartitionSpec(spec_id=1) |
| 549 | + |
| 550 | + |
| 551 | +@pytest.mark.integration |
| 552 | +@pytest.mark.parametrize("test_catalog", CATALOGS) |
| 553 | +def test_update_table_spec_then_revert(test_catalog: Catalog, test_schema: Schema, table_name: str, database_name: str) -> None: |
| 554 | + identifier = (database_name, table_name) |
| 555 | + test_catalog.create_namespace(database_name) |
| 556 | + |
| 557 | + initial_spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=BucketTransform(16), name="id_bucket")) |
| 558 | + |
| 559 | + table = test_catalog.create_table(identifier, test_schema, partition_spec=initial_spec, properties={"format-version": "2"}) |
| 560 | + assert table.format_version == 2 |
| 561 | + |
| 562 | + with table.update_spec() as update: |
| 563 | + update.add_identity(source_column_name="tpep_pickup_datetime") |
| 564 | + |
| 565 | + with table.update_spec() as update: |
| 566 | + update.remove_field("tpep_pickup_datetime") |
| 567 | + |
| 568 | + assert table.spec() == initial_spec |
| 569 | + |
| 570 | + |
506 | 571 | @pytest.mark.integration |
507 | 572 | @pytest.mark.parametrize("test_catalog", CATALOGS) |
508 | 573 | def test_register_table(test_catalog: Catalog, table_schema_nested: Schema, table_name: str, database_name: str) -> None: |
|
0 commit comments