Skip to content

Initialization error of diffusion policy #17

@zichunxx

Description

@zichunxx

Hi, Albert! Thanks for your generous sharing!

I found that the Hydra parameters of the diffusion policy in the diffusion_policy.yaml

ema_factory:
_target_: diffusers.training_utils.EMAModel
_partial_: true
decay: 0.9999
use_ema_warmup: false
inv_gamma: 1.0
power: 0.75

do not match those of EMAModel

class EMAModel:
    """
    Exponential Moving Average of models weights
    """

    def __init__(
        self,
        model,
        update_after_step=0,
        inv_gamma=1.0,
        power=2 / 3,
        min_value=0.0,
        max_value=0.9999,
        device=None,
    ):

, which will trigger the initialization error.

Replacing

self.ema: EMAModel = ema_factory(parameters=self.networks.parameters())

with

 self.ema: EMAModel = ema_factory(model=self.networks)

could resolve this error.

Besides, for the following snippet,

def train(self, mode=True):
"""Override train method to manage EMA parameters."""
if mode and self._using_ema_params:
# Switching to train mode, restore original parameters
self.ema.restore(self.networks.parameters())
self._using_ema_params = False
elif not mode and not self._using_ema_params:
# Switching to eval mode, use EMA parameters
self.ema.store(self.networks.parameters())
self.ema.copy_to(self.networks.parameters())
self._using_ema_params = True
# Call parent train method
return super().train(mode)

EMAModel' object has no attribute 'store'. Thus,

self.ema.step(self.networks)

may be the right solution.

Thanks for your time.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions