From 9751c31c20ae68c0f58414c1cfc09227d9046220 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Fri, 29 May 2026 22:25:25 +0800 Subject: [PATCH] Add a benchmark for plotting 2-D xarray with cartopy --- benchmarks/bench_cartopy_xarray.py | 151 +++++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 benchmarks/bench_cartopy_xarray.py diff --git a/benchmarks/bench_cartopy_xarray.py b/benchmarks/bench_cartopy_xarray.py new file mode 100644 index 0000000..c0fbca9 --- /dev/null +++ b/benchmarks/bench_cartopy_xarray.py @@ -0,0 +1,151 @@ +""" +Benchmark PyGMT and cartopy when plotting 2-D Earth relief grids. +""" + +import statistics +import time +from pathlib import Path + +import cartopy.crs as ccrs +import pygmt + +import matplotlib.pyplot as plt # noqa: E402 + + +OUTPUT_DIR = Path("plots/earth-relief") +REPEATS = 10 +GRID_SPACINGS = ("01d",) +REGION = [-180, 180, -90, 90] +CARTOPY_CMAP = "gist_earth" +PYGMT_CMAP = "geo" + + +def load_earth_relief_grid(spacing: str): + """Load an Earth relief grid before benchmark timing starts.""" + return pygmt.datasets.load_earth_relief( + resolution=spacing, + region=REGION, + registration="gridline", + ) + + +def plot_cartopy(grid): + """Create a 2-D Earth relief plot with cartopy.""" + fig = plt.figure(figsize=(6, 4), dpi=300) + ax = fig.add_subplot(1, 1, 1, projection=ccrs.Robinson()) + ax.set_global() + # Use mpl's imshow. + ax.imshow( + grid.values, origin="lower", transform=ccrs.PlateCarree(), cmap=CARTOPY_CMAP + ) + # Use xarray's built-in plotting method. + # grid.plot(ax=ax, transform=ccrs.PlateCarree(), cmap=CARTOPY_CMAP) + return fig + + +def save_cartopy(fig, output: Path) -> None: + """Save a cartopy/matplotlib figure and release it.""" + fig.savefig(output) + plt.close(fig) + + +def plot_pygmt(grid) -> pygmt.Figure: + """Create a 2-D Earth relief plot with PyGMT.""" + fig = pygmt.Figure() + fig.grdimage( + grid=grid, + projection="R6i", + region=REGION, + cmap=PYGMT_CMAP, + frame="afg", + ) + return fig + + +def save_pygmt(fig: pygmt.Figure, output: Path) -> None: + """Save a PyGMT figure.""" + fig.savefig(output) + + +def benchmark( + name: str, + plot_func, + save_func, + grid, + output_dir: Path, + repeats: int, +) -> tuple[list[float], list[float]]: + """Time repeated relief plot creation and figure export runs.""" + output_dir.mkdir(parents=True, exist_ok=True) + + # Warm up each backend once before recording timings. + fig = plot_func(grid) + save_func(fig, output_dir / f"{name}_warmup.png") + + plot_timings = [] + save_timings = [] + for run_id in range(repeats): + output = output_dir / f"{name}_{run_id + 1}.png" + + start = time.perf_counter() + fig = plot_func(grid) + plot_timings.append(time.perf_counter() - start) + + start = time.perf_counter() + save_func(fig, output) + save_timings.append(time.perf_counter() - start) + + return plot_timings, save_timings + + +def format_summary(name: str, timings: list[float]) -> str: + """Format benchmark timing statistics.""" + mean = statistics.fmean(timings) + median = statistics.median(timings) + minimum = min(timings) + maximum = max(timings) + return ( + f"{name:10s} " + f"mean={mean:.4f}s " + f"median={median:.4f}s " + f"min={minimum:.4f}s " + f"max={maximum:.4f}s" + ) + + +def main() -> None: + """Run the Earth relief plotting benchmark.""" + print(f"Running {REPEATS} timed run(s) per backend") + print(f"Writing PNG files to {OUTPUT_DIR}") + + for spacing in GRID_SPACINGS: + print(f"Grid spacing: {spacing}") + grid = load_earth_relief_grid(spacing) + + print("Benchmarking cartopy...", flush=True) + plot_timings, save_timings = benchmark( + name=f"cartopy_{spacing}", + plot_func=plot_cartopy, + save_func=save_cartopy, + grid=grid, + output_dir=OUTPUT_DIR, + repeats=REPEATS, + ) + print(format_summary("cartopy plot", plot_timings)) + print(format_summary("cartopy savefig", save_timings)) + + print("Benchmarking pygmt...", flush=True) + plot_timings, save_timings = benchmark( + name=f"pygmt_{spacing}", + plot_func=plot_pygmt, + save_func=save_pygmt, + grid=grid, + output_dir=OUTPUT_DIR, + repeats=REPEATS, + ) + print(format_summary("pygmt plot", plot_timings)) + print(format_summary("pygmt savefig", save_timings)) + + +if __name__ == "__main__": + main()