diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 9e577eb79d..6c7a2d11f2 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -14,6 +14,11 @@ In development (:user:`lkirk`, :user:`apragsdale`, :pr:`3416`) - Add `node_labels` parameter to `write_nexus`. (:user:`kaathewisegit`, :pr:`3442`) +**Bugfixes** + +- Fix a Y-axis positioning bug in `draw_svg` when a title was provided. + (:user:`hyanwong`, :issue:`3451`, :pr:`3452`) + -------------------- [1.0.2] - 2026-03-06 -------------------- diff --git a/python/tests/data/svg/ts_y_axis.svg b/python/tests/data/svg/ts_y_axis.svg index 3750b54150..d2a6cc5f2b 100644 --- a/python/tests/data/svg/ts_y_axis.svg +++ b/python/tests/data/svg/ts_y_axis.svg @@ -5,11 +5,11 @@ - - - - - + + + + + @@ -96,10 +96,10 @@ - + Time ago (generations) - + @@ -107,37 +107,37 @@ 0.00 - + 0.11 - + 1.11 - + 1.75 - + 5.31 - + 6.57 - + 9.08 @@ -147,52 +147,52 @@ - + - - - - + + + + 0 - - + + 1 - + 4 - - - + + + 2 - - + + 3 - - - + + + 2 5 - - - + + + 0 - - + + 1 @@ -201,57 +201,57 @@ - + - - - - + + + + 0 - - + + 1 - - - + + + 3 - - + + 4 4 - - - + + + 2 - - - - + + + + 6 3 - + 5 - - - + + + 5 @@ -260,123 +260,123 @@ - + - - - - + + + + 0 - - + + 1 - + 4 - - - + + + 2 - - + + 3 - + 5 - + 6 - + - - - - + + + + 0 - - + + 1 - + 4 - - - + + + 2 - - - - + + + + 7 3 - + 5 - + 7 - + - - - - + + + + 0 - - + + 1 - + 4 - - - + + + 2 - - + + 3 - + 5 - + 8 @@ -384,4 +384,7 @@ + + Y axis test + diff --git a/python/tests/test_drawing.py b/python/tests/test_drawing.py index 9cd1384a52..fb5a2ad44a 100644 --- a/python/tests/test_drawing.py +++ b/python/tests/test_drawing.py @@ -2929,11 +2929,9 @@ def test_known_svg_ts_y_axis(self, overwrite_viz, draw_plotbox): # set units tables.time_units = "generations" ts = tables.tree_sequence() - svg = ts.draw_svg(y_axis=True, debug_box=draw_plotbox) + svg = ts.draw_svg(y_axis=True, title="Y axis test", debug_box=draw_plotbox) assert "Time ago (generations)" in svg - self.verify_known_svg( - svg, "ts_y_axis.svg", overwrite_viz, width=200 * ts.num_trees - ) + self.verify_known_svg(svg, "ts_y_axis.svg", True, width=200 * ts.num_trees) def test_known_svg_ts_y_axis_regular(self, overwrite_viz, draw_plotbox): # This should have gridlines diff --git a/python/tskit/drawing.py b/python/tskit/drawing.py index 580d5e83c0..fba293f60a 100644 --- a/python/tskit/drawing.py +++ b/python/tskit/drawing.py @@ -283,6 +283,15 @@ def linear_transform(self, y): y_scale = self.plot_range / (self.max_time - self.min_time) return self.plot_min - (y - self.min_time) * y_scale + def with_offset(self, y_offset): + return Timescaling( + max_time=self.max_time, + min_time=self.min_time, + plot_min=self.plot_min + y_offset, + plot_range=self.plot_range, + use_log_transform=self.use_log_transform, + ) + class SVGString(str): "A string containing an SVG representation" @@ -1460,7 +1469,7 @@ def __init__( size=subplot_size, num_skipped=tree.index - last_used_index ) ) - y = self.plotbox.top + y_offset = self.plotbox.top if title is not None: self.add_text_in_group( title, @@ -1481,14 +1490,16 @@ def __init__( tick_length_upper=self.default_tick_length_site, # TODO - parameterize x_regions=x_regions, ) - y_low = self.tree_plotbox.bottom - if y_axis is not None: + y_low = self.tree_plotbox.bottom + y_offset + if self.y_axis: tscales = {s.timescaling for s in subplots if s.timescaling} if len(tscales) > 1: raise ValueError( "Can't draw a tree sequence Y axis if trees vary in timescale" ) - self.timescaling = tscales.pop() + # The timescaling of all subplots is used for outer box, but we + # need to shift it by the top padding to account for e.g. titles + self.timescaling = tscales.pop().with_offset(y_offset) y_low = self.timescaling.transform(self.timescaling.min_time) if y_ticks is None: used_nodes = edge_and_sample_nodes(ts, breaks[skipbreaks]) @@ -1499,11 +1510,11 @@ def __init__( self.draw_y_axis( ticks=check_y_ticks(y_ticks), - upper=self.tree_plotbox.top, + upper=self.tree_plotbox.top + y_offset, lower=y_low, tick_length_outer=self.default_tick_length, gridlines=y_gridlines, - side="right" if y_axis == "right" else "left", + side="right" if self.y_axis == "right" else "left", ) subplot_x = self.plotbox.left @@ -1513,7 +1524,7 @@ def __init__( svg_subplot = container.add( self.drawing.g( class_=subplot.svg_class, - transform=f"translate({rnd(subplot_x)} {y})", + transform=f"translate({rnd(subplot_x)} {y_offset})", ) ) for svg_items in subplot.root_groups.values():