Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import json
import mimetypes
import os
import re
import sys
import textwrap
import threading
Expand Down Expand Up @@ -105,6 +106,44 @@ def get_batch_size(t: Type) -> Optional[int]:
return None


class FileExtension:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follows same pattern as BatchSize:

class BatchSize:

"""
This is used to annotate a FlyteFile when we want to download the file with a specific extension. For example,

```python
# ContainerTask
def t1(file: Annotated[FlyteFile, FileExtension("csv")]):
... # copilot downloads the file to e.g. /inputs/file.csv

versus...

def t1(file: FlyteFile["csv"]):
... # copilot downloads the file to e.g. /inputs/file
```

val: (Default is "") The file extension (e.g. "csv", "parquet") to use during copilot download.
"""

def __init__(self, val: str = ""):
self._val = val

pattern = r"^[a-zA-Z0-9]+(\.[a-zA-Z0-9]+)*$"
if not re.match(pattern, self._val):
raise ValueError(f"Invalid file extension: {self._val}")

@property
def val(self) -> str:
return self._val


def get_file_extension(t: Type) -> Optional[str]:
if is_annotated(t):
for annotation in get_args(t)[1:]:
if isinstance(annotation, FileExtension):
return annotation.val
return None


def modify_literal_uris(lit: Literal):
"""
Modifies the literal object recursively to replace the URIs with the native paths in case they are of
Expand Down
26 changes: 23 additions & 3 deletions flytekit/models/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,16 @@ class BlobDimensionality(object):
SINGLE = _types_pb2.BlobType.SINGLE
MULTIPART = _types_pb2.BlobType.MULTIPART

def __init__(self, format, dimensionality):
def __init__(self, format, dimensionality, file_extension=""):
"""
:param Text format: A string describing the format of the underlying blob data.
:param int dimensionality: An integer from BlobType.BlobDimensionality enum
:param Text file_extension: The file extension (e.g. "csv", "parquet") to use
during copilot download, e.g. "csv", "parquet". Empty by default.
"""
self._format = format
self._dimensionality = dimensionality
self._file_extension = file_extension

@property
def format(self):
Expand All @@ -62,16 +65,33 @@ def dimensionality(self):
"""
return self._dimensionality

@property
def file_extension(self):
"""
The file extension (e.g. "csv", "parquet") to use during copilot download.
Default is "", which means no extension is appended.
:rtype: Text
"""
return self._file_extension

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.types_pb2.BlobType
"""
return _types_pb2.BlobType(format=self.format, dimensionality=self.dimensionality)
return _types_pb2.BlobType(
format=self.format,
dimensionality=self.dimensionality,
file_extension=self._file_extension,
)

@classmethod
def from_flyte_idl(cls, proto):
"""
:param flyteidl.core.types_pb2.BlobType proto:
:rtype: BlobType
"""
return cls(format=proto.format, dimensionality=proto.dimensionality)
return cls(
format=proto.format,
dimensionality=proto.dimensionality,
file_extension=proto.file_extension,
)
28 changes: 24 additions & 4 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
AsyncTypeTransformer,
TypeEngine,
TypeTransformerFailedError,
get_file_extension,
get_underlying_type,
)
from flytekit.exceptions.user import FlyteAssertion
Expand Down Expand Up @@ -477,8 +478,17 @@ def get_format(t: typing.Union[typing.Type[FlyteFile], os.PathLike]) -> str:
return ""
return cast(FlyteFile, t).extension()

def _blob_type(self, format: str) -> BlobType:
return BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE)
@staticmethod
def get_file_extension(t: typing.Union[typing.Type[FlyteFile], os.PathLike]) -> str:
if t is os.PathLike:
return ""
file_extension = get_file_extension(t)
if file_extension is None:
return ""
return file_extension

def _blob_type(self, format: str, file_extension: str = "") -> BlobType:
return BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE, file_extension=file_extension)

def assert_type(
self, t: typing.Union[typing.Type[FlyteFile], os.PathLike], v: typing.Union[FlyteFile, os.PathLike, str]
Expand All @@ -491,7 +501,12 @@ def assert_type(
)

def get_literal_type(self, t: typing.Union[typing.Type[FlyteFile], os.PathLike]) -> LiteralType:
return LiteralType(blob=self._blob_type(format=FlyteFilePathTransformer.get_format(t)))
return LiteralType(
blob=self._blob_type(
format=FlyteFilePathTransformer.get_format(t),
file_extension=FlyteFilePathTransformer.get_file_extension(t),
)
)

def get_mime_type_from_extension(self, extension: str) -> typing.Union[str, typing.Sequence[str]]:
extension_to_mime_type = {
Expand Down Expand Up @@ -565,7 +580,12 @@ async def async_to_literal(
raise ValueError(f"Incorrect type {python_type}, must be either a FlyteFile or os.PathLike")

# information used by all cases
meta = BlobMetadata(type=self._blob_type(format=FlyteFilePathTransformer.get_format(python_type)))
meta = BlobMetadata(
type=self._blob_type(
format=FlyteFilePathTransformer.get_format(python_type),
file_extension=FlyteFilePathTransformer.get_file_extension(python_type),
)
)

if isinstance(python_val, FlyteFile):
# Cast the source path to str type to avoid error raised when the source path is used as the blob uri,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [
"diskcache>=5.2.1",
"docker>=4.0.0",
"docstring-parser>=0.9.0",
"flyteidl>=1.16.1,<2.0.0a0",
"flyteidl @ git+https://github.com/ddl-rliu/flyte.git@93ff903e63de6384d41db4c9da8df155612d16db#subdirectory=flyteidl",
"fsspec>=2023.3.0",
# Bug in 2025.5.0, 2025.5.0post1 https://github.com/fsspec/gcsfs/issues/687
# Bug in 2024.2.0 https://github.com/fsspec/gcsfs/pull/643
Expand Down
28 changes: 27 additions & 1 deletion tests/flytekit/unit/core/test_flyte_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from flytekit.core.hash import HashMethod
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.task import task
from flytekit.core.type_engine import TypeEngine
from flytekit.core.type_engine import FileExtension, TypeEngine
from flytekit.core.workflow import workflow
from flytekit.models.core.types import BlobType
from flytekit.models.literals import LiteralMap, Blob, BlobMetadata
Expand Down Expand Up @@ -764,6 +764,32 @@ def test_headers():
assert len(FlyteFilePathTransformer.get_additional_headers(".gz")) == 1


def test_transform_flytefile_with_file_extension():
csv_file_no_file_extension = FlyteFile["csv"]
lt = FlyteFilePathTransformer().get_literal_type(csv_file_no_file_extension)
assert lt.blob.file_extension == ""

csv_file_with_file_extension = Annotated[FlyteFile["csv"], FileExtension("csv")]
lt = FlyteFilePathTransformer().get_literal_type(csv_file_with_file_extension)
assert lt.blob.file_extension == "csv"


def test_file_extension_valid_compound_extension():
extension = FileExtension("tar.gz")
assert extension.val == "tar.gz"


@pytest.mark.parametrize("bad_ext", [
".csv",
"my file",
"../../escape",
"csv!",
])
def test_file_extension_rejects_invalid_extensions(bad_ext):
with pytest.raises(ValueError, match="Invalid file extension"):
FileExtension(bad_ext)


def test_new_remote_file():
nf = FlyteFile.new_remote_file(name="foo.txt")
assert isinstance(nf, FlyteFile)
Expand Down
17 changes: 17 additions & 0 deletions tests/flytekit/unit/models/core/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,28 @@ def test_blob_type():
)
assert o.format == "csv"
assert o.dimensionality == _types.BlobType.BlobDimensionality.SINGLE
assert o.file_extension == ""

o2 = _types.BlobType.from_flyte_idl(o.to_flyte_idl())
assert o == o2
assert o2.format == "csv"
assert o2.dimensionality == _types.BlobType.BlobDimensionality.SINGLE
assert o2.file_extension == ""

o = _types.BlobType(
format="csv",
dimensionality=_types.BlobType.BlobDimensionality.SINGLE,
file_extension="csv",
)
assert o.format == "csv"
assert o.dimensionality == _types.BlobType.BlobDimensionality.SINGLE
assert o.file_extension == "csv"

o2 = _types.BlobType.from_flyte_idl(o.to_flyte_idl())
assert o == o2
assert o2.format == "csv"
assert o2.dimensionality == _types.BlobType.BlobDimensionality.SINGLE
assert o2.file_extension == "csv"


def test_enum_type():
Expand Down
Loading