diff --git a/algorithms/archived_paper_baselines/adamw/jax/submission.py b/algorithms/archived_paper_baselines/adamw/jax/submission.py index b8ea5d30a..c0ffe7601 100644 --- a/algorithms/archived_paper_baselines/adamw/jax/submission.py +++ b/algorithms/archived_paper_baselines/adamw/jax/submission.py @@ -254,6 +254,8 @@ def get_batch_size(workload_name): return 16 elif workload_name == 'cifar': return 32 + elif workload_name == 'finewebedu_lm': + return 64 else: raise ValueError(f'Unsupported workload name: {workload_name}.')