6060 ],
6161]
6262
63+ TEST_CASES_CONDITIONAL = [
64+ [
65+ {
66+ "spatial_dims" : 2 ,
67+ "in_channels" : 1 ,
68+ "num_res_blocks" : 1 ,
69+ "num_channels" : (8 , 8 , 8 ),
70+ "attention_levels" : (False , False , True ),
71+ "num_head_channels" : 8 ,
72+ "norm_num_groups" : 8 ,
73+ "conditioning_embedding_in_channels" : 1 ,
74+ "conditioning_embedding_num_channels" : (8 , 8 ),
75+ "use_checkpointing" : False ,
76+ "with_conditioning" : True ,
77+ "cross_attention_dim" : 2 ,
78+ },
79+ 6 ,
80+ (1 , 8 , 4 , 4 ),
81+ ],
82+ [
83+ {
84+ "spatial_dims" : 3 ,
85+ "in_channels" : 1 ,
86+ "num_res_blocks" : 1 ,
87+ "num_channels" : (8 , 8 , 8 ),
88+ "attention_levels" : (False , False , True ),
89+ "num_head_channels" : 8 ,
90+ "norm_num_groups" : 8 ,
91+ "conditioning_embedding_in_channels" : 1 ,
92+ "conditioning_embedding_num_channels" : (8 , 8 ),
93+ "use_checkpointing" : True ,
94+ "with_conditioning" : True ,
95+ "cross_attention_dim" : 2 ,
96+ },
97+ 6 ,
98+ (1 , 8 , 4 , 4 , 4 ),
99+ ],
100+ ]
101+
102+ TEST_CASES_ERROR = [
103+ [
104+ {"spatial_dims" : 2 , "in_channels" : 1 , "with_conditioning" : True , "cross_attention_dim" : None },
105+ "ControlNet expects dimension of the cross-attention conditioning "
106+ "(cross_attention_dim) when using with_conditioning." ,
107+ ],
108+ [
109+ {"spatial_dims" : 2 , "in_channels" : 1 , "with_conditioning" : False , "cross_attention_dim" : 2 },
110+ "ControlNet expects with_conditioning=True when specifying the cross_attention_dim." ,
111+ ],
112+ [
113+ {"spatial_dims" : 2 , "in_channels" : 1 , "num_channels" : (8 , 16 ), "norm_num_groups" : 16 },
114+ "ControlNet expects all num_channels being multiple of norm_num_groups" ,
115+ ],
116+ [
117+ {
118+ "spatial_dims" : 2 ,
119+ "in_channels" : 1 ,
120+ "num_channels" : (8 , 16 ),
121+ "attention_levels" : (True ,),
122+ "norm_num_groups" : 8 ,
123+ },
124+ "ControlNet expects num_channels being same size of attention_levels" ,
125+ ],
126+ ]
127+
63128
64129@skipUnless (has_generative , "monai-generative required" )
65130class TestControlNet (unittest .TestCase ):
@@ -76,6 +141,27 @@ def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_
76141 self .assertEqual (len (result [0 ]), expected_num_down_blocks_residuals )
77142 self .assertEqual (result [1 ].shape , expected_shape )
78143
144+ @parameterized .expand (TEST_CASES_CONDITIONAL )
145+ def test_shape_conditioned_models (self , input_param , expected_num_down_blocks_residuals , expected_shape ):
146+ net = ControlNetMaisi (** input_param )
147+ with eval_mode (net ):
148+ x = torch .rand ((1 , 1 , 16 , 16 )) if input_param ["spatial_dims" ] == 2 else torch .rand ((1 , 1 , 16 , 16 , 16 ))
149+ timesteps = torch .randint (0 , 1000 , (1 ,)).long ()
150+ controlnet_cond = (
151+ torch .rand ((1 , 1 , 32 , 32 )) if input_param ["spatial_dims" ] == 2 else torch .rand ((1 , 1 , 32 , 32 , 32 ))
152+ )
153+ context = torch .randn ((1 , 1 , input_param ["cross_attention_dim" ]))
154+ result = net .forward (x , timesteps , controlnet_cond , context = context )
155+ self .assertEqual (len (result [0 ]), expected_num_down_blocks_residuals )
156+ self .assertEqual (result [1 ].shape , expected_shape )
157+
158+ @parameterized .expand (TEST_CASES_ERROR )
159+ def test_error_input (self , input_param , expected_error ):
160+ with self .assertRaises (ValueError ) as context : # output shape too small
161+ _ = ControlNetMaisi (** input_param )
162+ runtime_error = context .exception
163+ self .assertEqual (str (runtime_error ), expected_error )
164+
79165
80166if __name__ == "__main__" :
81167 unittest .main ()
0 commit comments