Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions robomimic/algo/diffusion_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
Implementation of Diffusion Policy https://diffusion-policy.cs.columbia.edu/ by Cheng Chi
"""
from typing import Callable, Union
import math
from copy import deepcopy
from collections import OrderedDict, deque
from packaging.version import parse as parse_version
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
# requires diffusers==0.11.1
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.training_utils import EMAModel
Expand All @@ -22,7 +20,6 @@

from robomimic.algo import register_algo_factory_func, PolicyAlgo

import random
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.obs_utils as ObsUtils
Expand Down Expand Up @@ -109,13 +106,17 @@ def _create_networks(self):

# setup EMA
ema = None
ema_nets = None
if self.algo_config.ema.enabled:
ema = EMAModel(model=nets, power=self.algo_config.ema.power)
ema = EMAModel(parameters=nets.parameters(), power=self.algo_config.ema.power)
ema_nets = deepcopy(nets)

# set attrs
self.nets = nets
self.noise_scheduler = noise_scheduler
self.ema = ema
self.ema_nets = ema_nets
self.ema_nets.eval()
self.action_check_done = False
self.obs_queue = None
self.action_queue = None
Expand Down Expand Up @@ -232,7 +233,7 @@ def train_on_batch(self, batch, epoch, validate=False):

# update Exponential Moving Average of the model weights
if self.ema is not None:
self.ema.step(self.nets)
self.ema.step(self.nets.parameters())

step_info = {
"policy_grad_norms": policy_grad_norms
Expand Down Expand Up @@ -330,7 +331,7 @@ def _get_action_trajectory(self, obs_dict, goal_dict=None):
# select network
nets = self.nets
if self.ema is not None:
nets = self.ema.averaged_model
nets = self.ema_nets

# encode obs
inputs = {
Expand Down Expand Up @@ -382,6 +383,8 @@ def serialize(self):
"""
Get dictionary of current model parameters.
"""
if self.ema is not None:
self.ema.copy_to(self.ema_nets.parameters())
return {
"nets": self.nets.state_dict(),
"optimizers": { k : self.optimizers[k].state_dict() for k in self.optimizers },
Expand All @@ -408,7 +411,8 @@ def deserialize(self, model_dict, load_optimizers=False):
model_dict["lr_schedulers"] = {}

if model_dict.get("ema", None) is not None:
self.ema.averaged_model.load_state_dict(model_dict["ema"])
self.ema.load_state_dict(model_dict["ema"])
self.ema_nets.load_state_dict(model_dict["nets"])

if load_optimizers:
for k in model_dict["optimizers"]:
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
"egl_probe>=1.0.1",
"torch",
"torchvision",
"huggingface_hub==0.23.4",
"huggingface_hub",
"transformers==4.41.2",
"diffusers==0.11.1",
"diffusers==0.34.0",
],
eager_resources=['*'],
include_package_data=True,
Expand Down