Skip to content

Commit a7a791d

Browse files
committed
taehv: patchify encode
1 parent e4cbcdc commit a7a791d

File tree

1 file changed

+61
-58
lines changed

1 file changed

+61
-58
lines changed

tae.hpp

Lines changed: 61 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,60 @@ class MemBlock : public GGMLBlock {
236236
}
237237
};
238238

239+
struct ggml_tensor* patchify(struct ggml_context* ctx,
240+
struct ggml_tensor* x,
241+
int64_t patch_size,
242+
int64_t b = 1) {
243+
// x: [f, b*c, h*q, w*r]
244+
// return: [f, b*c*r*q, h, w]
245+
if (patch_size == 1) {
246+
return x;
247+
}
248+
int64_t r = patch_size;
249+
int64_t q = patch_size;
250+
251+
int64_t W = x->ne[0];
252+
int64_t H = x->ne[1];
253+
int64_t C = x->ne[2];
254+
int64_t f = x->ne[3];
255+
256+
int64_t w = W / r;
257+
int64_t h = H / q;
258+
259+
x = ggml_reshape_4d(ctx, x, W, q, h, C * f); // [W, q, h, C*f]
260+
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [W, h, q, C*f]
261+
x = ggml_reshape_4d(ctx, x, r, w, h, q * C * f); // [r, w, h, q*C*f]
262+
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [w, h, r, q*C*f]
263+
x = ggml_reshape_4d(ctx, x, w, h, r * q * C, f); // [f, b*c*r*q, h, w]
264+
265+
return x;
266+
}
267+
268+
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
269+
struct ggml_tensor* x,
270+
int64_t patch_size,
271+
int64_t b = 1) {
272+
// x: [f, b*c*r*q, h, w]
273+
// return: [f, b*c, h*q, w*r]
274+
if (patch_size == 1) {
275+
return x;
276+
}
277+
int64_t r = patch_size;
278+
int64_t q = patch_size;
279+
int64_t c = x->ne[2] / b / q / r;
280+
int64_t f = x->ne[3];
281+
int64_t h = x->ne[1];
282+
int64_t w = x->ne[0];
283+
284+
x = ggml_reshape_4d(ctx, x, w, h, r, q * c * b * f); // [q*c*b*f, r, h, w]
285+
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [r, w, h, q*c*b*f]
286+
x = ggml_reshape_4d(ctx, x, r * w, h, q, c * b * f); // [c*b*f, q, h, r*w]
287+
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [r*w, q, h, c*b*f]
288+
x = ggml_reshape_4d(ctx, x, r * w, q * h, c * b, f);
289+
290+
return x;
291+
}
292+
239293
class TinyVideoEncoder : public UnaryBlock {
240294
int in_channels = 3;
241295
int hidden = 64;
@@ -263,8 +317,13 @@ class TinyVideoEncoder : public UnaryBlock {
263317

264318
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override {
265319
auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["0"]);
266-
auto h = first_conv->forward(ctx, z);
267-
h = ggml_relu_inplace(ctx->ggml_ctx, h);
320+
321+
if (patch_size > 1) {
322+
z = patchify(ctx->ggml_ctx, z, patch_size, 1);
323+
}
324+
325+
auto h = first_conv->forward(ctx, z);
326+
h = ggml_relu_inplace(ctx->ggml_ctx, h);
268327

269328
int index = 2;
270329
for (int i = 0; i < num_layers; i++) {
@@ -286,62 +345,6 @@ class TinyVideoEncoder : public UnaryBlock {
286345
}
287346
};
288347

289-
290-
struct ggml_tensor* patchify(struct ggml_context* ctx,
291-
struct ggml_tensor* x,
292-
int64_t patch_size,
293-
int64_t b = 1) {
294-
// x: [f, b*c, h*q, w*r]
295-
// return: [f, b*c*r*q, h, w]
296-
if (patch_size == 1) {
297-
return x;
298-
}
299-
int64_t r = patch_size;
300-
int64_t q = patch_size;
301-
302-
int64_t W = x->ne[0];
303-
int64_t H = x->ne[1];
304-
int64_t C = x->ne[2];
305-
int64_t f = x->ne[3];
306-
307-
int64_t w = W / r;
308-
int64_t h = H / q;
309-
310-
x = ggml_reshape_4d(ctx, x, W, q, h, C * f); // [W, q, h, C*f]
311-
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [W, h, q, C*f]
312-
x = ggml_reshape_4d(ctx, x, r, w, h, q * C * f); // [r, w, h, q*C*f]
313-
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [w, h, r, q*C*f]
314-
x = ggml_reshape_4d(ctx, x, w, h, r * q * C, f); // [f, b*c*r*q, h, w]
315-
316-
return x;
317-
}
318-
319-
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
320-
struct ggml_tensor* x,
321-
int64_t patch_size,
322-
int64_t b = 1) {
323-
// x: [f, b*c*r*q, h, w]
324-
// return: [f, b*c, h*q, w*r]
325-
if (patch_size == 1) {
326-
return x;
327-
}
328-
int64_t r = patch_size;
329-
int64_t q = patch_size;
330-
int64_t c = x->ne[2] / b / q / r;
331-
int64_t f = x->ne[3];
332-
int64_t h = x->ne[1];
333-
int64_t w = x->ne[0];
334-
335-
336-
x = ggml_reshape_4d(ctx, x, w, h, r, q * c * b * f); // [q*c*b*f, r, h, w]
337-
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [r, w, h, q*c*b*f]
338-
x = ggml_reshape_4d(ctx, x, r * w, h, q, c * b * f); // [c*b*f, q, h, r*w]
339-
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [r*w, q, h, c*b*f]
340-
x = ggml_reshape_4d(ctx, x, r * w, q * h, c * b, f);
341-
342-
return x;
343-
}
344-
345348
class TinyVideoDecoder : public UnaryBlock {
346349
int z_channels = 4;
347350
int out_channels = 3;

0 commit comments

Comments
 (0)