Skip to content

Fix Nemotron-H: add mlp layer type support#45300

Open
w4nderlust wants to merge 2 commits intohuggingface:mainfrom
w4nderlust:fix/nemotron-h-mlp-layer-type
Open

Fix Nemotron-H: add mlp layer type support#45300
w4nderlust wants to merge 2 commits intohuggingface:mainfrom
w4nderlust:fix/nemotron-h-mlp-layer-type

Conversation

@w4nderlust
Copy link
Copy Markdown
Contributor

@w4nderlust w4nderlust commented Apr 7, 2026

What does this PR do?

Nemotron-H models use standalone MLP layers in their hybrid_override_pattern (the - character), but the config parser, validators, and modeling code only know about mamba/attention/moe. This means every Nemotron-H model on the hub (nvidia/Nemotron-H-8B-Base-8K, nvidia/Nemotron-H-56B-Base-8K, and all Nemotron-3-Nano variants) crashes on load:

KeyError: '-'

in _pattern_to_list().

NVIDIA's own hub modeling code (the trust_remote_code=True path) handles - as "mlp" and dispatches to NemotronHMLP in its mixer, but the native transformers integration missed it. The NemotronHMLP class already exists in the codebase, it just wasn't wired up.

The fix is small, 4 files, all one-liners:

  • configuration_nemotron_h.py: add "-": "mlp" to _pattern_to_list mapping, add "mlp" to valid_types in validate_layers_block_type
  • modeling_nemotron_h.py: add "mlp": NemotronHMLP to MIXER_TYPES, add "mlp": None to block_type_to_mask, accept **kwargs in NemotronHMLP.__init__ (because NemotronHBlock passes layer_idx=)
  • configuration_utils.py: add "mlp" to ALLOWED_LAYER_TYPES

Tested locally: loaded nvidia/NVIDIA-Nemotron-3-Nano-4B-BF16 (pattern M-M-M-MM-M-M*-M-M*-M-M-M*-M-M-MM*-MMM-M-M-), ran model.generate() with a chat prompt, got coherent output. Before the fix it crashes on config parsing. Also ran make fix-repo checks (ruff, auto_mappings, copies, dummies) -- all pass.

Code Agent Policy

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @Cyrilvallez

Nemotron-H models use standalone MLP layers in their
hybrid_override_pattern (the '-' character), but the config parser,
validators, and modeling code only handle mamba/attention/moe.

This means every Nemotron-H model on the hub fails to load:
  KeyError: '-'  in _pattern_to_list()

Changes:
- _pattern_to_list: add '-' -> 'mlp' mapping
- validate_layers_block_type: add 'mlp' to valid_types
- MIXER_TYPES: add 'mlp' -> NemotronHMLP
- block_type_to_mask: add 'mlp' -> None
- NemotronHMLP.__init__: accept **kwargs (layer_idx passed by NemotronHBlock)
- ALLOWED_LAYER_TYPES: add 'mlp'
- modular_nemotron_h.py: same changes (source of truth for modeling code)
@w4nderlust w4nderlust force-pushed the fix/nemotron-h-mlp-layer-type branch from 2556332 to 54b8841 Compare April 7, 2026 21:10
@Rocketknight1
Copy link
Copy Markdown
Member

Interesting, and seems like a big gap in our tests. How did we miss this one, and is it possible to make a regression test for it that doesn't have to load a multi-billion parameter model?

- Add 4 regression tests that exercise the MLP layer type end-to-end:
  config parsing, forward pass, generation, and real Nemotron-H patterns
- Fix _list_to_pattern missing "mlp": "-" mapping (would crash on roundtrip)
- Fix _check_past_key_values_for_generate to handle "mlp" layer type
- Extend test_pattern_conversion_methods with MLP roundtrip coverage

All tests use tiny models (hidden_size=32, ~5-8 layers) — no downloads needed.
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 8, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: nemotron_h

@w4nderlust
Copy link
Copy Markdown
Contributor Author

I beleive the gap happened because the existing test suite only exercised patterns with M (mamba), * (attention), and E (moe), but all the actual Nemotron-H models on the Hub use - (standalone MLP) in their hybrid_override_pattern. The NemotronHMLP class was already in the codebas, it just wasn't wired into the dispatcher or validation. NVIDIA's trust_remote_code path handled it, which likely delayed reports.

I've pushed a follow-up commit with regression tests that don't need any model downloads as they instantiate tiny models (hidden_size=32, 5-8 layers) with MLP-containing patterns and exercise the full path:

  • test_mlp_layer_type_config: verifies "mlp" is accepted in layers_block_type and via the legacy "-" pattern
  • test_mlp_layer_type_forward: runs a forward pass through NemotronHModel with MLP layers
  • test_mlp_layer_type_causal_lm: runs model.generate() through NemotronHForCausalLM with MLP layers and cache
  • test_mlp_layer_type_nemotron_h_pattern: uses a shortened version of the real Nano-4B pattern (M-M-*M-M)

Also found and fixed two secondary bugs while writing the tests:

  • _list_to_pattern was missing "mlp": "-" in its reverse mapping, so any config roundtrip with MLP layers would crash
  • _check_past_key_values_for_generate in the test helper raised ValueError("Unknown layer type.") for "mlp" layers

Let me know if these tests sound good, happy to change them if not.

by the way, some of the CI errors seem flaky / transient, not caused by my changes.

@ArthurZucker
Copy link
Copy Markdown
Collaborator

this is a duplicate of #44763 please check it out

@w4nderlust
Copy link
Copy Markdown
Contributor Author

this is a duplicate of #44763 please check it out

Happy to change in the way you want it, but right now the Nemotron models are unusable, so I would either merge this or figure out a better way and I'll implement it, don't leave it hanging please.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants