Skip to content

Commit 45d8bb7

Browse files
authored
updates to apply_pca_colormap (nerfstudio-project#3086)
* improvements to pca_colormap: allow input pca matrix, optional ignore_zeros arg * typo
1 parent babf577 commit 45d8bb7

1 file changed

Lines changed: 27 additions & 16 deletions

File tree

nerfstudio/utils/colormaps.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -171,35 +171,46 @@ def apply_boolean_colormap(
171171
return colored_image
172172

173173

174-
def apply_pca_colormap(image: Float[Tensor, "*bs dim"]) -> Float[Tensor, "*bs rgb=3"]:
174+
def apply_pca_colormap(
175+
image: Float[Tensor, "*bs dim"], pca_mat: Optional[Float[Tensor, "dim rgb=3"]] = None, ignore_zeros=True
176+
) -> Float[Tensor, "*bs rgb=3"]:
175177
"""Convert feature image to 3-channel RGB via PCA. The first three principle
176178
components are used for the color channels, with outlier rejection per-channel
177179
178180
Args:
179181
image: image of arbitrary vectors
182+
pca_mat: an optional argument of the PCA matrix, shape (dim, 3)
183+
ignore_zeros: whether to ignore zero values in the input image (they won't affect the PCA computation)
180184
181185
Returns:
182186
Tensor: Colored image
183187
"""
184188
original_shape = image.shape
185189
image = image.view(-1, image.shape[-1])
186-
_, _, v = torch.pca_lowrank(image)
187-
image = torch.matmul(image, v[..., :3])
188-
d = torch.abs(image - torch.median(image, dim=0).values)
190+
if ignore_zeros:
191+
valids = (image.abs().amax(dim=-1)) > 0
192+
else:
193+
valids = torch.ones(image.shape[0], dtype=torch.bool)
194+
195+
if pca_mat is None:
196+
_, _, pca_mat = torch.pca_lowrank(image[valids, :], q=3, niter=20)
197+
assert pca_mat is not None
198+
image = torch.matmul(image, pca_mat[..., :3])
199+
d = torch.abs(image[valids, :] - torch.median(image[valids, :], dim=0).values)
189200
mdev = torch.median(d, dim=0).values
190201
s = d / mdev
191-
m = 3.0 # this is a hyperparam controlling how many std dev outside for outliers
192-
rins = image[s[:, 0] < m, 0]
193-
gins = image[s[:, 1] < m, 1]
194-
bins = image[s[:, 2] < m, 2]
195-
196-
image[:, 0] -= rins.min()
197-
image[:, 1] -= gins.min()
198-
image[:, 2] -= bins.min()
199-
200-
image[:, 0] /= rins.max() - rins.min()
201-
image[:, 1] /= gins.max() - gins.min()
202-
image[:, 2] /= bins.max() - bins.min()
202+
m = 2.0 # this is a hyperparam controlling how many std dev outside for outliers
203+
rins = image[valids, :][s[:, 0] < m, 0]
204+
gins = image[valids, :][s[:, 1] < m, 1]
205+
bins = image[valids, :][s[:, 2] < m, 2]
206+
207+
image[valids, 0] -= rins.min()
208+
image[valids, 1] -= gins.min()
209+
image[valids, 2] -= bins.min()
210+
211+
image[valids, 0] /= rins.max() - rins.min()
212+
image[valids, 1] /= gins.max() - gins.min()
213+
image[valids, 2] /= bins.max() - bins.min()
203214

204215
image = torch.clamp(image, 0, 1)
205216
image_long = (image * 255).long()

0 commit comments

Comments
 (0)