diff --git a/models/q_anti-pendulum.json b/models/q_anti-pendulum.json index e5468f5..85ab37d 100644 --- a/models/q_anti-pendulum.json +++ b/models/q_anti-pendulum.json @@ -1,238 +1,246 @@ { - "date": "29.04.2026 06:53:49", + "start-training": "30.04.2026 11:49:22", + "end-training": "30.04.2026 13:40:30", "pendulum": { - "start_speed": "1.0", - "render_mode": "none", + "wire-length": "10.0", + "wire-q-factor": "50.0", + "reward-factors": "(1.0, 0.0015, 0.0)", + "acceleration": "0.1", + "step-size": "0.1", + "observations-discretization": "{'pos': (0, 1), 'speed': (0, 1), 'distance': (0.0, 1.0, 2.0, 5.0, 10.0, 20.0), 'sector': (0, 1), 'energies': (np.float64(0.0), np.float64(0.014941105158016455), np.float64(0.373300117199762), np.float64(1.4903594295023934), np.float64(5.916153900902383), np.float64(13.142907888746564), np.float64(98.1))}", "reward_limit": "0.0" }, "q_agent": { - "use_trained": "True", - "filename": "C:\\Users\\eis\\Documents\\Projects\\Simulation_Model_Assurance\\osp\\packages\\crane-controller\\models\\q_trained.json", - "episodes": "6000", - "steps": "30006000", + "filename": "C:\\Users\\eis\\Documents\\Projects\\Simulation_Model_Assurance\\osp\\packages\\crane-controller\\models\\q_anti-pendulum.json", + "use_file": "rw", + "episodes": "2000", + "steps": "6043020", "learning_rate": "0.1", - "discount_factor": "0.95" + "discount_factor": "0.95", + "epsilon-decay": "0.001", + "final-epsilon": "0.1", + "epsilon": "0.1" }, "q_values": { "(0, 0, 0, 1, 1)": [ - -0.7127999098776818, - -0.7204031480080048, - -0.7412199306092444 + -0.13665139042486474, + -0.10179172074284769, + -0.03615620048379082 ], "(0, 0, 0, 0, 0)": [ - -0.39882058710042717, - -0.43180066871885253, - -0.36789168710316733 + -0.3855117367492396, + -0.3977828071547589, + -0.3073444787070865 ], "(0, 0, 0, 1, 0)": [ - -0.49644387699601483, - -0.476832491697103, - -0.4808429865234992 + -0.29353749709276, + -0.01245142162171615, + -0.18353348032889846 ], "(0, 0, 1, 1, 0)": [ - -0.3645407668594485, - -0.34941296493381946, - -0.19816466371287308 + -0.22450214202690819, + -0.011561431340910916, + -0.1607899174877803 ], "(0, 1, 1, 1, 0)": [ - -0.27335521972045285, - -0.31790379183091966, - -0.2034872242749172 + -0.18635102620612973, + -0.014918333953243068, + -0.19796550577122185 ], "(0, 1, 0, 1, 0)": [ - -0.6087312502217588, - -0.6754940805998394, - -0.6535399837815964 + -0.2145730606952124, + -0.08559682487920808, + -0.23241014511502056 ], "(0, 1, 0, 2, 0)": [ - -0.3949208069494323, - -0.4631147311708943, - -0.4776412098931605 + -0.15422746939849802, + -0.10111065548903554, + -0.14008905307453026 ], "(0, 0, 0, 2, 0)": [ - -0.2712981364627587, - -0.17654049561221843, - -0.2508326205458066 + -0.11621741035295982, + -0.136189021105371, + -0.04124338407154954 ], "(0, 0, 1, 2, 0)": [ - -0.48498802268403846, - -0.4395215670942963, - -0.4722944229115358 + -0.02277006644413333, + -0.13713080583764461, + -0.14179060788269038 ], "(0, 1, 1, 2, 0)": [ - -0.4990197857734165, - -0.5327879501542365, - -0.4941112778002146 + -0.1349583618360337, + -0.16924737718439545, + -0.1680842291860991 ], "(0, 0, 0, 3, 0)": [ - -0.5387880192090378, - -0.4784712525147248, - -0.517952873228921 + -0.10858577021970485, + -0.12418990186300555, + -0.09049991808047572 ], "(0, 0, 1, 3, 0)": [ - -0.9096919737168466, - -0.9127452620968198, - -0.9420045562304813 + -0.07941711160006829, + -0.0857258888169698, + -0.08216417483037675 ], "(0, 1, 1, 3, 0)": [ - -0.8667830882886871, - -0.8757935260356404, - -0.8758730446292982 + -0.09088298902159943, + -0.08990627639447457, + -0.09077903505343787 ], "(0, 1, 0, 3, 0)": [ - -0.43694463364384284, - -0.32317351557573637, - -0.37671045435695716 + -0.05470242304134086, + -0.05140006893860253, + -0.04592915973935223 ], "(0, 0, 1, 4, 0)": [ - -0.9306066987431686, - -0.9344436974699757, - -0.9655748330701439 + -0.1435545874574983, + -0.15231057248465216, + -0.15086957427182499 ], "(0, 1, 1, 4, 0)": [ - -1.009825188680733, - -1.016868953064729, - -0.9957558578136729 + -0.12139056771543123, + -0.21535329215297502, + -7.412694388042563 ], "(0, 1, 0, 4, 0)": [ - -1.0369014461449961, - -1.0197443807886495, - -0.9976879186706535 + -0.134709312919711, + -0.2503011744582585, + -47.30780376798799 ], "(0, 0, 0, 4, 0)": [ - -0.8997996021390884, - -0.8913054855090478, - -0.8905808268698551 + -0.14326072107267074, + -0.1371132606497332, + -0.14073483467765396 ], "(0, 1, 0, 5, 0)": [ - -1.5826848759769268, - -2.7445575744690793, - -2.418367271258592 + -0.14921092460534552, + -0.5336663925124036, + -0.413160457318822 ], "(0, 0, 0, 5, 0)": [ - -1.6683128963172371, - -2.251311793007559, - -2.3851805260338708 + -0.16817182471742664, + -0.1831098089787067, + -0.16880169485719884 ], "(0, 0, 1, 5, 0)": [ - -2.2373735329153392, - -1.540495957571949, - -2.76737403946213 + -0.22330769106781206, + -0.164930407464341, + -0.1585054303950049 ], "(0, 1, 1, 5, 0)": [ - -2.3314084980057013, - -1.5138347954693265, - -3.01917937692279 - ], - "(0, 1, 0, 1, 1)": [ - -0.261055476317184, - -0.27846357082465933, - -0.2774020188584469 + -0.1408499847124823, + -0.23484133236611604, + -0.3455137865525159 ], "(0, 0, 1, 1, 1)": [ - -0.1342852547795026, - -0.2470169520959763, - -0.2126480225045462 + -0.10777995324012056, + -0.01942676901122676, + -0.08416528524778964 ], "(0, 1, 1, 1, 1)": [ - -0.2797380701937499, - -0.2945979948624919, - -0.24076792235861236 + -0.020055193708830747, + -0.1009643749752495, + -0.1819727036838127 ], - "(0, 1, 1, 2, 1)": [ - -0.2728844361768713, - -0.13972184601087123, - -0.3275551939924255 + "(0, 1, 0, 1, 1)": [ + -0.09581720731770141, + -0.17129694552682964, + -0.007358568090630821 ], "(0, 1, 0, 2, 1)": [ - -0.455039876626072, - -0.43296129057936117, - -0.4316836608934558 + -0.05629874689761112, + -0.05997232601255042, + -0.013786458351433332 ], "(0, 0, 0, 2, 1)": [ - -0.5686315534261919, - -0.5538077807341172, - -0.5496103554748915 + -0.19923861861268571, + -0.02587201671587543, + -0.13903271453636865 + ], + "(0, 0, 1, 2, 1)": [ + -0.1204099926263145, + -0.02639568487280513, + -0.09331258152285944 + ], + "(0, 1, 1, 2, 1)": [ + -0.014954709519526446, + -0.08954397697589433, + -0.10256389706199417 ], "(0, 0, 0, 3, 1)": [ - -0.3966116309786595, - -0.3317568230750391, - -0.30858102473455185 + -0.07597091419941952, + -0.08521568787348147, + -0.0767067975458054 ], "(0, 0, 1, 3, 1)": [ - -0.8315655206780324, - -0.8333814435452305, - -0.7756715486070018 + -0.07768760972828949, + -0.15556947861941556, + -0.10213663795530971 ], "(0, 1, 1, 3, 1)": [ - -0.9831842749294987, - -0.9782200055298229, - -0.9930127344944251 + -0.006677023045033242, + -0.12168865598235325, + -0.11006819308277363 ], "(0, 1, 0, 3, 1)": [ - -0.6539627507980373, - -0.5852573030361744, - -0.6763166111516389 - ], - "(0, 1, 0, 4, 1)": [ - -0.9026947631459618, - -0.9407969095587947, - -0.965156293555089 - ], - "(0, 0, 0, 4, 1)": [ - -1.0295607537932239, - -1.0110078430982656, - -1.0217477290303583 + -0.0697119229410046, + -0.09423853202784556, + -0.006601675203503434 ], "(0, 0, 1, 4, 1)": [ - -1.0198724273521527, - -1.0273376264121616, - -1.0372176976745933 + -0.18849044357597325, + -0.06091103734585244, + -0.16438848786875762 ], "(0, 1, 1, 4, 1)": [ - -0.8367514960793316, - -0.8072150440466215, - -0.7999324124740165 + -0.019168554778457022, + -0.18565904626703944, + -0.16910122873025832 ], - "(0, 0, 1, 5, 1)": [ - -2.826052632887979, - -2.8216800501972115, - -2.8339346737355218 + "(0, 1, 0, 4, 1)": [ + -0.12186343148374591, + -0.1353245425973399, + -0.010326133307475296 + ], + "(0, 0, 0, 4, 1)": [ + -0.18803040883070493, + -0.18806111288250146, + -0.07370698948029014 ], "(0, 1, 1, 5, 1)": [ - -2.769926115050464, - -2.7896094930838444, - -2.784544336312168 + -0.28044334846311725, + -0.34174178181292564, + -0.08571555802193095 ], "(0, 1, 0, 5, 1)": [ - -2.7928766121585795, - -2.7887103458307694, - -2.7938313222414037 + -0.3678985661744148, + -0.09066726725097075, + -0.506598781689325 ], "(0, 0, 0, 5, 1)": [ - -2.8104767031828572, - -2.8062208641671322, - -2.8101316497840574 + -0.4346946110999083, + -0.08965319306320729, + -0.39860453597868206 ], - "(0, 0, 1, 2, 1)": [ - -0.5213427535062244, - -0.5910962186115551, - -0.47068284911064734 - ], - "(0, 0, 1, 0, 0)": [ - -0.41286468970384216, - -0.42684989269569895, - -169.29165129913554 + "(0, 0, 1, 5, 1)": [ + -0.5390323341639348, + -0.07573563164410979, + -0.4365671199679866 ], "(0, 1, 0, 0, 0)": [ - -0.11933872055277986, - -0.1195010821067479, - -0.1499420269134481 + -0.10155517152045257, + -0.09703513305872391, + -0.13317800430561305 + ], + "(0, 0, 1, 0, 0)": [ + -0.3468722158158501, + -0.3601257867654982, + -0.35361639690141167 ], "(0, 1, 1, 0, 0)": [ - -0.20624980702980253, - -0.22645535674627537, - -0.18943028486599597 + -0.1740876940193742, + -0.16639211647507102, + -0.1595306814502831 ] } } \ No newline at end of file diff --git a/models/q_pendulum.json b/models/q_pendulum.json new file mode 100644 index 0000000..24b3925 --- /dev/null +++ b/models/q_pendulum.json @@ -0,0 +1,1258 @@ +{ + "date": "30.04.2026 05:15:03", + "pendulum": { + "start_speed": "0.0", + "render_mode": "none", + "reward_limit": "1000.0" + }, + "q_agent": { + "use_trained": "True", + "filename": "/home/se/osp/packages/crane-controller/models/q_pendulum.json", + "episodes": "10000", + "steps": "19991482", + "learning_rate": "0.1", + "discount_factor": "0.95" + }, + "q_values": { + "(1, 0, 0, 1, 1)": [ + -0.0015285453382582262, + -0.0019110699513047503, + -0.000421685773718873 + ], + "(1, 0, 0, 0, 0)": [ + 0.0016483778628561532, + 0.0009174558585016858, + 0.0001847120557717193 + ], + "(1, 1, 1, 1, 0)": [ + -0.00011618763900436013, + 0.0021056726444264756, + -0.0009250088093969632 + ], + "(2, 0, 0, 1, 1)": [ + 0.30381525310187946, + 0.13268574959167037, + 0.08968956867959964 + ], + "(2, 0, 1, 1, 1)": [ + 0.07021166494175266, + 0.11544068028302096, + 0.2655879939625879 + ], + "(2, 1, 1, 2, 1)": [ + 0.0468700802191448, + 0.24296967824838114, + 0.057465378203944484 + ], + "(2, 1, 0, 2, 1)": [ + 0.247316722251548, + 0.07724841470718162, + 0.04444885410818501 + ], + "(2, 0, 0, 2, 1)": [ + 0.03151075349886775, + 0.1198098266764795, + 0.026761926901275757 + ], + "(2, 0, 0, 3, 1)": [ + 0.24750790139852671, + -0.00742543468925339, + 0.003439802770232106 + ], + "(2, 0, 1, 3, 1)": [ + 0.013165205849978824, + 0.020782332264133313, + 0.16600804643720815 + ], + "(2, 1, 1, 3, 1)": [ + 0.006877018618653728, + 0.017372356548709385, + 0.1822037784645431 + ], + "(2, 1, 0, 3, 1)": [ + 0.011260461429967967, + -0.06408511029671254, + 0.16832760241231254 + ], + "(2, 1, 0, 4, 1)": [ + 0.30877072872600403, + 0.17712588766549373, + -0.11992985889469697 + ], + "(2, 0, 0, 4, 1)": [ + 0.07171333312506538, + 0.3196709073875426, + -0.010305375893922684 + ], + "(2, 0, 1, 4, 1)": [ + 0.0460495219813757, + -0.025981954700818086, + 0.35798806472060796 + ], + "(2, 1, 1, 4, 1)": [ + 0.28024071171042864, + -0.08081625384759063, + -0.06021051977246365 + ], + "(2, 0, 1, 5, 1)": [ + 0.11199272160151952, + 0.34003948184320887, + 0.15555212679650068 + ], + "(2, 1, 1, 5, 1)": [ + -0.06388301321252567, + -0.27230305038177305, + 0.3227408520413516 + ], + "(2, 1, 0, 5, 1)": [ + 0.33544847358707, + -0.014778838303531092, + -0.13036007932481597 + ], + "(2, 0, 0, 5, 1)": [ + 0.3368562070401561, + -0.06442413772899995, + 0.015851813972102308 + ], + "(1, 1, 1, 5, 1)": [ + -0.873610573166523, + -0.8795035200229879, + -0.8813170678799008 + ], + "(1, 1, 0, 5, 1)": [ + -0.8588580712968287, + -0.8603630825591916, + -0.8545640919171729 + ], + "(1, 0, 0, 5, 1)": [ + -0.876058331400781, + -0.8812733294953802, + -0.8781109627654614 + ], + "(1, 0, 1, 5, 1)": [ + -0.8754681221299438, + -0.8811483806653537, + -0.8795342715024759 + ], + "(2, 1, 1, 1, 1)": [ + 0.2575141349768899, + 0.14721086451092427, + 0.13938193635472462 + ], + "(2, 0, 1, 2, 1)": [ + 0.049090745483137355, + 0.035345162836350655, + 0.1850890582095332 + ], + "(3, 0, 0, 3, 1)": [ + 0.408660876661018, + 0.42698247924321825, + 0.40792534233546174 + ], + "(3, 0, 1, 3, 1)": [ + 0.3600676621922624, + 0.3929021532144681, + 0.461063559575569 + ], + "(3, 1, 1, 3, 1)": [ + 0.43092477422897313, + 0.4760286469676715, + 0.5329087094566657 + ], + "(3, 1, 0, 3, 1)": [ + 0.4099435292952729, + 0.43676458251571, + 0.5251269059135935 + ], + "(3, 0, 0, 4, 1)": [ + 0.4048621627470891, + 0.4066271465074262, + 0.4043352559362854 + ], + "(3, 0, 1, 4, 1)": [ + 0.39328686645976846, + 0.39556738781196227, + 0.3959743788190199 + ], + "(3, 1, 0, 5, 1)": [ + 0.2918677426270124, + 0.2808605809982236, + 0.3364418740886684 + ], + "(3, 0, 0, 5, 1)": [ + 0.015025569672587752, + 0.08196942899783732, + 0.3327158807318435 + ], + "(3, 0, 1, 5, 1)": [ + 0.15452953096605598, + 0.32706016277878186, + 0.19602635793745557 + ], + "(2, 1, 0, 1, 1)": [ + 0.28308955078023956, + 0.16197169711955506, + 0.11997498874359913 + ], + "(3, 1, 1, 5, 1)": [ + 0.04055530756891616, + 0.2713994534576125, + 0.39683102819432226 + ], + "(4, 0, 0, 5, 1)": [ + 4.061952905462236, + 1.8796422036374005, + 2.6392271976381876 + ], + "(4, 0, 1, 5, 1)": [ + 1.8612065508222904, + 2.103909066858137, + 4.240572243261516 + ], + "(4, 1, 1, 5, 1)": [ + 2.7147234118909207, + 4.2477919190765006, + 2.532801305018678 + ], + "(4, 1, 0, 5, 1)": [ + 5.335646400509759, + 2.3417395159744956, + 2.8731429101445904 + ], + "(1, 0, 1, 1, 1)": [ + -0.003409519735202579, + -0.0013383819217571565, + 0.0002056980625474665 + ], + "(3, 1, 1, 4, 1)": [ + 0.38529100751464557, + 0.3942557035549322, + 0.4431156152070642 + ], + "(3, 0, 0, 2, 1)": [ + 0.5090782560731483, + 0.4798070339288543, + 0.4440706186764382 + ], + "(3, 0, 1, 2, 1)": [ + 0.4031617038672731, + 0.41820257002525474, + 0.4254288577893825 + ], + "(3, 1, 0, 2, 1)": [ + 0.5383778912624413, + 0.5816045650663727, + 0.44907309167118936 + ], + "(3, 0, 0, 1, 1)": [ + 0.3820396000856337, + 0.4244240876422671, + 0.4809947245965916 + ], + "(3, 0, 1, 1, 1)": [ + 0.48309577607918963, + 0.4359746680582087, + 0.47844205657321176 + ], + "(3, 1, 1, 2, 1)": [ + 0.4510595166267413, + 0.5413951132995184, + 0.5396546938623438 + ], + "(3, 1, 0, 4, 1)": [ + 0.4911755694970669, + 0.4273041013300876, + 0.4113438183162773 + ], + "(3, 1, 0, 1, 1)": [ + 0.41206946560587604, + 0.4276691064016972, + 0.3891151329059017 + ], + "(3, 0, 0, 1, 0)": [ + 0.595092804268447, + 0.5709074723626882, + 0.4357015522877032 + ], + "(3, 0, 1, 1, 0)": [ + 0.5048087137826464, + 0.5131171175030076, + 0.5824969529639638 + ], + "(2, 0, 1, 1, 0)": [ + 0.16890785501899347, + 0.07849898959202127, + 0.13065691356841727 + ], + "(2, 1, 1, 1, 0)": [ + -0.0019192156044434555, + 0.11937402910019926, + 0.06218611939997548 + ], + "(2, 1, 0, 1, 0)": [ + 0.15656530313154501, + 0.05852070858351637, + 0.05917317431102258 + ], + "(2, 0, 0, 1, 0)": [ + 0.11925737271077642, + 0.06270232978106371, + 0.016145004918409697 + ], + "(3, 1, 1, 1, 1)": [ + 0.415554313371666, + 0.4251120871709033, + 0.41738534617858564 + ], + "(3, 1, 0, 1, 0)": [ + 0.5175638162159195, + 0.4768692444308892, + 0.4092874554117059 + ], + "(1, 1, 0, 1, 0)": [ + 0.007559388663013664, + -0.07141246975944493, + -0.03218011823701119 + ], + "(1, 1, 1, 1, 1)": [ + -0.006336313500421211, + -0.00722006182105164, + -0.015438778111669823 + ], + "(1, 1, 0, 1, 1)": [ + -0.011637555446352072, + -0.03469500547718182, + -0.06946810336387721 + ], + "(1, 1, 0, 2, 1)": [ + -0.02059287808812973, + -0.020207666113278937, + -0.023751195001154267 + ], + "(1, 0, 0, 2, 1)": [ + -0.03136891469721944, + -0.018845272510248604, + -0.025639212801306065 + ], + "(2, 0, 0, 2, 0)": [ + -0.007405306099820059, + 0.13416722893659513, + -0.003863979660054713 + ], + "(3, 0, 0, 2, 0)": [ + 0.4968529980315005, + 0.6083878273580536, + 0.5107250954251326 + ], + "(3, 0, 1, 2, 0)": [ + 0.4572105904076367, + 0.48251636216964044, + 0.5371402979328573 + ], + "(3, 1, 1, 2, 0)": [ + 0.5834210121161674, + 0.5447607381569435, + 0.6522461652145303 + ], + "(3, 1, 1, 3, 0)": [ + 0.4515300417767746, + 0.45066875186289224, + 0.6830079417945389 + ], + "(3, 1, 0, 3, 0)": [ + 0.5338016897375181, + 0.5414142734086245, + 0.6346264676956274 + ], + "(3, 0, 0, 3, 0)": [ + 0.561690114055678, + 0.5920536892988973, + 0.5543067752647609 + ], + "(3, 0, 1, 3, 0)": [ + 0.5489920980754698, + 0.5660011505184062, + 0.5167225836179886 + ], + "(2, 1, 1, 3, 0)": [ + 0.1494485963494498, + -0.03830396513846329, + 0.015657549848606813 + ], + "(2, 1, 0, 3, 0)": [ + 0.20593271843361607, + 0.09479740257152519, + 0.0959316384523621 + ], + "(2, 1, 0, 2, 0)": [ + 0.13616448787873342, + 0.019560267471690474, + -0.005724736049839602 + ], + "(3, 1, 0, 2, 0)": [ + 0.683590879589438, + 0.5526158769509005, + 0.46954602590689115 + ], + "(1, 0, 0, 1, 0)": [ + 0.03733255481235882, + -0.02849128758664682, + -0.09589936241358511 + ], + "(1, 1, 1, 2, 1)": [ + -0.006083016288581055, + -0.03523377742840035, + -0.019203666274691625 + ], + "(1, 0, 1, 2, 1)": [ + -0.01686144825783039, + -0.008924069806829863, + 0.009536322538022453 + ], + "(1, 0, 1, 3, 1)": [ + -0.034434632147796705, + 0.00024112079172242465, + -0.03197066856336422 + ], + "(1, 1, 1, 3, 1)": [ + -0.01895592863353697, + -0.04411521018370294, + 0.01775191105812379 + ], + "(1, 1, 0, 3, 1)": [ + 0.04225706506102414, + -0.0367929078115716, + -0.04368072481979268 + ], + "(1, 0, 0, 3, 1)": [ + -0.007437638387498356, + -0.044450404397953926, + -0.04469796462300711 + ], + "(1, 1, 1, 4, 1)": [ + -0.09646741076342318, + -0.09875923262296923, + -0.07611277807180256 + ], + "(1, 1, 0, 4, 1)": [ + 0.004840534632955751, + -0.09842756570162219, + -0.09016058064434311 + ], + "(1, 0, 0, 4, 1)": [ + -0.09004032812642203, + -0.07791186134035129, + -0.09719535819232616 + ], + "(1, 0, 1, 4, 1)": [ + -0.05966874855523139, + 0.00806343851278417, + -0.0359359869255538 + ], + "(1, 0, 1, 1, 0)": [ + -0.020070667836140862, + -0.00023703353233769307, + -0.019699742890545824 + ], + "(1, 0, 1, 2, 0)": [ + -0.08614089449335621, + -0.07468383088946322, + -0.028281994617095248 + ], + "(1, 1, 1, 2, 0)": [ + -0.020887757859586534, + -0.05007737176424587, + 0.013443648795931901 + ], + "(1, 1, 0, 2, 0)": [ + -0.25387563088246234, + -0.008011626613088369, + -0.18283538351700016 + ], + "(1, 1, 0, 3, 0)": [ + -0.07464578664914494, + -0.043449407581580696, + -0.11463260589847274 + ], + "(1, 0, 0, 3, 0)": [ + -0.14350968802347358, + -0.04083978797784146, + -0.09909284163171486 + ], + "(1, 0, 1, 3, 0)": [ + -0.1370625750166114, + -0.0724585837043522, + -0.0009240473895465921 + ], + "(1, 1, 1, 3, 0)": [ + -0.03626126792752658, + -0.03710154477294096, + 0.0032683572316288146 + ], + "(1, 0, 1, 4, 0)": [ + 0.0014061010064767788, + -0.04452670073239082, + -0.07821196453222311 + ], + "(1, 1, 1, 4, 0)": [ + -0.17474951197312322, + -0.5414578832524665, + -0.22510473565580394 + ], + "(1, 1, 0, 4, 0)": [ + -0.1328897390182134, + -0.034781929591368305, + -0.16080111281963685 + ], + "(1, 0, 0, 4, 0)": [ + -0.1475330350868065, + -0.18227419935411793, + 0.001484675759361425 + ], + "(1, 0, 0, 5, 0)": [ + -0.9545247785708878, + -0.9514659970621401, + -0.90452487601084 + ], + "(1, 0, 1, 5, 0)": [ + -0.8391556754298948, + -0.7701522287778763, + -0.8233480591505404 + ], + "(1, 1, 1, 5, 0)": [ + -1.093243461459317, + -1.1130313370385934, + -1.117598650488551 + ], + "(1, 1, 0, 5, 0)": [ + -1.060052207589391, + -1.0520482450609705, + -1.0541775063616219 + ], + "(2, 0, 1, 2, 0)": [ + 0.0019578641658367663, + 0.025663384896343185, + 0.1263307771397902 + ], + "(2, 0, 1, 3, 0)": [ + 0.023588571331503803, + 0.12134317356917385, + 0.2760718263547822 + ], + "(2, 0, 0, 3, 0)": [ + 0.1161666059796964, + 0.2235732549534617, + 0.07430355449847614 + ], + "(2, 1, 1, 2, 0)": [ + 0.03458905003176114, + 0.0007093755508323987, + 0.17288176348599543 + ], + "(1, 0, 1, 0, 0)": [ + 0.0011937985794666583, + 0.0006577923193096143, + -8.048310965484833e-05 + ], + "(1, 1, 1, 0, 0)": [ + 0.0015724376946350388, + 0.0007011948995402796, + -7.399361560918754e-05 + ], + "(2, 0, 0, 4, 0)": [ + -0.05649924448818272, + -0.09895201781380289, + 0.010597525817515366 + ], + "(2, 0, 1, 4, 0)": [ + -0.08358002678014068, + -0.09810695789263416, + -0.00041284702559314396 + ], + "(2, 1, 1, 4, 0)": [ + 0.09217839560450072, + -0.017320517480429995, + -0.03878804552748378 + ], + "(2, 1, 0, 4, 0)": [ + 0.1925844035349752, + 0.012835236335966602, + -0.042144040424107326 + ], + "(2, 1, 0, 5, 0)": [ + -1.4444514200032612, + -1.445392257262375, + -1.4481597245532174 + ], + "(2, 0, 0, 5, 0)": [ + -1.4712160523544338, + -1.470824695284624, + -1.3327809396730481 + ], + "(2, 0, 1, 5, 0)": [ + -1.4469254687808646, + -1.5299065594273575, + -1.5281209393177229 + ], + "(2, 1, 1, 5, 0)": [ + -1.3128182215345525, + -1.4477609597552596, + -1.4392818680144006 + ], + "(3, 1, 1, 1, 0)": [ + 0.45806194208553763, + 0.46093042135097084, + 0.474475379606745 + ], + "(4, 0, 1, 4, 1)": [ + 1.4164846977098555, + 1.4459868052131215, + 1.4616804923252622 + ], + "(4, 1, 1, 4, 1)": [ + 1.4129498511120588, + 1.4025285190325392, + 1.4393948253927626 + ], + "(1, 1, 0, 0, 0)": [ + 0.0010373859416610812, + 0.0011703526531334982, + 0.0007493443523265281 + ], + "(1, 0, 0, 2, 0)": [ + -0.10532136935062486, + -0.03584791600007468, + -0.11854695228463348 + ], + "(3, 0, 1, 5, 0)": [ + 0.02285934492582601, + -0.07103977867880287, + 0.019668059978482654 + ], + "(3, 1, 1, 5, 0)": [ + 0.03469354291860075, + 0.044038976296537, + 0.013242232441255847 + ], + "(3, 1, 0, 5, 0)": [ + -0.020181802884764344, + 0.06649176778651363, + 0.06509248190323355 + ], + "(3, 0, 0, 5, 0)": [ + 0.028824321631020555, + 0.029977268479488877, + 0.026545109024266954 + ], + "(3, 1, 0, 4, 0)": [ + 0.6296155314616831, + 0.4336608270212875, + 0.39276605272085297 + ], + "(3, 0, 0, 4, 0)": [ + 0.7770847854086554, + 0.5242727900048311, + 0.5206170117506856 + ], + "(3, 0, 1, 4, 0)": [ + 0.5314214160012682, + 0.7844820846414253, + 0.5074484536943829 + ], + "(3, 1, 1, 4, 0)": [ + 0.3754666321452126, + 0.7229753308135625, + 0.3995770969535575 + ], + "(4, 1, 0, 4, 1)": [ + 1.4963873160873034, + 1.467341373698872, + 1.4384319896246058 + ], + "(4, 1, 0, 3, 1)": [ + 1.5752245110522871, + 1.5602010532290964, + 1.492566856395298 + ], + "(4, 0, 0, 3, 1)": [ + 1.519151228092836, + 1.5227125117824454, + 1.4903275891051206 + ], + "(4, 0, 0, 4, 1)": [ + 1.6473288659015743, + 1.7275422975605321, + 1.5048793902151347 + ], + "(4, 0, 1, 3, 1)": [ + 1.4129889090497951, + 1.452240570193209, + 1.4625105948676116 + ], + "(4, 1, 1, 3, 1)": [ + 1.4646407295985004, + 1.4732073804218064, + 1.511095490878426 + ], + "(4, 1, 0, 2, 1)": [ + 2.106400147645301, + 1.6356823386754527, + 1.6995586432240426 + ], + "(4, 0, 0, 2, 1)": [ + 2.161395368605027, + 1.5237311046844422, + 1.8173060474051923 + ], + "(4, 0, 0, 1, 1)": [ + 2.3493776509278357, + 1.7130386451873514, + 1.8850438496049542 + ], + "(4, 0, 1, 1, 1)": [ + 2.1151543614085875, + 1.6543782490876844, + 2.2703735803195966 + ], + "(4, 1, 1, 1, 1)": [ + 1.918925083602411, + 1.74973854680099, + 2.5088229704031644 + ], + "(4, 1, 1, 1, 0)": [ + 1.7371298756417777, + 1.962531149214229, + 2.264878889586818 + ], + "(4, 1, 0, 1, 0)": [ + 2.020061243093679, + 1.8524421509129365, + 1.9286353167781694 + ], + "(4, 0, 0, 1, 0)": [ + 1.6817658822452428, + 1.6421444458419396, + 2.3586227153499326 + ], + "(4, 0, 1, 1, 0)": [ + 1.78836241223083, + 1.848369564939383, + 2.453235378342548 + ], + "(4, 0, 1, 2, 1)": [ + 1.5272002349823173, + 1.774756715807296, + 2.2631339867701725 + ], + "(4, 1, 0, 1, 1)": [ + 1.9677313784165418, + 2.433707323529033, + 1.7422610937368759 + ], + "(4, 1, 1, 2, 1)": [ + 1.490825730025895, + 1.6990871895059803, + 2.1039556756595323 + ], + "(5, 1, 0, 5, 1)": [ + 6.123750372525588, + 8.70575880253172, + 6.260498065482083 + ], + "(5, 0, 0, 5, 1)": [ + 7.308611772562493, + 5.240209399773695, + 5.920309823658606 + ], + "(5, 0, 1, 5, 1)": [ + 6.96203305289973, + 4.890002993836127, + 11.78691576141891 + ], + "(5, 1, 1, 5, 1)": [ + 6.713821193839766, + 9.316098624644017, + 5.311107957816669 + ], + "(6, 1, 1, 5, 1)": [ + 15.047660106650373, + 16.821457503565934, + 27.71374620203335 + ], + "(6, 1, 0, 5, 1)": [ + 16.384453849684327, + 28.353343581666334, + 20.687843932731177 + ], + "(6, 0, 0, 5, 1)": [ + 30.01197247830154, + 18.86235490305763, + 19.344672586909752 + ], + "(6, 0, 1, 5, 1)": [ + 26.558921440960397, + 18.868850427579066, + 16.828520283531955 + ], + "(4, 0, 1, 3, 0)": [ + 2.1925640020257102, + 2.95691906279798, + 3.4723878527181853 + ], + "(4, 1, 1, 3, 0)": [ + 2.198797063712139, + 2.21724652824473, + 4.963837212356006 + ], + "(4, 1, 0, 3, 0)": [ + 3.886035854924468, + 2.4084368048531406, + 2.8785873750775908 + ], + "(4, 0, 0, 3, 0)": [ + 2.4390869073407133, + 3.1202053796624596, + 3.440342237225939 + ], + "(4, 1, 1, 2, 0)": [ + 1.7198249761230124, + 1.7264336589470106, + 2.55030041839504 + ], + "(4, 1, 0, 2, 0)": [ + 2.857534274934687, + 1.8560661471861224, + 1.7892755484782197 + ], + "(4, 1, 0, 4, 0)": [ + 5.208515876962349, + 4.076311172617219, + 3.1256266560673907 + ], + "(4, 0, 0, 4, 0)": [ + 5.359684776234321, + 2.8072297945126206, + 3.9345761324135915 + ], + "(4, 0, 1, 4, 0)": [ + 3.2715111780893653, + 5.189496212005362, + 4.395806899462235 + ], + "(4, 1, 1, 4, 0)": [ + 3.536815804580715, + 3.9015860629086827, + 5.135748652698586 + ], + "(4, 0, 1, 5, 0)": [ + -0.8730882647726272, + -0.8192135917755976, + -0.8542523954122301 + ], + "(4, 1, 1, 5, 0)": [ + -0.8245347871065521, + -0.8326455650944975, + 0.9991824483293552 + ], + "(4, 0, 1, 2, 0)": [ + 1.9753930346045618, + 1.9649213400482635, + 1.9809836107767345 + ], + "(4, 1, 0, 5, 0)": [ + -0.6336931013817524, + 1.110335983119446, + 0.010250432901592221 + ], + "(4, 0, 0, 5, 0)": [ + -0.5955443430025976, + -0.5479728074067413, + -0.5137459310071231 + ], + "(4, 0, 0, 2, 0)": [ + 2.306030089074062, + 2.0513914882428184, + 1.7358723965792504 + ], + "(5, 0, 1, 2, 1)": [ + 5.938331314830866, + 5.923671010454649, + 6.322198810091958 + ], + "(5, 0, 1, 1, 1)": [ + 6.3713621509072205, + 4.278415316748207, + 5.217314857720529 + ], + "(5, 1, 0, 1, 1)": [ + 6.271680380359423, + 6.118414945090114, + 5.863101745784359 + ], + "(5, 0, 0, 1, 1)": [ + 6.343186747568721, + 5.938795353886484, + 6.033491402716604 + ], + "(5, 1, 1, 1, 1)": [ + 5.855501877516305, + 5.146222646836102, + 6.270943831177874 + ], + "(5, 1, 1, 2, 1)": [ + 5.965178381510156, + 4.951535473108393, + 6.335327059170063 + ], + "(5, 1, 0, 2, 1)": [ + 6.101590240378415, + 5.943638909366957, + 6.166377646864589 + ], + "(5, 1, 0, 3, 1)": [ + 9.399681750212956, + 6.046443374728028, + 5.472047895906149 + ], + "(5, 0, 0, 3, 1)": [ + 7.264852310738755, + 6.100976298895757, + 10.121931565466037 + ], + "(5, 0, 1, 3, 1)": [ + 6.058497642090931, + 6.64055605254837, + 12.820985400788588 + ], + "(5, 1, 1, 3, 1)": [ + 6.295127566601765, + 8.86480034761974, + 6.525986911785163 + ], + "(5, 0, 0, 4, 1)": [ + 11.620473329796793, + 5.852233709190965, + 6.172684561311558 + ], + "(5, 0, 1, 4, 1)": [ + 5.853903246175382, + 7.465633418352998, + 14.203272816632714 + ], + "(5, 1, 1, 4, 1)": [ + 8.411250446883901, + 6.821876889600531, + 10.420607535997695 + ], + "(5, 1, 0, 4, 1)": [ + 10.978468161677114, + 6.095771009776962, + 7.448828231947645 + ], + "(5, 0, 0, 5, 0)": [ + 4.168560040349653, + 4.1193373810885126, + 4.189200305683654 + ], + "(5, 0, 1, 5, 0)": [ + 3.9037126196648213, + 3.952725477446389, + 3.925950696362765 + ], + "(5, 0, 0, 4, 0)": [ + 6.989124827386673, + 6.709668768383307, + 6.164502340428653 + ], + "(5, 0, 1, 4, 0)": [ + 6.470192212401894, + 6.206583505620243, + 6.219341416738382 + ], + "(5, 1, 1, 4, 0)": [ + 6.44804048006831, + 6.470249716808012, + 6.499092139332864 + ], + "(5, 1, 0, 4, 0)": [ + 6.60974475903126, + 5.809620089830466, + 6.38325714226122 + ], + "(5, 0, 1, 3, 0)": [ + 5.902298693239425, + 6.513993949524509, + 6.050931365245501 + ], + "(5, 1, 1, 3, 0)": [ + 5.7219438824090165, + 6.654252839175707, + 5.854631300383308 + ], + "(5, 1, 0, 3, 0)": [ + 7.060124662461811, + 6.133863344872211, + 5.887388222544612 + ], + "(5, 1, 0, 2, 0)": [ + 6.420811088645497, + 5.484591712572488, + 5.344338090966727 + ], + "(5, 1, 0, 1, 0)": [ + 7.60161765328693, + 5.415707409756238, + 5.625087230101972 + ], + "(5, 0, 0, 1, 0)": [ + 6.362344373068486, + 6.250819401040948, + 5.0532246397904155 + ], + "(5, 0, 0, 2, 1)": [ + 6.389023020089472, + 5.921270205667463, + 5.710168906134173 + ], + "(2, 1, 1, 0, 0)": [ + 0.007418340685015537, + 0.0, + 0.0 + ], + "(5, 1, 1, 5, 0)": [ + 3.993346037475164, + 4.02595780895983, + 3.958113797811793 + ], + "(5, 1, 0, 5, 0)": [ + 3.8704975586426262, + 3.829895425894513, + 3.8484935557684774 + ], + "(5, 0, 0, 3, 0)": [ + 7.206895605095726, + 6.361422712879813, + 6.382153739961362 + ], + "(6, 1, 1, 3, 0)": [ + 39.16516514968102, + 0.0, + 13.874993945259016 + ], + "(6, 1, 0, 3, 0)": [ + 189.25859455848735, + 0.0, + 29.91897931545399 + ], + "(6, 1, 0, 2, 0)": [ + 17.158262188709795, + 1.9514931582028616, + 0.0 + ], + "(6, 0, 0, 1, 0)": [ + 25.614850876813907, + 0.0, + 6.3154150574104975 + ], + "(6, 0, 0, 1, 1)": [ + 17.352813028917932, + 2.580182984411188, + 2.140465724900898 + ], + "(6, 0, 1, 1, 1)": [ + 31.2334214211808, + 4.9985873791887645, + 0.0 + ], + "(6, 0, 1, 2, 1)": [ + 29.335273825224643, + 14.06297569453392, + 21.55271341124275 + ], + "(6, 1, 1, 3, 1)": [ + 23.25511720982458, + 16.649397764488587, + 19.90803592789075 + ], + "(6, 1, 0, 3, 1)": [ + 26.80171996379373, + 0.0, + 18.557566348591283 + ], + "(5, 0, 0, 2, 0)": [ + 6.128584915294166, + 4.74501528893155, + 5.71323506758402 + ], + "(5, 0, 1, 1, 0)": [ + 5.433105157444771, + 6.061572284699453, + 5.812568088606896 + ], + "(5, 0, 1, 2, 0)": [ + 9.299494331173243, + 4.798479530028196, + 5.0529263281147205 + ], + "(5, 1, 1, 1, 0)": [ + 5.411730171435421, + 5.336890832671549, + 3.6940890253938563 + ], + "(6, 1, 0, 4, 1)": [ + 46.0361726410884, + 185.7281789946067, + 24.898756482539923 + ], + "(6, 0, 0, 4, 1)": [ + 240.5192583338726, + 19.671210374109826, + 32.44011863831897 + ], + "(6, 0, 1, 4, 1)": [ + 217.5041757792318, + 36.04632868615748, + 35.48728683017791 + ], + "(6, 1, 0, 5, 0)": [ + 43.831876602918115, + 27.8889817150394, + 41.01229025098874 + ], + "(6, 0, 0, 5, 0)": [ + 37.31076342912965, + 72.17808547747593, + 93.69829260586646 + ], + "(6, 0, 1, 5, 0)": [ + 30.225634154900494, + 57.043973946193184, + 19.12245464571805 + ], + "(6, 1, 1, 5, 0)": [ + 22.31531135227297, + 31.702532410298858, + 24.52593989472061 + ], + "(6, 1, 0, 4, 0)": [ + 51.69128813332314, + 43.10494568818312, + 31.630799280419943 + ], + "(6, 0, 0, 4, 0)": [ + 118.94112981801939, + 21.91343878965308, + 0.0 + ], + "(6, 0, 1, 4, 0)": [ + 43.63117424617634, + 72.1459389593239, + 244.531009560538 + ], + "(6, 1, 1, 4, 1)": [ + 40.439091364812, + 62.86471099901328, + 28.21579086570153 + ], + "(5, 1, 1, 2, 0)": [ + 6.0763041725669105, + 3.2256356687419245, + 1.9883725448540135 + ], + "(2, 1, 0, 0, 0)": [ + 0.04124880442854498, + 0.0, + 0.0 + ], + "(6, 1, 1, 4, 0)": [ + 33.12059537789176, + 36.79265587145829, + 33.0038579587303 + ], + "(6, 0, 0, 3, 0)": [ + 6.610554709899747, + 0.0, + 0.0 + ], + "(6, 0, 0, 2, 0)": [ + 19.103288763514552, + 1.7250724331149645, + 0.0 + ], + "(6, 0, 1, 2, 0)": [ + 11.45538182633565, + 0.0, + 0.0 + ], + "(6, 0, 1, 1, 0)": [ + 36.245136836621846, + 0.0, + 0.0 + ], + "(6, 1, 1, 1, 0)": [ + 51.319980637299345, + 10.223271940420567, + 3.1856529540347966 + ], + "(6, 1, 1, 1, 1)": [ + 28.43873600008985, + 3.5188606871528423, + 5.206377154699371 + ], + "(6, 1, 0, 1, 1)": [ + 254.43695151567147, + 0.0, + 4.383567406279885 + ], + "(6, 1, 0, 2, 1)": [ + 4.645323738237858, + 41.06083981486784, + 0.0 + ], + "(6, 0, 0, 2, 1)": [ + 34.28350285678632, + 34.97694231722596, + 0.0 + ], + "(6, 0, 1, 3, 1)": [ + 23.303998337366593, + 23.665273191046108, + 23.58001108802269 + ], + "(4, 1, 1, 0, 0)": [ + 0.32071104927797534, + 0.0, + 0.0 + ], + "(6, 0, 0, 3, 1)": [ + 29.960345351525937, + 12.223389700753652, + 14.174347922430263 + ], + "(2, 0, 1, 0, 0)": [ + 0.023708233987275087, + 0.0, + 0.0 + ], + "(6, 1, 1, 2, 1)": [ + 32.770471384636984, + 0.0, + 0.0 + ], + "(6, 1, 1, 2, 0)": [ + 0.0, + 8.425163932932508, + 0.0 + ], + "(6, 1, 0, 1, 0)": [ + 11.400227838868048, + 2.5621867842074355, + 1.6266017364584542 + ], + "(6, 0, 1, 3, 0)": [ + 410.21798252237346, + 103.80547176610862, + 0.0 + ] + } +} \ No newline at end of file diff --git a/scripts/use_q_ide.py b/scripts/use_q_ide.py index fa3c4c7..883abc7 100644 --- a/scripts/use_q_ide.py +++ b/scripts/use_q_ide.py @@ -26,15 +26,15 @@ def do_use(kwargs: dict[str, Any]) -> None: render (str)='none': render mode of environment reward (float)=-0.1: reward limit at which episode is terminated file (str): Optional definition of model-save file - use_trained (bool): Use pre-trained data? + use_file (str): How 'file' is used (if exists): 'r', 'w', 'rw' episodes (int)=10000: nnumber of episodes run in the training steps (int)=5000: number of steps per episodes (if not terminated or truncated) - + t_fac (float)=0.001 """ if "dry-train" in kwargs: # Check training setup (over-write some parameters) - kwargs.update({"render": "plot", "file": None, "use_trained": False, "episodes": 10, "steps": 1000}) + kwargs.update({"render": "plot", "file": None, "use_file": "r", "episodes": 10, "steps": 1000}) elif "dry_do" in kwargs: # Run a few episodes on trained data (file can be set by caller) - kwargs.update({"render": "plot", "use_trained": True, "episodes": 10, "steps": 1000}) + kwargs.update({"render": "plot", "use_file": "r", "episodes": 10, "steps": 1000}) env = AntiPendulumEnv( build_crane, seed=1, @@ -43,31 +43,51 @@ def do_use(kwargs: dict[str, Any]) -> None: render_mode=kwargs.get("render", "none"), reward_limit=kwargs.get("reward", 0.0), discrete=QLearningAgent.DEFAULT_DISCRETE.copy(), + reward_fac=(1.0, 0.0015, kwargs.get("t_fac", 0.0)), ) filename = kwargs.get("file") if filename is not None: Path(filename).parent.mkdir(parents=True, exist_ok=True) - use_trained = kwargs.get("use_trained", False) - agent = QLearningAgent(env, filename=filename, use_trained=use_trained) + use_file = kwargs.get("use_file", "r") + agent = QLearningAgent(env, filename=filename, use_file=use_file) agent.do_episodes(n_episodes=kwargs.get("episodes", 100), max_steps=kwargs.get("steps", 5000)) - if filename is not None: + if filename is not None and 'w' in agent.use_file: LOGGER.info(f"Model saved to {filename}") if __name__ == "__main__": + + def _args(base: dict[str, Any], upd: dict[str, Any]) -> dict[str, Any]: + base.update(upd) + return base + models = Path(__file__).parent.resolve().parent / "models" - args = { + anti = { # anti-pendulum settings "v0": 1.0, "render": "none", "reward": 0.0, "file": models / "q_anti-pendulum.json", - "use_trained": True, + "use_file": "rw", + "episodes": 1000, + "steps": 2000, + "t_fac": 0.0, + } + pend = { # start pendulum settings + "v0": 0.0, + "render": "none", + "reward": 200.0, + "file": models / "q_pendulum.json", + "use_file": "rw", "episodes": 1000, - "steps": 5000, + "steps": 2000, + "t_fac": 0.0, } - # args.update({'episodes':6000, 'use_trained':True}) # noqa: ERA001 ## do a mayor training adding to data - args.update({"episodes": 10, "render": "plot"}) - # args.update({'dry-train':True,}) # noqa: ERA001 ## check the setup before a long training - # args.update({'dry_do':True}) # noqa: ERA001 + # ruff: disable[ERA001] ## we intentionally work with commenting out lines here + args = _args(anti, {"episodes": 2000}) # anti-pendulum (additional) training + # args = _args(pend, {'episodes':10000}) # pendulum training + # args = _args( anti, {"episodes": 10, "render": "plot","use_file":'r'}) # show anti-pendulum results + # args = _args( pend, {"episodes": 10, "render": "plot", "use_file":'r'}) # show start pendulum results + # args = args.update(_args(anti, {'dry-train':True,})) # check the setup before a long training + # ruff: enable[ERA001] do_use(args) diff --git a/src/crane_controller/envs/controlled_crane_pendulum.py b/src/crane_controller/envs/controlled_crane_pendulum.py index 04992e1..49f5ff2 100644 --- a/src/crane_controller/envs/controlled_crane_pendulum.py +++ b/src/crane_controller/envs/controlled_crane_pendulum.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar import gymnasium as gym import matplotlib.pyplot as plt @@ -87,6 +87,7 @@ class AntiPendulumEnv(gym.Env[AntiPendulumObs, int]): When provided, activates discrete observation mode with the given category boundaries. Expected keys: ``"angles"``, ``"pos"``, ``"speed"``, ``"distance"``, ``"sector"`` (default None). + reward_fac (tuple[float,...])=(1.0,0.0015,0.001): Weights between reward contributions """ metadata: ClassVar[dict[str, object]] = { # pyright: ignore[reportIncompatibleVariableOverride] # Gymnasium metadata typing is loose @@ -493,3 +494,15 @@ def render(self) -> None: self.show_animation() elif self.render_mode == "plot": self.show_plot(self.nresets) + + def get_parameters(self) -> dict[str, Any]: + """Return the environment parameter settings as dict.""" + return { + "wire-length": self.wire.length, + "wire-q-factor": self.wire.q_factor, + "reward-factors": self.reward_fac, + "acceleration": self.acc, + "step-size": self.dt, + "observations-discretization": None if not hasattr(self, "discrete") else self.discrete, + "reward_limit": self.reward_limit, + } diff --git a/src/crane_controller/q_agent.py b/src/crane_controller/q_agent.py index fbf5529..2b824d1 100644 --- a/src/crane_controller/q_agent.py +++ b/src/crane_controller/q_agent.py @@ -65,7 +65,7 @@ class QLearningAgent: discount_factor : float, optional How much to value future rewards, in the range [0, 1] (default 0.95). filename (Path): Optional path to filename for pre-trained data and saving of results - use_trained (bool) = False: load pre-trained values? + use_file (str) = 'r': How to use filename. 'r', 'w', 'rw'. File is not read when not found! """ DEFAULT_DISCRETE: ClassVar[dict[str, tuple[float | int, ...]]] = { @@ -80,12 +80,11 @@ def __init__( self, env: AntiPendulumEnv, learning_rate: float = 0.1, - initial_epsilon: float = 1.0, + epsilon_decay: float = 1e-4, final_epsilon: float = 0.1, discount_factor: float = 0.95, filename: Path | None = None, - *, - use_trained: bool = False, + use_file: str = "r", ) -> None: """Initialize the Q-learning agent. @@ -93,23 +92,20 @@ def __init__( """ self.env = env self.filename = Path(filename) if filename is not None else None - self.use_trained = use_trained + self.use_file = use_file self.q_values: defaultdict[tuple[int, ...], np.ndarray] - if self.use_trained and self.filename is not None and self.filename.exists(): - self.q_values = self.read_dumped(self.filename) - self.epsilon = final_epsilon # assume that we are fully learned - else: # start from scratch, but save the q_values afterwards - self.q_values = defaultdict(lambda: np.array((0.0,) * env.action_space.n, float)) # type: ignore[attr-defined,type-var] - self.epsilon = initial_epsilon # start from scratch self.lr = learning_rate self.discount_factor = discount_factor # How much we care about future rewards # Exploration parameters + self.epsilon = 1.0 + self.epsilon_decay = epsilon_decay self.final_epsilon = final_epsilon # Track learning progress self.training_error: list[float] = [] + self.previous_steps = 0 def analyse_q(self, obs: tuple[int, ...]) -> None: """Log Q-table entries matching an observation pattern. @@ -207,10 +203,13 @@ def do_episodes(self, n_episodes: int = 1000, max_steps: int = 5000, show: int = Visualization mode - 0 for none, 1 for training summary, 2 for per-episode analysis (default 0). """ - if self.use_trained: + if "r" in self.use_file and self.filename is not None and self.filename.exists(): + self.q_values = self.read_dumped(self.filename) logger.info("Starting %s episodes, using pre-trained values from %s", n_episodes, self.filename) - else: + else: # start from scratch + self.q_values = defaultdict(lambda: np.array((0.0,) * self.env.action_space.n, float)) # type: ignore[attr-defined,type-var] logger.info("Starting new training with %s episodes.", n_episodes) + start_time = dt.datetime.now(dt.UTC) total_steps = 0 for _episode in tqdm(range(n_episodes)): # Start a new episode @@ -233,13 +232,15 @@ def do_episodes(self, n_episodes: int = 1000, max_steps: int = 5000, show: int = truncated |= nsteps > max_steps total_steps += nsteps # Reduce exploration rate (agent becomes less random over time): - self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon / (n_episodes / 2)) + self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay) if show == SHOW_TRAINING_SUMMARY: self.analyse_training() - if self.filename: - self.dump_results(episodes=n_episodes, steps=total_steps) + if self.filename and "w" in self.use_file: + self.dump_results(episodes=n_episodes, steps=total_steps, start_time=start_time) - def dump_results(self, filename: str | Path = "", episodes: int = -1, steps: int = -1) -> None: + def dump_results( + self, filename: str | Path = "", episodes: int = -1, steps: int = -1, start_time: dt.datetime | None = None + ) -> None: """Dump the Q-values to a JSON file. Args: @@ -247,6 +248,7 @@ def dump_results(self, filename: str | Path = "", episodes: int = -1, steps: int When empty, the filename provided at construction time is used (default ""). episodes (int): the number of episodes which have been run steps (int): the limiting number of steps per episode + start_time (dt.datetime): clock-time when the training started """ if not filename: # automatic file name if self.filename is None: @@ -259,20 +261,21 @@ def dump_results(self, filename: str | Path = "", episodes: int = -1, steps: int converted: dict[str, list[float]] = {} for k, v in self.q_values.items(): converted |= {str(k): list(v)} + env_parameters = {k: str(v) for k, v in self.env.get_parameters().items()} content = { - "date": dt.datetime.now(dt.UTC).strftime("%d.%m.%Y %H:%M:%S"), - "pendulum": { - "start_speed": str(self.env.start_speed), - "render_mode": str(self.env.render_mode), - "reward_limit": str(self.env.reward_limit), - }, + "start-training": "unknown" if start_time is None else start_time.strftime("%d.%m.%Y %H:%M:%S"), + "end-training": dt.datetime.now(dt.UTC).strftime("%d.%m.%Y %H:%M:%S"), + "pendulum": env_parameters, "q_agent": { - "use_trained": str(self.use_trained), "filename": str(self.filename), + "use_file": self.use_file, "episodes": str(episodes), - "steps": str(steps), + "steps": str(steps + self.previous_steps), "learning_rate": str(self.lr), "discount_factor": str(self.discount_factor), + "epsilon-decay": str(self.epsilon_decay), + "final-epsilon": str(self.final_epsilon), + "epsilon": str(self.epsilon), }, "q_values": converted, } @@ -280,7 +283,7 @@ def dump_results(self, filename: str | Path = "", episodes: int = -1, steps: int json.dump(content, _f, indent=3) logger.info("Updated q_values saved to %s", _filename.resolve()) - def read_dumped(self, filename: str | Path) -> defaultdict[tuple[int, ...], np.ndarray]: + def read_dumped(self, filename: str | Path | None = None) -> defaultdict[tuple[int, ...], np.ndarray]: """Read a Q-values dict from a JSON file. Parameters @@ -293,15 +296,25 @@ def read_dumped(self, filename: str | Path) -> defaultdict[tuple[int, ...], np.n defaultdict[tuple[int, ...], np.ndarray] Loaded Q-values mapping observation tuples to action-value arrays. """ - path = Path(filename) - with path.open(encoding="utf-8") as _f: - from_dump = json.load(_f) q_values: defaultdict[tuple[int, ...], np.ndarray] = defaultdict( lambda: np.array((0.0,) * self.env.action_space.n, float) # type: ignore[attr-defined,type-var] ) - assert "q_values" in from_dump, f"Key 'q_values' not found in file {filename}" - for k, v in from_dump["q_values"].items(): - q_values.update({literal_eval(k): np.array(v) if isinstance(v, list) else v}) + if filename is None and self.filename is None: # there is no file to read. Return empty defautdict + pass + else: + if filename is not None: + path = Path(filename) + elif self.filename is not None: + path = Path(self.filename) + + with path.open(encoding="utf-8") as _f: + from_dump = json.load(_f) + self.previous_steps = int(from_dump["q_agent"]["steps"]) + self.epsilon = float(from_dump["q_agent"].get("epsilon", 1.0)) + self.epsilon_decay = float(from_dump["q_agent"].get("epsilon", 1e-4)) + assert "q_values" in from_dump, f"Key 'q_values' not found in file {filename}" + for k, v in from_dump["q_values"].items(): + q_values.update({literal_eval(k): np.array(v) if isinstance(v, list) else v}) return q_values def analyse_training(self, window: int = 500) -> None: diff --git a/tests/test_q.py b/tests/test_q.py index 832621d..6a62eca 100644 --- a/tests/test_q.py +++ b/tests/test_q.py @@ -32,7 +32,9 @@ def test_q_analyse(crane: Callable[..., Crane], *, show: bool) -> None: crane, discrete=QLearningAgent.DEFAULT_DISCRETE.copy(), ) - agent = QLearningAgent(env, filename=Path("q_trained.json"), use_trained=True) + assert Path("q_trained.json").exists(), "File 'q_trained.json' not found" + agent = QLearningAgent(env, filename=Path("q_trained.json"), use_file="r") + agent.q_values = agent.read_dumped() for k, v in agent.q_values.items(): assert len(k) == 5, len(v) == 3 for pos in (0, 1): @@ -57,12 +59,12 @@ def test_intervals(crane: Callable[..., Crane]): discrete=QLearningAgent.DEFAULT_DISCRETE.copy(), ) - agent = QLearningAgent(env, filename=save_path, use_trained=False) + agent = QLearningAgent(env, filename=save_path, use_file="w") for i in range(10): _ = env.reset(seed=i + 1) agent.do_episodes(n_episodes=2, max_steps=100) if i == 0: - agent = QLearningAgent(env, filename=save_path, use_trained=True) + agent = QLearningAgent(env, filename=save_path, use_file="rw") logger.info(f"Model saved to {save_path}") @@ -72,6 +74,8 @@ def test_intervals(crane: Callable[..., Crane]): import pytest + from crane_controller.crane_factory import build_crane # noqa: F401 + retcode = pytest.main(["-rP -s -v", __file__]) assert retcode == 0, f"Return code {retcode}" os.chdir(Path(__file__).parent.absolute() / "test_working_directory")