@@ -299,19 +299,20 @@ struct ggml_tensor* patchify(struct ggml_context* ctx,
299299 int64_t r = patch_size;
300300 int64_t q = patch_size;
301301
302- int64_t w_in = x->ne [0 ];
303- int64_t h_in = x->ne [1 ];
304- int64_t cb = x->ne [2 ]; // b*c
305- int64_t f = x->ne [3 ];
302+ int64_t W = x->ne [0 ];
303+ int64_t H = x->ne [1 ];
304+ int64_t C = x->ne [2 ];
305+ int64_t f = x->ne [3 ];
306+
307+ int64_t w = W / r;
308+ int64_t h = H / q;
309+
310+ x = ggml_reshape_4d (ctx, x, W, q, h, C * f); // [W, q, h, C*f]
311+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [W, h, q, C*f]
312+ x = ggml_reshape_4d (ctx, x, r, w, h, q * C * f); // [r, w, h, q*C*f]
313+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 1 , 2 , 0 , 3 )); // [w, h, r, q*C*f]
314+ x = ggml_reshape_4d (ctx, x, w, h, r * q * C, f); // [f, b*c*r*q, h, w]
306315
307- int64_t w = w_in / r;
308- int64_t h = h_in / q;
309-
310- x = ggml_reshape_4d (ctx, x, w, r, h_in, cb * f); // [f*b*c, h*q, r, w]
311- x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [f*b*c, r, h*q, w]
312- x = ggml_reshape_4d (ctx, x, w, q, h, r * cb * f); // [f*b*c*r, h, q, w]
313- x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [f*b*c*r, q, h, w]
314- x = ggml_reshape_4d (ctx, x, w, h, q * r * cb, f); // [f, b*c*r*q, h, w]
315316 return x;
316317}
317318
@@ -332,12 +333,12 @@ struct ggml_tensor* unpatchify(struct ggml_context* ctx,
332333 int64_t w = x->ne [0 ];
333334
334335
335- x = ggml_reshape_4d (ctx, x, w * h, q * r, c * b, f); // [f, b*c , r*q , h* w]
336- x = ggml_reshape_4d (ctx, x, w, h * q, r, f * c * b); // [f*b*c, r, q* h, w ]
337- x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 2 , 0 , 1 , 3 )) ; // [f *b*c , q*h, w , r]
338- x = ggml_reshape_4d (ctx, x, r * w, h, q, f * c * b); // [f*b*c , q, h, w*r ]
339- x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [f*b*c, h, q, w*r]
340- x = ggml_reshape_4d (ctx, x, r * w, q * h, c * b, f); // [f, b*c, h*q, w*r]
336+ x = ggml_reshape_4d (ctx, x, w, h, r, q * c * b * f); // [q*c*b*f , r, h, w]
337+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 2 , 0 , 1 , 3 )); // [r, w, h, q*c*b*f ]
338+ x = ggml_reshape_4d (ctx, x, r * w, h, q, c * b * f) ; // [c *b*f , q, h , r*w ]
339+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [r*w , q, h, c*b*f ]
340+ x = ggml_reshape_4d (ctx, x, r * w, q * h, c * b, f);
341+
341342 return x;
342343}
343344
0 commit comments