Skip to content

Commit 1e4beaa

Browse files
committed
fix(tensor): add missing NF4Storage implementation (T9.3 agent omitted impl)
1 parent 9710e8d commit 1e4beaa

1 file changed

Lines changed: 183 additions & 0 deletions

File tree

tensor/quantized.go

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,3 +679,186 @@ func NewNVFloat4Storage(src []float32, shape []int) *NVFloat4Storage {
679679

680680
// Ensure NVFloat4Storage implements Storage[float32].
681681
var _ Storage[float32] = (*NVFloat4Storage)(nil)
682+
683+
// --- NF4 Quantization (QLoRA double quantization) ---
684+
685+
const (
686+
nf4BlockSize = 64
687+
nf4MetaBlockSize = 256
688+
)
689+
690+
// nf4Codebook is the fixed 16-level normal float 4-bit quantization codebook.
691+
// Values are sorted and optimized for normally distributed weights (Dettmers et al., 2023).
692+
var nf4Codebook = [16]float32{
693+
-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
694+
-0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
695+
0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224,
696+
0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0,
697+
}
698+
699+
// NF4Storage stores float32 data quantized to 4-bit normal floats with
700+
// double quantization: block scales (one per nf4BlockSize elements) are
701+
// themselves stored quantized using meta-blocks of nf4MetaBlockSize blocks.
702+
type NF4Storage struct {
703+
Data []byte // packed: 2 NF4 indices per byte (lo nibble = first, hi nibble = second)
704+
Scales []float32 // one absmax scale per block of nf4BlockSize elements
705+
MetaScales []float32 // one absmax scale per meta-block of nf4MetaBlockSize blocks
706+
Shape []int
707+
n int // number of elements
708+
}
709+
710+
// NewNF4Storage quantizes src into NF4 format.
711+
func NewNF4Storage(src []float32, shape []int) *NF4Storage {
712+
s := &NF4Storage{Shape: shape}
713+
_ = s.Quantize(src)
714+
return s
715+
}
716+
717+
// Quantize encodes src into NF4 with double quantization.
718+
func (s *NF4Storage) Quantize(src []float32) error {
719+
if len(src) == 0 {
720+
s.n = 0
721+
s.Data = nil
722+
s.Scales = nil
723+
s.MetaScales = nil
724+
return nil
725+
}
726+
n := len(src)
727+
s.n = n
728+
nBlocks := (n + nf4BlockSize - 1) / nf4BlockSize
729+
nMeta := (nBlocks + nf4MetaBlockSize - 1) / nf4MetaBlockSize
730+
731+
scales := make([]float32, nBlocks)
732+
for b := 0; b < nBlocks; b++ {
733+
start := b * nf4BlockSize
734+
end := start + nf4BlockSize
735+
if end > n {
736+
end = n
737+
}
738+
var absMax float32
739+
for _, v := range src[start:end] {
740+
if v < 0 {
741+
v = -v
742+
}
743+
if v > absMax {
744+
absMax = v
745+
}
746+
}
747+
scales[b] = absMax
748+
}
749+
750+
// Double-quantize the scales.
751+
metaScales := make([]float32, nMeta)
752+
for m := 0; m < nMeta; m++ {
753+
start := m * nf4MetaBlockSize
754+
end := start + nf4MetaBlockSize
755+
if end > nBlocks {
756+
end = nBlocks
757+
}
758+
var absMax float32
759+
for _, v := range scales[start:end] {
760+
if v > absMax {
761+
absMax = v
762+
}
763+
}
764+
metaScales[m] = absMax
765+
}
766+
767+
// Encode each element.
768+
packed := make([]byte, (n+1)/2)
769+
for b := 0; b < nBlocks; b++ {
770+
start := b * nf4BlockSize
771+
end := start + nf4BlockSize
772+
if end > n {
773+
end = n
774+
}
775+
scale := scales[b]
776+
for i := start; i < end; i++ {
777+
var normalized float32
778+
if scale > 0 {
779+
normalized = src[i] / scale
780+
}
781+
idx := nf4FindNearest(normalized)
782+
byteIdx := i / 2
783+
if i%2 == 0 {
784+
packed[byteIdx] = byte(idx & 0xF)
785+
} else {
786+
packed[byteIdx] |= byte((idx & 0xF) << 4)
787+
}
788+
}
789+
}
790+
791+
s.Data = packed
792+
s.Scales = scales
793+
s.MetaScales = metaScales
794+
return nil
795+
}
796+
797+
// Dequantize decodes NF4 data back to float32.
798+
func (s *NF4Storage) Dequantize() []float32 {
799+
if s.n == 0 {
800+
return nil
801+
}
802+
out := make([]float32, s.n)
803+
nBlocks := len(s.Scales)
804+
for b := 0; b < nBlocks; b++ {
805+
start := b * nf4BlockSize
806+
end := start + nf4BlockSize
807+
if end > s.n {
808+
end = s.n
809+
}
810+
scale := s.Scales[b]
811+
for i := start; i < end; i++ {
812+
byteIdx := i / 2
813+
var nibble byte
814+
if i%2 == 0 {
815+
nibble = s.Data[byteIdx] & 0xF
816+
} else {
817+
nibble = (s.Data[byteIdx] >> 4) & 0xF
818+
}
819+
out[i] = nf4Codebook[nibble] * scale
820+
}
821+
}
822+
return out
823+
}
824+
825+
// nf4FindNearest returns the codebook index closest to v via binary search on midpoints.
826+
func nf4FindNearest(v float32) int {
827+
lo, hi := 0, 15
828+
for lo < hi {
829+
mid := (lo + hi) / 2
830+
midpoint := (nf4Codebook[mid] + nf4Codebook[mid+1]) / 2
831+
if v <= midpoint {
832+
hi = mid
833+
} else {
834+
lo = mid + 1
835+
}
836+
}
837+
return lo
838+
}
839+
840+
// Len returns the number of elements.
841+
func (s *NF4Storage) Len() int { return s.n }
842+
843+
// NumBlocks returns the number of NF4 blocks.
844+
func (s *NF4Storage) NumBlocks() int { return len(s.Scales) }
845+
846+
// ByteSize returns the raw byte size of packed NF4 data plus scales.
847+
func (s *NF4Storage) ByteSize() int64 {
848+
return int64(len(s.Data)) + int64(len(s.Scales)*4) + int64(len(s.MetaScales)*4)
849+
}
850+
851+
// Slice dequantizes and returns a CPU float32 slice.
852+
func (s *NF4Storage) Slice() []float32 { return s.Dequantize() }
853+
854+
// Set re-quantizes from a new float32 slice.
855+
func (s *NF4Storage) Set(data []float32) {
856+
s.n = len(data)
857+
_ = s.Quantize(data)
858+
}
859+
860+
// DeviceType returns device.CPU.
861+
func (s *NF4Storage) DeviceType() device.Type { return device.CPU }
862+
863+
// Ensure NF4Storage implements Storage[float32].
864+
var _ Storage[float32] = (*NF4Storage)(nil)

0 commit comments

Comments
 (0)