Skip to content

Commit e4cbcdc

Browse files
committed
fix patched pixels order
1 parent 068d928 commit e4cbcdc

File tree

1 file changed

+19
-18
lines changed

1 file changed

+19
-18
lines changed

tae.hpp

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -299,19 +299,20 @@ struct ggml_tensor* patchify(struct ggml_context* ctx,
299299
int64_t r = patch_size;
300300
int64_t q = patch_size;
301301

302-
int64_t w_in = x->ne[0];
303-
int64_t h_in = x->ne[1];
304-
int64_t cb = x->ne[2]; // b*c
305-
int64_t f = x->ne[3];
302+
int64_t W = x->ne[0];
303+
int64_t H = x->ne[1];
304+
int64_t C = x->ne[2];
305+
int64_t f = x->ne[3];
306+
307+
int64_t w = W / r;
308+
int64_t h = H / q;
309+
310+
x = ggml_reshape_4d(ctx, x, W, q, h, C * f); // [W, q, h, C*f]
311+
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [W, h, q, C*f]
312+
x = ggml_reshape_4d(ctx, x, r, w, h, q * C * f); // [r, w, h, q*C*f]
313+
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [w, h, r, q*C*f]
314+
x = ggml_reshape_4d(ctx, x, w, h, r * q * C, f); // [f, b*c*r*q, h, w]
306315

307-
int64_t w = w_in / r;
308-
int64_t h = h_in / q;
309-
310-
x = ggml_reshape_4d(ctx, x, w, r, h_in, cb * f); // [f*b*c, h*q, r, w]
311-
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [f*b*c, r, h*q, w]
312-
x = ggml_reshape_4d(ctx, x, w, q, h, r * cb * f); // [f*b*c*r, h, q, w]
313-
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [f*b*c*r, q, h, w]
314-
x = ggml_reshape_4d(ctx, x, w, h, q * r * cb, f); // [f, b*c*r*q, h, w]
315316
return x;
316317
}
317318

@@ -332,12 +333,12 @@ struct ggml_tensor* unpatchify(struct ggml_context* ctx,
332333
int64_t w = x->ne[0];
333334

334335

335-
x = ggml_reshape_4d(ctx, x, w * h, q * r, c * b, f); // [f, b*c, r*q, h*w]
336-
x = ggml_reshape_4d(ctx, x, w, h * q, r, f * c * b); // [f*b*c, r, q*h, w]
337-
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [f*b*c, q*h, w, r]
338-
x = ggml_reshape_4d(ctx, x, r * w, h, q, f * c * b); // [f*b*c, q, h, w*r]
339-
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [f*b*c, h, q, w*r]
340-
x = ggml_reshape_4d(ctx, x, r * w, q * h, c * b, f); // [f, b*c, h*q, w*r]
336+
x = ggml_reshape_4d(ctx, x, w, h, r, q * c * b * f); // [q*c*b*f, r, h, w]
337+
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [r, w, h, q*c*b*f]
338+
x = ggml_reshape_4d(ctx, x, r * w, h, q, c * b * f); // [c*b*f, q, h, r*w]
339+
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [r*w, q, h, c*b*f]
340+
x = ggml_reshape_4d(ctx, x, r * w, q * h, c * b, f);
341+
341342
return x;
342343
}
343344

0 commit comments

Comments
 (0)