Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/virtual_stain_flow/vsf_logging/MlflowLogger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import mlflow
from torch import nn
from torch.optim import Optimizer

from ..models.model import BaseModel
from ..trainers.trainer_protocol import TrainerProtocol
Expand All @@ -34,6 +35,7 @@ class MlflowLogger:
"""
def __init__(
self,
*,
name: str,
experiment_name: str,
tracking_uri: Optional[path_type] = None,
Expand Down Expand Up @@ -228,6 +230,33 @@ def on_train_start(self):
except Exception as e:
print(f"Fail to log model config as artifact: {e}")

optimizers = self._get_optimizers()
for idx, optimizer in enumerate(optimizers):
if not isinstance(optimizer, Optimizer):
continue
try:
opt_config = {
"class_path": f"{optimizer.__class__.__module__}.{optimizer.__class__.__name__}",
"defaults": dict(optimizer.defaults),
}
except Exception as e:
print(f"Could not get optimizer config for logging: {e}")
opt_config = None

if opt_config:
mlflow.set_tag(
f"optimizer.{idx}.class_path",
str(opt_config.get("class_path"))
)
try:
self.log_config(
tag=f"optimizer_{idx}",
config=opt_config,
stage=None
)
except Exception as e:
print(f"Fail to log optimizer config as artifact: {e}")

self._log_loss_groups_config_and_tags()


Expand Down Expand Up @@ -540,6 +569,26 @@ def _get_loss_groups(self) -> Dict[str, Any]:

return loss_groups

def _get_optimizers(self) -> List[Optimizer]:
"""
Discover optimizer(s) attached to the bound trainer.
"""

if self.trainer is None:
return []

optimizers: List[Optimizer] = []

explicit_optimizers = getattr(self.trainer, 'optimizers', None)
if isinstance(explicit_optimizers, list):
optimizers.extend(explicit_optimizers)

explicit_optimizer = getattr(self.trainer, 'optimizer', None)
if explicit_optimizer is not None:
optimizers.append(explicit_optimizer)

return optimizers

def _log_loss_groups_config_and_tags(self) -> None:
"""
Log loss item names and weights as flattened string mlflow tags and
Expand Down
Loading