diff --git a/models/q_anti-pendulum.json b/models/q_anti-pendulum.json new file mode 100644 index 0000000..e5468f5 --- /dev/null +++ b/models/q_anti-pendulum.json @@ -0,0 +1,238 @@ +{ + "date": "29.04.2026 06:53:49", + "pendulum": { + "start_speed": "1.0", + "render_mode": "none", + "reward_limit": "0.0" + }, + "q_agent": { + "use_trained": "True", + "filename": "C:\\Users\\eis\\Documents\\Projects\\Simulation_Model_Assurance\\osp\\packages\\crane-controller\\models\\q_trained.json", + "episodes": "6000", + "steps": "30006000", + "learning_rate": "0.1", + "discount_factor": "0.95" + }, + "q_values": { + "(0, 0, 0, 1, 1)": [ + -0.7127999098776818, + -0.7204031480080048, + -0.7412199306092444 + ], + "(0, 0, 0, 0, 0)": [ + -0.39882058710042717, + -0.43180066871885253, + -0.36789168710316733 + ], + "(0, 0, 0, 1, 0)": [ + -0.49644387699601483, + -0.476832491697103, + -0.4808429865234992 + ], + "(0, 0, 1, 1, 0)": [ + -0.3645407668594485, + -0.34941296493381946, + -0.19816466371287308 + ], + "(0, 1, 1, 1, 0)": [ + -0.27335521972045285, + -0.31790379183091966, + -0.2034872242749172 + ], + "(0, 1, 0, 1, 0)": [ + -0.6087312502217588, + -0.6754940805998394, + -0.6535399837815964 + ], + "(0, 1, 0, 2, 0)": [ + -0.3949208069494323, + -0.4631147311708943, + -0.4776412098931605 + ], + "(0, 0, 0, 2, 0)": [ + -0.2712981364627587, + -0.17654049561221843, + -0.2508326205458066 + ], + "(0, 0, 1, 2, 0)": [ + -0.48498802268403846, + -0.4395215670942963, + -0.4722944229115358 + ], + "(0, 1, 1, 2, 0)": [ + -0.4990197857734165, + -0.5327879501542365, + -0.4941112778002146 + ], + "(0, 0, 0, 3, 0)": [ + -0.5387880192090378, + -0.4784712525147248, + -0.517952873228921 + ], + "(0, 0, 1, 3, 0)": [ + -0.9096919737168466, + -0.9127452620968198, + -0.9420045562304813 + ], + "(0, 1, 1, 3, 0)": [ + -0.8667830882886871, + -0.8757935260356404, + -0.8758730446292982 + ], + "(0, 1, 0, 3, 0)": [ + -0.43694463364384284, + -0.32317351557573637, + -0.37671045435695716 + ], + "(0, 0, 1, 4, 0)": [ + -0.9306066987431686, + -0.9344436974699757, + -0.9655748330701439 + ], + "(0, 1, 1, 4, 0)": [ + -1.009825188680733, + -1.016868953064729, + -0.9957558578136729 + ], + "(0, 1, 0, 4, 0)": [ + -1.0369014461449961, + -1.0197443807886495, + -0.9976879186706535 + ], + "(0, 0, 0, 4, 0)": [ + -0.8997996021390884, + -0.8913054855090478, + -0.8905808268698551 + ], + "(0, 1, 0, 5, 0)": [ + -1.5826848759769268, + -2.7445575744690793, + -2.418367271258592 + ], + "(0, 0, 0, 5, 0)": [ + -1.6683128963172371, + -2.251311793007559, + -2.3851805260338708 + ], + "(0, 0, 1, 5, 0)": [ + -2.2373735329153392, + -1.540495957571949, + -2.76737403946213 + ], + "(0, 1, 1, 5, 0)": [ + -2.3314084980057013, + -1.5138347954693265, + -3.01917937692279 + ], + "(0, 1, 0, 1, 1)": [ + -0.261055476317184, + -0.27846357082465933, + -0.2774020188584469 + ], + "(0, 0, 1, 1, 1)": [ + -0.1342852547795026, + -0.2470169520959763, + -0.2126480225045462 + ], + "(0, 1, 1, 1, 1)": [ + -0.2797380701937499, + -0.2945979948624919, + -0.24076792235861236 + ], + "(0, 1, 1, 2, 1)": [ + -0.2728844361768713, + -0.13972184601087123, + -0.3275551939924255 + ], + "(0, 1, 0, 2, 1)": [ + -0.455039876626072, + -0.43296129057936117, + -0.4316836608934558 + ], + "(0, 0, 0, 2, 1)": [ + -0.5686315534261919, + -0.5538077807341172, + -0.5496103554748915 + ], + "(0, 0, 0, 3, 1)": [ + -0.3966116309786595, + -0.3317568230750391, + -0.30858102473455185 + ], + "(0, 0, 1, 3, 1)": [ + -0.8315655206780324, + -0.8333814435452305, + -0.7756715486070018 + ], + "(0, 1, 1, 3, 1)": [ + -0.9831842749294987, + -0.9782200055298229, + -0.9930127344944251 + ], + "(0, 1, 0, 3, 1)": [ + -0.6539627507980373, + -0.5852573030361744, + -0.6763166111516389 + ], + "(0, 1, 0, 4, 1)": [ + -0.9026947631459618, + -0.9407969095587947, + -0.965156293555089 + ], + "(0, 0, 0, 4, 1)": [ + -1.0295607537932239, + -1.0110078430982656, + -1.0217477290303583 + ], + "(0, 0, 1, 4, 1)": [ + -1.0198724273521527, + -1.0273376264121616, + -1.0372176976745933 + ], + "(0, 1, 1, 4, 1)": [ + -0.8367514960793316, + -0.8072150440466215, + -0.7999324124740165 + ], + "(0, 0, 1, 5, 1)": [ + -2.826052632887979, + -2.8216800501972115, + -2.8339346737355218 + ], + "(0, 1, 1, 5, 1)": [ + -2.769926115050464, + -2.7896094930838444, + -2.784544336312168 + ], + "(0, 1, 0, 5, 1)": [ + -2.7928766121585795, + -2.7887103458307694, + -2.7938313222414037 + ], + "(0, 0, 0, 5, 1)": [ + -2.8104767031828572, + -2.8062208641671322, + -2.8101316497840574 + ], + "(0, 0, 1, 2, 1)": [ + -0.5213427535062244, + -0.5910962186115551, + -0.47068284911064734 + ], + "(0, 0, 1, 0, 0)": [ + -0.41286468970384216, + -0.42684989269569895, + -169.29165129913554 + ], + "(0, 1, 0, 0, 0)": [ + -0.11933872055277986, + -0.1195010821067479, + -0.1499420269134481 + ], + "(0, 1, 1, 0, 0)": [ + -0.20624980702980253, + -0.22645535674627537, + -0.18943028486599597 + ] + } +} \ No newline at end of file diff --git a/models/q_trained.json b/models/q_trained.json new file mode 100644 index 0000000..e5468f5 --- /dev/null +++ b/models/q_trained.json @@ -0,0 +1,238 @@ +{ + "date": "29.04.2026 06:53:49", + "pendulum": { + "start_speed": "1.0", + "render_mode": "none", + "reward_limit": "0.0" + }, + "q_agent": { + "use_trained": "True", + "filename": "C:\\Users\\eis\\Documents\\Projects\\Simulation_Model_Assurance\\osp\\packages\\crane-controller\\models\\q_trained.json", + "episodes": "6000", + "steps": "30006000", + "learning_rate": "0.1", + "discount_factor": "0.95" + }, + "q_values": { + "(0, 0, 0, 1, 1)": [ + -0.7127999098776818, + -0.7204031480080048, + -0.7412199306092444 + ], + "(0, 0, 0, 0, 0)": [ + -0.39882058710042717, + -0.43180066871885253, + -0.36789168710316733 + ], + "(0, 0, 0, 1, 0)": [ + -0.49644387699601483, + -0.476832491697103, + -0.4808429865234992 + ], + "(0, 0, 1, 1, 0)": [ + -0.3645407668594485, + -0.34941296493381946, + -0.19816466371287308 + ], + "(0, 1, 1, 1, 0)": [ + -0.27335521972045285, + -0.31790379183091966, + -0.2034872242749172 + ], + "(0, 1, 0, 1, 0)": [ + -0.6087312502217588, + -0.6754940805998394, + -0.6535399837815964 + ], + "(0, 1, 0, 2, 0)": [ + -0.3949208069494323, + -0.4631147311708943, + -0.4776412098931605 + ], + "(0, 0, 0, 2, 0)": [ + -0.2712981364627587, + -0.17654049561221843, + -0.2508326205458066 + ], + "(0, 0, 1, 2, 0)": [ + -0.48498802268403846, + -0.4395215670942963, + -0.4722944229115358 + ], + "(0, 1, 1, 2, 0)": [ + -0.4990197857734165, + -0.5327879501542365, + -0.4941112778002146 + ], + "(0, 0, 0, 3, 0)": [ + -0.5387880192090378, + -0.4784712525147248, + -0.517952873228921 + ], + "(0, 0, 1, 3, 0)": [ + -0.9096919737168466, + -0.9127452620968198, + -0.9420045562304813 + ], + "(0, 1, 1, 3, 0)": [ + -0.8667830882886871, + -0.8757935260356404, + -0.8758730446292982 + ], + "(0, 1, 0, 3, 0)": [ + -0.43694463364384284, + -0.32317351557573637, + -0.37671045435695716 + ], + "(0, 0, 1, 4, 0)": [ + -0.9306066987431686, + -0.9344436974699757, + -0.9655748330701439 + ], + "(0, 1, 1, 4, 0)": [ + -1.009825188680733, + -1.016868953064729, + -0.9957558578136729 + ], + "(0, 1, 0, 4, 0)": [ + -1.0369014461449961, + -1.0197443807886495, + -0.9976879186706535 + ], + "(0, 0, 0, 4, 0)": [ + -0.8997996021390884, + -0.8913054855090478, + -0.8905808268698551 + ], + "(0, 1, 0, 5, 0)": [ + -1.5826848759769268, + -2.7445575744690793, + -2.418367271258592 + ], + "(0, 0, 0, 5, 0)": [ + -1.6683128963172371, + -2.251311793007559, + -2.3851805260338708 + ], + "(0, 0, 1, 5, 0)": [ + -2.2373735329153392, + -1.540495957571949, + -2.76737403946213 + ], + "(0, 1, 1, 5, 0)": [ + -2.3314084980057013, + -1.5138347954693265, + -3.01917937692279 + ], + "(0, 1, 0, 1, 1)": [ + -0.261055476317184, + -0.27846357082465933, + -0.2774020188584469 + ], + "(0, 0, 1, 1, 1)": [ + -0.1342852547795026, + -0.2470169520959763, + -0.2126480225045462 + ], + "(0, 1, 1, 1, 1)": [ + -0.2797380701937499, + -0.2945979948624919, + -0.24076792235861236 + ], + "(0, 1, 1, 2, 1)": [ + -0.2728844361768713, + -0.13972184601087123, + -0.3275551939924255 + ], + "(0, 1, 0, 2, 1)": [ + -0.455039876626072, + -0.43296129057936117, + -0.4316836608934558 + ], + "(0, 0, 0, 2, 1)": [ + -0.5686315534261919, + -0.5538077807341172, + -0.5496103554748915 + ], + "(0, 0, 0, 3, 1)": [ + -0.3966116309786595, + -0.3317568230750391, + -0.30858102473455185 + ], + "(0, 0, 1, 3, 1)": [ + -0.8315655206780324, + -0.8333814435452305, + -0.7756715486070018 + ], + "(0, 1, 1, 3, 1)": [ + -0.9831842749294987, + -0.9782200055298229, + -0.9930127344944251 + ], + "(0, 1, 0, 3, 1)": [ + -0.6539627507980373, + -0.5852573030361744, + -0.6763166111516389 + ], + "(0, 1, 0, 4, 1)": [ + -0.9026947631459618, + -0.9407969095587947, + -0.965156293555089 + ], + "(0, 0, 0, 4, 1)": [ + -1.0295607537932239, + -1.0110078430982656, + -1.0217477290303583 + ], + "(0, 0, 1, 4, 1)": [ + -1.0198724273521527, + -1.0273376264121616, + -1.0372176976745933 + ], + "(0, 1, 1, 4, 1)": [ + -0.8367514960793316, + -0.8072150440466215, + -0.7999324124740165 + ], + "(0, 0, 1, 5, 1)": [ + -2.826052632887979, + -2.8216800501972115, + -2.8339346737355218 + ], + "(0, 1, 1, 5, 1)": [ + -2.769926115050464, + -2.7896094930838444, + -2.784544336312168 + ], + "(0, 1, 0, 5, 1)": [ + -2.7928766121585795, + -2.7887103458307694, + -2.7938313222414037 + ], + "(0, 0, 0, 5, 1)": [ + -2.8104767031828572, + -2.8062208641671322, + -2.8101316497840574 + ], + "(0, 0, 1, 2, 1)": [ + -0.5213427535062244, + -0.5910962186115551, + -0.47068284911064734 + ], + "(0, 0, 1, 0, 0)": [ + -0.41286468970384216, + -0.42684989269569895, + -169.29165129913554 + ], + "(0, 1, 0, 0, 0)": [ + -0.11933872055277986, + -0.1195010821067479, + -0.1499420269134481 + ], + "(0, 1, 1, 0, 0)": [ + -0.20624980702980253, + -0.22645535674627537, + -0.18943028486599597 + ] + } +} \ No newline at end of file diff --git a/scripts/train_q.py b/scripts/train_q.py index 26091b2..baf6dcd 100644 --- a/scripts/train_q.py +++ b/scripts/train_q.py @@ -19,6 +19,7 @@ from crane_controller.envs.controlled_crane_pendulum import AntiPendulumEnv from crane_controller.q_agent import QLearningAgent +logging.basicConfig(level=logging.INFO, format="%(message)s") LOGGER = logging.getLogger(__name__) @@ -52,8 +53,6 @@ def main() -> None: ) args = parser.parse_args() - logging.basicConfig(level=logging.INFO, format="%(message)s") - env = AntiPendulumEnv( build_crane, start_speed=args.v0, @@ -66,22 +65,12 @@ def main() -> None: agent = QLearningAgent(env, trained=None) agent.do_episodes(n_episodes=50, max_steps=1000) - elif args.intervals > 0: - Path(args.save_path).parent.mkdir(parents=True, exist_ok=True) - agent = QLearningAgent(env, trained=(args.save_path, False)) - for i in range(args.intervals): - _ = env.reset(seed=i + 1) - agent.do_episodes(n_episodes=10) - if i == 0: - agent = QLearningAgent(env, trained=(args.save_path, True)) - LOGGER.info("Model saved to %s", args.save_path) - else: Path(args.save_path).parent.mkdir(parents=True, exist_ok=True) trained = (args.trained, True) if args.trained else (args.save_path, False) agent = QLearningAgent(env, trained=trained) agent.do_episodes(n_episodes=args.episodes, max_steps=5000) - LOGGER.info("Model saved to %s", args.save_path) + LOGGER.info(f"Model saved to {args.save_path}") if __name__ == "__main__": diff --git a/scripts/use_q_ide.py b/scripts/use_q_ide.py new file mode 100644 index 0000000..fa3c4c7 --- /dev/null +++ b/scripts/use_q_ide.py @@ -0,0 +1,73 @@ +"""Train a Q-learning agent on the AntiPendulumEnv. Variant of train_q.py, running directly in the IDE. + +Examples +-------- +See end of the file, commented out code. +""" + +import logging +from pathlib import Path +from typing import Any + +from crane_controller.crane_factory import build_crane +from crane_controller.envs.controlled_crane_pendulum import AntiPendulumEnv +from crane_controller.q_agent import QLearningAgent + +logging.basicConfig(level=logging.INFO, format="%(message)s") +LOGGER = logging.getLogger(__name__) + + +def do_use(kwargs: dict[str, Any]) -> None: + """Perform training on the (Anti-)Pendulum environment using q-learning. + + Args: + dry_run (bool)=False: True: perform only a short run with plotting + v0 (float)=1.0: start speed of load in x-direction. 0: Pendulum mode, >/< 0 same/random start at every episode + render (str)='none': render mode of environment + reward (float)=-0.1: reward limit at which episode is terminated + file (str): Optional definition of model-save file + use_trained (bool): Use pre-trained data? + episodes (int)=10000: nnumber of episodes run in the training + steps (int)=5000: number of steps per episodes (if not terminated or truncated) + + """ + if "dry-train" in kwargs: # Check training setup (over-write some parameters) + kwargs.update({"render": "plot", "file": None, "use_trained": False, "episodes": 10, "steps": 1000}) + elif "dry_do" in kwargs: # Run a few episodes on trained data (file can be set by caller) + kwargs.update({"render": "plot", "use_trained": True, "episodes": 10, "steps": 1000}) + env = AntiPendulumEnv( + build_crane, + seed=1, + dt=0.1, + start_speed=kwargs.get("v0", 1.0), + render_mode=kwargs.get("render", "none"), + reward_limit=kwargs.get("reward", 0.0), + discrete=QLearningAgent.DEFAULT_DISCRETE.copy(), + ) + + filename = kwargs.get("file") + if filename is not None: + Path(filename).parent.mkdir(parents=True, exist_ok=True) + use_trained = kwargs.get("use_trained", False) + agent = QLearningAgent(env, filename=filename, use_trained=use_trained) + agent.do_episodes(n_episodes=kwargs.get("episodes", 100), max_steps=kwargs.get("steps", 5000)) + if filename is not None: + LOGGER.info(f"Model saved to {filename}") + + +if __name__ == "__main__": + models = Path(__file__).parent.resolve().parent / "models" + args = { + "v0": 1.0, + "render": "none", + "reward": 0.0, + "file": models / "q_anti-pendulum.json", + "use_trained": True, + "episodes": 1000, + "steps": 5000, + } + # args.update({'episodes':6000, 'use_trained':True}) # noqa: ERA001 ## do a mayor training adding to data + args.update({"episodes": 10, "render": "plot"}) + # args.update({'dry-train':True,}) # noqa: ERA001 ## check the setup before a long training + # args.update({'dry_do':True}) # noqa: ERA001 + do_use(args) diff --git a/src/crane_controller/envs/controlled_crane_pendulum.py b/src/crane_controller/envs/controlled_crane_pendulum.py index 760aa65..78c6c69 100644 --- a/src/crane_controller/envs/controlled_crane_pendulum.py +++ b/src/crane_controller/envs/controlled_crane_pendulum.py @@ -114,6 +114,7 @@ def __init__( # noqa: PLR0913 - environment API needs explicit parameters reward_limit: float = 50.0, dt: float = 1.0, discrete: dict[str, tuple[float | int, ...]] | None = None, + reward_fac: tuple[float, float, float] = (1.0, 0.0015, 0.001), ) -> None: """Initialize the anti-pendulum environment. @@ -126,13 +127,14 @@ def __init__( # noqa: PLR0913 - environment API needs explicit parameters self.wire: Wire = wire # type: ignore[assignment] # boom_by_name returns Boom; at runtime this is Wire assert render_mode in self.metadata["render_modes"], f"render_mode: {render_mode}" # type: ignore[operator] # metadata values are typed as object self.render_mode = render_mode + self.reward_fac = reward_fac self.reward_stats: list[list[float]] = [] self._playback: list[list[float]] = [] self.rewards: list[float] = [] if render_mode == "reward-tracking": self._reward_point = self._reward_plot_init() elif render_mode == "plot": - self.traces: dict[str, list[float]] = {"c_x": [], "c_v": [], "l_x": [], "l_v": []} + self.traces: dict[str, list[float]] = {"c_x": [], "c_v": [], "l_x": [], "l_v": [], "acc": []} self.obeservation_space: spaces.Box | spaces.Discrete # pyright: ignore[reportMissingTypeArgument] # Discrete type arg not needed here # Continuous observations are crane position, crane velocity, wire polar angle, and load x-velocity. @@ -160,6 +162,7 @@ def __init__( # noqa: PLR0913 - environment API needs explicit parameters self.action_space = spaces.Discrete(3, start=0, seed=42, dtype=np.int64) self.action_to_acc = {0: -self.acc, 1: 0.0, 2: self.acc} self.steps: int = 0 + self.time: float = 0.0 _ = super().reset(seed=seed) def _init_discrete( @@ -239,8 +242,8 @@ def show_plot(self, episode: int) -> None: episode : int Episode number used in the plot title. """ - _, ((ax1, ax2), (ax3, _)) = plt.subplots(2, 2) - times = np.arange(len(self.traces["c_x"])) + _, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2) + times = self.dt * np.arange(len(self.traces["c_x"])) damping = self.traces["l_v"][0] * np.exp(-times / self.wire.damping_time) ax1.plot(times, self.traces["l_x"], label="load angle", color="blue") ax1y2 = ax1.twinx() @@ -249,9 +252,12 @@ def show_plot(self, episode: int) -> None: ax2.plot(times, self.traces["c_x"], label="crane pos", color="blue") ax2y2 = ax2.twinx() ax2y2.plot(times, self.traces["c_v"], label="crane speed", color="red") - ax3.plot(list(range(len(self.rewards))), self.rewards, label="rewards") + ax3.plot(times[: len(self.rewards)], self.rewards, label="rewards") + ax4.plot(times, self.traces["acc"], label="x-acceleration", color="green") _ = ax1.legend() _ = ax2.legend() + _ = ax3.legend() + _ = ax4.legend() _ = plt.title(f"Detailed plot of episode {episode}, reward:{self.reward}") plt.show() for key in self.traces: @@ -309,7 +315,7 @@ def _get_discrete_obs(self, energy: float) -> tuple[int, ...]: int(self.crane.position[0] < 0.0), ) - def _get_obs(self) -> tuple[np.ndarray | tuple[int, ...], float, int]: + def _get_obs(self, acc: float = 0.0) -> tuple[np.ndarray | tuple[int, ...], float, int]: """Compute the current observation and reward from the crane state. In discrete mode the observation keys are:: @@ -325,16 +331,13 @@ def _get_obs(self) -> tuple[np.ndarray | tuple[int, ...], float, int]: ``(observation, reward, error_flag)``. """ energy = 9.81 * self.wire.end[2] + 0.5 * np.dot(self.wire.cm_v, self.wire.cm_v) # pyright: ignore[reportUnknownMemberType] # dynamic attr on Wire - if self.start_speed == 0.0: # start pendulum mode - reward = energy - else: # stop pendulum mode - reward = -energy - if np.sign(self.crane.position[0]) == np.sign(self.crane.velocity[0]): # moving away from origo - reward -= ( - 0.0015 * self.wire.length * (abs(self.crane.position[0]) + self.crane.velocity[0] ** 2 / self.acc) - ) - # if the crane moves towards the origo we do not add 'energy' - self.reward = reward + if self.start_speed != 0.0: # anti-pendulum mode + energy = -energy + if np.sign(self.crane.position[0]) == np.sign(self.crane.velocity[0]): # moving away from origo + positional = -self.wire.length * (abs(self.crane.position[0]) + self.crane.velocity[0] ** 2 / self.acc) + else: + positional = 0.0 # if the crane moves towards the origo we do not subtract reward + self.reward = sum(f * r for f, r in zip(self.reward_fac, (energy, positional, -self.time), strict=True)) obs: tuple[int, ...] | np.ndarray if len(self.discrete): @@ -349,8 +352,9 @@ def _get_obs(self) -> tuple[np.ndarray | tuple[int, ...], float, int]: self.traces["c_v"].append(self.crane.velocity[0]) self.traces["l_x"].append(self.wire.c_m[0]) self.traces["l_v"].append(self.wire.cm_v[0]) # pyright: ignore[reportUnknownMemberType] # dynamic attr on Wire + self.traces["acc"].append(acc) - return (obs, reward, err) + return (obs, self.reward, err) def low_reward(self) -> float: """Return the lowest possible reward for the current mode. @@ -392,16 +396,13 @@ def reset( ) -> tuple[tuple[int, ...] | np.ndarray, dict[str, float | int]]: """Reset the environment for a new episode. - Parameters - ---------- - seed : int or None, optional - Random seed (default None). - options : dict[str, object] or None, optional - Additional reset options (default None). + Args: + seed (int): Optional random seed (default None). + options (dict[str, object]): Optional additional arguments to super().reset(). Default None. Returns ------- - tuple[tuple[int, ...] | np.ndarray, dict[str, float | int]] + tuple[tuple[int, ...] | np.ndarray, dict[str, float | int]] Initial observation and info dict. """ self.reset_crane() @@ -432,13 +433,14 @@ def reset( (-self.start_speed - self.min_speed), ) speed = speed + self.min_speed if speed >= 0 else speed - self.min_speed - self.wire.cm_v[0] = np.radians(speed) # pyright: ignore[reportUnknownMemberType] # dynamic attr on Wire + self.wire.cm_v[0] = speed # pyright: ignore[reportUnknownMemberType] # dynamic attr on Wire else: # fixed speed in 'stop' mode (more control) - self.wire.cm_v[0] = np.radians(self.start_speed) # pyright: ignore[reportUnknownMemberType] # dynamic attr on Wire + self.wire.cm_v[0] = self.start_speed # pyright: ignore[reportUnknownMemberType] # dynamic attr on Wire obs, self.reward, _ = self._get_obs() if self.render_mode == "play-back": self._append_playback(0.0) self.steps = 0 + self.time = 0.0 info = self._get_info(self.reward, self.steps) return obs, info @@ -458,11 +460,13 @@ def step(self, action: int) -> tuple[tuple[int, ...] | np.ndarray, float, bool, action_idx = action if action_idx not in self.action_to_acc: action_idx += 1 - self.crane.d_velocity[0] = self.action_to_acc[action_idx] + acc = self.action_to_acc[action_idx] + self.crane.d_velocity[0] = acc self.steps += 1 - _ = self.crane.do_step(self.steps, self.dt) + _ = self.crane.do_step(self.time, self.dt) + self.time += self.dt - obs, self.reward, truncated = self._get_obs() + obs, self.reward, truncated = self._get_obs(acc) if self.render_mode != "none": self.rewards.append(self.reward) diff --git a/src/crane_controller/q_agent.py b/src/crane_controller/q_agent.py index 27ac78e..fbf5529 100644 --- a/src/crane_controller/q_agent.py +++ b/src/crane_controller/q_agent.py @@ -2,6 +2,7 @@ from __future__ import annotations +import datetime as dt import json import logging from ast import literal_eval @@ -63,10 +64,8 @@ class QLearningAgent: Minimum exploration rate (default 0.1). discount_factor : float, optional How much to value future rewards, in the range [0, 1] (default 0.95). - trained : tuple[str | Path, bool] or None, optional - Optional path and flag for pre-trained Q-values. - ``(filename, False)`` trains from scratch and saves; - ``(filename, True)`` loads pre-trained values (default None). + filename (Path): Optional path to filename for pre-trained data and saving of results + use_trained (bool) = False: load pre-trained values? """ DEFAULT_DISCRETE: ClassVar[dict[str, tuple[float | int, ...]]] = { @@ -84,17 +83,19 @@ def __init__( initial_epsilon: float = 1.0, final_epsilon: float = 0.1, discount_factor: float = 0.95, - trained: tuple[str | Path, bool] | None = None, + filename: Path | None = None, + *, + use_trained: bool = False, ) -> None: """Initialize the Q-learning agent. See the class docstring for parameter descriptions. """ self.env = env - _filename, self.use_pre_trained = trained if trained is not None else (None, False) - self.filename: Path | None = Path(_filename) if _filename is not None else None + self.filename = Path(filename) if filename is not None else None + self.use_trained = use_trained self.q_values: defaultdict[tuple[int, ...], np.ndarray] - if self.use_pre_trained and self.filename is not None: + if self.use_trained and self.filename is not None and self.filename.exists(): self.q_values = self.read_dumped(self.filename) self.epsilon = final_epsilon # assume that we are fully learned else: # start from scratch, but save the q_values afterwards @@ -206,10 +207,11 @@ def do_episodes(self, n_episodes: int = 1000, max_steps: int = 5000, show: int = Visualization mode - 0 for none, 1 for training summary, 2 for per-episode analysis (default 0). """ - if self.use_pre_trained: + if self.use_trained: logger.info("Starting %s episodes, using pre-trained values from %s", n_episodes, self.filename) else: logger.info("Starting new training with %s episodes.", n_episodes) + total_steps = 0 for _episode in tqdm(range(n_episodes)): # Start a new episode obs, _ = self.env.reset() @@ -229,42 +231,53 @@ def do_episodes(self, n_episodes: int = 1000, max_steps: int = 5000, show: int = self.analyse_episode() nsteps += 1 truncated |= nsteps > max_steps + total_steps += nsteps # Reduce exploration rate (agent becomes less random over time): self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon / (n_episodes / 2)) if show == SHOW_TRAINING_SUMMARY: self.analyse_training() if self.filename: - self.dump_results() + self.dump_results(episodes=n_episodes, steps=total_steps) - def dump_results(self, filename: str | Path = "") -> None: + def dump_results(self, filename: str | Path = "", episodes: int = -1, steps: int = -1) -> None: """Dump the Q-values to a JSON file. - Parameters - ---------- - filename : str or Path, optional - Target file path. When empty, the filename provided at - construction time is used (default ""). + Args: + filename (str|Path): Optional target file path. + When empty, the filename provided at construction time is used (default ""). + episodes (int): the number of episodes which have been run + steps (int): the limiting number of steps per episode """ if not filename: # automatic file name if self.filename is None: logger.warning("No base file name provided. Aborting dump to file.") return - if self.use_pre_trained: # do not overwrite pre-trained data - if len(self.filename.stem.split("_")) == 1: - _filename = self.filename.parent / f"{self.filename.stem}_1{self.filename.suffix}" - else: - stem, version = self.filename.stem.split("_") - _filename = self.filename.parent / f"{stem}_{int(version) + 1}{self.filename.suffix}" - else: - _filename = self.filename + _filename = self.filename else: _filename = Path(filename) converted: dict[str, list[float]] = {} for k, v in self.q_values.items(): converted |= {str(k): list(v)} + content = { + "date": dt.datetime.now(dt.UTC).strftime("%d.%m.%Y %H:%M:%S"), + "pendulum": { + "start_speed": str(self.env.start_speed), + "render_mode": str(self.env.render_mode), + "reward_limit": str(self.env.reward_limit), + }, + "q_agent": { + "use_trained": str(self.use_trained), + "filename": str(self.filename), + "episodes": str(episodes), + "steps": str(steps), + "learning_rate": str(self.lr), + "discount_factor": str(self.discount_factor), + }, + "q_values": converted, + } with _filename.open("w", encoding="utf-8") as _f: - json.dump(converted, _f, indent=3) + json.dump(content, _f, indent=3) logger.info("Updated q_values saved to %s", _filename.resolve()) def read_dumped(self, filename: str | Path) -> defaultdict[tuple[int, ...], np.ndarray]: @@ -286,7 +299,8 @@ def read_dumped(self, filename: str | Path) -> defaultdict[tuple[int, ...], np.n q_values: defaultdict[tuple[int, ...], np.ndarray] = defaultdict( lambda: np.array((0.0,) * self.env.action_space.n, float) # type: ignore[attr-defined,type-var] ) - for k, v in from_dump.items(): + assert "q_values" in from_dump, f"Key 'q_values' not found in file {filename}" + for k, v in from_dump["q_values"].items(): q_values.update({literal_eval(k): np.array(v) if isinstance(v, list) else v}) return q_values diff --git a/tests/conftest.py b/tests/conftest.py index 8ef92cb..95b77e2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -38,6 +38,15 @@ def test_dir() -> Path: return Path(__file__).parent.absolute() +@pytest.fixture(scope="package", autouse=True) +def model_dir() -> Path: + """ + Fixture that returns the absolute path of the directory containing the trained model files. + This fixture is automatically used for the entire package. + """ + return Path(__file__).parent.absolute().parent / "models" + + output_dirs: list[str] = [ "results", "data", diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index be54126..d83862e 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -61,7 +61,7 @@ def test_algorithm(crane: Callable[..., Crane], *, show: bool) -> None: from crane_controller.crane_factory import build_crane # noqa: F401 - retcode = pytest.main(["-rP -s -v", "--show", "False", __file__]) + retcode = pytest.main(["-rP -s -v", __file__]) assert retcode == 0, f"Return code {retcode}" os.chdir(Path(__file__).parent.absolute() / "test_working_directory") diff --git a/tests/test_environment.py b/tests/test_environment.py index 784915d..43a5d0e 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt import numpy as np -import pytest # noqa: F401 +import pytest from py_crane.crane import Crane from crane_controller.envs.controlled_crane_pendulum import AntiPendulumEnv @@ -79,15 +79,15 @@ def test_environment( def test_init(crane: Callable[..., Crane], *, show: bool) -> None: """Test the initialization of the environment.""" - env = AntiPendulumEnv(crane, seed=1, start_speed=1.0, render_mode="play-back" if show else "data") + env = AntiPendulumEnv(crane, seed=1, start_speed=-1.0, render_mode="play-back" if show else "data") rnd_u = env.np_random.uniform(2, 8) rnd_r = env.np_random.random() assert rnd_u == 5.07092974820154, f"Returns pseudo-random numbers when seed is given. Got {rnd_u} for seed 1" assert rnd_r == 0.9504636963259353, f"Returns pseudo-random numbers when seed is given. Got {rnd_r} for seed 1" obs, inf = env.reset(seed=1) - assert np.allclose(obs, [0.0, 0.0, np.pi, 0.017453292519943295]), f"Found {obs[3]}" + assert np.allclose(obs, [0.0, 0.0, np.pi, 0.1212789244604621]), f"Found {obs[3]}" assert inf["steps"] == 0 - assert abs(inf["reward"] + 0.5 * 0.017453292519943295**2) < 1e-9, f"Found initial reward {inf['reward']}" + assert abs(inf["reward"] + 0.5 * 0.1212789244604621**2) < 1e-9, f"Found initial reward {inf['reward']}" obs, reward, terminated, truncated, _ = env.step(-1) assert obs[0] == -0.1 assert obs[1] == -0.1 @@ -116,3 +116,17 @@ def test_observations_are_float(crane: Callable[..., Crane]) -> None: assert isinstance(obs, np.ndarray) assert obs.dtype == np.float64 assert not np.all(obs == obs.astype(int)) # sub-integer precision is preserved + + +if __name__ == "__main__": + import os + from pathlib import Path + + import pytest + + retcode = pytest.main(["-rP -s -v", __file__]) + assert retcode == 0, f"Return code {retcode}" + os.chdir(Path(__file__).parent.absolute() / "test_working_directory") + # test_init(build_crane, show=True) + # test_observation_space_dtype(build_crane) + # test_observations_are_float(build_crane) diff --git a/tests/test_q.py b/tests/test_q.py index 705ce1a..832621d 100644 --- a/tests/test_q.py +++ b/tests/test_q.py @@ -1,5 +1,7 @@ import logging +import shutil from collections.abc import Callable +from pathlib import Path import numpy as np from py_crane.crane import Crane @@ -18,16 +20,19 @@ def test_smoke(crane: Callable[..., Crane], *, show: bool) -> None: reward_limit=-0.05, discrete=QLearningAgent.DEFAULT_DISCRETE.copy(), ) - agent = QLearningAgent(env, trained=None) + agent = QLearningAgent(env, filename=None) agent.do_episodes(n_episodes=5, max_steps=200) -def test_q_analyse(crane: Callable[..., Crane], *, trained: tuple[str, bool]) -> None: +def test_q_analyse(crane: Callable[..., Crane], *, show: bool) -> None: + models = Path(__file__).parent.resolve().parent / "models" + assert (models / "q_trained.json").exists(), "Expect a file 'q_trained.json' in the models directory. Not found" + _ = shutil.copy2(models / "q_trained.json", ".") # copy to working_directory env = AntiPendulumEnv( crane, discrete=QLearningAgent.DEFAULT_DISCRETE.copy(), ) - agent = QLearningAgent(env, trained=trained) + agent = QLearningAgent(env, filename=Path("q_trained.json"), use_trained=True) for k, v in agent.q_values.items(): assert len(k) == 5, len(v) == 3 for pos in (0, 1): @@ -39,3 +44,38 @@ def test_q_analyse(crane: Callable[..., Crane], *, trained: tuple[str, bool]) -> col = [x[i] for x in res.values()] acc.append(np.average(col)) logger.info(f"averages: {acc}") + + +def test_intervals(crane: Callable[..., Crane]): + """Test that learning / saving / resuming learning works:""" + save_path = Path.cwd() / "q_interval_training.json" + env = AntiPendulumEnv( + crane, + start_speed=-1.0, + render_mode="none", + reward_limit=-0.05, + discrete=QLearningAgent.DEFAULT_DISCRETE.copy(), + ) + + agent = QLearningAgent(env, filename=save_path, use_trained=False) + for i in range(10): + _ = env.reset(seed=i + 1) + agent.do_episodes(n_episodes=2, max_steps=100) + if i == 0: + agent = QLearningAgent(env, filename=save_path, use_trained=True) + logger.info(f"Model saved to {save_path}") + + +if __name__ == "__main__": + import os + from pathlib import Path + + import pytest + + retcode = pytest.main(["-rP -s -v", __file__]) + assert retcode == 0, f"Return code {retcode}" + os.chdir(Path(__file__).parent.absolute() / "test_working_directory") + + # test_smoke(build_crane, show=True) + # test_q_analyse(build_crane, show=True) + # test_intervals(build_crane)