diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index 78ba818ca..74eeb5502 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -357,9 +357,11 @@ def _( name2optimizer_id = { "momentum": MOMENTUM, + "lars": MOMENTUM, "rmsprop": RMSPROP, "adagrad": ADAGRAD, "adam": ADAM, + "lamb": ADAM, "lion": LION, "ademamix": ADEMAMIX, } diff --git a/bitsandbytes/backends/triton/kernels_optim.py b/bitsandbytes/backends/triton/kernels_optim.py index 2cd6d8c93..d7abc4af1 100644 --- a/bitsandbytes/backends/triton/kernels_optim.py +++ b/bitsandbytes/backends/triton/kernels_optim.py @@ -24,9 +24,11 @@ name2optimizer_id = { "momentum": MOMENTUM, + "lars": MOMENTUM, "rmsprop": RMSPROP, "adagrad": ADAGRAD, "adam": ADAM, + "lamb": ADAM, "lion": LION, "ademamix": ADEMAMIX, } @@ -121,7 +123,8 @@ def _optimizer_precondition_1state_32bit( if OPTIMIZER_ID == 0: # MOMENTUM if step == 1: - s1_vals = g_vals + # Cast to fp32 to avoid type mismatch: s1_vals is fp32 but g_vals may be fp16. + s1_vals = g_vals.to(tl.float32) else: s1_vals = s1_vals * beta1 + g_vals update_norm = s1_vals * s1_vals @@ -313,6 +316,10 @@ def _optimizer_update_1state_32bit_triton_kernel( "preprocess": _optimizer_precondition_2state_32bit, "update": _optimizer_update_2state_32bit_triton_kernel, }, + "lamb": { + "preprocess": _optimizer_precondition_2state_32bit, + "update": _optimizer_update_2state_32bit_triton_kernel, + }, "ademamix": { "preprocess": _optimizer_precondition_2state_32bit, "update": _optimizer_update_2state_32bit_triton_kernel, @@ -321,6 +328,10 @@ def _optimizer_update_1state_32bit_triton_kernel( "preprocess": _optimizer_precondition_1state_32bit, "update": _optimizer_update_1state_32bit_triton_kernel, }, + "lars": { + "preprocess": _optimizer_precondition_1state_32bit, + "update": _optimizer_update_1state_32bit_triton_kernel, + }, "rmsprop": { "preprocess": _optimizer_precondition_1state_32bit, "update": _optimizer_update_1state_32bit_triton_kernel, @@ -1065,9 +1076,11 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel( name2optimizer_fn = { "momentum": _optimizer_update_1state_8bit_blockwise_triton_kernel, + "lars": _optimizer_update_1state_8bit_blockwise_triton_kernel, "rmsprop": _optimizer_update_1state_8bit_blockwise_triton_kernel, "adagrad": _optimizer_update_1state_8bit_blockwise_triton_kernel, "adam": _optimizer_update_2state_8bit_blockwise_triton_kernel, + "lamb": _optimizer_update_2state_8bit_blockwise_triton_kernel, "lion": _optimizer_update_1state_8bit_blockwise_triton_kernel, "ademamix": _optimizer_update_2state_8bit_blockwise_triton_kernel, } diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py index b4c980078..686d42929 100644 --- a/bitsandbytes/backends/triton/ops.py +++ b/bitsandbytes/backends/triton/ops.py @@ -18,7 +18,7 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> t torch._check_is_size(blocksize) # torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}") with torch_accelerator_module.device(A.device): - out, absmax = kernels_8bit_quant.quantize_blockwise_triton(A, code, blocksize) + out, absmax = kernels_8bit_quant.quantize_blockwise_triton(A.contiguous(), code, blocksize) return out, absmax.float() @@ -30,7 +30,7 @@ def dequantize_blockwise( # torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}") with torch_accelerator_module.device(A.device): out = kernels_8bit_quant.dequant_8bit_blockwise( - A, + A.contiguous(), absmax, code, blocksize,