Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,7 @@ def render_labels(
scale=param_values["scale"],
table_name=param_values["table_name"],
table_layer=param_values["table_layer"],
transfunc=kwargs.get("transfunc"),
zorder=n_steps,
colorbar=param_values["colorbar"],
colorbar_params=param_values["colorbar_params"],
Expand Down
6 changes: 5 additions & 1 deletion src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def _render_shapes(
sdata_filt[element] = shapes

# color_source_vector is None when the values aren't categorical
if values_are_categorical and render_params.transfunc is not None:
if not values_are_categorical and render_params.transfunc is not None:
color_vector = render_params.transfunc(color_vector)

norm = copy(render_params.cmap_params.norm)
Expand Down Expand Up @@ -1702,6 +1702,10 @@ def _render_labels(
if isinstance(color_vector.dtype, pd.CategoricalDtype):
color_vector = color_vector.remove_unused_categories()

# color_source_vector is None when the values aren't categorical
if color_source_vector is None and render_params.transfunc is not None:
color_vector = render_params.transfunc(color_vector)

def _draw_labels(
seg_erosionpx: int | None,
seg_boundaries: bool,
Expand Down
1 change: 1 addition & 0 deletions src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ class LabelsRenderParams:
scale: str | None = None
table_name: str | None = None
table_layer: str | None = None
transfunc: Callable[[float], float] | None = None
zorder: int = 0
colorbar: bool | str | None = "auto"
colorbar_params: dict[str, object] | None = None
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 19 additions & 0 deletions tests/pl/test_render_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,11 @@ def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs: SpatialData):
cmap=_viridis_with_under_over(),
).pl.show()

def test_plot_transfunc_applied_to_continuous_labels(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", transfunc=lambda x: x * 100).pl.show(
title="transfunc: x * 100"
)

def test_plot_can_annotate_labels_with_table_layer(self, sdata_blobs: SpatialData):
sdata_blobs["table"].layers["normalized"] = get_standard_RNG().random(sdata_blobs["table"].X.shape)
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", table_layer="normalized").pl.show()
Expand Down Expand Up @@ -454,3 +459,17 @@ def test_groups_warns_when_no_groups_match_labels(sdata_blobs: SpatialData, capl
sdata_blobs.pl.render_labels(
labels_name, color="cat", groups=["nonexistent"], table_name="label_table", na_color=None
).pl.show()


def test_transfunc_is_applied_for_continuous_labels(sdata_blobs: SpatialData):
called = []

def track(x):
called.append(True)
return x

fig, ax = plt.subplots()
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", transfunc=track).pl.show(ax=ax)
plt.close(fig)

assert called, "transfunc was not called for continuous labels data"
19 changes: 19 additions & 0 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,11 @@ def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs_shapes_annotated
element="blobs_polygons", color="value", norm=Normalize(2, 4, clip=False), cmap=_viridis_with_under_over()
).pl.show()

def test_plot_transfunc_applied_to_continuous_shapes(self, sdata_blobs_shapes_annotated: SpatialData):
sdata_blobs_shapes_annotated.pl.render_shapes(
element="blobs_polygons", color="value", transfunc=lambda x: x * 100
).pl.show(title="transfunc: x * 100")

def test_plot_datashader_can_color_with_norm_and_clipping(self, sdata_blobs_shapes_annotated: SpatialData):
sdata_blobs_shapes_annotated.pl.render_shapes(
element="blobs_polygons",
Expand Down Expand Up @@ -1310,3 +1315,17 @@ def test_datashader_na_color_nan_overlay(sdata_blobs: SpatialData, na_color: str
f"Expected {expected_images} image(s), got {len(ax.get_images())} for na_color={na_color!r}"
)
plt.close(fig)


def test_transfunc_is_applied_for_continuous_shapes(sdata_blobs_shapes_annotated: SpatialData):
called = []

def track(x):
called.append(True)
return x

fig, ax = plt.subplots()
sdata_blobs_shapes_annotated.pl.render_shapes("blobs_polygons", color="value", transfunc=track).pl.show(ax=ax)
plt.close(fig)

assert called, "transfunc was not called for continuous shapes data"
Loading