Skip to content

Commit ea34cdb

Browse files
committed
progress towards video support
1 parent 60bec38 commit ea34cdb

File tree

2 files changed

+35
-25
lines changed

2 files changed

+35
-25
lines changed

stable-diffusion.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,10 @@ class StableDiffusionGGML {
491491
"decoder",
492492
vae_decode_only,
493493
version);
494+
if (sd_ctx_params->vae_conv_direct) {
495+
LOG_INFO("Using Conv2d direct in the tae model");
496+
tae_first_stage->set_conv2d_direct_enabled(true);
497+
}
494498
}
495499
} else if (version == VERSION_CHROMA_RADIANCE) {
496500
first_stage_model = std::make_shared<FakeVAE>(vae_backend,
@@ -1718,6 +1722,10 @@ class StableDiffusionGGML {
17181722
first_stage_model->free_compute_buffer();
17191723
process_vae_output_tensor(result);
17201724
} else {
1725+
if (sd_version_is_wan(version)) {
1726+
x = ggml_permute(work_ctx, x, 0, 1, 3, 2);
1727+
}
1728+
17211729
if (vae_tiling_params.enabled && !decode_video) {
17221730
// split latent in 64x64 tiles and compute in several steps
17231731
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
@@ -1728,6 +1736,7 @@ class StableDiffusionGGML {
17281736
tae_first_stage->compute(n_threads, x, true, &result);
17291737
}
17301738
tae_first_stage->free_compute_buffer();
1739+
17311740
}
17321741

17331742
int64_t t1 = ggml_time_ms();

tae.hpp

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -237,19 +237,9 @@ class MemBlock : public GGMLBlock {
237237
}
238238
};
239239

240-
class Clamp : public UnaryBlock {
241-
public:
242-
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
243-
return ggml_scale_inplace(ctx->ggml_ctx,
244-
ggml_tanh_inplace(ctx->ggml_ctx,
245-
ggml_scale(ctx->ggml_ctx, x, 1.0f / 3.0f)),
246-
3.0f);
247-
}
248-
};
249-
250240
class TinyVideoEncoder : public UnaryBlock {
251241
int in_channels = 3;
252-
int channels = 64;
242+
int hidden = 64;
253243
int z_channels = 4;
254244
int num_blocks = 3;
255245
int num_layers = 3;
@@ -259,17 +249,17 @@ class TinyVideoEncoder : public UnaryBlock {
259249
TinyVideoEncoder(int z_channels = 4, int patch_size = 1)
260250
: z_channels(z_channels), patch_size(patch_size) {
261251
int index = 0;
262-
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels * patch_size * patch_size, channels, {3, 3}, {1, 1}, {1, 1}));
252+
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels * patch_size * patch_size, hidden, {3, 3}, {1, 1}, {1, 1}));
263253
index++; // nn.ReLU()
264254
for (int i = 0; i < num_layers; i++) {
265255
int stride = i == num_layers - 1 ? 1 : 2;
266-
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TPool(channels, stride));
267-
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
256+
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TPool(hidden, stride));
257+
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(hidden, hidden, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
268258
for (int j = 0; j < num_blocks; j++) {
269-
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(channels, channels));
259+
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(hidden, hidden));
270260
}
271261
}
272-
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1}));
262+
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(hidden, z_channels, {3, 3}, {1, 1}, {1, 1}));
273263
}
274264

275265
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override {
@@ -301,9 +291,7 @@ class TinyVideoDecoder : public UnaryBlock {
301291

302292
public:
303293
TinyVideoDecoder(int z_channels = 4, int patch_size = 1) : z_channels(z_channels) {
304-
int index = 0;
305-
// n_f = [256, 128, 64, 64]
306-
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Clamp());
294+
int index = 1; // Clamp()
307295
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, channels[0], {3, 3}, {1, 1}, {1, 1}));
308296
index++; // nn.ReLU()
309297
for (int i = 0; i < num_layers; i++) {
@@ -322,11 +310,17 @@ class TinyVideoDecoder : public UnaryBlock {
322310

323311
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override {
324312
LOG_DEBUG("Here");
325-
auto clamp = std::dynamic_pointer_cast<Clamp>(blocks["0"]);
326313
auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["1"]);
327-
auto h = first_conv->forward(ctx, clamp->forward(ctx, z));
328-
h = ggml_relu_inplace(ctx->ggml_ctx, h);
329-
int index = 3;
314+
315+
// Clamp()
316+
auto h = ggml_scale_inplace(ctx->ggml_ctx,
317+
ggml_tanh_inplace(ctx->ggml_ctx,
318+
ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)),
319+
3.0f);
320+
321+
h = first_conv->forward(ctx, h);
322+
h = ggml_relu_inplace(ctx->ggml_ctx, h);
323+
int index = 3;
330324
for (int i = 0; i < num_layers; i++) {
331325
for (int j = 0; j < num_blocks; j++) {
332326
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
@@ -350,17 +344,19 @@ class TinyVideoDecoder : public UnaryBlock {
350344
// shape(W, H, 3, T+3) => shape(W, H, 3, T)
351345
h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - 3, h->nb[1], h->nb[2], h->nb[3], 0);
352346
LOG_DEBUG("Here");
347+
print_ggml_tensor(h, true);
353348
return h;
354349
}
355350
};
356351

357352
class TAEHV : public GGMLBlock {
358353
protected:
359354
bool decode_only;
355+
SDVersion version;
360356

361357
public:
362358
TAEHV(bool decode_only = true, SDVersion version = VERSION_WAN2)
363-
: decode_only(decode_only) {
359+
: decode_only(decode_only), version(version) {
364360
int z_channels = 16;
365361
int patch = 1;
366362
if (version == VERSION_WAN2_2_TI2V) {
@@ -376,7 +372,12 @@ class TAEHV : public GGMLBlock {
376372
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
377373
LOG_DEBUG("Decode");
378374
auto decoder = std::dynamic_pointer_cast<TinyVideoDecoder>(blocks["decoder"]);
379-
return decoder->forward(ctx, z);
375+
auto result = decoder->forward(ctx, z);
376+
LOG_DEBUG("Decoded");
377+
if (sd_version_is_wan(version)) {
378+
result = ggml_permute(ctx->ggml_ctx, result, 0, 1, 3, 2);
379+
}
380+
return result;
380381
}
381382

382383
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {

0 commit comments

Comments
 (0)