|
27 | 27 | from pyiceberg.catalog.rest import RestCatalog |
28 | 28 | from pyiceberg.catalog.sql import SqlCatalog |
29 | 29 | from pyiceberg.exceptions import ( |
| 30 | + CommitFailedException, |
30 | 31 | NamespaceAlreadyExistsError, |
31 | 32 | NamespaceNotEmptyError, |
32 | 33 | NoSuchNamespaceError, |
33 | 34 | NoSuchTableError, |
34 | 35 | TableAlreadyExistsError, |
35 | 36 | ) |
36 | 37 | from pyiceberg.io import WAREHOUSE |
| 38 | +from pyiceberg.partitioning import PartitionField, PartitionSpec |
37 | 39 | from pyiceberg.schema import Schema |
| 40 | +from pyiceberg.transforms import BucketTransform |
38 | 41 | from tests.conftest import clean_up |
39 | 42 |
|
40 | 43 |
|
@@ -85,7 +88,6 @@ def rest_test_catalog() -> Generator[Catalog, None, None]: |
85 | 88 | pytest.skip("PYICEBERG_TEST_CATALOG environment variables not set") |
86 | 89 |
|
87 | 90 |
|
88 | | -@pytest.fixture(scope="function") |
89 | 91 | def hive_catalog() -> Generator[Catalog, None, None]: |
90 | 92 | test_catalog = HiveCatalog( |
91 | 93 | "test_hive_catalog", |
@@ -355,3 +357,66 @@ def test_update_namespace_properties(test_catalog: Catalog, database_name: str) |
355 | 357 | else: |
356 | 358 | assert k in update_report.removed |
357 | 359 | assert "updated test description" == test_catalog.load_namespace_properties(database_name)["comment"] |
| 360 | + |
| 361 | + |
| 362 | +@pytest.mark.integration |
| 363 | +@pytest.mark.parametrize("test_catalog", CATALOGS) |
| 364 | +def test_update_table_spec(test_catalog: Catalog, test_schema: Schema, table_name: str, database_name: str) -> None: |
| 365 | + identifier = (database_name, table_name) |
| 366 | + test_catalog.create_namespace(database_name) |
| 367 | + table = test_catalog.create_table(identifier, test_schema) |
| 368 | + |
| 369 | + with table.update_spec() as update: |
| 370 | + update.add_field(source_column_name="VendorID", transform=BucketTransform(16), partition_field_name="shard") |
| 371 | + |
| 372 | + loaded = test_catalog.load_table(identifier) |
| 373 | + expected_spec = PartitionSpec( |
| 374 | + PartitionField(source_id=1, field_id=1000, transform=BucketTransform(16), name="shard"), spec_id=1 |
| 375 | + ) |
| 376 | + # The spec ID may not match, so check equality of the fields |
| 377 | + assert loaded.spec() == expected_spec |
| 378 | + |
| 379 | + |
| 380 | +@pytest.mark.integration |
| 381 | +@pytest.mark.parametrize("test_catalog", CATALOGS) |
| 382 | +def test_update_table_spec_conflict(test_catalog: Catalog, test_schema: Schema, table_name: str, database_name: str) -> None: |
| 383 | + identifier = (database_name, table_name) |
| 384 | + test_catalog.create_namespace(database_name) |
| 385 | + spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=BucketTransform(16), name="id_bucket")) |
| 386 | + table = test_catalog.create_table(identifier, test_schema, partition_spec=spec) |
| 387 | + |
| 388 | + update = table.update_spec() |
| 389 | + update.add_field(source_column_name="tpep_pickup_datetime", transform=BucketTransform(16), partition_field_name="shard") |
| 390 | + |
| 391 | + # update with conflict |
| 392 | + conflict_table = test_catalog.load_table(identifier) |
| 393 | + with conflict_table.update_spec() as conflict_update: |
| 394 | + conflict_update.remove_field("id_bucket") |
| 395 | + |
| 396 | + with pytest.raises( |
| 397 | + CommitFailedException, match="Requirement failed: default spec id has changed|default partition spec changed" |
| 398 | + ): |
| 399 | + update.commit() |
| 400 | + |
| 401 | + loaded = test_catalog.load_table(identifier) |
| 402 | + assert loaded.spec() == PartitionSpec(spec_id=1) |
| 403 | + |
| 404 | + |
| 405 | +@pytest.mark.integration |
| 406 | +@pytest.mark.parametrize("test_catalog", CATALOGS) |
| 407 | +def test_update_table_spec_then_revert(test_catalog: Catalog, test_schema: Schema, table_name: str, database_name: str) -> None: |
| 408 | + identifier = (database_name, table_name) |
| 409 | + test_catalog.create_namespace(database_name) |
| 410 | + |
| 411 | + initial_spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=BucketTransform(16), name="id_bucket")) |
| 412 | + |
| 413 | + table = test_catalog.create_table(identifier, test_schema, partition_spec=initial_spec, properties={"format-version": "2"}) |
| 414 | + assert table.format_version == 2 |
| 415 | + |
| 416 | + with table.update_spec() as update: |
| 417 | + update.add_identity(source_column_name="tpep_pickup_datetime") |
| 418 | + |
| 419 | + with table.update_spec() as update: |
| 420 | + update.remove_field("tpep_pickup_datetime") |
| 421 | + |
| 422 | + assert table.spec() == initial_spec |
0 commit comments