Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 64 additions & 12 deletions package/PartSegImage/image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import re
import sys
import typing
Expand Down Expand Up @@ -39,13 +40,27 @@ class ChannelInfo:

@dataclass(**ch_par)
class ChannelInfoFull:
"""Full channel information used in :py:class:`.Image`"""

name: str
color_map: str | np.ndarray
contrast_limits: tuple[float, float]

def __post_init__(self):
"""Normalize color_map to numpy array if it is not string."""
if not isinstance(self.color_map, (str, np.ndarray)):
self.color_map = np.array(self.color_map)
self.color_map = np.array(self.color_map, dtype=np.uint8)
if isinstance(self.color_map, np.ndarray):
if self.color_map.dtype != np.uint8:
message = f"Colormap as array need to be uint8, not {self.color_map.dtype}"
raise ValueError(message)
if self.color_map.ndim in {1, 2}:
if self.color_map.shape[0] not in {3, 4}:
message = f"Color map need to have 3 or 4 elements (RGB or RGBA), not {self.color_map.shape}"
raise ValueError(message)
else:
message = f"Colormap as sequence need to be 1d or 2d array, not {self.color_map.shape}"
raise ValueError(message)


def minimal_dtype(val: int):
Expand Down Expand Up @@ -896,7 +911,12 @@ def cut_image(
axes_order=self.axis_order,
)

def get_imagej_colors(self):
def get_imagej_colors(self) -> list[np.ndarray[tuple[typing.Literal[3], typing.Literal[256]], np.dtype[np.uint8]]]:
"""Get colors in format used by imagej

:return: list of 3x256 arrays with RGB values
:rtype: list of numpy.ndarray
"""
res = []
for color in self.default_coloring:

Expand All @@ -908,17 +928,43 @@ def get_imagej_colors(self):
res.append(np.array([np.linspace(0, x, num=256) for x in color_array]).astype(np.uint8))
elif color.ndim == 1:
res.append(np.array([np.linspace(0, x, num=256) for x in color]).astype(np.uint8))
else:
if color.shape[1] != 256:
res.append(
np.array(
[
np.interp(np.linspace(0, 255, num=256), np.linspace(0, color.shape[1], num=256), x)
for x in color
]
)
elif color.shape[1] != 256:
res.append(
np.array(
[
np.interp(np.linspace(0, 255, num=256), np.linspace(0, color.shape[1], num=256), x)
for x in color
]
)
res.append(color)
)
else:
res.append(color.astype(np.uint8))
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return res

def get_ome_colors(self) -> list[int]:
"""The ome stores colors as single integer encoding RGB value

:returns: list of integers representing colors
"""

res = []
default_colors = ["red", "blue", "green", "yellow", "magenta", "cyan"]
for i, color in enumerate(self.default_coloring):
if isinstance(color, str):
if color.startswith("#"):
color_array = _hex_to_rgb(color)
else:
color_array = _name_to_rgb(color)
res.append(_rgb_to_signed_int(color_array))
elif color.ndim == 1:
# treat as RGB
res.append(_rgb_to_signed_int(tuple(color)))
else:
logging.warning(
"Do not support custom colormap in ome colors. Use %s", default_colors[i % len(default_colors)]
)
color_array = _name_to_rgb(default_colors[i % len(default_colors)])
res.append(_rgb_to_signed_int(color_array))
return res
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def get_colors(self) -> list[str | list[int]]:
Expand Down Expand Up @@ -1001,6 +1047,12 @@ def _name_to_rgb(name: str) -> tuple[int, int, int]:
return _hex_to_rgb(_NAMED_COLORS[name])


def _rgb_to_signed_int(rgb: tuple[int, int, int]) -> int:
"""Convert an RGB tuple to a signed integer representation."""
r, g, b = rgb[:3]
return np.int32((np.int32(r) << 24) | (np.int32(g) << 16) | (np.int32(b) << 8) | np.int32(255))


try:
from vispy.color import get_color_dict
except ImportError: # pragma: no cover
Expand Down
25 changes: 22 additions & 3 deletions package/PartSegImage/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,15 +590,34 @@ def read_resolution_from_tags(self, image_file):
return x_spacing, y_spacing

@staticmethod
def _read_imagej_colors(image_file):
def _read_imagej_colors(
image_file: tifffile.TiffFile,
) -> list[np.ndarray[tuple[int, int], np.dtype[np.uint8]]]:
"""
Read colors from ImageJ metadata

:param image_file: tiff file to read
:return: list of colors or empty list if no colors
"""
colors = image_file.imagej_metadata.get("LUTs", [])
if isinstance(colors, list) and colors and colors[0].shape[0] == 24:
# drop buggy colors that comes from bug in PArtSeg with
# drop buggy colors that comes from bug in PartSeg with
# writing 64 bit integers in tifffile
return []
if isinstance(colors, np.ndarray):
return [colors]

return colors

def read_imagej_metadata(self, image_file):
def read_imagej_metadata(self, image_file: tifffile.TiffFile) -> None:
"""
Read metadata from the ImageJ tiff file.

Read spacing, colors, channel names, ranges and other metadata.
Save original metadata in :py:attr:`metadata`

:param image_file: file to read
"""
try:
z_spacing = image_file.imagej_metadata["spacing"] * name_to_scalar[image_file.imagej_metadata["unit"]]
except KeyError:
Expand Down
1 change: 1 addition & 0 deletions package/PartSegImage/image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def save(cls, image: Image, save_path: typing.Union[str, BytesIO, Path], compres
metadata["Channel"] = {
"Name": image.channel_names,
"axes": "TZYXC",
"Color": image.get_ome_colors(),
}
cls._save(data, save_path, metadata, compression)

Expand Down
44 changes: 43 additions & 1 deletion package/tests/test_PartSegImage/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import os

import numpy as np
import numpy.testing as npt
import pytest
from skimage.morphology import diamond

from PartSegImage import Channel, ChannelInfo, Image, ImageWriter, TiffImageReader
from PartSegImage import Channel, ChannelInfo, ChannelInfoFull, Image, ImageWriter, TiffImageReader
from PartSegImage.image import FRAME_THICKNESS, _hex_to_rgb, _name_to_rgb


Expand Down Expand Up @@ -728,3 +729,44 @@ def test_name_to_rgb_vispy():
# This test check mapping not defined in fallback dictionary
pytest.importorskip("vispy", reason="vispy not installed")
assert _name_to_rgb("lime") == (0, 255, 0)


class TestChannelInfoFull:
RED_DEF = (255, 0, 0)
RED_2D_REF = [[0, 255], [0, 0], [0, 0]]

def test_str(self):
obj = ChannelInfoFull(name="test", color_map="gray", contrast_limits=(0, 20))
assert obj.name == "test"
assert obj.color_map == "gray"
assert obj.contrast_limits == (0, 20)

@pytest.mark.parametrize("color_map", [tuple(RED_DEF), np.array(RED_DEF, dtype=np.uint8), list(RED_DEF)])
def test_1d_colormap(self, color_map):
obj = ChannelInfoFull(name="test", color_map=color_map, contrast_limits=(0, 20))
assert obj.name == "test"
assert np.all(obj.contrast_limits == (0, 20))
assert isinstance(obj.color_map, np.ndarray)
assert obj.color_map.shape == (3,)
npt.assert_array_equal(obj.color_map, self.RED_DEF)

@pytest.mark.parametrize("color_map", [tuple(RED_2D_REF), np.array(RED_2D_REF, dtype=np.uint8), list(RED_2D_REF)])
def test_2d_colormap(self, color_map):
obj = ChannelInfoFull(name="test", color_map=color_map, contrast_limits=(0, 20))
assert obj.name == "test"
assert np.all(obj.contrast_limits == (0, 20))
assert isinstance(obj.color_map, np.ndarray)
assert obj.color_map.shape == (3, 2)
npt.assert_array_equal(obj.color_map, self.RED_2D_REF)

def test_wrong_dtype_colormap(self):
with pytest.raises(ValueError, match="Colormap as array need to be uint8"):
ChannelInfoFull(name="test", color_map=np.array(self.RED_DEF, dtype=np.float32), contrast_limits=(0, 20))

def test_wrong_colors_colormap(self):
with pytest.raises(ValueError, match="Color map need to have 3 or 4 elements"):
ChannelInfoFull(name="test", color_map=np.array([[0, 0], [0, 0]], dtype=np.uint8), contrast_limits=(0, 20))

def test_wrong_dimesnsion_colormap(self):
with pytest.raises(ValueError, match="Colormap as sequence need to"):
ChannelInfoFull(name="test", color_map=np.array([[[0]], [[0]]], dtype=np.uint8), contrast_limits=(0, 20))
33 changes: 28 additions & 5 deletions package/tests/test_PartSegImage/test_image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ def test_imagej_write_all_metadata(tmp_path, data_test_dir):
npt.assert_array_equal(image2.default_coloring, image.get_imagej_colors())


def test_imagej_save_color(tmp_path):
data = np.zeros((4, 20, 20), dtype=np.uint8)
@pytest.fixture
def data_to_save():
"""Create image with 5 channels with different color definitions"""
data = np.zeros((5, 20, 20), dtype=np.uint8)
data[:, 2:-2, 2:-2] = 20
img = Image(
data,
Expand All @@ -109,18 +111,39 @@ def test_imagej_save_color(tmp_path):
ChannelInfo(name="ch2", color_map="#FFAA00", contrast_limits=(0, 30)),
ChannelInfo(name="ch3", color_map="#FB1", contrast_limits=(0, 25)),
ChannelInfo(name="ch4", color_map=(0, 180, 0), contrast_limits=(0, 22)),
ChannelInfo(
name="ch5",
color_map=np.linspace((0, 0, 0), (128, 255, 0), num=256, dtype=np.uint8).T,
contrast_limits=(0, 20),
),
],
)
assert img.get_colors()[:3] == ["blue", "#FFAA00", "#FB1"]
assert tuple(img.get_colors()[3]) == (0, 180, 0)
IMAGEJImageWriter.save(img, tmp_path / "image.tif")
return img


def test_imagej_save_color(tmp_path, data_to_save):
IMAGEJImageWriter.save(data_to_save, tmp_path / "image.tif")
image2 = TiffImageReader.read_image(tmp_path / "image.tif")
assert image2.channel_names == ["ch1", "ch2", "ch3", "ch4"]
assert image2.ranges == [(0, 20), (0, 30), (0, 25), (0, 22)]
assert image2.channel_names == ["ch1", "ch2", "ch3", "ch4", "ch5"]
npt.assert_array_equal(image2.ranges, [(0, 20), (0, 30), (0, 25), (0, 22), (0, 20)])
assert tuple(image2.default_coloring[0][:, -1]) == (0, 0, 255)
assert tuple(image2.default_coloring[1][:, -1]) == (255, 170, 0)
assert tuple(image2.default_coloring[2][:, -1]) == (255, 187, 17)
assert tuple(image2.default_coloring[3][:, -1]) == (0, 180, 0)
assert tuple(image2.default_coloring[4][:, -1]) == (128, 255, 0)


def test_ome_save_color(tmp_path, data_to_save):
ImageWriter.save(data_to_save, tmp_path / "image.tif")
image2 = TiffImageReader.read_image(tmp_path / "image.tif")
assert image2.channel_names == ["ch1", "ch2", "ch3", "ch4", "ch5"]
assert tuple(image2.default_coloring[0]) == (0, 0, 255)
assert tuple(image2.default_coloring[1]) == (255, 170, 0)
assert tuple(image2.default_coloring[2]) == (255, 187, 17)
assert tuple(image2.default_coloring[3]) == (0, 180, 0)
assert tuple(image2.default_coloring[4]) == (255, 0, 255) # fallback to magenta


def test_save_mask_imagej(tmp_path):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[build-system]
# Minimum requirements for the build system to execute.
requires = ["setuptools>=61.2.0", "setuptools_scm[toml]>=8"] # PEP 508 specifications.
requires = ["setuptools>=77.0.0", "setuptools_scm[toml]>=8"] # PEP 508 specifications.
build-backend = "setuptools.build_meta"

[tool.setuptools_scm]
Expand Down
Loading