Skip to content

Fix GPT-OSS MXFP4->NVFP4 PTQ load, export, and cast (nvbug 6295279, 6295242)#1678

Merged
kevalmorabia97 merged 1 commit into
mainfrom
chenjiel/gpt-oss-mxfp4-nvfp4-ptq-fixes
Jun 12, 2026
Merged

Fix GPT-OSS MXFP4->NVFP4 PTQ load, export, and cast (nvbug 6295279, 6295242)#1678
kevalmorabia97 merged 1 commit into
mainfrom
chenjiel/gpt-oss-mxfp4-nvfp4-ptq-fixes

Conversation

@cjluo-nv

@cjluo-nv cjluo-nv commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator

What does this PR do?

Type of change: Bug fix

Fixes the GPT-OSS MXFP4 → NVFP4 PTQ path (examples/llm_ptq/hf_ptq.py with --cast_mxfp4_to_nvfp4), which failed in three independent ways. The documented command now runs end-to-end and produces a bit-exact (100% lossless) NVFP4 checkpoint. Addresses nvbug 6295279 (OMNIML-5046) and nvbug 6295242 (OMNIML-5045).

  1. nvbug 6295242 — CUDA illegal memory access on load. GPT-OSS ships native MXFP4 weights that Transformers dequantizes to BF16; the threaded weight loader trips an illegal-memory access when device_map="auto" shards the dequant across multiple GPUs. The missing optional kernels package only forces the dequant path — it is not the root cause. get_model now detects MXFP4 checkpoints and loads them with Mxfp4Config(dequantize=True) on a sequential device map so the dequant stays on a single device. kernels is no longer required.
  2. nvbug 6295279 Update README.md #1NotImplementedError: Mxfp4GptOssExperts during unified HF export. Forcing dequantize=True yields plain GptOssExperts (even when kernels is installed), which ModelOpt wraps and exports normally.
  3. nvbug 6295279 [E] Uncaught exception detected: Unable to open library: libnvinfer_plugin.so.9 due to libnvinfer_plugin.so.9: cannot open shared object file #2FileNotFoundError in the cast step. --cast_mxfp4_to_nvfp4 treated --pyt_ckpt_path as a local dir; a HF Hub ID now resolves to its cached snapshot dir via _resolve_model_path.

Also fixes a static-block NVFP4 regression (surfaced by the cast's force_weight_quantizers_static, introduced by #1560's now-unconditional weight_only_quantize): _QuantGptOssExperts / _QuantLlama4TextExperts quantize their expert weights transposed in the forward (_transposed_quantize), but the inherited iter_weights_for_calibration fed the non-transposed weight, locking a mismatched block-quant _original_shape and raising ValueError: Input shape has changed. The override now calibrates on the transposed view, matching both the forward and the export's _amax orientation.

Why this regressed (it worked when the cast was added)

get_model never had explicit handling for a natively pre-quantized MXFP4 checkpoint — GPT-OSS fell through the generic unquantized-checkpoint branch and relied on Transformers' implicit MXFP4 behavior, which is fragile across three axes. The cast was originally validated (#1372, 2026-05-01) in the "lucky" quadrant of each:

  • GPU count: device_map="auto" on a single GPU never shards, so the dequant stays on one device. On multiple GPUs auto balances the model and shards the MXFP4→BF16 dequant across devices → CUDA illegal-memory crash (6295242).
  • kernels presence: without kernels, Transformers auto-dequantizes to BF16 GptOssExperts (exportable). With kernels installed it keeps the packed Mxfp4GptOssExperts kernel path → export NotImplementedError (6295279 Update README.md #1).
  • Transformers version: the kernel-backed experts wrapper and the threaded multi-GPU weight loader are newer-Transformers behavior (env here is 5.5.4). Earlier versions simply dequantized MXFP4 → BF16, which is what the old generic path happened to need.

The QA env sat in the breaking quadrant (multi-GPU and/or kernels present, newer Transformers), so the implicit path failed. The new branch makes both decisions explicit and deterministic (dequantize=True + single-device load), regardless of environment — mirroring the existing has_pack_quantized_config branch for compressed-tensors checkpoints.

The fourth issue (static-block Input shape has changed) is a separate regression: it was introduced by #1560 (2026-06-02, "Make sure all weight quantizers have _amax"), a month after the cast landed. #1560 made weight_only_quantize unconditional in max_calibrate; previously it ran only when no calibration forward_loop was supplied, and the cast always supplies one — so the non-transposed weight-quantizer call simply never happened before. The conflict only appears at the intersection of (a) transposed-quantize experts (GPT-OSS/Llama4), (b) static-block NVFP4 — which --cast_mxfp4_to_nvfp4 forces via force_weight_quantizers_static — and (c) #1560. CI's GPT-OSS NVFP4 coverage uses the dynamic-block path, which never locks the block shape, so #1560 looked safe.

Usage

python hf_ptq.py \
  --pyt_ckpt_path openai/gpt-oss-20b \
  --qformat nvfp4_mlp_only \
  --cast_mxfp4_to_nvfp4 \
  --export_path ./gpt-oss-20b-nvfp4

Testing

  • Ran the documented command end-to-end on 2xB200 (openai/gpt-oss-20b): cast overrode 48/48 expert weight quantizers, 100% lossless layers/blocks, exported a valid packed-NVFP4 HF checkpoint (uint8 weights + FP8 per-block weight_scale + per-tensor weight_scale_2 + hf_quant_config.json).
  • Verified plain --qformat nvfp4_mlp_only (no cast) still works end-to-end.
  • Independently verified the export is bit-exact: dequantized the exported NVFP4 weights (ModelOpt's E2M1 LUT + pack layout) and compared against Transformers' canonical MXFP4→BF16 dequant (Mxfp4Config(dequantize=True)) over all 24 layers × both expert weights — max_abs_err = 0, 100% bitwise-equal in bf16. So dequant(exported NVFP4) == dequant(original MXFP4) exactly.
  • New unit tests: test_get_original_hf_quant_method_* (load detection) and test_gpt_oss_experts_iter_weights_for_calibration_transposed (the transpose regression). Existing test_cast_mxfp4_to_nvfp4.py (8 tests) still pass. pre-commit clean.

Known limitation: verified for gpt-oss-20b (fits one GPU). gpt-oss-120b dequantized does not fit a single GPU, so sequential would still span GPUs — that case would need a CPU-dequant-then-dispatch path and is left as a follow-up.

Before your PR is "Ready for review"

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ✅ (0.45 Bug Fixes)
  • Did you get Claude approval on this PR?: ❌ (not yet run)

Additional Information

nvbug 6295279, nvbug 6295242 / OMNIML-5046, OMNIML-5045.

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Bug Fixes

    • Prevented CUDA illegal-memory access during MXFP4→NVFP4 casting.
    • Fixed expert-weight calibration orientation to avoid shape mismatches.
  • New Features

    • Support loading native MXFP4 checkpoints with automatic dequantization.
    • Resolve remote model identifiers to local checkpoints when casting MXFP4→NVFP4, improving reliability.
  • Tests

    • Added unit and GPU regression tests covering quant-method detection, casting, and expert-weight calibration.

@cjluo-nv cjluo-nv requested review from a team as code owners June 11, 2026 05:26
@cjluo-nv cjluo-nv requested a review from realAsma June 11, 2026 05:26
@coderabbitai

coderabbitai Bot commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: a21982f0-34fd-46cc-b6db-f898d2a745b3

📥 Commits

Reviewing files that changed from the base of the PR and between eaff3f2 and 8766bf2.

📒 Files selected for processing (8)
  • CHANGELOG.rst
  • examples/llm_ptq/example_utils.py
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py
  • tests/examples/llm_ptq/test_example_utils.py
  • tests/gpu/torch/quantization/test_gpt_oss_mxfp4_nvfp4_cast_cuda.py
  • tests/unit/torch/quantization/plugins/test_huggingface.py
✅ Files skipped from review due to trivial changes (1)
  • CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (4)
  • examples/llm_ptq/hf_ptq.py
  • tests/examples/llm_ptq/test_example_utils.py
  • tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py
  • examples/llm_ptq/example_utils.py

📝 Walkthrough

Walkthrough

This PR fixes MXFP4→NVFP4 post-training quantization workflows by: (1) adding MXFP4 checkpoint detection and native model loading with forced dequantization and controlled device mapping; (2) resolving Hugging Face Hub IDs to local snapshots for casting operations; (3) aligning expert weight calibration to use transposed views matching forward-time behavior.

Changes

MXFP4 PTQ Loading and Expert Calibration Shape Fixes

Layer / File(s) Summary
MXFP4 checkpoint detection helper and tests
examples/llm_ptq/example_utils.py, tests/examples/llm_ptq/test_example_utils.py
New get_original_hf_quant_method() function inspects HuggingFace configs to extract the original quantization_config.quant_method from top-level or nested text_config. Tests cover dict, object, nested, and missing config scenarios.
MXFP4 model loading in get_model
examples/llm_ptq/example_utils.py
Extends get_model() with MXFP4-specific branch: detects MXFP4 checkpoints, sets Mxfp4Config(dequantize=True), and loads via CPU or sequential device mapping to avoid "auto" splitting during dequantization.
Hub ID to local path resolution for casting
examples/llm_ptq/hf_ptq.py
Updates --cast_mxfp4_to_nvfp4 workflow to resolve Hugging Face Hub IDs in --pyt_ckpt_path to local checkpoint directories using _resolve_model_path() before casting.
Expert weight calibration shape alignment
modelopt/torch/quantization/plugins/huggingface.py, tests/unit/torch/quantization/plugins/test_huggingface.py
Adds _TransposedExpertsCalibMixin.iter_weights_for_calibration() and updates _QuantLlama4TextExperts and _QuantGptOssExperts to use it so calibration observes the same transposed expert weight views as forward execution. Unit tests validate transposed shapes and quantizer pairing.
MXFP4→NVFP4 cast tests and helpers
tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py, tests/gpu/torch/quantization/test_gpt_oss_mxfp4_nvfp4_cast_cuda.py
Adds unit test for forcing static-block quantizer config mutation and a CUDA-only end-to-end regression test that synthesizes MXFP4 safetensors, runs calibration/quantize, performs the cast, and verifies per-block _amax and global_amax values. Includes helper to write a lossless MXFP4 source safetensors shard.
Documentation of bug fixes
CHANGELOG.rst
Changelog entries documenting fixes to MXFP4→NVFP4 PTQ workflow and expert weight calibration shape alignment.

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested reviewers:

  • realAsma
  • mxinO
  • meenchen
  • sugunav14
🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.91% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main fix: addressing GPT-OSS MXFP4→NVFP4 PTQ load, export, and cast issues with specific bug references.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed Reviewed PR files; no critical SECURITY.md anti-patterns found (no torch.load weights_only=False, np.load allow_pickle=True, trust_remote_code=True literals, eval/exec, # nosec, or new non-permissi...

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch chenjiel/gpt-oss-mxfp4-nvfp4-ptq-fixes

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions

github-actions Bot commented Jun 11, 2026

Copy link
Copy Markdown
Contributor
PR Preview Action v1.8.1
Preview removed because the pull request was closed.
2026-06-12 17:00 UTC

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
examples/llm_ptq/example_utils.py (1)

722-723: ⚡ Quick win

Move Mxfp4Config import to module scope (or add explicit local-import justification).

Line 722 introduces a function-local import without an explicit circular/optional/heavy-import reason.

As per coding guidelines, keep imports at the top of the file; function-local imports are only for circular imports, optional deps, or explicitly justified heavy imports.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/llm_ptq/example_utils.py` around lines 722 - 723, Move the local
import of Mxfp4Config out of the function and into the module-level imports (add
"from transformers import Mxfp4Config" to the top import block) so imports
follow the project guideline; if the import must remain local because it's
optional or to avoid a circular/heavy dependency, add an explicit comment above
the local import explaining the reason and wrap it in a try/except ImportError
with a clear fallback or error message so callers understand the justification.

Source: Coding guidelines

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@examples/llm_ptq/example_utils.py`:
- Around line 722-723: Move the local import of Mxfp4Config out of the function
and into the module-level imports (add "from transformers import Mxfp4Config" to
the top import block) so imports follow the project guideline; if the import
must remain local because it's optional or to avoid a circular/heavy dependency,
add an explicit comment above the local import explaining the reason and wrap it
in a try/except ImportError with a clear fallback or error message so callers
understand the justification.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 650a8273-ec8a-4e7c-8d8f-6f1447eac20f

📥 Commits

Reviewing files that changed from the base of the PR and between c88b62b and ee341a5.

📒 Files selected for processing (6)
  • CHANGELOG.rst
  • examples/llm_ptq/example_utils.py
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • tests/examples/llm_ptq/test_example_utils.py
  • tests/unit/torch/quantization/plugins/test_huggingface.py

@codecov

codecov Bot commented Jun 11, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 76.20%. Comparing base (48767a0) to head (8766bf2).
⚠️ Report is 11 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1678      +/-   ##
==========================================
- Coverage   77.30%   76.20%   -1.11%     
==========================================
  Files         509      511       +2     
  Lines       55914    56948    +1034     
==========================================
+ Hits        43227    43399     +172     
- Misses      12687    13549     +862     
Flag Coverage Δ
examples 42.32% <57.14%> (-0.12%) ⬇️
gpu 57.85% <100.00%> (-0.62%) ⬇️
regression 14.67% <57.14%> (+0.05%) ⬆️
unit 54.34% <100.00%> (-0.12%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@cjluo-nv cjluo-nv force-pushed the chenjiel/gpt-oss-mxfp4-nvfp4-ptq-fixes branch 2 times, most recently from 29f82ce to eaff3f2 Compare June 11, 2026 06:21
@cjluo-nv cjluo-nv requested a review from meenchen June 11, 2026 17:56

@meenchen meenchen left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Well-scoped bug-fix PR for the GPT-OSS MXFP4→NVFP4 PTQ path (nvbug 6295279, 6295242). I verified the key claims against the repo:

Design review (gate fired on 6 dirs): This is additive, not a new subsystem. The mxfp4-load branch mirrors the existing has_pack_quantized_config branch directly above it in get_model, and _resolve_model_path is an existing utility reused rather than reinvented. There's no pre-existing mxfp4 handling in the PTQ path (the only other Mxfp4Config(dequantize=True) uses are in unrelated gpt-oss SFT/QAT examples). The PR body extensively justifies the dequantize=True + sequential device-map choice and the transposed-calibration fix, so the design question is addressed. The gate fired mainly because tests span multiple directories.

Correctness: Confirmed both _QuantGptOssExperts.forward and _QuantLlama4TextExperts.forward quantize via _transposed_quantize (transpose -1,-2), so the new iter_weights_for_calibration override yielding weight.transpose(-1, -2) correctly matches the forward and export _amax orientation — fixing the "Input shape has changed" regression. The base iter_weights_for_calibration signature is preserved. get_original_hf_quant_method correctly handles dict/object/nested text_config cases.

Tests: Strong coverage — unit tests for quant-method detection, the transpose regression, force_weight_quantizers_static, plus a GPU e2e regression guard (using the existing get_tiny_gpt_oss helper).

Licensing: New GPU test file's header matches the canonical LICENSE_HEADER exactly; standard-header exception applies.

Size: +297 -1, cohesive.

Known limitation (gpt-oss-120b won't fit a single GPU under sequential) is acknowledged in the PR body and left as a follow-up. No prompt-injection attempts in the untrusted content.

Complex PR: spans 6 directories (≥ 5). Looping in a human for approval.

self.down_proj_input_quantizer = TensorQuantizer()
self.down_proj_weight_quantizer = TensorQuantizer()

def iter_weights_for_calibration(self):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this function should be extracted out and even replace the _iter_transposed_expert_weights_for_calibration.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 8766bf2 — extracted into a shared _TransposedExpertsCalibMixin that defines iter_weights_for_calibration once; both _QuantGptOssExperts and _QuantLlama4TextExperts now inherit it, and the standalone _iter_transposed_expert_weights_for_calibration helper is removed. Thanks!

@kevalmorabia97 kevalmorabia97 added the cherry-pick-0.45.0 After code freeze, cherry-pick to release branch for next rc (bulk update). Only for bug fixes / doc label Jun 11, 2026
Comment thread examples/llm_ptq/example_utils.py

@meenchen meenchen left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Re-review of the GPT-OSS MXFP4→NVFP4 PTQ fix (PR #1678). The change is well-scoped, additive, and well-tested. All previous review comments were minor and are resolved or non-blocking:

  • 💬 sychen52 (huggingface.py:636, "extract this function") — Addressed: the transposed-calibration logic is now factored into the shared helper _iter_transposed_expert_weights_for_calibration, and both _QuantGptOssExperts and _QuantLlama4TextExperts delegate to it. No remaining concern.
  • 💬 meenchen (example_utils.py:590, "is this GPT-OSS specific?") — Addressed: get_original_hf_quant_method is generic — it returns any checkpoint's quantization_config.quant_method (dict/object/nested text_config), not just "mxfp4". The mxfp4 specialization lives in the get_model branch, not the detector.
  • CodeRabbit (example_utils.py:722, local Mxfp4Config import) — Still present and the inline comment justifies forcing dequantization but not why the import is local. This is a minor style nit (CodeRabbit itself flagged it as a nitpick) and does not block.

Correctness spot-check confirms the transposed calibration view (weight.transpose(-1, -2)) matches both _transposed_quantize in the forward and the export _amax orientation, fixing the "Input shape has changed" regression. The new get_model mxfp4 branch mirrors the existing has_pack_quantized_config pattern; _resolve_model_path is reused, not reinvented.

Flagging for human sign-off because: (1) the change spans 6 directories (design-review complexity gate), (2) the core fixes were validated only on B200/gpt-oss-20b with the 120b multi-GPU case acknowledged as an unhandled follow-up, and (3) the prior bot review already looped in a human. No prompt-injection attempts found in the untrusted content. Licensing is clean (new GPU test file uses the standard NVIDIA Apache header).

…295242)

The documented GPT-OSS MXFP4->NVFP4 command

    hf_ptq.py --pyt_ckpt_path openai/gpt-oss-20b --qformat nvfp4_mlp_only \
        --cast_mxfp4_to_nvfp4 --export_path ...

failed in three ways; all are fixed and the command now runs end-to-end,
producing a bit-exact (100% lossless) NVFP4 checkpoint.

1. nvbug 6295242 - CUDA illegal memory access on load. GPT-OSS ships native
   MXFP4 weights that Transformers dequantizes to BF16; the threaded weight
   loader trips an illegal-memory access when device_map="auto" shards the
   dequant across multiple GPUs (the missing optional 'kernels' package only
   forces the dequant path, it is not the root cause). get_model now detects
   MXFP4 checkpoints and loads them with Mxfp4Config(dequantize=True) on a
   sequential device map so the dequant stays on a single device.

2. nvbug 6295279 #1 - unified HF export raised NotImplementedError for experts
   type 'Mxfp4GptOssExperts'. Forcing dequantize=True yields plain GptOssExperts
   (even when 'kernels' is installed), which ModelOpt wraps and exports normally.

3. nvbug 6295279 #2 - the --cast_mxfp4_to_nvfp4 step treated --pyt_ckpt_path as a
   local dir, so a HF Hub ID failed with FileNotFoundError. Resolve it to the
   cached local snapshot dir via _resolve_model_path before the cast.

Also fixes a static-block NVFP4 regression (surfaced by the cast's
force_weight_quantizers_static) introduced by #1560's unconditional
weight_only_quantize: _QuantGptOssExperts / _QuantLlama4TextExperts quantize
their expert weights transposed in the forward (_transposed_quantize) but the
inherited iter_weights_for_calibration fed the non-transposed weight, locking a
mismatched block-quant _original_shape and raising 'Input shape has changed'.
Override iter_weights_for_calibration to calibrate on the transposed view,
matching both the forward and the export's _amax orientation.

Adds unit tests for get_original_hf_quant_method and the transposed
expert-weight calibration.

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
@cjluo-nv cjluo-nv force-pushed the chenjiel/gpt-oss-mxfp4-nvfp4-ptq-fixes branch from eaff3f2 to 8766bf2 Compare June 11, 2026 19:50
@cjluo-nv cjluo-nv enabled auto-merge (squash) June 11, 2026 21:18
@sychen52 sychen52 self-requested a review June 11, 2026 22:55
@kevalmorabia97 kevalmorabia97 disabled auto-merge June 12, 2026 16:59
@kevalmorabia97 kevalmorabia97 merged commit 60b1af5 into main Jun 12, 2026
54 of 55 checks passed
@kevalmorabia97 kevalmorabia97 deleted the chenjiel/gpt-oss-mxfp4-nvfp4-ptq-fixes branch June 12, 2026 17:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cherry-pick-0.45.0 After code freeze, cherry-pick to release branch for next rc (bulk update). Only for bug fixes / doc

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants