Skip to content

Fix duplicate peak learning rate in warmup schedule#3095

Open
ChingTsai wants to merge 1 commit intomainfrom
jimmytsai/fix-learning-rate-schedule
Open

Fix duplicate peak learning rate in warmup schedule#3095
ChingTsai wants to merge 1 commit intomainfrom
jimmytsai/fix-learning-rate-schedule

Conversation

@ChingTsai
Copy link
Collaborator

@ChingTsai ChingTsai commented Feb 5, 2026

Description

This bug was introduced by recent changes in this PR, which caused the warmup schedule to reach its peak one step early, resulting in a duplicate peak learning rate.

from MaxText import pyconfig
config = pyconfig.initialize(
        [None, "/mnt/disks/jimmy_workspace/maxtext_dev/maxtext/src/MaxText/configs/base.yml"],
        enable_checkpointing=False,
        learning_rate=1,
        learning_rate_schedule_steps=10,
        steps=12,
        warmup_steps_fraction=0.3,
        lr_schedule_type="cosine",
        learning_rate_final_fraction=0.1,
    )
from maxtext.utils import maxtext_utils
schedule_fn = maxtext_utils.create_learning_rate_schedule(config)
[(i, float(schedule_fn(i))) for i in range(7)]

Before

[(0, 0.0),
 (1, 0.5),
 (2, 1.0),
 (3, 1.0),
 (4, 0.9397114515304565),
 (5, 0.7749999761581421),
 (6, 0.5499999523162842)]

After

[(0, 0.0),
 (1, 0.3333333730697632),
 (2, 0.6666666865348816),
 (3, 1.0),
 (4, 0.9397114515304565),
 (5, 0.7749999761581421),
 (6, 0.5499999523162842)]

original schedule (git sha d4a259d)

[(0, 0.0),
 (1, 0.3333333730697632),
 (2, 0.6666666865348816),
 (3, 1.0),
 (4, 0.9554359912872314),
 (5, 0.8305704593658447),
 (6, 0.6501344442367554)]

b/481934309

Changes

  • Fix duplicate peak learning rate in warmup schedule
  • Add an additional unit test to ensure the delta in the warmup schedule is constant.

Tests

 python -m unittest tests.unit.maxtext_utils_test.TestLearningRateSchedules

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ChingTsai ChingTsai force-pushed the jimmytsai/fix-learning-rate-schedule branch 2 times, most recently from cac8863 to b6a6e9b Compare February 5, 2026 09:54
@codecov
Copy link

codecov bot commented Feb 5, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@ChingTsai
Copy link
Collaborator Author

Hi @NuojCheng,
Could you help review this simple fix for the learning rate schedule?
Thanks!

@ChingTsai ChingTsai force-pushed the jimmytsai/fix-learning-rate-schedule branch from 0d98891 to e4dc022 Compare February 6, 2026 01:18
@ChingTsai ChingTsai force-pushed the jimmytsai/fix-learning-rate-schedule branch from e4dc022 to 5f00b25 Compare February 6, 2026 06:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant