Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/MaxText/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

from MaxText import pyconfig
from MaxText.layers import models
from maxtext.trainers.post_train.dpo import dpo_utils
from maxtext.utils import maxtext_utils
from maxtext.utils import model_creation_utils
from maxtext.utils.model_creation_utils import from_config
Expand Down
6 changes: 2 additions & 4 deletions src/MaxText/gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def gradient_accumulation_loss_and_grad(
params_shardings,
data,
dropout_rng,
extra_dpo_args,
):
"""
Calculates gradients using gradient accumulation.
Expand All @@ -45,7 +44,7 @@ def gradient_accumulation_loss_and_grad(

Args:
_loss_fn: The loss function to differentiate. Its signature is expected
to be: `(model, config, data, dropout_rng, params, *extra_args, is_train=True)`.
to be: `(model, config, data, dropout_rng, params, is_train=True)`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we removing the *extra_args here?

config: Model and training configuration object. Must contain
`gradient_accumulation_steps` and `shard_optimizer_over_data`.
model: The model module.
Expand All @@ -54,7 +53,6 @@ def gradient_accumulation_loss_and_grad(
data: A PyTree of batched data. The leading dimension is assumed
to be the total batch size (microbatch_size * num_accumulations).
dropout_rng: JAX PRNGKey for dropout.
extra_dpo_args: A tuple of extra arguments to pass to the loss function.

Returns:
A tuple containing:
Expand Down Expand Up @@ -91,7 +89,7 @@ def convert_to_bf16(param):

def accumulate_gradient(acc_grad_and_loss, data):
ga_params = acc_grad_and_loss["ga_params"]
(_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, ga_params, *extra_dpo_args, is_train=True)
(_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, ga_params, is_train=True)
acc_grad_and_loss["loss"] += aux["total_loss"]
acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"]
acc_grad_and_loss["mtp_loss"] += aux["mtp_loss"]
Expand Down
155 changes: 32 additions & 123 deletions src/MaxText/input_pipeline/_grain_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,57 +303,6 @@ def pretrain_preprocessing_pipeline(
return dataset


def dpo_preprocessing_pipeline(
dataset,
config,
data_columns,
tokenize,
grain_worker_count,
grain_per_worker_buffer_size,
):
"""Use grain to pre-process the dataset and return iterators for dpo fine-tuning"""
if config.grain_file_type == "arrayrecord":
dataset = dataset.map(_input_pipeline_utils.ParseFeatures(data_columns, tokenize))
dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
tokenizer_model = tokenizer.build_tokenizer(
config.tokenizer_path,
config.tokenizer_type,
config.add_bos,
config.add_eos,
config.hf_access_token,
config.dataset_type,
)
if tokenizer_model.pad_id is not None:
pad_id = tokenizer_model.pad_id
elif tokenizer_model.unk_id is not None:
pad_id = tokenizer_model.unk_id
else:
pad_id = -1

if tokenize:
dataset = dataset.map(_grain_tokenizer.TokenizeAndTrim(data_columns, config.max_target_length, tokenizer_model))

dataset = dataset.map(_input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id))
batch_size = config.global_batch_size_to_load // jax.process_count()
batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id)
dataset = dataset.batch(batch_size, batch_fn=batch_fn)
multiprocessing_options = (
pick_performance_config(
ds=dataset,
ram_budget_mb=config.grain_ram_budget_mb,
max_workers=None,
max_buffer_size=None,
).multiprocessing_options
if grain_worker_count == -1
else grain.MultiprocessingOptions(
num_workers=grain_worker_count,
per_worker_buffer_size=grain_per_worker_buffer_size,
)
)
dataset = dataset.mp_prefetch(multiprocessing_options)
return dataset


def make_grain_train_iterator(
config: ml_collections.ConfigDict,
global_mesh,
Expand All @@ -378,24 +327,14 @@ def make_grain_train_iterator(
grain_data_source_max_workers=config.grain_data_source_max_workers,
mixture_config_path=config.grain_train_mixture_config_path,
)
if config.use_dpo:
train_dataloader = dpo_preprocessing_pipeline(
train_ds,
config,
data_columns=config.train_data_columns,
tokenize=config.tokenize_train_data,
grain_worker_count=config.grain_worker_count,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size,
)
else:
train_dataloader = pretrain_preprocessing_pipeline(
train_ds,
config,
data_columns=config.train_data_columns,
tokenize=config.tokenize_train_data,
grain_worker_count=config.grain_worker_count,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size,
)
train_dataloader = pretrain_preprocessing_pipeline(
train_ds,
config,
data_columns=config.train_data_columns,
tokenize=config.tokenize_train_data,
grain_worker_count=config.grain_worker_count,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size,
)
return multihost_dataloading.MultiHostDataLoadIterator(
train_dataloader,
global_mesh,
Expand All @@ -415,24 +354,14 @@ def make_grain_train_iterator(
grain_prefetch_buffer_size=config.grain_prefetch_buffer_size,
grain_data_source_max_workers=config.grain_data_source_max_workers,
)
if config.use_dpo:
preprocessing_fn = functools.partial(
pretrain_preprocessing_pipeline,
config=config,
data_columns=config.train_data_columns,
tokenize=config.tokenize_train_data,
grain_worker_count=config.grain_worker_count,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size,
)
else:
preprocessing_fn = functools.partial(
pretrain_preprocessing_pipeline,
config=config,
data_columns=config.train_data_columns,
tokenize=config.tokenize_train_data,
grain_worker_count=config.grain_worker_count,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size,
)
preprocessing_fn = functools.partial(
pretrain_preprocessing_pipeline,
config=config,
data_columns=config.train_data_columns,
tokenize=config.tokenize_train_data,
grain_worker_count=config.grain_worker_count,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size,
)
if config.colocated_python_data_input:
global_shape = (config.global_batch_size_to_load, config.max_target_length)
return multihost_dataloading.RemoteIterator(get_ds_fn, preprocessing_fn, global_mesh, global_shape)
Expand Down Expand Up @@ -475,24 +404,14 @@ def make_grain_eval_iterator(
grain_prefetch_buffer_size=config.grain_prefetch_buffer_size_eval,
grain_data_source_max_workers=config.grain_data_source_max_workers,
)
if config.use_dpo:
eval_dataloader = dpo_preprocessing_pipeline(
eval_ds,
config,
data_columns=config.eval_data_columns,
tokenize=config.tokenize_eval_data,
grain_worker_count=config.grain_worker_count_eval,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
)
else:
eval_dataloader = pretrain_preprocessing_pipeline(
eval_ds,
config,
data_columns=config.eval_data_columns,
tokenize=config.tokenize_eval_data,
grain_worker_count=config.grain_worker_count_eval,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
)
eval_dataloader = pretrain_preprocessing_pipeline(
eval_ds,
config,
data_columns=config.eval_data_columns,
tokenize=config.tokenize_eval_data,
grain_worker_count=config.grain_worker_count_eval,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
)
return multihost_dataloading.MultiHostDataLoadIterator(
eval_dataloader, global_mesh, config.generate_padding_batch_eval
)
Expand All @@ -509,23 +428,13 @@ def make_grain_eval_iterator(
grain_prefetch_buffer_size=config.grain_prefetch_buffer_size_eval,
grain_data_source_max_workers=config.grain_data_source_max_workers,
)
if config.use_dpo:
preprocessing_fn = functools.partial(
dpo_preprocessing_pipeline,
config=config,
data_columns=config.eval_data_columns,
tokenize=config.tokenize_eval_data,
grain_worker_count=config.grain_worker_count_eval,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
)
else:
preprocessing_fn = functools.partial(
pretrain_preprocessing_pipeline,
config=config,
data_columns=config.eval_data_columns,
tokenize=config.tokenize_eval_data,
grain_worker_count=config.grain_worker_count_eval,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
)
preprocessing_fn = functools.partial(
pretrain_preprocessing_pipeline,
config=config,
data_columns=config.eval_data_columns,
tokenize=config.tokenize_eval_data,
grain_worker_count=config.grain_worker_count_eval,
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
)
global_shape = (config.global_batch_size_to_load, config.max_target_length)
return multihost_dataloading.RemoteIterator(get_ds_fn, preprocessing_fn, global_mesh, global_shape)
16 changes: 2 additions & 14 deletions src/MaxText/input_pipeline/_hf_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@

import grain.python as grain

import numpy as np

from MaxText.input_pipeline import _input_pipeline_utils
from MaxText.input_pipeline import instruction_data_processing
from MaxText import multihost_dataloading
Expand Down Expand Up @@ -205,7 +203,6 @@ def preprocessing_pipeline(
num_threads=1,
drop_remainder=True,
generate_padding_batch=False,
use_dpo=None,
use_sft=None,
use_tunix_gradient_accumulation=False,
num_microbatches=1,
Expand Down Expand Up @@ -312,19 +309,12 @@ def preprocessing_pipeline(
)
)
data_column_names = ("inputs", "targets")
elif use_dpo:

def lists2array(x):
"""Convert lists/tuples to array"""
return jax.tree.map(np.asarray, x, is_leaf=lambda y: isinstance(y, (list, tuple)))

operations.append(grain.MapOperation(lists2array))
else:
assert len(data_column_names) == 1
operations.append(_input_pipeline_utils.HFNormalizeFeatures(data_column_names[0]))
data_column_names = ("inputs", "targets")

if packing and not use_dpo:
if packing:
length_struct = {col: max_target_length for col in data_column_names}
max_segments = max_segments_per_seq
if max_segments is not None and max_segments <= 0:
Expand All @@ -341,7 +331,7 @@ def lists2array(x):
operations.append(_input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id))
operations.append(grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder))

if shift and not use_dpo:
if shift:
operations.append(_input_pipeline_utils.ShiftData(ignored_ids=[pad_id, tokenizer.bos_token_id], axis=1))

# Since HuggingFace IterableDataset does not support access through index
Expand Down Expand Up @@ -418,7 +408,6 @@ def make_hf_train_iterator(
add_eos=config.add_eos,
packing=config.packing,
generate_padding_batch=config.generate_padding_batch_train,
use_dpo=config.use_dpo,
use_sft=config.use_sft,
use_tunix_gradient_accumulation=config.use_tunix_gradient_accumulation,
num_microbatches=config.gradient_accumulation_steps,
Expand Down Expand Up @@ -476,7 +465,6 @@ def make_hf_eval_iterator(
add_eos=config.add_eos,
packing=config.packing,
generate_padding_batch=config.generate_padding_batch_eval,
use_dpo=config.use_dpo,
use_sft=config.use_sft,
num_microbatches=config.gradient_accumulation_steps,
sft_train_on_completion_only=config.sft_train_on_completion_only,
Expand Down
24 changes: 8 additions & 16 deletions src/MaxText/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,15 @@ def preprocessing_pipeline(
shift: bool = True,
drop_remainder: bool = True,
prefetch_size=tf.data.experimental.AUTOTUNE,
use_dpo: bool = False,
hf_access_token: str = "",
):
"""pipeline for preprocessing TFDS dataset."""
if not use_dpo:
assert len(data_column_names) == 1
dataset = dataset.map(
lambda x: _input_pipeline_utils.normalize_features(x, data_column_names[0]), num_parallel_calls=AUTOTUNE
)
else:
dataset = dataset.map(lambda x: {col: x[col] for col in data_column_names}, num_parallel_calls=AUTOTUNE)
assert len(data_column_names) == 1
dataset = dataset.map(
lambda x: _input_pipeline_utils.normalize_features(x, data_column_names[0]), num_parallel_calls=AUTOTUNE
)

data_column_names = data_column_names if use_dpo else ("inputs", "targets")
data_column_names = ("inputs", "targets")

tokenizer_model = _input_pipeline_utils.get_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token)
if tokenizer_model.pad_id is not None:
Expand All @@ -123,7 +119,7 @@ def preprocessing_pipeline(
if max_target_length > 0:
# in pre-training we can take upto max_length+1 because there would be truncation by
# 1 token for both inputs and targets
extra_tokens = 1 if not use_dpo else 0
extra_tokens = 1
dataset = dataset.map(
lambda x: _input_pipeline_utils.truncate_to_max_allowable_length(x, max_target_length + extra_tokens),
num_parallel_calls=AUTOTUNE,
Expand All @@ -136,13 +132,13 @@ def preprocessing_pipeline(
dataset = dataset.repeat(num_epochs)

# Shift inputs for teacher-forced training
if shift and not use_dpo:
if shift:
dataset = dataset.map(
_input_pipeline_utils.shift_data_by_truncation, num_parallel_calls=tf.data.AUTOTUNE, deterministic=True
)

# Perform greedy sequence packing and batching
if pack_examples and not use_dpo:
if pack_examples:
dataset = sequence_packing.pack_dataset(dataset, max_target_length, pad_id)
dataset = dataset.batch(global_batch_size // jax.process_count(), drop_remainder=drop_remainder)
else:
Expand Down Expand Up @@ -202,7 +198,6 @@ def make_tfds_train_iterator(
add_eos=config.add_eos,
num_epochs=config.num_epoch,
pack_examples=config.packing,
use_dpo=config.use_dpo,
hf_access_token=config.hf_access_token,
)
return multihost_dataloading.MultiHostDataLoadIterator(
Expand All @@ -227,7 +222,6 @@ def make_tfds_train_iterator(
add_eos=config.add_eos,
num_epochs=config.num_epoch,
pack_examples=config.packing,
use_dpo=config.use_dpo,
hf_access_token=config.hf_access_token,
)
global_shape = (config.global_batch_size_to_load, config.max_target_length)
Expand Down Expand Up @@ -265,7 +259,6 @@ def make_tfds_eval_iterator(
add_bos=config.add_bos,
add_eos=config.add_eos,
pack_examples=config.packing,
use_dpo=config.use_dpo,
hf_access_token=config.hf_access_token,
)
return multihost_dataloading.MultiHostDataLoadIterator(
Expand All @@ -292,7 +285,6 @@ def make_tfds_eval_iterator(
add_bos=config.add_bos,
add_eos=config.add_eos,
pack_examples=config.packing,
use_dpo=config.use_dpo,
hf_access_token=config.hf_access_token,
)
return multihost_dataloading.RemoteIterator(get_ds_fn, preprocessing_fn, config, global_mesh)
Loading
Loading