From f94e46209ee79619cf95955f7eef20ba0d53cc70 Mon Sep 17 00:00:00 2001 From: Aireen Mei Date: Tue, 3 Feb 2026 15:59:55 -0800 Subject: [PATCH] Roll forward after fix: https://github.com/AI-Hypercomputer/maxtext/pull/3050 Reverts 33209a81f0bc61aa5aa2f0762e8593ca9b3745e6 PiperOrigin-RevId: 865085994 --- src/MaxText/layers/engram.py | 2 +- src/MaxText/rl/train_rl.py | 3 +- src/MaxText/train.py | 1 + src/maxtext/common/checkpointing.py | 4 +- src/maxtext/configs/base.yml | 2 +- .../examples/sft_train_and_evaluate.py | 4 +- .../experimental/rl/grpo_input_pipeline.py | 10 ++-- .../inference/inference_microbenchmark.py | 2 +- .../inference/mlperf/offline_inference.py | 4 +- src/maxtext/inference/offline_engine.py | 2 +- .../input_pipeline/__init__.py | 2 +- .../distillation_data_processing.py} | 4 +- .../input_pipeline/grain_data_processing.py} | 34 ++++++------- .../input_pipeline/grain_tokenizer.py} | 2 +- .../input_pipeline/hf_data_processing.py} | 50 +++++++++---------- .../input_pipeline_interface.py | 10 ++-- .../input_pipeline/input_pipeline_utils.py} | 2 +- .../instruction_data_processing.py | 0 .../input_pipeline}/multihost_dataloading.py | 0 .../input_pipeline/packing/__init__.py | 13 +++++ .../packing}/prefill_packing.py | 0 .../packing}/sequence_packing.py | 0 .../synthetic_data_processing.py | 2 +- .../input_pipeline/tfds_data_processing.py} | 18 +++---- .../tfds_data_processing_c4_mlperf.py} | 8 +-- .../input_pipeline}/tokenizer.py | 0 .../post_train/distillation/train_distill.py | 4 +- src/maxtext/trainers/post_train/sft/hooks.py | 2 +- src/maxtext/utils/train_utils.py | 2 +- .../sft_trainer_correctness_test.py | 10 ++-- .../unit/distillation_data_processing_test.py | 14 +++--- tests/unit/grain_data_processing_test.py | 21 ++++---- tests/unit/hf_data_processing_test.py | 8 +-- .../unit/instruction_data_processing_test.py | 2 +- tests/unit/multihost_dataloading_test.py | 2 +- tests/unit/multimodal_rope_check.py | 26 +++++----- tests/unit/sft_data_processing_test.py | 8 +-- tests/unit/tfds_data_processing_test.py | 10 ++-- tests/unit/tokenizer_test.py | 12 ++--- tests/unit/tokenizer_transform_test.py | 16 +++--- .../generate_distillation_data.py | 8 +-- 41 files changed, 169 insertions(+), 155 deletions(-) rename src/{MaxText => maxtext}/input_pipeline/__init__.py (93%) rename src/{MaxText/input_pipeline/_distillation_data_processing.py => maxtext/input_pipeline/distillation_data_processing.py} (98%) rename src/{MaxText/input_pipeline/_grain_data_processing.py => maxtext/input_pipeline/grain_data_processing.py} (93%) rename src/{MaxText/input_pipeline/_grain_tokenizer.py => maxtext/input_pipeline/grain_tokenizer.py} (98%) rename src/{MaxText/input_pipeline/_hf_data_processing.py => maxtext/input_pipeline/hf_data_processing.py} (91%) rename src/{MaxText => maxtext}/input_pipeline/input_pipeline_interface.py (89%) rename src/{MaxText/input_pipeline/_input_pipeline_utils.py => maxtext/input_pipeline/input_pipeline_utils.py} (99%) rename src/{MaxText => maxtext}/input_pipeline/instruction_data_processing.py (100%) rename src/{MaxText => maxtext/input_pipeline}/multihost_dataloading.py (100%) create mode 100644 src/maxtext/input_pipeline/packing/__init__.py rename src/{MaxText => maxtext/input_pipeline/packing}/prefill_packing.py (100%) rename src/{MaxText => maxtext/input_pipeline/packing}/sequence_packing.py (100%) rename src/{MaxText => maxtext}/input_pipeline/synthetic_data_processing.py (98%) rename src/{MaxText/input_pipeline/_tfds_data_processing.py => maxtext/input_pipeline/tfds_data_processing.py} (92%) rename src/{MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py => maxtext/input_pipeline/tfds_data_processing_c4_mlperf.py} (98%) rename src/{MaxText => maxtext/input_pipeline}/tokenizer.py (100%) diff --git a/src/MaxText/layers/engram.py b/src/MaxText/layers/engram.py index d47d6d80bd..7278fe76fa 100644 --- a/src/MaxText/layers/engram.py +++ b/src/MaxText/layers/engram.py @@ -29,7 +29,7 @@ from jax.sharding import Mesh from flax import nnx -from MaxText.tokenizer import HFTokenizer +from maxtext.input_pipeline.tokenizer import HFTokenizer from MaxText.common_types import MODEL_MODE_TRAIN, Array, Config from MaxText.layers.embeddings import Embed from MaxText.layers.initializers import nd_dense_init, NdInitializer diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index 70a76d9ce3..8473f72865 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -77,7 +77,7 @@ from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter from MaxText.rl.evaluate_rl import evaluate from MaxText.rl import utils_rl -from MaxText.input_pipeline.instruction_data_processing import load_template_from_file +from maxtext.input_pipeline.instruction_data_processing import load_template_from_file from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils @@ -370,7 +370,6 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): max_logging.log("Creating policy model with same config as reference model on trainer mesh") actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices) - if trainer_config.debug.rl: max_logging.log("Policy Model initialized successfully") nnx.display(actor_model) diff --git a/src/MaxText/train.py b/src/MaxText/train.py index c66472e685..aac919a602 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -538,6 +538,7 @@ def initialize(argv: Sequence[str]) -> tuple[pyconfig.HyperParameters, Any, Any] if config.shard_mode == ShardMode.EXPLICIT: jax.config.update("jax_remove_size_one_mesh_axis_from_type", True) os.environ["TFDS_DATA_DIR"] = config.dataset_path or "" + os.environ["HF_DATASETS_DISABLE_DILL"] = "1" vertex_tensorboard_manager = VertexTensorboardManager() if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): vertex_tensorboard_manager.configure_vertex_tensorboard(config) diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 9d3a347268..873a689565 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -23,8 +23,8 @@ from flax.training import train_state import jax from MaxText.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE -from MaxText.multihost_dataloading import MultiHostDataLoadIterator, RemoteIterator -from MaxText.input_pipeline.input_pipeline_interface import PlaceHolderDataIterator +from maxtext.input_pipeline.multihost_dataloading import MultiHostDataLoadIterator, RemoteIterator +from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator from maxtext.utils import exceptions from maxtext.utils import max_logging import numpy as np diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index b152f6d081..e520c309ae 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -556,7 +556,7 @@ add_eos: True # If False, use chunking for long sequences instead of truncation. # Note: use_truncation=False is only available in grain's pretrain preprocessing pipeline. # See the TokenizeAndTrim and TokenizeAndChunk classes in -# `src/MaxText/input_pipeline/_grain_tokenizer.py` for implementation details. +# `src/maxtext/input_pipeline/_grain_tokenizer.py` for implementation details. use_truncation: True # Dataset diff --git a/src/maxtext/examples/sft_train_and_evaluate.py b/src/maxtext/examples/sft_train_and_evaluate.py index 305a6b0b17..449c4dadc4 100644 --- a/src/maxtext/examples/sft_train_and_evaluate.py +++ b/src/maxtext/examples/sft_train_and_evaluate.py @@ -85,10 +85,10 @@ from flax import nnx -from MaxText.globals import MAXTEXT_REPO_ROOT from MaxText import pyconfig -from MaxText.input_pipeline import instruction_data_processing +from MaxText.globals import MAXTEXT_REPO_ROOT from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter +from maxtext.input_pipeline import instruction_data_processing from maxtext.trainers.post_train.sft import train_sft from maxtext.utils import max_logging from maxtext.utils import max_utils diff --git a/src/maxtext/experimental/rl/grpo_input_pipeline.py b/src/maxtext/experimental/rl/grpo_input_pipeline.py index cb1e31039d..7748391300 100644 --- a/src/maxtext/experimental/rl/grpo_input_pipeline.py +++ b/src/maxtext/experimental/rl/grpo_input_pipeline.py @@ -32,8 +32,8 @@ import grain.python as grain -from MaxText.input_pipeline import input_pipeline_interface -from MaxText.input_pipeline import _input_pipeline_utils +from maxtext.input_pipeline import input_pipeline_interface +from maxtext.input_pipeline import input_pipeline_utils class SingleHostDataLoader: @@ -141,7 +141,7 @@ def preprocessing_pipeline( ) dataset = dataset.map( - _input_pipeline_utils.tokenization, + input_pipeline_utils.tokenization, batched=True, fn_kwargs={ "hf_tokenizer": tokenizer, @@ -151,7 +151,7 @@ def preprocessing_pipeline( }, ) dataset = dataset.select_columns(data_column_names) - dataset = _input_pipeline_utils.HFDataSource( + dataset = input_pipeline_utils.HFDataSource( dataset, dataloading_host_index, dataloading_host_count, @@ -166,7 +166,7 @@ def lists2array(x): operations = [ grain.MapOperation(lists2array), - _input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, add_true_length=True), + input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, add_true_length=True), grain.Batch(batch_size=global_batch_size // jax.process_count(), drop_remainder=drop_remainder), ] diff --git a/src/maxtext/inference/inference_microbenchmark.py b/src/maxtext/inference/inference_microbenchmark.py index 22e9d2e6c9..0d378faf5c 100644 --- a/src/maxtext/inference/inference_microbenchmark.py +++ b/src/maxtext/inference/inference_microbenchmark.py @@ -23,9 +23,9 @@ from collections.abc import MutableMapping from MaxText import maxengine -from MaxText import prefill_packing from MaxText import pyconfig from maxtext.common import profiler +from maxtext.input_pipeline.packing import prefill_packing from maxtext.utils import gcs_utils from maxtext.utils import max_utils from maxtext.utils import maxtext_utils diff --git a/src/maxtext/inference/mlperf/offline_inference.py b/src/maxtext/inference/mlperf/offline_inference.py index 451ab9ef07..c1aea9bafe 100644 --- a/src/maxtext/inference/mlperf/offline_inference.py +++ b/src/maxtext/inference/mlperf/offline_inference.py @@ -35,8 +35,8 @@ # pylint: disable=no-name-in-module from MaxText.maxengine import MaxEngine from MaxText.maxengine import set_engine_vars_from_base_engine -from MaxText.prefill_packing import PrefillProcessor -from MaxText.prefill_packing import BatchedPrefillProcessor +from maxtext.input_pipeline.packing.prefill_packing import PrefillProcessor +from maxtext.input_pipeline.packing.prefill_packing import BatchedPrefillProcessor DecodeState = Any Params = Any diff --git a/src/maxtext/inference/offline_engine.py b/src/maxtext/inference/offline_engine.py index 232b2bcb5e..f1fc566190 100644 --- a/src/maxtext/inference/offline_engine.py +++ b/src/maxtext/inference/offline_engine.py @@ -54,7 +54,7 @@ from jax.experimental import mesh_utils from MaxText.maxengine import MaxEngine -from MaxText.prefill_packing import PrefillProcessor, BatchedPrefillProcessor +from maxtext.input_pipeline.packing.prefill_packing import PrefillProcessor, BatchedPrefillProcessor from maxtext.utils import max_logging from maxtext.utils import max_utils diff --git a/src/MaxText/input_pipeline/__init__.py b/src/maxtext/input_pipeline/__init__.py similarity index 93% rename from src/MaxText/input_pipeline/__init__.py rename to src/maxtext/input_pipeline/__init__.py index 2237c9162e..5c7e6e3878 100644 --- a/src/MaxText/input_pipeline/__init__.py +++ b/src/maxtext/input_pipeline/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023–2025 Google LLC +# Copyright 2023-2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/MaxText/input_pipeline/_distillation_data_processing.py b/src/maxtext/input_pipeline/distillation_data_processing.py similarity index 98% rename from src/MaxText/input_pipeline/_distillation_data_processing.py rename to src/maxtext/input_pipeline/distillation_data_processing.py index 574ebb4884..44495a39e8 100644 --- a/src/MaxText/input_pipeline/_distillation_data_processing.py +++ b/src/maxtext/input_pipeline/distillation_data_processing.py @@ -23,7 +23,7 @@ from dataclasses import dataclass, field -from MaxText.input_pipeline import _input_pipeline_utils +from maxtext.input_pipeline import input_pipeline_utils from maxtext.utils import max_logging @@ -83,7 +83,7 @@ def process_dataset(config, dataset): # pylint: disable=redefined-outer-name assert any( set(data_column_names) == set(supported) for supported in supported_columns ), f"Dataset column names mismatch. Expected columns to match one of {supported_columns}, but got {data_column_names}" - assert _input_pipeline_utils.is_conversational( + assert input_pipeline_utils.is_conversational( dataset.features, data_column_names ), "Dataset is not in conversational format." diff --git a/src/MaxText/input_pipeline/_grain_data_processing.py b/src/maxtext/input_pipeline/grain_data_processing.py similarity index 93% rename from src/MaxText/input_pipeline/_grain_data_processing.py rename to src/maxtext/input_pipeline/grain_data_processing.py index 61258ca493..cf2dd649bd 100644 --- a/src/MaxText/input_pipeline/_grain_data_processing.py +++ b/src/maxtext/input_pipeline/grain_data_processing.py @@ -26,10 +26,10 @@ from grain.experimental import BestFitPackIterDataset, pick_performance_config import grain.python as grain -from MaxText.input_pipeline import _input_pipeline_utils -from MaxText.input_pipeline import _grain_tokenizer -from MaxText import multihost_dataloading -from MaxText import tokenizer +from maxtext.input_pipeline import input_pipeline_utils +from maxtext.input_pipeline import grain_tokenizer +from maxtext.input_pipeline import multihost_dataloading +from maxtext.input_pipeline import tokenizer from maxtext.utils import gcs_utils from maxtext.utils import max_logging @@ -199,10 +199,10 @@ def pretrain_preprocessing_pipeline( ): """Use grain pipeline to pre-process the dataset and return iterators for pretrain""" if config.grain_file_type == "arrayrecord": - dataset = dataset.map(_input_pipeline_utils.ParseFeatures(data_columns, tokenize)) - dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) + dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize)) + dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) else: - dataset = dataset.map(_input_pipeline_utils.KeepFeatures(feature_names=data_columns)) + dataset = dataset.map(input_pipeline_utils.KeepFeatures(feature_names=data_columns)) assert len(data_columns) == 1 text_column = data_columns[0] @@ -224,13 +224,13 @@ def pretrain_preprocessing_pipeline( if tokenize: if config.use_truncation: - dataset = dataset.map(_grain_tokenizer.TokenizeAndTrim(text_column, config.max_target_length, tokenizer_model)) + dataset = dataset.map(grain_tokenizer.TokenizeAndTrim(text_column, config.max_target_length, tokenizer_model)) else: - dataset = dataset.apply(_grain_tokenizer.TokenizeAndChunk(text_column, config.max_target_length, tokenizer_model)) + dataset = dataset.apply(grain_tokenizer.TokenizeAndChunk(text_column, config.max_target_length, tokenizer_model)) data_columns = ("inputs", "targets") rekey_dict = {col: text_column for col in data_columns} - dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict)) + dataset = dataset.map(input_pipeline_utils.Rekey(rekey_dict)) # Pack and Batch examples. batch_size = config.global_batch_size_to_load // jax.process_count() @@ -273,15 +273,15 @@ def pretrain_preprocessing_pipeline( "targets_position": "targets_positions", "inputs_position": "inputs_positions", } - dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict)) + dataset = dataset.map(input_pipeline_utils.Rekey(rekey_dict)) else: - dataset = dataset.map(_input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id)) + dataset = dataset.map(input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id)) batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id) dataset = dataset.batch(batch_size, batch_fn=batch_fn) # Shift inputs for teacher-forced training dataset = dataset.map( - _input_pipeline_utils.ShiftData( + input_pipeline_utils.ShiftData( ignored_ids=[pad_id], axis=1, ) @@ -313,8 +313,8 @@ def dpo_preprocessing_pipeline( ): """Use grain to pre-process the dataset and return iterators for dpo fine-tuning""" if config.grain_file_type == "arrayrecord": - dataset = dataset.map(_input_pipeline_utils.ParseFeatures(data_columns, tokenize)) - dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) + dataset = dataset.map(input_pipeline_utils.ParseFeatures(data_columns, tokenize)) + dataset = dataset.map(input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) tokenizer_model = tokenizer.build_tokenizer( config.tokenizer_path, config.tokenizer_type, @@ -331,9 +331,9 @@ def dpo_preprocessing_pipeline( pad_id = -1 if tokenize: - dataset = dataset.map(_grain_tokenizer.TokenizeAndTrim(data_columns, config.max_target_length, tokenizer_model)) + dataset = dataset.map(grain_tokenizer.TokenizeAndTrim(data_columns, config.max_target_length, tokenizer_model)) - dataset = dataset.map(_input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id)) + dataset = dataset.map(input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id)) batch_size = config.global_batch_size_to_load // jax.process_count() batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id) dataset = dataset.batch(batch_size, batch_fn=batch_fn) diff --git a/src/MaxText/input_pipeline/_grain_tokenizer.py b/src/maxtext/input_pipeline/grain_tokenizer.py similarity index 98% rename from src/MaxText/input_pipeline/_grain_tokenizer.py rename to src/maxtext/input_pipeline/grain_tokenizer.py index d43b0d8ce2..b43c1a5681 100644 --- a/src/MaxText/input_pipeline/_grain_tokenizer.py +++ b/src/maxtext/input_pipeline/grain_tokenizer.py @@ -20,7 +20,7 @@ from typing import Any import grain.python as grain import numpy as np -from MaxText import tokenizer +from maxtext.input_pipeline import tokenizer @dataclasses.dataclass diff --git a/src/MaxText/input_pipeline/_hf_data_processing.py b/src/maxtext/input_pipeline/hf_data_processing.py similarity index 91% rename from src/MaxText/input_pipeline/_hf_data_processing.py rename to src/maxtext/input_pipeline/hf_data_processing.py index 5c2fdac568..84ad2c0975 100644 --- a/src/MaxText/input_pipeline/_hf_data_processing.py +++ b/src/maxtext/input_pipeline/hf_data_processing.py @@ -24,9 +24,9 @@ import numpy as np -from MaxText.input_pipeline import _input_pipeline_utils -from MaxText.input_pipeline import instruction_data_processing -from MaxText import multihost_dataloading +from maxtext.input_pipeline import input_pipeline_utils +from maxtext.input_pipeline import instruction_data_processing +from maxtext.input_pipeline import multihost_dataloading def _get_pad_id(tokenizer): @@ -73,7 +73,7 @@ def vision_sft_preprocessing_pipeline( # If multiple image columns are provided, merge them into a single 'images' column. if isinstance(image_column, list): dataset = dataset.map( - _input_pipeline_utils.merge_image_columns, + input_pipeline_utils.merge_image_columns, fn_kwargs={ "image_columns": image_column, "max_num_images_per_example": config.max_num_images_per_example, @@ -87,7 +87,7 @@ def vision_sft_preprocessing_pipeline( dataset = dataset.rename_column(image_column, "images") dataset = dataset.map( - _input_pipeline_utils.reformat_prompt, + input_pipeline_utils.reformat_prompt, fn_kwargs={ "column": text_columns[0], "image_placeholder": config.image_placeholder, @@ -95,12 +95,12 @@ def vision_sft_preprocessing_pipeline( }, ) dataset = dataset.map( - _input_pipeline_utils.reformat_response, + input_pipeline_utils.reformat_response, fn_kwargs={"column": text_columns[1], "model_name": config.model_name}, ) dataset = dataset.map( - _input_pipeline_utils.pre_process_image_sft, + input_pipeline_utils.pre_process_image_sft, fn_kwargs={"image_column": "images", "model_name": config.model_name}, ) @@ -114,7 +114,7 @@ def vision_sft_preprocessing_pipeline( pad_id = _get_pad_id(tokenizer) dataset = dataset.map( - _input_pipeline_utils.tokenization, + input_pipeline_utils.tokenization, batched=True, batch_size=global_batch_size, fn_kwargs={ @@ -125,11 +125,11 @@ def vision_sft_preprocessing_pipeline( }, ) dataset = dataset.map( - _input_pipeline_utils.prepare_text_for_image_fusion, + input_pipeline_utils.prepare_text_for_image_fusion, fn_kwargs={"column_name": text_columns[0], "model_name": config.model_name}, ) - dataset = _input_pipeline_utils.HFDataSource( + dataset = input_pipeline_utils.HFDataSource( dataset=dataset, dataloading_host_index=dataloading_host_index, dataloading_host_count=dataloading_host_count, @@ -139,7 +139,7 @@ def vision_sft_preprocessing_pipeline( ) operations = [] operations.append( - _input_pipeline_utils.SFTPromptMaskingVision( + input_pipeline_utils.SFTPromptMaskingVision( query_column=text_columns[0], response_column=text_columns[1], max_target_length=config.max_target_length, @@ -148,17 +148,17 @@ def vision_sft_preprocessing_pipeline( ) # TODO(aireenmei, hengtaoguo): support packing operations.append( - _input_pipeline_utils.PadOrTrimToMaxLength( + input_pipeline_utils.PadOrTrimToMaxLength( config.max_target_length, pad_id, model_name=config.model_name, max_num_images_per_example=config.max_num_images_per_example, ) ) - operations.append(_input_pipeline_utils.ExtractImagesAndMasks()) + operations.append(input_pipeline_utils.ExtractImagesAndMasks()) operations.append(grain.Batch(batch_size=batch_size, drop_remainder=True)) - operations.append(_input_pipeline_utils.FoldImagesIntoBatch(model_name=config.model_name)) - operations.append(_input_pipeline_utils.ShiftData(ignored_ids=[pad_id], axis=1)) + operations.append(input_pipeline_utils.FoldImagesIntoBatch(model_name=config.model_name)) + operations.append(input_pipeline_utils.ShiftData(ignored_ids=[pad_id], axis=1)) dummy_index_sampler = grain.IndexSampler( num_records=len(dataset), num_epochs=1, @@ -255,7 +255,7 @@ def preprocessing_pipeline( dataset=dataset, data_columns=data_column_names, chat_template_path=chat_template_path ) - assert _input_pipeline_utils.is_conversational( + assert input_pipeline_utils.is_conversational( dataset.features, data_column_names ), "Dataset is not in conversational format." @@ -265,7 +265,7 @@ def preprocessing_pipeline( {combined_column_name: [{"content": datasets.Value(dtype="string"), "role": datasets.Value(dtype="string")}]} ) dataset = dataset.map( - _input_pipeline_utils.combine_columns, + input_pipeline_utils.combine_columns, fn_kwargs={"columns": data_column_names, "data_column": combined_column_name}, remove_columns=data_column_names, features=dataset_features, @@ -273,7 +273,7 @@ def preprocessing_pipeline( data_column_names = list(dataset.features.keys()) dataset = dataset.map( - _input_pipeline_utils.apply_chat_template, + input_pipeline_utils.apply_chat_template, fn_kwargs={"tokenizer_model": tokenizer, "data_column_name": data_column_names[0]}, ) else: @@ -283,7 +283,7 @@ def preprocessing_pipeline( if tokenize: dataset = dataset.map( - _input_pipeline_utils.tokenization, + input_pipeline_utils.tokenization, batched=True, fn_kwargs={ "hf_tokenizer": tokenizer, @@ -293,7 +293,7 @@ def preprocessing_pipeline( }, ) - dataset = _input_pipeline_utils.HFDataSource( + dataset = input_pipeline_utils.HFDataSource( dataset, dataloading_host_index, dataloading_host_count, @@ -304,7 +304,7 @@ def preprocessing_pipeline( operations = [] if use_sft: operations.append( - _input_pipeline_utils.SFTPromptMasking( + input_pipeline_utils.SFTPromptMasking( text_column_name=data_column_names[0], completion_only=sft_train_on_completion_only, max_target_length=max_target_length, @@ -321,7 +321,7 @@ def lists2array(x): operations.append(grain.MapOperation(lists2array)) else: assert len(data_column_names) == 1 - operations.append(_input_pipeline_utils.HFNormalizeFeatures(data_column_names[0])) + operations.append(input_pipeline_utils.HFNormalizeFeatures(data_column_names[0])) data_column_names = ("inputs", "targets") if packing and not use_dpo: @@ -336,13 +336,13 @@ def lists2array(x): max_sequences_per_bin=max_segments, ) ) - operations.append(_input_pipeline_utils.ReformatPacking(data_column_names)) + operations.append(input_pipeline_utils.ReformatPacking(data_column_names)) else: - operations.append(_input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id)) + operations.append(input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id)) operations.append(grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder)) if shift and not use_dpo: - operations.append(_input_pipeline_utils.ShiftData(ignored_ids=[pad_id, tokenizer.bos_token_id], axis=1)) + operations.append(input_pipeline_utils.ShiftData(ignored_ids=[pad_id, tokenizer.bos_token_id], axis=1)) # Since HuggingFace IterableDataset does not support access through index # Indexes generated by dummy_index_sampler is not used. diff --git a/src/MaxText/input_pipeline/input_pipeline_interface.py b/src/maxtext/input_pipeline/input_pipeline_interface.py similarity index 89% rename from src/MaxText/input_pipeline/input_pipeline_interface.py rename to src/maxtext/input_pipeline/input_pipeline_interface.py index 9a21d463c9..a4ed110b5f 100644 --- a/src/MaxText/input_pipeline/input_pipeline_interface.py +++ b/src/maxtext/input_pipeline/input_pipeline_interface.py @@ -19,11 +19,11 @@ from jax.sharding import PartitionSpec as P from MaxText import pyconfig -from MaxText.input_pipeline._grain_data_processing import make_grain_train_iterator, make_grain_eval_iterator -from MaxText.input_pipeline._hf_data_processing import make_hf_train_iterator, make_hf_eval_iterator -from MaxText.input_pipeline._tfds_data_processing import make_tfds_train_iterator, make_tfds_eval_iterator -from MaxText.input_pipeline._tfds_data_processing_c4_mlperf import make_c4_mlperf_train_iterator, make_c4_mlperf_eval_iterator -from MaxText.input_pipeline.synthetic_data_processing import SyntheticDataIterator, PlaceHolderDataIterator +from maxtext.input_pipeline.grain_data_processing import make_grain_train_iterator, make_grain_eval_iterator +from maxtext.input_pipeline.hf_data_processing import make_hf_train_iterator, make_hf_eval_iterator +from maxtext.input_pipeline.tfds_data_processing import make_tfds_train_iterator, make_tfds_eval_iterator +from maxtext.input_pipeline.tfds_data_processing_c4_mlperf import make_c4_mlperf_train_iterator, make_c4_mlperf_eval_iterator +from maxtext.input_pipeline.synthetic_data_processing import SyntheticDataIterator, PlaceHolderDataIterator from maxtext.utils import max_logging diff --git a/src/MaxText/input_pipeline/_input_pipeline_utils.py b/src/maxtext/input_pipeline/input_pipeline_utils.py similarity index 99% rename from src/MaxText/input_pipeline/_input_pipeline_utils.py rename to src/maxtext/input_pipeline/input_pipeline_utils.py index c83bc46f50..3238691efa 100644 --- a/src/MaxText/input_pipeline/_input_pipeline_utils.py +++ b/src/maxtext/input_pipeline/input_pipeline_utils.py @@ -25,7 +25,7 @@ import grain.python as grain import numpy as np import tensorflow as tf -from MaxText import tokenizer +from maxtext.input_pipeline import tokenizer from maxtext.multimodal import processor as mm_processor from maxtext.multimodal import utils as mm_utils from maxtext.utils import max_logging diff --git a/src/MaxText/input_pipeline/instruction_data_processing.py b/src/maxtext/input_pipeline/instruction_data_processing.py similarity index 100% rename from src/MaxText/input_pipeline/instruction_data_processing.py rename to src/maxtext/input_pipeline/instruction_data_processing.py diff --git a/src/MaxText/multihost_dataloading.py b/src/maxtext/input_pipeline/multihost_dataloading.py similarity index 100% rename from src/MaxText/multihost_dataloading.py rename to src/maxtext/input_pipeline/multihost_dataloading.py diff --git a/src/maxtext/input_pipeline/packing/__init__.py b/src/maxtext/input_pipeline/packing/__init__.py new file mode 100644 index 0000000000..7c14f5fcde --- /dev/null +++ b/src/maxtext/input_pipeline/packing/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/MaxText/prefill_packing.py b/src/maxtext/input_pipeline/packing/prefill_packing.py similarity index 100% rename from src/MaxText/prefill_packing.py rename to src/maxtext/input_pipeline/packing/prefill_packing.py diff --git a/src/MaxText/sequence_packing.py b/src/maxtext/input_pipeline/packing/sequence_packing.py similarity index 100% rename from src/MaxText/sequence_packing.py rename to src/maxtext/input_pipeline/packing/sequence_packing.py diff --git a/src/MaxText/input_pipeline/synthetic_data_processing.py b/src/maxtext/input_pipeline/synthetic_data_processing.py similarity index 98% rename from src/MaxText/input_pipeline/synthetic_data_processing.py rename to src/maxtext/input_pipeline/synthetic_data_processing.py index e0107d937c..80332ace9f 100644 --- a/src/MaxText/input_pipeline/synthetic_data_processing.py +++ b/src/maxtext/input_pipeline/synthetic_data_processing.py @@ -25,7 +25,7 @@ import jax.numpy as jnp from jax.sharding import PartitionSpec as P -from MaxText import multihost_dataloading +from maxtext.input_pipeline import multihost_dataloading from MaxText import pyconfig diff --git a/src/MaxText/input_pipeline/_tfds_data_processing.py b/src/maxtext/input_pipeline/tfds_data_processing.py similarity index 92% rename from src/MaxText/input_pipeline/_tfds_data_processing.py rename to src/maxtext/input_pipeline/tfds_data_processing.py index f103acc628..dd39c721c3 100644 --- a/src/MaxText/input_pipeline/_tfds_data_processing.py +++ b/src/maxtext/input_pipeline/tfds_data_processing.py @@ -24,10 +24,10 @@ import jax -from MaxText import multihost_dataloading -from MaxText import tokenizer -from MaxText import sequence_packing -from MaxText.input_pipeline import _input_pipeline_utils +from maxtext.input_pipeline import multihost_dataloading +from maxtext.input_pipeline import tokenizer +from maxtext.input_pipeline.packing import sequence_packing +from maxtext.input_pipeline import input_pipeline_utils AUTOTUNE = tf.data.experimental.AUTOTUNE @@ -99,14 +99,14 @@ def preprocessing_pipeline( if not use_dpo: assert len(data_column_names) == 1 dataset = dataset.map( - lambda x: _input_pipeline_utils.normalize_features(x, data_column_names[0]), num_parallel_calls=AUTOTUNE + lambda x: input_pipeline_utils.normalize_features(x, data_column_names[0]), num_parallel_calls=AUTOTUNE ) else: dataset = dataset.map(lambda x: {col: x[col] for col in data_column_names}, num_parallel_calls=AUTOTUNE) data_column_names = data_column_names if use_dpo else ("inputs", "targets") - tokenizer_model = _input_pipeline_utils.get_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token) + tokenizer_model = input_pipeline_utils.get_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token) if tokenizer_model.pad_id is not None: pad_id = tokenizer_model.pad_id elif tokenizer_model.unk_id is not None: @@ -125,7 +125,7 @@ def preprocessing_pipeline( # 1 token for both inputs and targets extra_tokens = 1 if not use_dpo else 0 dataset = dataset.map( - lambda x: _input_pipeline_utils.truncate_to_max_allowable_length(x, max_target_length + extra_tokens), + lambda x: input_pipeline_utils.truncate_to_max_allowable_length(x, max_target_length + extra_tokens), num_parallel_calls=AUTOTUNE, ) @@ -138,7 +138,7 @@ def preprocessing_pipeline( # Shift inputs for teacher-forced training if shift and not use_dpo: dataset = dataset.map( - _input_pipeline_utils.shift_data_by_truncation, num_parallel_calls=tf.data.AUTOTUNE, deterministic=True + input_pipeline_utils.shift_data_by_truncation, num_parallel_calls=tf.data.AUTOTUNE, deterministic=True ) # Perform greedy sequence packing and batching @@ -154,7 +154,7 @@ def preprocessing_pipeline( drop_remainder=drop_remainder, ) dataset = dataset.map( - lambda x: _input_pipeline_utils.add_segmentation_and_position(x, data_column_names, padding_token=pad_id), + lambda x: input_pipeline_utils.add_segmentation_and_position(x, data_column_names, padding_token=pad_id), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True, ) diff --git a/src/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py b/src/maxtext/input_pipeline/tfds_data_processing_c4_mlperf.py similarity index 98% rename from src/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py rename to src/maxtext/input_pipeline/tfds_data_processing_c4_mlperf.py index 69b5534692..a5cc547d1c 100644 --- a/src/MaxText/input_pipeline/_tfds_data_processing_c4_mlperf.py +++ b/src/maxtext/input_pipeline/tfds_data_processing_c4_mlperf.py @@ -27,10 +27,10 @@ import jax.numpy as jnp from jax.experimental import multihost_utils -from MaxText import tokenizer -from MaxText import multihost_dataloading -from MaxText import sequence_packing -from MaxText.input_pipeline._input_pipeline_utils import get_tokenizer +from maxtext.input_pipeline import tokenizer +from maxtext.input_pipeline import multihost_dataloading +from maxtext.input_pipeline.packing import sequence_packing +from maxtext.input_pipeline.input_pipeline_utils import get_tokenizer from maxtext.utils import max_logging AUTOTUNE = tf.data.experimental.AUTOTUNE diff --git a/src/MaxText/tokenizer.py b/src/maxtext/input_pipeline/tokenizer.py similarity index 100% rename from src/MaxText/tokenizer.py rename to src/maxtext/input_pipeline/tokenizer.py diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 508febc4b1..3bcb3ce378 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -48,8 +48,8 @@ # MaxText Imports from MaxText import optimizers from MaxText import pyconfig -from MaxText import tokenizer -from MaxText.input_pipeline import input_pipeline_interface +from maxtext.input_pipeline import tokenizer +from maxtext.input_pipeline import input_pipeline_interface from maxtext.utils import max_logging from maxtext.utils import maxtext_utils from maxtext.utils import model_creation_utils diff --git a/src/maxtext/trainers/post_train/sft/hooks.py b/src/maxtext/trainers/post_train/sft/hooks.py index bc66db7c94..b197bd862c 100644 --- a/src/maxtext/trainers/post_train/sft/hooks.py +++ b/src/maxtext/trainers/post_train/sft/hooks.py @@ -33,7 +33,7 @@ from tunix.sft.hooks import DataHooks, TrainingHooks from MaxText import sharding -from MaxText.input_pipeline.input_pipeline_interface import create_data_iterator +from maxtext.input_pipeline.input_pipeline_interface import create_data_iterator from maxtext.common.data_loader import DataLoader from maxtext.common.goodput import GoodputEvent, record_goodput from maxtext.common.metric_logger import MetricLogger, MetadataKey diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index ecc66aa9c2..eb5656a9fc 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -190,7 +190,7 @@ def setup_train_loop(config, recorder, devices=None): state: the initialized train state """ # pylint: disable=import-outside-toplevel - from MaxText.input_pipeline.input_pipeline_interface import create_data_iterator + from maxtext.input_pipeline.input_pipeline_interface import create_data_iterator with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): model = model_creation_utils.from_config(config, devices) diff --git a/tests/integration/sft_trainer_correctness_test.py b/tests/integration/sft_trainer_correctness_test.py index b57cf7ae9f..e17c66a30a 100644 --- a/tests/integration/sft_trainer_correctness_test.py +++ b/tests/integration/sft_trainer_correctness_test.py @@ -38,13 +38,13 @@ from jax.sharding import Mesh from transformers import AutoTokenizer -from maxtext.utils import maxtext_utils from MaxText import pyconfig from MaxText.common_types import MODEL_MODE_TRAIN from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_ASSETS_ROOT, MAXTEXT_TEST_ASSETS_ROOT -from MaxText.input_pipeline import _input_pipeline_utils from MaxText.layers import models from MaxText.layers import quantizations +from maxtext.input_pipeline import input_pipeline_utils +from maxtext.utils import maxtext_utils def get_golden_data(model_name): @@ -80,15 +80,15 @@ def prepare_maxtext_inputs(maxtext_data, config): add_eos_token=False, model_max_length=config.max_target_length, ) - data = _input_pipeline_utils.apply_chat_template(maxtext_data, tokenizer, "messages") - tokenized_data = _input_pipeline_utils.tokenization( + data = input_pipeline_utils.apply_chat_template(maxtext_data, tokenizer, "messages") + tokenized_data = input_pipeline_utils.tokenization( data, hf_tokenizer=tokenizer, truncation=False, max_length=config.max_target_length, column_names=["messages"], ) - masked_inputs = _input_pipeline_utils.SFTPromptMasking( + masked_inputs = input_pipeline_utils.SFTPromptMasking( text_column_name="messages", completion_only=False, max_target_length=config.max_target_length, diff --git a/tests/unit/distillation_data_processing_test.py b/tests/unit/distillation_data_processing_test.py index 191a4e4f9d..c810fb9c12 100644 --- a/tests/unit/distillation_data_processing_test.py +++ b/tests/unit/distillation_data_processing_test.py @@ -25,7 +25,7 @@ from datasets import Dataset from MaxText.globals import MAXTEXT_ASSETS_ROOT -from MaxText.input_pipeline import _distillation_data_processing +from maxtext.input_pipeline import distillation_data_processing PROMPT_DATA = [ [ @@ -101,7 +101,7 @@ def test_data_processing_with_messages(self): config = self.parser.parse_args(["--data-columns", "messages"]) dataset = Dataset.from_dict({"messages": MESSAGES_DATA}) - processed_dataset = _distillation_data_processing.process_dataset(config, dataset) + processed_dataset = distillation_data_processing.process_dataset(config, dataset) expected_prompts = [ ["What color is the sky?", "Why is the sky blue?"], @@ -126,8 +126,8 @@ def test_data_filtering_with_messages(self): config = self.parser.parse_args(["--data-columns", "messages", "--use-chat-template"]) dataset = Dataset.from_dict({"messages": MESSAGES_DATA}) - processed_dataset = _distillation_data_processing.process_dataset(config, dataset) - filtered_dataset = _distillation_data_processing.filter_dataset(config, processed_dataset, self.tokenizer) + processed_dataset = distillation_data_processing.process_dataset(config, dataset) + filtered_dataset = distillation_data_processing.filter_dataset(config, processed_dataset, self.tokenizer) self.assertEqual(len(filtered_dataset), 1) self.assertEqual(filtered_dataset[0].prompt, "What color is the sky?") @@ -137,7 +137,7 @@ def test_data_processing_with_prompt_completion(self): config = self.parser.parse_args(["--data-columns", "prompt", "completion"]) dataset = Dataset.from_dict({"prompt": PROMPT_DATA, "completion": COMPLETION_DATA}) - processed_dataset = _distillation_data_processing.process_dataset(config, dataset) + processed_dataset = distillation_data_processing.process_dataset(config, dataset) expected_prompts = [ ["What color is the sky?", "Why is the sky blue?"], @@ -162,8 +162,8 @@ def test_data_filtering_with_prompt_completion(self): config = self.parser.parse_args(["--data-columns", "prompt", "completion", "--use-chat-template"]) dataset = Dataset.from_dict({"prompt": PROMPT_DATA, "completion": COMPLETION_DATA}) - processed_dataset = _distillation_data_processing.process_dataset(config, dataset) - filtered_dataset = _distillation_data_processing.filter_dataset(config, processed_dataset, self.tokenizer) + processed_dataset = distillation_data_processing.process_dataset(config, dataset) + filtered_dataset = distillation_data_processing.filter_dataset(config, processed_dataset, self.tokenizer) self.assertEqual(len(filtered_dataset), 1) self.assertEqual(filtered_dataset[0].prompt, "What color is the sky?") diff --git a/tests/unit/grain_data_processing_test.py b/tests/unit/grain_data_processing_test.py index cbb11bcd2e..a0aa33d962 100644 --- a/tests/unit/grain_data_processing_test.py +++ b/tests/unit/grain_data_processing_test.py @@ -28,9 +28,9 @@ from jax.experimental import mesh_utils from MaxText import pyconfig -from MaxText.input_pipeline import _grain_data_processing -from MaxText.input_pipeline import input_pipeline_interface from MaxText.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_REPO_ROOT +from maxtext.input_pipeline import grain_data_processing +from maxtext.input_pipeline import input_pipeline_interface from maxtext.common.gcloud_stub import is_decoupled from tests.utils.test_helpers import get_test_base_output_directory, get_test_config_path, get_test_dataset_path @@ -93,7 +93,7 @@ def setUp(self): self.config.max_target_length, self.mesh, ) - self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) def test_train_ds(self): expected_shape = [jax.device_count(), self.config.max_target_length] @@ -115,7 +115,7 @@ def test_train_ds(self): @pytest.mark.external_serving # Skipped in decoupled mode due to rocBLAS scratch buffer TF issues on GPU def test_batch_determinism(self): batch1 = next(self.train_iter) - train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) batch2 = next(train_iter) self.assertTrue((batch1["inputs"] == batch2["inputs"]).all()) self.assertTrue((batch1["targets"] == batch2["targets"]).all()) @@ -194,7 +194,7 @@ def setUp(self): self.config.max_target_length, self.mesh, ) - self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) class GrainArrayRecordProcessingWithMixtureConfigTest(GrainArrayRecordProcessingTest): @@ -267,9 +267,10 @@ def setUp(self): self.config.max_target_length, self.mesh, ) - self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) +@pytest.mark.skip(reason="Skipping this test class due to out of memory issue.") class GrainArrayRecordAutoTuneTest(GrainArrayRecordProcessingTest): """Test grain data processing with auto-tuning enabled (grain_worker_count=-1).""" @@ -323,7 +324,7 @@ def setUp(self): self.config.max_target_length, self.mesh, ) - self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) @pytest.mark.skip( reason=( @@ -395,7 +396,7 @@ def setUp(self): self.config.max_target_length, self.mesh, ) - self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) class GrainParquetProcessingTest(unittest.TestCase): @@ -456,7 +457,7 @@ def setUp(self): self.config.max_target_length, self.mesh, ) - self.train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) def test_train_ds(self): expected_shape = [jax.device_count(), self.config.max_target_length] @@ -477,7 +478,7 @@ def test_train_ds(self): def test_batch_determinism(self): batch1 = next(self.train_iter) - train_iter = _grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) + train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices) batch2 = next(train_iter) self.assertTrue((batch1["inputs"] == batch2["inputs"]).all()) self.assertTrue((batch1["targets"] == batch2["targets"]).all()) diff --git a/tests/unit/hf_data_processing_test.py b/tests/unit/hf_data_processing_test.py index 138f6f99fa..a16165f573 100644 --- a/tests/unit/hf_data_processing_test.py +++ b/tests/unit/hf_data_processing_test.py @@ -23,8 +23,8 @@ from jax.experimental import mesh_utils from MaxText import pyconfig -from MaxText.input_pipeline import _hf_data_processing -from MaxText.input_pipeline import input_pipeline_interface +from maxtext.input_pipeline import hf_data_processing +from maxtext.input_pipeline import input_pipeline_interface from maxtext.common.gcloud_stub import is_decoupled from tests.utils.test_helpers import get_test_config_path, get_test_base_output_directory @@ -73,7 +73,7 @@ def setUp(self): self.mesh, ) - self.train_iter = _hf_data_processing.make_hf_train_iterator(self.config, self.mesh, self.process_indices) + self.train_iter = hf_data_processing.make_hf_train_iterator(self.config, self.mesh, self.process_indices) def test_train_ds(self): expected_shape = [jax.device_count(), self.config.max_target_length] @@ -94,7 +94,7 @@ def test_train_ds(self): def test_batch_determinism(self): batch1 = next(self.train_iter) - train_iter = _hf_data_processing.make_hf_train_iterator(self.config, self.mesh, self.process_indices) + train_iter = hf_data_processing.make_hf_train_iterator(self.config, self.mesh, self.process_indices) batch2 = next(train_iter) self.assertTrue((batch1["inputs"] == batch2["inputs"]).all()) self.assertTrue((batch1["targets"] == batch2["targets"]).all()) diff --git a/tests/unit/instruction_data_processing_test.py b/tests/unit/instruction_data_processing_test.py index 38bfa806bc..396b336c92 100644 --- a/tests/unit/instruction_data_processing_test.py +++ b/tests/unit/instruction_data_processing_test.py @@ -16,7 +16,7 @@ import unittest -from MaxText.input_pipeline import instruction_data_processing +from maxtext.input_pipeline import instruction_data_processing class InstructionDataProcessingTest(unittest.TestCase): diff --git a/tests/unit/multihost_dataloading_test.py b/tests/unit/multihost_dataloading_test.py index 3404ec51d6..c9fa14e540 100644 --- a/tests/unit/multihost_dataloading_test.py +++ b/tests/unit/multihost_dataloading_test.py @@ -28,8 +28,8 @@ import tensorflow as tf from MaxText import pyconfig -from MaxText import multihost_dataloading from tests.utils.test_helpers import get_test_config_path, get_test_dataset_path, get_test_base_output_directory +from maxtext.input_pipeline import multihost_dataloading class MultihostDataloadingTest(unittest.TestCase): diff --git a/tests/unit/multimodal_rope_check.py b/tests/unit/multimodal_rope_check.py index 21f92d5081..2c0c380247 100644 --- a/tests/unit/multimodal_rope_check.py +++ b/tests/unit/multimodal_rope_check.py @@ -30,19 +30,19 @@ apply_rotary_pos_emb, ) -from MaxText import multimodal_utils -from MaxText.input_pipeline._input_pipeline_utils import ComputeQwen3OmniPositions from MaxText.layers.embeddings import Qwen3OmniMoeThinkerTextRotaryEmbedding as JaxMRoPE +from maxtext.input_pipeline.input_pipeline_utils import ComputeQwen3OmniPositions +from maxtext.multimodal import processor_qwen3_omni # Qwen3-Omni special token IDs -VISION_START = multimodal_utils.QWEN3_OMNI_VISION_START_TOKEN -VISION_END = multimodal_utils.QWEN3_OMNI_VISION_END_TOKEN -AUDIO_START = multimodal_utils.QWEN3_OMNI_AUDIO_START_TOKEN -AUDIO_END = multimodal_utils.QWEN3_OMNI_AUDIO_END_TOKEN -IMAGE_TOKEN = multimodal_utils.QWEN3_OMNI_IMAGE_TOKEN -VIDEO_TOKEN = multimodal_utils.QWEN3_OMNI_VIDEO_TOKEN -AUDIO_TOKEN = multimodal_utils.QWEN3_OMNI_AUDIO_TOKEN +VISION_START = processor_qwen3_omni.QWEN3_OMNI_VISION_START_TOKEN +VISION_END = processor_qwen3_omni.QWEN3_OMNI_VISION_END_TOKEN +AUDIO_START = processor_qwen3_omni.QWEN3_OMNI_AUDIO_START_TOKEN +AUDIO_END = processor_qwen3_omni.QWEN3_OMNI_AUDIO_END_TOKEN +IMAGE_TOKEN = processor_qwen3_omni.QWEN3_OMNI_IMAGE_TOKEN +VIDEO_TOKEN = processor_qwen3_omni.QWEN3_OMNI_VIDEO_TOKEN +AUDIO_TOKEN = processor_qwen3_omni.QWEN3_OMNI_AUDIO_TOKEN def create_pytorch_config(head_dim=128, mrope_section=(24, 20, 20), rope_max_timescale=1_000_000): @@ -109,7 +109,7 @@ def create_audio_in_video_sequence( np.ndarray of interleaved token IDs for audio-in-video. """ # Compute token counts - expected_audio_tokens = int(multimodal_utils._get_feat_extract_output_lengths(jnp.array(audio_lengths[0])).item()) # pylint: disable=protected-access + expected_audio_tokens = int(processor_qwen3_omni._get_feat_extract_output_lengths(jnp.array(audio_lengths[0])).item()) # pylint: disable=protected-access # Video tokens video_tokens_per_frame = (video_grid_thw[0, 1] // spatial_merge_size) * (video_grid_thw[0, 2] // spatial_merge_size) @@ -255,7 +255,7 @@ def compare_with_pytorch( Returns: Tuple of (jax_position_ids, pytorch_position_ids, match_status) """ - jax_position_ids_np, jax_deltas_np = multimodal_utils.get_rope_index( + jax_position_ids_np, jax_deltas_np = processor_qwen3_omni.get_rope_index( input_ids=input_ids, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, @@ -435,7 +435,7 @@ def test_single_audio(self): # Compute expected audio tokens from raw length audio_lengths = np.array([1600]) # pylint: disable=protected-access - expected_tokens = int(multimodal_utils._get_feat_extract_output_lengths(jnp.array(1600)).item()) + expected_tokens = int(processor_qwen3_omni._get_feat_extract_output_lengths(jnp.array(1600)).item()) audio_tokens = [AUDIO_TOKEN] * expected_tokens input_ids = np.array([[AUDIO_START, *audio_tokens, AUDIO_END]]) @@ -645,7 +645,7 @@ def test_transform_wrapper(self): self.assertIn("inputs_mrope_deltas", result) # Verify it matches direct get_rope_index call - expected_pos, expected_deltas = multimodal_utils.get_rope_index( + expected_pos, expected_deltas = processor_qwen3_omni.get_rope_index( input_ids=jnp.array(element["inputs"]), image_grid_thw=jnp.array(element["image_grid_thw"]), video_grid_thw=None, diff --git a/tests/unit/sft_data_processing_test.py b/tests/unit/sft_data_processing_test.py index 00c43cbc35..8b704fc2af 100644 --- a/tests/unit/sft_data_processing_test.py +++ b/tests/unit/sft_data_processing_test.py @@ -27,9 +27,9 @@ from MaxText import pyconfig from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_CONFIGS_DIR, MAXTEXT_ASSETS_ROOT -from MaxText.input_pipeline import _hf_data_processing -from MaxText.input_pipeline import input_pipeline_interface -from MaxText.input_pipeline._hf_data_processing import _get_pad_id +from maxtext.input_pipeline import hf_data_processing +from maxtext.input_pipeline import input_pipeline_interface +from maxtext.input_pipeline.hf_data_processing import _get_pad_id PROMPT_DATA = [ [ @@ -347,7 +347,7 @@ def setUp(self): def get_data_iterator(self, train_ds, data_columns): """Get data iterator.""" - return _hf_data_processing.preprocessing_pipeline( + return hf_data_processing.preprocessing_pipeline( dataloading_host_index=self.process_indices.index(jax.process_index()), dataloading_host_count=len(self.process_indices), global_mesh=self.mesh, diff --git a/tests/unit/tfds_data_processing_test.py b/tests/unit/tfds_data_processing_test.py index 08b5dcad24..8ce5aaa8f1 100644 --- a/tests/unit/tfds_data_processing_test.py +++ b/tests/unit/tfds_data_processing_test.py @@ -26,8 +26,8 @@ from MaxText import pyconfig from MaxText.globals import MAXTEXT_ASSETS_ROOT -from MaxText.input_pipeline import _tfds_data_processing -from MaxText.input_pipeline import input_pipeline_interface +from maxtext.input_pipeline import tfds_data_processing +from maxtext.input_pipeline import input_pipeline_interface from maxtext.common.gcloud_stub import is_decoupled from tests.utils.test_helpers import get_test_config_path, get_test_dataset_path, get_test_base_output_directory @@ -76,8 +76,8 @@ def setUp(self): ) self.read_config.add_tfds_id = True self.train_ds = self._get_datasets() - self.train_iter = _tfds_data_processing.make_tfds_train_iterator(self.config, self.mesh, self.process_indices) - self.eval_iter = _tfds_data_processing.make_tfds_eval_iterator(self.config, self.mesh, self.process_indices) + self.train_iter = tfds_data_processing.make_tfds_train_iterator(self.config, self.mesh, self.process_indices) + self.eval_iter = tfds_data_processing.make_tfds_eval_iterator(self.config, self.mesh, self.process_indices) def _get_datasets(self): ds_builder = tfds.builder(self.config.dataset_name) @@ -120,7 +120,7 @@ def test_ds_determinism(self): def test_batch_determinism(self): batch1 = next(self.train_iter) - train_iter = _tfds_data_processing.make_tfds_train_iterator(self.config, self.mesh, self.process_indices) + train_iter = tfds_data_processing.make_tfds_train_iterator(self.config, self.mesh, self.process_indices) batch2 = next(train_iter) self.assertTrue(tf.reduce_all(tf.equal(batch1["inputs"], batch2["inputs"]))) self.assertTrue(tf.reduce_all(tf.equal(batch1["targets"], batch2["targets"]))) diff --git a/tests/unit/tokenizer_test.py b/tests/unit/tokenizer_test.py index 100c0076a7..c03c6f4cc5 100644 --- a/tests/unit/tokenizer_test.py +++ b/tests/unit/tokenizer_test.py @@ -16,8 +16,8 @@ import numpy as np from MaxText import train_tokenizer -from MaxText.input_pipeline import _input_pipeline_utils from MaxText.globals import MAXTEXT_ASSETS_ROOT +from maxtext.input_pipeline import input_pipeline_utils import unittest import pytest @@ -38,7 +38,7 @@ def setUpClass(cls): assets_path = "tests" vocab_model_name = "test_tokenizer" cls.tokenizer_path = os.path.join(assets_path, vocab_model_name) - cls.source_tokenizer = _input_pipeline_utils.get_tokenizer( + cls.source_tokenizer = input_pipeline_utils.get_tokenizer( os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"), "sentencepiece", add_bos=False, @@ -56,7 +56,7 @@ def setUpClass(cls): vocab_size=cls.vocab_size, max_corpus_chars=cls.max_corpus_chars, ) - cls.test_tokenizer = _input_pipeline_utils.get_tokenizer( + cls.test_tokenizer = input_pipeline_utils.get_tokenizer( cls.tokenizer_path, "sentencepiece", add_bos=False, add_eos=False ) @@ -82,7 +82,7 @@ class TikTokenTest(unittest.TestCase): def setUpClass(cls): dataset_name = "c4/en:3.0.1" dataset_path = "gs://maxtext-dataset" - cls.source_tokenizer = _input_pipeline_utils.get_tokenizer( + cls.source_tokenizer = input_pipeline_utils.get_tokenizer( os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer_llama3.tiktoken"), "tiktoken", add_bos=False, @@ -119,10 +119,10 @@ def setUpClass(cls): ["gcloud", "storage", "cp", "-R", source, destination], check=True, ) - cls.hf_tokenizer = _input_pipeline_utils.get_tokenizer( + cls.hf_tokenizer = input_pipeline_utils.get_tokenizer( os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "gemma2-2b"), "huggingface", add_bos=False, add_eos=False ) - cls.sp_tokenizer = _input_pipeline_utils.get_tokenizer( + cls.sp_tokenizer = input_pipeline_utils.get_tokenizer( os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.gemma"), "sentencepiece", add_bos=False, add_eos=False ) diff --git a/tests/unit/tokenizer_transform_test.py b/tests/unit/tokenizer_transform_test.py index 825ca15f61..0c2bfc3274 100644 --- a/tests/unit/tokenizer_transform_test.py +++ b/tests/unit/tokenizer_transform_test.py @@ -19,8 +19,8 @@ import grain.python as grain import numpy as np -from MaxText.input_pipeline import _grain_tokenizer -from MaxText.input_pipeline import _input_pipeline_utils +from maxtext.input_pipeline import grain_tokenizer +from maxtext.input_pipeline import input_pipeline_utils from numpy.testing import assert_array_equal @@ -51,7 +51,7 @@ def setUp(self): def test_tokenize_and_trim(self): """Tests the 1:1 MapTransform (truncation) logic.""" - trim_op = _grain_tokenizer.TokenizeAndTrim( + trim_op = grain_tokenizer.TokenizeAndTrim( feature_names=self.feature_names, sequence_length=self.max_len, tokenizer=self.mock_tokenizer ) trim_ds = self.base_ds.map(trim_op) @@ -70,7 +70,7 @@ def test_tokenize_and_trim(self): def test_tokenize_and_chunk(self): """Tests the 1:N FlatMapTransform (chunking) logic.""" - chunk_op = _grain_tokenizer.TokenizeAndChunk( + chunk_op = grain_tokenizer.TokenizeAndChunk( feature_names=self.feature_names, sequence_length=self.max_len, tokenizer=self.mock_tokenizer ) chunk_ds = self.base_ds.apply(chunk_op) @@ -90,10 +90,10 @@ def test_tokenize_and_chunk(self): def test_trim_and_pad_chaining(self): """Tests chaining TokenizeAndTrim.map() -> PadOrTrimToMaxLength.map()""" - trim_op = _grain_tokenizer.TokenizeAndTrim( + trim_op = grain_tokenizer.TokenizeAndTrim( feature_names=self.feature_names, sequence_length=self.max_len, tokenizer=self.mock_tokenizer ) - pad_op = _input_pipeline_utils.PadOrTrimToMaxLength(max_length=self.pad_length, pad_id=self.pad_id) + pad_op = input_pipeline_utils.PadOrTrimToMaxLength(max_length=self.pad_length, pad_id=self.pad_id) chained_ds = self.base_ds.map(trim_op).map(pad_op) results = list(chained_ds) self.assertEqual(len(results), len(self.source_data)) @@ -110,10 +110,10 @@ def test_trim_and_pad_chaining(self): def test_chunk_and_pad_chaining(self): """Tests chaining TokenizeAndChunk.apply() -> PadOrTrimToMaxLength.map()""" - chunk_op = _grain_tokenizer.TokenizeAndChunk( + chunk_op = grain_tokenizer.TokenizeAndChunk( feature_names=self.feature_names, sequence_length=self.max_len, tokenizer=self.mock_tokenizer ) - pad_op = _input_pipeline_utils.PadOrTrimToMaxLength(max_length=self.pad_length, pad_id=self.pad_id) + pad_op = input_pipeline_utils.PadOrTrimToMaxLength(max_length=self.pad_length, pad_id=self.pad_id) chained_ds = self.base_ds.apply(chunk_op).map(pad_op) results = list(chained_ds) self.assertEqual(len(results), 5) diff --git a/tools/data_generation/generate_distillation_data.py b/tools/data_generation/generate_distillation_data.py index a1b259723f..06852a5d5e 100644 --- a/tools/data_generation/generate_distillation_data.py +++ b/tools/data_generation/generate_distillation_data.py @@ -60,7 +60,7 @@ from maxtext.utils import gcs_utils from maxtext.utils import max_logging -from MaxText.input_pipeline import _distillation_data_processing +from maxtext.input_pipeline import distillation_data_processing from jetstream.core.proto import jetstream_pb2 from jetstream.core.proto import jetstream_pb2_grpc @@ -231,7 +231,7 @@ def upload_data(config, data, batch_num): # pylint: disable=redefined-outer-nam def generate_data(config): # pylint: disable=redefined-outer-name """Generates data for distillation.""" - dataset = _distillation_data_processing.load_dataset(config) + dataset = distillation_data_processing.load_dataset(config) tokenizer = transformers.AutoTokenizer.from_pretrained( config.tokenizer_path, @@ -244,8 +244,8 @@ def generate_data(config): # pylint: disable=redefined-outer-name data = dataset[start_idx : start_idx + config.batch_size] start_idx += config.batch_size sampled_dataset = Dataset.from_dict(data) - sampled_dataset = _distillation_data_processing.process_dataset(config, sampled_dataset) - requests = _distillation_data_processing.filter_dataset(config, sampled_dataset, tokenizer) + sampled_dataset = distillation_data_processing.process_dataset(config, sampled_dataset) + requests = distillation_data_processing.filter_dataset(config, sampled_dataset, tokenizer) distillation_data = generate_completions(config, requests, tokenizer) upload_data(config, distillation_data, batch_num)