Skip to content

Commit c342b23

Browse files
committed
add more test cases
Signed-off-by: Pengfei Guo <pengfeig@nvidia.com>
1 parent 9ae94fe commit c342b23

2 files changed

Lines changed: 92 additions & 6 deletions

File tree

monai/apps/generation/maisi/networks/controlnet_maisi.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,11 @@ def _apply_mid_block(self, emb, context, h):
168168

169169
def _apply_controlnet_blocks(self, h, down_block_res_samples):
170170
# 6. Control net blocks
171-
controlnet_down_block_res_samples = []
172-
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
173-
down_block_res_sample = controlnet_block(down_block_res_sample)
174-
controlnet_down_block_res_samples.append(down_block_res_sample)
171+
controlnet_down_block_res_samples = []
172+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
173+
down_block_res_sample = controlnet_block(down_block_res_sample)
174+
controlnet_down_block_res_samples.append(down_block_res_sample)
175175

176-
mid_block_res_sample = self.controlnet_mid_block(h)
176+
mid_block_res_sample = self.controlnet_mid_block(h)
177177

178-
return controlnet_down_block_res_samples, mid_block_res_sample
178+
return controlnet_down_block_res_samples, mid_block_res_sample

tests/test_controlnet_maisi.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,71 @@
6060
],
6161
]
6262

63+
TEST_CASES_CONDITIONAL = [
64+
[
65+
{
66+
"spatial_dims": 2,
67+
"in_channels": 1,
68+
"num_res_blocks": 1,
69+
"num_channels": (8, 8, 8),
70+
"attention_levels": (False, False, True),
71+
"num_head_channels": 8,
72+
"norm_num_groups": 8,
73+
"conditioning_embedding_in_channels": 1,
74+
"conditioning_embedding_num_channels": (8, 8),
75+
"use_checkpointing": False,
76+
"with_conditioning": True,
77+
"cross_attention_dim": 2,
78+
},
79+
6,
80+
(1, 8, 4, 4),
81+
],
82+
[
83+
{
84+
"spatial_dims": 3,
85+
"in_channels": 1,
86+
"num_res_blocks": 1,
87+
"num_channels": (8, 8, 8),
88+
"attention_levels": (False, False, True),
89+
"num_head_channels": 8,
90+
"norm_num_groups": 8,
91+
"conditioning_embedding_in_channels": 1,
92+
"conditioning_embedding_num_channels": (8, 8),
93+
"use_checkpointing": True,
94+
"with_conditioning": True,
95+
"cross_attention_dim": 2,
96+
},
97+
6,
98+
(1, 8, 4, 4, 4),
99+
],
100+
]
101+
102+
TEST_CASES_ERROR = [
103+
[
104+
{"spatial_dims": 2, "in_channels": 1, "with_conditioning": True, "cross_attention_dim": None},
105+
"ControlNet expects dimension of the cross-attention conditioning "
106+
"(cross_attention_dim) when using with_conditioning.",
107+
],
108+
[
109+
{"spatial_dims": 2, "in_channels": 1, "with_conditioning": False, "cross_attention_dim": 2},
110+
"ControlNet expects with_conditioning=True when specifying the cross_attention_dim.",
111+
],
112+
[
113+
{"spatial_dims": 2, "in_channels": 1, "num_channels": (8, 16), "norm_num_groups": 16},
114+
"ControlNet expects all num_channels being multiple of norm_num_groups",
115+
],
116+
[
117+
{
118+
"spatial_dims": 2,
119+
"in_channels": 1,
120+
"num_channels": (8, 16),
121+
"attention_levels": (True,),
122+
"norm_num_groups": 8,
123+
},
124+
"ControlNet expects num_channels being same size of attention_levels",
125+
],
126+
]
127+
63128

64129
@skipUnless(has_generative, "monai-generative required")
65130
class TestControlNet(unittest.TestCase):
@@ -76,6 +141,27 @@ def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_
76141
self.assertEqual(len(result[0]), expected_num_down_blocks_residuals)
77142
self.assertEqual(result[1].shape, expected_shape)
78143

144+
@parameterized.expand(TEST_CASES_CONDITIONAL)
145+
def test_shape_conditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape):
146+
net = ControlNetMaisi(**input_param)
147+
with eval_mode(net):
148+
x = torch.rand((1, 1, 16, 16)) if input_param["spatial_dims"] == 2 else torch.rand((1, 1, 16, 16, 16))
149+
timesteps = torch.randint(0, 1000, (1,)).long()
150+
controlnet_cond = (
151+
torch.rand((1, 1, 32, 32)) if input_param["spatial_dims"] == 2 else torch.rand((1, 1, 32, 32, 32))
152+
)
153+
context = torch.randn((1, 1, input_param["cross_attention_dim"]))
154+
result = net.forward(x, timesteps, controlnet_cond, context=context)
155+
self.assertEqual(len(result[0]), expected_num_down_blocks_residuals)
156+
self.assertEqual(result[1].shape, expected_shape)
157+
158+
@parameterized.expand(TEST_CASES_ERROR)
159+
def test_error_input(self, input_param, expected_error):
160+
with self.assertRaises(ValueError) as context: # output shape too small
161+
_ = ControlNetMaisi(**input_param)
162+
runtime_error = context.exception
163+
self.assertEqual(str(runtime_error), expected_error)
164+
79165

80166
if __name__ == "__main__":
81167
unittest.main()

0 commit comments

Comments
 (0)