Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion legged_gym/legged_gym/envs/base/legged_robot_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ class runner:
max_iterations = 200000 # number of policy updates

# logging
save_interval = 20 # check for potential saves every this many iterations
save_interval = 10 # check for potential saves every this many iterations
experiment_name = 'test'
run_name = ''
# load and resume
Expand Down
13 changes: 8 additions & 5 deletions rsl_rl/rsl_rl/algorithms/him_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def update(self):
mean_value_loss = 0
mean_surrogate_loss = 0
mean_estimation_loss = 0
mean_swap_loss = 0
mean_kl_loss = 0
mean_recon_loss = 0

generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)

Expand Down Expand Up @@ -150,7 +151,7 @@ def update(self):
param_group['lr'] = self.learning_rate

#Estimator Update
estimation_loss, swap_loss = self.actor_critic.estimator.update(obs_batch, next_critic_obs_batch, lr=self.learning_rate)
estimation_loss, kl_loss, recon_loss = self.actor_critic.estimator.update(obs_batch, next_critic_obs_batch, lr=self.learning_rate)

# Surrogate loss
ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
Expand Down Expand Up @@ -180,13 +181,15 @@ def update(self):
mean_value_loss += value_loss.item()
mean_surrogate_loss += surrogate_loss.item()
mean_estimation_loss += estimation_loss
mean_swap_loss += swap_loss
mean_kl_loss += kl_loss
mean_recon_loss += recon_loss

num_updates = self.num_learning_epochs * self.num_mini_batches
mean_value_loss /= num_updates
mean_surrogate_loss /= num_updates
mean_estimation_loss /= num_updates
mean_swap_loss /= num_updates
mean_kl_loss /= num_updates
mean_recon_loss /= num_updates
self.storage.clear()

return mean_value_loss, mean_surrogate_loss, estimation_loss, swap_loss
return mean_value_loss, mean_surrogate_loss, mean_estimation_loss, mean_kl_loss, mean_recon_loss
119 changes: 38 additions & 81 deletions rsl_rl/rsl_rl/modules/him_estimator.py
Original file line number Diff line number Diff line change
@@ -1,136 +1,91 @@
import copy
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributions as torchd
from torch.distributions import Normal, Categorical


class HIMEstimator(nn.Module):
def __init__(self,
temporal_steps,
num_one_step_obs,
enc_hidden_dims=[128, 64, 16],
tar_hidden_dims=[128, 64],
activation='elu',
learning_rate=1e-3,
max_grad_norm=10.0,
num_prototype=32,
temperature=3.0,
kl_weight=1.0,
**kwargs):
if kwargs:
print("Estimator_CL.__init__ got unexpected arguments, which will be ignored: " + str(
[key for key in kwargs.keys()]))
print("HIMEstimator.__init__ got unexpected arguments, which will be ignored: " +
str([key for key in kwargs.keys()]))
super(HIMEstimator, self).__init__()
activation = get_activation(activation)
activation_fn = get_activation(activation)

self.temporal_steps = temporal_steps
self.num_one_step_obs = num_one_step_obs
self.num_latent = enc_hidden_dims[-1]
self.max_grad_norm = max_grad_norm
self.temperature = temperature
self.kl_weight = kl_weight

# Encoder
# Encoder: outputs vel(3) + mu(num_latent) + logvar(num_latent)
enc_input_dim = self.temporal_steps * self.num_one_step_obs
enc_layers = []
for l in range(len(enc_hidden_dims) - 1):
enc_layers += [nn.Linear(enc_input_dim, enc_hidden_dims[l]), activation]
enc_layers += [nn.Linear(enc_input_dim, enc_hidden_dims[l]), activation_fn]
enc_input_dim = enc_hidden_dims[l]
enc_layers += [nn.Linear(enc_input_dim, enc_hidden_dims[-1] + 3)]
enc_layers += [nn.Linear(enc_input_dim, 3 + self.num_latent * 2)]
self.encoder = nn.Sequential(*enc_layers)

# Target
tar_input_dim = self.num_one_step_obs
tar_layers = []
for l in range(len(tar_hidden_dims)):
tar_layers += [nn.Linear(tar_input_dim, tar_hidden_dims[l]), activation]
tar_input_dim = tar_hidden_dims[l]
tar_layers += [nn.Linear(tar_input_dim, enc_hidden_dims[-1])]
self.target = nn.Sequential(*tar_layers)

# Prototype
self.proto = nn.Embedding(num_prototype, enc_hidden_dims[-1])

# Optimizer
self.learning_rate = learning_rate
self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)

def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std

def get_latent(self, obs_history):
vel, z = self.encode(obs_history)
return vel.detach(), z.detach()
"""Inference: use mu directly (no sampling noise)."""
out = self.encoder(obs_history.detach())
vel = out[..., :3]
mu = out[..., 3:3 + self.num_latent]
return vel.detach(), mu.detach()

def forward(self, obs_history):
parts = self.encoder(obs_history.detach())
vel, z = parts[..., :3], parts[..., 3:]
z = F.normalize(z, dim=-1, p=2)
return vel.detach(), z.detach()
return self.get_latent(obs_history)

def encode(self, obs_history):
parts = self.encoder(obs_history.detach())
vel, z = parts[..., :3], parts[..., 3:]
z = F.normalize(z, dim=-1, p=2)
return vel, z
"""Training: sample z via reparameterization."""
out = self.encoder(obs_history.detach())
vel = out[..., :3]
mu = out[..., 3:3 + self.num_latent]
logvar = out[..., 3 + self.num_latent:]
z = self.reparameterize(mu, logvar)
return vel, mu, logvar, z

def update(self, obs_history, next_critic_obs, lr=None):
if lr is not None:
self.learning_rate = lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.learning_rate

vel = next_critic_obs[:, self.num_one_step_obs:self.num_one_step_obs+3].detach()
next_obs = next_critic_obs.detach()[:, 3:self.num_one_step_obs+3]

z_s = self.encoder(obs_history)
z_t = self.target(next_obs)
pred_vel, z_s = z_s[..., :3], z_s[..., 3:]

z_s = F.normalize(z_s, dim=-1, p=2)
z_t = F.normalize(z_t, dim=-1, p=2)

with torch.no_grad():
w = self.proto.weight.data.clone()
w = F.normalize(w, dim=-1, p=2)
self.proto.weight.copy_(w)
# Ground-truth velocity from privileged obs
vel_gt = next_critic_obs[:, self.num_one_step_obs:self.num_one_step_obs + 3].detach()

score_s = z_s @ self.proto.weight.T
score_t = z_t @ self.proto.weight.T
pred_vel, mu, logvar, _ = self.encode(obs_history)

with torch.no_grad():
q_s = sinkhorn(score_s)
q_t = sinkhorn(score_t)
estimation_loss = F.mse_loss(pred_vel, vel_gt)

log_p_s = F.log_softmax(score_s / self.temperature, dim=-1)
log_p_t = F.log_softmax(score_t / self.temperature, dim=-1)
# KL divergence: D_KL( N(mu, sigma) || N(0,1) )
kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

swap_loss = -0.5 * (q_s * log_p_t + q_t * log_p_s).mean()
estimation_loss = F.mse_loss(pred_vel, vel)
losses = estimation_loss + swap_loss
loss = estimation_loss + self.kl_weight * kl_loss

self.optimizer.zero_grad()
losses.backward()
loss.backward()
nn.utils.clip_grad_norm_(self.parameters(), self.max_grad_norm)
self.optimizer.step()

return estimation_loss.item(), swap_loss.item()


@torch.no_grad()
def sinkhorn(out, eps=0.05, iters=3):
Q = torch.exp(out / eps).T
K, B = Q.shape[0], Q.shape[1]
Q /= Q.sum()

for it in range(iters):
# normalize each row: total weight per prototype must be 1/K
Q /= torch.sum(Q, dim=1, keepdim=True)
Q /= K

# normalize each column: total weight per sample must be 1/B
Q /= torch.sum(Q, dim=0, keepdim=True)
Q /= B
return (Q * B).T
return estimation_loss.item(), kl_loss.item()


def get_activation(act_name):
Expand All @@ -152,4 +107,6 @@ def get_activation(act_name):
return nn.Sigmoid()
else:
print("invalid activation function!")
return None
return None


44 changes: 37 additions & 7 deletions rsl_rl/rsl_rl/runners/him_on_policy_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def __init__(self,
self.tot_timesteps = 0
self.tot_time = 0
self.current_learning_iteration = 0
self.best_reward = -float('inf')
self.best_model_path = None

_, _ = self.env.reset()

Expand Down Expand Up @@ -139,15 +141,33 @@ def learn(self, num_learning_iterations, init_at_random_ep_len=False):
start = stop
self.alg.compute_returns(critic_obs)

mean_value_loss, mean_surrogate_loss, mean_estimation_loss, mean_swap_loss = self.alg.update()
mean_value_loss, mean_surrogate_loss, mean_estimation_loss, mean_kl_loss, mean_recon_loss = self.alg.update()
stop = time.time()
learn_time = stop - start
if self.log_dir is not None:
self.log(locals())
if it % self.save_interval == 0:
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(it)))
cur_path = os.path.join(self.log_dir, 'model_{}.pt'.format(it))
self.save(cur_path)

# Delete previous checkpoint (keep only latest + best)
prev_it = it - self.save_interval
if prev_it > 0:
prev_path = os.path.join(self.log_dir, 'model_{}.pt'.format(prev_it))
if os.path.exists(prev_path) and prev_path != self.best_model_path:
os.remove(prev_path)

# Track best model by mean reward
if len(rewbuffer) > 0:
cur_reward = statistics.mean(rewbuffer)
if cur_reward > self.best_reward:
self.best_reward = cur_reward
self.best_model_path = os.path.join(self.log_dir, 'model_best.pt')
self.save(self.best_model_path)
print(f' ** New best model saved: reward={cur_reward:.2f} at iter {it}')

ep_infos.clear()

self.current_learning_iteration += num_learning_iterations
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration)))

Expand Down Expand Up @@ -176,7 +196,8 @@ def log(self, locs, width=80, pad=35):
self.writer.add_scalar('Loss/value_function', locs['mean_value_loss'], locs['it'])
self.writer.add_scalar('Loss/surrogate', locs['mean_surrogate_loss'], locs['it'])
self.writer.add_scalar('Loss/Estimation Loss', locs['mean_estimation_loss'], locs['it'])
self.writer.add_scalar('Loss/Swap Loss', locs['mean_swap_loss'], locs['it'])
self.writer.add_scalar('Loss/KL Loss', locs['mean_kl_loss'], locs['it'])
self.writer.add_scalar('Loss/Reconstruction Loss', locs['mean_recon_loss'], locs['it'])
self.writer.add_scalar('Loss/learning_rate', self.alg.learning_rate, locs['it'])
self.writer.add_scalar('Policy/mean_noise_std', mean_std.item(), locs['it'])
self.writer.add_scalar('Perf/total_fps', fps, locs['it'])
Expand All @@ -198,7 +219,8 @@ def log(self, locs, width=80, pad=35):
f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
f"""{'Estimation loss:':>{pad}} {locs['mean_estimation_loss']:.4f}\n"""
f"""{'Swap loss:':>{pad}} {locs['mean_swap_loss']:.4f}\n"""
f"""{'KL loss:':>{pad}} {locs['mean_kl_loss']:.4f}\n"""
f"""{'Reconstruction loss:':>{pad}} {locs['mean_recon_loss']:.4f}\n"""
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n"""
f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""")
Expand All @@ -212,7 +234,8 @@ def log(self, locs, width=80, pad=35):
f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
f"""{'Estimation loss:':>{pad}} {locs['mean_estimation_loss']:.4f}\n"""
f"""{'Swap loss:':>{pad}} {locs['mean_swap_loss']:.4f}\n"""
f"""{'KL loss:':>{pad}} {locs['mean_kl_loss']:.4f}\n"""
f"""{'Reconstruction loss:':>{pad}} {locs['mean_recon_loss']:.4f}\n"""
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""")
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""")
Expand All @@ -237,7 +260,14 @@ def save(self, path, infos=None):

def load(self, path, load_optimizer=True):
loaded_dict = torch.load(path)
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
# strict=False: allows loading partial weights
# old checkpoints (no decoder) load encoder+PPO, decoder stays random
missing, unexpected = self.alg.actor_critic.load_state_dict(
loaded_dict['model_state_dict'], strict=False)
if missing:
print(f'[load] Missing keys (will init randomly): {missing}')
if unexpected:
print(f'[load] Unexpected keys (ignored): {unexpected}')
if load_optimizer:
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'])
self.alg.actor_critic.estimator.optimizer.load_state_dict(loaded_dict['estimator_optimizer_state_dict'])
Expand Down