From a5a66966413bbab25d8d9a2c77c82895e56fa15c Mon Sep 17 00:00:00 2001 From: mrk0669 Date: Sun, 10 May 2026 12:28:36 +0530 Subject: [PATCH 1/6] Replace HIM dual-encoder + terrain table with DreamWaQ-style VAE The original HIMEstimator used two encoders (primary + target) with a prototype embedding table and Sinkhorn-based contrastive (swap) loss to learn terrain-aware latent representations. Following DreamWaQ (Nahrendra et al., 2023), this replaces that architecture with a simple Variational Autoencoder: - Encoder now outputs vel(3) + mu(16) + logvar(16) instead of vel(3) + z(16) - Reparameterization trick (z = mu + eps*sigma) is used during training; inference uses mu directly (deterministic, no noise) - Loss = MSE velocity estimation + beta * KL divergence to N(0,1) - Removes: target encoder, prototype embedding table, Sinkhorn algorithm - Old config keys (tar_hidden_dims, num_prototype, temperature) are kept as no-op params so existing config files remain compatible - him_ppo.py and him_on_policy_runner.py updated: swap_loss -> kl_loss --- rsl_rl/rsl_rl/algorithms/him_ppo.py | 10 +- rsl_rl/rsl_rl/modules/him_estimator.py | 120 ++++++------------ rsl_rl/rsl_rl/runners/him_on_policy_runner.py | 8 +- 3 files changed, 48 insertions(+), 90 deletions(-) diff --git a/rsl_rl/rsl_rl/algorithms/him_ppo.py b/rsl_rl/rsl_rl/algorithms/him_ppo.py index 42f1bdd..2b2c82a 100644 --- a/rsl_rl/rsl_rl/algorithms/him_ppo.py +++ b/rsl_rl/rsl_rl/algorithms/him_ppo.py @@ -120,7 +120,7 @@ def update(self): mean_value_loss = 0 mean_surrogate_loss = 0 mean_estimation_loss = 0 - mean_swap_loss = 0 + mean_kl_loss = 0 generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) @@ -150,7 +150,7 @@ def update(self): param_group['lr'] = self.learning_rate #Estimator Update - estimation_loss, swap_loss = self.actor_critic.estimator.update(obs_batch, next_critic_obs_batch, lr=self.learning_rate) + estimation_loss, kl_loss = self.actor_critic.estimator.update(obs_batch, next_critic_obs_batch, lr=self.learning_rate) # Surrogate loss ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch)) @@ -180,13 +180,13 @@ def update(self): mean_value_loss += value_loss.item() mean_surrogate_loss += surrogate_loss.item() mean_estimation_loss += estimation_loss - mean_swap_loss += swap_loss + mean_kl_loss += kl_loss num_updates = self.num_learning_epochs * self.num_mini_batches mean_value_loss /= num_updates mean_surrogate_loss /= num_updates mean_estimation_loss /= num_updates - mean_swap_loss /= num_updates + mean_kl_loss /= num_updates self.storage.clear() - return mean_value_loss, mean_surrogate_loss, estimation_loss, swap_loss + return mean_value_loss, mean_surrogate_loss, mean_estimation_loss, mean_kl_loss diff --git a/rsl_rl/rsl_rl/modules/him_estimator.py b/rsl_rl/rsl_rl/modules/him_estimator.py index 4bd920a..6a13f7e 100644 --- a/rsl_rl/rsl_rl/modules/him_estimator.py +++ b/rsl_rl/rsl_rl/modules/him_estimator.py @@ -1,11 +1,7 @@ -import copy -import math import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F -import torch.distributions as torchd -from torch.distributions import Normal, Categorical class HIMEstimator(nn.Module): @@ -13,124 +9,86 @@ def __init__(self, temporal_steps, num_one_step_obs, enc_hidden_dims=[128, 64, 16], - tar_hidden_dims=[128, 64], + tar_hidden_dims=[128, 64], # kept for config compatibility, unused activation='elu', learning_rate=1e-3, max_grad_norm=10.0, - num_prototype=32, - temperature=3.0, + num_prototype=32, # kept for config compatibility, unused + temperature=3.0, # kept for config compatibility, unused + kl_weight=1.0, **kwargs): if kwargs: - print("Estimator_CL.__init__ got unexpected arguments, which will be ignored: " + str( - [key for key in kwargs.keys()])) + print("HIMEstimator.__init__ got unexpected arguments, which will be ignored: " + + str([key for key in kwargs.keys()])) super(HIMEstimator, self).__init__() - activation = get_activation(activation) + activation_fn = get_activation(activation) self.temporal_steps = temporal_steps self.num_one_step_obs = num_one_step_obs self.num_latent = enc_hidden_dims[-1] self.max_grad_norm = max_grad_norm - self.temperature = temperature + self.kl_weight = kl_weight - # Encoder + # Encoder: outputs vel(3) + mu(num_latent) + logvar(num_latent) enc_input_dim = self.temporal_steps * self.num_one_step_obs enc_layers = [] for l in range(len(enc_hidden_dims) - 1): - enc_layers += [nn.Linear(enc_input_dim, enc_hidden_dims[l]), activation] + enc_layers += [nn.Linear(enc_input_dim, enc_hidden_dims[l]), activation_fn] enc_input_dim = enc_hidden_dims[l] - enc_layers += [nn.Linear(enc_input_dim, enc_hidden_dims[-1] + 3)] + enc_layers += [nn.Linear(enc_input_dim, 3 + self.num_latent * 2)] self.encoder = nn.Sequential(*enc_layers) - # Target - tar_input_dim = self.num_one_step_obs - tar_layers = [] - for l in range(len(tar_hidden_dims)): - tar_layers += [nn.Linear(tar_input_dim, tar_hidden_dims[l]), activation] - tar_input_dim = tar_hidden_dims[l] - tar_layers += [nn.Linear(tar_input_dim, enc_hidden_dims[-1])] - self.target = nn.Sequential(*tar_layers) - - # Prototype - self.proto = nn.Embedding(num_prototype, enc_hidden_dims[-1]) - - # Optimizer self.learning_rate = learning_rate self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + def get_latent(self, obs_history): - vel, z = self.encode(obs_history) - return vel.detach(), z.detach() + """Inference: use mu directly (no sampling noise).""" + out = self.encoder(obs_history.detach()) + vel = out[..., :3] + mu = out[..., 3:3 + self.num_latent] + return vel.detach(), mu.detach() def forward(self, obs_history): - parts = self.encoder(obs_history.detach()) - vel, z = parts[..., :3], parts[..., 3:] - z = F.normalize(z, dim=-1, p=2) - return vel.detach(), z.detach() + return self.get_latent(obs_history) def encode(self, obs_history): - parts = self.encoder(obs_history.detach()) - vel, z = parts[..., :3], parts[..., 3:] - z = F.normalize(z, dim=-1, p=2) - return vel, z + """Training: sample z via reparameterization.""" + out = self.encoder(obs_history.detach()) + vel = out[..., :3] + mu = out[..., 3:3 + self.num_latent] + logvar = out[..., 3 + self.num_latent:] + z = self.reparameterize(mu, logvar) + return vel, mu, logvar, z def update(self, obs_history, next_critic_obs, lr=None): if lr is not None: self.learning_rate = lr for param_group in self.optimizer.param_groups: param_group['lr'] = self.learning_rate - - vel = next_critic_obs[:, self.num_one_step_obs:self.num_one_step_obs+3].detach() - next_obs = next_critic_obs.detach()[:, 3:self.num_one_step_obs+3] - - z_s = self.encoder(obs_history) - z_t = self.target(next_obs) - pred_vel, z_s = z_s[..., :3], z_s[..., 3:] - - z_s = F.normalize(z_s, dim=-1, p=2) - z_t = F.normalize(z_t, dim=-1, p=2) - with torch.no_grad(): - w = self.proto.weight.data.clone() - w = F.normalize(w, dim=-1, p=2) - self.proto.weight.copy_(w) + # Ground-truth velocity from privileged obs + vel_gt = next_critic_obs[:, self.num_one_step_obs:self.num_one_step_obs + 3].detach() - score_s = z_s @ self.proto.weight.T - score_t = z_t @ self.proto.weight.T + pred_vel, mu, logvar, _ = self.encode(obs_history) - with torch.no_grad(): - q_s = sinkhorn(score_s) - q_t = sinkhorn(score_t) + estimation_loss = F.mse_loss(pred_vel, vel_gt) - log_p_s = F.log_softmax(score_s / self.temperature, dim=-1) - log_p_t = F.log_softmax(score_t / self.temperature, dim=-1) + # KL divergence: D_KL( N(mu, sigma) || N(0,1) ) + kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()) - swap_loss = -0.5 * (q_s * log_p_t + q_t * log_p_s).mean() - estimation_loss = F.mse_loss(pred_vel, vel) - losses = estimation_loss + swap_loss + loss = estimation_loss + self.kl_weight * kl_loss self.optimizer.zero_grad() - losses.backward() + loss.backward() nn.utils.clip_grad_norm_(self.parameters(), self.max_grad_norm) self.optimizer.step() - return estimation_loss.item(), swap_loss.item() - - -@torch.no_grad() -def sinkhorn(out, eps=0.05, iters=3): - Q = torch.exp(out / eps).T - K, B = Q.shape[0], Q.shape[1] - Q /= Q.sum() - - for it in range(iters): - # normalize each row: total weight per prototype must be 1/K - Q /= torch.sum(Q, dim=1, keepdim=True) - Q /= K - - # normalize each column: total weight per sample must be 1/B - Q /= torch.sum(Q, dim=0, keepdim=True) - Q /= B - return (Q * B).T + return estimation_loss.item(), kl_loss.item() def get_activation(act_name): @@ -152,4 +110,4 @@ def get_activation(act_name): return nn.Sigmoid() else: print("invalid activation function!") - return None \ No newline at end of file + return None diff --git a/rsl_rl/rsl_rl/runners/him_on_policy_runner.py b/rsl_rl/rsl_rl/runners/him_on_policy_runner.py index 5fa28b2..35a553f 100644 --- a/rsl_rl/rsl_rl/runners/him_on_policy_runner.py +++ b/rsl_rl/rsl_rl/runners/him_on_policy_runner.py @@ -139,7 +139,7 @@ def learn(self, num_learning_iterations, init_at_random_ep_len=False): start = stop self.alg.compute_returns(critic_obs) - mean_value_loss, mean_surrogate_loss, mean_estimation_loss, mean_swap_loss = self.alg.update() + mean_value_loss, mean_surrogate_loss, mean_estimation_loss, mean_kl_loss = self.alg.update() stop = time.time() learn_time = stop - start if self.log_dir is not None: @@ -176,7 +176,7 @@ def log(self, locs, width=80, pad=35): self.writer.add_scalar('Loss/value_function', locs['mean_value_loss'], locs['it']) self.writer.add_scalar('Loss/surrogate', locs['mean_surrogate_loss'], locs['it']) self.writer.add_scalar('Loss/Estimation Loss', locs['mean_estimation_loss'], locs['it']) - self.writer.add_scalar('Loss/Swap Loss', locs['mean_swap_loss'], locs['it']) + self.writer.add_scalar('Loss/KL Loss', locs['mean_kl_loss'], locs['it']) self.writer.add_scalar('Loss/learning_rate', self.alg.learning_rate, locs['it']) self.writer.add_scalar('Policy/mean_noise_std', mean_std.item(), locs['it']) self.writer.add_scalar('Perf/total_fps', fps, locs['it']) @@ -198,7 +198,7 @@ def log(self, locs, width=80, pad=35): f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n""" f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" f"""{'Estimation loss:':>{pad}} {locs['mean_estimation_loss']:.4f}\n""" - f"""{'Swap loss:':>{pad}} {locs['mean_swap_loss']:.4f}\n""" + f"""{'KL loss:':>{pad}} {locs['mean_kl_loss']:.4f}\n""" f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""" f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n""" f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""") @@ -212,7 +212,7 @@ def log(self, locs, width=80, pad=35): f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n""" f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" f"""{'Estimation loss:':>{pad}} {locs['mean_estimation_loss']:.4f}\n""" - f"""{'Swap loss:':>{pad}} {locs['mean_swap_loss']:.4f}\n""" + f"""{'KL loss:':>{pad}} {locs['mean_kl_loss']:.4f}\n""" f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""") # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n""" # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""") From a82fedb4a75e90b9a87a3cf78c934412405a0b69 Mon Sep 17 00:00:00 2001 From: mrk0669 Date: Sun, 7 Jun 2026 15:49:52 +0530 Subject: [PATCH 2/6] Remove unused HIM params (tar_hidden_dims, num_prototype, temperature) --- rsl_rl/rsl_rl/modules/him_estimator.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/rsl_rl/rsl_rl/modules/him_estimator.py b/rsl_rl/rsl_rl/modules/him_estimator.py index 6a13f7e..f6d7585 100644 --- a/rsl_rl/rsl_rl/modules/him_estimator.py +++ b/rsl_rl/rsl_rl/modules/him_estimator.py @@ -9,12 +9,9 @@ def __init__(self, temporal_steps, num_one_step_obs, enc_hidden_dims=[128, 64, 16], - tar_hidden_dims=[128, 64], # kept for config compatibility, unused activation='elu', learning_rate=1e-3, max_grad_norm=10.0, - num_prototype=32, # kept for config compatibility, unused - temperature=3.0, # kept for config compatibility, unused kl_weight=1.0, **kwargs): if kwargs: From 7b55d931c7055e96677db6132cd1079e4725d215 Mon Sep 17 00:00:00 2001 From: mrk0669 Date: Mon, 8 Jun 2026 16:03:59 +0530 Subject: [PATCH 3/6] Add VAE decoder + reconstruction loss (complete DreamWaQ implementation) - him_estimator.py: add decoder NN (19->64->128->45), keep z from encode(), compute recon_loss = MSE(pred_next_obs, actual_next_obs), total loss = vel_loss + kl_loss + recon_loss, return 3 values - him_ppo.py: unpack 3 return values, track mean_recon_loss - him_on_policy_runner.py: log Reconstruction Loss to TensorBoard, print in training output Co-Authored-By: Claude Sonnet 4.6 --- rsl_rl/rsl_rl/algorithms/him_ppo.py | 7 +++-- rsl_rl/rsl_rl/modules/him_estimator.py | 30 +++++++++++++++---- rsl_rl/rsl_rl/runners/him_on_policy_runner.py | 5 +++- 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/rsl_rl/rsl_rl/algorithms/him_ppo.py b/rsl_rl/rsl_rl/algorithms/him_ppo.py index 2b2c82a..0eaf0ce 100644 --- a/rsl_rl/rsl_rl/algorithms/him_ppo.py +++ b/rsl_rl/rsl_rl/algorithms/him_ppo.py @@ -121,6 +121,7 @@ def update(self): mean_surrogate_loss = 0 mean_estimation_loss = 0 mean_kl_loss = 0 + mean_recon_loss = 0 generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) @@ -150,7 +151,7 @@ def update(self): param_group['lr'] = self.learning_rate #Estimator Update - estimation_loss, kl_loss = self.actor_critic.estimator.update(obs_batch, next_critic_obs_batch, lr=self.learning_rate) + estimation_loss, kl_loss, recon_loss = self.actor_critic.estimator.update(obs_batch, next_critic_obs_batch, lr=self.learning_rate) # Surrogate loss ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch)) @@ -181,12 +182,14 @@ def update(self): mean_surrogate_loss += surrogate_loss.item() mean_estimation_loss += estimation_loss mean_kl_loss += kl_loss + mean_recon_loss += recon_loss num_updates = self.num_learning_epochs * self.num_mini_batches mean_value_loss /= num_updates mean_surrogate_loss /= num_updates mean_estimation_loss /= num_updates mean_kl_loss /= num_updates + mean_recon_loss /= num_updates self.storage.clear() - return mean_value_loss, mean_surrogate_loss, mean_estimation_loss, mean_kl_loss + return mean_value_loss, mean_surrogate_loss, mean_estimation_loss, mean_kl_loss, mean_recon_loss diff --git a/rsl_rl/rsl_rl/modules/him_estimator.py b/rsl_rl/rsl_rl/modules/him_estimator.py index f6d7585..7865be4 100644 --- a/rsl_rl/rsl_rl/modules/him_estimator.py +++ b/rsl_rl/rsl_rl/modules/him_estimator.py @@ -35,6 +35,14 @@ def __init__(self, enc_layers += [nn.Linear(enc_input_dim, 3 + self.num_latent * 2)] self.encoder = nn.Sequential(*enc_layers) + # Decoder: vel(3) + z(num_latent) -> next_obs(num_one_step_obs) + dec_input_dim = 3 + self.num_latent # 19 + self.decoder = nn.Sequential( + nn.Linear(dec_input_dim, 64), activation_fn, + nn.Linear(64, 128), activation_fn, + nn.Linear(128, self.num_one_step_obs) # 45 + ) + self.learning_rate = learning_rate self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) @@ -69,23 +77,33 @@ def update(self, obs_history, next_critic_obs, lr=None): param_group['lr'] = self.learning_rate # Ground-truth velocity from privileged obs - vel_gt = next_critic_obs[:, self.num_one_step_obs:self.num_one_step_obs + 3].detach() + vel_gt = next_critic_obs[:, self.num_one_step_obs:self.num_one_step_obs + 3].detach() + + # Ground-truth next observation (first num_one_step_obs dims of next_critic_obs) + next_obs_gt = next_critic_obs[:, :self.num_one_step_obs].detach() - pred_vel, mu, logvar, _ = self.encode(obs_history) + pred_vel, mu, logvar, z = self.encode(obs_history) - estimation_loss = F.mse_loss(pred_vel, vel_gt) + # Decode: reconstruct next observation from vel + z + dec_input = torch.cat([pred_vel, z], dim=-1) # (batch, 19) + pred_next_obs = self.decoder(dec_input) # (batch, 45) + + estimation_loss = F.mse_loss(pred_vel, vel_gt) # KL divergence: D_KL( N(mu, sigma) || N(0,1) ) - kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()) + kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()) + + # Reconstruction loss: predicted next obs vs actual next obs + recon_loss = F.mse_loss(pred_next_obs, next_obs_gt) - loss = estimation_loss + self.kl_weight * kl_loss + loss = estimation_loss + self.kl_weight * kl_loss + recon_loss self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(self.parameters(), self.max_grad_norm) self.optimizer.step() - return estimation_loss.item(), kl_loss.item() + return estimation_loss.item(), kl_loss.item(), recon_loss.item() def get_activation(act_name): diff --git a/rsl_rl/rsl_rl/runners/him_on_policy_runner.py b/rsl_rl/rsl_rl/runners/him_on_policy_runner.py index 35a553f..9cf20e7 100644 --- a/rsl_rl/rsl_rl/runners/him_on_policy_runner.py +++ b/rsl_rl/rsl_rl/runners/him_on_policy_runner.py @@ -139,7 +139,7 @@ def learn(self, num_learning_iterations, init_at_random_ep_len=False): start = stop self.alg.compute_returns(critic_obs) - mean_value_loss, mean_surrogate_loss, mean_estimation_loss, mean_kl_loss = self.alg.update() + mean_value_loss, mean_surrogate_loss, mean_estimation_loss, mean_kl_loss, mean_recon_loss = self.alg.update() stop = time.time() learn_time = stop - start if self.log_dir is not None: @@ -177,6 +177,7 @@ def log(self, locs, width=80, pad=35): self.writer.add_scalar('Loss/surrogate', locs['mean_surrogate_loss'], locs['it']) self.writer.add_scalar('Loss/Estimation Loss', locs['mean_estimation_loss'], locs['it']) self.writer.add_scalar('Loss/KL Loss', locs['mean_kl_loss'], locs['it']) + self.writer.add_scalar('Loss/Reconstruction Loss', locs['mean_recon_loss'], locs['it']) self.writer.add_scalar('Loss/learning_rate', self.alg.learning_rate, locs['it']) self.writer.add_scalar('Policy/mean_noise_std', mean_std.item(), locs['it']) self.writer.add_scalar('Perf/total_fps', fps, locs['it']) @@ -199,6 +200,7 @@ def log(self, locs, width=80, pad=35): f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" f"""{'Estimation loss:':>{pad}} {locs['mean_estimation_loss']:.4f}\n""" f"""{'KL loss:':>{pad}} {locs['mean_kl_loss']:.4f}\n""" + f"""{'Reconstruction loss:':>{pad}} {locs['mean_recon_loss']:.4f}\n""" f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""" f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n""" f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""") @@ -213,6 +215,7 @@ def log(self, locs, width=80, pad=35): f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" f"""{'Estimation loss:':>{pad}} {locs['mean_estimation_loss']:.4f}\n""" f"""{'KL loss:':>{pad}} {locs['mean_kl_loss']:.4f}\n""" + f"""{'Reconstruction loss:':>{pad}} {locs['mean_recon_loss']:.4f}\n""" f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""") # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n""" # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""") From 3c3026e3f88fcbe52541be6608e871893b664361 Mon Sep 17 00:00:00 2001 From: mrk0669 Date: Mon, 8 Jun 2026 20:46:36 +0530 Subject: [PATCH 4/6] Save every 10 iters, delete old checkpoints, track best model - save_interval respected (every 10 iters via config) - previous checkpoint deleted after each save (saves disk space) - best model saved separately as model_best.pt whenever reward improves - strict=False on load to support M4->decoder transfer Co-Authored-By: Claude Sonnet 4.6 --- rsl_rl/rsl_rl/runners/him_on_policy_runner.py | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/rsl_rl/rsl_rl/runners/him_on_policy_runner.py b/rsl_rl/rsl_rl/runners/him_on_policy_runner.py index 9cf20e7..d5ec019 100644 --- a/rsl_rl/rsl_rl/runners/him_on_policy_runner.py +++ b/rsl_rl/rsl_rl/runners/him_on_policy_runner.py @@ -80,6 +80,8 @@ def __init__(self, self.tot_timesteps = 0 self.tot_time = 0 self.current_learning_iteration = 0 + self.best_reward = -float('inf') + self.best_model_path = None _, _ = self.env.reset() @@ -145,9 +147,27 @@ def learn(self, num_learning_iterations, init_at_random_ep_len=False): if self.log_dir is not None: self.log(locals()) if it % self.save_interval == 0: - self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(it))) + cur_path = os.path.join(self.log_dir, 'model_{}.pt'.format(it)) + self.save(cur_path) + + # Delete previous checkpoint (keep only latest + best) + prev_it = it - self.save_interval + if prev_it > 0: + prev_path = os.path.join(self.log_dir, 'model_{}.pt'.format(prev_it)) + if os.path.exists(prev_path) and prev_path != self.best_model_path: + os.remove(prev_path) + + # Track best model by mean reward + if len(rewbuffer) > 0: + cur_reward = statistics.mean(rewbuffer) + if cur_reward > self.best_reward: + self.best_reward = cur_reward + self.best_model_path = os.path.join(self.log_dir, 'model_best.pt') + self.save(self.best_model_path) + print(f' ** New best model saved: reward={cur_reward:.2f} at iter {it}') + ep_infos.clear() - + self.current_learning_iteration += num_learning_iterations self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration))) @@ -240,7 +260,14 @@ def save(self, path, infos=None): def load(self, path, load_optimizer=True): loaded_dict = torch.load(path) - self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict']) + # strict=False: allows loading partial weights + # old checkpoints (no decoder) load encoder+PPO, decoder stays random + missing, unexpected = self.alg.actor_critic.load_state_dict( + loaded_dict['model_state_dict'], strict=False) + if missing: + print(f'[load] Missing keys (will init randomly): {missing}') + if unexpected: + print(f'[load] Unexpected keys (ignored): {unexpected}') if load_optimizer: self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict']) self.alg.actor_critic.estimator.optimizer.load_state_dict(loaded_dict['estimator_optimizer_state_dict']) From 35fde29711efa79058f1187cbc2df489727dfd4a Mon Sep 17 00:00:00 2001 From: mrk0669 Date: Mon, 8 Jun 2026 20:47:14 +0530 Subject: [PATCH 5/6] Set save_interval=10 for more frequent checkpointing Co-Authored-By: Claude Sonnet 4.6 --- legged_gym/legged_gym/envs/base/legged_robot_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/legged_gym/legged_gym/envs/base/legged_robot_config.py b/legged_gym/legged_gym/envs/base/legged_robot_config.py index 41a1c1e..151faeb 100644 --- a/legged_gym/legged_gym/envs/base/legged_robot_config.py +++ b/legged_gym/legged_gym/envs/base/legged_robot_config.py @@ -267,7 +267,7 @@ class runner: max_iterations = 200000 # number of policy updates # logging - save_interval = 20 # check for potential saves every this many iterations + save_interval = 10 # check for potential saves every this many iterations experiment_name = 'test' run_name = '' # load and resume From e218cc925814282e66b80d41d5cf3e267f672b17 Mon Sep 17 00:00:00 2001 From: mrk0669 Date: Wed, 10 Jun 2026 21:03:41 +0530 Subject: [PATCH 6/6] Update him_estimator.py --- rsl_rl/rsl_rl/modules/him_estimator.py | 32 +++++++------------------- 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/rsl_rl/rsl_rl/modules/him_estimator.py b/rsl_rl/rsl_rl/modules/him_estimator.py index 7865be4..f14df6b 100644 --- a/rsl_rl/rsl_rl/modules/him_estimator.py +++ b/rsl_rl/rsl_rl/modules/him_estimator.py @@ -35,14 +35,6 @@ def __init__(self, enc_layers += [nn.Linear(enc_input_dim, 3 + self.num_latent * 2)] self.encoder = nn.Sequential(*enc_layers) - # Decoder: vel(3) + z(num_latent) -> next_obs(num_one_step_obs) - dec_input_dim = 3 + self.num_latent # 19 - self.decoder = nn.Sequential( - nn.Linear(dec_input_dim, 64), activation_fn, - nn.Linear(64, 128), activation_fn, - nn.Linear(128, self.num_one_step_obs) # 45 - ) - self.learning_rate = learning_rate self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) @@ -77,33 +69,23 @@ def update(self, obs_history, next_critic_obs, lr=None): param_group['lr'] = self.learning_rate # Ground-truth velocity from privileged obs - vel_gt = next_critic_obs[:, self.num_one_step_obs:self.num_one_step_obs + 3].detach() - - # Ground-truth next observation (first num_one_step_obs dims of next_critic_obs) - next_obs_gt = next_critic_obs[:, :self.num_one_step_obs].detach() - - pred_vel, mu, logvar, z = self.encode(obs_history) + vel_gt = next_critic_obs[:, self.num_one_step_obs:self.num_one_step_obs + 3].detach() - # Decode: reconstruct next observation from vel + z - dec_input = torch.cat([pred_vel, z], dim=-1) # (batch, 19) - pred_next_obs = self.decoder(dec_input) # (batch, 45) + pred_vel, mu, logvar, _ = self.encode(obs_history) - estimation_loss = F.mse_loss(pred_vel, vel_gt) + estimation_loss = F.mse_loss(pred_vel, vel_gt) # KL divergence: D_KL( N(mu, sigma) || N(0,1) ) - kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()) + kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()) - # Reconstruction loss: predicted next obs vs actual next obs - recon_loss = F.mse_loss(pred_next_obs, next_obs_gt) - - loss = estimation_loss + self.kl_weight * kl_loss + recon_loss + loss = estimation_loss + self.kl_weight * kl_loss self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(self.parameters(), self.max_grad_norm) self.optimizer.step() - return estimation_loss.item(), kl_loss.item(), recon_loss.item() + return estimation_loss.item(), kl_loss.item() def get_activation(act_name): @@ -126,3 +108,5 @@ def get_activation(act_name): else: print("invalid activation function!") return None + +