Skip to content

Commit 663be50

Browse files
committed
feat: Add models for rest scan planning
1 parent 4c9d887 commit 663be50

File tree

2 files changed

+658
-0
lines changed

2 files changed

+658
-0
lines changed
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
from datetime import date, datetime, time
20+
from decimal import Decimal
21+
from typing import Annotated, Generic, Literal, TypeAlias, TypeVar
22+
from uuid import UUID
23+
24+
from pydantic import Field, model_validator
25+
26+
from pyiceberg.catalog.rest.response import ErrorResponseMessage
27+
from pyiceberg.expressions import BooleanExpression
28+
from pyiceberg.manifest import FileFormat
29+
from pyiceberg.typedef import IcebergBaseModel
30+
31+
# Primitive types that can appear in partition values and bounds
32+
PrimitiveTypeValue: TypeAlias = bool | int | float | str | Decimal | UUID | date | time | datetime | bytes
33+
34+
V = TypeVar("V")
35+
36+
37+
class KeyValueMap(IcebergBaseModel, Generic[V]):
38+
"""Map serialized as parallel key/value arrays for column statistics."""
39+
40+
keys: list[int] = Field(default_factory=list)
41+
values: list[V] = Field(default_factory=list)
42+
43+
@model_validator(mode="after")
44+
def _validate_lengths_match(self) -> KeyValueMap[V]:
45+
if len(self.keys) != len(self.values):
46+
raise ValueError(f"keys and values must have same length: {len(self.keys)} != {len(self.values)}")
47+
return self
48+
49+
def to_dict(self) -> dict[int, V]:
50+
"""Convert to dictionary mapping field ID to value."""
51+
return dict(zip(self.keys, self.values, strict=True))
52+
53+
54+
class CountMap(KeyValueMap[int]):
55+
"""Map of field IDs to counts."""
56+
57+
58+
class ValueMap(KeyValueMap[PrimitiveTypeValue]):
59+
"""Map of field IDs to primitive values (for lower/upper bounds)."""
60+
61+
62+
class StorageCredential(IcebergBaseModel):
63+
"""Storage credential for accessing content files."""
64+
65+
prefix: str = Field(description="Storage location prefix this credential applies to")
66+
config: dict[str, str] = Field(default_factory=dict)
67+
68+
69+
class RESTContentFile(IcebergBaseModel):
70+
"""Base model for data and delete files from REST API."""
71+
72+
spec_id: int = Field(alias="spec-id")
73+
partition: list[PrimitiveTypeValue] = Field(default_factory=list)
74+
content: Literal["data", "position-deletes", "equality-deletes"]
75+
file_path: str = Field(alias="file-path")
76+
file_format: FileFormat = Field(alias="file-format")
77+
file_size_in_bytes: int = Field(alias="file-size-in-bytes")
78+
record_count: int = Field(alias="record-count")
79+
key_metadata: str | None = Field(alias="key-metadata", default=None)
80+
split_offsets: list[int] | None = Field(alias="split-offsets", default=None)
81+
sort_order_id: int | None = Field(alias="sort-order-id", default=None)
82+
83+
84+
class RESTDataFile(RESTContentFile):
85+
"""Data file from REST API."""
86+
87+
content: Literal["data"] = Field(default="data")
88+
first_row_id: int | None = Field(alias="first-row-id", default=None)
89+
column_sizes: CountMap | None = Field(alias="column-sizes", default=None)
90+
value_counts: CountMap | None = Field(alias="value-counts", default=None)
91+
null_value_counts: CountMap | None = Field(alias="null-value-counts", default=None)
92+
nan_value_counts: CountMap | None = Field(alias="nan-value-counts", default=None)
93+
lower_bounds: ValueMap | None = Field(alias="lower-bounds", default=None)
94+
upper_bounds: ValueMap | None = Field(alias="upper-bounds", default=None)
95+
96+
97+
class RESTPositionDeleteFile(RESTContentFile):
98+
"""Position delete file from REST API."""
99+
100+
content: Literal["position-deletes"] = Field(default="position-deletes")
101+
referenced_data_file: str | None = Field(alias="referenced-data-file", default=None)
102+
content_offset: int | None = Field(alias="content-offset", default=None)
103+
content_size_in_bytes: int | None = Field(alias="content-size-in-bytes", default=None)
104+
105+
106+
class RESTEqualityDeleteFile(RESTContentFile):
107+
"""Equality delete file from REST API."""
108+
109+
content: Literal["equality-deletes"] = Field(default="equality-deletes")
110+
equality_ids: list[int] | None = Field(alias="equality-ids", default=None)
111+
112+
113+
# Discriminated union for delete files
114+
RESTDeleteFile = Annotated[
115+
RESTPositionDeleteFile | RESTEqualityDeleteFile,
116+
Field(discriminator="content"),
117+
]
118+
119+
120+
class RESTFileScanTask(IcebergBaseModel):
121+
"""A file scan task from the REST server."""
122+
123+
data_file: RESTDataFile = Field(alias="data-file")
124+
delete_file_references: list[int] | None = Field(alias="delete-file-references", default=None)
125+
residual_filter: BooleanExpression | None = Field(alias="residual-filter", default=None)
126+
127+
128+
class ScanTasks(IcebergBaseModel):
129+
"""Container for scan tasks returned by the server."""
130+
131+
delete_files: list[RESTDeleteFile] = Field(alias="delete-files", default_factory=list)
132+
file_scan_tasks: list[RESTFileScanTask] = Field(alias="file-scan-tasks", default_factory=list)
133+
plan_tasks: list[str] = Field(alias="plan-tasks", default_factory=list)
134+
135+
@model_validator(mode="after")
136+
def _validate_delete_file_references(self) -> ScanTasks:
137+
# validate delete file references are in bounds
138+
max_idx = len(self.delete_files) - 1
139+
for task in self.file_scan_tasks:
140+
for idx in task.delete_file_references or []:
141+
if idx < 0 or idx > max_idx:
142+
raise ValueError(f"Invalid delete file reference: {idx} (valid range: 0-{max_idx})")
143+
144+
if self.delete_files and not self.file_scan_tasks:
145+
raise ValueError("Invalid response: deleteFiles should only be returned with fileScanTasks that reference them")
146+
147+
return self
148+
149+
150+
class PlanCompleted(ScanTasks):
151+
"""Completed scan plan result."""
152+
153+
status: Literal["completed"] = "completed"
154+
plan_id: str | None = Field(alias="plan-id", default=None)
155+
storage_credentials: list[StorageCredential] | None = Field(alias="storage-credentials", default=None)
156+
157+
158+
class PlanSubmitted(IcebergBaseModel):
159+
"""Scan plan submitted, poll for completion."""
160+
161+
status: Literal["submitted"] = "submitted"
162+
plan_id: str | None = Field(alias="plan-id", default=None)
163+
164+
165+
class PlanCancelled(IcebergBaseModel):
166+
"""Planning was cancelled."""
167+
168+
status: Literal["cancelled"] = "cancelled"
169+
170+
171+
class PlanFailed(IcebergBaseModel):
172+
"""Planning failed with error."""
173+
174+
status: Literal["failed"] = "failed"
175+
error: ErrorResponseMessage
176+
177+
178+
PlanningResponse = Annotated[
179+
PlanCompleted | PlanSubmitted | PlanCancelled | PlanFailed,
180+
Field(discriminator="status"),
181+
]
182+
183+
184+
class PlanTableScanRequest(IcebergBaseModel):
185+
"""Request body for planning a REST scan."""
186+
187+
snapshot_id: int | None = Field(alias="snapshot-id", default=None)
188+
select: list[str] | None = Field(default=None)
189+
filter: BooleanExpression | None = Field(default=None)
190+
case_sensitive: bool = Field(alias="case-sensitive", default=True)
191+
use_snapshot_schema: bool = Field(alias="use-snapshot-schema", default=False)
192+
start_snapshot_id: int | None = Field(alias="start-snapshot-id", default=None)
193+
end_snapshot_id: int | None = Field(alias="end-snapshot-id", default=None)
194+
stats_fields: list[str] | None = Field(alias="stats-fields", default=None)
195+
196+
@model_validator(mode="after")
197+
def _validate_snapshot_fields(self) -> PlanTableScanRequest:
198+
if self.start_snapshot_id is not None and self.end_snapshot_id is None:
199+
raise ValueError("end-snapshot-id is required when start-snapshot-id is specified")
200+
if self.snapshot_id is not None and self.start_snapshot_id is not None:
201+
raise ValueError("Cannot specify both snapshot-id and start-snapshot-id")
202+
return self
203+
204+
205+
class FetchScanTasksRequest(IcebergBaseModel):
206+
"""Request body for fetching scan tasks endpoint."""
207+
208+
plan_task: str = Field(alias="plan-task")

0 commit comments

Comments
 (0)