Skip to content

Commit 54f9a5a

Browse files
committed
taehv: support patchified latents
1 parent 1cbcca2 commit 54f9a5a

File tree

1 file changed

+67
-9
lines changed

1 file changed

+67
-9
lines changed

tae.hpp

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -264,13 +264,13 @@ class TinyVideoEncoder : public UnaryBlock {
264264
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override {
265265
auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["0"]);
266266
auto h = first_conv->forward(ctx, z);
267-
h = ggml_relu_inplace(ctx->ggml_ctx, h);
268-
267+
h = ggml_relu_inplace(ctx->ggml_ctx, h);
268+
269269
int index = 2;
270270
for (int i = 0; i < num_layers; i++) {
271271
auto pool = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string(index++)]);
272272
auto conv = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string(index++)]);
273-
273+
274274
h = pool->forward(ctx, h);
275275
h = conv->forward(ctx, h);
276276
for (int j = 0; j < num_blocks; j++) {
@@ -280,21 +280,77 @@ class TinyVideoEncoder : public UnaryBlock {
280280
h = block->forward(ctx, h, mem);
281281
}
282282
}
283-
auto last_conv = std::dynamic_pointer_cast<Conv2d>(blocks[std::to_string(index)]);
284-
h = last_conv->forward(ctx, h);
283+
auto last_conv = std::dynamic_pointer_cast<Conv2d>(blocks[std::to_string(index)]);
284+
h = last_conv->forward(ctx, h);
285285
return h;
286286
}
287287
};
288288

289+
290+
struct ggml_tensor* patchify(struct ggml_context* ctx,
291+
struct ggml_tensor* x,
292+
int64_t patch_size,
293+
int64_t b = 1) {
294+
// x: [f, b*c, h*q, w*r]
295+
// return: [f, b*c*r*q, h, w]
296+
if (patch_size == 1) {
297+
return x;
298+
}
299+
int64_t r = patch_size;
300+
int64_t q = patch_size;
301+
302+
int64_t w_in = x->ne[0];
303+
int64_t h_in = x->ne[1];
304+
int64_t cb = x->ne[2]; // b*c
305+
int64_t f = x->ne[3];
306+
307+
int64_t w = w_in / r;
308+
int64_t h = h_in / q;
309+
310+
x = ggml_reshape_4d(ctx, x, w, r, h_in, cb * f); // [f*b*c, h*q, r, w]
311+
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [f*b*c, r, h*q, w]
312+
x = ggml_reshape_4d(ctx, x, w, q, h, r * cb * f); // [f*b*c*r, h, q, w]
313+
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [f*b*c*r, q, h, w]
314+
x = ggml_reshape_4d(ctx, x, w, h, q * r * cb, f); // [f, b*c*r*q, h, w]
315+
return x;
316+
}
317+
318+
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
319+
struct ggml_tensor* x,
320+
int64_t patch_size,
321+
int64_t b = 1) {
322+
// x: [f, b*c*r*q, h, w]
323+
// return: [f, b*c, h*q, w*r]
324+
if (patch_size == 1) {
325+
return x;
326+
}
327+
int64_t r = patch_size;
328+
int64_t q = patch_size;
329+
int64_t c = x->ne[2] / b / q / r;
330+
int64_t f = x->ne[3];
331+
int64_t h = x->ne[1];
332+
int64_t w = x->ne[0];
333+
334+
335+
x = ggml_reshape_4d(ctx, x, w * h, q * r, c * b, f); // [f, b*c, r*q, h*w]
336+
x = ggml_reshape_4d(ctx, x, w, h * q, r, f * c * b); // [f*b*c, r, q*h, w]
337+
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [f*b*c, q*h, w, r]
338+
x = ggml_reshape_4d(ctx, x, r * w, h, q, f * c * b); // [f*b*c, q, h, w*r]
339+
x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [f*b*c, h, q, w*r]
340+
x = ggml_reshape_4d(ctx, x, r * w, q * h, c * b, f); // [f, b*c, h*q, w*r]
341+
return x;
342+
}
343+
289344
class TinyVideoDecoder : public UnaryBlock {
290345
int z_channels = 4;
291346
int out_channels = 3;
292347
int num_blocks = 3;
293348
static const int num_layers = 3;
294349
int channels[num_layers + 1] = {256, 128, 64, 64};
350+
int patch_size = 1;
295351

296352
public:
297-
TinyVideoDecoder(int z_channels = 4, int patch_size = 1) : z_channels(z_channels) {
353+
TinyVideoDecoder(int z_channels = 4, int patch_size = 1) : z_channels(z_channels), patch_size(patch_size) {
298354
int index = 1; // Clamp()
299355
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, channels[0], {3, 3}, {1, 1}, {1, 1}));
300356
index++; // nn.ReLU()
@@ -343,7 +399,9 @@ class TinyVideoDecoder : public UnaryBlock {
343399

344400
auto last_conv = std::dynamic_pointer_cast<Conv2d>(blocks[std::to_string(++index)]);
345401
h = last_conv->forward(ctx, h);
346-
402+
if (patch_size > 1) {
403+
h = unpatchify(ctx->ggml_ctx, h, patch_size, 1);
404+
}
347405
// shape(W, H, 3, 3 + T) => shape(W, H, 3, T)
348406
h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - 3, h->nb[1], h->nb[2], h->nb[3], 3 * h->nb[3]);
349407
return h;
@@ -376,7 +434,7 @@ class TAEHV : public GGMLBlock {
376434
// (W, H, C, T) -> (W, H, T, C)
377435
z = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, z, 0, 1, 3, 2));
378436
}
379-
auto result = decoder->forward(ctx, z);
437+
auto result = decoder->forward(ctx, z);
380438
if (sd_version_is_wan(version)) {
381439
// (W, H, C, T) -> (W, H, T, C)
382440
result = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2));
@@ -387,7 +445,7 @@ class TAEHV : public GGMLBlock {
387445
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
388446
auto encoder = std::dynamic_pointer_cast<TinyVideoEncoder>(blocks["encoder"]);
389447
// (W, H, T, C) -> (W, H, C, T)
390-
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
448+
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
391449
int64_t num_frames = x->ne[3];
392450
if (num_frames % 4) {
393451
// pad to multiple of 4 at the end

0 commit comments

Comments
 (0)