@@ -679,3 +679,186 @@ func NewNVFloat4Storage(src []float32, shape []int) *NVFloat4Storage {
679679
680680// Ensure NVFloat4Storage implements Storage[float32].
681681var _ Storage [float32 ] = (* NVFloat4Storage )(nil )
682+
683+ // --- NF4 Quantization (QLoRA double quantization) ---
684+
685+ const (
686+ nf4BlockSize = 64
687+ nf4MetaBlockSize = 256
688+ )
689+
690+ // nf4Codebook is the fixed 16-level normal float 4-bit quantization codebook.
691+ // Values are sorted and optimized for normally distributed weights (Dettmers et al., 2023).
692+ var nf4Codebook = [16 ]float32 {
693+ - 1.0 , - 0.6961928009986877 , - 0.5250730514526367 , - 0.39491748809814453 ,
694+ - 0.28444138169288635 , - 0.18477343022823334 , - 0.09105003625154495 , 0.0 ,
695+ 0.07958029955625534 , 0.16093020141124725 , 0.24611230194568634 , 0.33791524171829224 ,
696+ 0.44070982933044434 , 0.5626170039176941 , 0.7229568362236023 , 1.0 ,
697+ }
698+
699+ // NF4Storage stores float32 data quantized to 4-bit normal floats with
700+ // double quantization: block scales (one per nf4BlockSize elements) are
701+ // themselves stored quantized using meta-blocks of nf4MetaBlockSize blocks.
702+ type NF4Storage struct {
703+ Data []byte // packed: 2 NF4 indices per byte (lo nibble = first, hi nibble = second)
704+ Scales []float32 // one absmax scale per block of nf4BlockSize elements
705+ MetaScales []float32 // one absmax scale per meta-block of nf4MetaBlockSize blocks
706+ Shape []int
707+ n int // number of elements
708+ }
709+
710+ // NewNF4Storage quantizes src into NF4 format.
711+ func NewNF4Storage (src []float32 , shape []int ) * NF4Storage {
712+ s := & NF4Storage {Shape : shape }
713+ _ = s .Quantize (src )
714+ return s
715+ }
716+
717+ // Quantize encodes src into NF4 with double quantization.
718+ func (s * NF4Storage ) Quantize (src []float32 ) error {
719+ if len (src ) == 0 {
720+ s .n = 0
721+ s .Data = nil
722+ s .Scales = nil
723+ s .MetaScales = nil
724+ return nil
725+ }
726+ n := len (src )
727+ s .n = n
728+ nBlocks := (n + nf4BlockSize - 1 ) / nf4BlockSize
729+ nMeta := (nBlocks + nf4MetaBlockSize - 1 ) / nf4MetaBlockSize
730+
731+ scales := make ([]float32 , nBlocks )
732+ for b := 0 ; b < nBlocks ; b ++ {
733+ start := b * nf4BlockSize
734+ end := start + nf4BlockSize
735+ if end > n {
736+ end = n
737+ }
738+ var absMax float32
739+ for _ , v := range src [start :end ] {
740+ if v < 0 {
741+ v = - v
742+ }
743+ if v > absMax {
744+ absMax = v
745+ }
746+ }
747+ scales [b ] = absMax
748+ }
749+
750+ // Double-quantize the scales.
751+ metaScales := make ([]float32 , nMeta )
752+ for m := 0 ; m < nMeta ; m ++ {
753+ start := m * nf4MetaBlockSize
754+ end := start + nf4MetaBlockSize
755+ if end > nBlocks {
756+ end = nBlocks
757+ }
758+ var absMax float32
759+ for _ , v := range scales [start :end ] {
760+ if v > absMax {
761+ absMax = v
762+ }
763+ }
764+ metaScales [m ] = absMax
765+ }
766+
767+ // Encode each element.
768+ packed := make ([]byte , (n + 1 )/ 2 )
769+ for b := 0 ; b < nBlocks ; b ++ {
770+ start := b * nf4BlockSize
771+ end := start + nf4BlockSize
772+ if end > n {
773+ end = n
774+ }
775+ scale := scales [b ]
776+ for i := start ; i < end ; i ++ {
777+ var normalized float32
778+ if scale > 0 {
779+ normalized = src [i ] / scale
780+ }
781+ idx := nf4FindNearest (normalized )
782+ byteIdx := i / 2
783+ if i % 2 == 0 {
784+ packed [byteIdx ] = byte (idx & 0xF )
785+ } else {
786+ packed [byteIdx ] |= byte ((idx & 0xF ) << 4 )
787+ }
788+ }
789+ }
790+
791+ s .Data = packed
792+ s .Scales = scales
793+ s .MetaScales = metaScales
794+ return nil
795+ }
796+
797+ // Dequantize decodes NF4 data back to float32.
798+ func (s * NF4Storage ) Dequantize () []float32 {
799+ if s .n == 0 {
800+ return nil
801+ }
802+ out := make ([]float32 , s .n )
803+ nBlocks := len (s .Scales )
804+ for b := 0 ; b < nBlocks ; b ++ {
805+ start := b * nf4BlockSize
806+ end := start + nf4BlockSize
807+ if end > s .n {
808+ end = s .n
809+ }
810+ scale := s .Scales [b ]
811+ for i := start ; i < end ; i ++ {
812+ byteIdx := i / 2
813+ var nibble byte
814+ if i % 2 == 0 {
815+ nibble = s .Data [byteIdx ] & 0xF
816+ } else {
817+ nibble = (s .Data [byteIdx ] >> 4 ) & 0xF
818+ }
819+ out [i ] = nf4Codebook [nibble ] * scale
820+ }
821+ }
822+ return out
823+ }
824+
825+ // nf4FindNearest returns the codebook index closest to v via binary search on midpoints.
826+ func nf4FindNearest (v float32 ) int {
827+ lo , hi := 0 , 15
828+ for lo < hi {
829+ mid := (lo + hi ) / 2
830+ midpoint := (nf4Codebook [mid ] + nf4Codebook [mid + 1 ]) / 2
831+ if v <= midpoint {
832+ hi = mid
833+ } else {
834+ lo = mid + 1
835+ }
836+ }
837+ return lo
838+ }
839+
840+ // Len returns the number of elements.
841+ func (s * NF4Storage ) Len () int { return s .n }
842+
843+ // NumBlocks returns the number of NF4 blocks.
844+ func (s * NF4Storage ) NumBlocks () int { return len (s .Scales ) }
845+
846+ // ByteSize returns the raw byte size of packed NF4 data plus scales.
847+ func (s * NF4Storage ) ByteSize () int64 {
848+ return int64 (len (s .Data )) + int64 (len (s .Scales )* 4 ) + int64 (len (s .MetaScales )* 4 )
849+ }
850+
851+ // Slice dequantizes and returns a CPU float32 slice.
852+ func (s * NF4Storage ) Slice () []float32 { return s .Dequantize () }
853+
854+ // Set re-quantizes from a new float32 slice.
855+ func (s * NF4Storage ) Set (data []float32 ) {
856+ s .n = len (data )
857+ _ = s .Quantize (data )
858+ }
859+
860+ // DeviceType returns device.CPU.
861+ func (s * NF4Storage ) DeviceType () device.Type { return device .CPU }
862+
863+ // Ensure NF4Storage implements Storage[float32].
864+ var _ Storage [float32 ] = (* NF4Storage )(nil )
0 commit comments