@@ -237,19 +237,9 @@ class MemBlock : public GGMLBlock {
237237 }
238238};
239239
240- class Clamp : public UnaryBlock {
241- public:
242- struct ggml_tensor * forward (GGMLRunnerContext* ctx, struct ggml_tensor * x) override {
243- return ggml_scale_inplace (ctx->ggml_ctx ,
244- ggml_tanh_inplace (ctx->ggml_ctx ,
245- ggml_scale (ctx->ggml_ctx , x, 1 .0f / 3 .0f )),
246- 3 .0f );
247- }
248- };
249-
250240class TinyVideoEncoder : public UnaryBlock {
251241 int in_channels = 3 ;
252- int channels = 64 ;
242+ int hidden = 64 ;
253243 int z_channels = 4 ;
254244 int num_blocks = 3 ;
255245 int num_layers = 3 ;
@@ -259,17 +249,17 @@ class TinyVideoEncoder : public UnaryBlock {
259249 TinyVideoEncoder (int z_channels = 4 , int patch_size = 1 )
260250 : z_channels(z_channels), patch_size(patch_size) {
261251 int index = 0 ;
262- blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new Conv2d (z_channels * patch_size * patch_size, channels , {3 , 3 }, {1 , 1 }, {1 , 1 }));
252+ blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels * patch_size * patch_size, hidden , {3 , 3 }, {1 , 1 }, {1 , 1 }));
263253 index++; // nn.ReLU()
264254 for (int i = 0 ; i < num_layers; i++) {
265255 int stride = i == num_layers - 1 ? 1 : 2 ;
266- blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new TPool (channels , stride));
267- blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new Conv2d (channels, channels , {3 , 3 }, {2 , 2 }, {1 , 1 }, {1 , 1 }, false ));
256+ blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new TPool (hidden , stride));
257+ blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new Conv2d (hidden, hidden , {3 , 3 }, {2 , 2 }, {1 , 1 }, {1 , 1 }, false ));
268258 for (int j = 0 ; j < num_blocks; j++) {
269- blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new MemBlock (channels, channels ));
259+ blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new MemBlock (hidden, hidden ));
270260 }
271261 }
272- blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new Conv2d (channels , z_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
262+ blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new Conv2d (hidden , z_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
273263 }
274264
275265 struct ggml_tensor * forward (GGMLRunnerContext* ctx, struct ggml_tensor * z) override {
@@ -301,9 +291,7 @@ class TinyVideoDecoder : public UnaryBlock {
301291
302292public:
303293 TinyVideoDecoder (int z_channels = 4 , int patch_size = 1 ) : z_channels(z_channels) {
304- int index = 0 ;
305- // n_f = [256, 128, 64, 64]
306- blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new Clamp ());
294+ int index = 1 ; // Clamp()
307295 blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new Conv2d (z_channels, channels[0 ], {3 , 3 }, {1 , 1 }, {1 , 1 }));
308296 index++; // nn.ReLU()
309297 for (int i = 0 ; i < num_layers; i++) {
@@ -322,11 +310,17 @@ class TinyVideoDecoder : public UnaryBlock {
322310
323311 struct ggml_tensor * forward (GGMLRunnerContext* ctx, struct ggml_tensor * z) override {
324312 LOG_DEBUG (" Here" );
325- auto clamp = std::dynamic_pointer_cast<Clamp>(blocks[" 0" ]);
326313 auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks[" 1" ]);
327- auto h = first_conv->forward (ctx, clamp->forward (ctx, z));
328- h = ggml_relu_inplace (ctx->ggml_ctx , h);
329- int index = 3 ;
314+
315+ // Clamp()
316+ auto h = ggml_scale_inplace (ctx->ggml_ctx ,
317+ ggml_tanh_inplace (ctx->ggml_ctx ,
318+ ggml_scale (ctx->ggml_ctx , z, 1 .0f / 3 .0f )),
319+ 3 .0f );
320+
321+ h = first_conv->forward (ctx, h);
322+ h = ggml_relu_inplace (ctx->ggml_ctx , h);
323+ int index = 3 ;
330324 for (int i = 0 ; i < num_layers; i++) {
331325 for (int j = 0 ; j < num_blocks; j++) {
332326 auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string (index++)]);
@@ -350,17 +344,19 @@ class TinyVideoDecoder : public UnaryBlock {
350344 // shape(W, H, 3, T+3) => shape(W, H, 3, T)
351345 h = ggml_view_4d (ctx->ggml_ctx , h, h->ne [0 ], h->ne [1 ], h->ne [2 ], h->ne [3 ] - 3 , h->nb [1 ], h->nb [2 ], h->nb [3 ], 0 );
352346 LOG_DEBUG (" Here" );
347+ print_ggml_tensor (h, true );
353348 return h;
354349 }
355350};
356351
357352class TAEHV : public GGMLBlock {
358353protected:
359354 bool decode_only;
355+ SDVersion version;
360356
361357public:
362358 TAEHV (bool decode_only = true , SDVersion version = VERSION_WAN2)
363- : decode_only(decode_only) {
359+ : decode_only(decode_only), version(version) {
364360 int z_channels = 16 ;
365361 int patch = 1 ;
366362 if (version == VERSION_WAN2_2_TI2V) {
@@ -376,7 +372,12 @@ class TAEHV : public GGMLBlock {
376372 struct ggml_tensor * decode (GGMLRunnerContext* ctx, struct ggml_tensor * z) {
377373 LOG_DEBUG (" Decode" );
378374 auto decoder = std::dynamic_pointer_cast<TinyVideoDecoder>(blocks[" decoder" ]);
379- return decoder->forward (ctx, z);
375+ auto result = decoder->forward (ctx, z);
376+ LOG_DEBUG (" Decoded" );
377+ if (sd_version_is_wan (version)) {
378+ result = ggml_permute (ctx->ggml_ctx , result, 0 , 1 , 3 , 2 );
379+ }
380+ return result;
380381 }
381382
382383 struct ggml_tensor * encode (GGMLRunnerContext* ctx, struct ggml_tensor * x) {
0 commit comments