@@ -17,22 +17,43 @@ class TAEBlock : public UnaryBlock {
1717protected:
1818 int n_in;
1919 int n_out;
20+ bool use_midblock_gn;
2021
2122public:
22- TAEBlock (int n_in, int n_out)
23- : n_in(n_in), n_out(n_out) {
23+ TAEBlock (int n_in, int n_out, bool use_midblock_gn = false )
24+ : n_in(n_in), n_out(n_out), use_midblock_gn(use_midblock_gn) {
2425 blocks[" conv.0" ] = std::shared_ptr<GGMLBlock>(new Conv2d (n_in, n_out, {3 , 3 }, {1 , 1 }, {1 , 1 }));
2526 blocks[" conv.2" ] = std::shared_ptr<GGMLBlock>(new Conv2d (n_out, n_out, {3 , 3 }, {1 , 1 }, {1 , 1 }));
2627 blocks[" conv.4" ] = std::shared_ptr<GGMLBlock>(new Conv2d (n_out, n_out, {3 , 3 }, {1 , 1 }, {1 , 1 }));
2728 if (n_in != n_out) {
2829 blocks[" skip" ] = std::shared_ptr<GGMLBlock>(new Conv2d (n_in, n_out, {1 , 1 }, {1 , 1 }, {1 , 1 }, {1 , 1 }, false ));
2930 }
31+ if (use_midblock_gn) {
32+ int n_gn = n_in * 4 ;
33+ blocks[" pool.0" ] = std::shared_ptr<GGMLBlock>(new Conv2d (n_in, n_gn, {1 , 1 }, {1 , 1 }, {0 , 0 }, {1 , 1 }, false ));
34+ blocks[" pool.1" ] = std::shared_ptr<GGMLBlock>(new GroupNorm (4 , n_gn));
35+ // pool.2 is ReLU, handled in forward
36+ blocks[" pool.3" ] = std::shared_ptr<GGMLBlock>(new Conv2d (n_gn, n_in, {1 , 1 }, {1 , 1 }, {0 , 0 }, {1 , 1 }, false ));
37+ }
3038 }
3139
3240 struct ggml_tensor * forward (GGMLRunnerContext* ctx, struct ggml_tensor * x) override {
3341 // x: [n, n_in, h, w]
3442 // return: [n, n_out, h, w]
3543
44+ if (use_midblock_gn) {
45+ auto pool_0 = std::dynamic_pointer_cast<Conv2d>(blocks[" pool.0" ]);
46+ auto pool_1 = std::dynamic_pointer_cast<GroupNorm>(blocks[" pool.1" ]);
47+ auto pool_3 = std::dynamic_pointer_cast<Conv2d>(blocks[" pool.3" ]);
48+
49+ auto p = pool_0->forward (ctx, x);
50+ p = pool_1->forward (ctx, p);
51+ p = ggml_relu_inplace (ctx->ggml_ctx , p);
52+ p = pool_3->forward (ctx, p);
53+
54+ x = ggml_add (ctx->ggml_ctx , x, p);
55+ }
56+
3657 auto conv_0 = std::dynamic_pointer_cast<Conv2d>(blocks[" conv.0" ]);
3758 auto conv_2 = std::dynamic_pointer_cast<Conv2d>(blocks[" conv.2" ]);
3859 auto conv_4 = std::dynamic_pointer_cast<Conv2d>(blocks[" conv.4" ]);
@@ -62,7 +83,7 @@ class TinyEncoder : public UnaryBlock {
6283 int num_blocks = 3 ;
6384
6485public:
65- TinyEncoder (int z_channels = 4 )
86+ TinyEncoder (int z_channels = 4 , bool use_midblock_gn = false )
6687 : z_channels(z_channels) {
6788 int index = 0 ;
6889 blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new Conv2d (in_channels, channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
@@ -80,7 +101,7 @@ class TinyEncoder : public UnaryBlock {
80101
81102 blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new Conv2d (channels, channels, {3 , 3 }, {2 , 2 }, {1 , 1 }, {1 , 1 }, false ));
82103 for (int i = 0 ; i < num_blocks; i++) {
83- blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock (channels, channels));
104+ blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock (channels, channels, use_midblock_gn ));
84105 }
85106
86107 blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new Conv2d (channels, z_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
@@ -107,15 +128,15 @@ class TinyDecoder : public UnaryBlock {
107128 int num_blocks = 3 ;
108129
109130public:
110- TinyDecoder (int z_channels = 4 )
131+ TinyDecoder (int z_channels = 4 , bool use_midblock_gn = false )
111132 : z_channels(z_channels) {
112133 int index = 0 ;
113134
114135 blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new Conv2d (z_channels, channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
115136 index++; // nn.ReLU()
116137
117138 for (int i = 0 ; i < num_blocks; i++) {
118- blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock (channels, channels));
139+ blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock (channels, channels, use_midblock_gn ));
119140 }
120141 index++; // nn.Upsample()
121142 blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new Conv2d (channels, channels, {3 , 3 }, {1 , 1 }, {1 , 1 }, {1 , 1 }, false ));
@@ -470,29 +491,44 @@ class TAEHV : public GGMLBlock {
470491class TAESD : public GGMLBlock {
471492protected:
472493 bool decode_only;
494+ bool taef2 = false ;
473495
474496public:
475497 TAESD (bool decode_only = true , SDVersion version = VERSION_SD1)
476498 : decode_only(decode_only) {
477- int z_channels = 4 ;
499+ int z_channels = 4 ;
500+ bool use_midblock_gn = false ;
501+ taef2 = sd_version_is_flux2 (version);
502+
478503 if (sd_version_is_dit (version)) {
479504 z_channels = 16 ;
480505 }
481- blocks[" decoder.layers" ] = std::shared_ptr<GGMLBlock>(new TinyDecoder (z_channels));
506+ if (taef2) {
507+ z_channels = 32 ;
508+ use_midblock_gn = true ;
509+ }
510+ blocks[" decoder.layers" ] = std::shared_ptr<GGMLBlock>(new TinyDecoder (z_channels, use_midblock_gn));
482511
483512 if (!decode_only) {
484- blocks[" encoder.layers" ] = std::shared_ptr<GGMLBlock>(new TinyEncoder (z_channels));
513+ blocks[" encoder.layers" ] = std::shared_ptr<GGMLBlock>(new TinyEncoder (z_channels, use_midblock_gn ));
485514 }
486515 }
487516
488517 struct ggml_tensor * decode (GGMLRunnerContext* ctx, struct ggml_tensor * z) {
489518 auto decoder = std::dynamic_pointer_cast<TinyDecoder>(blocks[" decoder.layers" ]);
519+ if (taef2) {
520+ z = unpatchify (ctx->ggml_ctx , z, 2 );
521+ }
490522 return decoder->forward (ctx, z);
491523 }
492524
493525 struct ggml_tensor * encode (GGMLRunnerContext* ctx, struct ggml_tensor * x) {
494526 auto encoder = std::dynamic_pointer_cast<TinyEncoder>(blocks[" encoder.layers" ]);
495- return encoder->forward (ctx, x);
527+ auto z = encoder->forward (ctx, x);
528+ if (taef2) {
529+ z = patchify (ctx->ggml_ctx , z, 2 );
530+ }
531+ return z;
496532 }
497533};
498534
0 commit comments