Skip to content

Commit 2430989

Browse files
committed
LoKR: re-implement conv
1 parent 8553862 commit 2430989

1 file changed

Lines changed: 93 additions & 59 deletions

File tree

ggml_extend.hpp

Lines changed: 93 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,7 +1577,7 @@ struct WeightAdapter {
15771577
bool force_prec_f32 = false;
15781578
float scale = 1.f;
15791579
} linear;
1580-
struct conv2d_params_t{
1580+
struct conv2d_params_t {
15811581
int s0 = 1;
15821582
int s1 = 1;
15831583
int p0 = 0;
@@ -2642,7 +2642,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_lokr_forward(
26422642
bool is_conv,
26432643
WeightAdapter::ForwardParams::conv2d_params_t conv_params,
26442644
float scale) {
2645-
26462645
GGML_ASSERT((w1 != NULL || (w1a != NULL && w1b != NULL)));
26472646
GGML_ASSERT((w2 != NULL || (w2a != NULL && w2b != NULL)));
26482647

@@ -2660,16 +2659,16 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_lokr_forward(
26602659

26612660
if (!is_conv) {
26622661
int batch = (int)h->ne[1];
2663-
struct ggml_tensor* h_mat = ggml_reshape_2d(ctx, h, vq, uq * batch);
2662+
struct ggml_tensor* h_split = ggml_reshape_2d(ctx, h, vq, uq * batch);
26642663

26652664
if (w2 != NULL) {
2666-
hb = ggml_mul_mat(ctx, w2, h_mat);
2665+
hb = ggml_mul_mat(ctx, w2, h_split);
26672666
} else {
2668-
hb = ggml_mul_mat(ctx, w2b, ggml_mul_mat(ctx, w2a, h_mat));
2667+
hb = ggml_mul_mat(ctx, w2b, ggml_mul_mat(ctx, w2a, h_split));
26692668
}
26702669

2671-
struct ggml_tensor* hb_unbundled = ggml_reshape_3d(ctx, hb, vp, uq, batch);
2672-
struct ggml_tensor* hb_t = ggml_cont(ctx,ggml_transpose(ctx, hb_unbundled));
2670+
struct ggml_tensor* hb_cat = ggml_reshape_3d(ctx, hb, vp, uq, batch);
2671+
struct ggml_tensor* hb_t = ggml_cont(ctx, ggml_transpose(ctx, hb_cat));
26732672

26742673
struct ggml_tensor* hc;
26752674
if (w1 != NULL) {
@@ -2683,92 +2682,127 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_lokr_forward(
26832682
return ggml_scale(ctx, out, scale);
26842683

26852684
} else {
2686-
#if 0
2685+
#if 1
26872686
// very slow implementation for now (can this be optimized?)
26882687
int batch = (int)h->ne[3];
26892688

26902689
// 1. Reshape input: [W, H, vq*uq, batch] -> [W, H, vq, uq * batch]
26912690
// This is free (metadata only)
2692-
struct ggml_tensor* h_grouped = ggml_reshape_4d(ctx, h, h->ne[0], h->ne[1], vq, uq * batch);
2691+
// print_ggml_tensor(h, true, "\nh");
2692+
struct ggml_tensor* h_split = ggml_reshape_4d(ctx, h, h->ne[0], h->ne[1], vq, uq * batch);
2693+
// print_ggml_tensor(h_split, true, "h_split");
26932694

26942695
struct ggml_tensor* hb;
26952696
if (w2 != NULL) {
2696-
hb = ggml_conv_2d(ctx, w2, h_grouped, conv_params.s0, conv_params.s1,
2697-
conv_params.p0, conv_params.p1, conv_params.d0, conv_params.d1);
2697+
// no LoRA
2698+
// print_ggml_tensor(w2, true, "w2");
2699+
hb = ggml_ext_conv_2d(ctx, h_split, w2, nullptr,
2700+
conv_params.s0,
2701+
conv_params.s1,
2702+
conv_params.p0,
2703+
conv_params.p1,
2704+
conv_params.d0,
2705+
conv_params.d1,
2706+
conv_params.direct,
2707+
conv_params.circular_x,
2708+
conv_params.circular_y,
2709+
conv_params.scale);
2710+
26982711
} else {
2699-
int rank = (int)w2b->ne[1];
2700-
int k = (int)sqrt(w2b->ne[0] / vq);
2701-
struct ggml_tensor* w2b_4d = (ggml_n_dims(w2b) < 4) ? ggml_reshape_4d(ctx, w2b, k, k, vq, rank) : w2b;
2702-
struct ggml_tensor* w2a_4d = (ggml_n_dims(w2a) < 4) ? ggml_reshape_4d(ctx, w2a, 1, 1, rank, vp) : w2a;
2712+
// TODO: do not merge (loCon forward)
2713+
// w2a could be 2d
2714+
w2 = ggml_ext_merge_lora(ctx, w2b, w2a);
2715+
if (ggml_n_dims(w2) < 4) {
2716+
w2 = ggml_reshape_4d(ctx, w2, 1, 1, w2->ne[0], w2->ne[1]);
2717+
}
2718+
if (w2->ne[2] != h_split->ne[2]) {
2719+
int k = sqrt(w2->ne[2] / h_split->ne[2]);
2720+
GGML_ASSERT(k * k * h_split->ne[2] == w2->ne[2]);
2721+
w2 = ggml_reshape_4d(ctx, w2, w2->ne[0] * k, w2->ne[1] * k, w2->ne[2] / (k * k), w2->ne[3]);
2722+
}
2723+
hb = ggml_ext_conv_2d(ctx, h_split, w2, nullptr,
2724+
conv_params.s0,
2725+
conv_params.s1,
2726+
conv_params.p0,
2727+
conv_params.p1,
2728+
conv_params.d0,
2729+
conv_params.d1,
2730+
conv_params.direct,
2731+
conv_params.circular_x,
2732+
conv_params.circular_y,
2733+
conv_params.scale);
27032734

2704-
struct ggml_tensor* ha = ggml_conv_2d(ctx, w2b_4d, h_grouped, conv_params.s0, conv_params.s1,
2705-
conv_params.p0, conv_params.p1, conv_params.d0, conv_params.d1);
2706-
hb = ggml_conv_2d(ctx, w2a_4d, ha, 1, 1, 0, 0, 1, 1);
2735+
2736+
// TODO: figure out why this is not working:
2737+
// struct ggml_tensor* ha = ggml_ext_conv_2d(ctx, h_split, w2a, nullptr,
2738+
// conv_params.s0,
2739+
// conv_params.s1,
2740+
// conv_params.p0,
2741+
// conv_params.p1,
2742+
// conv_params.d0,
2743+
// conv_params.d1);
2744+
// // not supporting lora_mid here
2745+
// hb = ggml_ext_conv_2d(ctx,
2746+
// ha,
2747+
// w2b,
2748+
// nullptr,
2749+
// 1,
2750+
// 1,
2751+
// 0,
2752+
// 0,
2753+
// 1,
2754+
// 1,
2755+
// conv_params.direct,
2756+
// conv_params.circular_x,
2757+
// conv_params.circular_y,
2758+
// conv_params.scale);
27072759
}
27082760

27092761
// Current hb shape: [W_out, H_out, vp, uq * batch]
27102762
int w_out = (int)hb->ne[0];
27112763
int h_out = (int)hb->ne[1];
27122764

2713-
// 2. Prepare for Matrix Multiplication
2714-
// Collapse spatial and 'vp' into one dimension to treat as 'M' in MatMul
2715-
// Shape: [W*H*vp, uq, batch]
2716-
struct ggml_tensor* hb_flat = ggml_reshape_3d(ctx, hb, w_out * h_out * vp, uq, batch);
2717-
// Transpose to [uq, W*H*vp, batch] so that uq is ne[0] (the shared K dimension)
2718-
struct ggml_tensor* hb_t = ggml_transpose(ctx, hb_flat);
2765+
// struct ggml_tensor* hb_cat = ggml_reshape_4d(ctx, hb, w_out , h_out , vp * uq, batch);
2766+
// [W_out, H_out, vp * uq, batch]
2767+
// Now left to compute (W1 kr Id) * hb_cat == (W1 kr W2) * h
27192768

2720-
struct ggml_tensor* hc;
2769+
// merge the uq groups of size vp*w_out*h_out
2770+
struct ggml_tensor* hb_merged = ggml_reshape_2d(ctx, hb, w_out * h_out * vp, uq * batch);
2771+
struct ggml_tensor* hc_t;
2772+
struct ggml_tensor* hb_merged_t = ggml_cont(ctx, ggml_transpose(ctx, hb_merged));
27212773
if (w1 != NULL) {
2722-
struct ggml_tensor* w1_mat = ggml_reshape_2d(ctx, w1, uq, up);
2723-
hc = ggml_mul_mat(ctx, w1_mat, hb_t);
2774+
// Would be great to be able to transpose w1 instead to avoid transposing both hb and hc
2775+
hc_t = ggml_mul_mat(ctx, w1, hb_merged_t);
27242776
} else {
2725-
// Low-rank: (up x rank) * (rank x uq) * (uq x Spatial)
2726-
hc = ggml_mul_mat(ctx, w1b, ggml_mul_mat(ctx, w1a, hb_t));
2777+
hc_t = ggml_mul_mat(ctx, w1b, ggml_mul_mat(ctx, w1a, hb_merged_t));
27272778
}
2728-
2729-
// 3. Final Layout Transformation
2730-
// Current hc shape: [up, W*H*vp, batch]
2731-
// Logical dims in ne[1]: [W*H, vp]
2732-
// We want final shape: [W, H, up*vp, batch]
2733-
2734-
// Split ne[1] back into spatial and vp
2735-
struct ggml_tensor* hc_split = ggml_reshape_4d(ctx, hc, up, w_out * h_out, vp, batch);
2736-
2737-
// Permute to bring up and vp together: [spatial, up, vp, batch]
2738-
// This moves spatial to ne[0], which is necessary for the final W,H,C layout
2739-
struct ggml_tensor* hc_perm = ggml_permute(ctx, hc_split, 1, 0, 2, 3);
2740-
2741-
// Resolve layout and scale in one go (if possible) or just cont
2742-
// This is the only mandatory copy
2743-
struct ggml_tensor* out_cont = ggml_cont(ctx, hc_perm);
2744-
2745-
// Final reshape to merge up and vp into the channel dimension
2746-
struct ggml_tensor* out = ggml_reshape_4d(ctx, out_cont, w_out, h_out, up * vp, batch);
2747-
2779+
struct ggml_tensor* hc = ggml_transpose(ctx, hc_t);
2780+
hc = ggml_cont(ctx, hc);
2781+
struct ggml_tensor* out = ggml_reshape_4d(ctx, hc, w_out, h_out, up * vp, batch);
27482782
return ggml_scale(ctx, out, scale);
27492783
#else
27502784
// compute the weight diff and do a single conv
27512785
if (w1 == NULL) {
27522786
w1 = ggml_ext_merge_lora(ctx, w1b, w1a);
27532787
}
2754-
if(ggml_n_dims(w1) < 4){
2788+
if (ggml_n_dims(w1) < 4) {
27552789
w1 = ggml_reshape_4d(ctx, w1, 1, 1, w1->ne[0], w1->ne[1]);
27562790
}
27572791
if (w2 == NULL) {
27582792
w2 = ggml_ext_merge_lora(ctx, w2b, w2a);
27592793
}
2760-
if(ggml_n_dims(w2) < 4){
2794+
if (ggml_n_dims(w2) < 4) {
27612795
w2 = ggml_reshape_4d(ctx, w2, 1, 1, w2->ne[0], w2->ne[1]);
27622796
}
2763-
if(w2->ne[2] * w1->ne[2] != h->ne[2]){
2764-
int k = sqrt(w2->ne[2] * w1->ne[2]/h->ne[2]);
2765-
GGML_ASSERT(k*k * h->ne[2] == w2->ne[2] * w1->ne[2]);
2766-
w2 = ggml_reshape_4d(ctx, w2, w2->ne[0]*k, w2->ne[1]*k, w2->ne[2]/(k*k), w2->ne[3]);
2797+
if (w2->ne[2] * w1->ne[2] != h->ne[2]) {
2798+
int k = sqrt(w2->ne[2] * w1->ne[2] / h->ne[2]);
2799+
GGML_ASSERT(k * k * h->ne[2] == w2->ne[2] * w1->ne[2]);
2800+
w2 = ggml_reshape_4d(ctx, w2, w2->ne[0] * k, w2->ne[1] * k, w2->ne[2] / (k * k), w2->ne[3]);
27672801
}
2768-
w1 = ggml_ext_cast_f32(ctx, w1);
2769-
w2 = ggml_ext_cast_f32(ctx, w2);
2770-
struct ggml_tensor* w = ggml_ext_kronecker(ctx, w1, w2);
2771-
struct ggml_tensor* out = ggml_conv_2d(ctx, w, h, conv_params.s0, conv_params.s1, conv_params.p0, conv_params.p1, conv_params.d0, conv_params.d1);
2802+
w1 = ggml_ext_cast_f32(ctx, w1);
2803+
w2 = ggml_ext_cast_f32(ctx, w2);
2804+
struct ggml_tensor* w = ggml_ext_kronecker(ctx, w1, w2);
2805+
struct ggml_tensor* out = ggml_ext_conv_2d(ctx, h, w, nullptr, conv_params.s0, conv_params.s1, conv_params.p0, conv_params.p1, conv_params.d0, conv_params.d1, conv_params.direct, conv_params.circular_x, conv_params.circular_y, conv_params.scale);
27722806
27732807
return ggml_scale(ctx, out, scale);
27742808
#endif

0 commit comments

Comments
 (0)