Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,26 @@ The changelog format is based on [Keep a Changelog](https://keepachangelog.com/e
## [Unreleased]

### Added
* `gamma` parameter on `ProximalPolicyOptimizationAgent` (default 0.99) and `--gamma` CLI flag in
`train_ppo.py` to configure the PPO discount factor without editing source code.

### Fixed
* `ProximalPolicyOptimizationAgent.load()` now applies a `TimeLimit` wrapper (max 3000 steps),
matching the training configuration. Without it, `play_ppo.py` ran indefinitely on a converged
model whose near-zero reward never crossed the termination threshold.
* `AntiPendulumEnv.render()` now handles `render_mode='plot'` by calling `show_plot()` directly,
so the episode plot appears when running `play_ppo.py --render-mode plot`.
* `show_plot()` legend now includes lines from twin y-axes (load speed, crane speed, damping) by
combining handles from both axes with `get_legend_handles_labels()`.
* `show_plot()` title moved from `plt.title()` (attached to last axes) to `plt.suptitle()`
(figure-level), preventing the title from appearing between subplots.
* `show_plot()` switched from 2×2 grid to 4×1 vertical layout (16×12 in) so all subplots share
a common time axis and each has full width.
* Disabled explicit time penalty (`reward_fac[2] = 0.0`) in PPO training and playback scripts.
The term `−self.time × 0.001` uses hidden state absent from the observation, violating the
Markov property and destabilising PPO's value function. Time preference is already encoded
implicitly through the discount factor γ.

* `ProximalPolicyOptimizationAgent.resume()` classmethod to continue training from a saved checkpoint.
Restores VecNormalize statistics and keeps normalization in training mode, consistent with SB3's
`PPO.load()` + `.learn(reset_num_timesteps=False)` pattern.
Expand Down
1 change: 1 addition & 0 deletions scripts/play_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def main() -> None:
"crane": build_crane,
"start_speed": 1.0,
"render_mode": args.render_mode,
"reward_fac": (1.0, 0.0015, 0.0),
},
)

Expand Down
9 changes: 9 additions & 0 deletions scripts/train_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def main() -> None:
default=None,
help="Path to a saved model zip to resume training from.",
)
_ = parser.add_argument(
"--gamma",
type=float,
default=0.99,
help="Discount factor for future rewards (default 0.99). Try 0.999 for longer planning horizon.",
)
_ = parser.add_argument(
"--dry-run",
action="store_true",
Expand Down Expand Up @@ -67,6 +73,7 @@ def main() -> None:
"crane": build_crane,
"start_speed": 1.0,
"render_mode": args.render_mode,
"reward_fac": (1.0, 0.0015, 0.0),
},
save_path=args.save_path,
n_envs=args.n_envs,
Expand All @@ -84,8 +91,10 @@ def main() -> None:
"crane": build_crane,
"start_speed": 1.0,
"render_mode": args.render_mode,
"reward_fac": (1.0, 0.0015, 0.0),
},
save_path=args.save_path,
gamma=args.gamma,
)
agent.do_training(args.steps)
vecnorm_path = Path(args.save_path).parent / f"{Path(args.save_path).stem}_vecnorm.pkl"
Expand Down
16 changes: 11 additions & 5 deletions src/crane_controller/envs/controlled_crane_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def show_plot(self, episode: int) -> None:
episode : int
Episode number used in the plot title.
"""
_, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
_, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(16, 12))
times = self.dt * np.arange(len(self.traces["c_x"]))
damping = self.traces["l_v"][0] * np.exp(-times / self.wire.damping_time)
ax1.plot(times, self.traces["l_x"], label="load angle", color="blue")
Expand All @@ -254,11 +254,15 @@ def show_plot(self, episode: int) -> None:
ax2y2.plot(times, self.traces["c_v"], label="crane speed", color="red")
ax3.plot(times[: len(self.rewards)], self.rewards, label="rewards")
ax4.plot(times, self.traces["acc"], label="x-acceleration", color="green")
_ = ax1.legend()
_ = ax2.legend()
lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax1y2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2)
lines3, labels3 = ax2.get_legend_handles_labels()
lines4, labels4 = ax2y2.get_legend_handles_labels()
ax2.legend(lines3 + lines4, labels3 + labels4, loc="upper left")
_ = ax3.legend()
_ = ax4.legend()
_ = plt.title(f"Detailed plot of episode {episode}, reward:{self.reward}")
_ = plt.suptitle(f"Detailed plot of episode {episode}, reward:{self.reward}")
plt.show()
for key in self.traces:
self.traces[key] = []
Expand Down Expand Up @@ -485,5 +489,7 @@ def step(self, action: int) -> tuple[tuple[int, ...] | np.ndarray, float, bool,

def render(self) -> None:
"""Render the current episode."""
if self.render_mode == "play-back": # show the animation
if self.render_mode == "play-back":
self.show_animation()
elif self.render_mode == "plot":
self.show_plot(self.nresets)
15 changes: 13 additions & 2 deletions src/crane_controller/ppo_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class ProximalPolicyOptimizationAgent:
Maximum steps per episode enforced via a TimeLimit wrapper (default 3000).
Ensures episodes always end, even when a plateau agent never triggers the
environment's own termination condition.
gamma : float, optional
Discount factor for future rewards (default 0.99). Higher values (e.g. 0.999)
extend the effective planning horizon, which can improve policy quality on
long episodes at the cost of slower value function convergence.
"""

def __init__(
Expand All @@ -57,6 +61,7 @@ def __init__(
env_kwargs: dict[str, Any] | None = None,
save_path: str | None = None,
max_episode_steps: int = 3000,
gamma: float = 0.99,
) -> None:
"""Set up the agent for training. Use :meth:`load` for inference."""
self.save_path = save_path
Expand All @@ -68,7 +73,7 @@ def __init__(
wrapper_kwargs={"max_episode_steps": max_episode_steps},
)
self.vec_env = VecNormalize(raw_vec_env, norm_obs=True, norm_reward=True)
self.model = PPO("MlpPolicy", self.vec_env, verbose=1 if n_envs == 1 else 0)
self.model = PPO("MlpPolicy", self.vec_env, gamma=gamma, verbose=1 if n_envs == 1 else 0)
self.env: AntiPendulumEnv = self.vec_env.venv.envs[0] # type: ignore[attr-defined]

@classmethod
Expand All @@ -95,7 +100,13 @@ def load(
Agent configured for inference with VecNormalize in evaluation mode.
"""
instance = object.__new__(cls)
raw_vec_env = make_vec_env(env_id=env, n_envs=1, env_kwargs=env_kwargs)
raw_vec_env = make_vec_env(
env_id=env,
n_envs=1,
env_kwargs=env_kwargs,
wrapper_class=TimeLimit, # type: ignore[arg-type]
wrapper_kwargs={"max_episode_steps": 3000},
)
stats_path = cls._stats_path(str(model_path))
if stats_path.exists():
instance.vec_env = VecNormalize.load(str(stats_path), raw_vec_env)
Expand Down
1 change: 1 addition & 0 deletions stubs/matplotlib-stubs/pyplot.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def title(
y: float | None = None,
**kwargs: Any,
) -> Text: ...
def suptitle(t: str, **kwargs: Any) -> Text: ...
def plot(
*args: Any,
scalex: bool = True,
Expand Down