diff --git a/poetry.lock b/poetry.lock index 290f9b82..75530217 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1661,6 +1661,18 @@ test-arrow = ["arro3-compute", "arro3-core", "nanoarrow", "pyarrow"] tests = ["check-manifest", "coverage (>=7.4.2)", "defusedxml", "markdown2", "olefile", "packaging", "pyroma (>=5)", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "trove-classifiers (>=2024.10.12)"] xmp = ["defusedxml"] +[[package]] +name = "pip" +version = "26.1.2" +description = "The PyPA recommended tool for installing Python packages." +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "pip-26.1.2-py3-none-any.whl", hash = "sha256:382ff9f685ee3bc25864f820aa50505825f10f5458ffff07e30a6d96e5715cab"}, + {file = "pip-26.1.2.tar.gz", hash = "sha256:f49cd134c61cf2fd75e0ce2676db03e4054504a5a4986d00f8299ae632dc4605"}, +] + [[package]] name = "plotly" version = "6.7.0" @@ -2563,4 +2575,4 @@ plot = ["matplotlib", "plotly"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.15" -content-hash = "7fe4b5b43ccc91f2385df603af3105e77acf96969197e63dea24199857e93847" +content-hash = "17044c3f2505450a7d70bdcbeca05ea7e52da2823e0297ab3e80f40b8429d896" diff --git a/prek.toml b/prek.toml index b70b8a27..7d9dc3de 100644 --- a/prek.toml +++ b/prek.toml @@ -36,7 +36,7 @@ id = "check-added-large-files" repo = "https://github.com/astral-sh/ruff-pre-commit" rev = "v0.15.14" [[repos.hooks]] -id = "ruff" +id = "ruff-check" args = ["--fix", "--exit-non-zero-on-fix", "--ignore=C901"] [[repos.hooks]] diff --git a/pyproject.toml b/pyproject.toml index a1089df1..d1937ee3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ dependencies = [ "flatdict (==4.1.0)", "pytest (>=9.0.3,<10.0.0)", "unyt (>=3.1.0,<4.0.0)", + "pip>=26.1.2", ] [project.urls] @@ -88,9 +89,62 @@ dev = [ ] [tool.ruff] -lint.extend-select = ["C901", "T201"] -lint.mccabe.max-complexity = 11 -extend-exclude = ["tests", "examples", "notebooks"] +extend-exclude = ["tests", "examples", "notebooks", "simvue/pynvml.py"] +preview = true + +[tool.ruff.lint] +ignore = ["COM812", "PLR0904", "PLR6301", "D203", "D212"] +dummy-variable-rgx = "^_+$" +extend-select = [ + "C90", + "D417", + "E", + "F", + "W", + "B", + "UP", + "I", + "SIM", + "ARG", + "RUF", + "RET", + "S", + "BLE", + "FBT", + "COM", + "A", + "C4", + "DTZ", + "ICN", + "G", + "PIE", + "T20", + "INP", + "PYI", + "PT", + "Q", + "RSE", + "SLOT", + "TID", + "TC", + "ARG", + "I", + "D417", + "ERA", + "PL", + "UP", + "FURB", + "N", +] + +[tool.ruff.lint.mccabe] +max-complexity = 14 + +[tool.ruff.lint.pylint] +max-returns = 8 +max-args = 20 +max-branches = 15 +max-statements = 60 [tool.pytest.ini_options] addopts = "-p no:warnings --no-cov -n 0" diff --git a/simvue/api/__init__.py b/simvue/api/__init__.py index f56eb001..bc6a837d 100644 --- a/simvue/api/__init__.py +++ b/simvue/api/__init__.py @@ -1,5 +1,4 @@ -""" -Simvue API +"""Simvue API. ========== Module contains methods for interacting with a Simvue server diff --git a/simvue/api/objects/__init__.py b/simvue/api/objects/__init__.py index 78a99131..48cc0a71 100644 --- a/simvue/api/objects/__init__.py +++ b/simvue/api/objects/__init__.py @@ -1,5 +1,4 @@ -""" -Simvue API Objects +"""Simvue API Objects. ================== The following module defines objects which provide exact representations @@ -12,50 +11,49 @@ from .alert import ( Alert, EventsAlert, - MetricsThresholdAlert, MetricsRangeAlert, + MetricsThresholdAlert, UserAlert, ) -from .storage import ( - S3Storage, - FileStorage, - Storage, -) from .artifact import ( + Artifact, FileArtifact, ObjectArtifact, - Artifact, ) - -from .stats import Stats -from .run import Run -from .tag import Tag -from .folder import Folder, get_folder_from_path from .events import Events as Events -from .metrics import Metrics as Metrics +from .folder import Folder, get_folder_from_path from .grids import Grid, GridMetrics +from .metrics import Metrics as Metrics +from .run import Run +from .stats import Stats +from .storage import ( + FileStorage, + S3Storage, + Storage, +) +from .tag import Tag __all__ = [ + "Alert", + "Artifact", + "Events", + "EventsAlert", + "FileArtifact", + "FileStorage", + "Folder", "Grid", "GridMetrics", "Metrics", - "Events", - "get_folder_from_path", - "Folder", - "Stats", - "Run", - "Tag", - "Artifact", - "FileArtifact", + "MetricsRangeAlert", + "MetricsThresholdAlert", "ObjectArtifact", + "Run", "S3Storage", - "FileStorage", + "Stats", "Storage", - "MetricsRangeAlert", - "MetricsThresholdAlert", - "UserAlert", - "EventsAlert", - "Alert", + "Tag", "Tenant", "User", + "UserAlert", + "get_folder_from_path", ] diff --git a/simvue/api/objects/administrator/tenant.py b/simvue/api/objects/administrator/tenant.py index 61c4b7fc..3b1c407d 100644 --- a/simvue/api/objects/administrator/tenant.py +++ b/simvue/api/objects/administrator/tenant.py @@ -9,11 +9,14 @@ from typing import Self, override except ImportError: from typing_extensions import Self, override + +import datetime +import typing from collections.abc import Generator + import pydantic -import datetime -from simvue.api.objects.base import write_only, SimvueObject, staging_check +from simvue.api.objects.base import SimvueObject, staging_check, write_only from simvue.models import DATETIME_FORMAT @@ -33,7 +36,7 @@ def __init__( server_token: pydantic.SecretStr | None = None, **kwargs, ) -> None: - """Initialise a Tenant + """Initialise a Tenant. If an identifier is provided a connection will be made to the object matching the identifier on the target server. @@ -51,9 +54,13 @@ def __init__( token for alternative server, default None **kwargs : dict any additional arguments to be passed to the object initialiser + """ super().__init__( - identifier, server_url=server_url, server_token=server_token, **kwargs + identifier, + server_url=server_url, + server_token=server_token, + **kwargs, ) @override @@ -123,7 +130,7 @@ def get( offset: pydantic.NonNegativeInt | None = None, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Generator[tuple[str, Self | None]]: """Retrieve tenants from the server. @@ -137,6 +144,8 @@ def get( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for retrieval Yields ------ @@ -146,86 +155,95 @@ def get( Returns ------- Generator[tuple[str, Tenant | None]] + """ # Currently no tenant filters _ = kwargs.pop("filters", None) return super().get( - count=count, offset=offset, server_url=server_url, server_token=server_token + count=count, + offset=offset, + server_url=server_url, + server_token=server_token, ) @property def name(self) -> str: - """Retrieve the name of the tenant""" + """Retrieve the name of the tenant.""" return self._get_attribute("name") @name.setter @write_only @pydantic.validate_call def name(self, name: str) -> None: - """Change name of tenant""" + """Change name of tenant.""" self._staging["name"] = name @property @staging_check def is_enabled(self) -> bool: - """Retrieve if tenant is enabled""" + """Retrieve if tenant is enabled.""" return self._get_attribute("is_enabled") @is_enabled.setter @write_only @pydantic.validate_call def is_enabled(self, is_enabled: bool) -> None: - """Enable/disable tenant""" + """Enable/disable tenant.""" self._staging["is_enabled"] = is_enabled @property @staging_check def max_request_rate(self) -> int: - """Retrieve the tenant's maximum request rate""" + """Retrieve the tenant's maximum request rate.""" return self._get_attribute("max_request_rate") @max_request_rate.setter @write_only @pydantic.validate_call def max_request_rate(self, max_request_rate: int) -> None: - """Update tenant's maximum request rate""" + """Update tenant's maximum request rate.""" self._staging["max_request_rate"] = max_request_rate @property @staging_check def max_runs(self) -> int: - """Retrieve the tenant's maximum runs""" + """Retrieve the tenant's maximum runs.""" return self._get_attribute("max_runs") @max_runs.setter @write_only @pydantic.validate_call def max_runs(self, max_runs: int) -> None: - """Update tenant's maximum runs""" + """Update tenant's maximum runs.""" self._staging["max_runs"] = max_runs @property @staging_check def max_data_volume(self) -> int: - """Retrieve the tenant's maximum data volume""" + """Retrieve the tenant's maximum data volume.""" return self._get_attribute("max_data_volume") @max_data_volume.setter @write_only @pydantic.validate_call def max_data_volume(self, max_data_volume: int) -> None: - """Update tenant's maximum data volume""" + """Update tenant's maximum data volume.""" self._staging["max_data_volume"] = max_data_volume @property def created(self) -> datetime.datetime | None: - """Set/retrieve created datetime for the run. + """Set/retrieve created datetime in UTC for the run. Returns ------- datetime.datetime + """ _created: str | None = self._get_attribute("created") return ( - datetime.datetime.strptime(_created, DATETIME_FORMAT) if _created else None + datetime.datetime.strptime(_created, DATETIME_FORMAT).astimezone( + datetime.UTC, + ) + if _created + else None ) diff --git a/simvue/api/objects/administrator/user.py b/simvue/api/objects/administrator/user.py index 272cc7c5..f658a7fb 100644 --- a/simvue/api/objects/administrator/user.py +++ b/simvue/api/objects/administrator/user.py @@ -5,8 +5,10 @@ """ -import pydantic import datetime +import typing + +import pydantic from simvue.models import DATETIME_FORMAT @@ -34,7 +36,7 @@ def __init__( server_token: pydantic.SecretStr | None = None, **kwargs, ) -> None: - """Initialise a User + """Initialise a User. If an identifier is provided a connection will be made to the object matching the identifier on the target server. @@ -52,9 +54,13 @@ def __init__( token for alternative server, default None **kwargs : dict any additional arguments to be passed to the object initialiser + """ super().__init__( - identifier, server_url=server_url, server_token=server_token, **kwargs + identifier, + server_url=server_url, + server_token=server_token, + **kwargs, ) @override @@ -145,7 +151,7 @@ def get( offset: int | None = None, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> dict[str, "User"]: """Retrieve users from the Simvue server. @@ -159,11 +165,14 @@ def get( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for retrieval Yields ------ User user instance representing user on server + """ # Currently no user filters _ = kwargs.pop("filters", None) @@ -178,7 +187,7 @@ def get( @property @staging_check def username(self) -> str: - """Retrieve the username for the user""" + """Retrieve the username for the user.""" if self.id and self.id.startswith("offline_"): return self._get_attribute("user")["username"] return self._get_attribute("username") @@ -187,13 +196,13 @@ def username(self) -> str: @write_only @pydantic.validate_call def username(self, username: str) -> None: - """Set the username for the user""" + """Set the username for the user.""" self._staging["username"] = username @property @staging_check def fullname(self) -> str: - """Retrieve the full name for the user""" + """Retrieve the full name for the user.""" if self.id and self.id.startswith("offline_"): return self._get_attribute("user")["fullname"] return self._get_attribute("fullname") @@ -202,13 +211,13 @@ def fullname(self) -> str: @write_only @pydantic.validate_call def fullname(self, fullname: str) -> None: - """Set the full name for the user""" + """Set the full name for the user.""" self._staging["fullname"] = fullname @property @staging_check def is_manager(self) -> bool: - """Retrieve if the user has manager privileges""" + """Retrieve if the user has manager privileges.""" if self.id and self.id.startswith("offline_"): return self._get_attribute("user")["is_manager"] return self._get_attribute("is_manager") @@ -217,13 +226,13 @@ def is_manager(self) -> bool: @write_only @pydantic.validate_call def is_manager(self, is_manager: bool) -> None: - """Set if the user has manager privileges""" + """Set if the user has manager privileges.""" self._staging["is_manager"] = is_manager @property @staging_check def is_admin(self) -> bool: - """Retrieve if the user has admin privileges""" + """Retrieve if the user has admin privileges.""" if self.id and self.id.startswith("offline_"): return self._get_attribute("user")["is_admin"] return self._get_attribute("is_admin") @@ -232,12 +241,12 @@ def is_admin(self) -> bool: @write_only @pydantic.validate_call def is_admin(self, is_admin: bool) -> None: - """Set if the user has admin privileges""" + """Set if the user has admin privileges.""" self._staging["is_admin"] = is_admin @property def deleted(self) -> bool: - """Retrieve if the user is pending deletion""" + """Retrieve if the user is pending deletion.""" if self.id and self.id.startswith("offline_"): return self._get_attribute("user")["is_deleted"] return self._get_attribute("is_deleted") @@ -245,7 +254,7 @@ def deleted(self) -> bool: @property @staging_check def is_readonly(self) -> bool: - """Retrieve if the user has read-only access""" + """Retrieve if the user has read-only access.""" if self.id and self.id.startswith("offline_"): return self._get_attribute("user")["is_readonly"] return self._get_attribute("is_readonly") @@ -254,13 +263,13 @@ def is_readonly(self) -> bool: @write_only @pydantic.validate_call def is_readonly(self, is_readonly: bool) -> None: - """Set if the user has read-only access""" + """Set if the user has read-only access.""" self._staging["is_readonly"] = is_readonly @property @staging_check def enabled(self) -> bool: - """Retrieve if the user is enabled""" + """Retrieve if the user is enabled.""" if self.id and self.id.startswith("offline_"): return self._get_attribute("user")["is_enabled"] return self._get_attribute("is_enabled") @@ -269,13 +278,13 @@ def enabled(self) -> bool: @write_only @pydantic.validate_call def enabled(self, is_enabled: bool) -> None: - """Set if the user is enabled""" + """Set if the user is enabled.""" self._staging["is_enabled"] = is_enabled @property @staging_check def email(self) -> str: - """Retrieve the user email""" + """Retrieve the user email.""" if self.id and self.id.startswith("offline_"): return self._get_attribute("user")["email"] return self._get_attribute("email") @@ -284,18 +293,23 @@ def email(self) -> str: @write_only @pydantic.validate_call def email(self, email: str) -> None: - """Set the user email""" + """Set the user email.""" self._staging["email"] = email @property def created(self) -> datetime.datetime | None: - """Set/retrieve created datetime for the run. + """Set/retrieve created datetime in UTC for the run. Returns ------- datetime.datetime + """ _created: str | None = self._get_attribute("created") return ( - datetime.datetime.strptime(_created, DATETIME_FORMAT) if _created else None + datetime.datetime.strptime(_created, DATETIME_FORMAT).astimezone( + datetime.UTC, + ) + if _created + else None ) diff --git a/simvue/api/objects/alert/__init__.py b/simvue/api/objects/alert/__init__.py index 71cabde8..361ad035 100644 --- a/simvue/api/objects/alert/__init__.py +++ b/simvue/api/objects/alert/__init__.py @@ -6,15 +6,15 @@ """ -from .fetch import Alert -from .metrics import MetricsThresholdAlert, MetricsRangeAlert from .events import EventsAlert +from .fetch import Alert +from .metrics import MetricsRangeAlert, MetricsThresholdAlert from .user import UserAlert __all__ = [ "Alert", + "EventsAlert", "MetricsRangeAlert", "MetricsThresholdAlert", - "EventsAlert", "UserAlert", ] diff --git a/simvue/api/objects/alert/base.py b/simvue/api/objects/alert/base.py index 8d1258a0..0286c9d9 100644 --- a/simvue/api/objects/alert/base.py +++ b/simvue/api/objects/alert/base.py @@ -4,23 +4,28 @@ """ -import http -import pydantic import datetime +import http import typing + +import pydantic + from simvue.api.objects.base import SimvueObject, staging_check, write_only -from simvue.api.request import get as sv_get, get_json_from_response -from simvue.api.url import URL -from simvue.models import NAME_REGEX, DATETIME_FORMAT +from simvue.api.request import get as sv_get +from simvue.api.request import get_json_from_response +from simvue.models import DATETIME_FORMAT, NAME_REGEX + +if typing.TYPE_CHECKING: + from simvue.api.url import URL try: from typing import Self, override except ImportError: - from typing_extensions import Self, override # noqa: UP035 + from typing_extensions import Self, override class AlertBase(SimvueObject): - """Class for interfacing with Simvue alerts + """Class for interfacing with Simvue alerts. Contains properties common to all alert types. """ @@ -53,9 +58,9 @@ def __init__( server_token: pydantic.SecretStr | None = None, **kwargs, ) -> None: - """Retrieve an alert from the Simvue server by identifier""" + """Retrieve an alert from the Simvue server by identifier.""" _params: dict[str, str | bool] = kwargs.pop("_params", {}) | { - "deduplicate": not kwargs.get("allow_duplicates", True) + "deduplicate": not kwargs.get("allow_duplicates", True), } super().__init__( identifier=identifier, @@ -78,38 +83,41 @@ def _compare_objects(self, other: "AlertBase") -> bool: self.description == other.description, self.source == other.source, self.notification == other.notification, - ] + ], ) @override def __eq__(self, other: "AlertBase") -> bool: """Check if alerts are the same.""" - # Need to ensure objects are read-only for this # operation as we do not want staging to alter _self_is_read_only: bool = self._read_only _other_is_read_only: bool = other._read_only - self.read_only(True) - other.read_only(True) + self.read_only(is_read_only=True) + other.read_only(is_read_only=True) _comparison = self._compare_objects(other) # Restore to write allowed unless the input object # was read-only to begin with if not _self_is_read_only: - self.read_only(False, clear_staged=False) + self.read_only(is_read_only=False, clear_staged=False) if not _other_is_read_only: - other.read_only(False, clear_staged=False) + other.read_only(is_read_only=False, clear_staged=False) return _comparison + @override + def __hash__(self) -> int: + return hash(f"{self.name}+{self.description}+{self.source}+{self.notification}") + def compare(self, other: "AlertBase") -> bool: - """Compare this alert to another""" + """Compare this alert to another.""" return type(self) is type(other) and self.name == other.name @staging_check def get_alert(self) -> dict[str, typing.Any]: - """Retrieve alert definition""" + """Retrieve alert definition.""" try: return self._get_attribute("alert") except AttributeError: @@ -117,121 +125,127 @@ def get_alert(self) -> dict[str, typing.Any]: @property def name(self) -> str: - """Retrieve alert name""" + """Retrieve alert name.""" return self._get_attribute("name") @name.setter @write_only @pydantic.validate_call def name( - self, name: typing.Annotated[str, pydantic.Field(pattern=NAME_REGEX)] + self, + name: typing.Annotated[str, pydantic.Field(pattern=NAME_REGEX)], ) -> None: - """Set alert name""" + """Set alert name.""" self._staging["name"] = name @property @staging_check def description(self) -> str | None: - """Retrieve alert description""" + """Retrieve alert description.""" return self._get_attribute("description") @description.setter @write_only @pydantic.validate_call def description(self, description: str | None) -> None: - """Set alert description""" + """Set alert description.""" self._staging["description"] = description @property def run_tags(self) -> list[str]: - """Retrieve automatically assigned tags from runs""" + """Retrieve automatically assigned tags from runs.""" return self._get_attribute("run_tags") @property @staging_check def auto(self) -> bool: - """Retrieve if alert has run tag auto-assign""" + """Retrieve if alert has run tag auto-assign.""" return self._get_attribute("auto") @auto.setter @write_only @pydantic.validate_call def auto(self, auto: bool) -> None: - """Set alert to use run tag auto-assign""" + """Set alert to use run tag auto-assign.""" self._staging["auto"] = auto @property @staging_check def notification(self) -> typing.Literal["none", "email"]: - """Retrieve alert notification setting""" + """Retrieve alert notification setting.""" return self._get_attribute("notification") @notification.setter @write_only @pydantic.validate_call def notification(self, notification: typing.Literal["none", "email"]) -> None: - """Configure alert notification setting""" + """Configure alert notification setting.""" self._staging["notification"] = notification @property def source(self) -> typing.Literal["events", "metrics", "user"]: - """Retrieve alert source""" + """Retrieve alert source.""" return self._get_attribute("source") @property @staging_check def enabled(self) -> bool: - """Retrieve if alert is enabled""" + """Retrieve if alert is enabled.""" return self._get_attribute("enabled") @enabled.setter @write_only @pydantic.validate_call def enabled(self, enabled: str) -> None: - """Enable/disable alert""" + """Enable/disable alert.""" self._staging["enabled"] = enabled @property @staging_check def abort(self) -> bool: - """Retrieve if alert can abort simulations""" + """Retrieve if alert can abort simulations.""" return self._get_attribute("abort") @property @staging_check def delay(self) -> int: - """Retrieve delay value for this alert""" + """Retrieve delay value for this alert.""" return self._get_attribute("delay") @property def created(self) -> datetime.datetime | None: - """Retrieve created datetime for the alert""" + """Retrieve created datetime in UTC for the alert.""" _created: str | None = self._get_attribute("created") return ( - datetime.datetime.strptime(_created, DATETIME_FORMAT) if _created else None + datetime.datetime.strptime(_created, DATETIME_FORMAT).astimezone( + datetime.UTC, + ) + if _created + else None ) @abort.setter @write_only @pydantic.validate_call def abort(self, abort: bool) -> None: - """Configure alert to trigger aborts""" + """Configure alert to trigger aborts.""" self._staging["abort"] = abort @pydantic.validate_call - def set_status(self, run_id: str, status: typing.Literal["ok", "critical"]) -> None: - """Set the status of this alert for a given run""" + def set_status(self, _: str, __: typing.Literal["ok", "critical"]) -> None: + """Set the status of this alert for a given run.""" raise AttributeError( - f"Cannot update state for alert of type '{self.__class__.__name__}'" + f"Cannot update state for alert of type '{self.__class__.__name__}'", ) def get_status(self, run_id: str) -> typing.Literal["ok", "critical"]: - """Retrieve the status of this alert for a given run""" + """Retrieve the status of this alert for a given run.""" _offline_run: bool = run_id.startswith("offline") if not self._offline and run_id.startswith("offline"): raise ValueError( - f"Cannot retrieve status of online alert '{self.id}' for offline run '{run_id}'" + f"Cannot retrieve status of online alert '{self.id}' " + f"for offline run '{run_id}'", ) _url: URL = self.url / f"status/{run_id}" diff --git a/simvue/api/objects/alert/events.py b/simvue/api/objects/alert/events.py index f4a504d3..38619dae 100644 --- a/simvue/api/objects/alert/events.py +++ b/simvue/api/objects/alert/events.py @@ -5,25 +5,27 @@ """ import typing -import pydantic - from collections.abc import Generator +import pydantic + try: from typing import Self, override except ImportError: from typing_extensions import Self, override -from simvue.api.objects.base import write_only, staging_check -from .base import AlertBase +from simvue.api.objects.base import staging_check, write_only from simvue.models import NAME_REGEX +from .base import AlertBase + class EventsAlert(AlertBase): """Simvue Events Alert. - This class is used to connect to/create event-based alert objects on the Simvue server, - any modification of EventsAlert instance attributes is mirrored on the remote object. + This class is used to connect to/create event-based alert objects + on the Simvue server, any modification of EventsAlert instance + attributes is mirrored on the remote object. """ @@ -36,11 +38,12 @@ def __init__( server_token: pydantic.SecretStr | None = None, **kwargs, ) -> None: - """Initialise an Events Alert + """Initialise an Events Alert. If an identifier is provided a connection will be made to the object matching the identifier on the target server. - Else a new EventsAlert instance will be created using arguments provided in kwargs. + Else a new EventsAlert instance will be created using + arguments provided in kwargs. Parameters ---------- @@ -52,10 +55,14 @@ def __init__( token for alternative server, default None **kwargs : dict any additional arguments to be passed to the object initialiser + """ self.alert = EventAlertDefinition(self) super().__init__( - identifier, server_url=server_url, server_token=server_token, **kwargs + identifier, + server_url=server_url, + server_token=server_token, + **kwargs, ) @override @@ -68,7 +75,7 @@ def get( server_url: str | None = None, server_token: pydantic.SecretStr | None = None, ) -> Generator[dict[str, typing.Any]]: - """Retrieve only alerts of the event alert type""" + """Retrieve only alerts of the event alert type.""" raise NotImplementedError("Retrieval of only event alerts is not yet supported") @classmethod @@ -87,7 +94,7 @@ def new( server_token: pydantic.SecretStr | None = None, **_, ) -> Self: - """Create a new event-based alert + """Create a new event-based alert. Note parameters are keyword arguments only. @@ -118,7 +125,6 @@ def new( a new event alert with changes staged """ - _alert_definition = {"pattern": pattern, "frequency": frequency} _alert = cls( name=name, @@ -145,46 +151,49 @@ def _compare_objects(self, other: "AlertBase") -> bool: class EventAlertDefinition: - """Event alert definition sub-class""" + """Event alert definition sub-class.""" def __init__(self, alert: EventsAlert) -> None: - """Initialise an alert definition with its parent alert""" + """Initialise an alert definition with its parent alert.""" self._sv_obj = alert def __eq__(self, other: "EventAlertDefinition") -> bool: - """Compare this definition with that of another EventAlert""" + """Compare this definition with that of another EventAlert.""" return all( [ self.frequency == other.frequency, self.pattern == other.pattern, - ] + ], ) + def __hash__(self) -> int: + return hash(f"{self.frequency}+{self.pattern}") + @property def pattern(self) -> str: - """Retrieve the event log pattern monitored by this alert""" + """Retrieve the event log pattern monitored by this alert.""" try: return self._sv_obj.get_alert()["pattern"] except KeyError as e: raise RuntimeError( - "Expected key 'pattern' in alert definition retrieval" + "Expected key 'pattern' in alert definition retrieval", ) from e @property @staging_check def frequency(self) -> int: - """Retrieve the update frequency for this alert""" + """Retrieve the update frequency for this alert.""" try: return self._sv_obj.get_alert()["frequency"] except KeyError as e: raise RuntimeError( - "Expected key 'frequency' in alert definition retrieval" + "Expected key 'frequency' in alert definition retrieval", ) from e @frequency.setter @write_only @pydantic.validate_call def frequency(self, frequency: int) -> None: - """Set the update frequency for this alert""" + """Set the update frequency for this alert.""" _alert = self._sv_obj.get_alert() | {"frequency": frequency} - self._sv_obj._staging["alert"] = _alert + self._sv_obj.staging["alert"] = _alert diff --git a/simvue/api/objects/alert/fetch.py b/simvue/api/objects/alert/fetch.py index 2d6d503f..d68bbbbe 100644 --- a/simvue/api/objects/alert/fetch.py +++ b/simvue/api/objects/alert/fetch.py @@ -6,6 +6,7 @@ import http import json +import typing import pydantic @@ -18,13 +19,15 @@ from typing_extensions import override from collections.abc import Generator + from simvue.api.objects.alert.user import UserAlert from simvue.api.objects.base import Sort -from simvue.api.request import get_json_from_response from simvue.api.request import get as sv_get -from .events import EventsAlert -from .metrics import MetricsThresholdAlert, MetricsRangeAlert +from simvue.api.request import get_json_from_response + from .base import AlertBase +from .events import EventsAlert +from .metrics import MetricsRangeAlert, MetricsThresholdAlert AlertType = EventsAlert | UserAlert | MetricsThresholdAlert | MetricsRangeAlert @@ -33,7 +36,7 @@ class AlertSort(Sort): @pydantic.field_validator("column") @classmethod def check_column(cls, column: str) -> str: - if column and column not in ("name", "created"): + if column and column not in {"name", "created"}: raise ValueError(f"Invalid sort column for alerts '{column}'") return column @@ -53,7 +56,7 @@ def __new__( *, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> AlertType: """Retrieve an object representing an alert on the server by id. @@ -65,11 +68,14 @@ def __new__( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional retrieval arguments Returns ------- MetricsThresholdAlert | MetricRangeAlert | UserAlert | EventsAlert object representing an alert + """ _alert_pre = AlertBase( identifier=identifier, @@ -80,12 +86,14 @@ def __new__( if ( identifier is not None and identifier.startswith("offline_") - and not _alert_pre._staging.get("source", None) + and not _alert_pre.staging.get("source", None) ): raise RuntimeError( - "Cannot determine Alert type - this is likely because you are attempting to reconnect " - + "to an offline alert which has already been sent to the server. To fix this, use the " - + "exact Alert type instead (eg MetricThresholdAlert, MetricRangeAlert etc)." + "Cannot determine Alert type - this is likely because you " + "are attempting to reconnect to an offline alert which " + "has already been sent to the server. To fix this, use the " + "exact Alert type instead " + "(eg MetricThresholdAlert, MetricRangeAlert etc).", ) if _alert_pre.source == "events": return EventsAlert( @@ -94,21 +102,21 @@ def __new__( server_token=server_token, **kwargs, ) - elif _alert_pre.source == "metrics" and _alert_pre.get_alert().get("threshold"): + if _alert_pre.source == "metrics" and _alert_pre.get_alert().get("threshold"): return MetricsThresholdAlert( identifier=identifier, server_url=server_url, server_token=server_token, **kwargs, ) - elif _alert_pre.source == "metrics": + if _alert_pre.source == "metrics": return MetricsRangeAlert( identifier=identifier, server_url=server_url, server_token=server_token, **kwargs, ) - elif _alert_pre.source == "user": + if _alert_pre.source == "user": return UserAlert( identifier=identifier, server_url=server_url, @@ -128,7 +136,7 @@ def get( sorting: list[AlertSort] | None = None, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Generator[tuple[str, AlertType]]: """Fetch all alerts from the server for the current user. @@ -144,19 +152,23 @@ def get( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional retrieval arguments Yields ------ tuple[str, AlertType] identifier for an alert the alert itself as a class instance - """ + """ # Currently no alert filters _ = kwargs.pop("filters", None) _config: SimvueConfiguration = SimvueConfiguration.fetch( - mode="online", server_url=server_url, server_token=server_token + mode="online", + server_url=server_url, + server_token=server_token, ) _url = URL(f"{_config.server.url}") / AlertBase.endpoint() @@ -202,7 +214,10 @@ def get( yield ( _id, MetricsThresholdAlert( - _local=True, _read_only=True, identifier=_id, **_entry + _local=True, + _read_only=True, + identifier=_id, + **_entry, ), ) elif ( @@ -212,10 +227,14 @@ def get( yield ( _id, MetricsRangeAlert( - _local=True, _read_only=True, identifier=_id, **_entry + _local=True, + _read_only=True, + identifier=_id, + **_entry, ), ) else: raise RuntimeError( - f"Unrecognised alert source '{_entry['source']}' with data '{_entry}'" + f"Unrecognised alert source '{_entry['source']}' " + f"with data '{_entry}'", ) diff --git a/simvue/api/objects/alert/metrics.py b/simvue/api/objects/alert/metrics.py index 43ce4843..4611e947 100644 --- a/simvue/api/objects/alert/metrics.py +++ b/simvue/api/objects/alert/metrics.py @@ -5,32 +5,35 @@ """ -import pydantic import typing +import pydantic + try: from typing import Self except ImportError: from typing_extensions import Self from simvue.api.objects.base import write_only -from .base import AlertBase, staging_check from simvue.models import NAME_REGEX +from .base import AlertBase, staging_check + Aggregate = typing.Literal["average", "sum", "at least one", "all"] Rule = typing.Literal["is above", "is below", "is inside range", "is outside range"] try: from typing import override except ImportError: - from typing_extensions import override # noqa: UP035 + from typing_extensions import override class MetricsThresholdAlert(AlertBase): """Simvue Metrics Threshold Alert. - This class is used to connect to/create metrics threshold alert objects on the Simvue server, - any modification of MetricsThresholdAlert instance attributes is mirrored on the remote object. + This class is used to connect to/create metrics threshold alert + objects on the Simvue server, any modification of MetricsThresholdAlert + instance attributes is mirrored on the remote object. """ @@ -42,11 +45,12 @@ def __init__( server_token: pydantic.SecretStr | None = None, **kwargs, ) -> None: - """Initialise a Metrics Threshold Alert + """Initialise a Metrics Threshold Alert. If an identifier is provided a connection will be made to the object matching the identifier on the target server. - Else a new MetricsThresholdAlert instance will be created using arguments provided in kwargs. + Else a new MetricsThresholdAlert instance will be created using + arguments provided in kwargs. Parameters ---------- @@ -58,10 +62,14 @@ def __init__( token for alternative server, default None **kwargs : dict any additional arguments to be passed to the object initialiser + """ self.alert = MetricThresholdAlertDefinition(self) super().__init__( - identifier, server_url=server_url, server_token=server_token, **kwargs + identifier, + server_url=server_url, + server_token=server_token, + **kwargs, ) self._local_only_args += [ "rule", @@ -81,7 +89,7 @@ def get( server_token: pydantic.SecretStr | None = None, **_, ) -> dict[str, typing.Any]: - """Retrieve only MetricsThresholdAlerts""" + """Retrieve only MetricsThresholdAlerts.""" raise NotImplementedError("Retrieve of only metric alerts is not yet supported") @override @@ -97,7 +105,7 @@ def new( aggregation: Aggregate, rule: typing.Literal["is above", "is below"], window: pydantic.PositiveInt, - threshold: float | int, + threshold: float, frequency: pydantic.PositiveInt, enabled: bool = True, offline: bool = False, @@ -105,7 +113,7 @@ def new( server_token: pydantic.SecretStr | None = None, **_, ) -> Self: - """Create a new metric threshold alert either locally or on the server + """Create a new metric threshold alert either locally or on the server. Note all arguments are keyword arguments. @@ -142,6 +150,7 @@ def new( ------- MetricsThresholdAlert object representing a metric threshold alert + """ _alert_definition = { "rule": rule, @@ -163,7 +172,7 @@ def new( _read_only=False, _offline=offline, ) - _alert._staging |= _alert_definition + _alert.append_to_staging(_alert_definition) _alert._params = {"deduplicate": True} return _alert @@ -178,8 +187,9 @@ def _compare_objects(self, other: "AlertBase") -> bool: class MetricsRangeAlert(AlertBase): """Simvue Metrics Range Alert. - This class is used to connect to/create metrics range alert objects on the Simvue server, - any modification of MetricsRangeAlert instance attributes is mirrored on the remote object. + This class is used to connect to/create metrics range alert objects + on the Simvue server, any modification of MetricsRangeAlert instance + attributes is mirrored on the remote object. """ @@ -191,11 +201,12 @@ def __init__( server_token: pydantic.SecretStr | None = None, **kwargs, ) -> None: - """Initialise a Metrics Range Alert + """Initialise a Metrics Range Alert. If an identifier is provided a connection will be made to the object matching the identifier on the target server. - Else a new MetricsRangeAlert instance will be created using arguments provided in kwargs. + Else a new MetricsRangeAlert instance will be created using arguments + provided in kwargs. Parameters ---------- @@ -207,10 +218,14 @@ def __init__( token for alternative server, default None **kwargs : dict any additional arguments to be passed to the object initialiser + """ self.alert = MetricRangeAlertDefinition(self) super().__init__( - identifier, server_url=server_url, server_token=server_token, **kwargs + identifier, + server_url=server_url, + server_token=server_token, + **kwargs, ) self._local_only_args += [ "rule", @@ -222,7 +237,7 @@ def __init__( @override def _compare_objects(self, other: "AlertBase") -> bool: - """Compare two MetricRangeAlerts""" + """Compare two MetricRangeAlerts.""" if not isinstance(other, MetricsRangeAlert): return False return super()._compare_objects(other) and self.alert == other.alert @@ -248,7 +263,7 @@ def new( server_token: pydantic.SecretStr | None = None, **_, ) -> Self: - """Create a new metric range alert either locally or on the server + """Create a new metric range alert either locally or on the server. Note all arguments are keyword arguments. @@ -314,42 +329,46 @@ def new( class MetricsAlertDefinition: - """General alert definition for a metric alert""" + """General alert definition for a metric alert.""" def __init__(self, alert: MetricsRangeAlert) -> None: - """Initialise definition with target alert""" + """Initialise definition with target alert.""" self._sv_obj = alert def __eq__(self, other: "MetricsAlertDefinition") -> bool: - """Compare a MetricsAlertDefinition with another""" + """Compare a MetricsAlertDefinition with another.""" return all( [ self.aggregation == other.aggregation, self.frequency == other.frequency, self.rule == other.rule, self.window == other.window, - ] + ], ) + def __hash__(self) -> int: + """Return definition hash.""" + return hash(f"{self.aggregation}-{self.frequency}-{self.rule}-{self.window}") + @property def aggregation(self) -> Aggregate: - """Retrieve the aggregation strategy for this alert""" + """Retrieve the aggregation strategy for this alert.""" if (_aggregation := self._sv_obj.get_alert().get("aggregation")) is None: raise RuntimeError( - "Expected key 'aggregation' in alert definition retrieval" + "Expected key 'aggregation' in alert definition retrieval", ) return _aggregation @property def rule(self) -> Rule: - """Retrieve the rule for this alert""" + """Retrieve the rule for this alert.""" if (_rule := self._sv_obj.get_alert().get("rule")) is None: raise RuntimeError("Expected key 'rule' in alert definition retrieval") return _rule @property def window(self) -> int: - """Retrieve the aggregation window for this alert""" + """Retrieve the aggregation window for this alert.""" if (_window := self._sv_obj.get_alert().get("window")) is None: raise RuntimeError("Expected key 'window' in alert definition retrieval") return _window @@ -357,46 +376,49 @@ def window(self) -> int: @property @staging_check def frequency(self) -> int: - """Retrieve the monitor frequency for this alert""" + """Retrieve the monitor frequency for this alert.""" try: return self._sv_obj.get_alert()["frequency"] except KeyError as e: raise RuntimeError( - "Expected key 'frequency' in alert definition retrieval" + "Expected key 'frequency' in alert definition retrieval", ) from e @frequency.setter @write_only @pydantic.validate_call def frequency(self, frequency: int) -> None: - """Set the monitor frequency for this alert""" + """Set the monitor frequency for this alert.""" _alert = self._sv_obj.get_alert() | {"frequency": frequency} - self._sv_obj._staging["alert"] = _alert + self._sv_obj.append_to_staging({"alert": _alert}) class MetricThresholdAlertDefinition(MetricsAlertDefinition): - """Alert definition for metric threshold alerts""" + """Alert definition for metric threshold alerts.""" def __eq__(self, other: "MetricThresholdAlertDefinition") -> bool: - """Compare this MetricThresholdAlertDefinition with another""" + """Compare this MetricThresholdAlertDefinition with another.""" if not super().__eq__(other): return False return self.threshold == other.threshold + def __hash__(self) -> int: + return hash(f"{super().__hash__()}+{self.threshold}") + @property def threshold(self) -> float: - """Retrieve the threshold value for this alert""" + """Retrieve the threshold value for this alert.""" if (threshold_l := self._sv_obj.get_alert().get("threshold")) is None: raise RuntimeError("Expected key 'threshold' in alert definition retrieval") return threshold_l class MetricRangeAlertDefinition(MetricsAlertDefinition): - """Alert definition for metric range alerts""" + """Alert definition for metric range alerts.""" def __eq__(self, other: "MetricRangeAlertDefinition") -> bool: - """Compare a MetricRangeAlertDefinition with another""" + """Compare a MetricRangeAlertDefinition with another.""" if not super().__eq__(other): return False @@ -404,21 +426,24 @@ def __eq__(self, other: "MetricRangeAlertDefinition") -> bool: [ self.range_high == other.range_high, self.range_low == other.range_low, - ] + ], ) + def __hash__(self) -> int: + return hash(f"{super().__hash__()}+{self.range_high}+{self.range_low}") + @property def range_low(self) -> float: - """Retrieve the lower limit for metric range""" + """Retrieve the lower limit for metric range.""" if (range_l := self._sv_obj.get_alert().get("range_low")) is None: raise RuntimeError("Expected key 'range_low' in alert definition retrieval") return range_l @property def range_high(self) -> float: - """Retrieve upper limit for metric range""" + """Retrieve upper limit for metric range.""" if (range_u := self._sv_obj.get_alert().get("range_high")) is None: raise RuntimeError( - "Expected key 'range_high' in alert definition retrieval" + "Expected key 'range_high' in alert definition retrieval", ) return range_u diff --git a/simvue/api/objects/alert/user.py b/simvue/api/objects/alert/user.py index 3e9dedac..97573a5a 100644 --- a/simvue/api/objects/alert/user.py +++ b/simvue/api/objects/alert/user.py @@ -4,19 +4,22 @@ """ -import pydantic import typing +import pydantic + try: from typing import Self, override except ImportError: from typing_extensions import Self, override import http -from simvue.api.request import get_json_from_response, put as sv_put -from .base import AlertBase +from simvue.api.request import get_json_from_response +from simvue.api.request import put as sv_put from simvue.models import NAME_REGEX +from .base import AlertBase + class UserAlert(AlertBase): """Simvue User Alert. @@ -34,11 +37,12 @@ def __init__( server_token: pydantic.SecretStr | None = None, **kwargs, ) -> None: - """Initialise a User Alert + """Initialise a User Alert. If an identifier is provided a connection will be made to the object matching the identifier on the target server. - Else a new UserAlert instance will be created using arguments provided in kwargs. + Else a new UserAlert instance will be created using arguments + provided in kwargs. Parameters ---------- @@ -50,9 +54,13 @@ def __init__( token for alternative server, default None **kwargs : dict any additional arguments to be passed to the object initialiser + """ super().__init__( - identifier, server_url=server_url, server_token=server_token, **kwargs + identifier, + server_url=server_url, + server_token=server_token, + **kwargs, ) self._local_status: dict[str, str | None] = kwargs.pop("status", {}) @@ -71,7 +79,7 @@ def new( server_token: pydantic.SecretStr | None = None, **_, ) -> Self: - """Create a new user-defined alert + """Create a new user-defined alert. Note all arguments are keyword arguments. @@ -93,7 +101,7 @@ def new( token for alternative server, default None """ - _alert = cls( + return cls( name=name, description=description, notification=notification, @@ -105,7 +113,6 @@ def new( _read_only=False, _offline=offline, ) - return _alert @override def _compare_objects(self, other: "AlertBase") -> bool: @@ -115,34 +122,37 @@ def _compare_objects(self, other: "AlertBase") -> bool: @classmethod def get( - cls, count: int | None = None, offset: int | None = None + cls, + count: int | None = None, + offset: int | None = None, ) -> dict[str, typing.Any]: - """Return only UserAlerts""" + """Return only UserAlerts.""" raise NotImplementedError("Retrieve of only user alerts is not yet supported") def get_status(self, run_id: str) -> typing.Literal["ok", "critical"] | None: - """Retrieve current alert status for the given run""" + """Retrieve current alert status for the given run.""" if self._offline: return self._staging.get("status", self._local_status).get(run_id) return super().get_status(run_id) def on_reconnect(self, id_mapping: dict[str, str]) -> None: - """Set status update on reconnect""" + """Set status update on reconnect.""" for offline_id, status in self._staging.get("status", {}).items(): self.set_status(id_mapping.get(offline_id), status) @pydantic.validate_call def set_status(self, run_id: str, status: typing.Literal["ok", "critical"]) -> None: - """Set the status of this alert for a given run""" + """Set the status of this alert for a given run.""" if self._offline: if "status" not in self._staging: self._staging["status"] = {} self._staging["status"][run_id] = status return - elif run_id.startswith("offline"): + if run_id.startswith("offline"): raise ValueError( - f"Cannot set status of online alert '{self.id}' for offline run '{run_id}'" + f"Cannot set status of online alert '{self.id}' for " + f"offline run '{run_id}'", ) _response = sv_put( diff --git a/simvue/api/objects/artifact/base.py b/simvue/api/objects/artifact/base.py index a1ef4356..db76e5ff 100644 --- a/simvue/api/objects/artifact/base.py +++ b/simvue/api/objects/artifact/base.py @@ -9,25 +9,33 @@ import io import logging import typing + import pydantic try: from typing import Self, override except ImportError: - from typing_extensions import Self, override # noqa: F401, + from typing_extensions import Self, override -from simvue.api.url import URL from collections.abc import Generator -from simvue.exception import ObjectNotFoundError -from simvue.models import DATETIME_FORMAT + from simvue.api.objects.base import SimvueObject, staging_check, write_only from simvue.api.objects.run import Run from simvue.api.request import ( - put as sv_put, + get as sv_get, +) +from simvue.api.request import ( get_json_from_response, +) +from simvue.api.request import ( post as sv_post, - get as sv_get, ) +from simvue.api.request import ( + put as sv_put, +) +from simvue.api.url import URL +from simvue.exception import ObjectNotFoundError +from simvue.models import DATETIME_FORMAT Category = typing.Literal["code", "input", "output"] @@ -40,7 +48,7 @@ class ArtifactBase(SimvueObject): - """Connect to/create an artifact locally or on the server""" + """Connect to/create an artifact locally or on the server.""" _label: str = "artifact" @@ -54,7 +62,6 @@ def __init__( **kwargs, ) -> None: """Retrieve an artifact instance from the Simvue server by identifier.""" - super().__init__( identifier=identifier, server_url=server_url, @@ -84,6 +91,7 @@ def attach_to_run(self, run_id: str, category: Category) -> None: identifier of run to associate this artifact with. category : Literal['input', 'output', 'code'] category of this artifact with respect to the run. + """ self._init_data["runs"][run_id] = category @@ -110,16 +118,17 @@ def attach_to_run(self, run_id: str, category: Category) -> None: ) def on_reconnect(self, id_mapping: dict[str, str]) -> None: - """Operations performed when this artifact is switched from offline to online mode. + """Operations performed when artifact mode switched from offline to online. Parameters ---------- id_mapping : dict[str, str] mapping from offline identifier to new online identifier. + """ _offline_staging = self._init_data["runs"].copy() - for id, category in _offline_staging.items(): - self.attach_to_run(run_id=id_mapping[id], category=category) + for _id, category in _offline_staging.items(): + self.attach_to_run(run_id=id_mapping[_id], category=category) def _upload(self, file: io.BytesIO, timeout: int, file_size: int) -> None: if self._offline: @@ -133,13 +142,15 @@ def _upload(self, file: io.BytesIO, timeout: int, file_size: int) -> None: timeout = BASE_TIMEOUT + UPLOAD_TIMEOUT_PER_MB * file_size / 1024 / 1024 self._logger.debug( - f"Will wait for a period of {timeout:.0f}s for upload of file for {file_size}B file to complete." + "Will wait for a period of %s for upload of file for %s file to complete.", + f"{timeout:.0f}s", + f"{file_size}B", ) _name = self._staging["name"] if _fields := self._init_data.get("fields"): - _logger.debug(f"Using POST for artifact upload to '{_url}': {_fields}") + _logger.debug("Using POST for artifact upload to '%s': %s", _url, _fields) _response = sv_post( url=_url, headers={}, @@ -151,7 +162,7 @@ def _upload(self, file: io.BytesIO, timeout: int, file_size: int) -> None: ) else: - _logger.debug(f"Using PUT for artifact upload to '{_url}'") + _logger.debug("Using PUT for artifact upload to '%s'", _url) _response = sv_put( url=_url, headers={}, @@ -173,15 +184,18 @@ def _upload(self, file: io.BytesIO, timeout: int, file_size: int) -> None: ) # Temporarily remove read-only state - self.read_only(False) + self.read_only(is_read_only=False) # Update the server status to confirm file uploaded self.uploaded = True super().commit() - self.read_only(True) + self.read_only(is_read_only=True) def _get( - self, storage: str | None = None, url: str | None = None, **kwargs + self, + storage: str | None = None, + url: str | None = None, + **kwargs, ) -> dict[str, typing.Any]: return super()._get( storage=storage or self._staging.get("server", {}).get("storage_id"), @@ -196,6 +210,7 @@ def checksum(self) -> str: Returns ------- str + """ return self._get_attribute("checksum") @@ -206,6 +221,7 @@ def storage_url(self) -> URL | None: Returns ------- simvue.api.url.URL | None + """ return URL(_url) if (_url := self._init_data.get("url")) else None @@ -216,6 +232,7 @@ def original_path(self) -> str: Returns ------- str + """ return self._get_attribute("original_path") @@ -226,6 +243,7 @@ def storage_id(self) -> str | None: Returns ------- str | None + """ return self._get_attribute("storage_id") @@ -236,6 +254,7 @@ def mime_type(self) -> str: Returns ------- str + """ return self._get_attribute("mime_type") @@ -246,6 +265,7 @@ def size(self) -> int: Returns ------- int + """ return self._get_attribute("size") @@ -256,20 +276,26 @@ def name(self) -> str | None: Returns ------- str | None + """ return self._get_attribute("name") @property def created(self) -> datetime.datetime | None: - """Retrieve created datetime for the artifact. + """Retrieve created datetime in UTC for the artifact. Returns ------- datetime.datetime | None + """ _created: str | None = self._get_attribute("created") return ( - datetime.datetime.strptime(_created, DATETIME_FORMAT) if _created else None + datetime.datetime.strptime(_created, DATETIME_FORMAT).astimezone( + datetime.UTC, + ) + if _created + else None ) @property @@ -280,6 +306,7 @@ def uploaded(self) -> bool: Returns ------- bool + """ return self._get_attribute("uploaded") @@ -292,11 +319,12 @@ def uploaded(self, is_uploaded: bool) -> None: @property def download_url(self) -> URL | None: - """Retrieve the URL for downloading this artifact + """Retrieve the URL for downloading this artifact. Returns ------- simvue.api.url.URL | None + """ return self._get_attribute("url") @@ -312,6 +340,7 @@ def runs(self) -> Generator[str]: Returns ------- Generator[str, None, None] + """ for _id, _ in Run.get(filters=[f"artifact.id == {self.id}"]): yield _id @@ -322,6 +351,7 @@ def get_category(self, run_id: str) -> Category: Returns ------- Literal['input', 'output', 'code'] + """ _run_url = ( URL(self._user_config.server.url) @@ -331,11 +361,16 @@ def get_category(self, run_id: str) -> Category: _json_response = get_json_from_response( response=_response, expected_status=[http.HTTPStatus.OK, http.HTTPStatus.NOT_FOUND], - scenario=f"Retrieval of category for artifact '{self._identifier}' with respect to run '{run_id}'", + scenario=( + f"Retrieval of category for artifact '{self._identifier}' " + f"with respect to run '{run_id}'" + ), ) if _response.status_code == http.HTTPStatus.NOT_FOUND: raise ObjectNotFoundError( - self.label(), self._identifier, extra=f"for run '{run_id}'" + self.label(), + self._identifier, + extra=f"for run '{run_id}'", ) return _json_response["category"] @@ -352,16 +387,20 @@ def download_content(self) -> Generator[bytes]: Returns ------- Generator[bytes, None, None] + """ if not self.download_url: raise ValueError( - f"Could not retrieve URL for artifact '{self._identifier}'" + f"Could not retrieve URL for artifact '{self._identifier}'", ) _timeout = BASE_TIMEOUT + DOWNLOAD_TIMEOUT_PER_MB * self.size / 1024 / 1024 self._logger.debug( - f"Will wait {_timeout:.0f}s for download of file {self.name} of size {self.size}B" + "Will wait %s for download of file %s of size %s", + f"{_timeout:.0f}s", + self.name, + f"{self.size}B", ) _response = sv_get( diff --git a/simvue/api/objects/artifact/fetch.py b/simvue/api/objects/artifact/fetch.py index b15f7f77..109a366c 100644 --- a/simvue/api/objects/artifact/fetch.py +++ b/simvue/api/objects/artifact/fetch.py @@ -5,21 +5,22 @@ """ import http -import typing -import pydantic import json +import typing +from collections.abc import Generator +import pydantic from simvue.api.objects.artifact.base import ArtifactBase -from simvue.api.objects.base import Sort -from simvue.config.user import SimvueConfiguration -from .file import FileArtifact -from collections.abc import Generator from simvue.api.objects.artifact.object import ObjectArtifact -from simvue.api.request import get_json_from_response, get as sv_get +from simvue.api.objects.base import Sort +from simvue.api.request import get as sv_get +from simvue.api.request import get_json_from_response from simvue.api.url import URL +from simvue.config.user import SimvueConfiguration from simvue.exception import ObjectNotFoundError +from .file import FileArtifact __all__ = ["Artifact"] @@ -29,7 +30,7 @@ class ArtifactSort(Sort): @classmethod def check_column(cls, column: str) -> str: if column and ( - column not in ("name", "created") and not column.startswith("metadata.") + column not in {"name", "created"} and not column.startswith("metadata.") ): raise ValueError(f"Invalid sort column for artifacts '{column}'") return column @@ -48,7 +49,7 @@ def __new__( *, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> FileArtifact | ObjectArtifact: """Retrieve an object representing an artifact on the server by id. @@ -60,11 +61,14 @@ def __new__( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for retrieval Returns ------- FileArtifact | ObjectArtifact object representing storage + """ _artifact_pre = ArtifactBase( identifier=identifier, @@ -79,13 +83,12 @@ def __new__( server_token=server_token, **kwargs, ) - else: - return ObjectArtifact( - identifier=identifier, - server_url=server_url, - server_token=server_token, - **kwargs, - ) + return ObjectArtifact( + identifier=identifier, + server_url=server_url, + server_token=server_token, + **kwargs, + ) @classmethod def from_run( @@ -127,13 +130,18 @@ def from_run( ------ ObjectNotFoundError Raised if artifacts could not be found for that run + """ _config: SimvueConfiguration = SimvueConfiguration.fetch( - mode="online", server_url=server_url, server_token=server_token + mode="online", + server_url=server_url, + server_token=server_token, ) _url = URL(f"{_config.server.url}") / f"runs/{run_id}/artifacts" _response = sv_get( - url=f"{_url}", params={"category": category}, headers=_config.headers + url=f"{_url}", + params={"category": category}, + headers=_config.headers, ) _json_response = get_json_from_response( expected_type=list, @@ -144,7 +152,9 @@ def from_run( if _response.status_code == http.HTTPStatus.NOT_FOUND or not _json_response: raise ObjectNotFoundError( - ArtifactBase.label, category, extra=f"for run '{run_id}'" + ArtifactBase.label, + category, + extra=f"for run '{run_id}'", ) for _entry in _json_response: @@ -190,13 +200,18 @@ def from_name( ------ RuntimeError when duplicate artifacts are found within a single run + """ _config: SimvueConfiguration = SimvueConfiguration.fetch( - mode="online", server_url=server_url, server_token=server_token + mode="online", + server_url=server_url, + server_token=server_token, ) _url = URL(f"{_config.server.url}") / f"runs/{run_id}/artifacts" _response = sv_get( - url=f"{_url}", params={"name": name}, headers=_config.headers + url=f"{_url}", + params={"name": name}, + headers=_config.headers, ) _json_response = get_json_from_response( expected_type=list, @@ -207,13 +222,15 @@ def from_name( if _response.status_code == http.HTTPStatus.NOT_FOUND or not _json_response: raise ObjectNotFoundError( - ArtifactBase.label(), name, extra=f"for run '{run_id}'" + ArtifactBase.label(), + name, + extra=f"for run '{run_id}'", ) if (_n_res := len(_json_response)) > 1 and not force_overwrite: raise RuntimeError( f"Expected single result for artifact '{name}' for run '{run_id}'" - f" but got {_n_res}" + f" but got {_n_res}", ) _first_result: dict[str, typing.Any] = _json_response[0] @@ -238,7 +255,7 @@ def get( sorting: list[ArtifactSort] | None = None, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Generator[tuple[str, FileArtifact | ObjectArtifact]]: """Returns artifacts associated with the current user. @@ -254,16 +271,20 @@ def get( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for retrieval Yields ------ tuple[str, FileArtifact | ObjectArtifact] identifier for artifact the artifact itself as a class instance - """ + """ _config: SimvueConfiguration = SimvueConfiguration.fetch( - mode="online", server_url=server_url, server_token=server_token + mode="online", + server_url=server_url, + server_token=server_token, ) _url = URL(f"{_config.server.url}") / ArtifactBase.endpoint() _params = {"start": offset, "count": count} diff --git a/simvue/api/objects/artifact/file.py b/simvue/api/objects/artifact/file.py index bfc07e61..6b11d9d6 100644 --- a/simvue/api/objects/artifact/file.py +++ b/simvue/api/objects/artifact/file.py @@ -5,17 +5,19 @@ """ -from .base import ArtifactBase - -import typing -import pydantic +import datetime import os import pathlib import shutil +import typing + +import pydantic + from simvue.config.user import SimvueConfiguration -from datetime import datetime from simvue.models import NAME_REGEX -from simvue.utilities import get_mimetype_for_file, get_mimetypes, calculate_sha256 +from simvue.utilities import calculate_file_sha256, get_mimetype_for_file, get_mimetypes + +from .base import ArtifactBase try: from typing import Self @@ -38,11 +40,12 @@ def __init__( server_token: pydantic.SecretStr | None = None, **kwargs, ) -> None: - """Initialise a File Artifact + """Initialise a File Artifact. If an identifier is provided a connection will be made to the object matching the identifier on the target server. - Else a new FileArtifact instance will be created using arguments provided in kwargs. + Else a new FileArtifact instance will be created using + arguments provided in kwargs. Parameters ---------- @@ -54,6 +57,7 @@ def __init__( token for alternative server, default None **kwargs : dict any additional arguments to be passed to the object initialiser + """ super().__init__( identifier=identifier, @@ -76,9 +80,9 @@ def new( snapshot: bool = False, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Self: - """Create a new artifact either locally or on the server + """Create a new artifact either locally or on the server. Note all arguments are keyword arguments @@ -99,11 +103,14 @@ def new( offline : bool, optional whether to define this artifact locally, default is False snapshot : bool, optional - whether to create a snapshot of this file before uploading it, default is False + whether to create a snapshot of this file before uploading it, + default is False server_url: str | None, optional alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for initialisation Returns ------- @@ -129,18 +136,19 @@ def new( ) _local_staging_dir: pathlib.Path = _user_config.offline.cache.joinpath( - "artifacts" + "artifacts", ) _local_staging_dir.mkdir(parents=True, exist_ok=True) _local_staging_file = _local_staging_dir.joinpath( - f"{file_path.stem}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S_%f')[:-3]}.file" + f"{file_path.stem}_" + f"{datetime.datetime.now(tz=datetime.UTC).strftime('%Y-%m-%d_%H-%M-%S_%f')[:-3]}.file", ) shutil.copy(file_path, _local_staging_file) file_path = _local_staging_file _file_size = file_path.stat().st_size _file_orig_path = file_path.expanduser().absolute() - _file_checksum = calculate_sha256(f"{file_path}", is_file=True) + _file_checksum = calculate_file_sha256(file_path) _artifact = cls( name=name, @@ -169,7 +177,7 @@ def new( if offline: return _artifact - with open(_file_orig_path, "rb") as out_f: + with pathlib.Path(_file_orig_path).open("rb") as out_f: _artifact._upload(file=out_f, timeout=upload_timeout, file_size=_file_size) # If snapshot created, delete it after uploading diff --git a/simvue/api/objects/artifact/object.py b/simvue/api/objects/artifact/object.py index 8faff12d..0a234df3 100644 --- a/simvue/api/objects/artifact/object.py +++ b/simvue/api/objects/artifact/object.py @@ -5,15 +5,18 @@ """ -from .base import ArtifactBase +import io +import pathlib +import sys +import typing + +import pydantic + from simvue.models import NAME_REGEX from simvue.serialization import serialize_object -from simvue.utilities import calculate_sha256 +from simvue.utilities import calculate_object_sha256 -import pydantic -import typing -import sys -import io +from .base import ArtifactBase try: from typing import Self, override @@ -24,8 +27,9 @@ class ObjectArtifact(ArtifactBase): """Simvue Object Artifact. - This class is used to connect to/create file object artifact objects on the Simvue server, - any modification of instance attributes is mirrored on the remote object. + This class is used to connect to/create file object artifact + objects on the Simvue server, any modification of instance + attributes is mirrored on the remote object. """ @@ -38,11 +42,12 @@ def __init__( server_token: pydantic.SecretStr | None = None, **kwargs, ) -> None: - """Initialise a Object Artifact + """Initialise a Object Artifact. If an identifier is provided a connection will be made to the object matching the identifier on the target server. - Else a new ObjectArtifact instance will be created using arguments provided in kwargs. + Else a new ObjectArtifact instance will be created using + arguments provided in kwargs. Parameters ---------- @@ -54,6 +59,7 @@ def __init__( token for alternative server, default None **kwargs : dict any additional arguments to be passed to the object initialiser + """ kwargs.pop("original_path", None) super().__init__(identifier, original_path="", **kwargs) @@ -73,9 +79,9 @@ def new( offline: bool = False, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Self: - """Create a new artifact either locally or on the server + """Create a new artifact either locally or on the server. Note all arguments are keyword arguments @@ -100,6 +106,8 @@ def new( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for initialisation Returns ------- @@ -113,23 +121,26 @@ def new( _data_type = kwargs.pop("mime_type") _serialized = kwargs.pop("serialized") _checksum = kwargs.pop("checksum") - kwargs.pop("size") - kwargs.pop("original_path") + _ = kwargs.pop("size") + _ = kwargs.pop("original_path") except KeyError: - raise ValueError("Must provide an object to be saved, not None.") + raise ValueError( + "Must provide an object to be saved, not None.", + ) from None else: - _serialization = serialize_object(obj, allow_pickling) + _serialization = serialize_object(obj, allow_pickle=allow_pickling) if not _serialization or not (_serialized := _serialization[0]): raise ValueError(f"Could not serialize object of type '{type(obj)}'") if not (_data_type := _serialization[1]) and not allow_pickling: raise ValueError( - f"Could not serialize object of type '{type(obj)}' without pickling" + f"Could not serialize object of type '{type(obj)}' " + "without pickling", ) - _checksum = calculate_sha256(_serialized, is_file=False) + _checksum = calculate_object_sha256(_serialized) _artifact = cls( name=name, @@ -149,11 +160,9 @@ def new( _artifact._init_data = {} _artifact._staging["obj"] = None _artifact._local_staging_file.parent.mkdir(parents=True, exist_ok=True) - with open( + pathlib.Path( _artifact._local_staging_file.parent.joinpath(f"{_artifact.id}.object"), - "wb", - ) as file: - file.write(_serialized) + ).write_bytes(_serialized) else: _artifact._init_data = _artifact._post_single(**_artifact._staging) diff --git a/simvue/api/objects/base.py b/simvue/api/objects/base.py index 4a729571..bc44613b 100644 --- a/simvue/api/objects/base.py +++ b/simvue/api/objects/base.py @@ -4,43 +4,54 @@ """ import abc -import pathlib -import types -import typing -import inspect -import uuid import http +import inspect import json import logging +import types +import typing +import uuid +from collections.abc import Generator import msgpack import pydantic -from collections.abc import Generator -from simvue.utilities import staging_merger -from simvue.config.user import SimvueConfiguration -from simvue.exception import ObjectNotFoundError +from simvue.api.request import ( + delete as sv_delete, +) from simvue.api.request import ( get as sv_get, +) +from simvue.api.request import ( + get_json_from_response, get_paginated, +) +from simvue.api.request import ( post as sv_post, +) +from simvue.api.request import ( put as sv_put, - delete as sv_delete, - get_json_from_response, ) from simvue.api.url import URL +from simvue.config.user import SimvueConfiguration +from simvue.exception import ObjectNotFoundError +from simvue.utilities import staging_merger + +if typing.TYPE_CHECKING: + import pathlib try: from typing import Self, override except ImportError: - from typing_extensions import Self, override # noqa: UP035 + from typing_extensions import Self, override -# Need to use this inside of Generator typing to fix bug present in Python 3.10 - see issue #745 +# Need to use this inside of Generator typing to +# fix bug present in Python 3.10 - see issue #745 T = typing.TypeVar("T", bound="SimvueObject") def staging_check(member_func: typing.Callable) -> typing.Callable: - """Decorator for checking if requested attribute has uncommitted changes""" + """Decorator for checking if requested attribute has uncommitted changes.""" def _wrapper(self) -> typing.Any: if isinstance(self, SimvueObject): @@ -49,13 +60,14 @@ def _wrapper(self) -> typing.Any: _sv_obj = self._sv_obj else: raise RuntimeError( - f"Cannot use 'staging_check' decorator on type '{type(self).__name__}'" + f"Cannot use 'staging_check' decorator on type '{type(self).__name__}'", ) - if _sv_obj._offline: + if _sv_obj.user_config.run.mode == "offline": return member_func(self) - if not _sv_obj._read_only and member_func.__name__ in _sv_obj._staging: + if not _sv_obj.is_read_only and member_func.__name__ in _sv_obj.staging: _sv_obj._logger.warning( - f"Uncommitted change found for attribute '{member_func.__name__}'" + "Uncommitted change found for attribute '%s'", + member_func.__name__, ) return member_func(self) @@ -71,10 +83,10 @@ def _wrapper(self) -> typing.Any: def write_only(attribute_func: typing.Callable) -> typing.Callable: def _wrapper(self: "SimvueObject", *args, **kwargs) -> typing.Any: _sv_obj = getattr(self, "_sv_obj", self) - if _sv_obj._read_only: + if _sv_obj.is_read_only: raise AssertionError( f"Cannot set property '{attribute_func.__name__}' " - f"on read-only object of type '{self.label()}'" + f"on read-only object of type '{self.label()}'", ) return attribute_func(self, *args, **kwargs) @@ -88,16 +100,16 @@ def _wrapper(self: "SimvueObject", *args, **kwargs) -> typing.Any: class Visibility: - """Interface for object visibility definition""" + """Interface for object visibility definition.""" def __init__(self, sv_obj: "SimvueObject") -> None: - """Initialise visibility with target object""" + """Initialise visibility with target object.""" self._sv_obj = sv_obj def _update_visibility(self, key: str, value: typing.Any) -> None: - """Update the visibility configuration for this object""" - _visibility = self._sv_obj._get_visibility() | {key: value} - self._sv_obj._staging["visibility"] = _visibility + """Update the visibility configuration for this object.""" + _visibility = self._sv_obj.get_visibility() | {key: value} + self._sv_obj.staging["visibility"] = _visibility @property @staging_check @@ -111,8 +123,9 @@ def users(self) -> list[str]: Returns ------- list[str] + """ - return self._sv_obj._get_visibility().get("users", []) + return self._sv_obj.get_visibility().get("users", []) @users.setter @write_only @@ -131,8 +144,9 @@ def public(self) -> bool: Returns ------- bool + """ - return self._sv_obj._get_visibility().get("public", False) # type: ignore + return self._sv_obj.get_visibility().get("public", False) @public.setter @write_only @@ -151,8 +165,9 @@ def tenant(self) -> bool: Returns ------- bool + """ - return self._sv_obj._get_visibility().get("tenant", False) # type: ignore + return self._sv_obj.get_visibility().get("tenant", False) @tenant.setter @write_only @@ -214,7 +229,7 @@ def __init__( identifier is not None and identifier.startswith("offline_") ) - self._user_config = SimvueConfiguration.fetch( + self._user_config: SimvueConfiguration = SimvueConfiguration.fetch( mode="offline" if self._offline else "online", server_token=server_token, server_url=server_url, @@ -224,7 +239,8 @@ def __init__( # e.g. multiple runs writing at the same time self._local_staging_file: pathlib.Path = ( self._user_config.offline.cache.joinpath( - self.endpoint(), f"{self._identifier}.json" + self.endpoint(), + f"{self._identifier}.json", ) ) @@ -239,30 +255,28 @@ def __init__( # If this object is read-only, but not a local construction, make an API call if ( not self._identifier.startswith("offline_") - and self._read_only + and self.is_read_only and not self._local ): self._staging = self._get() # Recover any locally staged changes if not read-only self._staging |= ( - {} if (_read_only and not self._offline) else self._get_local_staged() + {} if (self._read_only and not self._offline) else self._get_local_staged() ) self._staging |= kwargs - def _get_local_staged(self, obj_label: str | None = None) -> dict[str, typing.Any]: - """Retrieve any locally staged data for this identifier""" + def _get_local_staged(self) -> dict[str, typing.Any]: + """Retrieve any locally staged data for this identifier.""" if not self._local_staging_file.exists() or not self._identifier: return {} with self._local_staging_file.open() as in_f: - _staged_data = json.load(in_f) - - return _staged_data + return json.load(in_f) def _stage_to_other(self, obj_label: str, key: str, value: typing.Any) -> None: - """Stage a change to another object type""" + """Stage a change to another object type.""" with self._local_staging_file.open() as in_f: _staged_data = json.load(in_f) @@ -303,11 +317,12 @@ def _get_attribute( ------- object the attribute value + """ # In the case where the object is read-only, staging is the data # already retrieved from the server _attribute_is_property: bool = attribute in self._properties - _state_is_read_only: bool = getattr(self, "_read_only", True) + _state_is_read_only: bool = getattr(self, "is_read_only", True) _offline_state: bool = ( self._identifier is not None and self._identifier.startswith("offline_") ) @@ -327,22 +342,26 @@ def _get_attribute( return _attribute raise AttributeError( f"Could not retrieve attribute '{attribute}' " - f"for {self.label()} '{self._identifier}' from cached data" + f"for {self.label()} '{self._identifier}' from cached data", ) from e try: self._logger.debug( - f"Retrieving attribute '{attribute}' from {self.label()} '{self._identifier}'" + "Retrieving attribute '%s' from %s '%s'", + attribute, + self.label(), + self.id, ) return self._get(url=url)[attribute] except KeyError as e: if self._offline: raise AttributeError( f"A value for attribute '{attribute}' has " - f"not yet been committed for offline {self.label()} '{self._identifier}'" + f"not yet been committed for offline {self.label()}" + f" '{self._identifier}'", ) from e raise RuntimeError( - f"Expected key '{attribute}' for {self.label()} '{self._identifier}'" + f"Expected key '{attribute}' for {self.label()} '{self._identifier}'", ) from e def _clear_staging(self) -> None: @@ -360,7 +379,7 @@ def _clear_staging(self) -> None: with self._local_staging_file.open("w") as out_f: json.dump(_staged_data, out_f, indent=2) - def _get_visibility(self) -> dict[str, bool | list[str]]: + def get_visibility(self) -> dict[str, bool | list[str]]: try: return self._get_attribute("visibility") except AttributeError: @@ -373,7 +392,9 @@ def new(cls, **_) -> Self: @classmethod def batch_create( - cls, obj_args: ObjectBatchArgs, visibility: VisibilityBatchArgs + cls, + obj_args: ObjectBatchArgs, + visibility: VisibilityBatchArgs, ) -> Generator[str]: _, __ = obj_args, visibility raise NotImplementedError @@ -386,7 +407,7 @@ def ids( offset: int | None = None, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Generator[str, None, None]: """Retrieve a list of all object identifiers. @@ -400,11 +421,14 @@ def ids( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for request Yields - ------- + ------ str identifiers for all objects of this type. + """ _count: int = 0 for response in cls._get_all_objects( @@ -416,7 +440,7 @@ def ids( ): if (_data := response.get("data")) is None: raise RuntimeError( - f"Expected key 'data' for retrieval of {cls.__name__.lower()}s" + f"Expected key 'data' for retrieval of {cls.__name__.lower()}s", ) for entry in _data: yield entry["id"] @@ -433,7 +457,7 @@ def get( offset: pydantic.NonNegativeInt | None = None, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Generator[tuple[str, T | None]]: """Retrieve items of this object type from the server. @@ -447,6 +471,8 @@ def get( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for request Yields ------ @@ -456,6 +482,7 @@ def get( Returns ------- Generator[tuple[str, SimvueObject | None]] + """ _count: int = 0 @@ -470,7 +497,7 @@ def get( return if (_data := _response.get("data")) is None: raise RuntimeError( - f"Expected key 'data' for retrieval of {cls.__name__.lower()}s" + f"Expected key 'data' for retrieval of {cls.__name__.lower()}s", ) # If data is an empty list @@ -482,7 +509,7 @@ def get( yield ( _id, cls( - _read_only=True, + is_read_only=True, identifier=_id, server_url=server_url, server_token=server_token, @@ -498,7 +525,7 @@ def count( *, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> int: """Return the total number of entries for this object type from the server. @@ -508,11 +535,14 @@ def count( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for request Returns ------- int total from server database for current user. + """ _count_total: int = 0 for _data in cls._get_all_objects( @@ -524,7 +554,7 @@ def count( ): if not (_count := _data.get("count")): raise RuntimeError( - f"Expected key 'count' for retrieval of {cls.__name__.lower()}s" + f"Expected key 'count' for retrieval of {cls.__name__.lower()}s", ) _count_total += _count return _count_total @@ -541,7 +571,9 @@ def _get_all_objects( **kwargs, ) -> Generator[dict, None, None]: _config: SimvueConfiguration = SimvueConfiguration.fetch( - mode="online", server_url=server_url, server_token=server_token + mode="online", + server_url=server_url, + server_token=server_token, ) # Allow the possibility of paginating a URL that is not the @@ -549,24 +581,32 @@ def _get_all_objects( _url = f"{_config.server.url}/{endpoint or cls.endpoint()}" _label = cls.label() - if _label.endswith("s"): - _label = _label[:-1] + _label = _label.removesuffix("s") for response in get_paginated( - _url, headers=_config.headers, offset=offset, count=count, **kwargs + _url, + headers=_config.headers, + offset=offset, + count=count, + **kwargs, ): _generator = get_json_from_response( response=response, expected_status=[http.HTTPStatus.OK], scenario=f"Retrieval of {_label}s", expected_type=expected_type, - ) # type: ignore + ) if expected_type is dict: yield _generator else: yield from _generator + @property + def is_read_only(self) -> bool: + """Returns if this instance is in read-only mode.""" + return self._read_only + def read_only(self, is_read_only: bool, *, clear_staged: bool = True) -> None: """Set whether this object is in read only state. @@ -576,6 +616,7 @@ def read_only(self, is_read_only: bool, *, clear_staged: bool = True) -> None: whether object is read only. clear_staged : bool, optional whether to clear staging data, default is True. + """ self._read_only = is_read_only @@ -583,20 +624,23 @@ def read_only(self, is_read_only: bool, *, clear_staged: bool = True) -> None: # in this context it contains existing data retrieved # from the server/local entry which we dont want to # re-push unnecessarily, then read any locally staged changes - if not self._read_only and clear_staged: + if not self.is_read_only and clear_staged: self._staging = self._get_local_staged() def commit(self) -> dict | list[dict] | None: """Send updates to the server, or if offline, store locally.""" - if self._read_only: + if self.is_read_only: raise AttributeError("Cannot commit object in 'read-only' mode") if self._offline: self._logger.debug( - f"Writing updates to staging file for {self.label()} '{self.id}': {self._staging}" + "Writing updates to staging file for %s '%s': %s", + self.label(), + self.id, + self._staging, ) self._cache() - return + return None _response: dict[str, str] | list[dict[str, str]] | None = None @@ -606,17 +650,25 @@ def commit(self) -> dict | list[dict] | None: # If batch upload send as list, else send as dictionary of params if _batch_commit := self._staging.get("batch"): self._logger.debug( - f"Posting batched data to server: {len(_batch_commit)} {self.label()}s" + "Posting batched data to server: %s %s", + len(_batch_commit), + f"{self.label()}s", ) _response = self._post_batch(batch_data=_batch_commit) else: self._logger.debug( - f"Posting from staged data for {self.label()} '{self.id}': {self._staging}" + "Posting from staged data for %s '%s': %s", + self.label(), + self.id, + self._staging, ) _response = self._post_single(**self._staging) elif self._staging: self._logger.debug( - f"Pushing updates from staged data for {self.label()} '{self.id}': {self._staging}" + "Pushing updates from staged data for %s '%s': %s", + self.label(), + self.id, + self._staging, ) _response = self._put(**self._staging) @@ -625,6 +677,15 @@ def commit(self) -> dict | list[dict] | None: return _response + @property + def staging(self) -> dict[str, Any]: + """Return current staging for this object.""" + return self._staging + + def append_to_staging(self, items: dict[str, Any]) -> None: + """Add additional items to staging.""" + self._staging |= items + @property def id(self) -> str | None: """The identifier for this object if applicable. @@ -632,11 +693,12 @@ def id(self) -> str | None: Returns ------- str | None + """ return self._identifier @property - def _base_url(self) -> URL: + def base_url(self) -> URL: return URL(self._user_config.server.url) / self.endpoint() @property @@ -646,15 +708,16 @@ def url(self) -> URL | None: Returns ------- simvue.api.url.URL | None + """ - return None if self._identifier is None else self._base_url / self._identifier + return None if self._identifier is None else self.base_url / self._identifier def _post_batch( self, batch_data: list[ObjectBatchArgs], ) -> list[dict[str, str]]: _response = sv_post( - url=f"{self._base_url}", + url=f"{self.base_url}", headers=self._headers | {"Content-Type": "application/msgpack"}, params=self._params or {}, data=batch_data, @@ -663,7 +726,8 @@ def _post_batch( if _response.status_code == http.HTTPStatus.FORBIDDEN: raise RuntimeError( - f"Forbidden: You do not have permission to create object of type '{self.label()}'" + "Forbidden: You do not have permission to " + f"create object of type '{self.label()}'", ) _json_response = get_json_from_response( @@ -675,15 +739,21 @@ def _post_batch( if not len(batch_data) == (_n_created := len(_json_response)): raise RuntimeError( - f"Expected {len(batch_data)} to be created, but only {_n_created} found." + "Expected %s to be created, but only %s found.", + len(batch_data), + _n_created, ) - self._logger.debug(f"successfully created {_n_created} {self.label()}s") + self._logger.debug("successfully created %s %s", _n_created, f"{self.label()}s") return _json_response def _post_single( - self, *, is_json: bool = True, data: list | dict | None = None, **kwargs + self, + *, + is_json: bool = True, + data: list | dict | None = None, + **kwargs, ) -> dict[str, typing.Any] | list[dict[str, typing.Any]]: # Remove any extra keys for key in self._local_only_args: @@ -693,7 +763,7 @@ def _post_single( kwargs = msgpack.packb(data or kwargs, use_bin_type=True) _response = sv_post( - url=f"{self._base_url}", + url=f"{self.base_url}", headers=self._headers | {"Content-Type": "application/msgpack"}, params=self._params or {}, data=data or kwargs, @@ -702,7 +772,8 @@ def _post_single( if _response.status_code == http.HTTPStatus.FORBIDDEN: raise RuntimeError( - f"Forbidden: You do not have permission to create object of type '{self.label()}'" + "Forbidden: You do not have permission to create " + f"object of type '{self.label()}'", ) _json_response = get_json_from_response( @@ -721,7 +792,7 @@ def _post_single( _detail = "No information in JSON response." raise RuntimeError( - f"Expected new ID for {self.label()} but none found: {_detail}." + f"Expected new ID for {self.label()} but none found: {_detail}.", ) return _json_response @@ -735,12 +806,16 @@ def _put(self, **kwargs) -> dict[str, typing.Any]: _ = kwargs.pop(key, None) _response = sv_put( - url=f"{self.url}", headers=self._headers, data=kwargs, is_json=True + url=f"{self.url}", + headers=self._headers, + data=kwargs, + is_json=True, ) if _response.status_code == http.HTTPStatus.FORBIDDEN: raise RuntimeError( - f"Forbidden: You do not have permission to create object of type '{self.label()}'" + "Forbidden: You do not have permission to " + f"create object of type '{self.label()}'", ) return get_json_from_response( @@ -756,8 +831,8 @@ def delete(self, **kwargs) -> dict[str, typing.Any]: ------- dict[str, Any] response from server on deletion. - """ + """ if self._get_local_staged(): self._local_staging_file.unlink(missing_ok=True) @@ -780,7 +855,11 @@ def delete(self, **kwargs) -> dict[str, typing.Any]: return _json_response def _get( - self, url: str | None = None, allow_parse_failure: bool = False, **kwargs + self, + url: str | None = None, + *, + allow_parse_failure: bool = False, + **kwargs, ) -> dict[str, typing.Any]: if self._identifier.startswith("offline_"): return self._get_local_staged() @@ -789,12 +868,15 @@ def _get( raise RuntimeError(f"Identifier for instance of {self.label()} Unknown") _response = sv_get( - url=f"{url or self.url}", headers=self._headers, params=kwargs + url=f"{url or self.url}", + headers=self._headers, + params=kwargs, ) if _response.status_code == http.HTTPStatus.NOT_FOUND: raise ObjectNotFoundError( - obj_type=self.label(), name=self._identifier or "Unknown" + obj_type=self.label(), + name=self._identifier or "Unknown", ) _json_response = get_json_from_response( @@ -807,14 +889,15 @@ def _get( if not isinstance(_json_response, dict): raise RuntimeError( - f"Expected dictionary from JSON response during {self.label()} retrieval " - f"but got '{type(_json_response)}'" + "Expected dictionary from JSON response " + f"during {self.label()} retrieval " + f"but got '{type(_json_response)}'", ) return _json_response def refresh(self) -> None: """Refresh staging from local data if in read-only mode.""" - if self._read_only: + if self.is_read_only: self._staging = self._get() def _cache(self) -> None: @@ -839,6 +922,7 @@ def to_dict(self) -> dict[str, typing.Any]: ------- dict[str, Any] dictionary representation of this object + """ return self._get() | self._staging @@ -857,9 +941,15 @@ def staged(self) -> dict[str, typing.Any] | None: ------- dict[str, Any] | None the locally staged data if available. + """ return self._staging or None + @property + def user_config(self) -> SimvueConfiguration: + """Return current user configuration.""" + return self._user_config + @classmethod def label(cls) -> str: """Return API label for this object type.""" @@ -881,20 +971,20 @@ def __repr__(self) -> str: _property_values: list[str] = [] _property_warn_list: list[str] = [] - for property in self._properties: + for _property in self._properties: try: - _value = getattr(self, property) - except (KeyError, Exception): + _value = getattr(self, _property) + except AttributeError: # Display a warning only once if a property could not be retrieved if property not in _property_warn_list: - self._logger.warning(f"Failed to retrieve property '{property}'") - _property_warn_list.append(property) + self._logger.warning("Failed to retrieve property '%s'", _property) + _property_warn_list.append(_property) continue if isinstance(_value, types.GeneratorType): continue - _property_values.append(f"{property}={_value!r}") + _property_values.append(f"{_property}={_value!r}") _out_str += ", ".join(_property_values) _out_str += ")" diff --git a/simvue/api/objects/events.py b/simvue/api/objects/events.py index 066c635e..7496249d 100644 --- a/simvue/api/objects/events.py +++ b/simvue/api/objects/events.py @@ -1,5 +1,4 @@ -""" -Simvue Server Events +"""Simvue Server Events. ==================== Contains a class for remotely connecting to Simvue events, or defining @@ -7,24 +6,27 @@ """ +import datetime import http import typing -import datetime +from collections.abc import Generator import pydantic -from simvue.api.url import URL +from simvue.api.request import get as sv_get +from simvue.api.request import get_json_from_response +from simvue.models import EventSet, simvue_timestamp from .base import SimvueObject -from simvue.models import EventSet, simvue_timestamp -from simvue.api.request import get as sv_get, get_json_from_response -from collections.abc import Generator try: from typing import Self except ImportError: from typing_extensions import Self +if typing.TYPE_CHECKING: + from simvue.api.url import URL + __all__ = ["Events"] @@ -58,6 +60,7 @@ def __init__( token for alternative server, default None **kwargs : dict any additional arguments to be passed to the object initialiser + """ super().__init__( identifier=identifier, @@ -78,7 +81,7 @@ def get( offset: pydantic.PositiveInt | None = None, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Generator[EventSet]: """Retrieve events from the server. @@ -94,6 +97,8 @@ def get( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for retrieval Yields ------ @@ -103,6 +108,7 @@ def get( Returns ------- Generator[EventSet] + """ _class_instance = cls(_read_only=True, _local=True) _count: int = 0 @@ -117,7 +123,8 @@ def get( ): if (_data := response.get("data")) is None: raise RuntimeError( - f"Expected key 'data' for retrieval of {_class_instance.__class__.__name__.lower()}s" + "Expected key 'data' for retrieval of " + f"{_class_instance.__class__.__name__.lower()}s", ) for _entry in _data: @@ -136,7 +143,7 @@ def new( offline: bool = False, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Self: """Create a new Events entry on the Simvue server. @@ -152,11 +159,14 @@ def new( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional initialisation arguments Returns ------- Events an object representing this event set + """ return cls( run=run, @@ -185,9 +195,9 @@ def histogram( if timestamp_end - timestamp_begin <= datetime.timedelta(seconds=window): raise ValueError( "Invalid arguments for datetime range, " - "value difference must be greater than window" + "value difference must be greater than window", ) - _url: URL = self._base_url / "histogram" + _url: URL = self.base_url / "histogram" _time_begin: str = simvue_timestamp(timestamp_begin) _time_end: str = simvue_timestamp(timestamp_end) _response = sv_get( @@ -215,6 +225,7 @@ def delete(self, **kwargs) -> dict[str, typing.Any]: ------ NotImplementedError as event set deletion not supported + """ raise NotImplementedError("Cannot delete event set") diff --git a/simvue/api/objects/filter/base.py b/simvue/api/objects/filter/base.py index 306f3498..637298be 100644 --- a/simvue/api/objects/filter/base.py +++ b/simvue/api/objects/filter/base.py @@ -1,10 +1,10 @@ """Base Filter object for RestAPI queries.""" -import abc -from collections.abc import Generator -import typing import enum import json +import typing +from collections.abc import Generator + import pydantic as pyd from simvue.utilities import prettify_pydantic @@ -15,7 +15,7 @@ try: from typing import Self except ImportError: - from typing_extensions import Self # noqa: UP035 + from typing_extensions import Self class Time(str, enum.Enum): @@ -27,25 +27,30 @@ class Time(str, enum.Enum): Ended = "ended" -class RestAPIFilter(abc.ABC): +class RestAPIFilter: """RestAPI query filter object.""" def __init__(self, simvue_object: "type[SimvueObject] | None" = None) -> None: """Initialise a query object using a Simvue object class.""" - self._sv_object: "type[SimvueObject] | None" = simvue_object + self._sv_object: type[SimvueObject] | None = simvue_object self._filters: list[str] = [] def _time_within( - self, time_type: Time, *, hours: int = 0, days: int = 0, years: int = 0 + self, + time_type: Time, + *, + hours: int = 0, + days: int = 0, + years: int = 0, ) -> Self: """Define filter using time range.""" - if len(_non_zero := list(i for i in (hours, days, years) if i != 0)) > 1: + if len(_non_zero := [i for i in (hours, days, years) if i != 0]) > 1: raise AssertionError( - "Only one duration type may be provided: hours, days or years" + "Only one duration type may be provided: hours, days or years", ) if len(_non_zero) < 1: raise AssertionError( - f"No duration provided for filter '{time_type.value}_within'" + f"No duration provided for filter '{time_type.value}_within'", ) if hours: @@ -124,14 +129,14 @@ def exclude_metadata_attribute(self, attribute: str) -> Self: @prettify_pydantic @pyd.validate_call - def has_metadata_value(self, attribute: str, value: str | float | int) -> Self: + def has_metadata_value(self, attribute: str, value: str | float) -> Self: """Filter by the value of a metadata attribute.""" self._filters.append(f"metadata.{attribute} == {value}") return self @prettify_pydantic @pyd.validate_call - def exclude_metadata_value(self, attribute: str, value: str | float | int) -> Self: + def exclude_metadata_value(self, attribute: str, value: str | float) -> Self: """Veto by the value of a metadata attribute.""" self._filters.append(f"metadata.{attribute} != {value}") return self @@ -139,7 +144,9 @@ def exclude_metadata_value(self, attribute: str, value: str | float | int) -> Se @prettify_pydantic @pyd.validate_call def has_metadata_value_greater_than( - self, attribute: str, value: float | int + self, + attribute: str, + value: float, ) -> Self: """Filter by the value of a metadata value threshold.""" self._filters.append(f"metadata.{attribute} > {value}") @@ -147,7 +154,7 @@ def has_metadata_value_greater_than( @prettify_pydantic @pyd.validate_call - def has_metadata_value_less_than(self, attribute: str, value: float | int) -> Self: + def has_metadata_value_less_than(self, attribute: str, value: float) -> Self: """Filter by the value of a metadata value threshold.""" self._filters.append(f"metadata.{attribute} < {value}") return self @@ -155,7 +162,9 @@ def has_metadata_value_less_than(self, attribute: str, value: float | int) -> Se @prettify_pydantic @pyd.validate_call def has_metadata_value_greater_than_or_equal_to( - self, attribute: str, value: float | int + self, + attribute: str, + value: float, ) -> Self: """Filter by the value of a metadata value threshold.""" self._filters.append(f"metadata.{attribute} >= {value}") @@ -164,7 +173,9 @@ def has_metadata_value_greater_than_or_equal_to( @prettify_pydantic @pyd.validate_call def has_metadata_value_less_than_or_equal_to( - self, attribute: str, value: float | int + self, + attribute: str, + value: float, ) -> Self: """Filter by the value of a metadata value threshold.""" self._filters.append(f"metadata.{attribute} <= {value}") @@ -181,7 +192,10 @@ def get( raise RuntimeError("No object type associated with filter.") _filters: str = json.dumps(self._filters) return self._sv_object.get( - count=count, offset=offset, filters=_filters, **kwargs + count=count, + offset=offset, + filters=_filters, + **kwargs, ) def count(self, **kwargs) -> int: diff --git a/simvue/api/objects/filter/folder.py b/simvue/api/objects/filter/folder.py index a1ff94e2..b4019029 100644 --- a/simvue/api/objects/filter/folder.py +++ b/simvue/api/objects/filter/folder.py @@ -1,15 +1,17 @@ """Simvue RestAPI Folders Filter.""" import typing + import pydantic as pyd try: from typing import Self except ImportError: - from typing_extensions import Self # noqa: UP035 + from typing_extensions import Self from simvue.models import FOLDER_REGEX from simvue.utilities import prettify_pydantic + from .base import RestAPIFilter @@ -19,7 +21,8 @@ class FoldersFilter(RestAPIFilter): @prettify_pydantic @pyd.validate_call def has_path( - self, name: typing.Annotated[str, pyd.Field(pattern=FOLDER_REGEX)] + self, + name: typing.Annotated[str, pyd.Field(pattern=FOLDER_REGEX)], ) -> Self: """Check if a folder has the given path.""" self._filters.append(f"path == {name}") diff --git a/simvue/api/objects/filter/run.py b/simvue/api/objects/filter/run.py index 75d6d7ef..a92a710e 100644 --- a/simvue/api/objects/filter/run.py +++ b/simvue/api/objects/filter/run.py @@ -1,13 +1,14 @@ """Simvue RestAPI Runs Filter.""" import typing -import semver + import pydantic as pyd +import semver try: from typing import Self except ImportError: - from typing_extensions import Self # noqa: UP035 + from typing_extensions import Self from simvue.models import FOLDER_REGEX from simvue.utilities import prettify_pydantic @@ -17,10 +18,15 @@ try: from typing import override except ImportError: - from typing_extensions import override # noqa: UP035 + from typing_extensions import override Status = typing.Literal[ - "lost", "failed", "completed", "terminated", "running", "created" + "lost", + "failed", + "completed", + "terminated", + "running", + "created", ] @@ -135,7 +141,8 @@ def ended_within( @prettify_pydantic @pyd.validate_call def in_folder( - self, folder_path: typing.Annotated[str, pyd.Field(pattern=FOLDER_REGEX)] + self, + folder_path: typing.Annotated[str, pyd.Field(pattern=FOLDER_REGEX)], ) -> Self: """Filter by whether run is within the given folder.""" self._filters.append(f"folder.path == {folder_path}") @@ -151,7 +158,8 @@ def in_folder_containing(self, folder_path: str) -> Self: @prettify_pydantic @pyd.validate_call def exclude_in_folder( - self, folder_path: typing.Annotated[str, pyd.Field(pattern=FOLDER_REGEX)] + self, + folder_path: typing.Annotated[str, pyd.Field(pattern=FOLDER_REGEX)], ) -> Self: """Filter by whether run is not within the given folder.""" self._filters.append(f"folder.path != {folder_path}") @@ -188,7 +196,10 @@ def exclude_hostname(self, hostname: str) -> Self: @prettify_pydantic @pyd.validate_call def has_cpu( - self, *, architecture: str | None = None, processor: str | None = None + self, + *, + architecture: str | None = None, + processor: str | None = None, ) -> Self: """Filter by CPU architecture and processor.""" if architecture: @@ -200,7 +211,10 @@ def has_cpu( @prettify_pydantic @pyd.validate_call def exclude_cpu( - self, *, architecture: str | None = None, processor: str | None = None + self, + *, + architecture: str | None = None, + processor: str | None = None, ) -> Self: """Veto by CPU architecture and processor.""" if architecture: @@ -226,7 +240,10 @@ def has_gpu(self, *, name: str | None = None, processor: str | None = None) -> S @prettify_pydantic @pyd.validate_call def exclude_gpu( - self, *, name: str | None = None, processor: str | None = None + self, + *, + name: str | None = None, + processor: str | None = None, ) -> Self: """Veto by GPU name or processor.""" if name: @@ -242,7 +259,7 @@ def has_python_version(self, python_version: str) -> Self: _ = semver.Version.parse(python_version) except ValueError as e: raise ValueError( - f"'{python_version}' is not a valid semantic version." + f"'{python_version}' is not a valid semantic version.", ) from e self._filters.append(f"system.pythonversion == {python_version}") return self @@ -254,7 +271,7 @@ def exclude_python_version(self, python_version: str) -> Self: _ = semver.Version.parse(python_version) except ValueError as e: raise ValueError( - f"'{python_version}' is not a valid semantic version." + f"'{python_version}' is not a valid semantic version.", ) from e self._filters.append(f"system.pythonversion != {python_version}") return self @@ -262,7 +279,11 @@ def exclude_python_version(self, python_version: str) -> Self: @prettify_pydantic @pyd.validate_call def has_platform( - self, platform: str, *, release: str | None = None, version: str | None = None + self, + platform: str, + *, + release: str | None = None, + version: str | None = None, ) -> Self: """Filter by simulation host platform.""" self._filters.append(f"system.platform.system == {platform}") @@ -275,18 +296,22 @@ def has_platform( @prettify_pydantic @pyd.validate_call def exclude_platform( - self, platform: str, *, release: str | None = None, version: str | None = None + self, + platform: str, + *, + release: str | None = None, + version: str | None = None, ) -> Self: """Veto by simulation host platform. If platform is specified then results WITHOUT this platform are returned. - However if a version and/or release is given then results WITH the given platform - but NOT the given release/version are returned. + However if a version and/or release is given then results WITH + the given platform but NOT the given release/version are returned. """ self._filters.append( "system.platform.system " + "!=" if not release and not version - else "==" + " " + platform + else "==" + " " + platform, ) if release: self._filters.append(f"system.platform.release != {release}") diff --git a/simvue/api/objects/folder.py b/simvue/api/objects/folder.py index 03943336..5d880f54 100644 --- a/simvue/api/objects/folder.py +++ b/simvue/api/objects/folder.py @@ -1,5 +1,4 @@ -""" -Simvue Server Folder +"""Simvue Server Folder. ==================== Contains a class for remotely connecting to a Simvue folder, or defining @@ -7,22 +6,24 @@ """ -import http -import typing import datetime +import http import json +import typing +from collections.abc import Generator import pydantic from simvue.api.objects.filter import FoldersFilter +from simvue.api.request import get_json_from_response +from simvue.api.request import put as sv_put from simvue.exception import ObjectNotFoundError -from simvue.api.request import put as sv_put, get_json_from_response +from simvue.models import DATETIME_FORMAT, FOLDER_REGEX -from .base import SimvueObject, staging_check, write_only, Sort -from simvue.models import FOLDER_REGEX, DATETIME_FORMAT -from collections.abc import Generator +from .base import SimvueObject, Sort, staging_check, write_only -# Need to use this inside of Generator typing to fix bug present in Python 3.10 - see issue #745 +# Need to use this inside of Generator typing to +# fix bug present in Python 3.10 - see issue #745 try: from typing import Self, override except ImportError: @@ -41,7 +42,7 @@ class FolderSort(Sort): def check_column(cls, column: str) -> str: if ( column - and column not in ("created", "modified", "path") + and column not in {"created", "modified", "path"} and not column.startswith("metadata.") ): raise ValueError(f"Invalid sort column for folders '{column}") @@ -63,7 +64,7 @@ def __init__( server_token: pydantic.SecretStr | None = None, **kwargs, ) -> None: - """Initialise a Folder + """Initialise a Folder. If an identifier is provided a connection will be made to the object matching the identifier on the target server. @@ -79,9 +80,13 @@ def __init__( token for alternative server, default None **kwargs : dict any additional arguments to be passed to the object initialiser + """ super().__init__( - identifier, server_token=server_token, server_url=server_url, **kwargs + identifier, + server_token=server_token, + server_url=server_url, + **kwargs, ) self._properties.remove("tree") @@ -96,13 +101,13 @@ def new( server_token: pydantic.SecretStr | None = None, **kwargs, ) -> Self: - """Create a new Folder on the Simvue server with the given path""" + """Create a new Folder on the Simvue server with the given path.""" return cls( path=path, - _read_only=False, _offline=offline, server_url=server_url, server_token=server_token, + _read_only=False, **kwargs, ) @@ -116,7 +121,7 @@ def get( sorting: list[FolderSort] | None = None, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Generator[tuple[str, T | None]]: """Get folders from the server. @@ -132,19 +137,28 @@ def get( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional initialisation arguments Yields ------ tuple[str, Folder] id of run Folder object representing object on server + """ _params: dict[str, str] = kwargs if sorting: _params["sorting"] = json.dumps([i.to_params() for i in sorting]) - return super().get(count=count, offset=offset, **_params) + return super().get( + server_url=server_url, + server_token=server_token, + count=count, + offset=offset, + **_params, + ) @classmethod def filter(cls) -> FoldersFilter: @@ -167,10 +181,11 @@ def tree(self) -> dict[str, object]: ------- dict a nested dictionary describing the hierarchy + """ _level: int = len(self.path.split("/")) _folders = self.__class__.get( - filters=json.dumps([f"path contains {self.path}"]) + filters=json.dumps([f"path contains {self.path}"]), ) _paths = [folder.path.split("/") for _, folder in _folders] _paths = sorted(_paths, key=len) @@ -188,84 +203,84 @@ def tree(self) -> dict[str, object]: @property @staging_check def tags(self) -> list[str]: - """Return list of tags assigned to this folder""" + """Return list of tags assigned to this folder.""" return self._get_attribute("tags") @tags.setter @write_only @pydantic.validate_call def tags(self, tags: list[str]) -> None: - """Set tags assigned to this folder""" + """Set tags assigned to this folder.""" self._staging["tags"] = tags @property def path(self) -> str: - """Return the path of this folder""" + """Return the path of this folder.""" return self._get_attribute("path") @property @staging_check def description(self) -> str | None: - """Return the folder description""" + """Return the folder description.""" return self._get().get("description") @description.setter @write_only @pydantic.validate_call def description(self, description: str) -> None: - """Update the folder description""" + """Update the folder description.""" self._staging["description"] = description @property @staging_check def name(self) -> str | None: - """Return the folder name""" + """Return the folder name.""" return self._get().get("name") @name.setter @write_only @pydantic.validate_call def name(self, name: str) -> None: - """Update the folder name""" + """Update the folder name.""" self._staging["name"] = name @property @staging_check - def metadata(self) -> dict[str, int | str | None | float | dict] | None: - """Return the folder metadata""" + def metadata(self) -> dict[str, int | str | float | dict | None] | None: + """Return the folder metadata.""" return self._get().get("metadata") @metadata.setter @write_only @pydantic.validate_call - def metadata(self, metadata: dict[str, int | str | None | float | dict]) -> None: - """Update the folder metadata""" + def metadata(self, metadata: dict[str, int | float | str | dict | None]) -> None: + """Update the folder metadata.""" self._staging["metadata"] = metadata @property @staging_check def star(self) -> bool: - """Return if this folder is starred""" + """Return if this folder is starred.""" return self._get().get("starred", False) @star.setter @write_only @pydantic.validate_call def star(self, is_true: bool = True) -> None: - """Star this folder as a favourite""" + """Star this folder as a favourite.""" self._staging["starred"] = is_true @property @staging_check def ttl(self) -> int: - """Return the retention period for this folder""" + """Return the retention period for this folder.""" return self._get_attribute("ttl") @ttl.setter @write_only @pydantic.validate_call def ttl(self, time_seconds: int) -> None: - """Update the retention period for this folder""" + """Update the retention period for this folder.""" self._staging["ttl"] = time_seconds def delete( # should params to this be optional and default to False? @@ -276,16 +291,18 @@ def delete( # should params to this be optional and default to False? runs_only: bool | None = False, ) -> dict[str, typing.Any]: return super().delete( - recursive=recursive, runs=delete_runs, runs_only=runs_only + recursive=recursive, + runs=delete_runs, + runs_only=runs_only, ) @property def created(self) -> datetime.datetime | None: - """Retrieve created datetime for the run""" + """Retrieve created datetime for the run.""" _created: str | None = self._get_attribute("created") return ( datetime.datetime.strptime(_created, DATETIME_FORMAT).replace( - tzinfo=datetime.timezone.utc + tzinfo=datetime.timezone.utc, ) if _created else None @@ -295,7 +312,9 @@ def _set_favourite(self, *, starred: bool) -> dict: """Set starred status.""" _url = self.url / "starred" _response = sv_put( - f"{_url}", headers=self._user_config.headers, data={"starred": starred} + f"{_url}", + headers=self._user_config.headers, + data={"starred": starred}, ) return get_json_from_response( expected_status=[http.HTTPStatus.OK], @@ -318,4 +337,4 @@ def get_folder_from_path( if not _folder: raise ObjectNotFoundError(obj_type="folder", name=path) - return _folder # type: ignore + return _folder diff --git a/simvue/api/objects/grids.py b/simvue/api/objects/grids.py index de0f156f..1d3caca2 100644 --- a/simvue/api/objects/grids.py +++ b/simvue/api/objects/grids.py @@ -6,24 +6,29 @@ """ import http -import msgpack -import numpy import typing - -import pydantic - -from simvue.api.url import URL -from simvue.models import GridMetricSet from collections.abc import Generator +import msgpack +import numpy as np +import pydantic -from .base import SimvueObject, write_only from simvue.api.request import ( get as sv_get, - put as sv_put, - post as sv_post, +) +from simvue.api.request import ( get_json_from_response, ) +from simvue.api.request import ( + post as sv_post, +) +from simvue.api.request import ( + put as sv_put, +) +from simvue.api.url import URL +from simvue.models import GridMetricSet + +from .base import SimvueObject, write_only try: from typing import Self, override @@ -34,15 +39,15 @@ def check_ordered_array( - axis_ticks: list[list[float]] | numpy.ndarray, + axis_ticks: list[list[float]] | np.ndarray, ) -> list[list[float]]: """Returns if array is ordered or reverse ordered.""" - if isinstance(axis_ticks, numpy.ndarray): + if isinstance(axis_ticks, np.ndarray): axis_ticks = axis_ticks.tolist() for i, _array in enumerate(axis_ticks): - _array = numpy.array(_array) - if not numpy.all(numpy.sort(_array) == _array) or numpy.all( - reversed(numpy.sort(_array)) == _array + _array = np.array(_array) + if not np.all(np.sort(_array) == _array) or np.all( + reversed(np.sort(_array)) == _array, ): raise ValueError(f"Axis {i} has unordered values.") return axis_ticks @@ -64,7 +69,7 @@ def __init__( server_token: pydantic.SecretStr | None = None, **kwargs, ) -> None: - """Initialise a Grid + """Initialise a Grid. If an identifier is provided a connection will be made to the object matching the identifier on the target server. @@ -80,6 +85,7 @@ def __init__( token for alternative server, default None **kwargs : dict any additional arguments to be passed to the object initialiser + """ super().__init__( identifier, @@ -96,7 +102,7 @@ def attach_metric_for_run(self, run_id: str, metric_name: str) -> None: self._staging.setdefault("runs", []) self._staging["runs"].append((run_id, metric_name)) super().commit() - return + return None _response = sv_put( url=f"{self.run_data_url(run_id)}", @@ -120,6 +126,7 @@ def on_reconnect(self, id_mapping: dict[str, str]) -> None: ---------- id_mapping : dict[str, str] mapping from offline identifier to new online identifier. + """ _online_runs = ( (id_mapping[run_id], metric_name) @@ -130,7 +137,9 @@ def on_reconnect(self, id_mapping: dict[str, str]) -> None: try: self.attach_metric_for_run(run_id=run_id, metric_name=metric_name) except KeyError: - raise RuntimeError("Failed to retrieve online run identifier.") + raise RuntimeError( + "Failed to retrieve online run identifier.", + ) from None @property def grid(self) -> list[list[float]]: @@ -149,7 +158,9 @@ def new( grid: typing.Annotated[ list[list[float]], pydantic.conlist( - pydantic.conlist(float, min_length=1), min_length=1, max_length=2 + pydantic.conlist(float, min_length=1), + min_length=1, + max_length=2, ), pydantic.AfterValidator(check_ordered_array), ], @@ -157,7 +168,7 @@ def new( offline: bool = False, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Self: """Create a new Grid on the Simvue server. @@ -176,17 +187,19 @@ def new( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional initialsation arguments Returns ------- Metrics metrics object - """ + """ if len(labels) != len(grid): raise AssertionError( "Length of argument 'labels' must match first " - f"grid dimension {len(grid)}." + f"grid dimension {len(grid)}.", ) return cls( @@ -208,20 +221,24 @@ def dimensions(self) -> tuple[int, int]: def run_data_url(self, run_id: str) -> URL: """Returns the URL for grid data for a specific run.""" return URL( - f"{self._user_config.server.url}/runs/{run_id}/grids/{self._identifier}" + f"{self._user_config.server.url}/runs/{run_id}/grids/{self._identifier}", ) def run_metric_url(self, run_id: str, metric_name: str) -> URL: """Returns the URL for the values for a given run metric.""" return URL( - f"{self._user_config.server.url}/runs/{run_id}/metrics/{metric_name}/" + f"{self._user_config.server.url}/runs/{run_id}/metrics/{metric_name}/", ) @pydantic.validate_call def get_run_metric_values( - self, *, run_id: str, metric_name: str, step: int + self, + *, + run_id: str, + metric_name: str, + step: int, ) -> dict: - """Retrieve values for this grid from the server for a given run at a given step. + """Retrieve values for grid given run at a given step. Parameters ---------- @@ -233,9 +250,10 @@ def get_run_metric_values( time step to retrieve values for. Returns - ------ + ------- dict[str, list[dict[str, float]] dictionary containing values from this for the run at specified step. + """ _response = sv_get( url=f"{self.run_metric_url(run_id, metric_name) / 'values'}", @@ -268,6 +286,7 @@ def get_run_metric_span(self, *, run_id: str, metric_name: str) -> dict: ------- dict[str, list[dict[str, float]] dictionary containing span from this for the run at specified step. + """ _response = sv_get( url=f"{self.run_metric_url(run_id, metric_name) / 'span'}", @@ -313,7 +332,7 @@ def __init__( self, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> None: """Initialise a GridMetrics object instance. @@ -323,9 +342,15 @@ def __init__( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for retrieval + """ super().__init__( - identifier=None, server_url=server_url, server_token=server_token, **kwargs + identifier=None, + server_url=server_url, + server_token=server_token, + **kwargs, ) self._run_id = self._staging.get("run") self._is_set = True @@ -353,7 +378,7 @@ def new( offline: bool = False, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Self: """Create a new GridMetrics object for n-dimensional metric submission. @@ -369,11 +394,14 @@ def new( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for initialsation Returns ------- Metrics metrics object + """ return cls( run=run, @@ -398,7 +426,7 @@ def get( spans: bool = False, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Generator[dict[str, dict[str, list[dict[str, float]]]]]: """Retrieve tensor-metrics from the server for a given set of runs. @@ -420,11 +448,14 @@ def get( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for object retrieval Yields ------ dict[str, dict[str, list[dict[str, float]]] metric set object containing metrics for run. + """ for metric in metrics: for run in runs: @@ -439,16 +470,17 @@ def get( def commit(self) -> dict | None: if not (_run_staging := self._staging.pop("data", None)): - return + return None return self._log_values(_run_staging) def on_reconnect(self, id_mapping: dict[str, str]) -> None: - """Operations performed when this grid metrics object is switched from offline to online mode. + """Operations performed when grid metrics object switched mode switched. Parameters ---------- id_mapping : dict[str, str] mapping from offline identifier to new online identifier. + """ metrics = self._staging.pop("data", []) @@ -470,7 +502,7 @@ def _log_values(self, metrics: list[GridMetricSet]) -> None: self._staging.setdefault("data", []) self._staging["data"] += metrics super().commit() - return + return None _response = sv_post( url=f"{self._user_config.server.url}/{self.run_grids_endpoint(self._run_id)}", diff --git a/simvue/api/objects/metrics.py b/simvue/api/objects/metrics.py index 12b20f2c..0fedbc31 100644 --- a/simvue/api/objects/metrics.py +++ b/simvue/api/objects/metrics.py @@ -6,16 +6,17 @@ """ import http -import typing import json +import typing +from collections.abc import Generator import pydantic -from collections.abc import Generator +from simvue.api.request import get as sv_get +from simvue.api.request import get_json_from_response +from simvue.models import MetricSet from .base import SimvueObject -from simvue.models import MetricSet -from simvue.api.request import get as sv_get, get_json_from_response try: from typing import Self, override @@ -53,9 +54,13 @@ def __init__( token for alternative server, default None **kwargs : dict any additional arguments to be passed to the object initialiser + """ super().__init__( - identifier=None, server_url=server_url, server_token=server_token, **kwargs + identifier=identifier, + server_url=server_url, + server_token=server_token, + **kwargs, ) self._run_id = self._staging.get("run") self._is_set = True @@ -71,7 +76,7 @@ def new( offline: bool = False, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Self: """Create a new Metrics entry on the Simvue server. @@ -87,11 +92,14 @@ def new( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for object creation Returns ------- Metrics metrics object + """ return cls( run=run, @@ -115,7 +123,7 @@ def get( offset: pydantic.PositiveInt | None = None, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Generator[dict[str, dict[str, list[dict[str, float]]]]]: """Retrieve metrics from the server for a given set of runs. @@ -138,11 +146,14 @@ def get( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for retrieval Yields ------ dict[str, dict[str, list[dict[str, float]]] metric set object containing metrics for run. + """ yield from cls._get_all_objects( offset=offset, @@ -157,8 +168,8 @@ def get( @pydantic.validate_call def span(self, run_ids: list[str]) -> dict[str, int | float]: - """Returns the metrics span for the given runs""" - _url = self._base_url / "span" + """Returns the metrics span for the given runs.""" + _url = self.base_url / "span" _response = sv_get(url=f"{_url}", headers=self._headers, json=run_ids) return get_json_from_response( response=_response, @@ -168,10 +179,12 @@ def span(self, run_ids: list[str]) -> dict[str, int | float]: @pydantic.validate_call def names(self, run_ids: list[str]) -> list[str]: - """Returns the metric names for the given runs""" - _url = self._base_url / "names" + """Returns the metric names for the given runs.""" + _url = self.base_url / "names" _response = sv_get( - url=f"{_url}", headers=self._headers, params={"runs": json.dumps(run_ids)} + url=f"{_url}", + headers=self._headers, + params={"runs": json.dumps(run_ids)}, ) return get_json_from_response( response=_response, @@ -184,7 +197,7 @@ def _post_single(self, **kwargs) -> dict[str, typing.Any]: return super()._post_single(is_json=False, **kwargs) def delete(self, **kwargs) -> dict[str, typing.Any]: - """Metrics cannot be deleted""" + """Metrics cannot be deleted.""" raise NotImplementedError("Cannot delete metric set") def on_reconnect(self, id_mapping: dict[str, str]): @@ -206,5 +219,6 @@ def to_dict(self) -> dict[str, typing.Any]: ------- dict[str, Any] dictionary representation of metrics object. + """ return self._staging diff --git a/simvue/api/objects/run.py b/simvue/api/objects/run.py index b4d894f8..e1baa9c0 100644 --- a/simvue/api/objects/run.py +++ b/simvue/api/objects/run.py @@ -1,5 +1,4 @@ -""" -Simvue Runs +"""Simvue Runs. =========== Contains a class for remotely connecting to Simvue runs, or defining @@ -7,42 +6,54 @@ """ -from collections.abc import Generator, Iterable +import datetime import http +import json +import time import typing +from collections.abc import Generator, Iterable + import pydantic -import datetime -import time -import json try: from typing import Self, override except ImportError: from typing_extensions import Self, override +from simvue.api.request import ( + get as sv_get, +) +from simvue.api.request import ( + get_json_from_response, +) +from simvue.api.request import ( + put as sv_put, +) +from simvue.api.url import URL +from simvue.models import DATETIME_FORMAT, FOLDER_REGEX, NAME_REGEX, simvue_timestamp + from .base import ( ObjectBatchArgs, - VisibilityBatchArgs, SimvueObject, Sort, - staging_check, Visibility, + VisibilityBatchArgs, + staging_check, write_only, ) from .filter import RunsFilter -from simvue.api.request import ( - get as sv_get, - put as sv_put, - get_json_from_response, -) -from simvue.api.url import URL -from simvue.models import FOLDER_REGEX, NAME_REGEX, DATETIME_FORMAT, simvue_timestamp Status = typing.Literal[ - "lost", "failed", "completed", "terminated", "running", "created" + "lost", + "failed", + "completed", + "terminated", + "running", + "created", ] -# Need to use this inside of Generator typing to fix bug present in Python 3.10 - see issue #745 +# Need to use this inside of Generator typing to +# fix bug present in Python 3.10 - see issue #745 T = typing.TypeVar("T", bound="Run") __all__ = ["Run"] @@ -57,7 +68,7 @@ def check_column(cls, column: str) -> str: and column != "name" and not column.startswith("metrics") and not column.startswith("metadata.") - and column not in ("created", "started", "endtime", "modified") + and column not in {"created", "started", "endtime", "modified"} ): raise ValueError(f"Invalid sort column for runs '{column}'") @@ -72,7 +83,12 @@ class RunBatchArgs(ObjectBatchArgs): folder: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)] | None = None system: dict[str, typing.Any] | None = None status: typing.Literal[ - "terminated", "created", "failed", "completed", "lost", "running" + "terminated", + "created", + "failed", + "completed", + "lost", + "running", ] = "created" @@ -108,10 +124,14 @@ def __init__( token for alternative server, default None **kwargs : dict any additional arguments to be passed to the object initialiser + """ self.visibility = Visibility(self) super().__init__( - identifier, server_url=server_url, server_token=server_token, **kwargs + identifier, + server_url=server_url, + server_token=server_token, + **kwargs, ) @classmethod @@ -135,12 +155,17 @@ def new( folder: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)], system: dict[str, typing.Any] | None = None, status: typing.Literal[ - "terminated", "created", "failed", "completed", "lost", "running" + "terminated", + "created", + "failed", + "completed", + "lost", + "running", ] = "created", offline: bool = False, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Self: """Create a new Run on the Simvue server. @@ -158,6 +183,8 @@ def new( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for initialisation Returns ------- @@ -166,7 +193,6 @@ def new( Examples -------- - ```python run = Run.new( folder="/", @@ -176,6 +202,7 @@ def new( ) run.commit() ``` + """ return cls( folder=folder, @@ -200,7 +227,7 @@ def batch_create( metadata: dict[str, str | int | float | bool] | None = None, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Generator[str]: """Create a batch of Runs as a single request. @@ -211,18 +238,22 @@ def batch_create( visibility : VisibilityBatchArgs | None, optional specify visibility options for these runs, default is None. folder : str, optional - override folder specification for these runs to be a single folder, default None. + override folder specification for these runs to be + a single folder, default None. metadata : dict[str, int | str | float | bool], optional override metadata specification for these runs, default None. server_url: str | None, optional alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for object creation Yields ------ str identifiers for created runs + """ _data: list[dict[str, object]] = [ entry.model_dump(exclude_none=True) @@ -256,6 +287,7 @@ def name(self) -> str: Returns ------- str + """ return self._get_attribute("name") @@ -263,7 +295,8 @@ def name(self) -> str: @write_only @pydantic.validate_call def name( - self, name: typing.Annotated[str, pydantic.Field(pattern=NAME_REGEX)] + self, + name: typing.Annotated[str, pydantic.Field(pattern=NAME_REGEX)], ) -> None: self._staging["name"] = name @@ -275,6 +308,7 @@ def tags(self) -> list[str]: Returns ------- list[str] + """ return self._get_attribute("tags") @@ -292,6 +326,7 @@ def status(self) -> Status: Returns ------- "lost" | "failed" | "completed" | "terminated" | "running" | "created" + """ return self._get_attribute("status") @@ -309,6 +344,7 @@ def ttl(self) -> int: Returns ------- int + """ return self._get_attribute("ttl") @@ -326,6 +362,7 @@ def folder(self) -> str: Returns ------- str + """ return self._get_attribute("folder") @@ -333,7 +370,8 @@ def folder(self) -> str: @write_only @pydantic.validate_call def folder( - self, folder: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)] + self, + folder: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)], ) -> None: self._staging["folder"] = folder @@ -345,6 +383,7 @@ def metadata(self) -> dict[str, typing.Any]: Returns ------- dict[str, Any] + """ return self._get_attribute("metadata") @@ -367,6 +406,7 @@ def description(self) -> str: Returns ------- str + """ return self._get_attribute("description") @@ -383,6 +423,7 @@ def system(self) -> dict[str, typing.Any]: Returns ------- dict[str, Any] + """ return self._get_attribute("system") @@ -400,6 +441,7 @@ def heartbeat_timeout(self) -> int | None: Returns ------- int | None + """ return self._get_attribute("heartbeat_timeout") @@ -417,6 +459,7 @@ def notifications(self) -> typing.Literal["none", "all", "error", "lost"]: Returns ------- "none" | "all" | "error" | "lost" + """ return self._get_attribute("notifications")["state"] @@ -424,7 +467,8 @@ def notifications(self) -> typing.Literal["none", "all", "error", "lost"]: @write_only @pydantic.validate_call def notifications( - self, notifications: typing.Literal["none", "all", "error", "lost"] + self, + notifications: typing.Literal["none", "all", "error", "lost"], ) -> None: self._staging["notifications"] = {"state": notifications} @@ -436,6 +480,7 @@ def alerts(self) -> list[str]: Returns ------- list[str] + """ if self._offline: return self._get_attribute("alerts") @@ -469,12 +514,15 @@ def get( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for retrieval Yields ------ tuple[str, Run] id of run Run object representing object on server + """ _params: dict[str, str] = kwargs @@ -506,10 +554,12 @@ def get_alert_details(self) -> Generator[dict[str, typing.Any]]: Returns ------- Generator[dict[str, Any], None, None] + """ if self._offline: raise RuntimeError( - "Cannot get alert details from an offline run - use .alerts to access a list of IDs instead" + "Cannot get alert details from an offline run - " + "use .alerts to access a list of IDs instead", ) for alert in self._get_attribute("alerts"): yield alert["alert"] @@ -517,15 +567,20 @@ def get_alert_details(self) -> Generator[dict[str, typing.Any]]: @property @staging_check def created(self) -> datetime.datetime | None: - """Set/retrieve created datetime for the run. + """Set/retrieve created datetime in UTC for the run. Returns ------- datetime.datetime + """ _created: str | None = self._get_attribute("created") return ( - datetime.datetime.strptime(_created, DATETIME_FORMAT) if _created else None + datetime.datetime.strptime(_created, DATETIME_FORMAT).astimezone( + datetime.UTC, + ) + if _created + else None ) @created.setter @@ -537,7 +592,7 @@ def created(self, created: datetime.datetime) -> None: @property @staging_check def runtime(self) -> datetime.datetime | None: - """Retrieve execution time for the run""" + """Retrieve execution time for the run.""" _runtime: str | None = self._get_attribute("runtime") return time.strptime(_runtime, "%H:%M:%S.%f") if _runtime else None @@ -549,11 +604,12 @@ def started(self) -> datetime.datetime | None: Returns ------- datetime.datetime + """ _started: str | None = self._get_attribute("started") return ( datetime.datetime.strptime(_started, DATETIME_FORMAT).replace( - tzinfo=datetime.timezone.utc + tzinfo=datetime.timezone.utc, ) if _started else None @@ -568,21 +624,23 @@ def started(self, started: datetime.datetime) -> None: @property @staging_check def star(self) -> bool: - """Return if this folder is starred""" + """Return if this folder is starred.""" return self._get().get("starred", False) @star.setter @write_only @pydantic.validate_call def star(self, is_true: bool = True) -> None: - """Star this folder as a favourite""" + """Star this folder as a favourite.""" self._staging["starred"] = is_true def _set_favourite(self, *, starred: bool) -> dict: """Set starred status.""" _url = self.url / "starred" _response = sv_put( - f"{_url}", headers=self._user_config.headers, data={"starred": starred} + f"{_url}", + headers=self.user_config.headers, + data={"starred": starred}, ) return get_json_from_response( expected_status=[http.HTTPStatus.OK], @@ -598,11 +656,12 @@ def endtime(self) -> datetime.datetime | None: Returns ------- datetime.datetime + """ _endtime: str | None = self._get_attribute("endtime") return ( datetime.datetime.strptime(_endtime, DATETIME_FORMAT).replace( - tzinfo=datetime.timezone.utc + tzinfo=datetime.timezone.utc, ) if _endtime else None @@ -628,6 +687,7 @@ def metrics( Returns ------- Generator[tuple[str, dict[str, int | float | bool]] + """ yield from self._get_attribute("metrics").items() @@ -645,6 +705,7 @@ def events( Returns ------- Generator[tuple[str, dict[str, Any]] + """ yield from self._get_attribute("events").items() @@ -668,7 +729,7 @@ def send_heartbeat(self) -> dict[str, typing.Any] | None: _heartbeat_file.touch() return None - _url = self._base_url + _url = self.base_url _url /= f"{self._identifier}/heartbeat" _response = sv_put(f"{_url}", headers=self._headers, data={}) return get_json_from_response( @@ -705,6 +766,7 @@ def abort_trigger(self) -> bool: ------- bool the current state of the abort trigger + """ if self._offline or not self._identifier: return False @@ -725,6 +787,7 @@ def artifacts(self) -> list[dict[str, typing.Any]]: ------- list[dict[str, Any]] the artifacts associated with this run + """ if self._offline or not self._artifact_url: return [] @@ -746,6 +809,7 @@ def grids(self) -> list[dict[str, str]]: ------- list[dict[str, str]] the grids associated with this run + """ if self._offline or not self._grid_url: return [] @@ -778,7 +842,9 @@ def abort(self, reason: str) -> dict[str, typing.Any]: raise RuntimeError("Cannot abort run, no endpoint defined") _response = sv_put( - f"{self._abort_url}", headers=self._headers, data={"reason": reason} + f"{self._abort_url}", + headers=self._headers, + data={"reason": reason}, ) return get_json_from_response( @@ -794,13 +860,14 @@ def on_reconnect(self, id_mapping: dict[str, str]) -> None: ---------- id_mapping: dict[str, str] A mapping from offline identifier to online identifier. + """ online_alert_ids: list[str | None] = list( - set( + { id_mapping.get(_id) for _id in self._staging.get("alerts", []) if _id.startswith("offline") - ) + }, ) if not all(online_alert_ids): raise KeyError("Could not find alert ID in offline to online ID mapping.") diff --git a/simvue/api/objects/stats.py b/simvue/api/objects/stats.py index a121d4db..de4099f2 100644 --- a/simvue/api/objects/stats.py +++ b/simvue/api/objects/stats.py @@ -7,13 +7,15 @@ import http import typing -from pydantic import BaseModel import pydantic +from pydantic import BaseModel -from .base import SimvueObject -from simvue.api.request import get as sv_get, get_json_from_response +from simvue.api.request import get as sv_get +from simvue.api.request import get_json_from_response from simvue.api.url import URL +from .base import SimvueObject + __all__ = ["Stats"] @@ -54,10 +56,13 @@ def __init__( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + """ self.runs = RunStatistics(self) super().__init__( - identifier=None, server_url=server_url, server_token=server_token + identifier=None, + server_url=server_url, + server_token=server_token, ) # Stats is a singular object (i.e. identifier is not applicable) @@ -65,22 +70,24 @@ def __init__( self._identifier = "" @classmethod - def new(cls, **kwargs) -> None: + def new(cls, **_) -> None: """Creation of multiple stats objects is not logical here. Raises ------ AttributeError + """ raise AttributeError("Creation of statistics objects is not supported") @classmethod - def delete(cls, **kwargs) -> None: + def delete(cls, **_) -> None: """Deletion of stats object is not logical here. Raises ------ AttributeError + """ raise AttributeError("Deletion of statistics is not supported") @@ -90,6 +97,7 @@ def read_only(self) -> None: Raises ------ NotImplementedError + """ raise NotImplementedError("Statistics are not modifiable.") @@ -99,32 +107,34 @@ def id(self) -> None: Returns ------- None + """ - return None + return def on_reconnect(self, **_) -> None: """No offline to online reconnect functionality for statistics.""" - pass @classmethod - def get(cls, **kwargs) -> None: + def get(cls, **_) -> None: """Retrieval of multiple stats object is not logical here. Raises ------ AttributeError + """ raise AttributeError( - "Retrieval of multiple of statistics objects is not supported" + "Retrieval of multiple of statistics objects is not supported", ) @classmethod - def ids(cls, **kwargs) -> None: + def ids(cls, **_) -> None: """Retrieval of identifiers is not logical here. Raises ------ AttributeError + """ raise AttributeError("Retrieval of ids for statistics objects is not supported") @@ -135,6 +145,7 @@ def whoami(self) -> dict[str, str]: ------- dict[str, str] server response for 'whomai' query. + """ _url: URL = URL(self._user_config.server.url) / "whoami" _response = sv_get(url=f"{_url}", headers=self._headers) @@ -144,16 +155,16 @@ def whoami(self) -> dict[str, str]: scenario="Retrieving current user", ) - def _get_run_stats(self) -> dict[str, int]: - """Retrieve the run statistics""" + def get_run_stats(self) -> dict[str, int]: + """Retrieve the run statistics.""" return self._get_attribute("runs") def _get_local_staged(self) -> dict[str, typing.Any]: - """No staging for stats so returns empty dict""" + """No staging for stats so returns empty dict.""" return {} def _get_visibility(self) -> dict[str, bool | list[str]]: - """Visibility does not apply here""" + """Visibility does not apply here.""" return {} def to_dict(self) -> dict[str, typing.Any]: @@ -163,52 +174,53 @@ def to_dict(self) -> dict[str, typing.Any]: ------- dict[str, Any] statistics data as dictionary + """ - return {"runs": self._get_run_stats()} + return {"runs": self.get_run_stats()} def admin_stats(self, *, tenant: str | None = None) -> dict[str, dict[str, int]]: return { name: UserStatistics(**entry) for name, entry in self._get( - single=False, **({"tenant": tenant} if tenant else {}) + single=False, + **({"tenant": tenant} if tenant else {}), ).items() } def commit(self) -> None: """Does nothing, no data sendable to server.""" - pass class RunStatistics: - """Interface to the run section of statistics output""" + """Interface to the run section of statistics output.""" def __init__(self, sv_obj: Stats) -> None: self._sv_obj = sv_obj @property def created(self) -> int: - """Number of created runs""" - if (_created := self._sv_obj._get_run_stats().get("created")) is None: + """Number of created runs.""" + if (_created := self._sv_obj.get_run_stats().get("created")) is None: raise RuntimeError("Expected key 'created' in run statistics retrieval") return _created @property def running(self) -> int: - """Number of running runs""" - if (_running := self._sv_obj._get_run_stats().get("running")) is None: + """Number of running runs.""" + if (_running := self._sv_obj.get_run_stats().get("running")) is None: raise RuntimeError("Expected key 'running' in run statistics retrieval") return _running @property def completed(self) -> int: - """Number of completed runs""" - if (_completed := self._sv_obj._get_run_stats().get("running")) is None: + """Number of completed runs.""" + if (_completed := self._sv_obj.get_run_stats().get("running")) is None: raise RuntimeError("Expected key 'completed' in run statistics retrieval") return _completed @property def data(self) -> int: - """Data count""" - if (_data := self._sv_obj._get_run_stats().get("running")) is None: + """Data count.""" + if (_data := self._sv_obj.get_run_stats().get("running")) is None: raise RuntimeError("Expected key 'data' in run statistics retrieval") return _data diff --git a/simvue/api/objects/storage/__init__.py b/simvue/api/objects/storage/__init__.py index 01034513..6a7ea965 100644 --- a/simvue/api/objects/storage/__init__.py +++ b/simvue/api/objects/storage/__init__.py @@ -6,8 +6,8 @@ """ +from .fetch import Storage from .file import FileStorage from .s3 import S3Storage -from .fetch import Storage __all__ = ["FileStorage", "S3Storage", "Storage"] diff --git a/simvue/api/objects/storage/base.py b/simvue/api/objects/storage/base.py index 750451a5..e828b64e 100644 --- a/simvue/api/objects/storage/base.py +++ b/simvue/api/objects/storage/base.py @@ -1,17 +1,16 @@ -""" -Simvue Storage Base +"""Simvue Storage Base. =================== Contains general definitions for Simvue Storage objects. """ +import datetime import typing import pydantic -import datetime from simvue.api.objects.base import SimvueObject, staging_check, write_only -from simvue.models import NAME_REGEX, DATETIME_FORMAT +from simvue.models import DATETIME_FORMAT, NAME_REGEX try: from typing import Self, override @@ -42,79 +41,90 @@ def __init__( ) -> None: """Retrieve a storage instance from the Simvue server by identifier.""" super().__init__( - identifier, server_url=server_url, server_token=server_token, **kwargs + identifier, + server_url=server_url, + server_token=server_token, + **kwargs, ) @classmethod def new( - cls, *, server_url: str | None, server_token: pydantic.SecretStr | None, **_ + cls, + *, + server_url: str | None, + server_token: pydantic.SecretStr | None, + **_, ) -> Self: - """Create a new instance of a storage type""" - pass + """Create a new instance of a storage type.""" @property @staging_check def name(self) -> str: - """Retrieve the name for this storage""" + """Retrieve the name for this storage.""" return self._get_attribute("name") @name.setter @write_only @pydantic.validate_call def name( - self, name: typing.Annotated[str, pydantic.Field(pattern=NAME_REGEX)] + self, + name: typing.Annotated[str, pydantic.Field(pattern=NAME_REGEX)], ) -> None: - """Set name assigned to this folder""" + """Set name assigned to this folder.""" self._staging["name"] = name @property def backend(self) -> str: - """Retrieve the backend of storage""" + """Retrieve the backend of storage.""" return self._get_attribute("backend") @property @staging_check def is_default(self) -> bool: - """Retrieve if this is the default storage for the user""" + """Retrieve if this is the default storage for the user.""" return self._get_attribute("is_default") @is_default.setter @write_only @pydantic.validate_call def is_default(self, is_default: bool) -> None: - """Set this storage to be the default""" + """Set this storage to be the default.""" self._staging["is_default"] = is_default @property @staging_check def is_tenant_useable(self) -> bool: - """Retrieve if this is usable by the current user tenant""" + """Retrieve if this is usable by the current user tenant.""" return self._get_attribute("is_tenant_useable") @is_tenant_useable.setter @write_only @pydantic.validate_call def is_tenant_useable(self, is_tenant_useable: bool) -> None: - """Set this storage to be usable by the current user tenant""" + """Set this storage to be usable by the current user tenant.""" self._staging["is_tenant_useable"] = is_tenant_useable @property @staging_check def is_enabled(self) -> bool: - """Retrieve if this is enabled""" + """Retrieve if this is enabled.""" return self._get_attribute("is_enabled") @is_enabled.setter @write_only @pydantic.validate_call def is_enabled(self, is_enabled: bool) -> None: - """Set this storage to be usable by the current user tenant""" + """Set this storage to be usable by the current user tenant.""" self._staging["is_enabled"] = is_enabled @property def created(self) -> datetime.datetime | None: - """Retrieve created datetime for the artifact""" + """Retrieve created datetime in UTC for the artifact.""" _created: str | None = self._get_attribute("created") return ( - datetime.datetime.strptime(_created, DATETIME_FORMAT) if _created else None + datetime.datetime.strptime(_created, DATETIME_FORMAT).astimezone( + tz=datetime.UTC, + ) + if _created + else None ) diff --git a/simvue/api/objects/storage/fetch.py b/simvue/api/objects/storage/fetch.py index 1468e9f3..7dd4cfe1 100644 --- a/simvue/api/objects/storage/fetch.py +++ b/simvue/api/objects/storage/fetch.py @@ -5,15 +5,17 @@ """ import http +import typing +from collections.abc import Generator + import pydantic -from simvue.api.request import get_json_from_response from simvue.api.request import get as sv_get -from collections.abc import Generator +from simvue.api.request import get_json_from_response -from .s3 import S3Storage -from .file import FileStorage from .base import StorageBase +from .file import FileStorage +from .s3 import S3Storage class Storage: @@ -29,7 +31,7 @@ def __new__( *, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> S3Storage | FileStorage: """Retrieve an object representing on the server by id. @@ -41,17 +43,35 @@ def __new__( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for initialisation Returns ------- S3Storage | FileStorage object representing storage + """ - _storage_pre = StorageBase(identifier=identifier, **kwargs) + _storage_pre = StorageBase( + server_token=server_token, + server_url=server_url, + identifier=identifier, + **kwargs, + ) if _storage_pre.backend == "S3": - return S3Storage(identifier=identifier, **kwargs) - elif _storage_pre.backend == "File": - return FileStorage(identifier=identifier, **kwargs) + return S3Storage( + server_token=server_token, + server_url=server_url, + identifier=identifier, + **kwargs, + ) + if _storage_pre.backend == "File": + return FileStorage( + server_token=server_token, + server_url=server_url, + identifier=identifier, + **kwargs, + ) raise RuntimeError(f"Unknown backend '{_storage_pre.backend}'") @@ -64,7 +84,7 @@ def get( offset: int | None = None, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Generator[tuple[str, FileStorage | S3Storage]]: """Returns storage systems accessible to the current user. @@ -78,14 +98,16 @@ def get( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for retrieval Yields ------ tuple[str, FileStorage | S3Storage] identifier for a storage the storage itself as a class instance - """ + """ # Currently no storage filters _ = kwargs.pop("filters", None) @@ -95,10 +117,10 @@ def get( server_token=server_token, _local=True, ) - _url = f"{_class_instance._base_url}" + _url = f"{_class_instance.base_url}" _response = sv_get( _url, - headers=_class_instance._headers, + headers=_class_instance.user_config.headers, params={"start": offset, "count": count} | kwargs, ) _label: str = _class_instance.__class__.__name__.lower() @@ -138,5 +160,5 @@ def get( ) else: raise RuntimeError( - f"Unrecognised storage backend '{_entry['backend']}'" + f"Unrecognised storage backend '{_entry['backend']}'", ) diff --git a/simvue/api/objects/storage/file.py b/simvue/api/objects/storage/file.py index 1e8396ba..dcdb0646 100644 --- a/simvue/api/objects/storage/file.py +++ b/simvue/api/objects/storage/file.py @@ -12,9 +12,10 @@ from typing_extensions import Self, override import pydantic -from .base import StorageBase from simvue.models import NAME_REGEX +from .base import StorageBase + class FileStorage(StorageBase): """Simvue File Storage. @@ -33,11 +34,12 @@ def __init__( server_token: pydantic.SecretStr | None = None, **kwargs, ) -> None: - """Initialise a File Storage + """Initialise a File Storage. If an identifier is provided a connection will be made to the object matching the identifier on the target server. - Else a new FileStorage instance will be created using arguments provided in kwargs. + Else a new FileStorage instance will be created using arguments + provided in kwargs. Parameters ---------- @@ -49,9 +51,13 @@ def __init__( token for alternative server, default None **kwargs : dict any additional arguments to be passed to the object initialiser + """ super().__init__( - identifier, server_url=server_url, server_token=server_token, **kwargs + identifier, + server_url=server_url, + server_token=server_token, + **kwargs, ) @override @@ -95,6 +101,7 @@ def new( ------- FileStorage instance of storage system with staged changes + """ return cls( name=name, diff --git a/simvue/api/objects/storage/s3.py b/simvue/api/objects/storage/s3.py index 0d7e709c..7541182d 100644 --- a/simvue/api/objects/storage/s3.py +++ b/simvue/api/objects/storage/s3.py @@ -12,10 +12,10 @@ from typing_extensions import Self, override import pydantic -from simvue.api.objects.base import write_only, staging_check +from simvue.api.objects.base import staging_check, write_only +from simvue.models import NAME_REGEX from .base import StorageBase -from simvue.models import NAME_REGEX class S3Storage(StorageBase): @@ -35,11 +35,12 @@ def __init__( server_token: pydantic.SecretStr | None = None, **kwargs, ) -> None: - """Initialise a S3 Storage + """Initialise a S3 Storage. If an identifier is provided a connection will be made to the object matching the identifier on the target server. - Else a new S3Storage instance will be created using arguments provided in kwargs. + Else a new S3Storage instance will be created using arguments + provided in kwargs. Parameters ---------- @@ -51,10 +52,14 @@ def __init__( token for alternative server, default None **kwargs : dict any additional arguments to be passed to the object initialiser + """ self.config = Config(self) super().__init__( - identifier, server_url=server_url, server_token=server_token, **kwargs + identifier, + server_url=server_url, + server_token=server_token, + **kwargs, ) self._local_only_args += [ "endpoint_url", @@ -127,7 +132,7 @@ def new( """ _config: dict[str, str] = { - "endpoint_url": endpoint_url.__str__(), + "endpoint_url": str(endpoint_url), "access_key_id": access_key_id, "secret_access_key": secret_access_key.get_secret_value(), "bucket": bucket, @@ -152,12 +157,12 @@ def new( _offline=offline, _read_only=False, ) - _storage._staging |= _config + _storage.append_to_staging(_config) return _storage @staging_check def get_config(self) -> dict[str, typing.Any]: - """Retrieve configuration""" + """Retrieve configuration.""" try: return self._get_attribute("config") except AttributeError: @@ -165,10 +170,10 @@ def get_config(self) -> dict[str, typing.Any]: class Config: - """S3 Configuration interface""" + """S3 Configuration interface.""" def __init__(self, storage: S3Storage) -> None: - """Initialise a new configuration using an S3Storage object""" + """Initialise a new configuration using an S3Storage object.""" self._sv_obj = storage @property @@ -180,55 +185,56 @@ def endpoint_url(self) -> str: ------- str the endpoint for this storage object + """ try: return self._sv_obj.get_config()["endpoint_url"] except KeyError as e: raise RuntimeError( - "Expected key 'endpoint_url' in alert definition retrieval" + "Expected key 'endpoint_url' in alert definition retrieval", ) from e @endpoint_url.setter @write_only @pydantic.validate_call def endpoint_url(self, endpoint_url: pydantic.HttpUrl) -> None: - _config = self._sv_obj.get_config() | {"endpoint_url": endpoint_url.__str__()} - self._sv_obj._staging["config"] = _config + _config = self._sv_obj.get_config() | {"endpoint_url": str(endpoint_url)} + self._sv_obj.append_to_staging({"config": _config}) @property @staging_check def region_name(self) -> str | None: - """Retrieve the region name for this storage""" + """Retrieve the region name for this storage.""" return self._sv_obj.get_config().get("region_name") @region_name.setter @write_only @pydantic.validate_call def region_name(self, region_name: str) -> None: - """Modify the region name for this storage""" + """Modify the region name for this storage.""" _config = self._sv_obj.get_config() | {"region_name": region_name} - self._sv_obj._staging["config"] = _config + self._sv_obj.append_to_staging({"config": _config}) @property @staging_check def bucket(self) -> str: - """Retrieve the bucket label for this storage""" + """Retrieve the bucket label for this storage.""" try: return self._sv_obj.get_config()["bucket"] except KeyError as e: raise RuntimeError( - "Expected key 'bucket' in alert definition retrieval" + "Expected key 'bucket' in alert definition retrieval", ) from e @bucket.setter @write_only @pydantic.validate_call def bucket(self, bucket: str) -> None: - """Modify the bucket label for this storage""" + """Modify the bucket label for this storage.""" if self._sv_obj.type == "file": raise ValueError( - f"Cannot set attribute 'bucket' for storage type '{self._sv_obj.type}'" + f"Cannot set attribute 'bucket' for storage type '{self._sv_obj.type}'", ) _config = self._sv_obj.get_config() | {"bucket": bucket} - self._sv_obj._staging["config"] = _config + self._sv_obj.append_to_staging({"config": _config}) diff --git a/simvue/api/objects/tag.py b/simvue/api/objects/tag.py index a7fc15ba..9b1034bb 100644 --- a/simvue/api/objects/tag.py +++ b/simvue/api/objects/tag.py @@ -5,15 +5,16 @@ """ -import typing -import json import datetime +import json +import typing +from collections.abc import Generator + import pydantic import pydantic_extra_types.color as pyd_color from simvue.api.objects.base import SimvueObject, Sort, staging_check, write_only from simvue.models import DATETIME_FORMAT -from collections.abc import Generator try: from typing import Self, override @@ -27,7 +28,7 @@ class TagSort(Sort): @pydantic.field_validator("column") @classmethod def check_column(cls, column: str) -> str: - if column and column not in ("created", "name"): + if column and column not in {"created", "name"}: raise ValueError(f"Invalid sort column for tags '{column}") return column @@ -48,7 +49,7 @@ def __init__( server_token: pydantic.SecretStr | None = None, **kwargs, ) -> None: - """Initialise a Tag + """Initialise a Tag. If an identifier is provided a connection will be made to the object matching the identifier on the target server. @@ -64,9 +65,13 @@ def __init__( token for alternative server, default None **kwargs : dict any additional arguments to be passed to the object initialiser + """ super().__init__( - identifier, server_url=server_url, server_token=server_token, **kwargs + identifier, + server_url=server_url, + server_token=server_token, + **kwargs, ) @override @@ -79,7 +84,7 @@ def new( offline: bool = False, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Self: """Create a new Tag on the Simvue server. @@ -93,11 +98,14 @@ def new( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for initialisation Returns ------- Tag tag object with staged attributes + """ _data: dict[str, typing.Any] = {"name": name} return cls( @@ -112,49 +120,49 @@ def new( @property @staging_check def name(self) -> str: - """Retrieve the tag name""" + """Retrieve the tag name.""" return self._get_attribute("name") @name.setter @write_only @pydantic.validate_call def name(self, name: str) -> None: - """Set the tag name""" + """Set the tag name.""" self._staging["name"] = name @property @staging_check def colour(self) -> pyd_color.RGBA: - """Retrieve the tag colour""" + """Retrieve the tag colour.""" return pyd_color.parse_str(self._get_attribute("colour")) @colour.setter @write_only @pydantic.validate_call def colour(self, colour: pyd_color.Color) -> None: - """Set the tag colour""" + """Set the tag colour.""" self._staging["colour"] = colour.as_hex() @property @staging_check def description(self) -> str: - """Get description for this tag""" + """Get description for this tag.""" return self._get_attribute("description") @description.setter @write_only @pydantic.validate_call def description(self, description: str) -> None: - """Set the description for this tag""" + """Set the description for this tag.""" self._staging["description"] = description @property def created(self) -> datetime.datetime | None: - """Retrieve created datetime for the run""" + """Retrieve created datetime for the run.""" _created: str | None = self._get_attribute("created") return ( datetime.datetime.strptime(_created, DATETIME_FORMAT).replace( - tzinfo=datetime.timezone.utc + tzinfo=datetime.timezone.utc, ) if _created else None @@ -171,7 +179,7 @@ def get( sorting: list[TagSort] | None = None, server_url: str | None = None, server_token: pydantic.SecretStr | None = None, - **kwargs, + **kwargs: typing.Any, ) -> Generator[tuple[str, "SimvueObject"]]: """Get tags from the server. @@ -187,12 +195,15 @@ def get( alternative server URL, default None server_token : str | None, optional token for alternative server, default None + **kwargs : Any + additional arguments for the request Yields ------ tuple[str, Tag] id of tag Tag object representing object on server + """ # There are currently no tag filters _ = kwargs.pop("filters", None) diff --git a/simvue/api/request.py b/simvue/api/request.py index 5a568588..6ca8bd84 100644 --- a/simvue/api/request.py +++ b/simvue/api/request.py @@ -1,5 +1,4 @@ -""" -Simvue API Connection +"""Simvue API Connection. ===================== Provides methods for interacting with a Simvue server which include retry @@ -8,10 +7,11 @@ """ import copy +import http import json as json_module -import typing import logging -import http +import typing +from collections.abc import Generator import requests from tenacity import ( @@ -20,8 +20,8 @@ stop_after_attempt, wait_exponential, ) + from simvue.utilities import parse_validation_response -from collections.abc import Generator DEFAULT_API_TIMEOUT = 10 RETRY_MULTIPLIER = 1 @@ -31,11 +31,12 @@ MAX_ENTRIES_PER_PAGE: int = 100 RETRY_STATUSES = {502, 503, 504} +logger = logging.getLogger(__name__) + def set_json_header(headers: dict[str, str]) -> dict[str, str]: - """ - Return a copy of the headers with Content-Type set to - application/json + """Return a copy of the headers with Content-Type set to + application/json. """ headers = copy.deepcopy(headers) headers["Content-Type"] = "application/json" @@ -54,12 +55,13 @@ class RetryableHTTPError(Exception): RetryableHTTPError, requests.exceptions.Timeout, requests.exceptions.ConnectionError, - ) + ), ), reraise=True, ) def post( url: str, + *, headers: dict[str, str], params: dict[str, str], data: typing.Any, @@ -67,7 +69,7 @@ def post( timeout: int | None = None, files: dict[str, typing.Any] | None = None, ) -> requests.Response: - """HTTP POST with retries + """HTTP POST with retries. Parameters ---------- @@ -81,6 +83,10 @@ def post( data to post is_json : bool, optional send as JSON string, by default True + timeout : int | None = None + timeout for the request + files : dict[str, Any] | None = None + file data for this request Returns ------- @@ -106,12 +112,13 @@ def post( if response.status_code == http.HTTPStatus.UNPROCESSABLE_ENTITY: _parsed_response = parse_validation_response(response.json()) raise ValueError( - f"Validation error for '{url}' [{response.status_code}]:\n{_parsed_response}" + f"Validation error for '{url}' " + f"[{response.status_code}]:\n{_parsed_response}", ) if response.status_code in RETRY_STATUSES: raise RetryableHTTPError( - f"Received status code {response.status_code} from server" + f"Received status code {response.status_code} from server", ) return response @@ -124,20 +131,21 @@ def post( RetryableHTTPError, requests.exceptions.Timeout, requests.exceptions.ConnectionError, - ) + ), ), stop=stop_after_attempt(RETRY_STOP), reraise=True, ) def put( url: str, + *, headers: dict[str, str], data: dict[str, typing.Any] | None = None, json: dict[str, typing.Any] | None = None, is_json: bool = True, timeout: int = DEFAULT_API_TIMEOUT, ) -> requests.Response: - """HTTP PUT with retries + """HTTP PUT with retries. Parameters ---------- @@ -158,6 +166,7 @@ def put( ------- requests.Response response from executing PUT + """ if is_json and data: data_sent: str | dict[str, typing.Any] = json_module.dumps(data) @@ -165,15 +174,19 @@ def put( else: data_sent = data - logging.debug(f"PUT: {url}\n\tdata={data_sent}\n\tjson={json}") + logger.debug("PUT: %s\n\tdata=%s\n\tjson=%s", url, data_sent, json) response = requests.put( - url, headers=headers, data=data_sent, timeout=timeout, json=json + url, + headers=headers, + data=data_sent, + timeout=timeout, + json=json, ) if response.status_code in RETRY_STATUSES: raise RetryableHTTPError( - f"Received status code {response.status_code} from server" + f"Received status code {response.status_code} from server", ) return response @@ -186,7 +199,7 @@ def put( RetryableHTTPError, requests.exceptions.Timeout, requests.exceptions.ConnectionError, - ) + ), ), stop=stop_after_attempt(RETRY_STOP), reraise=True, @@ -198,14 +211,16 @@ def get( timeout: int = DEFAULT_API_TIMEOUT, json: dict[str, typing.Any] | None = None, ) -> requests.Response: - """HTTP GET + """HTTP GET. Parameters ---------- url : str URL to put to headers : dict[str, str] - headers for the post request + headers for the get request + params : dict[str, Any] + additional parameters for request timeout : int, optional timeout of request, by default DEFAULT_API_TIMEOUT json : dict[str, Any] | None, optional @@ -215,15 +230,20 @@ def get( ------- requests.Response response from executing GET + """ - logging.debug(f"GET: {url}\n\tparams={params}") + logger.debug("GET: %s\n\tparams=%s", url, params) response = requests.get( - url, headers=headers, timeout=timeout, params=params, json=json + url, + headers=headers, + timeout=timeout, + params=params, + json=json, ) if response.status_code in RETRY_STATUSES: raise RetryableHTTPError( - f"Received status code {response.status_code} from server" + f"Received status code {response.status_code} from server", ) return response @@ -236,7 +256,7 @@ def get( RetryableHTTPError, requests.exceptions.Timeout, requests.exceptions.ConnectionError, - ) + ), ), stop=stop_after_attempt(RETRY_STOP), reraise=True, @@ -247,7 +267,7 @@ def delete( timeout: int = DEFAULT_API_TIMEOUT, params: dict[str, typing.Any] | None = None, ) -> requests.Response: - """HTTP DELETE + """HTTP DELETE. Parameters ---------- @@ -264,24 +284,26 @@ def delete( ------- requests.Response response from executing DELETE + """ - logging.debug(f"DELETE: {url}\n\tparams={params}") + logger.debug("DELETE: %s\n\tparams=%s", url, params) response = requests.delete(url, headers=headers, timeout=timeout, params=params) if response.status_code in RETRY_STATUSES: raise RetryableHTTPError( - f"Received status code {response.status_code} from server" + f"Received status code {response.status_code} from server", ) return response def get_json_from_response( + *, expected_status: list[int], scenario: str, response: requests.Response, allow_parse_failure: bool = False, - expected_type: typing.Type[dict | list] = dict, + expected_type: type[dict | list] = dict, ) -> dict | list: try: json_response = response.json() @@ -296,7 +318,10 @@ def get_json_from_response( if (_status_code := response.status_code) in expected_status: if not isinstance(json_response, expected_type): - details = f"expected type '{expected_type.__name__}' but got '{type(json_response).__name__}'" + details = ( + f"expected type '{expected_type.__name__}' " + f"but got '{type(json_response).__name__}'" + ) elif json_response is not None: return json_response else: @@ -320,11 +345,12 @@ def get_json_from_response( def get_paginated( url: str, + *, headers: dict[str, str] | None = None, timeout: int = DEFAULT_API_TIMEOUT, json: dict[str, typing.Any] | None = None, - offset: int | None = None, count: int | None = None, + offset: int | None = None, **params, ) -> Generator[requests.Response]: """Paginate results of a server query. @@ -339,11 +365,18 @@ def get_paginated( timeout of request, by default DEFAULT_API_TIMEOUT json : dict[str, Any] | None, optional any json to send in request + count: int | None, optional + limit number of objects + offset : int | None, optional + set start index for objects list + **params: Any + additional parameters for request Yield ----- requests.Response server response + """ _offset: int = offset or 0 @@ -371,5 +404,6 @@ def get_paginated( break except json_module.JSONDecodeError: raise RuntimeError( - f"[{_response.status_code}] Failed to retrieve content from server: {_response.text}" - ) + f"[{_response.status_code}] Failed to retrieve content from server: " + + _response.text, + ) from None diff --git a/simvue/api/url.py b/simvue/api/url.py index 92d6a8fd..231e8d70 100644 --- a/simvue/api/url.py +++ b/simvue/api/url.py @@ -1,5 +1,4 @@ -""" -URL Library +"""URL Library. =========== Module contains classes for easier handling of URLs. @@ -10,8 +9,8 @@ from typing import Self except ImportError: from typing_extensions import Self -import urllib.parse import copy +import urllib.parse import pydantic @@ -21,8 +20,8 @@ class URL: @pydantic.validate_call def __init__(self, url: str) -> None: - """Initialise a url from string form""" - url = url[:-1] if url.endswith("/") else url + """Initialise a url from string form.""" + url = url.removesuffix("/") _url = urllib.parse.urlparse(url) self._scheme: str = _url.scheme @@ -32,21 +31,21 @@ def __init__(self, url: str) -> None: self._fragment: str = _url.fragment def __truediv__(self, other: str) -> Self: - """Define URL extension through use of '/'""" + """Define URL extension through use of '/'.""" _new = copy.deepcopy(self) _new /= other return _new def __repr__(self) -> str: - """Representation of URL""" + """Representation of URL.""" _out_str = f"{self.__class__.__module__}.{self.__class__.__qualname__}" return f"{_out_str}(url={self.__str__()!r})" @pydantic.validate_call def __itruediv__(self, other: str) -> Self: - """Define URL extension through use of '/'""" - other = other[1:] if other.startswith("/") else other - other = other[:-1] if other.endswith("/") else other + """Define URL extension through use of '/'.""" + other = other.removeprefix("/") + other = other.removesuffix("/") self._path = f"{self._path}/{other}" if other else self._path return self @@ -72,7 +71,7 @@ def port(self) -> int | None: return self._port def __str__(self) -> str: - """Construct string form of the URL""" + """Construct string form of the URL.""" _out_str: str = "" if self.scheme: _out_str += f"{self.scheme}://" diff --git a/simvue/bin/sender.py b/simvue/bin/sender.py index 73e778de..8c433939 100644 --- a/simvue/bin/sender.py +++ b/simvue/bin/sender.py @@ -2,10 +2,10 @@ import logging import pathlib -import click -from simvue.sender import Sender, UPLOAD_ORDER, UploadItem +import click +from simvue.sender import UPLOAD_ORDER, Sender, UploadItem _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) @@ -26,7 +26,8 @@ type=int, required=False, default=10, - help="The number of objects of a given type above which items will be sent to the server in parallel, by default 10", + help="The number of objects of a given type above which items will be " + "sent to the server in parallel, by default 10", ) @click.option( "-o", @@ -67,5 +68,5 @@ def sender_cli( ) _sender.upload(objects_to_upload) except Exception as err: - _logger.critical("Exception running sender: %s", str(err)) - raise click.Abort + _logger.critical("Exception running sender: %s", err) + raise click.Abort from None diff --git a/simvue/client.py b/simvue/client.py index ef5ceef4..c6562b2f 100644 --- a/simvue/client.py +++ b/simvue/client.py @@ -5,52 +5,54 @@ """ import contextlib +import http import json import logging import pathlib import typing -import http -import pydantic -from concurrent.futures import ThreadPoolExecutor, as_completed from collections.abc import Generator -from pandas import DataFrame +from concurrent.futures import ThreadPoolExecutor, as_completed +import pydantic import requests +from pandas import DataFrame from simvue.api.objects.alert.base import AlertBase from simvue.exception import ObjectNotFoundError -from .converters import ( - aggregated_metrics_to_dataframe, - to_dataframe, - parse_run_set_metrics, -) -from .serialization import deserialize_data -from .simvue_types import DeserializedContent -from .utilities import check_extra, prettify_pydantic -from .models import FOLDER_REGEX, NAME_REGEX -from .config.user import SimvueConfiguration -from .api.request import get_json_from_response from .api.objects import ( - Run, - Folder, - Tag, - Artifact, Alert, + Artifact, FileArtifact, + Folder, ObjectArtifact, + Run, + Tag, get_folder_from_path, ) - +from .api.request import DEFAULT_API_TIMEOUT, get_json_from_response +from .config.user import SimvueConfiguration +from .converters import ( + aggregated_metrics_to_dataframe, + parse_run_set_metrics, + to_dataframe, +) +from .models import FOLDER_REGEX, NAME_REGEX +from .serialization import deserialize_data +from .utilities import check_extra, prettify_pydantic CONCURRENT_DOWNLOADS = 10 DOWNLOAD_CHUNK_SIZE = 8192 -logger = logging.getLogger(__file__) +if typing.TYPE_CHECKING: + from .simvue_types import DeserializedContent + +logger = logging.getLogger(__name__) def _download_artifact_to_file( - artifact: FileArtifact | ObjectArtifact, output_dir: pathlib.Path | None + artifact: FileArtifact | ObjectArtifact, + output_dir: pathlib.Path | None, ) -> None: if not artifact.name: raise RuntimeError(f"Expected artifact '{artifact.id}' to have a name") @@ -71,7 +73,7 @@ def __init__( server_token: pydantic.SecretStr | None = None, server_url: str | None = None, ) -> None: - """Initialise an instance of the Simvue client + """Initialise an instance of the Simvue client. Parameters ---------- @@ -79,17 +81,21 @@ def __init__( specify token, if unset this is read from the config file server_url : str, optional specify URL, if unset this is read from the config file + """ self._user_config = SimvueConfiguration.fetch( - server_token=server_token, server_url=server_url, mode="online" + server_token=server_token, + server_url=server_url, + mode="online", ) for label, value in zip( ("URL", "API token"), (self._user_config.server.url, self._user_config.server.url), + strict=True, ): if not value: - logger.warning(f"No {label} specified") + logger.warning("No %s specified", label) self._headers: dict[str, str] = self._user_config.headers | { "Accept-Encoding": "gzip", @@ -98,9 +104,10 @@ def __init__( @prettify_pydantic @pydantic.validate_call def get_run_id_from_name( - self, name: typing.Annotated[str, pydantic.Field(pattern=NAME_REGEX)] + self, + name: typing.Annotated[str, pydantic.Field(pattern=NAME_REGEX)], ) -> str: - """Get Run ID from the server matching the specified name + """Get Run ID from the server matching the specified name. Assumes a unique name for this run. If multiple results are found this method will fail. @@ -120,6 +127,7 @@ def get_run_id_from_name( RuntimeError if either information could not be retrieved from the server, or multiple/no runs are found + """ _runs = Run.get(filters=json.dumps([f"name == {name}"])) @@ -130,7 +138,7 @@ def get_run_id_from_name( if next(_runs, None): raise RuntimeError( - "Could not collect ID - more than one run exists with this name." + "Could not collect ID - more than one run exists with this name.", ) return _id @@ -138,7 +146,7 @@ def get_run_id_from_name( @prettify_pydantic @pydantic.validate_call def get_run(self, run_id: str) -> Run | None: - """Retrieve a single run + """Retrieve a single run. Parameters ---------- @@ -154,6 +162,7 @@ def get_run(self, run_id: str) -> Run | None: ------ RuntimeError if retrieval of information from the server on this run failed + """ return Run( identifier=run_id, @@ -165,7 +174,7 @@ def get_run(self, run_id: str) -> Run | None: @prettify_pydantic @pydantic.validate_call def get_run_name_from_id(self, run_id: str) -> str: - """Retrieve the name of a run from its identifier + """Retrieve the name of a run from its identifier. Parameters ---------- @@ -176,6 +185,7 @@ def get_run_name_from_id(self, run_id: str) -> str: ------- str the registered name for the run + """ return Run( identifier=run_id, @@ -260,6 +270,7 @@ def get_runs( if a value outside of 'dict' or 'dataframe' is specified RuntimeError if there was a failure in data retrieval from the server + """ filters = filters or [] if not show_shared: @@ -278,7 +289,10 @@ def get_runs( return_metadata=metadata, server_url=self._user_config.server.url, server_token=self._user_config.server.token, - sorting=[dict(zip(("column", "descending"), a)) for a in sort_by_columns] + sorting=[ + dict(zip(("column", "descending"), a, strict=True)) + for a in sort_by_columns + ] if sort_by_columns else None, ) @@ -297,7 +311,7 @@ def get_runs( @prettify_pydantic @pydantic.validate_call def delete_run(self, run_id: str) -> dict | None: - """Delete run by identifier + """Delete run by identifier. Parameters ---------- @@ -313,6 +327,7 @@ def delete_run(self, run_id: str) -> dict | None: ------ RuntimeError if the deletion failed due to server request error + """ return ( Run( @@ -324,7 +339,7 @@ def delete_run(self, run_id: str) -> dict | None: ) def _get_folder_from_path(self, path: str) -> Folder | None: - """Retrieve folder for the specified path if found + """Retrieve folder for the specified path if found. Parameters ---------- @@ -335,6 +350,7 @@ def _get_folder_from_path(self, path: str) -> Folder | None: ------- Folder | None if a match is found, return the folder + """ _folders = Folder.get( filters=json.dumps([f"path == {path}"]), @@ -347,7 +363,7 @@ def _get_folder_from_path(self, path: str) -> Folder | None: return _folder def _get_folder_id_from_path(self, path: str) -> str | None: - """Retrieve folder identifier for the specified path if found + """Retrieve folder identifier for the specified path if found. Parameters ---------- @@ -358,6 +374,7 @@ def _get_folder_id_from_path(self, path: str) -> str | None: ------- str | None if a match is found, return the identifier of the folder + """ _ids = Folder.ids( filters=json.dumps([f"path == {path}"]), @@ -370,7 +387,7 @@ def _get_folder_id_from_path(self, path: str) -> str | None: if next(_ids, None): raise RuntimeError( - f"Expected single folder match for '{path}', but found duplicate." + f"Expected single folder match for '{path}', but found duplicate.", ) return _id @@ -378,9 +395,10 @@ def _get_folder_id_from_path(self, path: str) -> str | None: @prettify_pydantic @pydantic.validate_call def delete_runs( - self, folder_path: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)] + self, + folder_path: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)], ) -> list | None: - """Delete runs in a named folder + """Delete runs in a named folder. Parameters ---------- @@ -397,6 +415,7 @@ def delete_runs( ------ RuntimeError if deletion fails due to server request error + """ if not (_folder := self._get_folder_from_path(folder_path)): raise ValueError(f"Could not find a folder matching '{folder_path}'") @@ -413,7 +432,7 @@ def delete_folder( remove_runs: bool = False, allow_missing: bool = False, ) -> list | None: - """Delete a folder by name + """Delete a folder by name. Parameters ---------- @@ -437,17 +456,17 @@ def delete_folder( ------ RuntimeError if deletion of the folder from the server failed + """ folder_id = self._get_folder_id_from_path(folder_path) if not folder_id: if allow_missing: return None - else: - raise ObjectNotFoundError( - name=folder_path, - obj_type="folder", - ) + raise ObjectNotFoundError( + name=folder_path, + obj_type="folder", + ) _response = Folder( identifier=folder_id, server_url=self._user_config.server.url, @@ -466,25 +485,28 @@ def delete_folder( @prettify_pydantic @pydantic.validate_call def delete_alert(self, alert_id: str) -> None: - """Delete an alert from the server by ID + """Delete an alert from the server by ID. Parameters ---------- alert_id : str the unique identifier for the alert + """ Alert( identifier=alert_id, server_url=self._user_config.server.url, server_token=self._user_config.server.token, - ).delete() # type: ignore + ).delete() @prettify_pydantic @pydantic.validate_call def list_artifacts( - self, run_id: str, sort_by_columns: list[tuple[str, bool]] | None = None + self, + run_id: str, + sort_by_columns: list[tuple[str, bool]] | None = None, ) -> Generator[Artifact]: - """Retrieve artifacts for a given run + """Retrieve artifacts for a given run. Parameters ---------- @@ -504,18 +526,24 @@ def list_artifacts( ------ RuntimeError if retrieval of artifacts failed when communicating with the server + """ return Artifact.get( runs=json.dumps([run_id]), server_url=self._user_config.server.url, server_token=self._user_config.server.token, - sorting=[dict(zip(("column", "descending"), a)) for a in sort_by_columns] + sorting=[ + dict(zip(("column", "descending"), a, strict=True)) + for a in sort_by_columns + ] if sort_by_columns else None, - ) # type: ignore + ) def _retrieve_artifacts_from_server( - self, run_id: str, name: str + self, + run_id: str, + name: str, ) -> FileArtifact | ObjectArtifact | None: return Artifact.from_name( run_id=run_id, @@ -527,7 +555,7 @@ def _retrieve_artifacts_from_server( @prettify_pydantic @pydantic.validate_call def abort_run(self, run_id: str, reason: str) -> dict | list: - """Abort a currently active run on the server + """Abort a currently active run on the server. Parameters ---------- @@ -540,6 +568,7 @@ def abort_run(self, run_id: str, reason: str) -> dict | list: ------- dict | list response from server + """ return Run( identifier=run_id, @@ -550,9 +579,13 @@ def abort_run(self, run_id: str, reason: str) -> dict | list: @prettify_pydantic @pydantic.validate_call def get_artifact( - self, run_id: str, name: str, allow_pickle: bool = False + self, + run_id: str, + name: str, + *, + allow_pickle: bool = False, ) -> typing.Any: - """Return the contents of a specified artifact + """Return the contents of a specified artifact. Parameters ---------- @@ -573,6 +606,7 @@ def get_artifact( ------ RuntimeError if retrieval of artifact from the server failed + """ _artifact = self._retrieve_artifacts_from_server(run_id, name) @@ -586,7 +620,9 @@ def get_artifact( _content = b"".join(_artifact.download_content()) _deserialized_content: DeserializedContent | None = deserialize_data( - _content, _artifact.mime_type, allow_pickle + _content, + _artifact.mime_type, + allow_pickle=allow_pickle, ) # Numpy array return means just 'if content' will be ambiguous @@ -601,7 +637,7 @@ def get_artifact_as_file( name: str, output_dir: pydantic.DirectoryPath | None = None, ) -> None: - """Retrieve the specified artifact in the form of a file + """Retrieve the specified artifact in the form of a file. Information is saved to a file as opposed to deserialized @@ -620,6 +656,7 @@ def get_artifact_as_file( RuntimeError if there was a failure during retrieval of information from the server + """ _artifact = self._retrieve_artifacts_from_server(run_id, name) @@ -640,7 +677,7 @@ def get_artifacts_as_files( category: typing.Literal["input", "output", "code"] | None = None, output_dir: pydantic.DirectoryPath | None = None, ) -> None: - """Retrieve artifacts from the given run as a set of files + """Retrieve artifacts from the given run as a set of files. Parameters ---------- @@ -659,6 +696,7 @@ def get_artifacts_as_files( ------ RuntimeError if there was a failure retrieving artifacts from the server + """ _artifacts: Generator[tuple[str, Artifact]] = Artifact.from_run( server_url=self._user_config.server.url, @@ -668,19 +706,25 @@ def get_artifacts_as_files( ) with ThreadPoolExecutor( - CONCURRENT_DOWNLOADS, thread_name_prefix=f"get_artifacts_run_{run_id}" + CONCURRENT_DOWNLOADS, + thread_name_prefix=f"get_artifacts_run_{run_id}", ) as executor: - futures = [ - executor.submit(_download_artifact_to_file, artifact, output_dir) + future_artifact_mapping = { + executor.submit( + _download_artifact_to_file, + artifact, + output_dir, + ): artifact for _, artifact in _artifacts - ] - for future, (_, artifact) in zip(as_completed(futures), _artifacts): + } + for future in as_completed(future_artifact_mapping): + _artifact = future_artifact_mapping[future] try: future.result() except Exception as e: raise RuntimeError( - f"Download of file {artifact.storage_url} " - f"failed with exception: {e}" + f"Download of file {_artifact.storage_url} " + f"failed with exception: {e}", ) from e @prettify_pydantic @@ -688,9 +732,10 @@ def get_artifacts_as_files( def get_folder( self, folder_path: typing.Annotated[str, pydantic.Field(pattern=FOLDER_REGEX)], + *, read_only: bool = True, ) -> Folder | None: - """Retrieve a folder by identifier + """Retrieve a folder by identifier. Parameters ---------- @@ -711,6 +756,7 @@ def get_folder( ------ RuntimeError if there was a failure when retrieving information from the server + """ try: _folder = get_folder_from_path(path=folder_path) @@ -728,7 +774,7 @@ def get_folders( start_index: pydantic.NonNegativeInt = 0, sort_by_columns: list[tuple[str, bool]] | None = None, ) -> Generator[tuple[str, Folder]]: - """Retrieve folders from the server + """Retrieve folders from the server. Parameters ---------- @@ -752,6 +798,7 @@ def get_folders( ------ RuntimeError if there was a failure retrieving data from the server + """ return Folder.get( filters=json.dumps(filters or []), @@ -759,15 +806,18 @@ def get_folders( offset=start_index, server_url=self._user_config.server.url, server_token=self._user_config.server.token, - sorting=[dict(zip(("column", "descending"), a)) for a in sort_by_columns] + sorting=[ + dict(zip(("column", "descending"), a, strict=True)) + for a in sort_by_columns + ] if sort_by_columns else None, - ) # type: ignore + ) @prettify_pydantic @pydantic.validate_call def get_metrics_names(self, run_id: str) -> Generator[str]: - """Return information on all metrics within a run + """Return information on all metrics within a run. Parameters ---------- @@ -783,14 +833,16 @@ def get_metrics_names(self, run_id: str) -> Generator[str]: ------ RuntimeError if there was a failure retrieving information from the server + """ _run = Run(identifier=run_id) - for id, _ in _run.metrics: - yield id + for _id, _ in _run.metrics: + yield _id def _get_run_metrics_from_server( self, + *, metric_names: list[str], run_ids: list[str], xaxis: str, @@ -809,6 +861,7 @@ def _get_run_metrics_from_server( f"{self._user_config.server.url}/metrics", headers=self._headers, params=params, + timeout=DEFAULT_API_TIMEOUT, ) return get_json_from_response( @@ -831,7 +884,7 @@ def get_metric_values( aggregate: bool = False, max_points: pydantic.PositiveInt | None = None, ) -> dict | DataFrame | None: - """Retrieve the values for a given metric across multiple runs + """Retrieve the values for a given metric across multiple runs. Uses filters to specify which runs should be retrieved. @@ -869,6 +922,7 @@ def get_metric_values( dict or DataFrame or None values for the given metric at each time interval if no runs pass filtering then return None + """ if not metric_names: raise ValueError("No metric names were provided") @@ -876,13 +930,13 @@ def get_metric_values( if run_filters and run_ids: raise AssertionError( "Specification of both 'run_ids' and 'run_filters' " - "in get_metric_values is ambiguous" + "in get_metric_values is ambiguous", ) if xaxis == "timestamp" and aggregate: raise AssertionError( "Cannot return metric values with options 'aggregate=True' and " - "'xaxis=timestamp'" + "'xaxis=timestamp'", ) _args = {"filters": json.dumps(run_filters)} if run_filters else {} @@ -902,12 +956,13 @@ def get_metric_values( return None if aggregate: return aggregated_metrics_to_dataframe( - _run_metrics, xaxis=xaxis, parse_to=output_format + _run_metrics, + xaxis=xaxis, + parse_to=output_format, ) if use_run_names: _run_metrics = { - Run(identifier=key).name: _run_metrics[key] - for key in _run_metrics.keys() + Run(identifier=key).name: _run_metrics[key] for key in _run_metrics } return parse_run_set_metrics( _run_metrics, @@ -927,7 +982,7 @@ def plot_metrics( *, max_points: int | None = None, ) -> typing.Any: - """Plt the time series values for multiple metrics/runs + """Plt the time series values for multiple metrics/runs. Parameters ---------- @@ -949,6 +1004,7 @@ def plot_metrics( ------ ValueError if invalid arguments are provided + """ if not isinstance(run_ids, list): raise ValueError("Invalid runs specified, must be a list of run names.") @@ -956,7 +1012,7 @@ def plot_metrics( if not isinstance(metric_names, list): raise ValueError("Invalid names specified, must be a list of metric names.") - data: DataFrame = self.get_metric_values( # type: ignore + data: DataFrame = self.get_metric_values( run_ids=run_ids, metric_names=metric_names, xaxis=xaxis, @@ -967,7 +1023,8 @@ def plot_metrics( if data is None: raise RuntimeError( - f"Cannot plot metrics {metric_names}, no data found for runs {run_ids}." + f"Cannot plot metrics {metric_names}, " + f"no data found for runs {run_ids}.", ) # Undo multi-indexing @@ -1011,7 +1068,7 @@ def get_events( start_index: pydantic.NonNegativeInt | None = None, count_limit: pydantic.PositiveInt | None = None, ) -> list[dict[str, str]]: - """Return events for a specified run + """Return events for a specified run. Parameters ---------- @@ -1033,8 +1090,8 @@ def get_events( ------ RuntimeError if there was a failure retrieving information from the server - """ + """ msg_filter: str = ( json.dumps([f"event.message contains {message_contains}"]) if message_contains @@ -1052,6 +1109,7 @@ def get_events( f"{self._user_config.server.url}/events", headers=self._headers, params=params, + timeout=DEFAULT_API_TIMEOUT, ) json_response = get_json_from_response( @@ -1074,16 +1132,18 @@ def get_alerts( count_limit: pydantic.PositiveInt | None = None, sort_by_columns: list[tuple[str, bool]] | None = None, ) -> list[AlertBase] | list[str | None]: - """Retrieve alerts for a given run + """Retrieve alerts for a given run. Parameters ---------- run_id : str | None The ID of the run to find alerts for critical_only : bool, optional - If a run is specified, whether to only return details about alerts which are currently critical, by default True + If a run is specified, whether to only return details about alerts + which are currently critical, by default True names_only: bool, optional - Whether to only return the names of the alerts (otherwise return the full details of the alerts), by default True + Whether to only return the names of the alerts (otherwise return + the full details of the alerts), by default True start_index : typing.int, optional slice results returning only those above this index, by default None count_limit : typing.int, optional @@ -1102,29 +1162,32 @@ def get_alerts( ------ RuntimeError if there was a failure retrieving data from the server + """ if not run_id: if critical_only: raise RuntimeError( - "critical_only is ambiguous when returning alerts with no run ID specified." + "critical_only is ambiguous when returning alerts " + "with no run ID specified.", ) return [ alert.name if names_only else alert for _, alert in Alert.get( sorting=[ - dict(zip(("column", "descending"), a)) for a in sort_by_columns + dict(zip(("column", "descending"), a, strict=True)) + for a in sort_by_columns ] if sort_by_columns else None, count=count_limit, offset=start_index, ) - ] # type: ignore + ] if sort_by_columns: logger.warning( "Run identifier specified for alert retrieval," - " argument 'sort_by_columns' will be ignored" + " argument 'sort_by_columns' will be ignored", ) _alerts = [ @@ -1152,7 +1215,7 @@ def get_tags( count_limit: pydantic.PositiveInt | None = None, sort_by_columns: list[tuple[str, bool]] | None = None, ) -> Generator[Tag]: - """Retrieve tags + """Retrieve tags. Parameters ---------- @@ -1175,13 +1238,17 @@ def get_tags( ------ RuntimeError if there was a failure retrieving data from the server + """ return Tag.get( count=count_limit, offset=start_index, server_url=self._user_config.server.url, server_token=self._user_config.server.token, - sorting=[dict(zip(("column", "descending"), a)) for a in sort_by_columns] + sorting=[ + dict(zip(("column", "descending"), a, strict=True)) + for a in sort_by_columns + ] if sort_by_columns else None, ) @@ -1189,7 +1256,7 @@ def get_tags( @prettify_pydantic @pydantic.validate_call def delete_tag(self, tag_id: str) -> None: - """Delete a tag by its identifier + """Delete a tag by its identifier. Parameters ---------- @@ -1200,6 +1267,7 @@ def delete_tag(self, tag_id: str) -> None: ------ RuntimeError if the deletion failed due to a server request error + """ with contextlib.suppress(ValueError): Tag( @@ -1211,7 +1279,7 @@ def delete_tag(self, tag_id: str) -> None: @prettify_pydantic @pydantic.validate_call def get_tag(self, tag_id: str) -> Tag: - """Retrieve a single tag + """Retrieve a single tag. Parameters ---------- @@ -1229,6 +1297,7 @@ def get_tag(self, tag_id: str) -> Tag: if retrieval of information from the server on this tag failed ObjectNotFoundError if tag does not exist + """ return Tag( identifier=tag_id, diff --git a/simvue/config/__init__.py b/simvue/config/__init__.py index 28232a84..80f76828 100644 --- a/simvue/config/__init__.py +++ b/simvue/config/__init__.py @@ -1,5 +1,4 @@ -""" -Simvue Configuration +"""Simvue Configuration. ==================== This module contains definitions for the Simvue configuration options diff --git a/simvue/config/files.py b/simvue/config/files.py index f00bc6dc..cb3ae791 100644 --- a/simvue/config/files.py +++ b/simvue/config/files.py @@ -1,5 +1,4 @@ -""" -Simvue Config File Lists +"""Simvue Config File Lists. ======================== Contains lists of valid Simvue configuration file names. diff --git a/simvue/config/parameters.py b/simvue/config/parameters.py index 67f67d79..65410845 100644 --- a/simvue/config/parameters.py +++ b/simvue/config/parameters.py @@ -1,5 +1,4 @@ -""" -Simvue Configuration File Models +"""Simvue Configuration File Models. ================================ Pydantic models for elements of the Simvue configuration file @@ -8,17 +7,17 @@ import logging import os +import pathlib import time -import pydantic import typing -import pathlib + +import pydantic import simvue.models as sv_models -from simvue.utilities import get_expiry from simvue.api.url import URL +from simvue.utilities import get_expiry - -logger = logging.getLogger(__file__) +logger = logging.getLogger(__name__) class ServerSpecifications(pydantic.BaseModel): @@ -34,18 +33,18 @@ class ServerSpecifications(pydantic.BaseModel): @classmethod def url_to_api_url(cls, v: typing.Any) -> str | None: if not v: - return + return None if f"{v}".endswith("/api"): return f"{v}" _url = URL(f"{v}") / "api" return f"{_url}" @pydantic.field_validator("token") - def check_token(cls, v: typing.Any) -> str | None: + @classmethod + def check_token(cls, v: pydantic.SecretStr | None) -> pydantic.SecretStr | None: if not v: - return - value = v.get_secret_value() - if not (expiry := get_expiry(value)): + return None + if not (expiry := get_expiry(v.get_secret_value())): raise AssertionError("Failed to parse Simvue token - invalid token form") if time.time() - expiry > 0: raise AssertionError("Simvue token has expired") diff --git a/simvue/config/user.py b/simvue/config/user.py index 067298e7..13398749 100644 --- a/simvue/config/user.py +++ b/simvue/config/user.py @@ -1,21 +1,20 @@ -""" -Simvue Configuration File Model +"""Simvue Configuration File Model. =============================== Pydantic model for the Simvue TOML configuration file """ -from collections.abc import Generator import functools +import http import logging import os -import typing -import http import pathlib +import typing + import pydantic -import toml import semver +import toml try: from typing import Self @@ -23,23 +22,22 @@ from typing_extensions import Self import simvue.utilities as sv_util +from simvue.api.request import get as sv_get +from simvue.api.url import URL +from simvue.config.files import ( + CONFIG_FILE_NAMES, + CONFIG_INI_FILE_NAMES, + DEFAULT_OFFLINE_DIRECTORY, +) from simvue.config.parameters import ( ClientGeneralOptions, DefaultRunSpecifications, MetricsSpecifications, - ServerSpecifications, OfflineSpecifications, + ServerSpecifications, ) - -from simvue.config.files import ( - CONFIG_FILE_NAMES, - CONFIG_INI_FILE_NAMES, - DEFAULT_OFFLINE_DIRECTORY, -) -from simvue.version import __version__ -from simvue.api.request import get as sv_get -from simvue.api.url import URL from simvue.eco.config import EcoConfig +from simvue.version import __version__ logger = logging.getLogger(__name__) @@ -47,6 +45,10 @@ SIMVUE_SERVER_LOWER_CONSTRAINT: semver.Version | None = semver.Version.parse("1.1.0") +if typing.TYPE_CHECKING: + from collections.abc import Generator + + class SimvueConfiguration(pydantic.BaseModel): # Hide values as they contain token and URL model_config = pydantic.ConfigDict( @@ -57,10 +59,11 @@ class SimvueConfiguration(pydantic.BaseModel): ) client: ClientGeneralOptions = ClientGeneralOptions() server: ServerSpecifications = pydantic.Field( - ..., description="Specifications for Simvue server" + ..., + description="Specifications for Simvue server", ) profiles: dict[str, ServerSpecifications] = pydantic.Field( - default_factory=dict[str, ServerSpecifications] + default_factory=dict[str, ServerSpecifications], ) run: DefaultRunSpecifications = DefaultRunSpecifications() offline: OfflineSpecifications = OfflineSpecifications() @@ -78,9 +81,10 @@ def server_version(self) -> semver.Version: @classmethod def _load_pyproject_configs(cls) -> dict | None: - """Recover any Simvue non-authentication configurations from pyproject.toml""" + """Recover any Simvue non-authentication configurations from pyproject.toml.""" _pyproject_toml = sv_util.find_first_instance_of_file( - file_names=["pyproject.toml"], check_user_space=False + file_names=["pyproject.toml"], + check_user_space=False, ) if not _pyproject_toml: @@ -100,10 +104,11 @@ def _load_pyproject_configs(cls) -> dict | None: _server_credentials.get("token"), _server_credentials.get("url"), _offline_credentials.get("cache"), - ] + ], ): raise RuntimeError( - "Provision of Simvue URL, Token or offline directory in pyproject.toml is not allowed." + "Provision of Simvue URL, Token or offline directory in " + "pyproject.toml is not allowed.", ) return _simvue_setup @@ -111,9 +116,12 @@ def _load_pyproject_configs(cls) -> dict | None: @classmethod @functools.lru_cache def _check_server( - cls, token: str, url: str, mode: typing.Literal["offline", "online", "disabled"] + cls, + token: str, + url: str, + mode: typing.Literal["offline", "online", "disabled"], ) -> semver.Version | None: - if mode in ("offline", "disabled"): + if mode in {"offline", "disabled"}: return None headers: dict[str, str] = { @@ -134,7 +142,7 @@ def _check_server( except Exception as err: raise AssertionError( - f"Exception retrieving server version:\n {str(err)}" + f"Exception retrieving server version:\n {err!s}", ) from err _version = semver.Version.parse(_version_str) @@ -144,13 +152,14 @@ def _check_server( and _version >= SIMVUE_SERVER_UPPER_CONSTRAINT ): raise AssertionError( - f"Python API v{_version_str} is not compatible with Simvue server versions " - f">= {SIMVUE_SERVER_UPPER_CONSTRAINT}" + f"Python API v{_version_str} is not compatible " + "with Simvue server versions " + f">= {SIMVUE_SERVER_UPPER_CONSTRAINT}", ) if SIMVUE_SERVER_LOWER_CONSTRAINT and _version < SIMVUE_SERVER_LOWER_CONSTRAINT: raise AssertionError( - f"Python API v{_version_str} is not compatible with Simvue server versions " - f"< {SIMVUE_SERVER_LOWER_CONSTRAINT}" + f"Python API v{_version_str} is not compatible with Simvue " + f"server versions < {SIMVUE_SERVER_LOWER_CONSTRAINT}", ) return _version @@ -168,7 +177,9 @@ def check_valid_server(self) -> Self: raise ValueError("No token provided.") self._server_version = self._check_server( - self.server.token.get_secret_value(), self.server.url, self.run.mode + self.server.token.get_secret_value(), + self.server.url, + self.run.mode, ) return self @@ -182,7 +193,7 @@ def fetch( server_token: str | None = None, profile: str | None = None, ) -> "SimvueConfiguration": - """Retrieve the Simvue configuration from this project + """Retrieve the Simvue configuration from this project. Will retrieve the configuration options set for this project either using local or global configurations. @@ -225,7 +236,7 @@ def fetch( elif not _config_dict.get("profiles", {}).get(profile): raise RuntimeError( f"Cannot load server configuration for '{profile}', " - "profile not found in configurations." + "profile not found in configurations.", ) else: _config_dict["server"] = _config_dict["profiles"][profile] @@ -237,7 +248,8 @@ def fetch( # Allow override of specification of offline directory via environment variable if not (_default_dir := os.environ.get("SIMVUE_OFFLINE_DIRECTORY")): _default_dir = _config_dict["offline"].get( - "cache", DEFAULT_OFFLINE_DIRECTORY + "cache", + DEFAULT_OFFLINE_DIRECTORY, ) pathlib.Path(_default_dir).mkdir(parents=True, exist_ok=True) @@ -247,14 +259,16 @@ def fetch( # Environment Variables > Run Definition > Configuration File _server_url = os.environ.get( - "SIMVUE_URL", server_url or _config_dict["server"].get("url") + "SIMVUE_URL", + server_url or _config_dict["server"].get("url"), ) if isinstance(_server_url, URL): _server_url = str(_server_url) _server_token = os.environ.get( - "SIMVUE_TOKEN", server_token or _config_dict["server"].get("token") + "SIMVUE_TOKEN", + server_token or _config_dict["server"].get("token"), ) _run_mode = mode or _config_dict["run"].get("mode") or "online" @@ -284,18 +298,21 @@ def fetch( @classmethod @functools.lru_cache def config_file(cls) -> pathlib.Path: - """Returns the path of top level configuration file used for the session""" + """Returns the path of top level configuration file used for the session.""" _config_file: pathlib.Path | None = sv_util.find_first_instance_of_file( - CONFIG_FILE_NAMES, check_user_space=True + CONFIG_FILE_NAMES, + check_user_space=True, ) # NOTE: Legacy INI support has been removed if not _config_file and sv_util.find_first_instance_of_file( - CONFIG_INI_FILE_NAMES, check_user_space=True + CONFIG_INI_FILE_NAMES, + check_user_space=True, ): raise RuntimeError( - "Simvue INI configuration file format has been deprecated in simvue>=1.2, " - "please use TOML file" + "Simvue INI configuration file format has been " + "deprecated in simvue>=1.2, " + "please use TOML file", ) if not _config_file: diff --git a/simvue/converters.py b/simvue/converters.py index 944599d1..43389927 100644 --- a/simvue/converters.py +++ b/simvue/converters.py @@ -1,5 +1,4 @@ -""" -Converter Functions +"""Converter Functions. =================== Contains functions for converting objects retrieved from the server between @@ -7,12 +6,12 @@ """ import typing -import pandas -import flatdict +import flatdict +import pandas as pd if typing.TYPE_CHECKING: - from pandas import DataFrame + from pd import DataFrame def aggregated_metrics_to_dataframe( @@ -20,7 +19,7 @@ def aggregated_metrics_to_dataframe( xaxis: str, parse_to: typing.Literal["dict", "dataframe"] = "dict", ) -> typing.Union["DataFrame", dict[str, dict[tuple[float, str], float]] | None]: - """Create data frame for an aggregate of metrics + """Create data frame for an aggregate of metrics. Returns a dataframe with columns being metrics and sub-columns being the minimum, average etc. @@ -34,21 +33,21 @@ def aggregated_metrics_to_dataframe( parse_to : Literal["dict", "dataframe"], optional form of output * dict - dictionary of values. - * dataframe - dataframe (Pandas must be installed). + * dataframe - dataframe (pd must be installed). Returns ------- DataFrame | dict - a Pandas dataframe of the metric set or the data as a dictionary - """ + a pd dataframe of the metric set or the data as a dictionary + """ _all_steps: list[float] = sorted( { d[xaxis] for sublist in request_response_data.values() for d in sublist if xaxis in d - } + }, ) # Get the keys from the aggregate which are not the xaxis label @@ -72,17 +71,16 @@ def aggregated_metrics_to_dataframe( next_item = next(metrics_iterator) for value_type in _value_types: result_dict[metric_name][step, value_type] = next_item.get( - value_type + value_type, ) if parse_to == "dataframe": - _data_frame = pandas.DataFrame(result_dict) + _data_frame = pd.DataFrame(result_dict) _data_frame.index.name = xaxis return _data_frame - elif parse_to == "dict": + if parse_to == "dict": return result_dict - else: - raise ValueError(f"Unrecognised parse format '{parse_to}'") + raise ValueError(f"Unrecognised parse format '{parse_to}'") def parse_run_set_metrics( @@ -90,10 +88,10 @@ def parse_run_set_metrics( xaxis: str, run_labels: list[str], parse_to: typing.Literal["dict", "dataframe"] = "dict", -) -> typing.Union[dict[str, dict[tuple[float, str], float]] | None, "DataFrame"]: - """Parse JSON response metric data from the server into the specified form +) -> "dict[str, dict[tuple[float, str], float]] | DataFrame | None": + """Parse JSON response metric data from the server into the specified form. - Creates either a dictionary or a pandas dataframe of the data collected + Creates either a dictionary or a pd dataframe of the data collected from multiple runs and metrics Parameters @@ -107,20 +105,21 @@ def parse_run_set_metrics( parse_to : Literal["dict", "dataframe"], optional form in which to parse data * dict - return a values dictionary (default). - * dataframe - assembled into dataframe (requires Pandas). + * dataframe - assembled into dataframe (requires pd). Returns ------- dict[str, dict[tuple[float, str], float]] | None | DataFrame - either a dictionary or Pandas DataFrame containing the results + either a dictionary or pd DataFrame containing the results Raises ------ ValueError if an unrecognised parse format is specified + """ if not request_response_data: - return pandas.DataFrame({}) if parse_to == "dataframe" else {} + return pd.DataFrame({}) if parse_to == "dataframe" else {} _all_steps: list[float] = sorted( { @@ -129,11 +128,11 @@ def parse_run_set_metrics( for sublist in run_data.values() for d in sublist if xaxis in d - } + }, ) _all_metrics: list[str] = sorted( - {key for run_data in request_response_data.values() for key in run_data.keys()} + {key for run_data in request_response_data.values() for key in run_data}, ) # Get the keys from the aggregate which are not the xaxis label @@ -147,7 +146,11 @@ def parse_run_set_metrics( metric_name: {} for metric_name in _all_metrics } - for run_label, run_data in zip(run_labels, request_response_data.values()): + for run_label, run_data in zip( + run_labels, + request_response_data.values(), + strict=True, + ): for metric_name in _all_metrics: if metric_name not in run_data: for step in _all_steps: @@ -164,23 +167,20 @@ def parse_run_set_metrics( result_dict[metric_name][step, run_label] = next_item.get("value") if parse_to == "dataframe": - return pandas.DataFrame( + return pd.DataFrame( result_dict, - index=pandas.MultiIndex.from_product( - [_all_steps, run_labels], names=(xaxis, "run") + index=pd.MultiIndex.from_product( + [_all_steps, run_labels], + names=(xaxis, "run"), ), ) - elif parse_to == "dict": + if parse_to == "dict": return result_dict - else: - raise ValueError(f"Unrecognised parse format '{parse_to}'") + raise ValueError(f"Unrecognised parse format '{parse_to}'") -def to_dataframe(data) -> pandas.DataFrame: - """ - Convert runs to dataframe - """ - +def to_dataframe(data) -> pd.DataFrame: + """Convert runs to dataframe.""" metadata = [] system_columns = [] columns = { @@ -200,7 +200,7 @@ def to_dataframe(data) -> pandas.DataFrame: if isinstance(value, dict): system_columns += [ col_name - for sub_item in value.keys() + for sub_item in value if (col_name := f"system.{item}.{sub_item}") not in system_columns ] elif f"system.{item}" not in system_columns: @@ -217,7 +217,7 @@ def to_dataframe(data) -> pandas.DataFrame: except TypeError: value_.append(None) - return pandas.DataFrame(data=columns) + return pd.DataFrame(data=columns) def metric_time_series_to_dataframe( @@ -225,7 +225,7 @@ def metric_time_series_to_dataframe( xaxis: typing.Literal["step", "time", "timestamp"], name: str | None = None, ) -> "DataFrame": - """Convert a single metric value set from a run into a dataframe + """Convert a single metric value set from a run into a dataframe. Parameters ---------- @@ -242,12 +242,12 @@ def metric_time_series_to_dataframe( Returns ------- DataFrame - a Pandas DataFrame containing values for the metric and run at each - """ + a pd DataFrame containing values for the metric and run at each + """ _df_dict: dict[str, list[float]] = { xaxis: [v[xaxis] for v in data], name or "value": [v["value"] for v in data], } - return pandas.DataFrame(_df_dict) + return pd.DataFrame(_df_dict) diff --git a/simvue/dispatch/__init__.py b/simvue/dispatch/__init__.py index a10e2089..639858da 100644 --- a/simvue/dispatch/__init__.py +++ b/simvue/dispatch/__init__.py @@ -1,73 +1,5 @@ -"""Dispatch +"""Dispatch components.""" -Contains factory method for selecting dispatcher type based on Simvue Configuration -""" +from .dispatcher import Dispatcher -import typing -import logging - -if typing.TYPE_CHECKING: - from .base import DispatcherBaseClass - from threading import Event - -from .queued import QueuedDispatcher -from .direct import DirectDispatcher - -logger = logging.getLogger(__name__) - - -def Dispatcher( - mode: typing.Literal["direct", "queued"], - callback: typing.Callable[[list[typing.Any], str], None], - object_types: list[str], - termination_trigger: "Event", - name: str | None = None, - thresholds: dict[str, int | float] | None = None, -) -> "DispatcherBaseClass": - """Returns instance of dispatcher based on configuration - - Options are 'queued' which is the default and adds objects to a queue as well - as restricts the rate of dispatch, and 'prompt' which executes the callback - immediately - - Parameters - ---------- - mode : typing.Literal['prompt', 'queued'] - dispatcher mode - * prompt - execute callback immediately, do not queue. - * queue - execute callback on entries in a queue. - callback : typing.Callable[[list[typing.Any], str, dict[str, typing.Any]], None] - callback to be executed on each item provided - object_types : list[str] - categories, this is mainly used for creation of queues in a QueueDispatcher - termination_trigger : Event - event which triggers termination of the dispatcher - name : str | None, optional - name for the underlying thread, default None - thresholds: dict[str, int | float] | None, default None - if metadata is provided during item addition, specify - thresholds under which dispatch of an item is permitted, - default is None - - Returns - ------- - DispatcherBaseClass - either a DirectDispatcher or QueueDispatcher instance - """ - if mode == "direct": - logger.debug("Using direct dispatch for metric and queue sending") - return DirectDispatcher( - callback=callback, - object_types=object_types, - termination_trigger=termination_trigger, - thresholds=thresholds, - ) - else: - logger.debug("Using queued dispatch for metric and queue sending") - return QueuedDispatcher( - callback=callback, - object_types=object_types, - termination_trigger=termination_trigger, - name=name, - thresholds=thresholds, - ) +__all__ = ["Dispatcher"] diff --git a/simvue/dispatch/base.py b/simvue/dispatch/base.py index de3dea36..03682502 100644 --- a/simvue/dispatch/base.py +++ b/simvue/dispatch/base.py @@ -1,5 +1,5 @@ -import threading import abc +import threading import typing from simvue.exception import ObjectDispatchError @@ -34,6 +34,7 @@ def __init__( any additional thresholds to consider when handling items. This assumes metadata defining the values to compare to such thresholds is included when appending. + """ super().__init__() self._thresholds: dict[str, int | float] = thresholds or {} @@ -60,6 +61,7 @@ def add_item( metadata : dict[str, int | float] | None, optional additional metadata relating to the item to be used for threshold comparisons + """ _ = item _ = object_type @@ -68,36 +70,32 @@ def add_item( for key, threshold in self._thresholds.items(): if key in metadata and metadata[key] > threshold: raise ObjectDispatchError( - label=key, threshold=threshold, value=metadata[key] + label=key, + threshold=threshold, + value=metadata[key], ) @abc.abstractmethod def run(self) -> None: """Start the dispatcher.""" - pass @abc.abstractmethod def start(self) -> None: """Not used, this allows the class to be similar to a thread.""" - pass @abc.abstractmethod def join(self) -> None: """Not used, this allows the class to be similar to a thread.""" - pass @abc.abstractmethod def purge(self) -> None: """Clear the dispatcher of items.""" - pass @abc.abstractmethod def is_alive(self) -> bool: """Whether the dispatcher is operating correctly.""" - pass @property @abc.abstractmethod def empty(self) -> bool: """Whether the dispatcher is empty.""" - pass diff --git a/simvue/dispatch/direct.py b/simvue/dispatch/direct.py index fb96b83f..aaaf80e4 100644 --- a/simvue/dispatch/direct.py +++ b/simvue/dispatch/direct.py @@ -5,7 +5,7 @@ class DirectDispatcher(DispatcherBaseClass): - """The DirectDispatcher executes the provided callback immediately""" + """The DirectDispatcher executes the provided callback immediately.""" def __init__( self, @@ -15,7 +15,7 @@ def __init__( termination_trigger: threading.Event, thresholds: dict[str, int | float] | None = None, ) -> None: - """Initialise a new DirectDispatcher instance + """Initialise a new DirectDispatcher instance. Parameters ---------- @@ -29,6 +29,7 @@ def __init__( if metadata is provided during item addition, specify thresholds under which dispatch of an item is permitted, default is None + """ super().__init__( callback=callback, @@ -45,31 +46,27 @@ def add_item( metadata: dict[str, int | float] | None = None, **__, ) -> None: - """Execute callback on the given item""" + """Execute callback on the given item.""" super().add_item(item, object_type=object_type, metadata=metadata) self._callback([item], object_type) def run(self) -> None: - """Run does not execute anything in this context""" - pass + """Run does not execute anything in this context.""" def start(self) -> None: - """Start does not execute anything in this context""" - pass + """Start does not execute anything in this context.""" def join(self) -> None: - """Join does not execute anything in this context""" - pass + """Join does not execute anything in this context.""" def purge(self) -> None: - """Purge does not execute anything in this context""" - pass + """Purge does not execute anything in this context.""" def is_alive(self) -> bool: - """As unthreaded, state as not alive always""" + """As unthreaded, state as not alive always.""" return False @property def empty(self) -> bool: - """No queue so always empty""" + """No queue so always empty.""" return True diff --git a/simvue/dispatch/dispatcher.py b/simvue/dispatch/dispatcher.py new file mode 100644 index 00000000..cb088157 --- /dev/null +++ b/simvue/dispatch/dispatcher.py @@ -0,0 +1,74 @@ +"""General Dispatcher Initialisation. + +Contains factory method for selecting dispatcher type based on Simvue Configuration +""" + +import logging +import typing + +from .direct import DirectDispatcher +from .queued import QueuedDispatcher + +logger = logging.getLogger(__name__) + +if typing.TYPE_CHECKING: + from threading import Event + + from .base import DispatcherBaseClass + + +def Dispatcher( + mode: typing.Literal["direct", "queued"], + callback: typing.Callable[[list[typing.Any], str], None], + object_types: list[str], + termination_trigger: "Event", + name: str | None = None, + thresholds: dict[str, int | float] | None = None, +) -> "DispatcherBaseClass": + """Returns instance of dispatcher based on configuration. + + Options are 'queued' which is the default and adds objects to a queue as well + as restricts the rate of dispatch, and 'prompt' which executes the callback + immediately + + Parameters + ---------- + mode : typing.Literal['prompt', 'queued'] + dispatcher mode + * prompt - execute callback immediately, do not queue. + * queue - execute callback on entries in a queue. + callback : typing.Callable[[list[typing.Any], str, dict[str, typing.Any]], None] + callback to be executed on each item provided + object_types : list[str] + categories, this is mainly used for creation of queues in a QueueDispatcher + termination_trigger : Event + event which triggers termination of the dispatcher + name : str | None, optional + name for the underlying thread, default None + thresholds: dict[str, int | float] | None, default None + if metadata is provided during item addition, specify + thresholds under which dispatch of an item is permitted, + default is None + + Returns + ------- + DispatcherBaseClass + either a DirectDispatcher or QueueDispatcher instance + + """ + if mode == "direct": + logger.debug("Using direct dispatch for metric and queue sending") + return DirectDispatcher( + callback=callback, + object_types=object_types, + termination_trigger=termination_trigger, + thresholds=thresholds, + ) + logger.debug("Using queued dispatch for metric and queue sending") + return QueuedDispatcher( + callback=callback, + object_types=object_types, + termination_trigger=termination_trigger, + name=name, + thresholds=thresholds, + ) diff --git a/simvue/dispatch/queued.py b/simvue/dispatch/queued.py index bac71bb5..84dd7c56 100644 --- a/simvue/dispatch/queued.py +++ b/simvue/dispatch/queued.py @@ -1,5 +1,4 @@ -""" -Queue Dispatcher +"""Queue Dispatcher. ================ The QueueDispatcher provides a queue based system for execution of a callback on @@ -7,12 +6,12 @@ often the callback can be executed, and the number of items it is called on. """ +import contextlib import logging import queue import threading import time import typing -import contextlib from .base import DispatcherBaseClass @@ -24,11 +23,11 @@ class QueuedDispatcher(threading.Thread, DispatcherBaseClass): - """ - The QueuedDispatcher class enforces a maximum rate of execution for a given function - on items within a queue. Multiple queues can be defined with the dispatch - of each being executed in series. Items are added to a buffer which is handed - to the callback. + """The QueuedDispatcher class enforces a maximum rate of + execution for a given function on items within a queue. + Multiple queues can be defined with the dispatch of each + being executed in series. Items are added to a buffer which + is handed to the callback. """ def __init__( @@ -42,8 +41,7 @@ def __init__( max_read_rate: float = MAX_REQUESTS_PER_SECOND, thresholds: dict[str, int | float] | None = None, ) -> None: - """ - Initialise a new queue based dispatcher + """Initialise a new queue based dispatcher. Parameters ---------- @@ -64,6 +62,7 @@ def __init__( if metadata is provided during item addition, specify thresholds within which a single dispatch is permitted, default is None + """ DispatcherBaseClass.__init__( self, @@ -91,12 +90,12 @@ def add_item( blocking: bool = True, metadata: dict[str, int | float] | None = None, ) -> None: - """Add an item to the specified queue with/without blocking""" + """Add an item to the specified queue with/without blocking.""" super().add_item(item, object_type=object_type, metadata=metadata) if self._termination_trigger.is_set(): raise RuntimeError( f"Cannot append item '{item}' to queue '{object_type}', " - + "termination called." + "termination called.", ) if object_type not in self._queues: raise KeyError(f"No queue '{object_type}' found") @@ -104,11 +103,11 @@ def add_item( @property def empty(self) -> bool: - """Returns if all queues are empty""" + """Returns if all queues are empty.""" return all(queue.empty() for queue in self._queues.values()) def purge(self) -> None: - """Purge all queues""" + """Purge all queues.""" for q in self._queues.values(): while not q.empty(): with contextlib.suppress(queue.Empty): @@ -117,17 +116,17 @@ def purge(self) -> None: @property def _can_send(self) -> bool: - """Returns if time constraints are satisfied, hence the callback can be executed""" + """Returns if time constraints are satisfied, hence callback can be executed.""" return time.time() - self._send_timer >= 1 / self._max_read_rate def _create_buffer(self, queue_label: str) -> list[typing.Any]: - """Assemble queue items into a list as an argument to the callback + """Assemble queue items into a list as an argument to the callback. The length of the buffer is constrained. """ _buffer: list[typing.Any] = [] _criteria: dict[str, int | float] = {} - _threshold_totals: dict[str, float] = {k: 0 for k in self._thresholds} + _threshold_totals: dict[str, float] = dict.fromkeys(self._thresholds, 0) while ( not self._queues[queue_label].empty() @@ -149,7 +148,7 @@ def _create_buffer(self, queue_label: str) -> list[typing.Any]: return _buffer def run(self) -> None: - """Execute the dispatcher action + """Execute the dispatcher action. The action consists of a loop in which each queue is processed to create a buffer with number of entries equal or less than the maximum @@ -167,7 +166,9 @@ def run(self) -> None: for queue_label in self._queues: if _buffer := self._create_buffer(queue_label): logger.debug( - f"Executing '{queue_label}' callback on buffer {_buffer}" + "Executing '%s' callback on buffer %s", + queue_label, + _buffer, ) self._callback(_buffer, queue_label) self._send_timer = time.time() diff --git a/simvue/eco/__init__.py b/simvue/eco/__init__.py index 240c0c06..f3cc6e78 100644 --- a/simvue/eco/__init__.py +++ b/simvue/eco/__init__.py @@ -1,5 +1,4 @@ -""" -Simvue Eco +"""Simvue Eco. ========== Contains functionality for green IT, monitoring emissions etc. diff --git a/simvue/eco/api_client.py b/simvue/eco/api_client.py index 9d3b094f..ee869a86 100644 --- a/simvue/eco/api_client.py +++ b/simvue/eco/api_client.py @@ -1,5 +1,4 @@ -""" -CO2 Signal API Client +"""CO2 Signal API Client. ===================== Provides inteface to the CO2 Signal API, @@ -9,20 +8,25 @@ __date__ = "2025-02-27" -import requests -import pydantic +import datetime import functools import http import logging -import datetime +import typing + import geocoder import geocoder.location -import typing +import pydantic +import requests + +from simvue.api.request import DEFAULT_API_TIMEOUT CO2_SIGNAL_API_ENDPOINT: str = ( "https://api.electricitymap.org/v3/carbon-intensity/latest" ) +logger = logging.getLogger(__name__) + class CO2SignalData(pydantic.BaseModel): datetime: datetime.datetime @@ -38,7 +42,7 @@ class CO2SignalResponse(pydantic.BaseModel): def from_json_response(cls, json_response: dict) -> "CO2SignalResponse": _co2_signal_data = CO2SignalData( datetime=datetime.datetime.fromisoformat( - json_response["datetime"].replace("Z", "+00:00") + json_response["datetime"].replace("Z", "+00:00"), ), carbon_intensity=json_response["carbonIntensity"], ) @@ -49,9 +53,9 @@ def from_json_response(cls, json_response: dict) -> "CO2SignalResponse": ) -@functools.lru_cache() +@functools.lru_cache def _call_geocoder_query() -> typing.Any: - """Call GeoCoder API for IP location + """Call GeoCoder API for IP location. Cached so this API is only called once per session as required. """ @@ -59,47 +63,46 @@ def _call_geocoder_query() -> typing.Any: class APIClient(pydantic.BaseModel): - """ - CO2 Signal API Client + """CO2 Signal API Client. Provides an interface to the Electricity Maps API. + + Parameters + ---------- + co2_api_endpoint : str + endpoint for CO2 signal API + co2_api_token: str + The API token for the ElectricityMaps API, default is None. + timeout : int + timeout for API + """ co2_api_endpoint: pydantic.HttpUrl = pydantic.HttpUrl(CO2_SIGNAL_API_ENDPOINT) co2_api_token: pydantic.SecretStr | None = None timeout: pydantic.PositiveInt = 10 - def __init__(self, *args, **kwargs) -> None: - """Initialise the CO2 Signal API client. - - Parameters - ---------- - co2_api_endpoint : str - endpoint for CO2 signal API - co2_api_token: str - The API token for the ElectricityMaps API, default is None. - timeout : int - timeout for API - """ - super().__init__(*args, **kwargs) - self._logger = logging.getLogger(self.__class__.__name__) - + @pydantic.model_validator(mode="after") + def post_init(self) -> typing.Self: + """Post-initialise the CO2 Signal API client.""" if not self.co2_api_token: raise ValueError("API token is required for ElectricityMaps API.") self._get_user_location_info() + return self + def _get_user_location_info(self) -> None: """Retrieve location information for the current user.""" - self._logger.info("📍 Determining current user location.") + logger.info("📍 Determining current user location.") _current_user_loc_data: geocoder.location.BBox = _call_geocoder_query() self._latitude: float self._longitude: float self._latitude, self._longitude = _current_user_loc_data.latlng - self._two_letter_country_code: str = _current_user_loc_data.country # type: ignore + self._two_letter_country_code: str = _current_user_loc_data.country def get(self) -> CO2SignalResponse: - """Get the current data""" + """Get the current data.""" _params: dict[str, float | str] = { "zone": self._two_letter_country_code, } @@ -107,8 +110,12 @@ def get(self) -> CO2SignalResponse: if self.co2_api_token: _params["auth-token"] = self.co2_api_token.get_secret_value() - self._logger.debug(f"🍃 Retrieving carbon intensity data for: {_params}") - _response = requests.get(f"{self.co2_api_endpoint}", headers=_params) + logger.debug("🍃 Retrieving carbon intensity data for: %s", _params) + _response = requests.get( + f"{self.co2_api_endpoint}", + headers=_params, # FIXME: Should this be params= not headers=? + timeout=DEFAULT_API_TIMEOUT, + ) if _response.status_code != http.HTTPStatus.OK: try: @@ -116,23 +123,24 @@ def get(self) -> CO2SignalResponse: except (AttributeError, KeyError): _error = _response.text raise RuntimeError( - f"[{_response.status_code}] Failed to retrieve current CO2 signal data for" - f" country '{self._two_letter_country_code}': {_error}" + f"[{_response.status_code}] Failed to retrieve " + "current CO2 signal data for" + f" country '{self._two_letter_country_code}': {_error}", ) return CO2SignalResponse.from_json_response(_response.json()) @property def country_code(self) -> str: - """Returns the country code""" + """Returns the country code.""" return self._two_letter_country_code @property def latitude(self) -> float: - """Returns current latitude""" + """Returns current latitude.""" return self._latitude @property def longitude(self) -> float: - """Returns current longitude""" + """Returns current longitude.""" return self._longitude diff --git a/simvue/eco/config.py b/simvue/eco/config.py index 7e855b7c..d1877875 100644 --- a/simvue/eco/config.py +++ b/simvue/eco/config.py @@ -1,5 +1,4 @@ -""" -Eco Config +"""Eco Config. ========== Configuration file extension for configuring the Simvue Eco sub-module. @@ -21,6 +20,7 @@ class EcoConfig(pydantic.BaseModel): the TDP for the CPU gpu_thermal_design_power: int | None, optional the TDP for each GPU + """ co2_signal_api_token: pydantic.SecretStr | None = None @@ -28,6 +28,7 @@ class EcoConfig(pydantic.BaseModel): cpu_n_cores: pydantic.PositiveInt | None = None gpu_thermal_design_power: pydantic.PositiveInt | None = None intensity_refresh_interval: pydantic.PositiveInt | str | None = pydantic.Field( - default="1 hour", gt=2 * 60 + default="1 hour", + gt=2 * 60, ) co2_intensity: float | None = None diff --git a/simvue/eco/emissions_monitor.py b/simvue/eco/emissions_monitor.py index 0dc32d5c..d3da4628 100644 --- a/simvue/eco/emissions_monitor.py +++ b/simvue/eco/emissions_monitor.py @@ -1,5 +1,4 @@ -""" -CO2 Monitor +"""CO2 Monitor. =========== Provides an interface for estimating CO2 usage for processes on the CPU. @@ -8,20 +7,26 @@ __author__ = "Kristian Zarebski" __date__ = "2025-02-27" +import dataclasses import datetime import json -import pydantic -import dataclasses import logging -import humanfriendly -import pathlib import os.path +import typing + +import humanfriendly +import pydantic from simvue.eco.api_client import APIClient, CO2SignalResponse +if typing.TYPE_CHECKING: + import pathlib + TIME_FORMAT: str = "%Y_%m_%d_%H_%M_%S" CO2_SIGNAL_API_INTERVAL_LIMIT: int = 2 * 60 +logger = logging.getLogger(__name__) + @dataclasses.dataclass class ProcessData: @@ -35,113 +40,94 @@ class ProcessData: class CO2Monitor(pydantic.BaseModel): - """ - CO2 Monitor + """CO2 Monitor. Provides an interface for estimating CO2 usage for processes on the CPU. + + Parameters + ---------- + thermal_design_power_per_cpu: float | None + the TDP value for each CPU, default is 80W. + n_cores_per_cpu: int | None + the number of cores in each CPU, default is 4. + thermal_design_power_per_gpu: float | None + the TDP value for each GPU, default is 130W. + local_data_directory: pydantic.DirectoryPath + the directory in which to store CO2 intensity data. + intensity_refresh_interval: int | str | None + the interval in seconds at which to call the CO2 + signal API. The default is once per day, + note the API is restricted to 30 requests per hour + for a given user. Also accepts a + time period as a string, e.g. '1 week' + co2_intensity: float | None + disable using RestAPIs to retrieve CO2 intensity + and instead use this value. + Default is None, use remote data. Value is in kgCO2/kWh + co2_signal_api_token: str + The API token for CO2 signal, default is None. + offline: bool, optional + Run without any server interaction + """ thermal_design_power_per_cpu: pydantic.PositiveFloat | None n_cores_per_cpu: pydantic.PositiveInt | None thermal_design_power_per_gpu: pydantic.PositiveFloat | None local_data_directory: pydantic.DirectoryPath - intensity_refresh_interval: int | None | str + intensity_refresh_interval: int | str | None co2_intensity: float | None co2_signal_api_token: pydantic.SecretStr | None offline: bool = False - def now(self) -> str: - """Return data file timestamp for the current time""" - _now: datetime.datetime = datetime.datetime.now(datetime.timezone.utc) - return _now.strftime(TIME_FORMAT) - - @property - def outdated(self) -> bool: - """Checks if the current data is out of date.""" - if not self.intensity_refresh_interval: - return False - - _now: datetime.datetime = datetime.datetime.now() - _latest_time: datetime.datetime = datetime.datetime.strptime( - self._local_data["last_updated"], TIME_FORMAT - ) - return (_now - _latest_time).seconds > self.intensity_refresh_interval - - def _load_local_data(self) -> dict[str, str | dict[str, str | float]] | None: - """Loads locally stored CO2 intensity data""" - self._data_file_path = self.local_data_directory.joinpath( - "ecoclient_co2_intensity.json" - ) - - if not self._data_file_path.exists(): - return None - - with self._data_file_path.open() as in_f: - _data: dict[str, str | dict[str, str | float]] | None = json.load(in_f) - - return _data or None - - def __init__(self, *args, **kwargs) -> None: - """Initialise a CO2 Monitor. - - Parameters - ---------- - thermal_design_power_per_cpu: float | None - the TDP value for each CPU, default is 80W. - n_cores_per_cpu: int | None - the number of cores in each CPU, default is 4. - thermal_design_power_per_gpu: float | None - the TDP value for each GPU, default is 130W. - local_data_directory: pydantic.DirectoryPath - the directory in which to store CO2 intensity data. - intensity_refresh_interval: int | str | None - the interval in seconds at which to call the CO2 signal API. The default is once per day, - note the API is restricted to 30 requests per hour for a given user. Also accepts a - time period as a string, e.g. '1 week' - co2_intensity: float | None - disable using RestAPIs to retrieve CO2 intensity and instead use this value. - Default is None, use remote data. Value is in kgCO2/kWh - co2_signal_api_token: str - The API token for CO2 signal, default is None. - offline: bool, optional - Run without any server interaction - """ - _logger = logging.getLogger(self.__class__.__name__) + _last_local_write = pydantic.PrivateAttr(datetime.datetime.now(tz=datetime.UTC)) + @pydantic.model_validator(mode="before") + @classmethod + def check_api_arguments(cls, values: dict[str, Any]) -> dict[str, Any]: + """Check Argument Combinations.""" if not ( - kwargs.get("co2_intensity") - or kwargs.get("co2_signal_api_token") - or kwargs.get("offline") + values.get("co2_intensity") + or values.get("co2_signal_api_token") + or values.get("offline") ): raise ValueError( - "ElectricityMaps API token or hardcoeded CO2 intensity value is required for emissions tracking." + "ElectricityMaps API token or hardcoded CO2 " + "intensity value is required " + "for emissions tracking.", ) - - if not isinstance(kwargs.get("thermal_design_power_per_cpu"), float): - kwargs["thermal_design_power_per_cpu"] = 80.0 - _logger.warning( - "⚠️ No TDP value provided for current CPU, will use arbitrary value of 80W." + if not isinstance(values.get("thermal_design_power_per_cpu"), float): + values["thermal_design_power_per_cpu"] = 80.0 + logger.warning( + "⚠️ No TDP value provided for current CPU, will use " + "arbitrary value of 80W.", ) - if not isinstance(kwargs.get("n_cores_per_cpu"), float): - kwargs["n_cores_per_cpu"] = 4 - _logger.warning( - "⚠️ No core count provided for current CPU, will use arbitrary value of 4." + if not isinstance(values.get("n_cores_per_cpu"), float): + values["n_cores_per_cpu"] = 4 + logger.warning( + "⚠️ No core count provided for current CPU, will use " + "arbitrary value of 4.", ) - if not isinstance(kwargs.get("thermal_design_power_per_gpu"), float): - kwargs["thermal_design_power_per_gpu"] = 130.0 - _logger.warning( - "⚠️ No TDP value provided for current GPUs, will use arbitrary value of 130W." + if not isinstance(values.get("thermal_design_power_per_gpu"), float): + values["thermal_design_power_per_gpu"] = 130.0 + logger.warning( + "⚠️ No TDP value provided for current GPUs, " + "will use arbitrary value of 130W.", ) - super().__init__(*args, **kwargs) - self._last_local_write = datetime.datetime.now() + return values + + @pydantic.model_validator(mode="after") + def post_init_setup(self) -> typing.Self: + """Post initialisation setup.""" if self.intensity_refresh_interval and isinstance( - self.intensity_refresh_interval, str + self.intensity_refresh_interval, + str, ): self.intensity_refresh_interval = int( - humanfriendly.parse_timespan(self.intensity_refresh_interval) + humanfriendly.parse_timespan(self.intensity_refresh_interval), ) if ( @@ -149,22 +135,26 @@ def __init__(self, *args, **kwargs) -> None: and self.intensity_refresh_interval <= CO2_SIGNAL_API_INTERVAL_LIMIT ): raise ValueError( - "Invalid intensity refresh rate, CO2 signal API restricted to 30 calls per hour." + "Invalid intensity refresh rate, CO2 signal API restricted " + "to 30 calls per hour.", ) if self.co2_intensity: - _logger.warning( - f"⚠️ Disabling online data retrieval, using {self.co2_intensity} eqCO2g/kwh for CO2 intensity." + logger.warning( + "⚠️ Disabling online data retrieval, using %s " + "eqCO2g/kwh for CO2 intensity.", + self.co2_intensity, ) self._data_file_path: pathlib.Path | None = None - # Load any local data first, if the data is missing or due a refresh this will be None + # Load any local data first, if the data is missing or due a refresh + # this will be None self._local_data: dict[str, str | dict[str, float | str]] | None = ( self._load_local_data() or {} ) - self._measure_time = datetime.datetime.now() - self._logger = _logger + self._measure_time = datetime.datetime.now(datetime.UTC) + self._client: APIClient | None = ( None if self.co2_intensity or self.offline @@ -172,6 +162,40 @@ def __init__(self, *args, **kwargs) -> None: ) self._processes: dict[str, ProcessData] = {} + return self + + def now(self) -> str: + """Return data file timestamp for the current time.""" + _now: datetime.datetime = datetime.datetime.now(datetime.UTC) + return _now.strftime(TIME_FORMAT) + + @property + def outdated(self) -> bool: + """Checks if the current data is out of date.""" + if not self.intensity_refresh_interval: + return False + + _now: datetime.datetime = datetime.datetime.now(tz=datetime.UTC) + _latest_time: datetime.datetime = datetime.datetime.strptime( + self._local_data["last_updated"], + TIME_FORMAT, + ).replace(tzinfo=datetime.UTC) + return (_now - _latest_time).seconds > self.intensity_refresh_interval + + def _load_local_data(self) -> dict[str, str | dict[str, str | float]] | None: + """Loads locally stored CO2 intensity data.""" + self._data_file_path = self.local_data_directory.joinpath( + "ecoclient_co2_intensity.json", + ) + + if not self._data_file_path.exists(): + return None + + with self._data_file_path.open() as in_f: + _data: dict[str, str | dict[str, str | float]] | None = json.load(in_f) + + return _data or None + def check_refresh(self) -> bool: """Check to see if an intensity value refresh is required. @@ -180,14 +204,17 @@ def check_refresh(self) -> bool: bool whether a refresh of the CO2 intensity was requested from the CO2 Signal API. + """ # Need to check if the local cache has been modified # even if running offline if ( - self._data_file_path.exists() + self._data_file_path + and self._data_file_path.exists() and ( _check_write := datetime.datetime.fromtimestamp( - os.path.getmtime(f"{self._data_file_path}") + os.path.getmtime(f"{self._data_file_path}"), + tz=datetime.UTC, ) ) > self._last_local_write @@ -203,7 +230,7 @@ def check_refresh(self) -> bool: not self._local_data.setdefault(self._client.country_code, {}) or self.outdated ): - self._logger.info("🌍 CO2 emission outdated, calling API.") + logger.info("🌍 CO2 emission outdated, calling API.") _data: CO2SignalResponse = self._client.get() self._local_data[self._client.country_code] = _data.model_dump(mode="json") self._local_data["last_updated"] = self.now() @@ -219,12 +246,12 @@ def estimate_co2_emissions( gpu_percent: float | None, measure_interval: float, ) -> None: - """Estimate the CO2 emissions""" - self._logger.debug( + """Estimate the CO2 emissions.""" + logger.debug( f"📐 Estimating CO2 emissions from CPU usage of {cpu_percent}% " f"and GPU usage of {gpu_percent}%" if gpu_percent - else f"in interval {measure_interval}s." + else f"in interval {measure_interval}s.", ) if self._local_data is None: @@ -242,9 +269,9 @@ def estimate_co2_emissions( self.check_refresh() # If no local data yet then return if not (_country_codes := list(self._local_data.keys())): - self._logger.warning( + logger.warning( "No CO2 emission data recorded as no CO2 intensity value " - "has been provided and there is no local intensity data available." + "has been provided and there is no local intensity data available.", ) return False @@ -252,11 +279,13 @@ def estimate_co2_emissions( _country_code = self._client.country_code else: _country_code = _country_codes[0] - self._logger.debug( - f"🗂️ Using data for region '{_country_code}' from local cache for offline estimation." + logger.debug( + "🗂️ Using data for region '%s' from local " + "cache for offline estimation.", + _country_code, ) self._current_co2_data = CO2SignalResponse( - **self._local_data[_country_code] + **self._local_data[_country_code], ) _current_co2_intensity = self._current_co2_data.data.carbon_intensity _process.gpu_percentage = gpu_percent @@ -281,9 +310,17 @@ def estimate_co2_emissions( _process.co2_delta = _process.energy_delta * _carbon_intensity _process.co2_emission += _process.co2_delta - self._logger.debug( - f"📝 For process '{process_id}', in interval {measure_interval}, recorded: CPU={_process.cpu_percentage:.2f}%, " - f"Power={_process.power_usage:.2f}kW, Energy = {_process.energy_delta}kWh, CO2={_process.co2_delta:.2e}kg" + logger.debug( + "📝 For process '%s', in interval %s, " + "recorded: CPU=%s, " + "Power=%skW, Energy = " + "%s, CO2=%s", + process_id, + measure_interval, + f"{_process.cpu_percentage:.2f}%", + f"{_process.power_usage:.2f}kW", + f"{_process.energy_delta}kWh", + f"{_process.co2_delta:.2e}kg", ) return True diff --git a/simvue/exception.py b/simvue/exception.py index 09360a75..addcaec1 100644 --- a/simvue/exception.py +++ b/simvue/exception.py @@ -1,5 +1,4 @@ -""" -Simvue Exception Types +"""Simvue Exception Types. ====================== Custom exceptions for handling of Simvue request scenarions. @@ -8,26 +7,26 @@ class ObjectNotFoundError(Exception): - """For failure retrieving Simvue object from server""" + """For failure retrieving Simvue object from server.""" def __init__(self, obj_type: str, name: str, extra: str | None = None) -> None: super().__init__( f"Failed to retrieve '{name}' of type '{obj_type}' " f"{f'{extra}, ' if extra else ''}" - "no such object" + "no such object", ) class SimvueRunError(RuntimeError): - """A special sub-class of runtime error specifically for Simvue run errors""" + """A special sub-class of runtime error specifically for Simvue run errors.""" class ObjectDispatchError(Exception): """Raised if object dispatch failed due to condition.""" - def __init__(self, label: str, threshold: int | float, value: int | float) -> None: + def __init__(self, label: str, threshold: float, value: float) -> None: self.msg = ( f"Object dispatch failed, {label} " - + f"of {value} exceeds maximum permitted value of {threshold}" + f"of {value} exceeds maximum permitted value of {threshold}" ) super().__init__(self.msg) diff --git a/simvue/executor.py b/simvue/executor.py index ff8be2fc..f57c4d5b 100644 --- a/simvue/executor.py +++ b/simvue/executor.py @@ -1,25 +1,28 @@ """Simvue Job Executor. -Adds functionality for executing commands from the command line as part of a Simvue run, the executor -monitors the exit code of the command setting the status to failure if non-zero. +Adds functionality for executing commands from the command line +as part of a Simvue run, the executor monitors the exit code of +the command setting the status to failure if non-zero. Stdout and Stderr are sent to Simvue as artifacts. """ __author__ = "Kristian Zarebski" __date__ = "2023-11-15" +import contextlib import logging import multiprocessing.synchronize -import sys -import threading import os +import pathlib import shutil -import psutil import subprocess -import contextlib -import pathlib +import sys +import threading import time import typing + +import psutil + from simvue.api.objects.alert.user import UserAlert if typing.TYPE_CHECKING: @@ -52,16 +55,18 @@ def _execute_process( ) -> tuple[subprocess.Popen, threading.Thread | None]: thread_out = None - with open(f"{runner_name}_{proc_id}.err", "w") as err: - with open(f"{runner_name}_{proc_id}.out", "w") as out: - _result = subprocess.Popen( - command, - stdout=out, - stderr=err, - universal_newlines=True, - env=environment, - cwd=cwd, - ) + with ( + pathlib.Path(f"{runner_name}_{proc_id}.err").open("w", encoding="utf-8") as err, + pathlib.Path(f"{runner_name}_{proc_id}.out").open("w", encoding="utf-8") as out, + ): + _result = subprocess.Popen( + command, + stdout=out, + stderr=err, + universal_newlines=True, + env=environment, + cwd=cwd, + ) if completion_callback or completion_trigger: @@ -75,8 +80,12 @@ def trigger_check( if trigger_to_set: trigger_to_set.set() if completion_callback: - std_err = pathlib.Path(f"{runner_name}_{proc_id}.err").read_text() - std_out = pathlib.Path(f"{runner_name}_{proc_id}.out").read_text() + std_err = pathlib.Path(f"{runner_name}_{proc_id}.err").read_text( + encoding="utf-8", + ) + std_out = pathlib.Path(f"{runner_name}_{proc_id}.out").read_text( + encoding="utf-8", + ) completion_callback( status_code=process.returncode, std_out=std_out, @@ -95,15 +104,18 @@ def trigger_check( class Executor: - """Command Line command executor - - Adds execution of command line commands as part of a Simvue run, the status of these commands is monitored - and if non-zero cause the Simvue run to be stated as 'failed'. The executor accepts commands either as a - set of positional arguments or more specifically as components, two of these 'input_file' and 'script' then + """Command Line command executor. + + Adds execution of command line commands as part of a Simvue run, + the status of these commands is monitored + and if non-zero cause the Simvue run to be stated as 'failed'. + The executor accepts commands either as a + set of positional arguments or more specifically as components, + two of these 'input_file' and 'script' then being used to set the relevant metadata within the Simvue run itself. """ - def __init__(self, simvue_runner: "simvue.Run", keep_logs: bool = True) -> None: + def __init__(self, simvue_runner: "simvue.Run", *, keep_logs: bool = True) -> None: """Initialise an instance of the Simvue executor attaching it to a Run. Parameters @@ -112,12 +124,14 @@ def __init__(self, simvue_runner: "simvue.Run", keep_logs: bool = True) -> None: An instance of the Simvue runner used to send command execution feedback keep_logs : bool, optional whether to keep the stdout and stderr logs locally, by default False + """ self._runner = simvue_runner self._keep_logs = keep_logs self._completion_callbacks: dict[str, CompletionCallback] | None = {} self._completion_triggers: dict[ - str, multiprocessing.synchronize.Event | None + str, + multiprocessing.synchronize.Event | None, ] = {} self._completion_processes: dict[str, threading.Thread] | None = {} self._alert_ids: dict[str, str] = {} @@ -126,18 +140,18 @@ def __init__(self, simvue_runner: "simvue.Run", keep_logs: bool = True) -> None: self._all_processes: list[psutil.Process] = [] def std_out(self, process_id: str) -> str | None: - if not os.path.exists(out_file := f"{self._runner.name}_{process_id}.out"): + _out_file = pathlib.Path(f"{self._runner.name}_{process_id}.out") + if not _out_file.exists(): return None - with open(out_file) as out: - return out.read() or None + return _out_file.read_text(encoding="utf-8") or None def std_err(self, process_id: str) -> str | None: - if not os.path.exists(err_file := f"{self._runner.name}_{process_id}.err"): + _error_file = pathlib.Path(f"{self._runner.name}_{process_id}.err") + if not _error_file.exists(): return None - with open(err_file) as err: - return err.read() or None + return _error_file.read_text(encoding="utf-8") or None @staticmethod def _kwarg_assembly(kwargs, executable: str | None) -> list[str]: @@ -145,31 +159,31 @@ def _kwarg_assembly(kwargs, executable: str | None) -> list[str]: _shell_is_pwsh: bool = any( shell in get_current_shell() for shell in ("pwsh", "powershell") ) - _exec_is_pwsh: bool = executable in ("pwsh", "powershell", None) + _exec_is_pwsh: bool = executable in {"pwsh", "powershell", None} _use_pwsh: bool = _shell_is_pwsh and _exec_is_pwsh for arg, value in kwargs.items(): if arg.startswith("__"): continue - arg = arg.replace("_", "-") + _arg = arg.replace("_", "-") - if len(arg) == 1 or _use_pwsh: + if len(_arg) == 1 or _use_pwsh: _arguments += ( - [f"-{arg}"] + [f"-{_arg}"] if isinstance(value, bool) and value - else [f"-{arg}", f"{value}"] + else [f"-{_arg}", f"{value}"] ) elif isinstance(value, bool) and value: - _arguments += [f"--{arg}"] + _arguments += [f"--{_arg}"] else: - _arguments += [f"--{arg}", f"{value}"] + _arguments += [f"--{_arg}", f"{value}"] return _arguments def add_process( self, identifier: str, - *args, + *cmd_args, executable: str | None = None, script: pathlib.Path | None = None, input_file: pathlib.Path | None = None, @@ -179,11 +193,12 @@ def add_process( completion_trigger: threading.Event | multiprocessing.synchronize.Event | None = None, - **kwargs, + **cmd_kwargs, ) -> None: """Add a process to be executed to the executor. - This process can take many forms, for example a be a set of positional arguments: + This process can take many forms, for example a be a set + of positional arguments: ```python executor.add_process("my_process", "ls", "-ltr") @@ -192,16 +207,28 @@ def add_process( Provide explicitly the components of the command: ```python - executor.add_process("my_process", executable="bash", debug=True, c="return 1") - executor.add_process("my_process", executable="bash", script="my_script.sh", input="parameters.dat") + executor.add_process( + "my_process", + executable="bash", + debug=True, + c="return 1" + ) + executor.add_process( + "my_process", + executable="bash", + script="my_script.sh", + input="parameters.dat" + ) ``` - or a mixture of both. In the latter case arguments which are not 'executable', 'script', 'input' - are taken to be options to the command, for flags `flag=True` can be used to set the option and - for options taking values `option=value`. + or a mixture of both. In the latter case arguments which are + not 'executable', 'script', 'input' are taken to be options to the command, + for flags `flag=True` can be used to set the option and for options taking + values `option=value`. - When the process has completed if a function has been provided for the `completion_callback` argument - this will be called, this callback is expected to take the following form: + When the process has completed if a function has been provided for the + `completion_callback` argument this will be called, this callback is expected + to take the following form: ```python def callback_function(status_code: int, std_out: str, std_err: str) -> None: @@ -215,24 +242,35 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None: identifier : str A unique identifier for this process executable : str | None, optional - the main executable for the command, if not specified this is taken to be the first - positional argument, by default None - script : str | None, optional - the script to run, note this only work if the script is not an option, if this is the case - you should provide it as such and perform the upload manually, by default None - input_file : str | None, optional - the input file to run, note this only work if the input file is not an option, if this is the case - you should provide it as such and perform the upload manually, by default None + the main executable for the command, if not specified this is + taken to be the first positional argument, by default None + *cmd_args: Any, ..., optional + all other positional arguments are taken to be part of the + command to execute + script : pydantic.FilePath | None, optional + the script to run, note this only work if the script is not an option, + if this is the case you should provide it as such and perform the + upload manually, by default None + input_file : pydantic.FilePath | None, optional + the input file to run, note this only work if the input file is not an + option, if this is the case you should provide it as such and perform + the upload manually, by default None + completion_callback : typing.Callable | None, optional + callback to run when process terminates (not supported on Windows) + completion_trigger : threading.Event | None, optional + this trigger event is set when the processes completes env : dict[str, str], optional environment variables for process cwd: pathlib.Path | None, optional - working directory to execute the process within - completion_callback : typing.Callable | None, optional - callback to run when process terminates - completion_trigger : threading.Event | None, optional - this trigger event is set when the processes completes (not supported on Windows) + working directory to execute the process within. Note that executable, + input and script file paths should be absolute or relative to the + directory where this method is called, not relative to the new + working directory. + **cmd_kwargs: Any, ..., optional + all other keyword arguments are interpreted as options to the command + """ - pos_args = list(args) + pos_args = list(cmd_args) if not self._runner.name: raise RuntimeError("Cannot add process, expected Run instance to have name") @@ -240,7 +278,7 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None: if sys.platform == "win32" and completion_trigger: logger.warning( "Completion trigger for 'add_process' may fail on Windows " - "due to function pickling restrictions" + "due to function pickling restrictions", ) # To check the executable provided by the user exists combine with environment @@ -253,7 +291,8 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None: and not shutil.which(executable, path=_session_path) ): raise FileNotFoundError( - f"Executable '{executable}' does not exist, please check the path/environment." + f"Executable '{executable}' does not exist, please check the " + "path/environment.", ) if script: @@ -278,7 +317,7 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None: if input_file: command += [f"{input_file}"] - command += self._kwarg_assembly(kwargs, executable=executable) + command += self._kwarg_assembly(cmd_kwargs, executable=executable) command += pos_args @@ -299,7 +338,7 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None: ) self._alert_ids[identifier] = self._runner.create_user_alert( - name=f"{identifier}_exit_status" + name=f"{identifier}_exit_status", ) if not self._alert_ids[identifier]: @@ -307,7 +346,7 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None: @property def processes(self) -> list[psutil.Process]: - """Create an array containing a list of processes""" + """Create an array containing a list of processes.""" if not self._processes: return [] @@ -323,8 +362,8 @@ def processes(self) -> list[psutil.Process]: if child not in _current_processes: _current_processes.append(child) - _current_pids = set([_process.pid for _process in _current_processes]) - _previous_pids = set([_process.pid for _process in self._all_processes]) + _current_pids = {_process.pid for _process in _current_processes} + _previous_pids = {_process.pid for _process in self._all_processes} # Find processes which used to exist, which are no longer running _expired_process_pids = _previous_pids - _current_pids @@ -344,7 +383,8 @@ def processes(self) -> list[psutil.Process]: if _process.pid in _new_process_pids ] - # Get CPU usage stats for each of those new processes, so that next time it's measured by the heartbeat the value is accurate + # Get CPU usage stats for each of those new processes, so that next time it's + # measured by the heartbeat the value is accurate if _new_processes: [_process.cpu_percent() for _process in _new_processes] time.sleep(0.1) @@ -356,12 +396,12 @@ def processes(self) -> list[psutil.Process]: @property def success(self) -> int: - """Return whether all attached processes completed successfully""" + """Return whether all attached processes completed successfully.""" return all(i.returncode == 0 for i in self._processes.values()) @property def exit_status(self) -> int: - """Returns the first non-zero exit status if applicable""" + """Returns the first non-zero exit status if applicable.""" if _non_zero := [ i.returncode for i in self._processes.values() if i.returncode != 0 ]: @@ -370,7 +410,7 @@ def exit_status(self) -> int: return 0 def get_error_summary(self) -> dict[str, str] | None: - """Returns the summary messages of all errors""" + """Returns the summary messages of all errors.""" return { identifier: self._get_error_status(identifier) for identifier, value in self._processes.items() @@ -389,6 +429,7 @@ def get_command(self, process_id: str) -> str: ------- str command as a string + """ if process_id not in self._processes: raise KeyError(f"Failed to retrieve '{process_id}', no such process") @@ -396,18 +437,23 @@ def get_command(self, process_id: str) -> str: def _get_error_status(self, process_id: str) -> str | None: err_msg: str | None = None + line_length_cutoff: int = 10 # Return last 10 lines of stdout if stderr empty if not (err_msg := self.std_err(process_id)) and ( std_out := self.std_out(process_id) ): err_msg = " Tail STDOUT:\n\n" - start_index = -10 if len(lines := std_out.split("\n")) > 10 else 0 + start_index = ( + -line_length_cutoff + if len(lines := std_out.split("\n")) > line_length_cutoff + else 0 + ) err_msg += "\n".join(lines[start_index:]) return err_msg def _update_alerts(self) -> None: - """Send log events for the result of each process""" + """Send log events for the result of each process.""" # Wait for the dispatcher to send the latest information before # allowing the executor to finish (and as such the run instance to exit) _wait_limit: float = 1 @@ -417,8 +463,8 @@ def _update_alerts(self) -> None: # the user can manually set the correct status depending on logs etc. _alert = UserAlert( identifier=self._alert_ids[proc_id], - server_url=self._runner._user_config.server.url, - server_token=self._runner._user_config.server.token, + server_url=self._runner.user_config.server.url, + server_token=self._runner.user_config.server.token, ) _is_set: bool = False @@ -428,44 +474,50 @@ def _update_alerts(self) -> None: if process.returncode != 0: # If the process fails then purge the dispatcher event queue # and ensure that the stderr event is sent before the run closes - if self._runner._dispatcher: - self._runner._dispatcher.purge() + if self._runner.dispatcher: + self._runner.dispatcher.purge() if not _is_set: self._runner.log_alert( - identifier=self._alert_ids[proc_id], state="critical" + identifier=self._alert_ids[proc_id], + state="critical", ) elif self._runner.mode == "online" and not _is_set: self._runner.log_alert(identifier=self._alert_ids[proc_id], state="ok") _current_time: float = 0 while ( - self._runner._dispatcher - and not self._runner._dispatcher.empty + self._runner.dispatcher + and not self._runner.dispatcher.empty and _current_time < _wait_limit ): - time.sleep((_current_time := _current_time + 0.1)) + time.sleep(_current_time := _current_time + 0.1) def _save_output(self) -> None: - """Save the output to Simvue""" + """Save the output to Simvue.""" if self._runner.status != "running": logger.debug("Run is not active, skipping output save.") return - for proc_id in self._processes.keys(): + for proc_id in self._processes: # Only save the file if the contents are not empty if self.std_err(proc_id): self._runner.save_file( - f"{self._runner.name}_{proc_id}.err", category="output" + f"{self._runner.name}_{proc_id}.err", + category="output", ) if self.std_out(proc_id): self._runner.save_file( - f"{self._runner.name}_{proc_id}.out", category="output" + f"{self._runner.name}_{proc_id}.out", + category="output", ) def kill_process( - self, process_id: int | str, kill_children_only: bool = False + self, + process_id: int | str, + *, + kill_children_only: bool = False, ) -> None: - """Kill a running process by ID + """Kill a running process by ID. If argument is a string this is a process handled by the client, else it is a PID of a external monitored process @@ -477,11 +529,15 @@ def kill_process( of an external process kill_children_only : bool, optional if process_id is an integer, whether to kill only its children + """ + process = None + if isinstance(process_id, str): if (process := self._processes.get(process_id)) is None: logger.error( - f"Failed to terminate process '{process_id}', no such identifier." + "Failed to terminate process '%s' no such identifier.", + process.id, ) return try: @@ -495,34 +551,34 @@ def kill_process( return for child in parent.children(recursive=True): - logger.debug(f"Terminating child process {child.pid}: {child.name()}") + logger.debug("Terminating child process %s: %s", child.pid, child.name()) child.kill() for child in parent.children(recursive=True): child.wait() if not kill_children_only and process: - logger.debug(f"Terminating process {process.pid}: {process.args}") + logger.debug("Terminating process %s: %s", process.pid, process.args) process.kill() process.wait() self._save_output() def kill_all(self) -> None: - """Kill all running processes""" - for process in self._processes.keys(): + """Kill all running processes.""" + for process in self._processes: self.kill_process(process) def _clear_cache_files(self) -> None: - """Clear local log files if required""" + """Clear local log files if required.""" if not self._keep_logs: - for proc_id in self._processes.keys(): - os.remove(f"{self._runner.name}_{proc_id}.err") - os.remove(f"{self._runner.name}_{proc_id}.out") + for proc_id in self._processes: + pathlib.Path(f"{self._runner.name}_{proc_id}.err").unlink() + pathlib.Path(f"{self._runner.name}_{proc_id}.out").unlink() def wait_for_completion(self) -> None: - """Wait for all processes to finish then perform tidy up and upload""" - for identifier, process in self._processes.items(): + """Wait for all processes to finish then perform tidy up and upload.""" + for process in self._processes.values(): process.wait() self._update_alerts() diff --git a/simvue/handler.py b/simvue/handler.py index 95ae7ea1..b388c1ff 100644 --- a/simvue/handler.py +++ b/simvue/handler.py @@ -1,5 +1,6 @@ """Simvue logging handler.""" +import contextlib import logging import typing @@ -22,6 +23,7 @@ def __init__(self, simvue_run: "Run") -> None: ---------- simvue_run: simvue.Run run to attach this handler to + """ logging.Handler.__init__(self) self._run_object: Run = simvue_run @@ -34,10 +36,11 @@ def emit(self, record: logging.LogRecord) -> None: _msg: str = self.format(record) - try: + with contextlib.suppress(Exception): self._run_object.log_event(_msg) - except Exception: - logging.Handler.handleError(self, record) + return + + logging.Handler.handleError(self, record) @override def close(self) -> None: diff --git a/simvue/metadata.py b/simvue/metadata.py index 502f751f..8adcea63 100644 --- a/simvue/metadata.py +++ b/simvue/metadata.py @@ -1,5 +1,4 @@ -""" -Metadata +"""Metadata. ======== Contains functions for extracting additional metadata about the current project @@ -7,22 +6,24 @@ """ import contextlib -import typing +import fnmatch import json +import logging import os -import fnmatch +import pathlib +import typing + import toml import yaml -import logging -import pathlib +from pip._internal.operations.freeze import freeze from simvue.models import simvue_timestamp -logger = logging.getLogger(__file__) +logger = logging.getLogger(__name__) -def git_info(repository: str) -> dict[str, typing.Any]: - """Retrieves metadata for the target git repository +def git_info(repository: pathlib.Path) -> dict[str, typing.Any]: + """Retrieves metadata for the target git repository. This is a passive function which returns an empty dictionary if any metadata is missing. Exceptions are raised only if the repository @@ -37,6 +38,7 @@ def git_info(repository: str) -> dict[str, typing.Any]: ------- dict[str, typing.Any] metadata for the target repository + """ try: import git @@ -73,7 +75,7 @@ def git_info(repository: str) -> dict[str, typing.Any]: "blame": blame, "url": git_repo.remote().url, "dirty": dirty, - } + }, } except (git.InvalidGitRepositoryError, ValueError): return {} @@ -83,20 +85,29 @@ def _conda_dependency_parse(dependency: str) -> tuple[str, str] | None: """Parse a dependency definition into module-version.""" if dependency.startswith("::"): logger.warning( - f"Skipping Conda specific channel definition '{dependency}' in Python environment metadata." + "Skipping Conda specific channel definition '%s'" + "in Python environment metadata.", + dependency, ) return None - elif ">=" in dependency: + if ">=" in dependency: module, version = dependency.split(">=") logger.warning( - f"Ignoring '>=' constraint in Python package version, naively storing '{module}=={version}', " - "for a more accurate record use 'conda env export > environment.yml'" + "Ignoring '>=' constraint in Python package version, " + "naively storing '%s==%s', " + "for a more accurate record use 'conda env " + "export > environment.yml'", + module, + version, ) elif "~=" in dependency: module, version = dependency.split("~=") logger.warning( - f"Ignoring '~=' constraint in Python package version, naively storing '{module}=={version}', " - "for a more accurate record use 'conda env export > environment.yml'" + "Ignoring '~=' constraint in Python package version, " + "naively storing '%s==%s', " + "for a more accurate record use 'conda env export > environment.yml'", + module, + version, ) elif dependency.startswith("-e"): _, version = dependency.split("-e") @@ -114,7 +125,9 @@ def _conda_dependency_parse(dependency: str) -> tuple[str, str] | None: module = version.split("/")[-1].replace(".git", "") elif "==" not in dependency: logger.warning( - f"Ignoring '{dependency}' in Python environment record as no version constraint specified." + "Ignoring '%s' in Python environment record as " + "no version constraint specified.", + dependency, ) return None else: @@ -125,7 +138,7 @@ def _conda_dependency_parse(dependency: str) -> tuple[str, str] | None: def _conda_env(environment_file: pathlib.Path) -> dict[str, str]: """Parse/interpret a Conda environment file.""" - content = yaml.load(environment_file.open(), Loader=yaml.SafeLoader) + content = yaml.load(environment_file.open(encoding="utf-8"), Loader=yaml.SafeLoader) python_environment: dict[str, str] = {} pip_dependencies: list[str] = [] for dependency in content.get("dependencies", []): @@ -142,7 +155,7 @@ def _conda_env(environment_file: pathlib.Path) -> dict[str, str]: def _python_env(repository: pathlib.Path) -> dict[str, typing.Any]: - """Retrieve a dictionary of Python dependencies if lock file is available""" + """Retrieve a dictionary of Python dependencies if lock file is available.""" python_meta: dict[str, dict] = {} if (pyproject_file := pathlib.Path(repository).joinpath("pyproject.toml")).exists(): @@ -168,16 +181,14 @@ def _python_env(repository: pathlib.Path) -> dict[str, typing.Any]: python_meta["environment"] = { package["name"]: package["version"] for package in content } - # Handle Conda case, albeit naively given the user may or may not have used 'conda env' - # to dump their exact dependency versions + # Handle Conda case, albeit naively given the user may or may not + # have used 'conda env' to dump their exact dependency versions elif ( environment_file := pathlib.Path(repository).joinpath("environment.yml") ).exists(): python_meta["environment"] = _conda_env(environment_file) else: with contextlib.suppress((KeyError, ImportError)): - from pip._internal.operations.freeze import freeze - # Conda supports having file names with @ as entries # in the requirements.txt file as opposed to == python_meta["environment"] = {} @@ -197,7 +208,7 @@ def _python_env(repository: pathlib.Path) -> dict[str, typing.Any]: def _rust_env(repository: pathlib.Path) -> dict[str, typing.Any]: - """Retrieve a dictionary of Rust dependencies if lock file available""" + """Retrieve a dictionary of Rust dependencies if lock file available.""" rust_meta: dict[str, dict] = {} if (cargo_file := pathlib.Path(repository).joinpath("Cargo.toml")).exists(): @@ -221,16 +232,14 @@ def _rust_env(repository: pathlib.Path) -> dict[str, typing.Any]: def _julia_env(repository: pathlib.Path) -> dict[str, typing.Any]: - """Retrieve a dictionary of Julia dependencies if a project file is available""" + """Retrieve a dictionary of Julia dependencies if a project file is available.""" julia_meta: dict[str, dict] = {} if (project_file := pathlib.Path(repository).joinpath("Project.toml")).exists(): content = toml.load(project_file) julia_meta["project"] = { key: value for key, value in content.items() if not isinstance(value, dict) } - julia_meta["environment"] = { - key: value for key, value in content.get("compat", {}).items() - } + julia_meta["environment"] = dict(content.get("compat", {})) return julia_meta @@ -240,19 +249,22 @@ def _node_js_env(repository: pathlib.Path) -> dict[str, typing.Any]: project_file := pathlib.Path(repository).joinpath("package-lock.json") ).exists(): content = json.load(project_file.open()) - if (lfv := content["lockfileVersion"]) not in (1, 2, 3): + if (lfv := content["lockfileVersion"]) not in {1, 2, 3}: logger.warning( - f"Unsupported package-lock.json lockfileVersion {lfv}, ignoring JS project metadata" + "Unsupported package-lock.json lockfileVersion %s, " + "ignoring JS project metadata", + lfv, ) return {} js_meta["project"] = { - key: value for key, value in content.items() if key in ("name", "version") + key: value for key, value in content.items() if key in {"name", "version"} } js_meta["environment"] = { key.replace("@", ""): value["version"] for key, value in content.get( - "packages" if lfv in (2, 3) else "dependencies", {} + "packages" if lfv in {2, 3} else "dependencies", + {}, ).items() if key and not value.get("dev", True) } @@ -272,18 +284,19 @@ def _environment_variables(glob_exprs: list[str]) -> dict[str, str]: def environment( - repository: pathlib.Path = pathlib.Path.cwd(), + repository: pathlib.Path | None = None, env_var_glob_exprs: set[str] | None = None, ) -> dict[str, typing.Any]: - """Retrieve environment metadata""" + """Retrieve environment metadata.""" _environment_meta = {} - if _python_meta := _python_env(repository): + _repository: pathlib.Path = repository or pathlib.Path.cwd() + if _python_meta := _python_env(_repository): _environment_meta["python"] = _python_meta - if _rust_meta := _rust_env(repository): + if _rust_meta := _rust_env(_repository): _environment_meta["rust"] = _rust_meta - if _julia_meta := _julia_env(repository): + if _julia_meta := _julia_env(_repository): _environment_meta["julia"] = _julia_meta - if _js_meta := _node_js_env(repository): + if _js_meta := _node_js_env(_repository): _environment_meta["javascript"] = _js_meta if env_var_glob_exprs: _environment_meta["shell"] = _environment_variables(env_var_glob_exprs) diff --git a/simvue/metrics.py b/simvue/metrics.py index 2914b351..eb45b4b3 100644 --- a/simvue/metrics.py +++ b/simvue/metrics.py @@ -1,5 +1,4 @@ -""" -CPU/GPU Metrics +"""CPU/GPU Metrics. =============== Get information relating to the usage of the CPU and GPU (where applicable) @@ -8,8 +7,8 @@ import contextlib import logging -import psutil +import psutil from .pynvml import ( nvmlDeviceGetComputeRunningProcesses, @@ -39,6 +38,7 @@ def get_process_memory(processes: list[psutil.Process]) -> int: ------- int total process memory + """ rss: int = 0 for process in processes: @@ -49,9 +49,10 @@ def get_process_memory(processes: list[psutil.Process]) -> int: def get_process_cpu( - processes: list[psutil.Process], interval: float | None = None + processes: list[psutil.Process], + interval: float | None = None, ) -> float: - """Get the CPU usage + """Get the CPU usage. If first time being called, use a small interval to collect initial CPU metrics. @@ -60,12 +61,14 @@ def get_process_cpu( processes: list[psutil.Process] list of processes to track for CPU usage. interval: float, optional - interval to measure across, default is None, use previous measure time difference. + interval to measure across, default is None, + use previous measure time difference. Returns ------- float CPU percentage usage + """ cpu_percent: int = 0 for process in processes: @@ -89,6 +92,7 @@ def is_gpu_used(handle, processes: list[psutil.Process]) -> bool: ------- bool if GPU is being used + """ pids = [process.pid for process in processes] @@ -113,6 +117,7 @@ def get_gpu_metrics(processes: list[psutil.Process]) -> list[tuple[float, float] For each GPU identified: - gpu_percent - gpu_memory + """ gpu_metrics: list[tuple[float, float]] = [] @@ -148,6 +153,7 @@ def __init__( processes to measure across. interval: float | None interval to measure, if None previous measure time used for interval. + """ self.cpu_percent: float | None = get_process_cpu(processes, interval=interval) self.cpu_memory: float | None = get_process_memory(processes) diff --git a/simvue/models.py b/simvue/models.py index b07f369e..5afc55e6 100644 --- a/simvue/models.py +++ b/simvue/models.py @@ -1,9 +1,8 @@ import datetime import typing -import numpy -import warnings -import pydantic +import numpy as np +import pydantic FOLDER_REGEX: str = r"^/.*" NAME_REGEX: str = r"^[a-zA-Z0-9\-\_\s\/\.:]+$" @@ -12,22 +11,22 @@ OBJECT_ID: str = r"^[A-Za-z0-9]{22}$" MetadataKeyString = typing.Annotated[ - str, pydantic.StringConstraints(pattern=r"^[\w\-\s\.]+$") + str, + pydantic.StringConstraints(pattern=r"^[\w\-\s\.]+$"), ] TagString = typing.Annotated[str, pydantic.StringConstraints(pattern=r"^[\w\-\s\.]+$")] MetricKeyString = typing.Annotated[ - str, pydantic.StringConstraints(pattern=METRIC_KEY_REGEX) + str, + pydantic.StringConstraints(pattern=METRIC_KEY_REGEX), ] ObjectID = typing.Annotated[str, pydantic.StringConstraints(pattern=OBJECT_ID)] LogLevel = typing.Literal["debug", "info", "warning", "error", "critical"] -def validate_timestamp(timestamp: str, raise_except: bool = True) -> bool: - """ - Validate a user-provided timestamp - """ +def validate_timestamp(timestamp: str, *, raise_except: bool = True) -> bool: + """Validate a user-provided timestamp.""" try: - _ = datetime.datetime.strptime(timestamp, DATETIME_FORMAT) + _ = datetime.datetime.strptime(timestamp, DATETIME_FORMAT).astimezone() except ValueError as e: if raise_except: raise e @@ -42,7 +41,7 @@ def simvue_timestamp( | typing.Annotated[str | None, pydantic.BeforeValidator(validate_timestamp)] | None = None, ) -> str: - """Return the Simvue valid timestamp + """Return the Simvue valid timestamp. Parameters ---------- @@ -55,19 +54,17 @@ def simvue_timestamp( ------- str Datetime string valid for the Simvue server + """ - if isinstance(date_time, str): - warnings.warn( - "Timestamps as strings for object recording will be deprecated in Python API >= 2.3" - ) if not date_time: - date_time = datetime.datetime.now(datetime.timezone.utc) + date_time = datetime.datetime.now(datetime.UTC) elif isinstance(date_time, str): - _local_time = datetime.datetime.now().tzinfo + _local_time = datetime.datetime.now(datetime.UTC).astimezone().tzinfo date_time = ( - datetime.datetime.strptime(date_time, DATETIME_FORMAT) + datetime.datetime + .strptime(date_time, DATETIME_FORMAT) .replace(tzinfo=_local_time) - .astimezone(datetime.timezone.utc) + .astimezone(datetime.UTC) ) return date_time.strftime(DATETIME_FORMAT) @@ -94,18 +91,22 @@ class MetricSet(pydantic.BaseModel): class GridMetricSet(pydantic.BaseModel): model_config = pydantic.ConfigDict( - arbitrary_types_allowed=True, extra="forbid", validate_default=True + arbitrary_types_allowed=True, + extra="forbid", + validate_default=True, ) time: float | int timestamp: typing.Annotated[str | None, pydantic.BeforeValidator(simvue_timestamp)] step: pydantic.NonNegativeInt - array: list[float] | list[list[float]] | numpy.ndarray + array: list[float] | list[list[float]] | np.ndarray grid: str metric: str @pydantic.field_serializer("array", when_used="always") def serialize_array( - self, value: numpy.ndarray | list[float] | list[list[float]], *_ + self, + value: np.ndarray | list[float] | list[list[float]], + *_, ) -> list[float] | list[list[float]]: if isinstance(value, list): return value diff --git a/simvue/pynvml.py b/simvue/pynvml.py index ff319334..ff4d218a 100644 --- a/simvue/pynvml.py +++ b/simvue/pynvml.py @@ -65,8 +65,8 @@ ) from functools import wraps -## C Type mappings ## -## Enums +# C Type mappings ## +# Enums _nvmlEnableState_t = c_uint NVML_FEATURE_DISABLED = 0 NVML_FEATURE_ENABLED = 1 @@ -110,7 +110,7 @@ _nvmlComputeMode_t = c_uint NVML_COMPUTEMODE_DEFAULT = 0 -NVML_COMPUTEMODE_EXCLUSIVE_THREAD = 1 ## Support Removed +NVML_COMPUTEMODE_EXCLUSIVE_THREAD = 1 # Support Removed NVML_COMPUTEMODE_PROHIBITED = 2 NVML_COMPUTEMODE_EXCLUSIVE_PROCESS = 3 NVML_COMPUTEMODE_COUNT = 4 @@ -859,7 +859,7 @@ NVML_FI_MAX = 161 # One greater than the largest field ID defined above -## Enums needed for the method nvmlDeviceGetVirtualizationMode and nvmlDeviceSetVirtualizationMode +# Enums needed for the method nvmlDeviceGetVirtualizationMode and nvmlDeviceSetVirtualizationMode NVML_GPU_VIRTUALIZATION_MODE_NONE = 0 # Represents Bare Metal GPU NVML_GPU_VIRTUALIZATION_MODE_PASSTHROUGH = ( 1 # Device is associated with GPU-Passthorugh @@ -874,12 +874,12 @@ 4 # Device is associated with VGX hypervisor in vSGA mode ) -## Lib loading ## +# Lib loading ## nvmlLib = None libLoadLock = threading.Lock() _nvmlLib_refcount = 0 # Incremented on each nvmlInit and decremented on nvmlShutdown -## vGPU Management +# vGPU Management _nvmlVgpuTypeId_t = c_uint _nvmlVgpuInstance_t = c_uint @@ -935,7 +935,7 @@ NVML_GSP_FIRMWARE_VERSION_BUF_SIZE = 0x40 -## Error Checking ## +# Error Checking ## class NVMLError(Exception): _valClassMapping = dict() # List of currently known error codes @@ -963,9 +963,8 @@ class NVMLError(Exception): } def __new__(typ, value): - """ - Maps value to a proper subclass of NVMLError. - See _extractNVMLErrorsAsClasses function for more details + """Maps value to a proper subclass of NVMLError. + See _extractNVMLErrorsAsClasses function for more details. """ if typ == NVMLError: typ = NVMLError._valClassMapping.get(value, typ) @@ -977,7 +976,7 @@ def __str__(self): try: if self.value not in NVMLError._errcode_to_string: NVMLError._errcode_to_string[self.value] = str( - nvmlErrorString(self.value) + nvmlErrorString(self.value), ) return NVMLError._errcode_to_string[self.value] except NVMLError: @@ -994,8 +993,7 @@ def nvmlExceptionClass(nvmlErrorCode): def _extractNVMLErrorsAsClasses(): - """ - Generates a hierarchy of classes on top of NVMLError class. + """Generates a hierarchy of classes on top of NVMLError class. Each NVML Error gets a new NVMLError subclass. This way try,except blocks can filter appropriate exceptions more easily. @@ -1008,7 +1006,8 @@ def _extractNVMLErrorsAsClasses(): for err_name in nvmlErrorsNames: # e.g. Turn NVML_ERROR_ALREADY_INITIALIZED into NVMLError_AlreadyInitialized class_name = "NVMLError_" + string.capwords( - err_name.replace("NVML_ERROR_", ""), "_" + err_name.replace("NVML_ERROR_", ""), + "_", ).replace("_", "") err_val = getattr(this_module, err_name) @@ -1034,7 +1033,7 @@ def _nvmlCheckReturn(ret): return ret -## Function access ## +# Function access ## _nvmlGetFunctionPointer_cache = ( dict() ) # function pointers are cached to prevent unnecessary libLoadLock locking @@ -1061,11 +1060,11 @@ def _nvmlGetFunctionPointer(name): libLoadLock.release() -## Alternative object +# Alternative object # Allows the object to be printed # Allows mismatched types to be assigned # - like None when the Structure variant requires c_uint -class nvmlFriendlyObject(object): +class nvmlFriendlyObject: def __init__(self, dictionary): for x in dictionary: setattr(self, x, dictionary[x]) @@ -1098,7 +1097,7 @@ def nvmlFriendlyObjectToStruct(obj, model): return model -## Unit structures +# Unit structures class struct_c_nvmlUnit_t(Structure): pass # opaque handle @@ -1107,13 +1106,12 @@ class struct_c_nvmlUnit_t(Structure): class _PrintableStructure(Structure): - """ - Abstract class that produces nicer __str__ output than ctypes.Structure. + """Abstract class that produces nicer __str__ output than ctypes.Structure. e.g. instead of: >>> print str(obj) this class will print - class_name(field_name: formatted_value, field_name: formatted_value) + class_name(field_name: formatted_value, field_name: formatted_value). _fmt_ dictionary of -> e.g. class that has _field_ 'hex_value', c_uint could be formatted with @@ -1142,7 +1140,7 @@ def __str__(self): return self.__class__.__name__ + "(" + ", ".join(result) + ")" def __getattribute__(self, name): - res = super(_PrintableStructure, self).__getattribute__(name) + res = super().__getattribute__(name) # need to convert bytes to unicode for python3 don't need to for python2 # Python 2 strings are of both str and bytes # Python 3 strings are not of type bytes @@ -1158,7 +1156,7 @@ def __setattr__(self, name, value): # encoding a python2 string returns the same value, since python2 strings are bytes already # bytes passed in python3 will be ignored. value = value.encode() - super(_PrintableStructure, self).__setattr__(name, value) + super().__setattr__(name, value) class c_nvmlUnitInfo_t(_PrintableStructure): @@ -1197,7 +1195,7 @@ class c_nvmlUnitFanSpeeds_t(_PrintableStructure): _fields_ = [("fans", c_nvmlUnitFanInfo_t * 24), ("count", c_uint)] -## Device structures +# Device structures class struct_c_nvmlDevice_t(Structure): pass # opaque handle @@ -1575,7 +1573,7 @@ class c_nvmlGridLicensableFeatures_t(_PrintableStructure): ] -## Event structures +# Event structures class struct_c_nvmlEventSet_t(Structure): pass # opaque handle @@ -1601,7 +1599,7 @@ class struct_c_nvmlEventSet_t(Structure): | nvmlEventMigConfigChange ) -## Clock Throttle Reasons defines +# Clock Throttle Reasons defines nvmlClocksThrottleReasonGpuIdle = 0x0000000000000001 nvmlClocksThrottleReasonApplicationsClocksSetting = 0x0000000000000002 nvmlClocksThrottleReasonUserDefinedClocks = nvmlClocksThrottleReasonApplicationsClocksSetting # deprecated, use nvmlClocksThrottleReasonApplicationsClocksSetting @@ -1769,9 +1767,7 @@ class c_nvmlGpuInstanceProfileInfo_v2_t(Structure): ] def __init__(self): - super(c_nvmlGpuInstanceProfileInfo_v2_t, self).__init__( - version=nvmlGpuInstanceProfileInfo_v2 - ) + super().__init__(version=nvmlGpuInstanceProfileInfo_v2) class c_nvmlGpuInstanceInfo_t(Structure): @@ -1839,9 +1835,7 @@ class c_nvmlComputeInstanceProfileInfo_v2_t(Structure): ] def __init__(self): - super(c_nvmlComputeInstanceProfileInfo_v2_t, self).__init__( - version=nvmlComputeInstanceProfileInfo_v2 - ) + super().__init__(version=nvmlComputeInstanceProfileInfo_v2) class c_nvmlComputeInstanceInfo_t(Structure): @@ -1962,12 +1956,11 @@ class c_nvmlRowRemapperHistogramValues(Structure): ] -## string/bytes conversion for ease of use +# string/bytes conversion for ease of use def convertStrBytes(func): - """ - In python 3, strings are unicode instead of bytes, and need to be converted for ctypes + """In python 3, strings are unicode instead of bytes, and need to be converted for ctypes Args from caller: (1, 'string', <__main__.c_nvmlDevice_t at 0xFFFFFFFF>) - Args passed to function: (1, b'string', <__main__.c_nvmlDevice_t at 0xFFFFFFFF)> + Args passed to function: (1, b'string', <__main__.c_nvmlDevice_t at 0xFFFFFFFF)>. ---- Returned from function: b'returned string' Returned to caller: 'returned string' @@ -1992,7 +1985,7 @@ def wrapper(*args, **kwargs): return func -## C function wrappers ## +# C function wrappers ## def nvmlInitWithFlags(flags): _LoadNvmlLibrary() @@ -2008,18 +2001,14 @@ def nvmlInitWithFlags(flags): libLoadLock.acquire() _nvmlLib_refcount += 1 libLoadLock.release() - return None def nvmlInit(): nvmlInitWithFlags(0) - return None def _LoadNvmlLibrary(): - """ - Load the library if it isn't loaded already - """ + """Load the library if it isn't loaded already.""" global nvmlLib if nvmlLib is None: @@ -2038,7 +2027,7 @@ def _LoadNvmlLibrary(): os.path.join( os.getenv("WINDIR", "C:/Windows"), "System32/nvml.dll", - ) + ), ) except OSError: # If nvml.dll is not found in System32, it should be in ProgramFiles @@ -2047,7 +2036,7 @@ def _LoadNvmlLibrary(): os.path.join( os.getenv("ProgramFiles", "C:/Program Files"), "NVIDIA Corporation/NVSMI/nvml.dll", - ) + ), ) else: # assume linux @@ -2072,10 +2061,9 @@ def nvmlShutdown(): # Atomically update refcount global _nvmlLib_refcount libLoadLock.acquire() - if 0 < _nvmlLib_refcount: + if _nvmlLib_refcount > 0: _nvmlLib_refcount -= 1 libLoadLock.release() - return None # Added in 2.285 @@ -2142,7 +2130,7 @@ def nvmlSystemGetHicVersion(): ret = fn(byref(c_count), None) # this should only fail with insufficient size - if (ret != NVML_SUCCESS) and (ret != NVML_ERROR_INSUFFICIENT_SIZE): + if ret not in {NVML_SUCCESS, NVML_ERROR_INSUFFICIENT_SIZE}: raise NVMLError(ret) # If there are no hics @@ -2156,7 +2144,7 @@ def nvmlSystemGetHicVersion(): return hics -## Unit get functions +# Unit get functions def nvmlUnitGetCount(): c_count = c_uint() fn = _nvmlGetFunctionPointer("nvmlUnitGetCount") @@ -2236,7 +2224,7 @@ def nvmlUnitGetDevices(unit): return c_devices -## Device get functions +# Device get functions def nvmlDeviceGetCount(): c_count = c_uint() fn = _nvmlGetFunctionPointer("nvmlDeviceGetCount_v2") @@ -2366,14 +2354,12 @@ def nvmlDeviceSetCpuAffinity(handle): fn = _nvmlGetFunctionPointer("nvmlDeviceSetCpuAffinity") ret = fn(handle) _nvmlCheckReturn(ret) - return None def nvmlDeviceClearCpuAffinity(handle): fn = _nvmlGetFunctionPointer("nvmlDeviceClearCpuAffinity") ret = fn(handle) _nvmlCheckReturn(ret) - return None def nvmlDeviceGetMinorNumber(handle): @@ -2431,7 +2417,6 @@ def nvmlDeviceValidateInforom(handle): fn = _nvmlGetFunctionPointer("nvmlDeviceValidateInforom") ret = fn(handle) _nvmlCheckReturn(ret) - return None def nvmlDeviceGetDisplayMode(handle): @@ -2531,7 +2516,7 @@ def nvmlDeviceGetSupportedMemoryClocks(handle): if ret == NVML_SUCCESS: # special case, no clocks return [] - elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + if ret == NVML_ERROR_INSUFFICIENT_SIZE: # typical case clocks_array = c_uint * c_count.value c_clocks = clocks_array() @@ -2545,9 +2530,8 @@ def nvmlDeviceGetSupportedMemoryClocks(handle): procs.append(c_clocks[i]) return procs - else: - # error case - raise NVMLError(ret) + # error case + raise NVMLError(ret) # Added in 4.304 @@ -2560,7 +2544,7 @@ def nvmlDeviceGetSupportedGraphicsClocks(handle, memoryClockMHz): if ret == NVML_SUCCESS: # special case, no clocks return [] - elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + if ret == NVML_ERROR_INSUFFICIENT_SIZE: # typical case clocks_array = c_uint * c_count.value c_clocks = clocks_array() @@ -2574,9 +2558,8 @@ def nvmlDeviceGetSupportedGraphicsClocks(handle, memoryClockMHz): procs.append(c_clocks[i]) return procs - else: - # error case - raise NVMLError(ret) + # error case + raise NVMLError(ret) def nvmlDeviceGetFanSpeed(handle): @@ -2639,7 +2622,6 @@ def nvmlDeviceSetTemperatureThreshold(handle, threshold, temp): fn = _nvmlGetFunctionPointer("nvmlDeviceSetTemperatureThreshold") ret = fn(handle, _nvmlTemperatureThresholds_t(threshold), byref(c_temp)) _nvmlCheckReturn(ret) - return None # DEPRECATED use nvmlDeviceGetPerformanceState @@ -2918,7 +2900,7 @@ def nvmlDeviceGetComputeRunningProcesses_v3(handle): for suffix in ("_v3", "_v2", ""): try: fn = _nvmlGetFunctionPointer( - f"nvmlDeviceGetComputeRunningProcesses{suffix}" + f"nvmlDeviceGetComputeRunningProcesses{suffix}", ) break except NVMLError: @@ -2931,7 +2913,7 @@ def nvmlDeviceGetComputeRunningProcesses_v3(handle): if ret == NVML_SUCCESS: # special case, no running processes return [] - elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + if ret == NVML_ERROR_INSUFFICIENT_SIZE: # typical case # oversize the array incase more processes are created c_count.value = c_count.value * 2 + 5 @@ -2952,9 +2934,8 @@ def nvmlDeviceGetComputeRunningProcesses_v3(handle): procs.append(obj) return procs - else: - # error case - raise NVMLError(ret) + # error case + raise NVMLError(ret) def nvmlDeviceGetComputeRunningProcesses(handle): @@ -2969,7 +2950,7 @@ def nvmlDeviceGetGraphicsRunningProcesses_v3(handle): for suffix in ("_v3", "_v2", ""): try: fn = _nvmlGetFunctionPointer( - f"nvmlDeviceGetGraphicsRunningProcesses{suffix}" + f"nvmlDeviceGetGraphicsRunningProcesses{suffix}", ) break except NVMLError: @@ -2982,7 +2963,7 @@ def nvmlDeviceGetGraphicsRunningProcesses_v3(handle): if ret == NVML_SUCCESS: # special case, no running processes return [] - elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + if ret == NVML_ERROR_INSUFFICIENT_SIZE: # typical case # oversize the array incase more processes are created c_count.value = c_count.value * 2 + 5 @@ -3003,9 +2984,8 @@ def nvmlDeviceGetGraphicsRunningProcesses_v3(handle): procs.append(obj) return procs - else: - # error case - raise NVMLError(ret) + # error case + raise NVMLError(ret) def nvmlDeviceGetGraphicsRunningProcesses(handle): @@ -3025,7 +3005,7 @@ def nvmlDeviceGetMPSComputeRunningProcesses_v3(handle): if ret == NVML_SUCCESS: # special case, no running processes return [] - elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + if ret == NVML_ERROR_INSUFFICIENT_SIZE: # typical case # oversize the array incase more processes are created c_count.value = c_count.value * 2 + 5 @@ -3046,9 +3026,8 @@ def nvmlDeviceGetMPSComputeRunningProcesses_v3(handle): procs.append(obj) return procs - else: - # error case - raise NVMLError(ret) + # error case + raise NVMLError(ret) def nvmlDeviceGetAutoBoostedClocksEnabled(handle): @@ -3061,54 +3040,47 @@ def nvmlDeviceGetAutoBoostedClocksEnabled(handle): # Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks -## Set functions +# Set functions def nvmlUnitSetLedState(unit, color): fn = _nvmlGetFunctionPointer("nvmlUnitSetLedState") ret = fn(unit, _nvmlLedColor_t(color)) _nvmlCheckReturn(ret) - return None def nvmlDeviceSetPersistenceMode(handle, mode): fn = _nvmlGetFunctionPointer("nvmlDeviceSetPersistenceMode") ret = fn(handle, _nvmlEnableState_t(mode)) _nvmlCheckReturn(ret) - return None def nvmlDeviceSetComputeMode(handle, mode): fn = _nvmlGetFunctionPointer("nvmlDeviceSetComputeMode") ret = fn(handle, _nvmlComputeMode_t(mode)) _nvmlCheckReturn(ret) - return None def nvmlDeviceSetEccMode(handle, mode): fn = _nvmlGetFunctionPointer("nvmlDeviceSetEccMode") ret = fn(handle, _nvmlEnableState_t(mode)) _nvmlCheckReturn(ret) - return None def nvmlDeviceClearEccErrorCounts(handle, counterType): fn = _nvmlGetFunctionPointer("nvmlDeviceClearEccErrorCounts") ret = fn(handle, _nvmlEccCounterType_t(counterType)) _nvmlCheckReturn(ret) - return None def nvmlDeviceSetDriverModel(handle, model): fn = _nvmlGetFunctionPointer("nvmlDeviceSetDriverModel") ret = fn(handle, _nvmlDriverModel_t(model)) _nvmlCheckReturn(ret) - return None def nvmlDeviceSetAutoBoostedClocksEnabled(handle, enabled): fn = _nvmlGetFunctionPointer("nvmlDeviceSetAutoBoostedClocksEnabled") ret = fn(handle, _nvmlEnableState_t(enabled)) _nvmlCheckReturn(ret) - return None # Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks @@ -3116,7 +3088,6 @@ def nvmlDeviceSetDefaultAutoBoostedClocksEnabled(handle, enabled, flags): fn = _nvmlGetFunctionPointer("nvmlDeviceSetDefaultAutoBoostedClocksEnabled") ret = fn(handle, _nvmlEnableState_t(enabled), c_uint(flags)) _nvmlCheckReturn(ret) - return None # Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks @@ -3124,28 +3095,24 @@ def nvmlDeviceSetGpuLockedClocks(handle, minGpuClockMHz, maxGpuClockMHz): fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpuLockedClocks") ret = fn(handle, c_uint(minGpuClockMHz), c_uint(maxGpuClockMHz)) _nvmlCheckReturn(ret) - return None def nvmlDeviceResetGpuLockedClocks(handle): fn = _nvmlGetFunctionPointer("nvmlDeviceResetGpuLockedClocks") ret = fn(handle) _nvmlCheckReturn(ret) - return None def nvmlDeviceSetMemoryLockedClocks(handle, minMemClockMHz, maxMemClockMHz): fn = _nvmlGetFunctionPointer("nvmlDeviceSetMemoryLockedClocks") ret = fn(handle, c_uint(minMemClockMHz), c_uint(maxMemClockMHz)) _nvmlCheckReturn(ret) - return None def nvmlDeviceResetMemoryLockedClocks(handle): fn = _nvmlGetFunctionPointer("nvmlDeviceResetMemoryLockedClocks") ret = fn(handle) _nvmlCheckReturn(ret) - return None def nvmlDeviceGetClkMonStatus(handle, c_clkMonInfo): @@ -3159,7 +3126,6 @@ def nvmlDeviceSetApplicationsClocks(handle, maxMemClockMHz, maxGraphicsClockMHz) fn = _nvmlGetFunctionPointer("nvmlDeviceSetApplicationsClocks") ret = fn(handle, c_uint(maxMemClockMHz), c_uint(maxGraphicsClockMHz)) _nvmlCheckReturn(ret) - return None # Added in 4.304 @@ -3167,7 +3133,6 @@ def nvmlDeviceResetApplicationsClocks(handle): fn = _nvmlGetFunctionPointer("nvmlDeviceResetApplicationsClocks") ret = fn(handle) _nvmlCheckReturn(ret) - return None # Added in 4.304 @@ -3175,7 +3140,6 @@ def nvmlDeviceSetPowerManagementLimit(handle, limit): fn = _nvmlGetFunctionPointer("nvmlDeviceSetPowerManagementLimit") ret = fn(handle, c_uint(limit)) _nvmlCheckReturn(ret) - return None # Added in 4.304 @@ -3183,7 +3147,6 @@ def nvmlDeviceSetGpuOperationMode(handle, mode): fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpuOperationMode") ret = fn(handle, _nvmlGpuOperationMode_t(mode)) _nvmlCheckReturn(ret) - return None # Added in 2.285 @@ -3200,7 +3163,6 @@ def nvmlDeviceRegisterEvents(handle, eventTypes, eventSet): fn = _nvmlGetFunctionPointer("nvmlDeviceRegisterEvents") ret = fn(handle, c_ulonglong(eventTypes), eventSet) _nvmlCheckReturn(ret) - return None # Added in 2.285 @@ -3230,7 +3192,6 @@ def nvmlEventSetFree(eventSet): fn = _nvmlGetFunctionPointer("nvmlEventSetFree") ret = fn(eventSet) _nvmlCheckReturn(ret) - return None # Added in 3.295 @@ -3318,14 +3279,12 @@ def nvmlDeviceSetAccountingMode(handle, mode): fn = _nvmlGetFunctionPointer("nvmlDeviceSetAccountingMode") ret = fn(handle, _nvmlEnableState_t(mode)) _nvmlCheckReturn(ret) - return None def nvmlDeviceClearAccountingPids(handle): fn = _nvmlGetFunctionPointer("nvmlDeviceClearAccountingPids") ret = fn(handle) _nvmlCheckReturn(ret) - return None def nvmlDeviceGetAccountingStats(handle, pid): @@ -3365,7 +3324,7 @@ def nvmlDeviceGetRetiredPages(device, sourceFilter): ret = fn(device, c_source, byref(c_count), None) # this should only fail with insufficient size - if (ret != NVML_SUCCESS) and (ret != NVML_ERROR_INSUFFICIENT_SIZE): + if ret not in {NVML_SUCCESS, NVML_ERROR_INSUFFICIENT_SIZE}: raise NVMLError(ret) # call again with a buffer @@ -3388,7 +3347,7 @@ def nvmlDeviceGetRetiredPages_v2(device, sourceFilter): ret = fn(device, c_source, byref(c_count), None) # this should only fail with insufficient size - if (ret != NVML_SUCCESS) and (ret != NVML_ERROR_INSUFFICIENT_SIZE): + if ret not in {NVML_SUCCESS, NVML_ERROR_INSUFFICIENT_SIZE}: raise NVMLError(ret) # call again with a buffer @@ -3427,7 +3386,6 @@ def nvmlDeviceSetAPIRestriction(handle, apiType, isRestricted): fn = _nvmlGetFunctionPointer("nvmlDeviceSetAPIRestriction") ret = fn(handle, _nvmlRestrictedAPI_t(apiType), _nvmlEnableState_t(isRestricted)) _nvmlCheckReturn(ret) - return None def nvmlDeviceGetBridgeChipInfo(handle): @@ -3445,7 +3403,7 @@ def nvmlDeviceGetSamples(device, sampling_type, timeStamp): c_sample_value_type = _nvmlValueType_t() fn = _nvmlGetFunctionPointer("nvmlDeviceGetSamples") - ## First Call gets the size + # First Call gets the size ret = fn( device, c_sampling_type, @@ -3478,7 +3436,7 @@ def nvmlDeviceGetViolationStatus(device, perfPolicyType): c_violTime = c_nvmlViolationTime_t() fn = _nvmlGetFunctionPointer("nvmlDeviceGetViolationStatus") - ## Invoke the method to get violation time + # Invoke the method to get violation time ret = fn(device, c_perfPolicy_type, byref(c_violTime)) _nvmlCheckReturn(ret) return c_violTime @@ -3548,21 +3506,18 @@ def nvmlDeviceFreezeNvLinkUtilizationCounter(device, link, counter, freeze): fn = _nvmlGetFunctionPointer("nvmlDeviceFreezeNvLinkUtilizationCounter") ret = fn(device, link, counter, freeze) _nvmlCheckReturn(ret) - return None def nvmlDeviceResetNvLinkUtilizationCounter(device, link, counter): fn = _nvmlGetFunctionPointer("nvmlDeviceResetNvLinkUtilizationCounter") ret = fn(device, link, counter) _nvmlCheckReturn(ret) - return None def nvmlDeviceSetNvLinkUtilizationControl(device, link, counter, control, reset): fn = _nvmlGetFunctionPointer("nvmlDeviceSetNvLinkUtilizationControl") ret = fn(device, link, counter, byref(control), reset) _nvmlCheckReturn(ret) - return None def nvmlDeviceGetNvLinkUtilizationControl(device, link, counter): @@ -3593,7 +3548,6 @@ def nvmlDeviceResetNvLinkErrorCounters(device, link): fn = _nvmlGetFunctionPointer("nvmlDeviceResetNvLinkErrorCounters") ret = fn(device, link) _nvmlCheckReturn(ret) - return None def nvmlDeviceGetNvLinkRemotePciInfo(device, link): @@ -3632,7 +3586,6 @@ def nvmlDeviceModifyDrainState(pciInfo, newState): fn = _nvmlGetFunctionPointer("nvmlDeviceModifyDrainState") ret = fn(pointer(pciInfo), newState) _nvmlCheckReturn(ret) - return None def nvmlDeviceQueryDrainState(pciInfo): @@ -3647,14 +3600,12 @@ def nvmlDeviceRemoveGpu(pciInfo): fn = _nvmlGetFunctionPointer("nvmlDeviceRemoveGpu") ret = fn(pointer(pciInfo)) _nvmlCheckReturn(ret) - return None def nvmlDeviceDiscoverGpus(pciInfo): fn = _nvmlGetFunctionPointer("nvmlDeviceDiscoverGpus") ret = fn(pointer(pciInfo)) _nvmlCheckReturn(ret) - return None def nvmlDeviceGetFieldValues(handle, fieldIds): @@ -3696,7 +3647,7 @@ def nvmlDeviceGetSupportedVgpus(handle): if ret == NVML_SUCCESS: # special case, no supported vGPUs return [] - elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + if ret == NVML_ERROR_INSUFFICIENT_SIZE: # typical case vgpu_type_ids_array = _nvmlVgpuTypeId_t * c_vgpu_count.value c_vgpu_type_ids = vgpu_type_ids_array() @@ -3708,9 +3659,8 @@ def nvmlDeviceGetSupportedVgpus(handle): for i in range(c_vgpu_count.value): vgpus.append(c_vgpu_type_ids[i]) return vgpus - else: - # error case - raise NVMLError(ret) + # error case + raise NVMLError(ret) def nvmlDeviceGetCreatableVgpus(handle): @@ -3723,7 +3673,7 @@ def nvmlDeviceGetCreatableVgpus(handle): if ret == NVML_SUCCESS: # special case, no supported vGPUs return [] - elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + if ret == NVML_ERROR_INSUFFICIENT_SIZE: # typical case vgpu_type_ids_array = _nvmlVgpuTypeId_t * c_vgpu_count.value c_vgpu_type_ids = vgpu_type_ids_array() @@ -3735,9 +3685,8 @@ def nvmlDeviceGetCreatableVgpus(handle): for i in range(c_vgpu_count.value): vgpus.append(c_vgpu_type_ids[i]) return vgpus - else: - # error case - raise NVMLError(ret) + # error case + raise NVMLError(ret) def nvmlVgpuTypeGetGpuInstanceProfileId(vgpuTypeId): @@ -3846,7 +3795,7 @@ def nvmlDeviceGetActiveVgpus(handle): if ret == NVML_SUCCESS: # special case, no active vGPUs return [] - elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + if ret == NVML_ERROR_INSUFFICIENT_SIZE: # typical case vgpu_instance_array = _nvmlVgpuInstance_t * c_vgpu_count.value c_vgpu_instances = vgpu_instance_array() @@ -3858,9 +3807,8 @@ def nvmlDeviceGetActiveVgpus(handle): for i in range(c_vgpu_count.value): vgpus.append(c_vgpu_instances[i]) return vgpus - else: - # error case - raise NVMLError(ret) + # error case + raise NVMLError(ret) @convertStrBytes @@ -3989,7 +3937,9 @@ def nvmlVgpuInstanceGetGpuPciId(vgpuInstance): c_vgpuPciId = create_string_buffer(NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE) fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetGpuPciId") ret = fn( - vgpuInstance, c_vgpuPciId, byref(c_uint(NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE)) + vgpuInstance, + c_vgpuPciId, + byref(c_uint(NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE)), ) _nvmlCheckReturn(ret) return c_vgpuPciId.value @@ -4003,13 +3953,17 @@ def nvmlDeviceGetVgpuUtilization(handle, timeStamp): fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuUtilization") ret = fn( - handle, c_time_stamp, byref(c_sample_value_type), byref(c_vgpu_count), None + handle, + c_time_stamp, + byref(c_sample_value_type), + byref(c_vgpu_count), + None, ) if ret == NVML_SUCCESS: # special case, no active vGPUs return [] - elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + if ret == NVML_ERROR_INSUFFICIENT_SIZE: # typical case sampleArray = c_vgpu_count.value * c_nvmlVgpuInstanceUtilizationSample_t c_samples = sampleArray() @@ -4025,9 +3979,8 @@ def nvmlDeviceGetVgpuUtilization(handle, timeStamp): _nvmlCheckReturn(ret) return c_samples[0 : c_vgpu_count.value] - else: - # error case - raise NVMLError(ret) + # error case + raise NVMLError(ret) def nvmlDeviceGetP2PStatus(device1, device2, p2pIndex): @@ -4086,7 +4039,7 @@ def nvmlDeviceGetVgpuProcessUtilization(handle, timeStamp): if ret == NVML_SUCCESS: # special case, no active vGPUs return [] - elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + if ret == NVML_ERROR_INSUFFICIENT_SIZE: # typical case sampleArray = c_vgpu_count.value * c_nvmlVgpuProcessUtilizationSample_t c_samples = sampleArray() @@ -4096,9 +4049,8 @@ def nvmlDeviceGetVgpuProcessUtilization(handle, timeStamp): _nvmlCheckReturn(ret) return c_samples[0 : c_vgpu_count.value] - else: - # error case - raise NVMLError(ret) + # error case + raise NVMLError(ret) def nvmlDeviceGetEncoderStats(handle): @@ -4131,11 +4083,9 @@ def nvmlDeviceGetEncoderSessions(handle): for i in range(c_session_count.value): sessions.append(c_sessions[i]) return sessions - else: - return [] # no active sessions - else: - # error case - raise NVMLError(ret) + return [] # no active sessions + # error case + raise NVMLError(ret) def nvmlDeviceGetFBCStats(handle): @@ -4166,11 +4116,9 @@ def nvmlDeviceGetFBCSessions(handle): for i in range(c_session_count.value): sessions.append(c_sessions[i]) return sessions - else: - return [] # no active sessions - else: - # error case - raise NVMLError(ret) + return [] # no active sessions + # error case + raise NVMLError(ret) def nvmlVgpuInstanceGetEncoderStats(vgpuInstance): @@ -4179,7 +4127,10 @@ def nvmlVgpuInstanceGetEncoderStats(vgpuInstance): c_encoderLatency = c_ulonglong(0) fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEncoderStats") ret = fn( - vgpuInstance, byref(c_encoderCount), byref(c_encodeFps), byref(c_encoderLatency) + vgpuInstance, + byref(c_encoderCount), + byref(c_encodeFps), + byref(c_encoderLatency), ) _nvmlCheckReturn(ret) return (c_encoderCount.value, c_encodeFps.value, c_encoderLatency.value) @@ -4205,11 +4156,9 @@ def nvmlVgpuInstanceGetEncoderSessions(vgpuInstance): for i in range(c_session_count.value): sessions.append(c_sessions[i]) return sessions - else: - return [] # no active sessions - else: - # error case - raise NVMLError(ret) + return [] # no active sessions + # error case + raise NVMLError(ret) def nvmlVgpuInstanceGetFBCStats(vgpuInstance): @@ -4240,11 +4189,9 @@ def nvmlVgpuInstanceGetFBCSessions(vgpuInstance): for i in range(c_session_count.value): sessions.append(c_sessions[i]) return sessions - else: - return [] # no active sessions - else: - # error case - raise NVMLError(ret) + return [] # no active sessions + # error case + raise NVMLError(ret) def nvmlDeviceGetProcessUtilization(handle, timeStamp): @@ -4265,9 +4212,8 @@ def nvmlDeviceGetProcessUtilization(handle, timeStamp): _nvmlCheckReturn(ret) return c_samples[0 : c_count.value] - else: - # error case - raise NVMLError(ret) + # error case + raise NVMLError(ret) def nvmlVgpuInstanceGetMetadata(vgpuInstance): @@ -4444,7 +4390,10 @@ def nvmlDeviceGetGpuInstanceRemainingCapacity(device, profileId): def nvmlDeviceGetGpuInstancePossiblePlacements( - device, profileId, placementsRef, countRef + device, + profileId, + placementsRef, + countRef, ): fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstancePossiblePlacements_v2") ret = fn(device, profileId, placementsRef, countRef) @@ -4499,7 +4448,10 @@ def nvmlGpuInstanceGetInfo(gpuInstance): def nvmlGpuInstanceGetComputeInstanceProfileInfo( - device, profile, engProfile, version=2 + device, + profile, + engProfile, + version=2, ): if version == 2: c_info = c_nvmlComputeInstanceProfileInfo_v2_t() @@ -4544,7 +4496,10 @@ def nvmlComputeInstanceDestroy(computeInstance): def nvmlGpuInstanceGetComputeInstances( - gpuInstance, profileId, computeInstancesRef, countRef + gpuInstance, + profileId, + computeInstancesRef, + countRef, ): fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstances") ret = fn(gpuInstance, profileId, computeInstancesRef, countRef) diff --git a/simvue/run.py b/simvue/run.py index ca17edb7..03213458 100644 --- a/simvue/run.py +++ b/simvue/run.py @@ -5,83 +5,81 @@ """ import contextlib +import datetime +import functools import logging -import pathlib import mimetypes import multiprocessing.synchronize -import shlex -import threading -import warnings -import humanfriendly -import datetime import os -from unyt import unyt_quantity -from unyt.exceptions import UnitParseError - -import pydantic +import pathlib +import platform import re +import shlex import sys -import traceback as tb +import threading import time +import traceback as tb import types -import functools -import platform import typing import uuid -import numpy -import randomname + import click +import humanfriendly +import numpy as np import psutil +import pydantic +import randomname +from unyt import unyt_quantity +from unyt.exceptions import UnitParseError -from simvue.api.objects.alert.base import AlertBase from simvue.api.objects.alert.fetch import Alert from simvue.api.objects.folder import Folder from simvue.api.objects.grids import GridMetrics -from simvue.exception import ObjectNotFoundError, SimvueRunError, ObjectDispatchError +from simvue.exception import ObjectDispatchError, ObjectNotFoundError, SimvueRunError from simvue.utilities import prettify_pydantic - +from .api.objects import ( + Events, + EventsAlert, + FileArtifact, + Grid, + Metrics, + MetricsRangeAlert, + MetricsThresholdAlert, + ObjectArtifact, + UserAlert, +) +from .api.objects import ( + Run as RunObject, +) from .config.user import SimvueConfiguration - from .dispatch import Dispatcher +from .dispatch.base import DispatcherBaseClass +from .eco import CO2Monitor from .executor import Executor, get_current_shell +from .metadata import environment, git_info from .metrics import SystemResourceMeasurement from .models import ( FOLDER_REGEX, NAME_REGEX, + LogLevel, MetricKeyString, - validate_timestamp, simvue_timestamp, - LogLevel, + validate_timestamp, ) from .system import get_system -from .metadata import git_info, environment -from .eco import CO2Monitor from .utilities import ( skip_if_failed, ) -from .api.objects import ( - Run as RunObject, - FileArtifact, - ObjectArtifact, - MetricsThresholdAlert, - MetricsRangeAlert, - UserAlert, - EventsAlert, - Events, - Metrics, - Grid, -) - try: from typing import Self except ImportError: - from typing_extensions import Self # noqa: F401 - + from typing_extensions import Self if typing.TYPE_CHECKING: - from .dispatch import DispatcherBaseClass + from simvue.api.objects.alert.base import AlertBase + HEARTBEAT_INTERVAL: int = 60 RESOURCES_METRIC_PREFIX: str = "resources" @@ -108,7 +106,7 @@ def _wrapper(self: Self, *args: typing.Any, **kwargs: typing.Any) -> typing.Any: if not self._sv_obj: raise RuntimeError( - f"Simvue Run must be initialised before calling '{function.__name__}'" + f"Simvue Run must be initialised before calling '{function.__name__}'", ) return _function(self, *args, **kwargs) @@ -134,7 +132,7 @@ def __init__( debug: bool = False, server_profile: str | None = None, ) -> None: - """Initialise a new Simvue run + """Initialise a new Simvue run. If `abort_callback` is provided the first argument must be this Run instance @@ -160,11 +158,11 @@ def __init__( Examples -------- - ```python with simvue.Run() as run: ... ``` + """ self._uuid: str = f"{uuid.uuid4()}" @@ -194,7 +192,12 @@ def __init__( self._failed_metric_counter: int = 0 self._status: ( typing.Literal[ - "created", "running", "completed", "failed", "terminated", "lost" + "created", + "running", + "completed", + "failed", + "terminated", + "lost", ] | None ) = None @@ -212,7 +215,7 @@ def __init__( logging.DEBUG if (debug is not None and debug) or (debug is None and self._user_config.client.debug) - else logging.INFO + else logging.INFO, ) self._aborted: bool = False @@ -296,7 +299,7 @@ def __exit__( @property def duration(self) -> float: - """Return current run duration""" + """Return current run duration.""" return time.time() - self._start_time @property @@ -306,8 +309,7 @@ def mode(self) -> typing.Literal["offline", "online", "disabled"]: @property def processes(self) -> list[psutil.Process]: - """Create an array containing a list of processes""" - + """Create an array containing a list of processes.""" process_list = self._executor.processes if not self._parent_process: @@ -318,8 +320,14 @@ def processes(self) -> list[psutil.Process]: return list(set(process_list)) + @property + def user_config(self) -> SimvueConfiguration: + """Return current user configuration.""" + return self._sv_obj.user_config + def _terminate_run( self, + *, abort_callback: typing.Callable[[Self], None] | None, force_exit: bool = True, ) -> None: @@ -334,12 +342,13 @@ def _terminate_run( the callback to execute on the termination else None force_exit: bool, optional whether to close Python itself, the default is True + """ self._alert_raised_trigger.set() logger.debug("Received abort request from server") if abort_callback is not None: - abort_callback(self) # type: ignore + abort_callback(self) if self._abort_on_alert != "ignore": self.kill_all_processes() @@ -376,8 +385,8 @@ def _get_internal_metrics( tuple[float, float] new resource metric measure time new emissions metric measure time - """ + """ # In order to get a resource metric reading at t=0 # because there is no previous CPU reading yet we cannot # use the default of None for the interval here, so we measure @@ -400,7 +409,8 @@ def _get_internal_metrics( ) # For the first emissions metrics reading, the time interval to use - # Is the time since the run started, otherwise just use the time between readings + # Is the time since the run started, otherwise just use the time + # between readings if self._emissions_monitor: _estimated = self._emissions_monitor.estimate_co2_emissions( process_id=f"{self._sv_obj.name}", @@ -482,12 +492,11 @@ def _heartbeat( def _create_dispatch_callback( self, ) -> typing.Callable: - """Generates the relevant callback for posting of metrics and events + """Generates the relevant callback for posting of metrics and events. The generated callback is assigned to the dispatcher instance and is executed on metrics and events objects held in a buffer. """ - if self._user_config.run.mode == "online" and not self.id: raise RuntimeError("Expected identifier for run") @@ -509,7 +518,7 @@ def _dispatch_callback( events=buffer, ) return _events.commit() - elif category == "metrics_tensor": + if category == "metrics_tensor": _grid_metrics = GridMetrics.new( run=self.id, data=buffer, @@ -518,25 +527,25 @@ def _dispatch_callback( offline=self.mode == "offline", ) return _grid_metrics.commit() - else: - _metrics = Metrics.new( - run=self.id, - offline=self.mode == "offline", - server_url=self._user_config.server.url, - server_token=self._user_config.server.token, - metrics=buffer, - ) - return _metrics.commit() + _metrics = Metrics.new( + run=self.id, + offline=self.mode == "offline", + server_url=self._user_config.server.url, + server_token=self._user_config.server.token, + metrics=buffer, + ) + return _metrics.commit() return _dispatch_callback def _start(self) -> bool: - """Start a run + """Start a run. Returns ------- bool if successful + """ if self._user_config.run.mode == "disabled": return True @@ -576,7 +585,7 @@ def _start(self) -> bool: mode=self._dispatch_mode, termination_trigger=self._shutdown_event, object_types=["events", "metrics_regular", "metrics_tensor"], - thresholds=dict(object_size=TOTAL_GRID_METRIC_SIZE), + thresholds={"object_size": TOTAL_GRID_METRIC_SIZE}, callback=self._create_dispatch_callback(), ) @@ -597,8 +606,8 @@ def _start(self) -> bool: return True - def _error(self, message: str, join_threads: bool = True) -> None: - """Raise an exception if necessary and log error + def _error(self, message: str, *, join_threads: bool = True) -> None: + """Raise an exception if necessary and log error. Parameters ---------- @@ -612,6 +621,7 @@ def _error(self, message: str, join_threads: bool = True) -> None: ------ RuntimeError exception throw + """ # Finish stopping all threads if self._shutdown_event: @@ -642,18 +652,17 @@ def _error(self, message: str, join_threads: bool = True) -> None: self._aborted = True - @skip_if_failed("_aborted", "_suppress_errors", False) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=False) @pydantic.validate_call def init( self, name: typing.Annotated[str | None, pydantic.Field(pattern=NAME_REGEX)] = None, *, - metadata: dict[str, typing.Any] = None, + metadata: dict[str, typing.Any] | None = None, tags: list[str] | None = None, description: str | None = None, - folder: typing.Annotated[ - str, pydantic.Field(None, pattern=FOLDER_REGEX) - ] = None, + folder: typing.Annotated[str, pydantic.Field(None, pattern=FOLDER_REGEX)] + | None = None, notification: typing.Literal["none", "all", "error", "lost"] = "none", running: bool = True, retention_period: str | None = None, @@ -662,7 +671,7 @@ def init( no_color: bool = False, record_shell_vars: set[str] | None = None, ) -> bool: - """Initialise a Simvue run + """Initialise a Simvue run. Parameters ---------- @@ -706,10 +715,12 @@ def init( ------- bool whether the initialisation was successful + """ if self._user_config.run.mode == "disabled": logger.warning( - "Simvue monitoring has been deactivated for this run, metrics and artifacts will not be recorded." + "Simvue monitoring has been deactivated for this run, metrics " + "and artifacts will not be recorded.", ) return True @@ -728,9 +739,9 @@ def init( server_url=self._user_config.server.url, server_token=self._user_config.server.token, ) - self._folder.commit() # type: ignore + self._folder.commit() - if self._user_config.run.mode not in ("online", "offline"): + if self._user_config.run.mode not in {"online", "offline"}: self._error("invalid mode specified, must be online, offline or disabled") return False @@ -738,14 +749,14 @@ def init( not self._user_config.server.token or not self._user_config.server.url ): self._error( - "Unable to get URL and token from environment variables or config file" + "Unable to get URL and token from environment variables or config file", ) return False if name and not re.match(r"^[a-zA-Z0-9\-\_\s\/\.:]+$", name): self._error("specified name is invalid") return False - elif not name and self.mode == "offline": + if not name and self.mode == "offline": name = randomname.get_name() self._status = "running" if running else "created" @@ -754,7 +765,7 @@ def init( try: if retention_period: self._retention: int | None = int( - humanfriendly.parse_timespan(retention_period) + humanfriendly.parse_timespan(retention_period), ) else: self._retention = None @@ -787,7 +798,7 @@ def init( self._sv_obj.tags = tags self._sv_obj.metadata = ( (metadata or {}) - | git_info(os.getcwd()) + | git_info(pathlib.Path.cwd()) | environment(env_var_glob_exprs=record_shell_vars) ) self._sv_obj.heartbeat_timeout = timeout @@ -798,7 +809,7 @@ def init( if self._status == "running": self._sv_obj.system = get_system() - self._data = self._sv_obj._staging + self._data = self._sv_obj.staging self._sv_obj.commit() if not self.name: @@ -814,14 +825,16 @@ def init( fg="green" if self._term_color else None, ) click.secho( - f"[simvue] Monitor in the UI at {self._user_config.server.url.rsplit('/api', 1)[0]}/dashboard/runs/run/{self.id}", + "[simvue] Monitor in the UI at " + f"{self._user_config.server.url.rsplit('/api', 1)[0]}" + f"/dashboard/runs/run/{self.id}", bold=self._term_color, fg="green" if self._term_color else None, ) return True - @skip_if_failed("_aborted", "_suppress_errors", None) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=None) @pydantic.validate_call(config={"arbitrary_types_allowed": True}) def add_process( self, @@ -830,9 +843,7 @@ def add_process( executable: str | pathlib.Path | None = None, script: pydantic.FilePath | None = None, input_file: pydantic.FilePath | None = None, - completion_callback: typing.Optional[ - typing.Callable[[int, str, str], None] - ] = None, + completion_callback: typing.Callable[[int, str, str], None] | None = None, completion_trigger: threading.Event | multiprocessing.synchronize.Event | None = None, @@ -842,7 +853,8 @@ def add_process( ) -> None: """Add a process to be executed to the executor. - This process can take many forms, for example a be a set of positional arguments: + This process can take many forms, for example a be a set of + positional arguments: ```python executor.add_process("my_process", "ls", "-ltr") @@ -865,12 +877,14 @@ def add_process( ) ``` - or a mixture of both. In the latter case arguments which are not 'executable', 'script', 'input' - are taken to be options to the command, for flags `flag=True` can be used to set the option and - for options taking values `option=value`. + or a mixture of both. In the latter case arguments which are not + 'executable', 'script', 'input' are taken to be options to the command, + for flags `flag=True` can be used to set the option and for options + taking values `option=value`. - When the process has completed if a function has been provided for the `completion_callback` argument - this will be called, this callback is expected to take the following form: + When the process has completed if a function has been provided for the + `completion_callback` argument this will be called, this callback is + expected to take the following form: ```python def callback_function(status_code: int, std_out: str, std_err: str) -> None: ... @@ -878,24 +892,27 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None: ... Note `completion_callback` is not supported on Windows operating systems. - Alternatively you can use `completion_trigger` to create a multiprocessing event which will be set - when the process has completed. + Alternatively you can use `completion_trigger` to create a multiprocessing + event which will be set when the process has completed. Parameters ---------- identifier : str A unique identifier for this process executable : str | None, optional - the main executable for the command, if not specified this is taken to be the first - positional argument, by default None - *positional_arguments : Any, ..., optional - all other positional arguments are taken to be part of the command to execute + the main executable for the command, if not specified this is + taken to be the first positional argument, by default None + *cmd_args: Any, ..., optional + all other positional arguments are taken to be part of the + command to execute script : pydantic.FilePath | None, optional - the script to run, note this only work if the script is not an option, if this is the case - you should provide it as such and perform the upload manually, by default None + the script to run, note this only work if the script is not an option, + if this is the case you should provide it as such and perform the + upload manually, by default None input_file : pydantic.FilePath | None, optional - the input file to run, note this only work if the input file is not an option, if this is the case - you should provide it as such and perform the upload manually, by default None + the input file to run, note this only work if the input file is not an + option, if this is the case you should provide it as such and perform + the upload manually, by default None completion_callback : typing.Callable | None, optional callback to run when process terminates (not supported on Windows) completion_trigger : threading.Event | None, optional @@ -903,14 +920,15 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None: ... env : dict[str, str], optional environment variables for process cwd: pathlib.Path | None, optional - working directory to execute the process within. Note that executable, input and script file paths should - be absolute or relative to the directory where this method is called, not relative to the new working directory. - **kwargs : Any, ..., optional + working directory to execute the process within. Note that executable, + input and script file paths should be absolute or relative to the + directory where this method is called, not relative to the new + working directory. + **cmd_kwargs: Any, ..., optional all other keyword arguments are interpreted as options to the command Examples -------- - `run_count.sh` ```sh #!/bin/bash @@ -934,17 +952,13 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None: ... script="run_count.sh" ) ``` - """ - if isinstance(completion_trigger, multiprocessing.synchronize.Event): - warnings.warn( - "Use of a 'multiprocessing.Event' as a termination trigger will be deprecated in v2.5, " - + "use an instance of 'threading.Event' instead." - ) + """ if platform.system() == "Windows" and completion_trigger: raise RuntimeError( - "Use of 'completion_trigger' on Windows based operating systems is unsupported " - "due to function pickling restrictions for multiprocessing" + "Use of 'completion_trigger' on Windows based operating systems " + "is unsupported due to function pickling restrictions for " + "multiprocessing", ) if isinstance(executable, pathlib.Path) and not executable.is_file(): @@ -973,11 +987,11 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None: ... else: cmd_list += [f"-{kwarg}{(f' {_quoted_val}') if val else ''}"] else: - kwarg = kwarg.replace("_", "-") + _kwarg = kwarg.replace("_", "-") if isinstance(val, bool) and val: - cmd_list += [f"--{kwarg}"] + cmd_list += [f"--{_kwarg}"] else: - cmd_list += [f"--{kwarg}{(f' {_quoted_val}') if val else ''}"] + cmd_list += [f"--{_kwarg}{(f' {_quoted_val}') if val else ''}"] cmd_list += pos_args cmd_str = shlex.join(cmd_list) @@ -992,7 +1006,7 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None: ... executable=executable_str, script=script, input_file=input_file, - completion_callback=completion_callback, # type: ignore + completion_callback=completion_callback, completion_trigger=completion_trigger, env=env, cwd=cwd, @@ -1001,12 +1015,13 @@ def callback_function(status_code: int, std_out: str, std_err: str) -> None: ... @pydantic.validate_call def kill_process(self, process_id: str) -> None: - """Kill a running process by ID + """Kill a running process by ID. Parameters ---------- process_id : str the unique identifier for the added process + """ self._executor.kill_process(process_id) @@ -1034,15 +1049,20 @@ def _get_child_processes(self) -> list[psutil.Process]: @property def executor(self) -> Executor: - """Return the executor for this run""" + """Return the executor for this run.""" return self._executor + @property + def dispatcher(self) -> DispatcherBaseClass | None: + """Return the dispatcher for this run.""" + return self._dispatcher + @property def name(self) -> str | None: - """Return the name of the run""" + """Return the name of the run.""" if not self._sv_obj: logger.warning( - "Attempted to get name on non initialized run - returning None" + "Attempted to get name on non initialized run - returning None", ) return None return self._sv_obj.name @@ -1052,37 +1072,42 @@ def status( self, ) -> ( typing.Literal[ - "created", "running", "completed", "failed", "terminated", "lost" + "created", + "running", + "completed", + "failed", + "terminated", + "lost", ] | None ): - """Return the status of the run""" + """Return the status of the run.""" if not self._sv_obj: logger.warning( - "Attempted to get name on non initialized run - returning cached value" + "Attempted to get name on non initialized run - returning cached value", ) return self._status return self._sv_obj.status @property def uid(self) -> str: - """Return the local unique identifier of the run""" + """Return the local unique identifier of the run.""" return self._uuid @property def id(self) -> str | None: - """Return the unique id of the run""" + """Return the unique id of the run.""" if not self._sv_obj: logger.warning( - "Attempted to get name on non initialized run - returning None" + "Attempted to get name on non initialized run - returning None", ) return None return self._sv_obj.id - @skip_if_failed("_aborted", "_suppress_errors", False) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=False) @pydantic.validate_call def reconnect(self, run_id: str) -> bool: - """Reconnect to a run in the created state + """Reconnect to a run in the created state. Parameters ---------- @@ -1093,6 +1118,7 @@ def reconnect(self, run_id: str) -> bool: ------- bool whether reconnection succeeded + """ self._status = "running" @@ -1110,10 +1136,10 @@ def reconnect(self, run_id: str) -> bool: return True - @skip_if_failed("_aborted", "_suppress_errors", None) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=None) @pydantic.validate_call def set_pid(self, pid: int) -> None: - """Set pid of process to be monitored + """Set pid of process to be monitored. Parameters ---------- @@ -1122,7 +1148,6 @@ def set_pid(self, pid: int) -> None: Examples -------- - ```python import subprocess @@ -1139,17 +1164,19 @@ def set_pid(self, pid: int) -> None: with simvue.Run() as run: run.init("pid_track") run.set_pid(process_pid) + """ self._pid = pid self._parent_process = psutil.Process(self._pid) self._child_processes = self._get_child_processes() - # Get CPU usage stats for each of those new processes, so that next time it's measured by the heartbeat the value is accurate + # Get CPU usage stats for each of those new processes, so that next time it's + # measured by the heartbeat the value is accurate [ _process.cpu_percent() - for _process in self._child_processes + [self._parent_process] + for _process in (*self._child_processes, self._parent_process) ] - @skip_if_failed("_aborted", "_suppress_errors", False) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=False) @pydantic.validate_call def config( self, @@ -1162,7 +1189,7 @@ def config( storage_id: str | None = None, abort_on_alert: typing.Literal["run", "terminate", "ignore"] | None = None, ) -> bool: - """Optional configuration + """Optional configuration. Parameters ---------- @@ -1189,8 +1216,8 @@ def config( ------- bool if configuration was successful - """ + """ with self._configuration_lock: if suppress_errors is not None: self._suppress_errors = suppress_errors @@ -1200,7 +1227,8 @@ def config( if system_metrics_interval and disable_resources_metrics: self._error( - "Setting of resource metric interval and disabling resource metrics is ambiguous" + "Setting of resource metric interval and disabling " + "resource metrics is ambiguous", ) return False @@ -1210,7 +1238,7 @@ def config( if disable_resources_metrics: if self._emissions_monitor: self._error( - "Emissions metrics require resource metrics collection." + "Emissions metrics require resource metrics collection.", ) return False self._pid = None @@ -1219,7 +1247,8 @@ def config( if enable_emission_metrics: if not self._system_metrics_interval: self._error( - "Emissions metrics require resource metrics collection - make sure resource metrics are enabled!" + "Emissions metrics require resource metrics collection " + "- make sure resource metrics are enabled!", ) return False if self.mode == "offline": @@ -1247,13 +1276,6 @@ def config( self._error("Cannot disable emissions monitor once it has been started") if abort_on_alert is not None: - if isinstance(abort_on_alert, bool): - raise ( - TypeError( - "Use of type bool for argument 'abort_on_alert' has been removed, " - "please use either 'run', 'all' or 'ignore'" - ) - ) self._abort_on_alert = abort_on_alert if storage_id: @@ -1261,11 +1283,11 @@ def config( return True - @skip_if_failed("_aborted", "_suppress_errors", False) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=False) @check_run_initialised @pydantic.validate_call def update_metadata(self, metadata: dict[str, typing.Any]) -> bool: - """Update metadata for this run + """Update metadata for this run. Parameters ---------- @@ -1276,6 +1298,7 @@ def update_metadata(self, metadata: dict[str, typing.Any]) -> bool: ------- bool if the update was successful + """ if not self._sv_obj: self._error("Cannot update metadata, run not initialised") @@ -1292,11 +1315,11 @@ def update_metadata(self, metadata: dict[str, typing.Any]) -> bool: return True - @skip_if_failed("_aborted", "_suppress_errors", False) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=False) @check_run_initialised @pydantic.validate_call def set_tags(self, tags: list[str]) -> bool: - """Set tags for this run + """Set tags for this run. Parameters ---------- @@ -1316,6 +1339,7 @@ def set_tags(self, tags: list[str]) -> bool: run.init(tags=["old", "tag", "set"]) run.set_tags(["new", "tag", "set"]) ``` + """ if not self._sv_obj: self._error("Cannot update tags, run not initialised") @@ -1326,11 +1350,11 @@ def set_tags(self, tags: list[str]) -> bool: return True - @skip_if_failed("_aborted", "_suppress_errors", False) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=False) @check_run_initialised @pydantic.validate_call def update_tags(self, tags: list[str]) -> bool: - """Add additional tags to this run without duplication + """Add additional tags to this run without duplication. Parameters ---------- @@ -1350,6 +1374,7 @@ def update_tags(self, tags: list[str]) -> bool: run.init(tags=["current_tag"]) run.update_tags(["additional_tag"]) ``` + """ if not self._sv_obj: return False @@ -1368,7 +1393,7 @@ def update_tags(self, tags: list[str]) -> bool: return True - @skip_if_failed("_aborted", "_suppress_errors", False) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=False) @check_run_initialised @pydantic.validate_call(config={"validate_default": True}) def log_event( @@ -1376,11 +1401,12 @@ def log_event( message: str, *, timestamp: typing.Annotated[ - datetime.datetime | str | None, pydantic.BeforeValidator(simvue_timestamp) + datetime.datetime | str | None, + pydantic.BeforeValidator(simvue_timestamp), ] = None, log_level: LogLevel | None = None, ) -> bool: - """Log event to the server + """Log event to the server. Parameters ---------- @@ -1413,6 +1439,7 @@ def log_event( log_level="debug" ) ``` + """ if self._aborted: return False @@ -1447,7 +1474,9 @@ def log_event( "log_level": log_level or "info", } self._dispatcher.add_item( - _data, object_type="events", blocking=self._queue_blocking + _data, + object_type="events", + blocking=self._queue_blocking, ) return True @@ -1478,7 +1507,8 @@ def _add_metrics_to_dispatch( if self._status != "running": self._error( - "Cannot log metrics when not in the running state", join_on_fail + "Cannot log metrics when not in the running state", + join_on_fail, ) return False @@ -1498,17 +1528,17 @@ def _add_metrics_to_dispatch( _data, object_type="metrics_regular", blocking=self._queue_blocking, - metadata=dict(object_size=len(metrics)), + metadata={"object_size": len(metrics)}, ) except ObjectDispatchError as e: - logger.warning(f"Failed to log metric {id(_data)}: {e.msg}") + logger.warning("Failed to log metric %s: %s", id(_data), e.msg) self._failed_metric_counter += 1 return True def _add_tensors_to_dispatch( self, - tensors: dict[str, numpy.ndarray], + tensors: dict[str, np.ndarray], *, step: int | None = None, time: float | None = None, @@ -1532,7 +1562,8 @@ def _add_tensors_to_dispatch( if self._status != "running": self._error( - "Cannot log tensors when not in the running state", join_on_fail + "Cannot log tensors when not in the running state", + join_on_fail, ) return False @@ -1555,15 +1586,15 @@ def _add_tensors_to_dispatch( _data, object_type="metrics_tensor", blocking=self._queue_blocking, - metadata=dict(object_size=array.size), + metadata={"object_size": array.size}, ) except ObjectDispatchError as e: - logger.warning(f"Failed to grid metric {id(_data)}: {e.msg}") + logger.warning("Failed to grid metric %s: %s", id(_data), e.msg) self._failed_metric_counter += 1 return True - @skip_if_failed("_aborted", "_suppress_errors", False) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=False) @check_run_initialised @pydantic.validate_call(config={"arbitrary_types_allowed": True}) def assign_metric_to_grid( @@ -1571,7 +1602,7 @@ def assign_metric_to_grid( *, metric_name: str, grid_name: str | None = None, - axes_ticks: numpy.ndarray | list[list[float]] | None = None, + axes_ticks: np.ndarray | list[list[float]] | None = None, axes_labels: list[str] | None = None, ) -> bool: """Assign a metric to a new/existing tensor-based metric grid. @@ -1599,7 +1630,6 @@ def assign_metric_to_grid( Examples -------- - ```python with simvue.Run() as run: @@ -1615,8 +1645,9 @@ def assign_metric_to_grid( run.log_metrics({"G": numpy.random.random(10000).reshape((100, 100))}) ``` + """ - if isinstance(axes_ticks, numpy.ndarray): + if isinstance(axes_ticks, np.ndarray): axes_ticks = axes_ticks.tolist() grid_name = grid_name or metric_name @@ -1657,27 +1688,28 @@ def assign_metric_to_grid( server_url=self._user_config.server.url, server_token=self._user_config.server.token, ) - _grid_attach.read_only(False) + _grid_attach.read_only(is_read_only=False) _grid_attach.attach_metric_for_run(self.id, metric_name) self._grids[metric_name] = self._grids[grid_name] except (RuntimeError, ObjectNotFoundError) as e: self._error( - f"Failed to attach run '{self.id}' to grid '{grid_name}': {e.args[0]}" + f"Failed to attach run '{self.id}' to grid '{grid_name}': {e.args[0]}", ) return True - @skip_if_failed("_aborted", "_suppress_errors", False) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=False) @check_run_initialised @pydantic.validate_call( - config={"arbitrary_types_allowed": True, "validate_default": True} + config={"arbitrary_types_allowed": True, "validate_default": True}, ) def log_metrics( self, - metrics: dict[MetricKeyString, int | float | numpy.ndarray], + metrics: dict[MetricKeyString, int | float | np.ndarray], step: int | None = None, time: float | None = None, timestamp: typing.Annotated[ - datetime.datetime | str | None, pydantic.BeforeValidator(simvue_timestamp) + datetime.datetime | str | None, + pydantic.BeforeValidator(simvue_timestamp), ] = None, ) -> bool: """Log metrics to Simvue server. @@ -1723,8 +1755,8 @@ def log_metrics( time=-10, ) ``` - """ + """ # If there are any metric units to be uploaded do so now if _units := self._meta_cache.get("metrics"): self.update_metadata({"simvue": {"metrics": _units}}) @@ -1733,27 +1765,28 @@ def log_metrics( # TODO: When metrics and grids are combined into a single entity # this can be removed. For now need to separate tensor based metrics # from regular - _tensor_metrics: dict[str, numpy.ndarray] = {} + _tensor_metrics: dict[str, np.ndarray] = {} _regular_metrics: dict[str, int | float] = {} # Classify metrics into regular and tensor based for label, metric in metrics.items(): - if isinstance(metric, numpy.ndarray): + if isinstance(metric, np.ndarray): if metric.size > MAXIMUM_GRID_METRIC_SIZE: logger.warning( - f"Cannot log grid metric {label}, " - + f"size {metric.size} exceeds limit of {MAXIMUM_GRID_METRIC_SIZE}" + "Cannot log grid metric %s, size %d exceeds limit of %d", + label, + metric.size, + MAXIMUM_GRID_METRIC_SIZE, ) continue if label not in self._grids: logger.warning( - f"Metric '{label}' is not assigned to a grid, " - + "using default axis range [0, 1] for all axes " - + "and assuming constant interval." + "Metric '%s' is not assigned to a grid, " + "using default axis range [0, 1] for all axes " + "and assuming constant interval.", + label, ) - _axes_ticks = [ - numpy.linspace(0, 1, n) for n in reversed(metric.shape) - ] + _axes_ticks = [np.linspace(0, 1, n) for n in reversed(metric.shape)] self.assign_metric_to_grid( metric_name=label, grid_name=label, @@ -1763,7 +1796,7 @@ def log_metrics( if metric.ndim != (_ndims := self._grids[label]["dimensionality"]): self._error( f"Cannot log tensor '{label}', " - + f"dimensionality incompatibility: {metric.ndim} != {_ndims}" + f"dimensionality incompatibility: {metric.ndim} != {_ndims}", ) _tensor_metrics[label] = metric else: @@ -1782,7 +1815,8 @@ def log_metrics( if self._status != "running": self._error( - "Cannot log metrics when not in the running state", join_threads=True + "Cannot log metrics when not in the running state", + join_threads=True, ) return False @@ -1797,23 +1831,27 @@ def log_metrics( timestamp=timestamp, ) _regular_dispatch = self._add_metrics_to_dispatch( - metrics=_regular_metrics, step=step, time=time, timestamp=timestamp + metrics=_regular_metrics, + step=step, + time=time, + timestamp=timestamp, ) self._step += 1 return _tensor_add_dispatch and _regular_dispatch - @skip_if_failed("_aborted", "_suppress_errors", False) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=False) @check_run_initialised @pydantic.validate_call def save_object( self, obj: typing.Any, + *, category: typing.Literal["input", "output", "code"], name: typing.Annotated[str, pydantic.Field(pattern=NAME_REGEX)] | None = None, allow_pickle: bool = False, metadata: dict[str, typing.Any] | None = None, ) -> bool: - """Save an object to the Simvue server + """Save an object to the Simvue server. Parameters ---------- @@ -1827,7 +1865,8 @@ def save_object( name : str, optional name to associate with this object, by default None allow_pickle : bool, optional - whether to allow pickling if all other serialization types fail, by default False + whether to allow pickling if all other serialization + types fail, by default False metadata : str | None, optional any metadata to attach to the artifact @@ -1848,6 +1887,7 @@ def save_object( name="x" ) ``` + """ if not self._sv_obj or not self.id: self._error("Cannot save files, run not initialised") @@ -1873,12 +1913,13 @@ def save_object( return True - @skip_if_failed("_aborted", "_suppress_errors", False) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=False) @check_run_initialised @pydantic.validate_call def save_file( self, file_path: pydantic.FilePath, + *, category: typing.Literal["input", "output", "code"], file_type: str | None = None, preserve_path: bool = False, @@ -1886,7 +1927,7 @@ def save_file( name: typing.Annotated[str, pydantic.Field(pattern=NAME_REGEX)] | None = None, metadata: dict[str, typing.Any] | None = None, ) -> bool: - """Upload file to the server + """Upload file to the server. Parameters ---------- @@ -1912,6 +1953,7 @@ def save_file( ------- bool whether the upload was successful + """ if not self._sv_obj or not self.id: self._error("Cannot save files, run not initialised") @@ -1926,7 +1968,7 @@ def save_file( if preserve_path and stored_file_name.startswith("./"): stored_file_name = stored_file_name[2:] elif not preserve_path: - stored_file_name = os.path.basename(file_path) + stored_file_name = file_path.name try: # Register file @@ -1948,17 +1990,18 @@ def save_file( return True - @skip_if_failed("_aborted", "_suppress_errors", False) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=False) @check_run_initialised @pydantic.validate_call def save_directory( self, directory: pydantic.DirectoryPath, + *, category: typing.Literal["output", "input", "code"], file_type: str | None = None, preserve_path: bool = False, ) -> bool: - """Upload files from a whole directory + """Upload files from a whole directory. Parameters ---------- @@ -1978,6 +2021,7 @@ def save_directory( ------- bool if the directory save was successful + """ if not self._sv_obj: self._error("Cannot save directory, run not inirialised") @@ -1997,17 +2041,18 @@ def save_directory( return True - @skip_if_failed("_aborted", "_suppress_errors", False) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=False) @check_run_initialised @pydantic.validate_call def save_all( self, items: list[pydantic.FilePath | pydantic.DirectoryPath], + *, category: typing.Literal["input", "output", "code"], file_type: str | None = None, preserve_path: bool = False, ) -> bool: - """Save a set of files and directories + """Save a set of files and directories. Parameters ---------- @@ -2027,13 +2072,17 @@ def save_all( ------- bool whether the save was successful + """ for item in items: if item.is_file(): save_file = self.save_file(item, category, file_type, preserve_path) elif item.is_dir(): save_file = self.save_directory( - item, category, file_type, preserve_path + item, + category, + file_type, + preserve_path, ) else: self._error(f"{item}: No such file or directory") @@ -2043,13 +2092,14 @@ def save_all( return True - @skip_if_failed("_aborted", "_suppress_errors", False) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=False) @check_run_initialised @pydantic.validate_call def set_status( - self, status: typing.Literal["completed", "failed", "terminated"] + self, + status: typing.Literal["completed", "failed", "terminated"], ) -> bool: - """Set run status + """Set run status. status to assign to this run once finished @@ -2065,6 +2115,7 @@ def set_status( ------- bool if status update was successful + """ if not self._active: self._error("Run is not active") @@ -2101,7 +2152,8 @@ def _tidy_run(self) -> None: if self._sv_obj and self.mode == "offline" and self._status != "created": self._user_config.offline.cache.joinpath( - "runs", f"{self.id}.closed" + "runs", + f"{self.id}.closed", ).touch() if _non_zero := self.executor.exit_status: @@ -2113,31 +2165,33 @@ def _tidy_run(self) -> None: _error_msg = f":\n{_error_msg}" click.secho( "[simvue] Process executor terminated with non-zero exit status " - + f"{_non_zero}{_error_msg}", + f"{_non_zero}{_error_msg}", fg="red" if self._term_color else None, bold=self._term_color, ) sys.exit(_non_zero) if self._failed_metric_counter: click.secho( - "[simvue] Run completed with {self._failed_metric_counter} failed metrics.", + f"[simvue] Run completed with {self._failed_metric_counter} " + "failed metrics.", fg="yellow" if self._term_color else None, bold=self._term_color, ) sys.exit(1) - @skip_if_failed("_aborted", "_suppress_errors", False) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=False) def close(self) -> bool: - """Close the run + """Close the run. Returns ------- bool whether close was successful + """ if self._context_manager_called: self._error("Cannot call close method in context manager.") - return + return None self._executor.wait_for_completion() @@ -2153,7 +2207,7 @@ def close(self) -> bool: return True - @skip_if_failed("_aborted", "_suppress_errors", False) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=False) @check_run_initialised @pydantic.validate_call def set_folder_details( @@ -2162,7 +2216,7 @@ def set_folder_details( tags: list[str] | None = None, description: str | None = None, ) -> bool: - """Add metadata to the specified folder + """Add metadata to the specified folder. Parameters ---------- @@ -2177,6 +2231,7 @@ def set_folder_details( ------- bool returns True if update was successful + """ if not self._folder: self._error("Cannot update folder details, run was not initialised") @@ -2200,7 +2255,7 @@ def set_folder_details( return True - @skip_if_failed("_aborted", "_suppress_errors", False) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=False) @check_run_initialised @pydantic.validate_call def add_alerts( @@ -2208,7 +2263,7 @@ def add_alerts( ids: list[str] | None = None, names: list[str] | None = None, ) -> bool: - """Add a set of existing alerts to this run by name or id + """Add a set of existing alerts to this run by name or id. Parameters ---------- @@ -2221,6 +2276,7 @@ def add_alerts( ------- bool returns True if successful + """ if not self._sv_obj: self._error("Cannot add alerts, run not initialised") @@ -2232,7 +2288,8 @@ def add_alerts( if names and not ids: if self.mode == "offline": self._error( - "Cannot retrieve alerts based on names in offline mode - please use IDs instead." + "Cannot retrieve alerts based on names in offline mode " + "- please use IDs instead.", ) return False try: @@ -2241,7 +2298,7 @@ def add_alerts( server_url=self._user_config.server.url, server_token=self._user_config.server.token, ): - ids += [id for id, alert in alerts if alert.name in names] + ids += [_id for _id, alert in alerts if alert.name in names] else: self._error("No existing alerts") return False @@ -2271,7 +2328,7 @@ def _check_if_alert_exists(self, alert: "AlertBase") -> str | None: return _id return None - @skip_if_failed("_aborted", "_suppress_errors", None) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=None) @pydantic.validate_call def create_metric_range_alert( self, @@ -2285,7 +2342,10 @@ def create_metric_range_alert( window: pydantic.PositiveInt | None = None, frequency: pydantic.PositiveInt = 1, aggregation: typing.Literal[ - "average", "sum", "at least one", "all" + "average", + "sum", + "at least one", + "all", ] = "average", notification: typing.Literal["email", "none"] = "none", trigger_abort: bool = False, @@ -2319,7 +2379,7 @@ def create_metric_range_alert( method to use when aggregating metrics within time window * average - average across all values in window (default). * sum - take the sum of all values within window. - * at least one - returns if at least one value in window satisfy condition. + * at least one - returns if at least window value satisfies condition. * all - returns if all values in window satisfy condition. notification : Literal['email', 'none'], optional whether to notify on trigger @@ -2364,7 +2424,7 @@ def create_metric_range_alert( self.add_alerts(ids=[_alert.id]) return _alert.id - @skip_if_failed("_aborted", "_suppress_errors", None) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=None) @pydantic.validate_call def create_metric_threshold_alert( self, @@ -2377,15 +2437,18 @@ def create_metric_threshold_alert( window: pydantic.PositiveInt | None = None, frequency: pydantic.PositiveInt = 1, aggregation: typing.Literal[ - "average", "sum", "at least one", "all" + "average", + "sum", + "at least one", + "all", ] = "average", notification: typing.Literal["email", "none"] = "none", trigger_abort: bool = False, attach_to_run: bool = True, ) -> str | None: - """Creates a metric threshold alert with the specified name (if it doesn't exist) - and applies it to the current run. If alert already exists it will - not be duplicated. + """Creates a metric threshold alert with the specified name + (if it doesn't exist) and applies it to the current run. + If alert already exists it will not be duplicated. Parameters ---------- @@ -2410,7 +2473,7 @@ def create_metric_threshold_alert( method to use when aggregating metrics within time window * average - average across all values in window (default). * sum - take the sum of all values within window. - * at least one - returns if at least one value in window satisfy condition. + * at least one - returns if at least window value satisfies condition. * all - returns if all values in window satisfy condition. notification : Literal['email', 'none'], optional whether to notify on trigger @@ -2454,7 +2517,7 @@ def create_metric_threshold_alert( self.add_alerts(ids=[_alert.id]) return _alert.id - @skip_if_failed("_aborted", "_suppress_errors", None) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=None) @pydantic.validate_call def create_event_alert( self, @@ -2477,6 +2540,8 @@ def create_event_alert( name of alert pattern : str, optional for event based alerts pattern to look for, by default None + description : str, optional + one line description for this alert frequency : PositiveInt, optional frequency at which to check alert condition in seconds, by default None notification : Literal['email', 'none'], optional @@ -2519,7 +2584,7 @@ def create_event_alert( return _alert.id - @skip_if_failed("_aborted", "_suppress_errors", None) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=None) @pydantic.validate_call def create_user_alert( self, @@ -2555,7 +2620,6 @@ def create_user_alert( returns the created alert ID if successful """ - _alert = UserAlert.new( name=name, notification=notification, @@ -2577,7 +2641,7 @@ def create_user_alert( self.add_alerts(ids=[_alert.id]) return _alert.id - @skip_if_failed("_aborted", "_suppress_errors", False) + @skip_if_failed("_aborted", "_suppress_errors", on_failure_return=False) @check_run_initialised @pydantic.validate_call def log_alert( @@ -2603,8 +2667,9 @@ def log_alert( ------- bool whether alert state update was successful + """ - if state not in ("ok", "critical"): + if state not in {"ok", "critical"}: self._error('state must be either "ok" or "critical"') return False @@ -2614,7 +2679,8 @@ def log_alert( if name and self.mode == "offline": self._error( - "Cannot retrieve alerts based on names in offline mode - please use IDs instead." + "Cannot retrieve alerts based on names in offline mode " + "- please use IDs instead.", ) return False @@ -2622,7 +2688,8 @@ def log_alert( try: if alerts := Alert.get(offline=self.mode == "offline"): identifier = next( - (id for id, alert in alerts if alert.name == name), None + (_id for _id, alert in alerts if alert.name == name), + None, ) else: self._error("No existing alerts") @@ -2642,10 +2709,10 @@ def log_alert( if not isinstance(_alert, UserAlert): self._error( f"Cannot update state for alert '{identifier}' " - f"of type '{_alert.__class__.__name__.lower()}'" + f"of type '{_alert.__class__.__name__.lower()}'", ) return False - _alert.read_only(False) + _alert.read_only(is_read_only=False) _alert.set_status(run_id=self.id, status=state) _alert.commit() @@ -2669,6 +2736,25 @@ def set_metric_units( name of metric to assign units to units : str unit symbol + mks_unit : str | None, optional + relevant MKS unit for this unit, default is to infer + mks_conversion: float | None = None, optional + if providing custom unit, this is the conversion to the MKS unit + + Examples + -------- + ```python + with simvue.Run() as run: + run.init() + run.set_metric_units( + 'dimension_0', + units='Å', + mks_unit='metre', + mks_conversion=1e-10 + ) + run.log_metrics({'dimension_0', 2}) + ``` + """ self._meta_cache.setdefault("metrics", {}) @@ -2679,7 +2765,7 @@ def set_metric_units( "mks_conversion": mks_conversion or float(_unit_obj.in_mks().value), "mks_units": mks_unit or f"{_unit_obj.in_mks().units}", } - except UnitParseError: + except (UnitParseError, ValueError): self._meta_cache["metrics"][metric_name] = { "units": units, "mks_conversion": mks_conversion, diff --git a/simvue/sender/__init__.py b/simvue/sender/__init__.py index eb89cd52..4e600170 100644 --- a/simvue/sender/__init__.py +++ b/simvue/sender/__init__.py @@ -1,5 +1,5 @@ """Simvue sender for sending locally cached data to the server.""" -from .base import Sender, UPLOAD_ORDER, UploadItem +from .base import UPLOAD_ORDER, Sender, UploadItem -__all__ = ["Sender", "UPLOAD_ORDER", "UploadItem"] +__all__ = ["UPLOAD_ORDER", "Sender", "UploadItem"] diff --git a/simvue/sender/actions.py b/simvue/sender/actions.py index a01949b1..8c0d9853 100644 --- a/simvue/sender/actions.py +++ b/simvue/sender/actions.py @@ -1,16 +1,17 @@ """Upload actions for cached files.""" import abc -from collections.abc import Generator -from concurrent.futures import ThreadPoolExecutor import http import json import logging import pathlib import threading import typing +from collections.abc import Generator +from concurrent.futures import ThreadPoolExecutor -import requests +if typing.TYPE_CHECKING: + import requests from simvue.api.objects import ( Alert, @@ -37,16 +38,17 @@ from simvue.api.objects.alert.fetch import AlertType from simvue.api.objects.artifact.base import ArtifactBase from simvue.api.objects.base import SimvueObject -from simvue.api.request import put as sv_put, get_json_from_response -from simvue.models import ObjectID +from simvue.api.request import get_json_from_response +from simvue.api.request import put as sv_put from simvue.config.user import SimvueConfiguration from simvue.eco import CO2Monitor +from simvue.models import ObjectID from simvue.run import Run as SimvueRun try: from typing import override except ImportError: - from typing_extensions import override # noqa: UP035 + from typing_extensions import override class UploadAction: @@ -71,12 +73,16 @@ def json_file(cls, cache_directory: pathlib.Path, offline_id: str) -> pathlib.Pa ------- pathlib.Path path of local JSON file + """ return cache_directory.joinpath(f"{cls.object_type}", f"{offline_id}.json") @classmethod def _log_upload_failed( - cls, cache_directory: pathlib.Path, offline_id: str, data: dict[str, typing.Any] + cls, + cache_directory: pathlib.Path, + offline_id: str, + data: dict[str, typing.Any], ) -> None: """Log a failing upload to the local cache.""" data["upload_failed"] = True @@ -96,19 +102,23 @@ def count(cls, cache_directory: pathlib.Path) -> int: ------- int the number of objects of this type pending upload. + """ return len(list(cls.uploadable_objects(cache_directory))) @classmethod def pre_tasks( - cls, offline_id: str, data: dict[str, typing.Any], cache_directory: pathlib.Path + cls, + offline_id: str, + data: dict[str, typing.Any], + cache_directory: pathlib.Path, ) -> None: """Pre-upload actions. For this object type no pre-actions are performed. Parameters - ----------- + ---------- offline_id : str the offline identifier for the upload. online_id : str @@ -117,11 +127,11 @@ def pre_tasks( the data sent during upload. cache_directory : pathlib.Path the local cache directory to read from. + """ _ = offline_id _ = data _ = cache_directory - pass @classmethod def post_tasks( @@ -136,7 +146,7 @@ def post_tasks( Removes local JSON data on successful upload. Parameters - ----------- + ---------- offline_id : str the offline identifier for the upload. online_id : str @@ -145,6 +155,7 @@ def post_tasks( the data sent during upload. cache_directory : pathlib.Path the local cache directory to read from. + """ _ = data _ = online_id @@ -153,7 +164,9 @@ def post_tasks( @classmethod @abc.abstractmethod def initialise_object( - cls, online_id: ObjectID | None, **data + cls, + online_id: ObjectID | None, + **data, ) -> SimvueObject | None: """Initialise an instance of an object.""" _ = online_id @@ -175,6 +188,7 @@ def uploadable_objects(cls, cache_directory: pathlib.Path) -> Generator[str]: ------ str offline identifier + """ for file in cache_directory.glob(f"{cls.object_type}/*.json"): yield file.stem @@ -206,11 +220,14 @@ def _single_item_upload( try: cls.pre_tasks( - offline_id=identifier, data=_data, cache_directory=cache_directory + offline_id=identifier, + data=_data, + cache_directory=cache_directory, ) _object = cls.initialise_object( - online_id=id_mapping.get(identifier), **_data + online_id=id_mapping.get(identifier), + **_data, ) if not _object: @@ -223,7 +240,7 @@ def _single_item_upload( if not isinstance(_object, ArtifactBase): _object.commit() - _object.read_only(True) + _object.read_only(is_read_only=True) except Exception as err: if throw_exceptions: @@ -234,7 +251,8 @@ def _single_item_upload( if simvue_monitor_run: simvue_monitor_run.log_event(_exception_msg) simvue_monitor_run.log_alert( - name="sender_object_upload_failure", state="critical" + name="sender_object_upload_failure", + state="critical", ) cls.logger.error(_exception_msg) cls._log_upload_failed(cache_directory, identifier, _data) @@ -261,7 +279,9 @@ def _single_item_upload( id_mapping[identifier] = _object.id else: cls.logger.info( - "%s %s", "Updated" if id_mapping.get(identifier) else "Created", _label + "%s %s", + "Updated" if id_mapping.get(identifier) else "Created", + _label, ) if upload_status is not None: @@ -270,7 +290,7 @@ def _single_item_upload( upload_status[cls.object_type] += 1 if simvue_monitor_run: simvue_monitor_run.log_metrics( - {f"uploads.{cls.object_type}": upload_status[cls.object_type]} + {f"uploads.{cls.object_type}": upload_status[cls.object_type]}, ) cls.post_tasks( @@ -317,6 +337,7 @@ def upload( whether to retry failed uploads, default True. upload_status : dict[str, int | float] | None, optional a mapping which will be updated with upload status, default None. + """ _iterable = cls.uploadable_objects(cache_directory) if cls.count(cache_directory) < threading_threshold: @@ -371,7 +392,7 @@ def pre_tasks( preparation for the upload. Parameters - ----------- + ---------- offline_id : str the offline identifier for the upload. online_id : str @@ -380,11 +401,12 @@ def pre_tasks( the data sent during upload. cache_directory : pathlib.Path the local cache directory to read from. + """ if data["obj_type"] != "ObjectArtifact": return with cache_directory.joinpath(cls.object_type, f"{offline_id}.object").open( - "rb" + "rb", ) as in_f: data["serialized"] = in_f.read() @@ -403,7 +425,7 @@ def post_tasks( is object-based the locally serialized data is removed. Parameters - ----------- + ---------- offline_id : str the offline identifier for the upload. online_id : str @@ -412,6 +434,7 @@ def post_tasks( the data sent during upload. cache_directory : pathlib.Path the local cache directory to read from. + """ _ = online_id super().post_tasks( @@ -427,7 +450,9 @@ def post_tasks( @override @classmethod def initialise_object( - cls, online_id: ObjectID | None, **data + cls, + online_id: ObjectID | None, + **data, ) -> FileArtifact | ObjectArtifact: """Initialise/update an Artifact object. @@ -442,6 +467,7 @@ def initialise_object( ------- simvue.api.objects.FileArtifact | simvue.api.objects.ObjectArtifact a local representation of the server object. + """ if not online_id: if data.get("file_path"): @@ -449,8 +475,7 @@ def initialise_object( return ObjectArtifact.new(**data) - _sv_obj = Artifact(identifier=online_id, _read_only=False, **data) - return _sv_obj + return Artifact(identifier=online_id, _read_only=False, **data) class RunUploadAction(UploadAction): @@ -472,6 +497,7 @@ def initialise_object(cls, online_id: ObjectID | None, **data) -> Run: ------- simvue.api.objects.Run a local representation of the server object. + """ if not online_id: return Run.new(**data) @@ -493,7 +519,7 @@ def post_tasks( of additional files defining related identifiers. Parameters - ----------- + ---------- offline_id : str the offline identifier for the upload. online_id : str @@ -502,6 +528,7 @@ def post_tasks( the data sent during upload. cache_directory : pathlib.Path the local cache directory to read from. + """ super().post_tasks( offline_id=offline_id, @@ -511,11 +538,12 @@ def post_tasks( ) _ = cache_directory.joinpath("server_ids", f"{offline_id}.txt").write_text( - online_id + online_id, ) if not cache_directory.joinpath( - cls.object_type, f"{offline_id}.closed" + cls.object_type, + f"{offline_id}.closed", ).exists(): return @@ -551,6 +579,7 @@ def initialise_object(cls, online_id: ObjectID | None, **data) -> Folder: ------- simvue.api.objects.Folder a local representation of the server object. + """ if not online_id: return Folder.new(**data) @@ -572,7 +601,7 @@ def post_tasks( of additional files defining related identifiers. Parameters - ----------- + ---------- offline_id : str the offline identifier for the upload. online_id : str @@ -581,6 +610,7 @@ def post_tasks( the data sent during upload. cache_directory : pathlib.Path the local cache directory to read from. + """ super().post_tasks( offline_id=offline_id, @@ -590,7 +620,7 @@ def post_tasks( ) _ = cache_directory.joinpath("server_ids", f"{offline_id}.txt").write_text( - online_id + online_id, ) @@ -613,6 +643,7 @@ def initialise_object(cls, online_id: ObjectID | None, **data) -> Tenant: ------- simvue.api.objects.administrator.Tenant a local representation of the server object. + """ if not online_id: return Tenant.new(**data) @@ -639,6 +670,7 @@ def initialise_object(cls, online_id: ObjectID | None, **data) -> User: ------- simvue.api.objects.administrator.User a local representation of the server object. + """ if not online_id: return User.new(**data) @@ -665,6 +697,7 @@ def initialise_object(cls, online_id: ObjectID | None, **data) -> Tag: ------- simvue.api.objects.Tag a local representation of the server object. + """ if not online_id: return Tag.new(**data) @@ -686,7 +719,7 @@ def post_tasks( of additional files defining related identifiers. Parameters - ----------- + ---------- offline_id : str the offline identifier for the upload. online_id : str @@ -695,10 +728,11 @@ def post_tasks( the data sent during upload. cache_directory : pathlib.Path the local cache directory to read from. + """ super().post_tasks(offline_id, online_id, data, cache_directory) _ = cache_directory.joinpath("server_ids", f"{offline_id}.txt").write_text( - online_id + online_id, ) @@ -721,6 +755,7 @@ def initialise_object(cls, online_id: ObjectID | None, **data) -> AlertType: ------- simvue.api.objects.AlertType a local representation of the server object. + """ if not online_id: _source: str = data["source"] @@ -731,12 +766,11 @@ def initialise_object(cls, online_id: ObjectID | None, **data) -> AlertType: if _source == "events": return EventsAlert.new(**data) - elif _source == "metrics" and data.get("threshold"): + if _source == "metrics" and data.get("threshold"): return MetricsThresholdAlert.new(**data) - elif _source == "metrics": + if _source == "metrics": return MetricsRangeAlert.new(**data) - else: - return UserAlert.new(**data) + return UserAlert.new(**data) return Alert(identifier=online_id, _read_only=False, **data) @@ -755,7 +789,7 @@ def post_tasks( of additional files defining related identifiers. Parameters - ----------- + ---------- offline_id : str the offline identifier for the upload. online_id : str @@ -764,10 +798,11 @@ def post_tasks( the data sent during upload. cache_directory : pathlib.Path the local cache directory to read from. + """ super().post_tasks(offline_id, online_id, data, cache_directory) _ = cache_directory.joinpath("server_ids", f"{offline_id}.txt").write_text( - online_id + online_id, ) @@ -777,7 +812,9 @@ class StorageUploadAction(UploadAction): @classmethod @override def initialise_object( - cls, online_id: ObjectID | None, **data + cls, + online_id: ObjectID | None, + **data, ) -> S3Storage | FileStorage: """Initialise/update an Storage object. @@ -792,6 +829,7 @@ def initialise_object( ------- simvue.api.objects.S3Storage | simvue.api.objects.FileStorage a local representation of the server object. + """ if not online_id: if data.get("config", {}).get("endpoint_url"): @@ -821,6 +859,7 @@ def initialise_object(cls, online_id: ObjectID | None, **data) -> Grid: ------- simvue.api.objects.Grid a local representation of the server object. + """ if not online_id: return Grid.new(**data) @@ -849,6 +888,7 @@ def initialise_object(cls, online_id: ObjectID | None, **data) -> Metrics: ------- simvue.api.objects.Grid a local representation of the server object. + """ _ = online_id return Metrics.new(**data) @@ -875,6 +915,7 @@ def initialise_object(cls, online_id: ObjectID | None, **data) -> GridMetrics: ------- simvue.api.objects.GridMetrics a local representation of the server object. + """ _ = online_id return GridMetrics.new(**data) @@ -901,6 +942,7 @@ def initialise_object(cls, online_id: ObjectID | None, **data) -> Events: ------- simvue.api.objects.Events a local representation of the server object. + """ _ = online_id return Events.new(**data) @@ -920,13 +962,15 @@ def initialise_object(cls, online_id: ObjectID | None, **data) -> None: @override @classmethod def pre_tasks( - cls, offline_id: str, data: dict[str, typing.Any], cache_directory: pathlib.Path + cls, + offline_id: str, + data: dict[str, typing.Any], + cache_directory: pathlib.Path, ) -> None: """No pre-tasks for this action.""" _ = offline_id _ = data _ = cache_directory - pass @override @classmethod @@ -990,7 +1034,6 @@ def post_tasks( _ = data _ = cache_directory _ = online_id - pass class CO2IntensityUploadAction(UploadAction): @@ -1006,7 +1049,10 @@ def initialise_object(cls, online_id: ObjectID | None, **data) -> None: @override @classmethod def pre_tasks( - cls, offline_id: str, data: dict[str, typing.Any], cache_directory: pathlib.Path + cls, + offline_id: str, + data: dict[str, typing.Any], + cache_directory: pathlib.Path, ) -> None: """No pre-tasks for this action.""" _ = offline_id diff --git a/simvue/sender/base.py b/simvue/sender/base.py index c76dbf0c..ee7293c0 100644 --- a/simvue/sender/base.py +++ b/simvue/sender/base.py @@ -8,12 +8,13 @@ import logging import threading import typing -import pydantic + import psutil +import pydantic -from simvue.sender.actions import UPLOAD_ACTION_ORDER from simvue.config.user import SimvueConfiguration from simvue.run import Run +from simvue.sender.actions import UPLOAD_ACTION_ORDER logger = logging.getLogger(__name__) @@ -41,6 +42,7 @@ class Sender: @pydantic.validate_call def __init__( self, + *, cache_directory: pydantic.DirectoryPath | None = None, max_workers: pydantic.PositiveInt = 5, threading_threshold: pydantic.PositiveInt = 10, @@ -65,8 +67,13 @@ def __init__( default is False (exceptions will be logged) retry_failed_uploads : bool, optional Whether to retry sending objects which previously failed, by default False + run_notification : 'none' | 'all' | 'email', optional + Notification setting for the sender session, default is no notifications. + run_retention_period : str | None, optional + Specify the retention period as a string, default of None sets no limit. monitor_uploads : bool, optional Whether to track uploads as a Simvue run, by default False + """ _local_config: SimvueConfiguration = SimvueConfiguration.fetch(mode="online") self._cache_directory = cache_directory or _local_config.offline.cache @@ -92,7 +99,7 @@ def locked(self) -> bool: if not self._lock_path: raise RuntimeError("Expected lock file path, but none initialised.") return self._lock_path.exists() and psutil.pid_exists( - int(self._lock_path.read_text()) + int(self._lock_path.read_text()), ) @property @@ -113,7 +120,7 @@ def _release(self) -> None: def _initialise_monitor_run(self) -> Run: """Create a Simvue run for monitoring upload.""" _time_stamp: str = datetime.datetime.now(tz=datetime.UTC).strftime( - "%Y_%m_%d_%H_%M_%S" + "%Y_%m_%d_%H_%M_%S", ) _run = Run(mode="online") _ = _run.init( @@ -149,7 +156,9 @@ def upload(self, objects_to_upload: list[UploadItem] | None = None) -> None: Parameters ---------- objects_to_upload : list[str] - Types of objects to upload, by default uploads all types of objects present in cache + Types of objects to upload, by default uploads all types + of objects present in cache + """ self._lock() diff --git a/simvue/serialization.py b/simvue/serialization.py index 024c2efc..4b0c1cb6 100644 --- a/simvue/serialization.py +++ b/simvue/serialization.py @@ -1,33 +1,31 @@ -""" -Object Serialization +"""Object Serialization. ==================== Contains serializers for storage of objects on the Simvue server """ import contextlib -import typing -import pickle -import pandas import json -import numpy - +import pickle +import typing from io import BytesIO +import numpy as np +import pandas as pd + if typing.TYPE_CHECKING: - from pandas import DataFrame + from pd import DataFrame from plotly.graph_objects import Figure from torch import Tensor from typing_extensions import Buffer + from .types import DeserializedContent from .utilities import check_extra def _is_torch_tensor(data: typing.Any) -> bool: - """ - Check if value is a PyTorch tensor or state dict - """ + """Check if value is a PyTorch tensor or state dict.""" module_name = data.__class__.__module__ class_name = data.__class__.__name__ @@ -46,8 +44,8 @@ def _is_torch_tensor(data: typing.Any) -> bool: return False -def serialize_object(data: typing.Any, allow_pickle: bool) -> tuple[str, str] | None: - """Determine which serializer to use for the given object +def serialize_object(data: typing.Any, *, allow_pickle: bool) -> tuple[str, str] | None: + """Determine which serializer to use for the given object. Parameters ---------- @@ -60,25 +58,26 @@ def serialize_object(data: typing.Any, allow_pickle: bool) -> tuple[str, str] | ------- Callable[[typing.Any], tuple[str, str]] the serializer to use + """ module_name = data.__class__.__module__ class_name = data.__class__.__name__ if module_name == "plotly.graph_objs._figure" and class_name == "Figure": return _serialize_plotly_figure(data) - elif module_name == "matplotlib.figure" and class_name == "Figure": + if module_name == "matplotlib.figure" and class_name == "Figure": return _serialize_matplotlib_figure(data) - elif module_name == "numpy" and class_name == "ndarray": - return _serialize_numpy_array(data) - elif module_name == "pandas.core.frame" and class_name == "DataFrame": + if module_name == "numpy" and class_name == "ndarray": + return _serialize_np_array(data) + if module_name == "pandas.core.frame" and class_name == "DataFrame": return _serialize_dataframe(data) - elif _is_torch_tensor(data): + if _is_torch_tensor(data): return _serialize_torch_tensor(data) - elif module_name == "builtins" and class_name == "module" and not allow_pickle: + if module_name == "builtins" and class_name == "module" and not allow_pickle: with contextlib.suppress(ImportError): - import matplotlib.pyplot + import matplotlib.pyplot as plt - if data == matplotlib.pyplot: + if data == plt: return _serialize_matplotlib(data) elif serialized := _serialize_json(data): return serialized @@ -131,10 +130,10 @@ def _serialize_matplotlib_figure(data: typing.Any) -> tuple[str, str] | None: return data, mimetype -def _serialize_numpy_array(data: typing.Any) -> tuple[str, str] | None: +def _serialize_np_array(data: typing.Any) -> tuple[str, str] | None: mimetype = "application/vnd.simvue.numpy.v1" mfile = BytesIO() - numpy.save(mfile, data, allow_pickle=False) + np.save(mfile, data, allow_pickle=False) mfile.seek(0) data = mfile.read() return data, mimetype @@ -184,61 +183,59 @@ def _serialize_pickle(data: typing.Any) -> tuple[str, str] | None: def deserialize_data( - data: "Buffer", mimetype: str, allow_pickle: bool -) -> typing.Optional["DeserializedContent"]: - """ - Determine which deserializer to use - """ + data: "Buffer", + mimetype: str, + *, + allow_pickle: bool, +) -> "DeserializedContent | None": + """Determine which deserializer to use.""" if mimetype == "application/vnd.plotly.v1+json": return _deserialize_plotly_figure(data) - elif mimetype == "application/vnd.simvue.numpy.v1": + if mimetype == "application/vnd.simvue.numpy.v1": return _deserialize_numpy_array(data) - elif mimetype == "application/vnd.simvue.df.v1": + if mimetype == "application/vnd.simvue.df.v1": return _deserialize_dataframe(data) - elif mimetype == "application/vnd.simvue.torch.v1": + if mimetype == "application/vnd.simvue.torch.v1": return _deserialize_torch_tensor(data) - elif mimetype == "application/json": + if mimetype == "application/json": return _deserialize_json(data) - elif mimetype == "application/octet-stream" and allow_pickle: + if mimetype == "application/octet-stream" and allow_pickle: return _deserialize_pickle(data) return None @check_extra("plot") -def _deserialize_plotly_figure(data: "Buffer") -> typing.Optional["Figure"]: +def _deserialize_plotly_figure(data: "Buffer") -> "Figure | None": try: import plotly except ImportError: return None - data = plotly.io.from_json(data) - return data + return plotly.io.from_json(data) @check_extra("plot") -def _deserialize_matplotlib_figure(data: "Buffer") -> typing.Optional["Figure"]: +def _deserialize_matplotlib_figure(data: "Buffer") -> "Figure | None": try: import plotly except ImportError: return None - data = plotly.io.from_json(data) - return data + return plotly.io.from_json(data) def _deserialize_numpy_array(data: "Buffer") -> typing.Any | None: mfile = BytesIO(data) mfile.seek(0) - data = numpy.load(mfile, allow_pickle=False) - return data + return np.load(mfile, allow_pickle=False) -def _deserialize_dataframe(data: "Buffer") -> typing.Optional["DataFrame"]: +def _deserialize_dataframe(data: "Buffer") -> "DataFrame | None": mfile = BytesIO(data) mfile.seek(0) - return pandas.read_csv(mfile, index_col=0) + return pd.read_csv(mfile, index_col=0) @check_extra("torch") -def _deserialize_torch_tensor(data: "Buffer") -> typing.Optional["Tensor"]: +def _deserialize_torch_tensor(data: "Buffer") -> "Tensor | None": try: import torch except ImportError: @@ -251,10 +248,8 @@ def _deserialize_torch_tensor(data: "Buffer") -> typing.Optional["Tensor"]: def _deserialize_pickle(data) -> typing.Any | None: - data = pickle.loads(data) - return data + return pickle.loads(data) def _deserialize_json(data) -> typing.Any | None: - data = json.loads(data) - return data + return json.loads(data) diff --git a/simvue/simvue_types.py b/simvue/simvue_types.py index 95b3c46c..cb6e4592 100644 --- a/simvue/simvue_types.py +++ b/simvue/simvue_types.py @@ -3,7 +3,7 @@ try: from typing import TypeAlias except ImportError: - from typing_extensions import TypeAlias + from typing import TypeAlias if typing.TYPE_CHECKING: @@ -15,5 +15,10 @@ DeserializedContent: TypeAlias = typing.Union[ - "DataFrame", "ndarray", "Tensor", "Figure", "FigureWidget", "Buffer" + "DataFrame", + "ndarray", + "Tensor", + "Figure", + "FigureWidget", + "Buffer", ] diff --git a/simvue/system.py b/simvue/system.py index 84ce016b..d1db782f 100644 --- a/simvue/system.py +++ b/simvue/system.py @@ -1,37 +1,34 @@ -import os +import contextlib +import pathlib import platform +import shutil import socket import subprocess -import shutil import sys -import contextlib import typing def get_cpu_info(): - """ - Get CPU info - """ + """Get CPU info.""" model_name = "" arch = "" - if shutil.which("lscpu"): + if _lscpu := shutil.which("lscpu"): with contextlib.suppress(subprocess.CalledProcessError): - info = subprocess.check_output("lscpu").decode().strip() + info = subprocess.check_output(_lscpu).decode().strip() for line in info.split("\n"): if "Model name" in line: model_name = line.split(":")[1].strip() if "Architecture" in line: arch = line.split(":")[1].strip() - # TODO: Try /proc/cpuinfo if process fails - arch = arch or platform.machine() - if not model_name and shutil.which("sysctl"): + if not model_name and (_sysctl := shutil.which("sysctl")): with contextlib.suppress(subprocess.CalledProcessError): info = ( - subprocess.check_output(["sysctl", "machdep.cpu.brand_string"]) + subprocess + .check_output([_sysctl, "machdep.cpu.brand_string"]) .decode() .strip() ) @@ -42,32 +39,30 @@ def get_cpu_info(): def get_gpu_info(): - """ - Get GPU info - """ + """Get GPU info.""" _gpu_info: dict[str, str] = {"name": "", "driver_version": ""} - if shutil.which("nvidia-smi"): - with contextlib.suppress(subprocess.CalledProcessError, IndexError): - output = subprocess.check_output( - ["nvidia-smi", "--query-gpu=name,driver_version", "--format=csv"] - ) - lines = output.split(b"\n") - tokens = lines[1].split(b", ") - _gpu_info["name"] = tokens[0].decode() - _gpu_info["driver_version"] = tokens[1].decode() + if not (_nvidia_smi := shutil.which("nvidia-smi")): + return _gpu_info + + with contextlib.suppress(subprocess.CalledProcessError, IndexError): + output = subprocess.check_output( + [_nvidia_smi, "--query-gpu=name,driver_version", "--format=csv"], + ) + lines = output.split(b"\n") + tokens = lines[1].split(b", ") + _gpu_info["name"] = tokens[0].decode() + _gpu_info["driver_version"] = tokens[1].decode() return _gpu_info def get_system() -> dict[str, typing.Any]: - """ - Get system details - """ + """Get system details.""" cpu = get_cpu_info() gpu = get_gpu_info() - system: dict[str, typing.Any] = {"cwd": os.getcwd()} + system: dict[str, typing.Any] = {"cwd": f"{pathlib.Path.cwd()}"} system["hostname"] = socket.gethostname() system["pythonversion"] = ( f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" diff --git a/simvue/utilities.py b/simvue/utilities.py index ce2bff29..ee7bbed8 100644 --- a/simvue/utilities.py +++ b/simvue/utilities.py @@ -1,19 +1,19 @@ +import contextlib +import functools import hashlib -import logging +import importlib.util import json +import logging import mimetypes -import tabulate -import pydantic -import importlib.util -import functools -import contextlib import os import pathlib import typing + import jwt +import pydantic +import tabulate from deepmerge import Merger - CHECKSUM_BLOCK_SIZE = 4096 EXTRAS: tuple[str, ...] = ("plot", "torch") @@ -24,9 +24,11 @@ def find_first_instance_of_file( - file_names: list[str] | str, check_user_space: bool = True + file_names: list[str] | str, + *, + check_user_space: bool = True, ) -> pathlib.Path | None: - """Traverses a file hierarchy from bottom upwards to find file + """Traverses a file hierarchy from bottom upwards to find file. Returns the first instance of 'file_names' found when moving upward from the current directory. @@ -43,6 +45,7 @@ def find_first_instance_of_file( ------- pathlib.Path | None first matching file if found + """ if isinstance(file_names, str): file_names = [file_names] @@ -73,7 +76,7 @@ def find_first_instance_of_file( def parse_validation_response( response: dict[str, list[dict[str, str]]], ) -> str: - """Parse ValidationError response from server + """Parse ValidationError response from server. Reformats the error information from a validation error into a human readable table. Checks if 'body' exists within response to determine @@ -88,20 +91,24 @@ def parse_validation_response( ------- str return the validation information + """ if not (issues := response.get("detail")): raise RuntimeError( - "Expected key 'detail' in server response during validation failure" + "Expected key 'detail' in server response during validation failure", ) out: list[list[str]] = [] + error_string_cutoff: int = 60 if isinstance(issues, str): - return tabulate.tabulate( - [["Unknown", "N/A", issues]], - headers=["Type", "Location", "Message"], - tablefmt="fancy_grid", - ).__str__() + return str( + tabulate.tabulate( + [["Unknown", "N/A", issues]], + headers=["Type", "Location", "Message"], + tablefmt="fancy_grid", + ), + ) for issue in issues: obj_type: str = issue["type"] @@ -125,12 +132,12 @@ def parse_validation_response( for loc in location: if loc in input_arg: input_arg = input_arg[loc] - if len(str(input_arg)) > 60 and input_arg: - input_arg = f"{str(input_arg)[:60]}..." + if len(str(input_arg)) > error_string_cutoff and input_arg: + input_arg = f"{str(input_arg)[:error_string_cutoff]}..." information.append(input_arg) # Limit message to be 60 characters - msg: str = issue["msg"][:60] + msg: str = issue["msg"][:error_string_cutoff] information.append(msg) out.append(information) @@ -153,20 +160,20 @@ def wrapper(self, *args, **kwargs) -> typing.Any: [ importlib.util.find_spec("matplotlib"), importlib.util.find_spec("plotly"), - ] + ], ): raise RuntimeError( - f"Plotting features require the '{extra_name}' extension to Simvue" + f"Plotting features require the '{extra_name}' extension to Simvue", ) - elif extra_name == "eco": + if extra_name == "eco": if not importlib.util.find_spec("geocoder"): raise RuntimeError( - f"Eco features require the '{extra_name}' extenstion to Simvue" + f"Eco features require the '{extra_name}' extenstion to Simvue", ) elif extra_name == "torch": if not importlib.util.find_spec("torch"): raise RuntimeError( - "PyTorch features require the 'torch' module to be installed" + "PyTorch features require the 'torch' module to be installed", ) elif extra_name not in EXTRAS: raise RuntimeError(f"Unrecognised extra '{extra_name}'") @@ -179,19 +186,20 @@ def wrapper(self, *args, **kwargs) -> typing.Any: def parse_pydantic_error(error: pydantic.ValidationError) -> str: out_table: list[str] = [] + error_string_cutoff: int = 50 for data in json.loads(error.json()): _input = data.get("input") if data["input"] is not None else "None" if isinstance(_input, dict): _input_str = json.dumps(_input, indent=2) _input_str = "\n".join( - f"{line[:47]}..." if len(line) > 50 else line + f"{line[:47]}..." if len(line) > error_string_cutoff else line for line in _input_str.split("\n") ) else: _input_str = ( _input_str - if len((_input_str := f"{_input}")) < 50 - else f"{_input_str[:50]}..." + if len(_input_str := f"{_input}") < error_string_cutoff + else f"{_input_str[:error_string_cutoff]}..." ) _type: str = data["type"] @@ -216,7 +224,7 @@ def parse_pydantic_error(error: pydantic.ValidationError) -> str: data["loc"], _type, data["msg"], - ] + ], ) err_table = tabulate.tabulate( out_table, @@ -252,17 +260,20 @@ def skip_if_failed( ------- typing.Callable wrapped class method + """ def decorator(class_func: typing.Callable) -> typing.Callable: @functools.wraps(class_func) def wrapper(self: "Run", *args, **kwargs) -> typing.Any: if getattr(self, failure_attr, None) and getattr( - self, ignore_exc_attr, None + self, + ignore_exc_attr, + None, ): logger.debug( - f"Skipping call to '{class_func.__name__}', " - f"client in fail state (see logs)." + "Skipping call to '%s', client in fail state (see logs).", + class_func.__name__, ) return on_failure_return @@ -284,7 +295,7 @@ def wrapper(self: "Run", *args, **kwargs) -> typing.Any: def prettify_pydantic(func: typing.Callable) -> typing.Callable: - """Converts pydantic validation errors to a table + """Converts pydantic validation errors to a table. Parameters ---------- @@ -300,6 +311,7 @@ def prettify_pydantic(func: typing.Callable) -> typing.Callable: ------ RuntimeError the formatted validation error + """ @functools.wraps(func) @@ -308,37 +320,13 @@ def wrapper(*args, **kwargs) -> typing.Any: return func(*args, **kwargs) except pydantic.ValidationError as e: error_str = parse_pydantic_error(e) - raise RuntimeError(error_str) + raise RuntimeError(error_str) from None return wrapper -def create_file(filename: str) -> None: - """ - Create an empty file - """ - try: - with open(filename, "w") as fh: - fh.write("") - except Exception as err: - logger.error("Unable to write file %s due to: %s", filename, str(err)) - - -def remove_file(filename: str) -> None: - """ - Remove file - """ - if os.path.isfile(filename): - try: - os.remove(filename) - except Exception as err: - logger.error("Unable to remove file %s due to: %s", filename, str(err)) - - -def get_expiry(token) -> int | None: - """ - Get expiry date from a JWT token - """ +def get_expiry(token: str) -> int | None: + """Get expiry date from a JWT token.""" expiry: int | None = None with contextlib.suppress(jwt.DecodeError): expiry = jwt.decode(token, options={"verify_signature": False})["exp"] @@ -346,42 +334,29 @@ def get_expiry(token) -> int | None: return expiry -def prepare_for_api(data_in, all=True): - """ - Remove references to pickling - """ - data = data_in.copy() - if "pickled" in data: - del data["pickled"] - if "pickledFile" in data and all: - del data["pickledFile"] - return data +def calculate_object_sha256(file_data: typing.Any) -> str: + """Calculate the hash for data.""" + _sha256_hash = hashlib.sha256() + if isinstance(file_data, str): + _sha256_hash.update(bytes(file_data, "utf-8")) + else: + _sha256_hash.update(bytes(file_data)) + return _sha256_hash.hexdigest() -def calculate_sha256(filename: str | typing.Any, is_file: bool) -> str | None: - """ - Calculate sha256 checksum of the specified file - """ - sha256_hash = hashlib.sha256() - if is_file: - try: - with open(filename, "rb") as fd: - for byte_block in iter(lambda: fd.read(CHECKSUM_BLOCK_SIZE), b""): - sha256_hash.update(byte_block) - return sha256_hash.hexdigest() - except Exception: - return None - - if isinstance(filename, str): - sha256_hash.update(bytes(filename, "utf-8")) - else: - sha256_hash.update(bytes(filename)) - return sha256_hash.hexdigest() +def calculate_file_sha256(file: pathlib.Path) -> str | None: + """Calculate the hash for a file.""" + _sha256_hash = hashlib.sha256() + with contextlib.suppress(Exception), file.open("rb") as fd: + for byte_block in iter(lambda: fd.read(CHECKSUM_BLOCK_SIZE), b""): + _sha256_hash.update(byte_block) + return _sha256_hash.hexdigest() + return None @functools.lru_cache def get_mimetypes() -> list[str]: - """Returns a list of allowed MIME types""" + """Returns a list of allowed MIME types.""" mimetypes.init() _valid_mimetypes = ["application/vnd.plotly.v1+json"] _valid_mimetypes += list(mimetypes.types_map.values()) @@ -389,7 +364,7 @@ def get_mimetypes() -> list[str]: def get_mimetype_for_file(file_path: pathlib.Path) -> str: - """Return MIME type for the given file""" + """Return MIME type for the given file.""" _guess, *_ = mimetypes.guess_type(file_path) return _guess or "application/octet-stream" diff --git a/simvue/version.py b/simvue/version.py index 6c2f0889..9d18289f 100644 --- a/simvue/version.py +++ b/simvue/version.py @@ -1,5 +1,4 @@ import importlib.metadata -import os.path import pathlib import toml @@ -7,10 +6,14 @@ try: __version__ = importlib.metadata.version("simvue") except importlib.metadata.PackageNotFoundError: - _metadata = os.path.join( - pathlib.Path(os.path.dirname(__file__)).parents[1], "pyproject.toml" + _metadata = ( + pathlib.Path(__file__) + .parents[2] + .joinpath( + "pyproject.toml", + ) ) - if os.path.exists(_metadata): + if _metadata.exists(): __version__ = toml.load(_metadata)["project"]["version"] else: __version__ = "" diff --git a/tests/conftest.py b/tests/conftest.py index 75662bf3..4ed8d58e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,6 @@ import numpy import pytest import pytest_mock -import typing import uuid import tempfile import os @@ -91,7 +90,6 @@ def _mock_get(*args, **kwargs) -> requests.Response: else: return _req_get(*args, **kwargs) def _mock_location_info(self) -> None: - self._logger.info("📍 Determining current user location.") self._latitude: float self._longitude: float self._latitude, self._longitude = (-1, -1) @@ -191,11 +189,12 @@ def create_pending_run(request, prevent_script_exit) -> Generator[tuple[sv_run.R @pytest.fixture -def create_plain_run_offline(request,prevent_script_exit,monkeypatch) -> Generator[tuple[sv_run.Run, dict]]: +def create_plain_run_offline(request,prevent_script_exit,monkeypatch, mocker) -> Generator[tuple[sv_run.Run, dict]]: _ = prevent_script_exit with tempfile.TemporaryDirectory() as temp_d: monkeypatch.setenv("SIMVUE_OFFLINE_DIRECTORY", temp_d) with sv_run.Run(mode="offline") as run: + run.metric_spy = mocker.spy(run, "_get_internal_metrics") _temporary_directory = pathlib.Path(temp_d) yield run, setup_test_run(run, temp_dir=_temporary_directory, create_objects=False, request=request) clear_out_files() diff --git a/tests/functional/test_run_class.py b/tests/functional/test_run_class.py index aee52d33..a6d6f04b 100644 --- a/tests/functional/test_run_class.py +++ b/tests/functional/test_run_class.py @@ -77,11 +77,11 @@ def test_check_run_initialised_decorator() -> None: def test_run_with_emissions_online(speedy_heartbeat, mock_co2_signal, create_plain_run: tuple[sv_run.Run, ...], mocker) -> None: run_created, _ = create_plain_run metric_interval = 1 - run_created._user_config.eco.co2_signal_api_token = "test_token" + run_created.user_config.eco.co2_signal_api_token = "test_token" run_created.config(enable_emission_metrics=True, system_metrics_interval=metric_interval) while ( "sustainability.emissions.total" not in requests.get( - url=f"{run_created._user_config.server.url}/metrics/names", + url=f"{run_created.user_config.server.url}/metrics/names", headers=run_created._headers, params={"runs": json.dumps([run_created.id])}).json() and run_created.metric_spy.call_count < 4 @@ -112,24 +112,28 @@ def test_run_with_emissions_online(speedy_heartbeat, mock_co2_signal, create_pla @pytest.mark.offline def test_run_with_emissions_offline(speedy_heartbeat, mock_co2_signal, create_plain_run_offline, monkeypatch) -> None: run_created, _ = create_plain_run_offline - run_created.config(enable_emission_metrics=True) + metric_interval = 1 + run_created.config(enable_emission_metrics=True, system_metrics_interval=metric_interval) time.sleep(5) # Run should continue, but fail to log metrics until sender runs and creates file _sender = Sender(cache_directory=os.environ["SIMVUE_OFFLINE_DIRECTORY"], throw_exceptions=True) _sender.upload() id_mapping = _sender.id_mapping _run = RunObject(identifier=id_mapping[run_created.id]) - _metric_names = [item[0] for item in _run.metrics] - for _metric in ["emissions", "energy_consumed"]: - _total_metric_name = f"sustainability.{_metric}.total" - _delta_metric_name = f"sustainability.{_metric}.delta" - assert _total_metric_name not in _metric_names - assert _delta_metric_name not in _metric_names + _run.read_only(False) # Sender should now have made a local file, and the run should be able to use it to create emissions metrics time.sleep(5) _sender = Sender(cache_directory=os.environ["SIMVUE_OFFLINE_DIRECTORY"], throw_exceptions=True) _sender.upload() id_mapping = _sender.id_mapping + while ( + "sustainability.emissions.total" not in requests.get( + url=f"{run_created.user_config.server.url}/metrics/names", + headers=run_created._headers, + params={"runs": json.dumps([run_created.id])}).json() + and run_created.metric_spy.call_count < 4 + ): + time.sleep(metric_interval) _run.refresh() _metric_names = [item[0] for item in _run.metrics] client = sv_cl.Client() @@ -1094,7 +1098,7 @@ def test_save_object( except ImportError: pytest.skip("Numpy is not installed") save_obj = array([1, 2, 3, 4]) - simvue_run.save_object(save_obj, "input", f"test_object_{object_type}") + simvue_run.save_object(save_obj, category="input", name=f"test_object_{object_type}") @pytest.mark.run @@ -1245,7 +1249,7 @@ def test_add_alerts_offline(monkeypatch) -> None: rule="is inside range", ) - _sender = Sender(os.environ["SIMVUE_OFFLINE_DIRECTORY"], 2, 10, throw_exceptions=True) + _sender = Sender(cache_directory=os.environ["SIMVUE_OFFLINE_DIRECTORY"], max_workers=2, threading_threshold=10, throw_exceptions=True) _sender.upload() _online_run = RunObject(identifier=_sender.id_mapping.get(run.id)) @@ -1254,7 +1258,7 @@ def test_add_alerts_offline(monkeypatch) -> None: # Create another run without adding to run _id = run.create_user_alert(name=f"user_alert_{_uuid}", attach_to_run=False) - _sender = Sender(os.environ["SIMVUE_OFFLINE_DIRECTORY"], 2, 10, throw_exceptions=True) + _sender = Sender(cache_directory=os.environ["SIMVUE_OFFLINE_DIRECTORY"], max_workers=2, threading_threshold=10, throw_exceptions=True) _sender.upload() # Check alert is not added @@ -1264,7 +1268,7 @@ def test_add_alerts_offline(monkeypatch) -> None: # Try adding alerts with IDs, check there is no duplication _expected_alerts.append(_id) run.add_alerts(ids=_expected_alerts) - _sender = Sender(os.environ["SIMVUE_OFFLINE_DIRECTORY"], 2, 10, throw_exceptions=True) + _sender = Sender(cache_directory=os.environ["SIMVUE_OFFLINE_DIRECTORY"], max_workers=2, threading_threshold=10, throw_exceptions=True) _sender.upload() _online_run.refresh() diff --git a/tests/functional/test_run_execute_process.py b/tests/functional/test_run_execute_process.py index 11aed318..b33ea944 100644 --- a/tests/functional/test_run_execute_process.py +++ b/tests/functional/test_run_execute_process.py @@ -24,7 +24,7 @@ def test_monitor_processes(create_plain_run_offline: tuple[Run, dict]): _run.add_process(f"process_1_{os.environ.get('PYTEST_XDIST_WORKER', 0)}", Command="Write-Output 'Hello World!'", executable="powershell") _run.add_process(f"process_2_{os.environ.get('PYTEST_XDIST_WORKER', 0)}", Command="Get-ChildItem", executable="powershell") _run.add_process(f"process_3_{os.environ.get('PYTEST_XDIST_WORKER', 0)}", Command="exit 0", executable="powershell") - _sender = Sender(_run._sv_obj._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_run._sv_obj._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["folders", "alerts", "runs"], ) @@ -115,7 +115,7 @@ def test_processes_cwd(create_plain_run: dict[Run, dict]) -> None: cwd=temp_dir ) time.sleep(1) - run.save_file(os.path.join(temp_dir, f"new_file_{os.environ.get('PYTEST_XDIST_WORKER', 0)}.txt"), 'output') + run.save_file(os.path.join(temp_dir, f"new_file_{os.environ.get('PYTEST_XDIST_WORKER', 0)}.txt"), category='output') client = Client() diff --git a/tests/functional/test_utilities.py b/tests/functional/test_utilities.py index 1d8e0abd..9d9f5233 100644 --- a/tests/functional/test_utilities.py +++ b/tests/functional/test_utilities.py @@ -18,11 +18,11 @@ def test_calculate_hash(is_file: bool, hash: str) -> None: if is_file: with tempfile.TemporaryDirectory() as tempd: - with open(out_file := os.path.join(tempd, "temp.txt"), "w") as out_f: + with (out_file := pathlib.Path(tempd).joinpath("temp.txt")).open("w") as out_f: out_f.write("This is a test") - assert sv_util.calculate_sha256(filename=out_file, is_file=is_file) == hash + assert sv_util.calculate_file_sha256(out_file) == hash else: - assert sv_util.calculate_sha256(filename="temp.txt", is_file=is_file) == hash + assert sv_util.calculate_object_sha256("temp.txt") == hash @pytest.mark.config @pytest.mark.parametrize( diff --git a/tests/unit/test_event_alert.py b/tests/unit/test_event_alert.py index 764f7b54..7bad46c7 100644 --- a/tests/unit/test_event_alert.py +++ b/tests/unit/test_event_alert.py @@ -85,7 +85,7 @@ def test_event_alert_modification_online() -> None: _alert.commit() time.sleep(1) _new_alert = Alert(_alert.id) - _new_alert.read_only(False) + _new_alert.read_only(is_read_only=False) assert isinstance(_new_alert, EventsAlert) _new_alert.description = "updated!" assert _new_alert.description != "updated!" diff --git a/tests/unit/test_matplotlib_figure_mime_type.py b/tests/unit/test_matplotlib_figure_mime_type.py index d984f676..42817a54 100644 --- a/tests/unit/test_matplotlib_figure_mime_type.py +++ b/tests/unit/test_matplotlib_figure_mime_type.py @@ -16,6 +16,6 @@ def test_matplotlib_figure_mime_type() -> None: plt.plot([1, 2, 3, 4]) figure = plt.gcf() - _, mime_type = serialize_object(figure, False) + _, mime_type = serialize_object(figure, allow_pickle=False) assert (mime_type == 'application/vnd.plotly.v1+json') diff --git a/tests/unit/test_metric_range_alert.py b/tests/unit/test_metric_range_alert.py index a7efbaff..892e7d33 100644 --- a/tests/unit/test_metric_range_alert.py +++ b/tests/unit/test_metric_range_alert.py @@ -61,7 +61,7 @@ def test_metric_range_alert_creation_offline(offline_cache_setup) -> None: assert _local_data.get("name") == f"metrics_range_alert_{_uuid}" assert _local_data.get("notification") == "none" assert _local_data.get("alert").get("range_low") == 10 - _sender = Sender(_alert._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_alert._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["alerts"]) time.sleep(1) @@ -123,7 +123,7 @@ def test_metric_range_alert_modification_offline(offline_cache_setup) -> None: offline=True ) _alert.commit() - _sender = Sender(_alert._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_alert._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["alerts"]) time.sleep(1) @@ -148,7 +148,7 @@ def test_metric_range_alert_modification_offline(offline_cache_setup) -> None: _local_data = json.load(in_f) assert _local_data.get("description") == "updated!" - _sender = Sender(_alert._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_alert._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["alerts"]) time.sleep(1) diff --git a/tests/unit/test_metric_threshold_alert.py b/tests/unit/test_metric_threshold_alert.py index dfd1209e..9845bf1a 100644 --- a/tests/unit/test_metric_threshold_alert.py +++ b/tests/unit/test_metric_threshold_alert.py @@ -61,7 +61,7 @@ def test_metric_threshold_alert_creation_offline(offline_cache_setup) -> None: assert _local_data.get("notification") == "none" assert _local_data.get("alert").get("threshold") == 10 - _sender = Sender(_alert._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_alert._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["alerts"]) time.sleep(1) @@ -123,7 +123,7 @@ def test_metric_threshold_alert_modification_offline(offline_cache_setup) -> Non ) _alert.commit() - _sender = Sender(_alert._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_alert._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["alerts"]) time.sleep(1) @@ -149,7 +149,7 @@ def test_metric_threshold_alert_modification_offline(offline_cache_setup) -> Non _local_data = json.load(in_f) assert _local_data.get("description") == "updated!" - Sender(_alert._local_staging_file.parents[1], 1, 10, throw_exceptions=True).upload(["alerts"]) + Sender(cache_directory=_alert._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True).upload(["alerts"]) time.sleep(1) _online_alert.refresh() diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index 74a5c0aa..b0aa6e3d 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -88,7 +88,7 @@ def test_metrics_creation_offline(offline_cache_setup) -> None: assert _local_data.get("metrics")[0].get("step") == _step assert _local_data.get("metrics")[0].get("time") == _time - _sender = Sender(_metrics._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_metrics._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload( ["folders", "runs", "metrics"]) time.sleep(1) diff --git a/tests/unit/test_numpy_array_mime_type.py b/tests/unit/test_numpy_array_mime_type.py index 01c47295..234e0f9c 100644 --- a/tests/unit/test_numpy_array_mime_type.py +++ b/tests/unit/test_numpy_array_mime_type.py @@ -8,6 +8,6 @@ def test_numpy_array_mime_type() -> None: Check that the mimetype for numpy arrays is correct """ array = np.array([1, 2, 3, 4, 5]) - _, mime_type = serialize_object(array, False) + _, mime_type = serialize_object(array, allow_pickle=False) assert (mime_type == 'application/vnd.simvue.numpy.v1') diff --git a/tests/unit/test_numpy_array_serialization.py b/tests/unit/test_numpy_array_serialization.py index 52d6e6d0..e4471521 100644 --- a/tests/unit/test_numpy_array_serialization.py +++ b/tests/unit/test_numpy_array_serialization.py @@ -9,7 +9,7 @@ def test_numpy_array_serialization() -> None: """ array = np.array([1, 2, 3, 4, 5]) - serialized, mime_type = serialize_object(array, False) - array_out = deserialize_data(serialized, mime_type, False) + serialized, mime_type = serialize_object(array, allow_pickle=False) + array_out = deserialize_data(serialized, mime_type, allow_pickle=False) assert (array == array_out).all() diff --git a/tests/unit/test_object_artifact.py b/tests/unit/test_object_artifact.py index 1a60dda4..8aa474c5 100644 --- a/tests/unit/test_object_artifact.py +++ b/tests/unit/test_object_artifact.py @@ -63,7 +63,7 @@ def test_object_artifact_creation_offline(offline_cache_setup) -> None: assert _local_data.get("mime_type") == "application/vnd.simvue.numpy.v1" assert _local_data.get("runs") == {_run.id: "input"} - _sender = Sender(pathlib.Path(offline_cache_setup.name), 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=pathlib.Path(offline_cache_setup.name), max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload() time.sleep(1) diff --git a/tests/unit/test_pandas_dataframe_mimetype.py b/tests/unit/test_pandas_dataframe_mimetype.py index 1d3de890..ecd67e1a 100644 --- a/tests/unit/test_pandas_dataframe_mimetype.py +++ b/tests/unit/test_pandas_dataframe_mimetype.py @@ -16,6 +16,6 @@ def test_pandas_dataframe_mimetype() -> None: data = {'col1': [1, 2], 'col2': [3, 4]} df = pd.DataFrame(data=data) - _, mime_type = serialize_object(df, False) + _, mime_type = serialize_object(df, allow_pickle=False) assert (mime_type == 'application/vnd.simvue.df.v1') diff --git a/tests/unit/test_pandas_dataframe_serialization.py b/tests/unit/test_pandas_dataframe_serialization.py index 79c524e1..313f014e 100644 --- a/tests/unit/test_pandas_dataframe_serialization.py +++ b/tests/unit/test_pandas_dataframe_serialization.py @@ -16,7 +16,7 @@ def test_pandas_dataframe_serialization() -> None: data = {'col1': [1, 2], 'col2': [3, 4]} df = pd.DataFrame(data=data) - serialized, mime_type = serialize_object(df, False) - df_out = deserialize_data(serialized, mime_type, False) + serialized, mime_type = serialize_object(df, allow_pickle=False) + df_out = deserialize_data(serialized, mime_type, allow_pickle=False) assert (df.equals(df_out)) diff --git a/tests/unit/test_plotly_figure_mime_type.py b/tests/unit/test_plotly_figure_mime_type.py index 6884a440..3cdc63bb 100644 --- a/tests/unit/test_plotly_figure_mime_type.py +++ b/tests/unit/test_plotly_figure_mime_type.py @@ -24,6 +24,6 @@ def test_plotly_figure_mime_type() -> None: figure = plt.gcf() plotly_figure = plotly.tools.mpl_to_plotly(figure) - _, mime_type = serialize_object(plotly_figure, False) + _, mime_type = serialize_object(plotly_figure, allow_pickle=False) assert (mime_type == 'application/vnd.plotly.v1+json') diff --git a/tests/unit/test_pytorch_tensor_mime_type.py b/tests/unit/test_pytorch_tensor_mime_type.py index 35391850..ceb274be 100644 --- a/tests/unit/test_pytorch_tensor_mime_type.py +++ b/tests/unit/test_pytorch_tensor_mime_type.py @@ -15,6 +15,6 @@ def test_pytorch_tensor_mime_type() -> None: """ torch.manual_seed(1724) array = torch.rand(2, 3) - _, mime_type = serialize_object(array, False) + _, mime_type = serialize_object(array, allow_pickle=False) assert (mime_type == 'application/vnd.simvue.torch.v1') diff --git a/tests/unit/test_pytorch_tensor_serialization.py b/tests/unit/test_pytorch_tensor_serialization.py index 18a36e1b..785383fa 100644 --- a/tests/unit/test_pytorch_tensor_serialization.py +++ b/tests/unit/test_pytorch_tensor_serialization.py @@ -15,7 +15,7 @@ def test_pytorch_tensor_serialization() -> None: torch.manual_seed(1724) array = torch.rand(2, 3) - serialized, mime_type = serialize_object(array, False) - array_out = deserialize_data(serialized, mime_type, False) + serialized, mime_type = serialize_object(array, allow_pickle=False) + array_out = deserialize_data(serialized, mime_type, allow_pickle=False) assert (array == array_out).all() diff --git a/tests/unit/test_run.py b/tests/unit/test_run.py index 72875f4e..c33271fc 100644 --- a/tests/unit/test_run.py +++ b/tests/unit/test_run.py @@ -41,7 +41,7 @@ def test_run_creation_offline(offline_cache_setup) -> None: assert _local_data.get("name") == f"simvue_offline_run_{_uuid}" assert _local_data.get("folder") == _folder_name - _sender = Sender(_run._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_run._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["folders", "runs"]) time.sleep(1) @@ -118,7 +118,7 @@ def test_run_modification_offline(offline_cache_setup) -> None: assert _new_run.description == "Simvue test run" assert _new_run.name == "simvue_test_run" - _sender = Sender(_run._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_run._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["folders", "runs"]) time.sleep(1) @@ -138,7 +138,7 @@ def test_run_modification_offline(offline_cache_setup) -> None: _online_run.refresh() assert _online_run.tags == [] - _sender = Sender(_run._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_run._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["folders", "runs"]) time.sleep(1) diff --git a/tests/unit/test_s3_storage.py b/tests/unit/test_s3_storage.py index c22d5b85..ee957b31 100644 --- a/tests/unit/test_s3_storage.py +++ b/tests/unit/test_s3_storage.py @@ -72,7 +72,7 @@ def test_create_s3_offline(offline_cache_setup) -> None: assert not _local_data.get("user", None) assert not _local_data.get("usage", None) - _sender = Sender(_storage._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_storage._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["storage"]) _online_id = _sender.id_mapping[_storage.id] time.sleep(1) diff --git a/tests/unit/test_stats.py b/tests/unit/test_stats.py index cb42650f..6a385b58 100644 --- a/tests/unit/test_stats.py +++ b/tests/unit/test_stats.py @@ -6,7 +6,7 @@ @pytest.mark.online def test_stats() -> None: _statistics = Stats() - assert f"{_statistics.url}" == f"{_statistics._base_url}" + assert f"{_statistics.url}" == f"{_statistics.base_url}" assert isinstance(_statistics.runs.created, int) assert isinstance(_statistics.runs.running, int) assert isinstance(_statistics.runs.completed, int) diff --git a/tests/unit/test_tag.py b/tests/unit/test_tag.py index c91af8c1..dc8cf36e 100644 --- a/tests/unit/test_tag.py +++ b/tests/unit/test_tag.py @@ -35,7 +35,7 @@ def test_tag_creation_offline(offline_cache_setup) -> None: assert _local_data.get("name") == f"test_tag_{_uuid}" - _sender = Sender(_tag._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_tag._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["tags"]) time.sleep(1) @@ -79,7 +79,7 @@ def test_tag_modification_offline(offline_cache_setup) -> None: assert _local_data.get("name") == f"test_tag_{_uuid}" - _sender = Sender(_tag._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_tag._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["tags"]) _online_id = _sender.id_mapping.get(_tag.id) _online_tag = Tag(_online_id) @@ -103,7 +103,7 @@ def test_tag_modification_offline(offline_cache_setup) -> None: assert pydantic.color.parse_str(_local_data.get("colour")).r == 250 / 255 assert _local_data.get("description") == "modified test tag" - _sender = Sender(_tag._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_tag._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["tags"]) time.sleep(1) diff --git a/tests/unit/test_tenant.py b/tests/unit/test_tenant.py index 73117ddd..5b6bd5b0 100644 --- a/tests/unit/test_tenant.py +++ b/tests/unit/test_tenant.py @@ -40,7 +40,7 @@ def test_create_tenant_offline(offline_cache_setup) -> None: assert _local_data.get("name") == _uuid assert _local_data.get("is_enabled") == True - _sender = Sender(_new_tenant._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_new_tenant._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["tenants"]) time.sleep(1) _online_user = Tenant(_sender.id_mapping.get(_new_tenant.id)) diff --git a/tests/unit/test_user.py b/tests/unit/test_user.py index 5aac3c11..22afaa58 100644 --- a/tests/unit/test_user.py +++ b/tests/unit/test_user.py @@ -62,7 +62,7 @@ def test_create_user_offline(offline_cache_setup) -> None: assert _local_data.get("fullname") == "Joe Bloggs" assert _local_data.get("email") == "jbloggs@simvue.io" - _sender = Sender(_user._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_user._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["users"]) time.sleep(1) _online_user = User(_sender.id_mapping.get(_user.id)) diff --git a/tests/unit/test_user_alert.py b/tests/unit/test_user_alert.py index f13248c3..d31c8902 100644 --- a/tests/unit/test_user_alert.py +++ b/tests/unit/test_user_alert.py @@ -46,7 +46,7 @@ def test_user_alert_creation_offline(offline_cache_setup) -> None: assert _local_data.get("name") == f"users_alert_{_uuid}" assert _local_data.get("notification") == "none" - _sender = Sender(_alert._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_alert._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["alerts"]) time.sleep(1) @@ -94,7 +94,7 @@ def test_user_alert_modification_offline(offline_cache_setup) -> None: ) _alert.commit() - _sender = Sender(_alert._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_alert._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["alerts"]) time.sleep(1) @@ -118,7 +118,7 @@ def test_user_alert_modification_offline(offline_cache_setup) -> None: with _alert._local_staging_file.open() as in_f: _local_data = json.load(in_f) assert _local_data.get("description") == "updated!" - _sender = Sender(_alert._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_alert._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["alerts"]) time.sleep(1) @@ -193,7 +193,7 @@ def test_user_alert_status_offline(offline_cache_setup) -> None: _run.alerts = [_alert.id] _run.commit() - _sender = Sender(_alert._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_alert._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["folders", "runs", "alerts"]) time.sleep(1) @@ -209,7 +209,7 @@ def test_user_alert_status_offline(offline_cache_setup) -> None: _online_alert.refresh() assert not _online_alert.get_status(run_id=_sender.id_mapping.get(_run.id)) - _sender = Sender(_alert._local_staging_file.parents[1], 1, 10, throw_exceptions=True) + _sender = Sender(cache_directory=_alert._local_staging_file.parents[1], max_workers=1, threading_threshold=10, throw_exceptions=True) _sender.upload(["alerts"]) time.sleep(1) diff --git a/uv.lock b/uv.lock index edf1afe5..9a8c5877 100644 --- a/uv.lock +++ b/uv.lock @@ -1370,6 +1370,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bc/60/5382c03e1970de634027cee8e1b7d39776b778b81812aaf45b694dfe9e28/pillow-12.2.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:bfa9c230d2fe991bed5318a5f119bd6780cda2915cca595393649fc118ab895e", size = 7080946, upload-time = "2026-04-01T14:46:11.734Z" }, ] +[[package]] +name = "pip" +version = "26.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/91/47e7d486260f618783899587af63ccf7980fb60245c3e63dd4571c6b57ad/pip-26.1.2.tar.gz", hash = "sha256:f49cd134c61cf2fd75e0ce2676db03e4054504a5a4986d00f8299ae632dc4605", size = 1840799, upload-time = "2026-05-31T17:33:58.56Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/95/6b5cb3461ea5673ba0995989746db58eb18b91b54dbf331e72f569540946/pip-26.1.2-py3-none-any.whl", hash = "sha256:382ff9f685ee3bc25864f820aa50505825f10f5458ffff07e30a6d96e5715cab", size = 1813144, upload-time = "2026-05-31T17:33:56.772Z" }, +] + [[package]] name = "plotly" version = "6.7.0" @@ -1838,7 +1847,7 @@ wheels = [ [[package]] name = "simvue" -version = "2.5.4" +version = "2.5.5" source = { editable = "." } dependencies = [ { name = "click" }, @@ -1852,6 +1861,7 @@ dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.4.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pandas" }, + { name = "pip" }, { name = "psutil" }, { name = "pydantic" }, { name = "pydantic-extra-types" }, @@ -1900,6 +1910,7 @@ requires-dist = [ { name = "msgpack", specifier = ">=1.1.0,<2.0.0" }, { name = "numpy", specifier = ">=2.0.0,<3.0.0" }, { name = "pandas", specifier = ">=2.2.3,<3.0.0" }, + { name = "pip", specifier = ">=26.1.2" }, { name = "plotly", marker = "extra == 'plot'", specifier = ">=6.0.0,<7.0.0" }, { name = "psutil", specifier = ">=6.1.1,<8.0.0" }, { name = "pydantic", specifier = ">=2.11,<3.0.0" },