@@ -171,35 +171,46 @@ def apply_boolean_colormap(
171171 return colored_image
172172
173173
174- def apply_pca_colormap (image : Float [Tensor , "*bs dim" ]) -> Float [Tensor , "*bs rgb=3" ]:
174+ def apply_pca_colormap (
175+ image : Float [Tensor , "*bs dim" ], pca_mat : Optional [Float [Tensor , "dim rgb=3" ]] = None , ignore_zeros = True
176+ ) -> Float [Tensor , "*bs rgb=3" ]:
175177 """Convert feature image to 3-channel RGB via PCA. The first three principle
176178 components are used for the color channels, with outlier rejection per-channel
177179
178180 Args:
179181 image: image of arbitrary vectors
182+ pca_mat: an optional argument of the PCA matrix, shape (dim, 3)
183+ ignore_zeros: whether to ignore zero values in the input image (they won't affect the PCA computation)
180184
181185 Returns:
182186 Tensor: Colored image
183187 """
184188 original_shape = image .shape
185189 image = image .view (- 1 , image .shape [- 1 ])
186- _ , _ , v = torch .pca_lowrank (image )
187- image = torch .matmul (image , v [..., :3 ])
188- d = torch .abs (image - torch .median (image , dim = 0 ).values )
190+ if ignore_zeros :
191+ valids = (image .abs ().amax (dim = - 1 )) > 0
192+ else :
193+ valids = torch .ones (image .shape [0 ], dtype = torch .bool )
194+
195+ if pca_mat is None :
196+ _ , _ , pca_mat = torch .pca_lowrank (image [valids , :], q = 3 , niter = 20 )
197+ assert pca_mat is not None
198+ image = torch .matmul (image , pca_mat [..., :3 ])
199+ d = torch .abs (image [valids , :] - torch .median (image [valids , :], dim = 0 ).values )
189200 mdev = torch .median (d , dim = 0 ).values
190201 s = d / mdev
191- m = 3 .0 # this is a hyperparam controlling how many std dev outside for outliers
192- rins = image [s [:, 0 ] < m , 0 ]
193- gins = image [s [:, 1 ] < m , 1 ]
194- bins = image [s [:, 2 ] < m , 2 ]
195-
196- image [: , 0 ] -= rins .min ()
197- image [: , 1 ] -= gins .min ()
198- image [: , 2 ] -= bins .min ()
199-
200- image [: , 0 ] /= rins .max () - rins .min ()
201- image [: , 1 ] /= gins .max () - gins .min ()
202- image [: , 2 ] /= bins .max () - bins .min ()
202+ m = 2 .0 # this is a hyperparam controlling how many std dev outside for outliers
203+ rins = image [valids , :][ s [:, 0 ] < m , 0 ]
204+ gins = image [valids , :][ s [:, 1 ] < m , 1 ]
205+ bins = image [valids , :][ s [:, 2 ] < m , 2 ]
206+
207+ image [valids , 0 ] -= rins .min ()
208+ image [valids , 1 ] -= gins .min ()
209+ image [valids , 2 ] -= bins .min ()
210+
211+ image [valids , 0 ] /= rins .max () - rins .min ()
212+ image [valids , 1 ] /= gins .max () - gins .min ()
213+ image [valids , 2 ] /= bins .max () - bins .min ()
203214
204215 image = torch .clamp (image , 0 , 1 )
205216 image_long = (image * 255 ).long ()
0 commit comments