Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bitsandbytes/backends/default/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,11 @@ def _(

name2optimizer_id = {
"momentum": MOMENTUM,
"lars": MOMENTUM,
"rmsprop": RMSPROP,
"adagrad": ADAGRAD,
"adam": ADAM,
"lamb": ADAM,
"lion": LION,
"ademamix": ADEMAMIX,
}
Expand Down
15 changes: 14 additions & 1 deletion bitsandbytes/backends/triton/kernels_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@

name2optimizer_id = {
"momentum": MOMENTUM,
"lars": MOMENTUM,
"rmsprop": RMSPROP,
"adagrad": ADAGRAD,
"adam": ADAM,
"lamb": ADAM,
"lion": LION,
"ademamix": ADEMAMIX,
}
Expand Down Expand Up @@ -121,7 +123,8 @@ def _optimizer_precondition_1state_32bit(

if OPTIMIZER_ID == 0: # MOMENTUM
if step == 1:
s1_vals = g_vals
# Cast to fp32 to avoid type mismatch: s1_vals is fp32 but g_vals may be fp16.
s1_vals = g_vals.to(tl.float32)
else:
s1_vals = s1_vals * beta1 + g_vals
update_norm = s1_vals * s1_vals
Expand Down Expand Up @@ -313,6 +316,10 @@ def _optimizer_update_1state_32bit_triton_kernel(
"preprocess": _optimizer_precondition_2state_32bit,
"update": _optimizer_update_2state_32bit_triton_kernel,
},
"lamb": {
"preprocess": _optimizer_precondition_2state_32bit,
"update": _optimizer_update_2state_32bit_triton_kernel,
},
"ademamix": {
"preprocess": _optimizer_precondition_2state_32bit,
"update": _optimizer_update_2state_32bit_triton_kernel,
Expand All @@ -321,6 +328,10 @@ def _optimizer_update_1state_32bit_triton_kernel(
"preprocess": _optimizer_precondition_1state_32bit,
"update": _optimizer_update_1state_32bit_triton_kernel,
},
"lars": {
"preprocess": _optimizer_precondition_1state_32bit,
"update": _optimizer_update_1state_32bit_triton_kernel,
},
"rmsprop": {
"preprocess": _optimizer_precondition_1state_32bit,
"update": _optimizer_update_1state_32bit_triton_kernel,
Expand Down Expand Up @@ -1065,9 +1076,11 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel(

name2optimizer_fn = {
"momentum": _optimizer_update_1state_8bit_blockwise_triton_kernel,
"lars": _optimizer_update_1state_8bit_blockwise_triton_kernel,
"rmsprop": _optimizer_update_1state_8bit_blockwise_triton_kernel,
"adagrad": _optimizer_update_1state_8bit_blockwise_triton_kernel,
"adam": _optimizer_update_2state_8bit_blockwise_triton_kernel,
"lamb": _optimizer_update_2state_8bit_blockwise_triton_kernel,
"lion": _optimizer_update_1state_8bit_blockwise_triton_kernel,
"ademamix": _optimizer_update_2state_8bit_blockwise_triton_kernel,
}
Expand Down
4 changes: 2 additions & 2 deletions bitsandbytes/backends/triton/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> t
torch._check_is_size(blocksize)
# torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}")
with torch_accelerator_module.device(A.device):
out, absmax = kernels_8bit_quant.quantize_blockwise_triton(A, code, blocksize)
out, absmax = kernels_8bit_quant.quantize_blockwise_triton(A.contiguous(), code, blocksize)
return out, absmax.float()


Expand All @@ -30,7 +30,7 @@ def dequantize_blockwise(
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
with torch_accelerator_module.device(A.device):
out = kernels_8bit_quant.dequant_8bit_blockwise(
A,
A.contiguous(),
absmax,
code,
blocksize,
Expand Down