Skip to content

Commit e50e1f2

Browse files
authored
feat: add taef2 support (#1211)
1 parent c6206fb commit e50e1f2

File tree

1 file changed

+46
-10
lines changed

1 file changed

+46
-10
lines changed

tae.hpp

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,43 @@ class TAEBlock : public UnaryBlock {
1717
protected:
1818
int n_in;
1919
int n_out;
20+
bool use_midblock_gn;
2021

2122
public:
22-
TAEBlock(int n_in, int n_out)
23-
: n_in(n_in), n_out(n_out) {
23+
TAEBlock(int n_in, int n_out, bool use_midblock_gn = false)
24+
: n_in(n_in), n_out(n_out), use_midblock_gn(use_midblock_gn) {
2425
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {3, 3}, {1, 1}, {1, 1}));
2526
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
2627
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
2728
if (n_in != n_out) {
2829
blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {1, 1}, {1, 1}, {1, 1}, {1, 1}, false));
2930
}
31+
if (use_midblock_gn) {
32+
int n_gn = n_in * 4;
33+
blocks["pool.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_gn, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
34+
blocks["pool.1"] = std::shared_ptr<GGMLBlock>(new GroupNorm(4, n_gn));
35+
// pool.2 is ReLU, handled in forward
36+
blocks["pool.3"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_gn, n_in, {1, 1}, {1, 1}, {0, 0}, {1, 1}, false));
37+
}
3038
}
3139

3240
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
3341
// x: [n, n_in, h, w]
3442
// return: [n, n_out, h, w]
3543

44+
if (use_midblock_gn) {
45+
auto pool_0 = std::dynamic_pointer_cast<Conv2d>(blocks["pool.0"]);
46+
auto pool_1 = std::dynamic_pointer_cast<GroupNorm>(blocks["pool.1"]);
47+
auto pool_3 = std::dynamic_pointer_cast<Conv2d>(blocks["pool.3"]);
48+
49+
auto p = pool_0->forward(ctx, x);
50+
p = pool_1->forward(ctx, p);
51+
p = ggml_relu_inplace(ctx->ggml_ctx, p);
52+
p = pool_3->forward(ctx, p);
53+
54+
x = ggml_add(ctx->ggml_ctx, x, p);
55+
}
56+
3657
auto conv_0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]);
3758
auto conv_2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.2"]);
3859
auto conv_4 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
@@ -62,7 +83,7 @@ class TinyEncoder : public UnaryBlock {
6283
int num_blocks = 3;
6384

6485
public:
65-
TinyEncoder(int z_channels = 4)
86+
TinyEncoder(int z_channels = 4, bool use_midblock_gn = false)
6687
: z_channels(z_channels) {
6788
int index = 0;
6889
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1}));
@@ -80,7 +101,7 @@ class TinyEncoder : public UnaryBlock {
80101

81102
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
82103
for (int i = 0; i < num_blocks; i++) {
83-
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
104+
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels, use_midblock_gn));
84105
}
85106

86107
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1}));
@@ -107,15 +128,15 @@ class TinyDecoder : public UnaryBlock {
107128
int num_blocks = 3;
108129

109130
public:
110-
TinyDecoder(int z_channels = 4)
131+
TinyDecoder(int z_channels = 4, bool use_midblock_gn = false)
111132
: z_channels(z_channels) {
112133
int index = 0;
113134

114135
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, channels, {3, 3}, {1, 1}, {1, 1}));
115136
index++; // nn.ReLU()
116137

117138
for (int i = 0; i < num_blocks; i++) {
118-
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
139+
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels, use_midblock_gn));
119140
}
120141
index++; // nn.Upsample()
121142
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false));
@@ -470,29 +491,44 @@ class TAEHV : public GGMLBlock {
470491
class TAESD : public GGMLBlock {
471492
protected:
472493
bool decode_only;
494+
bool taef2 = false;
473495

474496
public:
475497
TAESD(bool decode_only = true, SDVersion version = VERSION_SD1)
476498
: decode_only(decode_only) {
477-
int z_channels = 4;
499+
int z_channels = 4;
500+
bool use_midblock_gn = false;
501+
taef2 = sd_version_is_flux2(version);
502+
478503
if (sd_version_is_dit(version)) {
479504
z_channels = 16;
480505
}
481-
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels));
506+
if (taef2) {
507+
z_channels = 32;
508+
use_midblock_gn = true;
509+
}
510+
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder(z_channels, use_midblock_gn));
482511

483512
if (!decode_only) {
484-
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels));
513+
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder(z_channels, use_midblock_gn));
485514
}
486515
}
487516

488517
struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) {
489518
auto decoder = std::dynamic_pointer_cast<TinyDecoder>(blocks["decoder.layers"]);
519+
if (taef2) {
520+
z = unpatchify(ctx->ggml_ctx, z, 2);
521+
}
490522
return decoder->forward(ctx, z);
491523
}
492524

493525
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
494526
auto encoder = std::dynamic_pointer_cast<TinyEncoder>(blocks["encoder.layers"]);
495-
return encoder->forward(ctx, x);
527+
auto z = encoder->forward(ctx, x);
528+
if (taef2) {
529+
z = patchify(ctx->ggml_ctx, z, 2);
530+
}
531+
return z;
496532
}
497533
};
498534

0 commit comments

Comments
 (0)