Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ repos:
rev: 0.7.22
hooks:
- id: mdformat
args: ['--number']
additional_dependencies: [mdformat-myst, mdformat-ruff]
files: (docs/.)
exclude: docs/guides/checkpointing_solutions.md
1 change: 0 additions & 1 deletion docs/guides/data_input_pipeline/data_input_hf.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,3 @@ tokenizer_path: 'google-t5/t5-large' # for using https://huggingface.co/google-

1. Streaming data directly from Hugging Face Hub may be impacted by the traffic of the server. During peak hours you may encounter "504 Server Error: Gateway Time-out". It's recommended to download the Hugging Face dataset to a Cloud Storage bucket or disk for the most stable experience.
2. Streaming data directly from Hugging Face Hub works in multi-host settings with a small number of hosts. With a host number larger than 16, you might encounter a "read time out" error.
3. Only supports `num_epoch=1` at the moment.
10 changes: 5 additions & 5 deletions docs/guides/optimization/pallas_kernels_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,11 @@ Dispatch a kernel on multiple devices with `jax.shard_map`. It’s usually simpl
## ✅ Putting it all together (checklist)

1. **Profile** the baseline using `named_scope` and `block_until_ready`.
1. **Tile arrays into smaller chunks using BlockSpecs** (virtually always necessary, even for simple kernels).
1. Build a **sweep harness** for block shapes (and optionally scalar prefetch grid choices).
1. **Validate** end-to-end performance in the model, not just microbenchmarks.
1. Consider **maintainability** and guard the new kernel with tests.
1. Consider applying **`jax.vmap`** to a Pallas kernel to simplify implementation; think of it as prepending grid dimensions automatically.
2. **Tile arrays into smaller chunks using BlockSpecs** (virtually always necessary, even for simple kernels).
3. Build a **sweep harness** for block shapes (and optionally scalar prefetch grid choices).
4. **Validate** end-to-end performance in the model, not just microbenchmarks.
5. Consider **maintainability** and guard the new kernel with tests.
6. Consider applying **`jax.vmap`** to a Pallas kernel to simplify implementation; think of it as prepending grid dimensions automatically.

## 📚 References

Expand Down
6 changes: 3 additions & 3 deletions docs/run_maxtext/run_maxtext_localhost.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Before you can begin a training run, you need to configure your storage environm
You'll need a GCS bucket to store all your training artifacts, such as logs, metrics, and model checkpoints.

1. In your Google Cloud project, create a new storage bucket.
1. Your TPU or GPU VMs require read/write access to this bucket. The simplest way to grant this is by assigning the `Storage Admin` (`roles/storage.admin`) role to the service account associated with your VMs.
2. Your TPU or GPU VMs require read/write access to this bucket. The simplest way to grant this is by assigning the `Storage Admin` (`roles/storage.admin`) role to the service account associated with your VMs.

### Setup MaxText

Expand All @@ -36,14 +36,14 @@ Local development on a single host TPU/GPU VM is a convenient way to run MaxText

1. Create and SSH to the single host VM of your choice. You can use any available single host TPU, such as `v5litepod-8`, `v5p-8`, or `v4-8`. For GPUs, you can use `nvidia-h100-mega-80gb`, `nvidia-h200-141gb`, or `nvidia-b200`. For setting up a TPU VM, use the Cloud TPU documentation available at https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm. For a GPU setup, refer to the guide at https://cloud.google.com/compute/docs/gpus/create-vm-with-gpus.

1. Clone MaxText onto that VM.
2. Clone MaxText onto that VM.

```bash
git clone https://github.com/google/maxtext.git
cd maxtext
```

1. Once you have cloned the repository, you have two primary options for setting up the necessary dependencies on your VM: Installing in a Python Environment, or building a Docker container. For single host workloads, we recommend to install dependencies in a python environment, and for multihost workloads we recommend the containerized approach.
3. Once you have cloned the repository, you have two primary options for setting up the necessary dependencies on your VM: Installing in a Python Environment, or building a Docker container. For single host workloads, we recommend to install dependencies in a python environment, and for multihost workloads we recommend the containerized approach.

Within the root directory of the cloned repo, create a virtual environment and install dependencies and the pre-commit hook by running:

Expand Down
2 changes: 1 addition & 1 deletion src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ train_image_column: 'image'
eval_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected"
eval_image_column: 'image'
packing: True
num_epoch: 1 # only grain and tfds pipeline supports num_epoch > 1
num_epoch: 1
generate_padding_batch_train: False
generate_padding_batch_eval: False
# Maximum number of segments that can be packed into a single sequence
Expand Down
28 changes: 26 additions & 2 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2288,10 +2288,34 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
):
logger.warning("`tokenizer_type` is not 'tiktoken' when using llama3 tokenizer. Overriding to 'tiktoken'.")
self.tokenizer_type = TokenizerType.TIKTOKEN
# Data input validations
if self.dataset_type == DatasetType.HF:
if not self.hf_path:
raise ValueError("hf_path can't be empty when dataset_type=hf")
if self.hf_eval_files:
self.hf_eval_split = "train"
if self.eval_interval > 0 and not self.hf_eval_split:
raise ValueError("Please specify hf_eval_split or set eval_interval to <=0.")
elif self.dataset_type == DatasetType.GRAIN:
if not self.grain_train_files and not self.grain_train_mixture_config_path:
raise ValueError("When dataset_type=grain, please set grain_train_files or grain_train_mixture_config_path")
if self.eval_interval > 0 and not self.grain_eval_files:
raise ValueError("Please specify grain_eval_files or set eval_interval to <=0.")
if self.tokenizer_type not in (TokenizerType.SENTENCEPIECE, TokenizerType.HUGGINGFACE):
raise ValueError(
f"grain pipeline only supports tokenizer_type: sentencepiece, huggingface, but got {self.tokenizer_type}"
)
elif self.dataset_type == DatasetType.TFDS:
if not self.dataset_name:
raise ValueError("dataset_name can't be empty when dataset_type=tfds")
if self.eval_interval > 0 and not self.eval_split:
raise ValueError("Please specify eval_split or set eval_interval to <=0.")

if self.sharding_tolerance > 1.0 or self.sharding_tolerance < 0.0:
logger.warning("'sharding_tolerance: allowed percentage of non-sharded parameters' should be between 0.0 and 1.0")

if self.eval_interval > 0 >= self.eval_steps and self.generate_padding_batch_eval:
raise ValueError("`eval_steps` must be > 0 when `generate_padding_batch_eval` is True.")
if self.dataset_type == "hf" and self.num_epoch != 1:
raise ValueError("HuggingFace pipeline only supports num_epoch=1.")
if self.rl.loss_algo == "grpo":
self.use_grpo = True
else:
Expand Down
18 changes: 16 additions & 2 deletions src/MaxText/input_pipeline/_hf_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,14 @@ def vision_sft_preprocessing_pipeline(
else:
batch_size = global_batch_size // jax.process_count()

if config.enable_data_shuffling:
# for multi-epoch with shuffle, shuffle each epoch with different seeds then concat
if config.enable_data_shuffling and config.num_epoch > 1:
Comment thread
aireenmei marked this conversation as resolved.
epoch_datasets = [dataset.shuffle(seed=config.data_shuffle_seed + i) for i in range(config.num_epoch)]
dataset = datasets.concatenate_datasets(epoch_datasets)
elif config.enable_data_shuffling:
dataset = dataset.shuffle(seed=config.data_shuffle_seed)
elif config.num_epoch > 1:
dataset = dataset.repeat(config.num_epoch)

# If multiple image columns are provided, merge them into a single 'images' column.
if isinstance(image_column, list):
Expand Down Expand Up @@ -206,6 +212,7 @@ def preprocessing_pipeline(
sft_train_on_completion_only=True,
grain_worker_count=1, # only support 0 or 1
max_segments_per_seq=None,
num_epoch=1,
):
"""pipeline for preprocessing HF dataset"""

Expand All @@ -217,8 +224,14 @@ def preprocessing_pipeline(
else:
batch_size = global_batch_size // jax.process_count()

if shuffle:
# for multi-epoch with shuffle, shuffle each epoch with different seeds then concat
if shuffle and num_epoch > 1:
epoch_datasets = [dataset.shuffle(seed=data_shuffle_seed + i) for i in range(num_epoch)]
dataset = datasets.concatenate_datasets(epoch_datasets)
elif shuffle:
dataset = dataset.shuffle(seed=data_shuffle_seed)
elif num_epoch > 1:
dataset = dataset.repeat(num_epoch)

tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer_path,
Expand Down Expand Up @@ -409,6 +422,7 @@ def make_hf_train_iterator(
sft_train_on_completion_only=config.sft_train_on_completion_only,
chat_template_path=config.chat_template_path,
max_segments_per_seq=config.max_segments_per_seq,
num_epoch=config.num_epoch,
)
return train_iter

Expand Down
1 change: 1 addition & 0 deletions src/MaxText/pyconfig_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def validate_tokamax_usage(keys):
raise ValueError(f"Invalid tokamax's megablox kernel usage for hardware {keys['hardware']}. Only TPU is supported.")


# All data input validations have been migrated to config/types.py
def validate_data_input(keys):
"""validate provided parameters for data input"""
if not keys["hf_access_token"]:
Expand Down
Loading