@@ -1577,7 +1577,7 @@ struct WeightAdapter {
15771577 bool force_prec_f32 = false ;
15781578 float scale = 1 .f;
15791579 } linear;
1580- struct conv2d_params_t {
1580+ struct conv2d_params_t {
15811581 int s0 = 1 ;
15821582 int s1 = 1 ;
15831583 int p0 = 0 ;
@@ -2642,7 +2642,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_lokr_forward(
26422642 bool is_conv,
26432643 WeightAdapter::ForwardParams::conv2d_params_t conv_params,
26442644 float scale) {
2645-
26462645 GGML_ASSERT ((w1 != NULL || (w1a != NULL && w1b != NULL )));
26472646 GGML_ASSERT ((w2 != NULL || (w2a != NULL && w2b != NULL )));
26482647
@@ -2660,16 +2659,16 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_lokr_forward(
26602659
26612660 if (!is_conv) {
26622661 int batch = (int )h->ne [1 ];
2663- struct ggml_tensor * h_mat = ggml_reshape_2d (ctx, h, vq, uq * batch);
2662+ struct ggml_tensor * h_split = ggml_reshape_2d (ctx, h, vq, uq * batch);
26642663
26652664 if (w2 != NULL ) {
2666- hb = ggml_mul_mat (ctx, w2, h_mat );
2665+ hb = ggml_mul_mat (ctx, w2, h_split );
26672666 } else {
2668- hb = ggml_mul_mat (ctx, w2b, ggml_mul_mat (ctx, w2a, h_mat ));
2667+ hb = ggml_mul_mat (ctx, w2b, ggml_mul_mat (ctx, w2a, h_split ));
26692668 }
26702669
2671- struct ggml_tensor * hb_unbundled = ggml_reshape_3d (ctx, hb, vp, uq, batch);
2672- struct ggml_tensor * hb_t = ggml_cont (ctx,ggml_transpose (ctx, hb_unbundled ));
2670+ struct ggml_tensor * hb_cat = ggml_reshape_3d (ctx, hb, vp, uq, batch);
2671+ struct ggml_tensor * hb_t = ggml_cont (ctx, ggml_transpose (ctx, hb_cat ));
26732672
26742673 struct ggml_tensor * hc;
26752674 if (w1 != NULL ) {
@@ -2683,92 +2682,127 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_lokr_forward(
26832682 return ggml_scale (ctx, out, scale);
26842683
26852684 } else {
2686- #if 0
2685+ #if 1
26872686 // very slow implementation for now (can this be optimized?)
26882687 int batch = (int )h->ne [3 ];
26892688
26902689 // 1. Reshape input: [W, H, vq*uq, batch] -> [W, H, vq, uq * batch]
26912690 // This is free (metadata only)
2692- struct ggml_tensor* h_grouped = ggml_reshape_4d(ctx, h, h->ne[0], h->ne[1], vq, uq * batch);
2691+ // print_ggml_tensor(h, true, "\nh");
2692+ struct ggml_tensor * h_split = ggml_reshape_4d (ctx, h, h->ne [0 ], h->ne [1 ], vq, uq * batch);
2693+ // print_ggml_tensor(h_split, true, "h_split");
26932694
26942695 struct ggml_tensor * hb;
26952696 if (w2 != NULL ) {
2696- hb = ggml_conv_2d(ctx, w2, h_grouped, conv_params.s0, conv_params.s1,
2697- conv_params.p0, conv_params.p1, conv_params.d0, conv_params.d1);
2697+ // no LoRA
2698+ // print_ggml_tensor(w2, true, "w2");
2699+ hb = ggml_ext_conv_2d (ctx, h_split, w2, nullptr ,
2700+ conv_params.s0 ,
2701+ conv_params.s1 ,
2702+ conv_params.p0 ,
2703+ conv_params.p1 ,
2704+ conv_params.d0 ,
2705+ conv_params.d1 ,
2706+ conv_params.direct ,
2707+ conv_params.circular_x ,
2708+ conv_params.circular_y ,
2709+ conv_params.scale );
2710+
26982711 } else {
2699- int rank = (int)w2b->ne[1];
2700- int k = (int)sqrt(w2b->ne[0] / vq);
2701- struct ggml_tensor* w2b_4d = (ggml_n_dims(w2b) < 4) ? ggml_reshape_4d(ctx, w2b, k, k, vq, rank) : w2b;
2702- struct ggml_tensor* w2a_4d = (ggml_n_dims(w2a) < 4) ? ggml_reshape_4d(ctx, w2a, 1, 1, rank, vp) : w2a;
2712+ // TODO: do not merge (loCon forward)
2713+ // w2a could be 2d
2714+ w2 = ggml_ext_merge_lora (ctx, w2b, w2a);
2715+ if (ggml_n_dims (w2) < 4 ) {
2716+ w2 = ggml_reshape_4d (ctx, w2, 1 , 1 , w2->ne [0 ], w2->ne [1 ]);
2717+ }
2718+ if (w2->ne [2 ] != h_split->ne [2 ]) {
2719+ int k = sqrt (w2->ne [2 ] / h_split->ne [2 ]);
2720+ GGML_ASSERT (k * k * h_split->ne [2 ] == w2->ne [2 ]);
2721+ w2 = ggml_reshape_4d (ctx, w2, w2->ne [0 ] * k, w2->ne [1 ] * k, w2->ne [2 ] / (k * k), w2->ne [3 ]);
2722+ }
2723+ hb = ggml_ext_conv_2d (ctx, h_split, w2, nullptr ,
2724+ conv_params.s0 ,
2725+ conv_params.s1 ,
2726+ conv_params.p0 ,
2727+ conv_params.p1 ,
2728+ conv_params.d0 ,
2729+ conv_params.d1 ,
2730+ conv_params.direct ,
2731+ conv_params.circular_x ,
2732+ conv_params.circular_y ,
2733+ conv_params.scale );
27032734
2704- struct ggml_tensor* ha = ggml_conv_2d(ctx, w2b_4d, h_grouped, conv_params.s0, conv_params.s1,
2705- conv_params.p0, conv_params.p1, conv_params.d0, conv_params.d1);
2706- hb = ggml_conv_2d(ctx, w2a_4d, ha, 1, 1, 0, 0, 1, 1);
2735+
2736+ // TODO: figure out why this is not working:
2737+ // struct ggml_tensor* ha = ggml_ext_conv_2d(ctx, h_split, w2a, nullptr,
2738+ // conv_params.s0,
2739+ // conv_params.s1,
2740+ // conv_params.p0,
2741+ // conv_params.p1,
2742+ // conv_params.d0,
2743+ // conv_params.d1);
2744+ // // not supporting lora_mid here
2745+ // hb = ggml_ext_conv_2d(ctx,
2746+ // ha,
2747+ // w2b,
2748+ // nullptr,
2749+ // 1,
2750+ // 1,
2751+ // 0,
2752+ // 0,
2753+ // 1,
2754+ // 1,
2755+ // conv_params.direct,
2756+ // conv_params.circular_x,
2757+ // conv_params.circular_y,
2758+ // conv_params.scale);
27072759 }
27082760
27092761 // Current hb shape: [W_out, H_out, vp, uq * batch]
27102762 int w_out = (int )hb->ne [0 ];
27112763 int h_out = (int )hb->ne [1 ];
27122764
2713- // 2. Prepare for Matrix Multiplication
2714- // Collapse spatial and 'vp' into one dimension to treat as 'M' in MatMul
2715- // Shape: [W*H*vp, uq, batch]
2716- struct ggml_tensor* hb_flat = ggml_reshape_3d(ctx, hb, w_out * h_out * vp, uq, batch);
2717- // Transpose to [uq, W*H*vp, batch] so that uq is ne[0] (the shared K dimension)
2718- struct ggml_tensor* hb_t = ggml_transpose(ctx, hb_flat);
2765+ // struct ggml_tensor* hb_cat = ggml_reshape_4d(ctx, hb, w_out , h_out , vp * uq, batch);
2766+ // [W_out, H_out, vp * uq, batch]
2767+ // Now left to compute (W1 kr Id) * hb_cat == (W1 kr W2) * h
27192768
2720- struct ggml_tensor* hc;
2769+ // merge the uq groups of size vp*w_out*h_out
2770+ struct ggml_tensor * hb_merged = ggml_reshape_2d (ctx, hb, w_out * h_out * vp, uq * batch);
2771+ struct ggml_tensor * hc_t ;
2772+ struct ggml_tensor * hb_merged_t = ggml_cont (ctx, ggml_transpose (ctx, hb_merged));
27212773 if (w1 != NULL ) {
2722- struct ggml_tensor* w1_mat = ggml_reshape_2d(ctx, w1, uq, up);
2723- hc = ggml_mul_mat(ctx, w1_mat, hb_t);
2774+ // Would be great to be able to transpose w1 instead to avoid transposing both hb and hc
2775+ hc_t = ggml_mul_mat (ctx, w1, hb_merged_t );
27242776 } else {
2725- // Low-rank: (up x rank) * (rank x uq) * (uq x Spatial)
2726- hc = ggml_mul_mat(ctx, w1b, ggml_mul_mat(ctx, w1a, hb_t));
2777+ hc_t = ggml_mul_mat (ctx, w1b, ggml_mul_mat (ctx, w1a, hb_merged_t ));
27272778 }
2728-
2729- // 3. Final Layout Transformation
2730- // Current hc shape: [up, W*H*vp, batch]
2731- // Logical dims in ne[1]: [W*H, vp]
2732- // We want final shape: [W, H, up*vp, batch]
2733-
2734- // Split ne[1] back into spatial and vp
2735- struct ggml_tensor* hc_split = ggml_reshape_4d(ctx, hc, up, w_out * h_out, vp, batch);
2736-
2737- // Permute to bring up and vp together: [spatial, up, vp, batch]
2738- // This moves spatial to ne[0], which is necessary for the final W,H,C layout
2739- struct ggml_tensor* hc_perm = ggml_permute(ctx, hc_split, 1, 0, 2, 3);
2740-
2741- // Resolve layout and scale in one go (if possible) or just cont
2742- // This is the only mandatory copy
2743- struct ggml_tensor* out_cont = ggml_cont(ctx, hc_perm);
2744-
2745- // Final reshape to merge up and vp into the channel dimension
2746- struct ggml_tensor* out = ggml_reshape_4d(ctx, out_cont, w_out, h_out, up * vp, batch);
2747-
2779+ struct ggml_tensor * hc = ggml_transpose (ctx, hc_t );
2780+ hc = ggml_cont (ctx, hc);
2781+ struct ggml_tensor * out = ggml_reshape_4d (ctx, hc, w_out, h_out, up * vp, batch);
27482782 return ggml_scale (ctx, out, scale);
27492783#else
27502784 // compute the weight diff and do a single conv
27512785 if (w1 == NULL) {
27522786 w1 = ggml_ext_merge_lora(ctx, w1b, w1a);
27532787 }
2754- if (ggml_n_dims (w1) < 4 ){
2788+ if (ggml_n_dims(w1) < 4) {
27552789 w1 = ggml_reshape_4d(ctx, w1, 1, 1, w1->ne[0], w1->ne[1]);
27562790 }
27572791 if (w2 == NULL) {
27582792 w2 = ggml_ext_merge_lora(ctx, w2b, w2a);
27592793 }
2760- if (ggml_n_dims (w2) < 4 ){
2794+ if (ggml_n_dims(w2) < 4) {
27612795 w2 = ggml_reshape_4d(ctx, w2, 1, 1, w2->ne[0], w2->ne[1]);
27622796 }
2763- if (w2->ne [2 ] * w1->ne [2 ] != h->ne [2 ]){
2764- int k = sqrt (w2->ne [2 ] * w1->ne [2 ]/ h->ne [2 ]);
2765- GGML_ASSERT (k* k * h->ne [2 ] == w2->ne [2 ] * w1->ne [2 ]);
2766- w2 = ggml_reshape_4d (ctx, w2, w2->ne [0 ]* k, w2->ne [1 ]* k, w2->ne [2 ]/(k* k), w2->ne [3 ]);
2797+ if (w2->ne[2] * w1->ne[2] != h->ne[2]) {
2798+ int k = sqrt(w2->ne[2] * w1->ne[2] / h->ne[2]);
2799+ GGML_ASSERT(k * k * h->ne[2] == w2->ne[2] * w1->ne[2]);
2800+ w2 = ggml_reshape_4d(ctx, w2, w2->ne[0] * k, w2->ne[1] * k, w2->ne[2] / (k * k), w2->ne[3]);
27672801 }
2768- w1 = ggml_ext_cast_f32 (ctx, w1);
2769- w2 = ggml_ext_cast_f32 (ctx, w2);
2770- struct ggml_tensor * w = ggml_ext_kronecker (ctx, w1, w2);
2771- struct ggml_tensor * out = ggml_conv_2d (ctx, w, h , conv_params.s0 , conv_params.s1 , conv_params.p0 , conv_params.p1 , conv_params.d0 , conv_params.d1 );
2802+ w1 = ggml_ext_cast_f32(ctx, w1);
2803+ w2 = ggml_ext_cast_f32(ctx, w2);
2804+ struct ggml_tensor* w = ggml_ext_kronecker(ctx, w1, w2);
2805+ struct ggml_tensor* out = ggml_ext_conv_2d (ctx, h, w, nullptr , conv_params.s0, conv_params.s1, conv_params.p0, conv_params.p1, conv_params.d0, conv_params.d1, conv_params.direct, conv_params.circular_x, conv_params.circular_y, conv_params.scale );
27722806
27732807 return ggml_scale(ctx, out, scale);
27742808 #endif
0 commit comments