@@ -264,13 +264,13 @@ class TinyVideoEncoder : public UnaryBlock {
264264 struct ggml_tensor * forward (GGMLRunnerContext* ctx, struct ggml_tensor * z) override {
265265 auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks[" 0" ]);
266266 auto h = first_conv->forward (ctx, z);
267- h = ggml_relu_inplace (ctx->ggml_ctx , h);
268-
267+ h = ggml_relu_inplace (ctx->ggml_ctx , h);
268+
269269 int index = 2 ;
270270 for (int i = 0 ; i < num_layers; i++) {
271271 auto pool = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string (index++)]);
272272 auto conv = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string (index++)]);
273-
273+
274274 h = pool->forward (ctx, h);
275275 h = conv->forward (ctx, h);
276276 for (int j = 0 ; j < num_blocks; j++) {
@@ -280,21 +280,77 @@ class TinyVideoEncoder : public UnaryBlock {
280280 h = block->forward (ctx, h, mem);
281281 }
282282 }
283- auto last_conv = std::dynamic_pointer_cast<Conv2d>(blocks[std::to_string (index)]);
284- h = last_conv->forward (ctx, h);
283+ auto last_conv = std::dynamic_pointer_cast<Conv2d>(blocks[std::to_string (index)]);
284+ h = last_conv->forward (ctx, h);
285285 return h;
286286 }
287287};
288288
289+
290+ struct ggml_tensor * patchify (struct ggml_context * ctx,
291+ struct ggml_tensor * x,
292+ int64_t patch_size,
293+ int64_t b = 1 ) {
294+ // x: [f, b*c, h*q, w*r]
295+ // return: [f, b*c*r*q, h, w]
296+ if (patch_size == 1 ) {
297+ return x;
298+ }
299+ int64_t r = patch_size;
300+ int64_t q = patch_size;
301+
302+ int64_t w_in = x->ne [0 ];
303+ int64_t h_in = x->ne [1 ];
304+ int64_t cb = x->ne [2 ]; // b*c
305+ int64_t f = x->ne [3 ];
306+
307+ int64_t w = w_in / r;
308+ int64_t h = h_in / q;
309+
310+ x = ggml_reshape_4d (ctx, x, w, r, h_in, cb * f); // [f*b*c, h*q, r, w]
311+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [f*b*c, r, h*q, w]
312+ x = ggml_reshape_4d (ctx, x, w, q, h, r * cb * f); // [f*b*c*r, h, q, w]
313+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [f*b*c*r, q, h, w]
314+ x = ggml_reshape_4d (ctx, x, w, h, q * r * cb, f); // [f, b*c*r*q, h, w]
315+ return x;
316+ }
317+
318+ struct ggml_tensor * unpatchify (struct ggml_context * ctx,
319+ struct ggml_tensor * x,
320+ int64_t patch_size,
321+ int64_t b = 1 ) {
322+ // x: [f, b*c*r*q, h, w]
323+ // return: [f, b*c, h*q, w*r]
324+ if (patch_size == 1 ) {
325+ return x;
326+ }
327+ int64_t r = patch_size;
328+ int64_t q = patch_size;
329+ int64_t c = x->ne [2 ] / b / q / r;
330+ int64_t f = x->ne [3 ];
331+ int64_t h = x->ne [1 ];
332+ int64_t w = x->ne [0 ];
333+
334+
335+ x = ggml_reshape_4d (ctx, x, w * h, q * r, c * b, f); // [f, b*c, r*q, h*w]
336+ x = ggml_reshape_4d (ctx, x, w, h * q, r, f * c * b); // [f*b*c, r, q*h, w]
337+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 2 , 0 , 1 , 3 )); // [f*b*c, q*h, w, r]
338+ x = ggml_reshape_4d (ctx, x, r * w, h, q, f * c * b); // [f*b*c, q, h, w*r]
339+ x = ggml_ext_cont (ctx, ggml_ext_torch_permute (ctx, x, 0 , 2 , 1 , 3 )); // [f*b*c, h, q, w*r]
340+ x = ggml_reshape_4d (ctx, x, r * w, q * h, c * b, f); // [f, b*c, h*q, w*r]
341+ return x;
342+ }
343+
289344class TinyVideoDecoder : public UnaryBlock {
290345 int z_channels = 4 ;
291346 int out_channels = 3 ;
292347 int num_blocks = 3 ;
293348 static const int num_layers = 3 ;
294349 int channels[num_layers + 1 ] = {256 , 128 , 64 , 64 };
350+ int patch_size = 1 ;
295351
296352public:
297- TinyVideoDecoder (int z_channels = 4 , int patch_size = 1 ) : z_channels(z_channels) {
353+ TinyVideoDecoder (int z_channels = 4 , int patch_size = 1 ) : z_channels(z_channels), patch_size(patch_size) {
298354 int index = 1 ; // Clamp()
299355 blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new Conv2d (z_channels, channels[0 ], {3 , 3 }, {1 , 1 }, {1 , 1 }));
300356 index++; // nn.ReLU()
@@ -343,7 +399,9 @@ class TinyVideoDecoder : public UnaryBlock {
343399
344400 auto last_conv = std::dynamic_pointer_cast<Conv2d>(blocks[std::to_string (++index)]);
345401 h = last_conv->forward (ctx, h);
346-
402+ if (patch_size > 1 ) {
403+ h = unpatchify (ctx->ggml_ctx , h, patch_size, 1 );
404+ }
347405 // shape(W, H, 3, 3 + T) => shape(W, H, 3, T)
348406 h = ggml_view_4d (ctx->ggml_ctx , h, h->ne [0 ], h->ne [1 ], h->ne [2 ], h->ne [3 ] - 3 , h->nb [1 ], h->nb [2 ], h->nb [3 ], 3 * h->nb [3 ]);
349407 return h;
@@ -376,7 +434,7 @@ class TAEHV : public GGMLBlock {
376434 // (W, H, C, T) -> (W, H, T, C)
377435 z = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , z, 0 , 1 , 3 , 2 ));
378436 }
379- auto result = decoder->forward (ctx, z);
437+ auto result = decoder->forward (ctx, z);
380438 if (sd_version_is_wan (version)) {
381439 // (W, H, C, T) -> (W, H, T, C)
382440 result = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , result, 0 , 1 , 3 , 2 ));
@@ -387,7 +445,7 @@ class TAEHV : public GGMLBlock {
387445 struct ggml_tensor * encode (GGMLRunnerContext* ctx, struct ggml_tensor * x) {
388446 auto encoder = std::dynamic_pointer_cast<TinyVideoEncoder>(blocks[" encoder" ]);
389447 // (W, H, T, C) -> (W, H, C, T)
390- x = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , x, 0 , 1 , 3 , 2 ));
448+ x = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , x, 0 , 1 , 3 , 2 ));
391449 int64_t num_frames = x->ne [3 ];
392450 if (num_frames % 4 ) {
393451 // pad to multiple of 4 at the end
0 commit comments