@@ -236,6 +236,60 @@ class MemBlock : public GGMLBlock {
236236 }
237237};
238238
239+ struct ggml_tensor * patchify (struct ggml_context * ctx,
240+ struct ggml_tensor * x,
241+ int64_t patch_size,
242+ int64_t b = 1 ) {
243+ // x: [f, b*c, h*q, w*r]
244+ // return: [f, b*c*r*q, h, w]
245+ if (patch_size == 1 ) {
246+ return x;
247+ }
248+ int64_t r = patch_size;
249+ int64_t q = patch_size;
250+
251+ int64_t W = x->ne [0 ];
252+ int64_t H = x->ne [1 ];
253+ int64_t C = x->ne [2 ];
254+ int64_t f = x->ne [3 ];
255+
256+ int64_t w = W / r;
257+ int64_t h = H / q;
258+
259+ x = ggml_reshape_4d (ctx, x, W, q, h, C * f); // [W, q, h, C*f]
260+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [W, h, q, C*f]
261+ x = ggml_reshape_4d (ctx, x, r, w, h, q * C * f); // [r, w, h, q*C*f]
262+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 1 , 2 , 0 , 3 )); // [w, h, r, q*C*f]
263+ x = ggml_reshape_4d (ctx, x, w, h, r * q * C, f); // [f, b*c*r*q, h, w]
264+
265+ return x;
266+ }
267+
268+ struct ggml_tensor * unpatchify (struct ggml_context * ctx,
269+ struct ggml_tensor * x,
270+ int64_t patch_size,
271+ int64_t b = 1 ) {
272+ // x: [f, b*c*r*q, h, w]
273+ // return: [f, b*c, h*q, w*r]
274+ if (patch_size == 1 ) {
275+ return x;
276+ }
277+ int64_t r = patch_size;
278+ int64_t q = patch_size;
279+ int64_t c = x->ne [2 ] / b / q / r;
280+ int64_t f = x->ne [3 ];
281+ int64_t h = x->ne [1 ];
282+ int64_t w = x->ne [0 ];
283+
284+ x = ggml_reshape_4d (ctx, x, w, h, r, q * c * b * f); // [q*c*b*f, r, h, w]
285+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 2 , 0 , 1 , 3 )); // [r, w, h, q*c*b*f]
286+ x = ggml_reshape_4d (ctx, x, r * w, h, q, c * b * f); // [c*b*f, q, h, r*w]
287+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [r*w, q, h, c*b*f]
288+ x = ggml_reshape_4d (ctx, x, r * w, q * h, c * b, f);
289+
290+ return x;
291+ }
292+
239293class TinyVideoEncoder : public UnaryBlock {
240294 int in_channels = 3 ;
241295 int hidden = 64 ;
@@ -263,8 +317,13 @@ class TinyVideoEncoder : public UnaryBlock {
263317
264318 struct ggml_tensor * forward (GGMLRunnerContext* ctx, struct ggml_tensor * z) override {
265319 auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks[" 0" ]);
266- auto h = first_conv->forward (ctx, z);
267- h = ggml_relu_inplace (ctx->ggml_ctx , h);
320+
321+ if (patch_size > 1 ) {
322+ z = patchify (ctx->ggml_ctx , z, patch_size, 1 );
323+ }
324+
325+ auto h = first_conv->forward (ctx, z);
326+ h = ggml_relu_inplace (ctx->ggml_ctx , h);
268327
269328 int index = 2 ;
270329 for (int i = 0 ; i < num_layers; i++) {
@@ -286,62 +345,6 @@ class TinyVideoEncoder : public UnaryBlock {
286345 }
287346};
288347
289-
290- struct ggml_tensor * patchify (struct ggml_context * ctx,
291- struct ggml_tensor * x,
292- int64_t patch_size,
293- int64_t b = 1 ) {
294- // x: [f, b*c, h*q, w*r]
295- // return: [f, b*c*r*q, h, w]
296- if (patch_size == 1 ) {
297- return x;
298- }
299- int64_t r = patch_size;
300- int64_t q = patch_size;
301-
302- int64_t W = x->ne [0 ];
303- int64_t H = x->ne [1 ];
304- int64_t C = x->ne [2 ];
305- int64_t f = x->ne [3 ];
306-
307- int64_t w = W / r;
308- int64_t h = H / q;
309-
310- x = ggml_reshape_4d (ctx, x, W, q, h, C * f); // [W, q, h, C*f]
311- x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [W, h, q, C*f]
312- x = ggml_reshape_4d (ctx, x, r, w, h, q * C * f); // [r, w, h, q*C*f]
313- x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 1 , 2 , 0 , 3 )); // [w, h, r, q*C*f]
314- x = ggml_reshape_4d (ctx, x, w, h, r * q * C, f); // [f, b*c*r*q, h, w]
315-
316- return x;
317- }
318-
319- struct ggml_tensor * unpatchify (struct ggml_context * ctx,
320- struct ggml_tensor * x,
321- int64_t patch_size,
322- int64_t b = 1 ) {
323- // x: [f, b*c*r*q, h, w]
324- // return: [f, b*c, h*q, w*r]
325- if (patch_size == 1 ) {
326- return x;
327- }
328- int64_t r = patch_size;
329- int64_t q = patch_size;
330- int64_t c = x->ne [2 ] / b / q / r;
331- int64_t f = x->ne [3 ];
332- int64_t h = x->ne [1 ];
333- int64_t w = x->ne [0 ];
334-
335-
336- x = ggml_reshape_4d (ctx, x, w, h, r, q * c * b * f); // [q*c*b*f, r, h, w]
337- x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 2 , 0 , 1 , 3 )); // [r, w, h, q*c*b*f]
338- x = ggml_reshape_4d (ctx, x, r * w, h, q, c * b * f); // [c*b*f, q, h, r*w]
339- x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [r*w, q, h, c*b*f]
340- x = ggml_reshape_4d (ctx, x, r * w, q * h, c * b, f);
341-
342- return x;
343- }
344-
345348class TinyVideoDecoder : public UnaryBlock {
346349 int z_channels = 4 ;
347350 int out_channels = 3 ;
0 commit comments