Fix sharding of quantized models with non-power-of-2 bits #3006
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Proposed changes
Sharding 6-bit (and other non-power-of-2) quantized models with certain input dimensions (like 1536) may fail because
input_dims *= 32 // bitstruncates incorrectly. See: ml-explore/mlx-lm#771 (comment)For 6-bit with packed dimension 288:
This caused
shard_linearto fail with:ValueError: [quantize] ... matrix has shape (6144,1440)Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes