From 3162ee52725312b917bcb17a121f4149b2d191b1 Mon Sep 17 00:00:00 2001 From: Eisinger Date: Thu, 30 Apr 2026 10:43:41 +0200 Subject: [PATCH 1/4] Optimisations in q_agent. Added reporting of parameters to pendulum environment. --- models/q_anti-pendulum.json | 303 +++++++++--------- scripts/use_q_ide.py | 44 ++- .../envs/controlled_crane_pendulum.py | 15 +- src/crane_controller/q_agent.py | 58 ++-- 4 files changed, 226 insertions(+), 194 deletions(-) diff --git a/models/q_anti-pendulum.json b/models/q_anti-pendulum.json index e5468f5..80143f0 100644 --- a/models/q_anti-pendulum.json +++ b/models/q_anti-pendulum.json @@ -1,238 +1,231 @@ { - "date": "29.04.2026 06:53:49", + "start-training": "30.04.2026 08:29:32", + "end-training": "30.04.2026 08:29:54", "pendulum": { - "start_speed": "1.0", - "render_mode": "none", + "wire-length": "10.0", + "wire-q-factor": "50.0", + "reward-factors": "(1.0, 0.0015, 0.0)", + "acceleration": "0.1", + "step-size": "0.1", + "observations-discretization": "{'pos': (0, 1), 'speed': (0, 1), 'distance': (0.0, 1.0, 2.0, 5.0, 10.0, 20.0), 'sector': (0, 1), 'energies': (np.float64(0.0), np.float64(0.014941105158016455), np.float64(0.373300117199762), np.float64(1.4903594295023934), np.float64(5.916153900902383), np.float64(13.142907888746564), np.float64(98.1))}", "reward_limit": "0.0" }, "q_agent": { - "use_trained": "True", - "filename": "C:\\Users\\eis\\Documents\\Projects\\Simulation_Model_Assurance\\osp\\packages\\crane-controller\\models\\q_trained.json", - "episodes": "6000", - "steps": "30006000", + "filename": "C:\\Users\\eis\\Documents\\Projects\\Simulation_Model_Assurance\\osp\\packages\\crane-controller\\models\\q_anti-pendulum.json", + "use_file": "rw", + "episodes": "10", + "steps": "40020", "learning_rate": "0.1", - "discount_factor": "0.95" + "discount_factor": "0.95", + "epsilon-decay": "0.001", + "final-epsilon": "0.1", + "epsilon": "0.98" }, "q_values": { "(0, 0, 0, 1, 1)": [ - -0.7127999098776818, - -0.7204031480080048, - -0.7412199306092444 + -0.16705581244794693, + -0.1282253386578821, + -0.09732839899973421 ], "(0, 0, 0, 0, 0)": [ - -0.39882058710042717, - -0.43180066871885253, - -0.36789168710316733 + -0.37890841049813784, + -0.3221144173089857, + -0.3983817365055842 ], "(0, 0, 0, 1, 0)": [ - -0.49644387699601483, - -0.476832491697103, - -0.4808429865234992 + -0.2664435349978729, + -0.17001690959731303, + -0.2325842595231777 ], "(0, 0, 1, 1, 0)": [ - -0.3645407668594485, - -0.34941296493381946, - -0.19816466371287308 + -0.15064415891825617, + -0.2511495308113964, + -0.25691496256409363 ], "(0, 1, 1, 1, 0)": [ - -0.27335521972045285, - -0.31790379183091966, - -0.2034872242749172 + -0.20723301719553744, + -0.2388453016403667, + -0.16750979906585578 ], "(0, 1, 0, 1, 0)": [ - -0.6087312502217588, - -0.6754940805998394, - -0.6535399837815964 + -0.21424728224181977, + -0.22083742688011837, + -0.2194005545131396 ], "(0, 1, 0, 2, 0)": [ - -0.3949208069494323, - -0.4631147311708943, - -0.4776412098931605 + -0.048696104408199575, + -0.05034525692645991, + -0.05258803579860084 ], "(0, 0, 0, 2, 0)": [ - -0.2712981364627587, - -0.17654049561221843, - -0.2508326205458066 + -0.052053372911474195, + -0.05245228578367619, + -0.056669454625284825 ], "(0, 0, 1, 2, 0)": [ - -0.48498802268403846, - -0.4395215670942963, - -0.4722944229115358 + -0.04772137252269057, + -0.05881755142692484, + -0.054423243178985496 ], "(0, 1, 1, 2, 0)": [ - -0.4990197857734165, - -0.5327879501542365, - -0.4941112778002146 + -0.05278649224808442, + -0.07632187326419965, + -0.060938333453418286 ], "(0, 0, 0, 3, 0)": [ - -0.5387880192090378, - -0.4784712525147248, - -0.517952873228921 + -0.24369975025229282, + -0.24695892939305783, + -0.24911621770298206 ], "(0, 0, 1, 3, 0)": [ - -0.9096919737168466, - -0.9127452620968198, - -0.9420045562304813 + -0.24027345169278405, + -0.24031167906999237, + -0.24314101938021784 ], "(0, 1, 1, 3, 0)": [ - -0.8667830882886871, - -0.8757935260356404, - -0.8758730446292982 + -0.24861198983323804, + -0.23256743828903015, + -0.23006936132916378 ], "(0, 1, 0, 3, 0)": [ - -0.43694463364384284, - -0.32317351557573637, - -0.37671045435695716 + -0.23803043434121016, + -0.23300086647016396, + -0.252096348380666 ], "(0, 0, 1, 4, 0)": [ - -0.9306066987431686, - -0.9344436974699757, - -0.9655748330701439 + -0.29577889270947405, + -0.2981742809792015, + -0.29791983068372924 ], "(0, 1, 1, 4, 0)": [ - -1.009825188680733, - -1.016868953064729, - -0.9957558578136729 + -0.2958234315862551, + -0.30645418768275684, + -0.30661740666730875 ], "(0, 1, 0, 4, 0)": [ - -1.0369014461449961, - -1.0197443807886495, - -0.9976879186706535 + -0.3029623394714542, + -0.3018348837313605, + -0.3286817686032363 ], "(0, 0, 0, 4, 0)": [ - -0.8997996021390884, - -0.8913054855090478, - -0.8905808268698551 + -0.29781883698756567, + -0.2940971933423739, + -0.29733487917694584 ], "(0, 1, 0, 5, 0)": [ - -1.5826848759769268, - -2.7445575744690793, - -2.418367271258592 + -0.6329437113580243, + -0.6297747244548322, + -0.628882264517245 ], "(0, 0, 0, 5, 0)": [ - -1.6683128963172371, - -2.251311793007559, - -2.3851805260338708 + -0.6277272058045739, + -0.6384550155569291, + -0.6303149829373899 ], "(0, 0, 1, 5, 0)": [ - -2.2373735329153392, - -1.540495957571949, - -2.76737403946213 + -0.6396081577986387, + -0.6334955382726669, + -0.6407924717823796 ], "(0, 1, 1, 5, 0)": [ - -2.3314084980057013, - -1.5138347954693265, - -3.01917937692279 - ], - "(0, 1, 0, 1, 1)": [ - -0.261055476317184, - -0.27846357082465933, - -0.2774020188584469 + -0.6400694360313026, + -0.633976678537013, + -0.6330865725095266 ], "(0, 0, 1, 1, 1)": [ - -0.1342852547795026, - -0.2470169520959763, - -0.2126480225045462 + -0.10369497244449818, + -0.11119982054398986, + -0.12993117280633998 ], "(0, 1, 1, 1, 1)": [ - -0.2797380701937499, - -0.2945979948624919, - -0.24076792235861236 + -0.10542409331231692, + -0.11768594196503504, + -0.07456918970206615 ], - "(0, 1, 1, 2, 1)": [ - -0.2728844361768713, - -0.13972184601087123, - -0.3275551939924255 + "(0, 1, 0, 1, 1)": [ + -0.14656019585858246, + -0.1326055167833163, + -0.12885325016738228 ], "(0, 1, 0, 2, 1)": [ - -0.455039876626072, - -0.43296129057936117, - -0.4316836608934558 + -0.11967009468971596, + -0.12216127394629554, + -0.1429122098640283 ], "(0, 0, 0, 2, 1)": [ - -0.5686315534261919, - -0.5538077807341172, - -0.5496103554748915 + -0.128347455422042, + -0.1487387200161241, + -0.16517981243223914 + ], + "(0, 0, 1, 2, 1)": [ + -0.2682879902850379, + -0.19492371557724883, + -0.19236122464754474 + ], + "(0, 1, 1, 2, 1)": [ + -0.15105097271312573, + -0.20727413471853204, + -0.20105859767753256 ], "(0, 0, 0, 3, 1)": [ - -0.3966116309786595, - -0.3317568230750391, - -0.30858102473455185 + -0.11746821209617946, + -0.12239144215834397, + -0.13720619491970695 ], "(0, 0, 1, 3, 1)": [ - -0.8315655206780324, - -0.8333814435452305, - -0.7756715486070018 + -0.11651570970036011, + -0.11862687420946329, + -0.1260539064368246 ], "(0, 1, 1, 3, 1)": [ - -0.9831842749294987, - -0.9782200055298229, - -0.9930127344944251 + -0.1450901027213944, + -0.12420631784497436, + -0.12104519333576104 ], "(0, 1, 0, 3, 1)": [ - -0.6539627507980373, - -0.5852573030361744, - -0.6763166111516389 - ], - "(0, 1, 0, 4, 1)": [ - -0.9026947631459618, - -0.9407969095587947, - -0.965156293555089 - ], - "(0, 0, 0, 4, 1)": [ - -1.0295607537932239, - -1.0110078430982656, - -1.0217477290303583 + -0.13278128970879577, + -0.14433471293373654, + -0.14787321638372558 ], "(0, 0, 1, 4, 1)": [ - -1.0198724273521527, - -1.0273376264121616, - -1.0372176976745933 + -0.1720169164100566, + -0.1697573068991259, + -0.18563861409468002 ], "(0, 1, 1, 4, 1)": [ - -0.8367514960793316, - -0.8072150440466215, - -0.7999324124740165 + -0.19331603592176078, + -0.16465692724409528, + -0.1870968740626517 ], - "(0, 0, 1, 5, 1)": [ - -2.826052632887979, - -2.8216800501972115, - -2.8339346737355218 + "(0, 1, 0, 4, 1)": [ + -0.16839244888515778, + -0.15548076071108868, + -0.15890884729639157 + ], + "(0, 0, 0, 4, 1)": [ + -0.16755688928521928, + -0.1598773466377167, + -0.17156379062820235 ], "(0, 1, 1, 5, 1)": [ - -2.769926115050464, - -2.7896094930838444, - -2.784544336312168 + -1.0698086689883646, + -1.095931248568509, + -1.1160177171578312 ], "(0, 1, 0, 5, 1)": [ - -2.7928766121585795, - -2.7887103458307694, - -2.7938313222414037 + -1.1027542984743757, + -1.0928909342292816, + -1.0833097769293856 ], "(0, 0, 0, 5, 1)": [ - -2.8104767031828572, - -2.8062208641671322, - -2.8101316497840574 + -1.088841306513128, + -1.1035400078582027, + -1.1176895277220635 ], - "(0, 0, 1, 2, 1)": [ - -0.5213427535062244, - -0.5910962186115551, - -0.47068284911064734 - ], - "(0, 0, 1, 0, 0)": [ - -0.41286468970384216, - -0.42684989269569895, - -169.29165129913554 - ], - "(0, 1, 0, 0, 0)": [ - -0.11933872055277986, - -0.1195010821067479, - -0.1499420269134481 - ], - "(0, 1, 1, 0, 0)": [ - -0.20624980702980253, - -0.22645535674627537, - -0.18943028486599597 + "(0, 0, 1, 5, 1)": [ + -1.0994216179219447, + -1.1237160556971877, + -1.1090547917156828 ] } } \ No newline at end of file diff --git a/scripts/use_q_ide.py b/scripts/use_q_ide.py index fa3c4c7..03c8e72 100644 --- a/scripts/use_q_ide.py +++ b/scripts/use_q_ide.py @@ -26,15 +26,15 @@ def do_use(kwargs: dict[str, Any]) -> None: render (str)='none': render mode of environment reward (float)=-0.1: reward limit at which episode is terminated file (str): Optional definition of model-save file - use_trained (bool): Use pre-trained data? + use_file (str): How 'file' is used (if exists): 'r', 'w', 'rw' episodes (int)=10000: nnumber of episodes run in the training steps (int)=5000: number of steps per episodes (if not terminated or truncated) - + t_fac (float)=0.001 """ if "dry-train" in kwargs: # Check training setup (over-write some parameters) - kwargs.update({"render": "plot", "file": None, "use_trained": False, "episodes": 10, "steps": 1000}) + kwargs.update({"render": "plot", "file": None, "use_file": 'r', "episodes": 10, "steps": 1000}) elif "dry_do" in kwargs: # Run a few episodes on trained data (file can be set by caller) - kwargs.update({"render": "plot", "use_trained": True, "episodes": 10, "steps": 1000}) + kwargs.update({"render": "plot", "use_file": 'r', "episodes": 10, "steps": 1000}) env = AntiPendulumEnv( build_crane, seed=1, @@ -43,31 +43,49 @@ def do_use(kwargs: dict[str, Any]) -> None: render_mode=kwargs.get("render", "none"), reward_limit=kwargs.get("reward", 0.0), discrete=QLearningAgent.DEFAULT_DISCRETE.copy(), + reward_fac = (1.0, 0.0015,kwargs.get('t_fac',0.0)), ) filename = kwargs.get("file") if filename is not None: Path(filename).parent.mkdir(parents=True, exist_ok=True) - use_trained = kwargs.get("use_trained", False) - agent = QLearningAgent(env, filename=filename, use_trained=use_trained) + use_file = kwargs.get("use_file", 'r') + agent = QLearningAgent(env, filename=filename, use_file=use_file) agent.do_episodes(n_episodes=kwargs.get("episodes", 100), max_steps=kwargs.get("steps", 5000)) if filename is not None: LOGGER.info(f"Model saved to {filename}") if __name__ == "__main__": + # ruff: disable[ERA001] ## we intentionally work with commenting out lines here + def _args( base:dict[str,Any], upd:dict[str,Any])-> dict[str,Any]: + base.update(upd) + return base + models = Path(__file__).parent.resolve().parent / "models" - args = { + anti = { # anti-pendulum settings "v0": 1.0, "render": "none", "reward": 0.0, "file": models / "q_anti-pendulum.json", - "use_trained": True, + "use_file": 'rw', + "episodes": 1000, + "steps": 2000, + "t_fac":0.0, + } + pend = { # start pendulum settings + "v0": 0.0, + "render": "none", + "reward": 200.0, + "file": models / "q_pendulum.json", + "use_file": 'rw', "episodes": 1000, - "steps": 5000, + "steps": 2000, + "t_fac":0.0, } - # args.update({'episodes':6000, 'use_trained':True}) # noqa: ERA001 ## do a mayor training adding to data - args.update({"episodes": 10, "render": "plot"}) - # args.update({'dry-train':True,}) # noqa: ERA001 ## check the setup before a long training - # args.update({'dry_do':True}) # noqa: ERA001 + args = _args(anti, {'episodes':10}) # anti-pendulum training + # args = _args(pend, {'episodes':10000}) # pendulum training + # args = _args( anti, {"episodes": 10, "render": "plot","use_file":'r'}) # show anti-pendulum results + # args = _args( pend, {"episodes": 10, "render": "plot", "use_file":'r'}) # show start pendulum results + # args = args.update(_args(anti, {'dry-train':True,})) # check the setup before a long training do_use(args) diff --git a/src/crane_controller/envs/controlled_crane_pendulum.py b/src/crane_controller/envs/controlled_crane_pendulum.py index 78c6c69..6658ae1 100644 --- a/src/crane_controller/envs/controlled_crane_pendulum.py +++ b/src/crane_controller/envs/controlled_crane_pendulum.py @@ -87,6 +87,7 @@ class AntiPendulumEnv(gym.Env[AntiPendulumObs, int]): When provided, activates discrete observation mode with the given category boundaries. Expected keys: ``"angles"``, ``"pos"``, ``"speed"``, ``"distance"``, ``"sector"`` (default None). + reward_fac (tuple[float,...])=(1.0,0.0015,0.001): Weights between reward contributions """ metadata: ClassVar[dict[str, object]] = { # pyright: ignore[reportIncompatibleVariableOverride] # Gymnasium metadata typing is loose @@ -102,7 +103,7 @@ class AntiPendulumEnv(gym.Env[AntiPendulumObs, int]): "show-len-1": False, "x-max": None, } - + def __init__( # noqa: PLR0913 - environment API needs explicit parameters self, crane: Callable[..., Crane], @@ -487,3 +488,15 @@ def render(self) -> None: """Render the current episode.""" if self.render_mode == "play-back": # show the animation self.show_animation() + + def get_parameters(self) -> dict[str,Any]: + """Return the environment parameter settings as dict.""" + return { + 'wire-length':self.wire.length, + 'wire-q-factor':self.wire.q_factor, + 'reward-factors': self.reward_fac, + 'acceleration':self.acc, + 'step-size':self.dt, + 'observations-discretization':None if not hasattr(self,'discrete') else self.discrete, + 'reward_limit':self.reward_limit + } diff --git a/src/crane_controller/q_agent.py b/src/crane_controller/q_agent.py index fbf5529..ee26090 100644 --- a/src/crane_controller/q_agent.py +++ b/src/crane_controller/q_agent.py @@ -65,7 +65,7 @@ class QLearningAgent: discount_factor : float, optional How much to value future rewards, in the range [0, 1] (default 0.95). filename (Path): Optional path to filename for pre-trained data and saving of results - use_trained (bool) = False: load pre-trained values? + use_file (str) = 'r': How to use filename. 'r', 'w', 'rw'. File is not read when not found! """ DEFAULT_DISCRETE: ClassVar[dict[str, tuple[float | int, ...]]] = { @@ -81,11 +81,11 @@ def __init__( env: AntiPendulumEnv, learning_rate: float = 0.1, initial_epsilon: float = 1.0, + epsilon_decay:float = 1e-3, final_epsilon: float = 0.1, discount_factor: float = 0.95, filename: Path | None = None, - *, - use_trained: bool = False, + use_file: str = 'r', ) -> None: """Initialize the Q-learning agent. @@ -93,23 +93,20 @@ def __init__( """ self.env = env self.filename = Path(filename) if filename is not None else None - self.use_trained = use_trained + self.use_file = use_file self.q_values: defaultdict[tuple[int, ...], np.ndarray] - if self.use_trained and self.filename is not None and self.filename.exists(): - self.q_values = self.read_dumped(self.filename) - self.epsilon = final_epsilon # assume that we are fully learned - else: # start from scratch, but save the q_values afterwards - self.q_values = defaultdict(lambda: np.array((0.0,) * env.action_space.n, float)) # type: ignore[attr-defined,type-var] - self.epsilon = initial_epsilon # start from scratch self.lr = learning_rate self.discount_factor = discount_factor # How much we care about future rewards # Exploration parameters + self.epsilon = initial_epsilon + self.epsilon_decay = epsilon_decay self.final_epsilon = final_epsilon # Track learning progress self.training_error: list[float] = [] + self.previous_steps = 0 def analyse_q(self, obs: tuple[int, ...]) -> None: """Log Q-table entries matching an observation pattern. @@ -207,10 +204,13 @@ def do_episodes(self, n_episodes: int = 1000, max_steps: int = 5000, show: int = Visualization mode - 0 for none, 1 for training summary, 2 for per-episode analysis (default 0). """ - if self.use_trained: + if 'r' in self.use_file and self.filename is not None and self.filename.exists(): + self.q_values = self.read_dumped(self.filename) logger.info("Starting %s episodes, using pre-trained values from %s", n_episodes, self.filename) - else: + else: # start from scratch + self.q_values = defaultdict(lambda: np.array((0.0,) * self.env.action_space.n, float)) # type: ignore[attr-defined,type-var] logger.info("Starting new training with %s episodes.", n_episodes) + start_time = dt.datetime.now(dt.UTC) total_steps = 0 for _episode in tqdm(range(n_episodes)): # Start a new episode @@ -233,13 +233,17 @@ def do_episodes(self, n_episodes: int = 1000, max_steps: int = 5000, show: int = truncated |= nsteps > max_steps total_steps += nsteps # Reduce exploration rate (agent becomes less random over time): - self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon / (n_episodes / 2)) + self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay) if show == SHOW_TRAINING_SUMMARY: self.analyse_training() - if self.filename: - self.dump_results(episodes=n_episodes, steps=total_steps) - - def dump_results(self, filename: str | Path = "", episodes: int = -1, steps: int = -1) -> None: + if self.filename and 'w' in self.use_file: + self.dump_results(episodes=n_episodes, steps=total_steps, start_time=start_time) + + def dump_results(self, + filename: str | Path = "", + episodes: int = -1, + steps: int = -1, + start_time:dt.datetime|None=None) -> None: """Dump the Q-values to a JSON file. Args: @@ -247,6 +251,7 @@ def dump_results(self, filename: str | Path = "", episodes: int = -1, steps: int When empty, the filename provided at construction time is used (default ""). episodes (int): the number of episodes which have been run steps (int): the limiting number of steps per episode + start_time (dt.datetime): clock-time when the training started """ if not filename: # automatic file name if self.filename is None: @@ -259,20 +264,21 @@ def dump_results(self, filename: str | Path = "", episodes: int = -1, steps: int converted: dict[str, list[float]] = {} for k, v in self.q_values.items(): converted |= {str(k): list(v)} + env_parameters = { k:str(v) for k,v in self.env.get_parameters().items()} content = { - "date": dt.datetime.now(dt.UTC).strftime("%d.%m.%Y %H:%M:%S"), - "pendulum": { - "start_speed": str(self.env.start_speed), - "render_mode": str(self.env.render_mode), - "reward_limit": str(self.env.reward_limit), - }, + "start-training":"unknown" if start_time is None else start_time.strftime("%d.%m.%Y %H:%M:%S"), + "end-training": dt.datetime.now(dt.UTC).strftime("%d.%m.%Y %H:%M:%S"), + "pendulum": env_parameters, "q_agent": { - "use_trained": str(self.use_trained), "filename": str(self.filename), + "use_file": self.use_file, "episodes": str(episodes), - "steps": str(steps), + "steps": str(steps+self.previous_steps), "learning_rate": str(self.lr), "discount_factor": str(self.discount_factor), + "epsilon-decay":str(self.epsilon_decay), + "final-epsilon":str(self.final_epsilon), + "epsilon":str(self.epsilon), }, "q_values": converted, } @@ -296,6 +302,8 @@ def read_dumped(self, filename: str | Path) -> defaultdict[tuple[int, ...], np.n path = Path(filename) with path.open(encoding="utf-8") as _f: from_dump = json.load(_f) + self.previous_steps = int(from_dump["q_agent"]["steps"]) + self.epsilon = float(from_dump["q_agent"].get("epsilon", 1.0)) q_values: defaultdict[tuple[int, ...], np.ndarray] = defaultdict( lambda: np.array((0.0,) * self.env.action_space.n, float) # type: ignore[attr-defined,type-var] ) From 8d5cc936266ac1b9e0295a189f8502d36dcc1c8c Mon Sep 17 00:00:00 2001 From: Eisinger Date: Thu, 30 Apr 2026 10:55:21 +0200 Subject: [PATCH 2/4] Fixed quality issues --- scripts/use_q_ide.py | 30 +++++++++--------- .../envs/controlled_crane_pendulum.py | 22 ++++++------- src/crane_controller/q_agent.py | 31 +++++++++---------- tests/test_q.py | 6 ++-- 4 files changed, 44 insertions(+), 45 deletions(-) diff --git a/scripts/use_q_ide.py b/scripts/use_q_ide.py index 03c8e72..4edb17e 100644 --- a/scripts/use_q_ide.py +++ b/scripts/use_q_ide.py @@ -32,9 +32,9 @@ def do_use(kwargs: dict[str, Any]) -> None: t_fac (float)=0.001 """ if "dry-train" in kwargs: # Check training setup (over-write some parameters) - kwargs.update({"render": "plot", "file": None, "use_file": 'r', "episodes": 10, "steps": 1000}) + kwargs.update({"render": "plot", "file": None, "use_file": "r", "episodes": 10, "steps": 1000}) elif "dry_do" in kwargs: # Run a few episodes on trained data (file can be set by caller) - kwargs.update({"render": "plot", "use_file": 'r', "episodes": 10, "steps": 1000}) + kwargs.update({"render": "plot", "use_file": "r", "episodes": 10, "steps": 1000}) env = AntiPendulumEnv( build_crane, seed=1, @@ -43,13 +43,13 @@ def do_use(kwargs: dict[str, Any]) -> None: render_mode=kwargs.get("render", "none"), reward_limit=kwargs.get("reward", 0.0), discrete=QLearningAgent.DEFAULT_DISCRETE.copy(), - reward_fac = (1.0, 0.0015,kwargs.get('t_fac',0.0)), + reward_fac=(1.0, 0.0015, kwargs.get("t_fac", 0.0)), ) filename = kwargs.get("file") if filename is not None: Path(filename).parent.mkdir(parents=True, exist_ok=True) - use_file = kwargs.get("use_file", 'r') + use_file = kwargs.get("use_file", "r") agent = QLearningAgent(env, filename=filename, use_file=use_file) agent.do_episodes(n_episodes=kwargs.get("episodes", 100), max_steps=kwargs.get("steps", 5000)) if filename is not None: @@ -57,35 +57,37 @@ def do_use(kwargs: dict[str, Any]) -> None: if __name__ == "__main__": - # ruff: disable[ERA001] ## we intentionally work with commenting out lines here - def _args( base:dict[str,Any], upd:dict[str,Any])-> dict[str,Any]: + + def _args(base: dict[str, Any], upd: dict[str, Any]) -> dict[str, Any]: base.update(upd) return base - + models = Path(__file__).parent.resolve().parent / "models" - anti = { # anti-pendulum settings + anti = { # anti-pendulum settings "v0": 1.0, "render": "none", "reward": 0.0, "file": models / "q_anti-pendulum.json", - "use_file": 'rw', + "use_file": "rw", "episodes": 1000, "steps": 2000, - "t_fac":0.0, + "t_fac": 0.0, } - pend = { # start pendulum settings + pend = { # start pendulum settings "v0": 0.0, "render": "none", "reward": 200.0, "file": models / "q_pendulum.json", - "use_file": 'rw', + "use_file": "rw", "episodes": 1000, "steps": 2000, - "t_fac":0.0, + "t_fac": 0.0, } - args = _args(anti, {'episodes':10}) # anti-pendulum training + # ruff: disable[ERA001] ## we intentionally work with commenting out lines here + args = _args(anti, {"episodes": 10}) # anti-pendulum training # args = _args(pend, {'episodes':10000}) # pendulum training # args = _args( anti, {"episodes": 10, "render": "plot","use_file":'r'}) # show anti-pendulum results # args = _args( pend, {"episodes": 10, "render": "plot", "use_file":'r'}) # show start pendulum results # args = args.update(_args(anti, {'dry-train':True,})) # check the setup before a long training + # ruff: enable[ERA001] do_use(args) diff --git a/src/crane_controller/envs/controlled_crane_pendulum.py b/src/crane_controller/envs/controlled_crane_pendulum.py index 09454fe..49f5ff2 100644 --- a/src/crane_controller/envs/controlled_crane_pendulum.py +++ b/src/crane_controller/envs/controlled_crane_pendulum.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar import gymnasium as gym import matplotlib.pyplot as plt @@ -103,7 +103,7 @@ class AntiPendulumEnv(gym.Env[AntiPendulumObs, int]): "show-len-1": False, "x-max": None, } - + def __init__( # noqa: PLR0913 - environment API needs explicit parameters self, crane: Callable[..., Crane], @@ -495,14 +495,14 @@ def render(self) -> None: elif self.render_mode == "plot": self.show_plot(self.nresets) - def get_parameters(self) -> dict[str,Any]: + def get_parameters(self) -> dict[str, Any]: """Return the environment parameter settings as dict.""" return { - 'wire-length':self.wire.length, - 'wire-q-factor':self.wire.q_factor, - 'reward-factors': self.reward_fac, - 'acceleration':self.acc, - 'step-size':self.dt, - 'observations-discretization':None if not hasattr(self,'discrete') else self.discrete, - 'reward_limit':self.reward_limit - } + "wire-length": self.wire.length, + "wire-q-factor": self.wire.q_factor, + "reward-factors": self.reward_fac, + "acceleration": self.acc, + "step-size": self.dt, + "observations-discretization": None if not hasattr(self, "discrete") else self.discrete, + "reward_limit": self.reward_limit, + } diff --git a/src/crane_controller/q_agent.py b/src/crane_controller/q_agent.py index ee26090..dd115f7 100644 --- a/src/crane_controller/q_agent.py +++ b/src/crane_controller/q_agent.py @@ -80,12 +80,11 @@ def __init__( self, env: AntiPendulumEnv, learning_rate: float = 0.1, - initial_epsilon: float = 1.0, - epsilon_decay:float = 1e-3, + epsilon_decay: float = 1e-3, final_epsilon: float = 0.1, discount_factor: float = 0.95, filename: Path | None = None, - use_file: str = 'r', + use_file: str = "r", ) -> None: """Initialize the Q-learning agent. @@ -100,7 +99,7 @@ def __init__( self.discount_factor = discount_factor # How much we care about future rewards # Exploration parameters - self.epsilon = initial_epsilon + self.epsilon = 1.0 self.epsilon_decay = epsilon_decay self.final_epsilon = final_epsilon @@ -204,7 +203,7 @@ def do_episodes(self, n_episodes: int = 1000, max_steps: int = 5000, show: int = Visualization mode - 0 for none, 1 for training summary, 2 for per-episode analysis (default 0). """ - if 'r' in self.use_file and self.filename is not None and self.filename.exists(): + if "r" in self.use_file and self.filename is not None and self.filename.exists(): self.q_values = self.read_dumped(self.filename) logger.info("Starting %s episodes, using pre-trained values from %s", n_episodes, self.filename) else: # start from scratch @@ -236,14 +235,12 @@ def do_episodes(self, n_episodes: int = 1000, max_steps: int = 5000, show: int = self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay) if show == SHOW_TRAINING_SUMMARY: self.analyse_training() - if self.filename and 'w' in self.use_file: + if self.filename and "w" in self.use_file: self.dump_results(episodes=n_episodes, steps=total_steps, start_time=start_time) - def dump_results(self, - filename: str | Path = "", - episodes: int = -1, - steps: int = -1, - start_time:dt.datetime|None=None) -> None: + def dump_results( + self, filename: str | Path = "", episodes: int = -1, steps: int = -1, start_time: dt.datetime | None = None + ) -> None: """Dump the Q-values to a JSON file. Args: @@ -264,21 +261,21 @@ def dump_results(self, converted: dict[str, list[float]] = {} for k, v in self.q_values.items(): converted |= {str(k): list(v)} - env_parameters = { k:str(v) for k,v in self.env.get_parameters().items()} + env_parameters = {k: str(v) for k, v in self.env.get_parameters().items()} content = { - "start-training":"unknown" if start_time is None else start_time.strftime("%d.%m.%Y %H:%M:%S"), + "start-training": "unknown" if start_time is None else start_time.strftime("%d.%m.%Y %H:%M:%S"), "end-training": dt.datetime.now(dt.UTC).strftime("%d.%m.%Y %H:%M:%S"), "pendulum": env_parameters, "q_agent": { "filename": str(self.filename), "use_file": self.use_file, "episodes": str(episodes), - "steps": str(steps+self.previous_steps), + "steps": str(steps + self.previous_steps), "learning_rate": str(self.lr), "discount_factor": str(self.discount_factor), - "epsilon-decay":str(self.epsilon_decay), - "final-epsilon":str(self.final_epsilon), - "epsilon":str(self.epsilon), + "epsilon-decay": str(self.epsilon_decay), + "final-epsilon": str(self.final_epsilon), + "epsilon": str(self.epsilon), }, "q_values": converted, } diff --git a/tests/test_q.py b/tests/test_q.py index 832621d..fd6fa05 100644 --- a/tests/test_q.py +++ b/tests/test_q.py @@ -32,7 +32,7 @@ def test_q_analyse(crane: Callable[..., Crane], *, show: bool) -> None: crane, discrete=QLearningAgent.DEFAULT_DISCRETE.copy(), ) - agent = QLearningAgent(env, filename=Path("q_trained.json"), use_trained=True) + agent = QLearningAgent(env, filename=Path("q_trained.json"), use_file='r') for k, v in agent.q_values.items(): assert len(k) == 5, len(v) == 3 for pos in (0, 1): @@ -57,12 +57,12 @@ def test_intervals(crane: Callable[..., Crane]): discrete=QLearningAgent.DEFAULT_DISCRETE.copy(), ) - agent = QLearningAgent(env, filename=save_path, use_trained=False) + agent = QLearningAgent(env, filename=save_path, use_file='w') for i in range(10): _ = env.reset(seed=i + 1) agent.do_episodes(n_episodes=2, max_steps=100) if i == 0: - agent = QLearningAgent(env, filename=save_path, use_trained=True) + agent = QLearningAgent(env, filename=save_path, use_file='rw') logger.info(f"Model saved to {save_path}") From 7578bb9ded864026f04b60ed65702239407e6914 Mon Sep 17 00:00:00 2001 From: Eisinger Date: Thu, 30 Apr 2026 12:37:23 +0200 Subject: [PATCH 3/4] Fixed remaining quality issues --- models/q_pendulum.json | 1258 +++++++++++++++++++++++++++++++ src/crane_controller/q_agent.py | 25 +- tests/test_q.py | 10 +- 3 files changed, 1281 insertions(+), 12 deletions(-) create mode 100644 models/q_pendulum.json diff --git a/models/q_pendulum.json b/models/q_pendulum.json new file mode 100644 index 0000000..24b3925 --- /dev/null +++ b/models/q_pendulum.json @@ -0,0 +1,1258 @@ +{ + "date": "30.04.2026 05:15:03", + "pendulum": { + "start_speed": "0.0", + "render_mode": "none", + "reward_limit": "1000.0" + }, + "q_agent": { + "use_trained": "True", + "filename": "/home/se/osp/packages/crane-controller/models/q_pendulum.json", + "episodes": "10000", + "steps": "19991482", + "learning_rate": "0.1", + "discount_factor": "0.95" + }, + "q_values": { + "(1, 0, 0, 1, 1)": [ + -0.0015285453382582262, + -0.0019110699513047503, + -0.000421685773718873 + ], + "(1, 0, 0, 0, 0)": [ + 0.0016483778628561532, + 0.0009174558585016858, + 0.0001847120557717193 + ], + "(1, 1, 1, 1, 0)": [ + -0.00011618763900436013, + 0.0021056726444264756, + -0.0009250088093969632 + ], + "(2, 0, 0, 1, 1)": [ + 0.30381525310187946, + 0.13268574959167037, + 0.08968956867959964 + ], + "(2, 0, 1, 1, 1)": [ + 0.07021166494175266, + 0.11544068028302096, + 0.2655879939625879 + ], + "(2, 1, 1, 2, 1)": [ + 0.0468700802191448, + 0.24296967824838114, + 0.057465378203944484 + ], + "(2, 1, 0, 2, 1)": [ + 0.247316722251548, + 0.07724841470718162, + 0.04444885410818501 + ], + "(2, 0, 0, 2, 1)": [ + 0.03151075349886775, + 0.1198098266764795, + 0.026761926901275757 + ], + "(2, 0, 0, 3, 1)": [ + 0.24750790139852671, + -0.00742543468925339, + 0.003439802770232106 + ], + "(2, 0, 1, 3, 1)": [ + 0.013165205849978824, + 0.020782332264133313, + 0.16600804643720815 + ], + "(2, 1, 1, 3, 1)": [ + 0.006877018618653728, + 0.017372356548709385, + 0.1822037784645431 + ], + "(2, 1, 0, 3, 1)": [ + 0.011260461429967967, + -0.06408511029671254, + 0.16832760241231254 + ], + "(2, 1, 0, 4, 1)": [ + 0.30877072872600403, + 0.17712588766549373, + -0.11992985889469697 + ], + "(2, 0, 0, 4, 1)": [ + 0.07171333312506538, + 0.3196709073875426, + -0.010305375893922684 + ], + "(2, 0, 1, 4, 1)": [ + 0.0460495219813757, + -0.025981954700818086, + 0.35798806472060796 + ], + "(2, 1, 1, 4, 1)": [ + 0.28024071171042864, + -0.08081625384759063, + -0.06021051977246365 + ], + "(2, 0, 1, 5, 1)": [ + 0.11199272160151952, + 0.34003948184320887, + 0.15555212679650068 + ], + "(2, 1, 1, 5, 1)": [ + -0.06388301321252567, + -0.27230305038177305, + 0.3227408520413516 + ], + "(2, 1, 0, 5, 1)": [ + 0.33544847358707, + -0.014778838303531092, + -0.13036007932481597 + ], + "(2, 0, 0, 5, 1)": [ + 0.3368562070401561, + -0.06442413772899995, + 0.015851813972102308 + ], + "(1, 1, 1, 5, 1)": [ + -0.873610573166523, + -0.8795035200229879, + -0.8813170678799008 + ], + "(1, 1, 0, 5, 1)": [ + -0.8588580712968287, + -0.8603630825591916, + -0.8545640919171729 + ], + "(1, 0, 0, 5, 1)": [ + -0.876058331400781, + -0.8812733294953802, + -0.8781109627654614 + ], + "(1, 0, 1, 5, 1)": [ + -0.8754681221299438, + -0.8811483806653537, + -0.8795342715024759 + ], + "(2, 1, 1, 1, 1)": [ + 0.2575141349768899, + 0.14721086451092427, + 0.13938193635472462 + ], + "(2, 0, 1, 2, 1)": [ + 0.049090745483137355, + 0.035345162836350655, + 0.1850890582095332 + ], + "(3, 0, 0, 3, 1)": [ + 0.408660876661018, + 0.42698247924321825, + 0.40792534233546174 + ], + "(3, 0, 1, 3, 1)": [ + 0.3600676621922624, + 0.3929021532144681, + 0.461063559575569 + ], + "(3, 1, 1, 3, 1)": [ + 0.43092477422897313, + 0.4760286469676715, + 0.5329087094566657 + ], + "(3, 1, 0, 3, 1)": [ + 0.4099435292952729, + 0.43676458251571, + 0.5251269059135935 + ], + "(3, 0, 0, 4, 1)": [ + 0.4048621627470891, + 0.4066271465074262, + 0.4043352559362854 + ], + "(3, 0, 1, 4, 1)": [ + 0.39328686645976846, + 0.39556738781196227, + 0.3959743788190199 + ], + "(3, 1, 0, 5, 1)": [ + 0.2918677426270124, + 0.2808605809982236, + 0.3364418740886684 + ], + "(3, 0, 0, 5, 1)": [ + 0.015025569672587752, + 0.08196942899783732, + 0.3327158807318435 + ], + "(3, 0, 1, 5, 1)": [ + 0.15452953096605598, + 0.32706016277878186, + 0.19602635793745557 + ], + "(2, 1, 0, 1, 1)": [ + 0.28308955078023956, + 0.16197169711955506, + 0.11997498874359913 + ], + "(3, 1, 1, 5, 1)": [ + 0.04055530756891616, + 0.2713994534576125, + 0.39683102819432226 + ], + "(4, 0, 0, 5, 1)": [ + 4.061952905462236, + 1.8796422036374005, + 2.6392271976381876 + ], + "(4, 0, 1, 5, 1)": [ + 1.8612065508222904, + 2.103909066858137, + 4.240572243261516 + ], + "(4, 1, 1, 5, 1)": [ + 2.7147234118909207, + 4.2477919190765006, + 2.532801305018678 + ], + "(4, 1, 0, 5, 1)": [ + 5.335646400509759, + 2.3417395159744956, + 2.8731429101445904 + ], + "(1, 0, 1, 1, 1)": [ + -0.003409519735202579, + -0.0013383819217571565, + 0.0002056980625474665 + ], + "(3, 1, 1, 4, 1)": [ + 0.38529100751464557, + 0.3942557035549322, + 0.4431156152070642 + ], + "(3, 0, 0, 2, 1)": [ + 0.5090782560731483, + 0.4798070339288543, + 0.4440706186764382 + ], + "(3, 0, 1, 2, 1)": [ + 0.4031617038672731, + 0.41820257002525474, + 0.4254288577893825 + ], + "(3, 1, 0, 2, 1)": [ + 0.5383778912624413, + 0.5816045650663727, + 0.44907309167118936 + ], + "(3, 0, 0, 1, 1)": [ + 0.3820396000856337, + 0.4244240876422671, + 0.4809947245965916 + ], + "(3, 0, 1, 1, 1)": [ + 0.48309577607918963, + 0.4359746680582087, + 0.47844205657321176 + ], + "(3, 1, 1, 2, 1)": [ + 0.4510595166267413, + 0.5413951132995184, + 0.5396546938623438 + ], + "(3, 1, 0, 4, 1)": [ + 0.4911755694970669, + 0.4273041013300876, + 0.4113438183162773 + ], + "(3, 1, 0, 1, 1)": [ + 0.41206946560587604, + 0.4276691064016972, + 0.3891151329059017 + ], + "(3, 0, 0, 1, 0)": [ + 0.595092804268447, + 0.5709074723626882, + 0.4357015522877032 + ], + "(3, 0, 1, 1, 0)": [ + 0.5048087137826464, + 0.5131171175030076, + 0.5824969529639638 + ], + "(2, 0, 1, 1, 0)": [ + 0.16890785501899347, + 0.07849898959202127, + 0.13065691356841727 + ], + "(2, 1, 1, 1, 0)": [ + -0.0019192156044434555, + 0.11937402910019926, + 0.06218611939997548 + ], + "(2, 1, 0, 1, 0)": [ + 0.15656530313154501, + 0.05852070858351637, + 0.05917317431102258 + ], + "(2, 0, 0, 1, 0)": [ + 0.11925737271077642, + 0.06270232978106371, + 0.016145004918409697 + ], + "(3, 1, 1, 1, 1)": [ + 0.415554313371666, + 0.4251120871709033, + 0.41738534617858564 + ], + "(3, 1, 0, 1, 0)": [ + 0.5175638162159195, + 0.4768692444308892, + 0.4092874554117059 + ], + "(1, 1, 0, 1, 0)": [ + 0.007559388663013664, + -0.07141246975944493, + -0.03218011823701119 + ], + "(1, 1, 1, 1, 1)": [ + -0.006336313500421211, + -0.00722006182105164, + -0.015438778111669823 + ], + "(1, 1, 0, 1, 1)": [ + -0.011637555446352072, + -0.03469500547718182, + -0.06946810336387721 + ], + "(1, 1, 0, 2, 1)": [ + -0.02059287808812973, + -0.020207666113278937, + -0.023751195001154267 + ], + "(1, 0, 0, 2, 1)": [ + -0.03136891469721944, + -0.018845272510248604, + -0.025639212801306065 + ], + "(2, 0, 0, 2, 0)": [ + -0.007405306099820059, + 0.13416722893659513, + -0.003863979660054713 + ], + "(3, 0, 0, 2, 0)": [ + 0.4968529980315005, + 0.6083878273580536, + 0.5107250954251326 + ], + "(3, 0, 1, 2, 0)": [ + 0.4572105904076367, + 0.48251636216964044, + 0.5371402979328573 + ], + "(3, 1, 1, 2, 0)": [ + 0.5834210121161674, + 0.5447607381569435, + 0.6522461652145303 + ], + "(3, 1, 1, 3, 0)": [ + 0.4515300417767746, + 0.45066875186289224, + 0.6830079417945389 + ], + "(3, 1, 0, 3, 0)": [ + 0.5338016897375181, + 0.5414142734086245, + 0.6346264676956274 + ], + "(3, 0, 0, 3, 0)": [ + 0.561690114055678, + 0.5920536892988973, + 0.5543067752647609 + ], + "(3, 0, 1, 3, 0)": [ + 0.5489920980754698, + 0.5660011505184062, + 0.5167225836179886 + ], + "(2, 1, 1, 3, 0)": [ + 0.1494485963494498, + -0.03830396513846329, + 0.015657549848606813 + ], + "(2, 1, 0, 3, 0)": [ + 0.20593271843361607, + 0.09479740257152519, + 0.0959316384523621 + ], + "(2, 1, 0, 2, 0)": [ + 0.13616448787873342, + 0.019560267471690474, + -0.005724736049839602 + ], + "(3, 1, 0, 2, 0)": [ + 0.683590879589438, + 0.5526158769509005, + 0.46954602590689115 + ], + "(1, 0, 0, 1, 0)": [ + 0.03733255481235882, + -0.02849128758664682, + -0.09589936241358511 + ], + "(1, 1, 1, 2, 1)": [ + -0.006083016288581055, + -0.03523377742840035, + -0.019203666274691625 + ], + "(1, 0, 1, 2, 1)": [ + -0.01686144825783039, + -0.008924069806829863, + 0.009536322538022453 + ], + "(1, 0, 1, 3, 1)": [ + -0.034434632147796705, + 0.00024112079172242465, + -0.03197066856336422 + ], + "(1, 1, 1, 3, 1)": [ + -0.01895592863353697, + -0.04411521018370294, + 0.01775191105812379 + ], + "(1, 1, 0, 3, 1)": [ + 0.04225706506102414, + -0.0367929078115716, + -0.04368072481979268 + ], + "(1, 0, 0, 3, 1)": [ + -0.007437638387498356, + -0.044450404397953926, + -0.04469796462300711 + ], + "(1, 1, 1, 4, 1)": [ + -0.09646741076342318, + -0.09875923262296923, + -0.07611277807180256 + ], + "(1, 1, 0, 4, 1)": [ + 0.004840534632955751, + -0.09842756570162219, + -0.09016058064434311 + ], + "(1, 0, 0, 4, 1)": [ + -0.09004032812642203, + -0.07791186134035129, + -0.09719535819232616 + ], + "(1, 0, 1, 4, 1)": [ + -0.05966874855523139, + 0.00806343851278417, + -0.0359359869255538 + ], + "(1, 0, 1, 1, 0)": [ + -0.020070667836140862, + -0.00023703353233769307, + -0.019699742890545824 + ], + "(1, 0, 1, 2, 0)": [ + -0.08614089449335621, + -0.07468383088946322, + -0.028281994617095248 + ], + "(1, 1, 1, 2, 0)": [ + -0.020887757859586534, + -0.05007737176424587, + 0.013443648795931901 + ], + "(1, 1, 0, 2, 0)": [ + -0.25387563088246234, + -0.008011626613088369, + -0.18283538351700016 + ], + "(1, 1, 0, 3, 0)": [ + -0.07464578664914494, + -0.043449407581580696, + -0.11463260589847274 + ], + "(1, 0, 0, 3, 0)": [ + -0.14350968802347358, + -0.04083978797784146, + -0.09909284163171486 + ], + "(1, 0, 1, 3, 0)": [ + -0.1370625750166114, + -0.0724585837043522, + -0.0009240473895465921 + ], + "(1, 1, 1, 3, 0)": [ + -0.03626126792752658, + -0.03710154477294096, + 0.0032683572316288146 + ], + "(1, 0, 1, 4, 0)": [ + 0.0014061010064767788, + -0.04452670073239082, + -0.07821196453222311 + ], + "(1, 1, 1, 4, 0)": [ + -0.17474951197312322, + -0.5414578832524665, + -0.22510473565580394 + ], + "(1, 1, 0, 4, 0)": [ + -0.1328897390182134, + -0.034781929591368305, + -0.16080111281963685 + ], + "(1, 0, 0, 4, 0)": [ + -0.1475330350868065, + -0.18227419935411793, + 0.001484675759361425 + ], + "(1, 0, 0, 5, 0)": [ + -0.9545247785708878, + -0.9514659970621401, + -0.90452487601084 + ], + "(1, 0, 1, 5, 0)": [ + -0.8391556754298948, + -0.7701522287778763, + -0.8233480591505404 + ], + "(1, 1, 1, 5, 0)": [ + -1.093243461459317, + -1.1130313370385934, + -1.117598650488551 + ], + "(1, 1, 0, 5, 0)": [ + -1.060052207589391, + -1.0520482450609705, + -1.0541775063616219 + ], + "(2, 0, 1, 2, 0)": [ + 0.0019578641658367663, + 0.025663384896343185, + 0.1263307771397902 + ], + "(2, 0, 1, 3, 0)": [ + 0.023588571331503803, + 0.12134317356917385, + 0.2760718263547822 + ], + "(2, 0, 0, 3, 0)": [ + 0.1161666059796964, + 0.2235732549534617, + 0.07430355449847614 + ], + "(2, 1, 1, 2, 0)": [ + 0.03458905003176114, + 0.0007093755508323987, + 0.17288176348599543 + ], + "(1, 0, 1, 0, 0)": [ + 0.0011937985794666583, + 0.0006577923193096143, + -8.048310965484833e-05 + ], + "(1, 1, 1, 0, 0)": [ + 0.0015724376946350388, + 0.0007011948995402796, + -7.399361560918754e-05 + ], + "(2, 0, 0, 4, 0)": [ + -0.05649924448818272, + -0.09895201781380289, + 0.010597525817515366 + ], + "(2, 0, 1, 4, 0)": [ + -0.08358002678014068, + -0.09810695789263416, + -0.00041284702559314396 + ], + "(2, 1, 1, 4, 0)": [ + 0.09217839560450072, + -0.017320517480429995, + -0.03878804552748378 + ], + "(2, 1, 0, 4, 0)": [ + 0.1925844035349752, + 0.012835236335966602, + -0.042144040424107326 + ], + "(2, 1, 0, 5, 0)": [ + -1.4444514200032612, + -1.445392257262375, + -1.4481597245532174 + ], + "(2, 0, 0, 5, 0)": [ + -1.4712160523544338, + -1.470824695284624, + -1.3327809396730481 + ], + "(2, 0, 1, 5, 0)": [ + -1.4469254687808646, + -1.5299065594273575, + -1.5281209393177229 + ], + "(2, 1, 1, 5, 0)": [ + -1.3128182215345525, + -1.4477609597552596, + -1.4392818680144006 + ], + "(3, 1, 1, 1, 0)": [ + 0.45806194208553763, + 0.46093042135097084, + 0.474475379606745 + ], + "(4, 0, 1, 4, 1)": [ + 1.4164846977098555, + 1.4459868052131215, + 1.4616804923252622 + ], + "(4, 1, 1, 4, 1)": [ + 1.4129498511120588, + 1.4025285190325392, + 1.4393948253927626 + ], + "(1, 1, 0, 0, 0)": [ + 0.0010373859416610812, + 0.0011703526531334982, + 0.0007493443523265281 + ], + "(1, 0, 0, 2, 0)": [ + -0.10532136935062486, + -0.03584791600007468, + -0.11854695228463348 + ], + "(3, 0, 1, 5, 0)": [ + 0.02285934492582601, + -0.07103977867880287, + 0.019668059978482654 + ], + "(3, 1, 1, 5, 0)": [ + 0.03469354291860075, + 0.044038976296537, + 0.013242232441255847 + ], + "(3, 1, 0, 5, 0)": [ + -0.020181802884764344, + 0.06649176778651363, + 0.06509248190323355 + ], + "(3, 0, 0, 5, 0)": [ + 0.028824321631020555, + 0.029977268479488877, + 0.026545109024266954 + ], + "(3, 1, 0, 4, 0)": [ + 0.6296155314616831, + 0.4336608270212875, + 0.39276605272085297 + ], + "(3, 0, 0, 4, 0)": [ + 0.7770847854086554, + 0.5242727900048311, + 0.5206170117506856 + ], + "(3, 0, 1, 4, 0)": [ + 0.5314214160012682, + 0.7844820846414253, + 0.5074484536943829 + ], + "(3, 1, 1, 4, 0)": [ + 0.3754666321452126, + 0.7229753308135625, + 0.3995770969535575 + ], + "(4, 1, 0, 4, 1)": [ + 1.4963873160873034, + 1.467341373698872, + 1.4384319896246058 + ], + "(4, 1, 0, 3, 1)": [ + 1.5752245110522871, + 1.5602010532290964, + 1.492566856395298 + ], + "(4, 0, 0, 3, 1)": [ + 1.519151228092836, + 1.5227125117824454, + 1.4903275891051206 + ], + "(4, 0, 0, 4, 1)": [ + 1.6473288659015743, + 1.7275422975605321, + 1.5048793902151347 + ], + "(4, 0, 1, 3, 1)": [ + 1.4129889090497951, + 1.452240570193209, + 1.4625105948676116 + ], + "(4, 1, 1, 3, 1)": [ + 1.4646407295985004, + 1.4732073804218064, + 1.511095490878426 + ], + "(4, 1, 0, 2, 1)": [ + 2.106400147645301, + 1.6356823386754527, + 1.6995586432240426 + ], + "(4, 0, 0, 2, 1)": [ + 2.161395368605027, + 1.5237311046844422, + 1.8173060474051923 + ], + "(4, 0, 0, 1, 1)": [ + 2.3493776509278357, + 1.7130386451873514, + 1.8850438496049542 + ], + "(4, 0, 1, 1, 1)": [ + 2.1151543614085875, + 1.6543782490876844, + 2.2703735803195966 + ], + "(4, 1, 1, 1, 1)": [ + 1.918925083602411, + 1.74973854680099, + 2.5088229704031644 + ], + "(4, 1, 1, 1, 0)": [ + 1.7371298756417777, + 1.962531149214229, + 2.264878889586818 + ], + "(4, 1, 0, 1, 0)": [ + 2.020061243093679, + 1.8524421509129365, + 1.9286353167781694 + ], + "(4, 0, 0, 1, 0)": [ + 1.6817658822452428, + 1.6421444458419396, + 2.3586227153499326 + ], + "(4, 0, 1, 1, 0)": [ + 1.78836241223083, + 1.848369564939383, + 2.453235378342548 + ], + "(4, 0, 1, 2, 1)": [ + 1.5272002349823173, + 1.774756715807296, + 2.2631339867701725 + ], + "(4, 1, 0, 1, 1)": [ + 1.9677313784165418, + 2.433707323529033, + 1.7422610937368759 + ], + "(4, 1, 1, 2, 1)": [ + 1.490825730025895, + 1.6990871895059803, + 2.1039556756595323 + ], + "(5, 1, 0, 5, 1)": [ + 6.123750372525588, + 8.70575880253172, + 6.260498065482083 + ], + "(5, 0, 0, 5, 1)": [ + 7.308611772562493, + 5.240209399773695, + 5.920309823658606 + ], + "(5, 0, 1, 5, 1)": [ + 6.96203305289973, + 4.890002993836127, + 11.78691576141891 + ], + "(5, 1, 1, 5, 1)": [ + 6.713821193839766, + 9.316098624644017, + 5.311107957816669 + ], + "(6, 1, 1, 5, 1)": [ + 15.047660106650373, + 16.821457503565934, + 27.71374620203335 + ], + "(6, 1, 0, 5, 1)": [ + 16.384453849684327, + 28.353343581666334, + 20.687843932731177 + ], + "(6, 0, 0, 5, 1)": [ + 30.01197247830154, + 18.86235490305763, + 19.344672586909752 + ], + "(6, 0, 1, 5, 1)": [ + 26.558921440960397, + 18.868850427579066, + 16.828520283531955 + ], + "(4, 0, 1, 3, 0)": [ + 2.1925640020257102, + 2.95691906279798, + 3.4723878527181853 + ], + "(4, 1, 1, 3, 0)": [ + 2.198797063712139, + 2.21724652824473, + 4.963837212356006 + ], + "(4, 1, 0, 3, 0)": [ + 3.886035854924468, + 2.4084368048531406, + 2.8785873750775908 + ], + "(4, 0, 0, 3, 0)": [ + 2.4390869073407133, + 3.1202053796624596, + 3.440342237225939 + ], + "(4, 1, 1, 2, 0)": [ + 1.7198249761230124, + 1.7264336589470106, + 2.55030041839504 + ], + "(4, 1, 0, 2, 0)": [ + 2.857534274934687, + 1.8560661471861224, + 1.7892755484782197 + ], + "(4, 1, 0, 4, 0)": [ + 5.208515876962349, + 4.076311172617219, + 3.1256266560673907 + ], + "(4, 0, 0, 4, 0)": [ + 5.359684776234321, + 2.8072297945126206, + 3.9345761324135915 + ], + "(4, 0, 1, 4, 0)": [ + 3.2715111780893653, + 5.189496212005362, + 4.395806899462235 + ], + "(4, 1, 1, 4, 0)": [ + 3.536815804580715, + 3.9015860629086827, + 5.135748652698586 + ], + "(4, 0, 1, 5, 0)": [ + -0.8730882647726272, + -0.8192135917755976, + -0.8542523954122301 + ], + "(4, 1, 1, 5, 0)": [ + -0.8245347871065521, + -0.8326455650944975, + 0.9991824483293552 + ], + "(4, 0, 1, 2, 0)": [ + 1.9753930346045618, + 1.9649213400482635, + 1.9809836107767345 + ], + "(4, 1, 0, 5, 0)": [ + -0.6336931013817524, + 1.110335983119446, + 0.010250432901592221 + ], + "(4, 0, 0, 5, 0)": [ + -0.5955443430025976, + -0.5479728074067413, + -0.5137459310071231 + ], + "(4, 0, 0, 2, 0)": [ + 2.306030089074062, + 2.0513914882428184, + 1.7358723965792504 + ], + "(5, 0, 1, 2, 1)": [ + 5.938331314830866, + 5.923671010454649, + 6.322198810091958 + ], + "(5, 0, 1, 1, 1)": [ + 6.3713621509072205, + 4.278415316748207, + 5.217314857720529 + ], + "(5, 1, 0, 1, 1)": [ + 6.271680380359423, + 6.118414945090114, + 5.863101745784359 + ], + "(5, 0, 0, 1, 1)": [ + 6.343186747568721, + 5.938795353886484, + 6.033491402716604 + ], + "(5, 1, 1, 1, 1)": [ + 5.855501877516305, + 5.146222646836102, + 6.270943831177874 + ], + "(5, 1, 1, 2, 1)": [ + 5.965178381510156, + 4.951535473108393, + 6.335327059170063 + ], + "(5, 1, 0, 2, 1)": [ + 6.101590240378415, + 5.943638909366957, + 6.166377646864589 + ], + "(5, 1, 0, 3, 1)": [ + 9.399681750212956, + 6.046443374728028, + 5.472047895906149 + ], + "(5, 0, 0, 3, 1)": [ + 7.264852310738755, + 6.100976298895757, + 10.121931565466037 + ], + "(5, 0, 1, 3, 1)": [ + 6.058497642090931, + 6.64055605254837, + 12.820985400788588 + ], + "(5, 1, 1, 3, 1)": [ + 6.295127566601765, + 8.86480034761974, + 6.525986911785163 + ], + "(5, 0, 0, 4, 1)": [ + 11.620473329796793, + 5.852233709190965, + 6.172684561311558 + ], + "(5, 0, 1, 4, 1)": [ + 5.853903246175382, + 7.465633418352998, + 14.203272816632714 + ], + "(5, 1, 1, 4, 1)": [ + 8.411250446883901, + 6.821876889600531, + 10.420607535997695 + ], + "(5, 1, 0, 4, 1)": [ + 10.978468161677114, + 6.095771009776962, + 7.448828231947645 + ], + "(5, 0, 0, 5, 0)": [ + 4.168560040349653, + 4.1193373810885126, + 4.189200305683654 + ], + "(5, 0, 1, 5, 0)": [ + 3.9037126196648213, + 3.952725477446389, + 3.925950696362765 + ], + "(5, 0, 0, 4, 0)": [ + 6.989124827386673, + 6.709668768383307, + 6.164502340428653 + ], + "(5, 0, 1, 4, 0)": [ + 6.470192212401894, + 6.206583505620243, + 6.219341416738382 + ], + "(5, 1, 1, 4, 0)": [ + 6.44804048006831, + 6.470249716808012, + 6.499092139332864 + ], + "(5, 1, 0, 4, 0)": [ + 6.60974475903126, + 5.809620089830466, + 6.38325714226122 + ], + "(5, 0, 1, 3, 0)": [ + 5.902298693239425, + 6.513993949524509, + 6.050931365245501 + ], + "(5, 1, 1, 3, 0)": [ + 5.7219438824090165, + 6.654252839175707, + 5.854631300383308 + ], + "(5, 1, 0, 3, 0)": [ + 7.060124662461811, + 6.133863344872211, + 5.887388222544612 + ], + "(5, 1, 0, 2, 0)": [ + 6.420811088645497, + 5.484591712572488, + 5.344338090966727 + ], + "(5, 1, 0, 1, 0)": [ + 7.60161765328693, + 5.415707409756238, + 5.625087230101972 + ], + "(5, 0, 0, 1, 0)": [ + 6.362344373068486, + 6.250819401040948, + 5.0532246397904155 + ], + "(5, 0, 0, 2, 1)": [ + 6.389023020089472, + 5.921270205667463, + 5.710168906134173 + ], + "(2, 1, 1, 0, 0)": [ + 0.007418340685015537, + 0.0, + 0.0 + ], + "(5, 1, 1, 5, 0)": [ + 3.993346037475164, + 4.02595780895983, + 3.958113797811793 + ], + "(5, 1, 0, 5, 0)": [ + 3.8704975586426262, + 3.829895425894513, + 3.8484935557684774 + ], + "(5, 0, 0, 3, 0)": [ + 7.206895605095726, + 6.361422712879813, + 6.382153739961362 + ], + "(6, 1, 1, 3, 0)": [ + 39.16516514968102, + 0.0, + 13.874993945259016 + ], + "(6, 1, 0, 3, 0)": [ + 189.25859455848735, + 0.0, + 29.91897931545399 + ], + "(6, 1, 0, 2, 0)": [ + 17.158262188709795, + 1.9514931582028616, + 0.0 + ], + "(6, 0, 0, 1, 0)": [ + 25.614850876813907, + 0.0, + 6.3154150574104975 + ], + "(6, 0, 0, 1, 1)": [ + 17.352813028917932, + 2.580182984411188, + 2.140465724900898 + ], + "(6, 0, 1, 1, 1)": [ + 31.2334214211808, + 4.9985873791887645, + 0.0 + ], + "(6, 0, 1, 2, 1)": [ + 29.335273825224643, + 14.06297569453392, + 21.55271341124275 + ], + "(6, 1, 1, 3, 1)": [ + 23.25511720982458, + 16.649397764488587, + 19.90803592789075 + ], + "(6, 1, 0, 3, 1)": [ + 26.80171996379373, + 0.0, + 18.557566348591283 + ], + "(5, 0, 0, 2, 0)": [ + 6.128584915294166, + 4.74501528893155, + 5.71323506758402 + ], + "(5, 0, 1, 1, 0)": [ + 5.433105157444771, + 6.061572284699453, + 5.812568088606896 + ], + "(5, 0, 1, 2, 0)": [ + 9.299494331173243, + 4.798479530028196, + 5.0529263281147205 + ], + "(5, 1, 1, 1, 0)": [ + 5.411730171435421, + 5.336890832671549, + 3.6940890253938563 + ], + "(6, 1, 0, 4, 1)": [ + 46.0361726410884, + 185.7281789946067, + 24.898756482539923 + ], + "(6, 0, 0, 4, 1)": [ + 240.5192583338726, + 19.671210374109826, + 32.44011863831897 + ], + "(6, 0, 1, 4, 1)": [ + 217.5041757792318, + 36.04632868615748, + 35.48728683017791 + ], + "(6, 1, 0, 5, 0)": [ + 43.831876602918115, + 27.8889817150394, + 41.01229025098874 + ], + "(6, 0, 0, 5, 0)": [ + 37.31076342912965, + 72.17808547747593, + 93.69829260586646 + ], + "(6, 0, 1, 5, 0)": [ + 30.225634154900494, + 57.043973946193184, + 19.12245464571805 + ], + "(6, 1, 1, 5, 0)": [ + 22.31531135227297, + 31.702532410298858, + 24.52593989472061 + ], + "(6, 1, 0, 4, 0)": [ + 51.69128813332314, + 43.10494568818312, + 31.630799280419943 + ], + "(6, 0, 0, 4, 0)": [ + 118.94112981801939, + 21.91343878965308, + 0.0 + ], + "(6, 0, 1, 4, 0)": [ + 43.63117424617634, + 72.1459389593239, + 244.531009560538 + ], + "(6, 1, 1, 4, 1)": [ + 40.439091364812, + 62.86471099901328, + 28.21579086570153 + ], + "(5, 1, 1, 2, 0)": [ + 6.0763041725669105, + 3.2256356687419245, + 1.9883725448540135 + ], + "(2, 1, 0, 0, 0)": [ + 0.04124880442854498, + 0.0, + 0.0 + ], + "(6, 1, 1, 4, 0)": [ + 33.12059537789176, + 36.79265587145829, + 33.0038579587303 + ], + "(6, 0, 0, 3, 0)": [ + 6.610554709899747, + 0.0, + 0.0 + ], + "(6, 0, 0, 2, 0)": [ + 19.103288763514552, + 1.7250724331149645, + 0.0 + ], + "(6, 0, 1, 2, 0)": [ + 11.45538182633565, + 0.0, + 0.0 + ], + "(6, 0, 1, 1, 0)": [ + 36.245136836621846, + 0.0, + 0.0 + ], + "(6, 1, 1, 1, 0)": [ + 51.319980637299345, + 10.223271940420567, + 3.1856529540347966 + ], + "(6, 1, 1, 1, 1)": [ + 28.43873600008985, + 3.5188606871528423, + 5.206377154699371 + ], + "(6, 1, 0, 1, 1)": [ + 254.43695151567147, + 0.0, + 4.383567406279885 + ], + "(6, 1, 0, 2, 1)": [ + 4.645323738237858, + 41.06083981486784, + 0.0 + ], + "(6, 0, 0, 2, 1)": [ + 34.28350285678632, + 34.97694231722596, + 0.0 + ], + "(6, 0, 1, 3, 1)": [ + 23.303998337366593, + 23.665273191046108, + 23.58001108802269 + ], + "(4, 1, 1, 0, 0)": [ + 0.32071104927797534, + 0.0, + 0.0 + ], + "(6, 0, 0, 3, 1)": [ + 29.960345351525937, + 12.223389700753652, + 14.174347922430263 + ], + "(2, 0, 1, 0, 0)": [ + 0.023708233987275087, + 0.0, + 0.0 + ], + "(6, 1, 1, 2, 1)": [ + 32.770471384636984, + 0.0, + 0.0 + ], + "(6, 1, 1, 2, 0)": [ + 0.0, + 8.425163932932508, + 0.0 + ], + "(6, 1, 0, 1, 0)": [ + 11.400227838868048, + 2.5621867842074355, + 1.6266017364584542 + ], + "(6, 0, 1, 3, 0)": [ + 410.21798252237346, + 103.80547176610862, + 0.0 + ] + } +} \ No newline at end of file diff --git a/src/crane_controller/q_agent.py b/src/crane_controller/q_agent.py index dd115f7..0ad178a 100644 --- a/src/crane_controller/q_agent.py +++ b/src/crane_controller/q_agent.py @@ -283,7 +283,7 @@ def dump_results( json.dump(content, _f, indent=3) logger.info("Updated q_values saved to %s", _filename.resolve()) - def read_dumped(self, filename: str | Path) -> defaultdict[tuple[int, ...], np.ndarray]: + def read_dumped(self, filename: str | Path | None = None) -> defaultdict[tuple[int, ...], np.ndarray]: """Read a Q-values dict from a JSON file. Parameters @@ -296,17 +296,24 @@ def read_dumped(self, filename: str | Path) -> defaultdict[tuple[int, ...], np.n defaultdict[tuple[int, ...], np.ndarray] Loaded Q-values mapping observation tuples to action-value arrays. """ - path = Path(filename) - with path.open(encoding="utf-8") as _f: - from_dump = json.load(_f) - self.previous_steps = int(from_dump["q_agent"]["steps"]) - self.epsilon = float(from_dump["q_agent"].get("epsilon", 1.0)) q_values: defaultdict[tuple[int, ...], np.ndarray] = defaultdict( lambda: np.array((0.0,) * self.env.action_space.n, float) # type: ignore[attr-defined,type-var] ) - assert "q_values" in from_dump, f"Key 'q_values' not found in file {filename}" - for k, v in from_dump["q_values"].items(): - q_values.update({literal_eval(k): np.array(v) if isinstance(v, list) else v}) + if filename is None and self.filename is None: # there is no file to read. Return empty defautdict + pass + else: + if filename is not None: + path = Path(filename) + elif self.filename is not None: + path = Path(self.filename) + + with path.open(encoding="utf-8") as _f: + from_dump = json.load(_f) + self.previous_steps = int(from_dump["q_agent"]["steps"]) + self.epsilon = float(from_dump["q_agent"].get("epsilon", 1.0)) + assert "q_values" in from_dump, f"Key 'q_values' not found in file {filename}" + for k, v in from_dump["q_values"].items(): + q_values.update({literal_eval(k): np.array(v) if isinstance(v, list) else v}) return q_values def analyse_training(self, window: int = 500) -> None: diff --git a/tests/test_q.py b/tests/test_q.py index fd6fa05..6a62eca 100644 --- a/tests/test_q.py +++ b/tests/test_q.py @@ -32,7 +32,9 @@ def test_q_analyse(crane: Callable[..., Crane], *, show: bool) -> None: crane, discrete=QLearningAgent.DEFAULT_DISCRETE.copy(), ) - agent = QLearningAgent(env, filename=Path("q_trained.json"), use_file='r') + assert Path("q_trained.json").exists(), "File 'q_trained.json' not found" + agent = QLearningAgent(env, filename=Path("q_trained.json"), use_file="r") + agent.q_values = agent.read_dumped() for k, v in agent.q_values.items(): assert len(k) == 5, len(v) == 3 for pos in (0, 1): @@ -57,12 +59,12 @@ def test_intervals(crane: Callable[..., Crane]): discrete=QLearningAgent.DEFAULT_DISCRETE.copy(), ) - agent = QLearningAgent(env, filename=save_path, use_file='w') + agent = QLearningAgent(env, filename=save_path, use_file="w") for i in range(10): _ = env.reset(seed=i + 1) agent.do_episodes(n_episodes=2, max_steps=100) if i == 0: - agent = QLearningAgent(env, filename=save_path, use_file='rw') + agent = QLearningAgent(env, filename=save_path, use_file="rw") logger.info(f"Model saved to {save_path}") @@ -72,6 +74,8 @@ def test_intervals(crane: Callable[..., Crane]): import pytest + from crane_controller.crane_factory import build_crane # noqa: F401 + retcode = pytest.main(["-rP -s -v", __file__]) assert retcode == 0, f"Return code {retcode}" os.chdir(Path(__file__).parent.absolute() / "test_working_directory") From b9ff6d97aa79ab67321619382a50bf423fa23313 Mon Sep 17 00:00:00 2001 From: Eisinger Date: Thu, 30 Apr 2026 15:53:01 +0200 Subject: [PATCH 4/4] Include read epsilon_decay when continuing a training, overwriting the value from the (new) agent. epsilon_decay default changed to 1e-4. New results from q-learning included. --- models/q_anti-pendulum.json | 271 +++++++++++++++++--------------- scripts/use_q_ide.py | 4 +- src/crane_controller/q_agent.py | 3 +- 3 files changed, 147 insertions(+), 131 deletions(-) diff --git a/models/q_anti-pendulum.json b/models/q_anti-pendulum.json index 80143f0..85ab37d 100644 --- a/models/q_anti-pendulum.json +++ b/models/q_anti-pendulum.json @@ -1,6 +1,6 @@ { - "start-training": "30.04.2026 08:29:32", - "end-training": "30.04.2026 08:29:54", + "start-training": "30.04.2026 11:49:22", + "end-training": "30.04.2026 13:40:30", "pendulum": { "wire-length": "10.0", "wire-q-factor": "50.0", @@ -13,219 +13,234 @@ "q_agent": { "filename": "C:\\Users\\eis\\Documents\\Projects\\Simulation_Model_Assurance\\osp\\packages\\crane-controller\\models\\q_anti-pendulum.json", "use_file": "rw", - "episodes": "10", - "steps": "40020", + "episodes": "2000", + "steps": "6043020", "learning_rate": "0.1", "discount_factor": "0.95", "epsilon-decay": "0.001", "final-epsilon": "0.1", - "epsilon": "0.98" + "epsilon": "0.1" }, "q_values": { "(0, 0, 0, 1, 1)": [ - -0.16705581244794693, - -0.1282253386578821, - -0.09732839899973421 + -0.13665139042486474, + -0.10179172074284769, + -0.03615620048379082 ], "(0, 0, 0, 0, 0)": [ - -0.37890841049813784, - -0.3221144173089857, - -0.3983817365055842 + -0.3855117367492396, + -0.3977828071547589, + -0.3073444787070865 ], "(0, 0, 0, 1, 0)": [ - -0.2664435349978729, - -0.17001690959731303, - -0.2325842595231777 + -0.29353749709276, + -0.01245142162171615, + -0.18353348032889846 ], "(0, 0, 1, 1, 0)": [ - -0.15064415891825617, - -0.2511495308113964, - -0.25691496256409363 + -0.22450214202690819, + -0.011561431340910916, + -0.1607899174877803 ], "(0, 1, 1, 1, 0)": [ - -0.20723301719553744, - -0.2388453016403667, - -0.16750979906585578 + -0.18635102620612973, + -0.014918333953243068, + -0.19796550577122185 ], "(0, 1, 0, 1, 0)": [ - -0.21424728224181977, - -0.22083742688011837, - -0.2194005545131396 + -0.2145730606952124, + -0.08559682487920808, + -0.23241014511502056 ], "(0, 1, 0, 2, 0)": [ - -0.048696104408199575, - -0.05034525692645991, - -0.05258803579860084 + -0.15422746939849802, + -0.10111065548903554, + -0.14008905307453026 ], "(0, 0, 0, 2, 0)": [ - -0.052053372911474195, - -0.05245228578367619, - -0.056669454625284825 + -0.11621741035295982, + -0.136189021105371, + -0.04124338407154954 ], "(0, 0, 1, 2, 0)": [ - -0.04772137252269057, - -0.05881755142692484, - -0.054423243178985496 + -0.02277006644413333, + -0.13713080583764461, + -0.14179060788269038 ], "(0, 1, 1, 2, 0)": [ - -0.05278649224808442, - -0.07632187326419965, - -0.060938333453418286 + -0.1349583618360337, + -0.16924737718439545, + -0.1680842291860991 ], "(0, 0, 0, 3, 0)": [ - -0.24369975025229282, - -0.24695892939305783, - -0.24911621770298206 + -0.10858577021970485, + -0.12418990186300555, + -0.09049991808047572 ], "(0, 0, 1, 3, 0)": [ - -0.24027345169278405, - -0.24031167906999237, - -0.24314101938021784 + -0.07941711160006829, + -0.0857258888169698, + -0.08216417483037675 ], "(0, 1, 1, 3, 0)": [ - -0.24861198983323804, - -0.23256743828903015, - -0.23006936132916378 + -0.09088298902159943, + -0.08990627639447457, + -0.09077903505343787 ], "(0, 1, 0, 3, 0)": [ - -0.23803043434121016, - -0.23300086647016396, - -0.252096348380666 + -0.05470242304134086, + -0.05140006893860253, + -0.04592915973935223 ], "(0, 0, 1, 4, 0)": [ - -0.29577889270947405, - -0.2981742809792015, - -0.29791983068372924 + -0.1435545874574983, + -0.15231057248465216, + -0.15086957427182499 ], "(0, 1, 1, 4, 0)": [ - -0.2958234315862551, - -0.30645418768275684, - -0.30661740666730875 + -0.12139056771543123, + -0.21535329215297502, + -7.412694388042563 ], "(0, 1, 0, 4, 0)": [ - -0.3029623394714542, - -0.3018348837313605, - -0.3286817686032363 + -0.134709312919711, + -0.2503011744582585, + -47.30780376798799 ], "(0, 0, 0, 4, 0)": [ - -0.29781883698756567, - -0.2940971933423739, - -0.29733487917694584 + -0.14326072107267074, + -0.1371132606497332, + -0.14073483467765396 ], "(0, 1, 0, 5, 0)": [ - -0.6329437113580243, - -0.6297747244548322, - -0.628882264517245 + -0.14921092460534552, + -0.5336663925124036, + -0.413160457318822 ], "(0, 0, 0, 5, 0)": [ - -0.6277272058045739, - -0.6384550155569291, - -0.6303149829373899 + -0.16817182471742664, + -0.1831098089787067, + -0.16880169485719884 ], "(0, 0, 1, 5, 0)": [ - -0.6396081577986387, - -0.6334955382726669, - -0.6407924717823796 + -0.22330769106781206, + -0.164930407464341, + -0.1585054303950049 ], "(0, 1, 1, 5, 0)": [ - -0.6400694360313026, - -0.633976678537013, - -0.6330865725095266 + -0.1408499847124823, + -0.23484133236611604, + -0.3455137865525159 ], "(0, 0, 1, 1, 1)": [ - -0.10369497244449818, - -0.11119982054398986, - -0.12993117280633998 + -0.10777995324012056, + -0.01942676901122676, + -0.08416528524778964 ], "(0, 1, 1, 1, 1)": [ - -0.10542409331231692, - -0.11768594196503504, - -0.07456918970206615 + -0.020055193708830747, + -0.1009643749752495, + -0.1819727036838127 ], "(0, 1, 0, 1, 1)": [ - -0.14656019585858246, - -0.1326055167833163, - -0.12885325016738228 + -0.09581720731770141, + -0.17129694552682964, + -0.007358568090630821 ], "(0, 1, 0, 2, 1)": [ - -0.11967009468971596, - -0.12216127394629554, - -0.1429122098640283 + -0.05629874689761112, + -0.05997232601255042, + -0.013786458351433332 ], "(0, 0, 0, 2, 1)": [ - -0.128347455422042, - -0.1487387200161241, - -0.16517981243223914 + -0.19923861861268571, + -0.02587201671587543, + -0.13903271453636865 ], "(0, 0, 1, 2, 1)": [ - -0.2682879902850379, - -0.19492371557724883, - -0.19236122464754474 + -0.1204099926263145, + -0.02639568487280513, + -0.09331258152285944 ], "(0, 1, 1, 2, 1)": [ - -0.15105097271312573, - -0.20727413471853204, - -0.20105859767753256 + -0.014954709519526446, + -0.08954397697589433, + -0.10256389706199417 ], "(0, 0, 0, 3, 1)": [ - -0.11746821209617946, - -0.12239144215834397, - -0.13720619491970695 + -0.07597091419941952, + -0.08521568787348147, + -0.0767067975458054 ], "(0, 0, 1, 3, 1)": [ - -0.11651570970036011, - -0.11862687420946329, - -0.1260539064368246 + -0.07768760972828949, + -0.15556947861941556, + -0.10213663795530971 ], "(0, 1, 1, 3, 1)": [ - -0.1450901027213944, - -0.12420631784497436, - -0.12104519333576104 + -0.006677023045033242, + -0.12168865598235325, + -0.11006819308277363 ], "(0, 1, 0, 3, 1)": [ - -0.13278128970879577, - -0.14433471293373654, - -0.14787321638372558 + -0.0697119229410046, + -0.09423853202784556, + -0.006601675203503434 ], "(0, 0, 1, 4, 1)": [ - -0.1720169164100566, - -0.1697573068991259, - -0.18563861409468002 + -0.18849044357597325, + -0.06091103734585244, + -0.16438848786875762 ], "(0, 1, 1, 4, 1)": [ - -0.19331603592176078, - -0.16465692724409528, - -0.1870968740626517 + -0.019168554778457022, + -0.18565904626703944, + -0.16910122873025832 ], "(0, 1, 0, 4, 1)": [ - -0.16839244888515778, - -0.15548076071108868, - -0.15890884729639157 + -0.12186343148374591, + -0.1353245425973399, + -0.010326133307475296 ], "(0, 0, 0, 4, 1)": [ - -0.16755688928521928, - -0.1598773466377167, - -0.17156379062820235 + -0.18803040883070493, + -0.18806111288250146, + -0.07370698948029014 ], "(0, 1, 1, 5, 1)": [ - -1.0698086689883646, - -1.095931248568509, - -1.1160177171578312 + -0.28044334846311725, + -0.34174178181292564, + -0.08571555802193095 ], "(0, 1, 0, 5, 1)": [ - -1.1027542984743757, - -1.0928909342292816, - -1.0833097769293856 + -0.3678985661744148, + -0.09066726725097075, + -0.506598781689325 ], "(0, 0, 0, 5, 1)": [ - -1.088841306513128, - -1.1035400078582027, - -1.1176895277220635 + -0.4346946110999083, + -0.08965319306320729, + -0.39860453597868206 ], "(0, 0, 1, 5, 1)": [ - -1.0994216179219447, - -1.1237160556971877, - -1.1090547917156828 + -0.5390323341639348, + -0.07573563164410979, + -0.4365671199679866 + ], + "(0, 1, 0, 0, 0)": [ + -0.10155517152045257, + -0.09703513305872391, + -0.13317800430561305 + ], + "(0, 0, 1, 0, 0)": [ + -0.3468722158158501, + -0.3601257867654982, + -0.35361639690141167 + ], + "(0, 1, 1, 0, 0)": [ + -0.1740876940193742, + -0.16639211647507102, + -0.1595306814502831 ] } } \ No newline at end of file diff --git a/scripts/use_q_ide.py b/scripts/use_q_ide.py index 4edb17e..883abc7 100644 --- a/scripts/use_q_ide.py +++ b/scripts/use_q_ide.py @@ -52,7 +52,7 @@ def do_use(kwargs: dict[str, Any]) -> None: use_file = kwargs.get("use_file", "r") agent = QLearningAgent(env, filename=filename, use_file=use_file) agent.do_episodes(n_episodes=kwargs.get("episodes", 100), max_steps=kwargs.get("steps", 5000)) - if filename is not None: + if filename is not None and 'w' in agent.use_file: LOGGER.info(f"Model saved to {filename}") @@ -84,7 +84,7 @@ def _args(base: dict[str, Any], upd: dict[str, Any]) -> dict[str, Any]: "t_fac": 0.0, } # ruff: disable[ERA001] ## we intentionally work with commenting out lines here - args = _args(anti, {"episodes": 10}) # anti-pendulum training + args = _args(anti, {"episodes": 2000}) # anti-pendulum (additional) training # args = _args(pend, {'episodes':10000}) # pendulum training # args = _args( anti, {"episodes": 10, "render": "plot","use_file":'r'}) # show anti-pendulum results # args = _args( pend, {"episodes": 10, "render": "plot", "use_file":'r'}) # show start pendulum results diff --git a/src/crane_controller/q_agent.py b/src/crane_controller/q_agent.py index 0ad178a..2b824d1 100644 --- a/src/crane_controller/q_agent.py +++ b/src/crane_controller/q_agent.py @@ -80,7 +80,7 @@ def __init__( self, env: AntiPendulumEnv, learning_rate: float = 0.1, - epsilon_decay: float = 1e-3, + epsilon_decay: float = 1e-4, final_epsilon: float = 0.1, discount_factor: float = 0.95, filename: Path | None = None, @@ -311,6 +311,7 @@ def read_dumped(self, filename: str | Path | None = None) -> defaultdict[tuple[i from_dump = json.load(_f) self.previous_steps = int(from_dump["q_agent"]["steps"]) self.epsilon = float(from_dump["q_agent"].get("epsilon", 1.0)) + self.epsilon_decay = float(from_dump["q_agent"].get("epsilon", 1e-4)) assert "q_values" in from_dump, f"Key 'q_values' not found in file {filename}" for k, v in from_dump["q_values"].items(): q_values.update({literal_eval(k): np.array(v) if isinstance(v, list) else v})