diff --git a/ibc/agents/ibc_policy.py b/ibc/agents/ibc_policy.py index 4371a23..72e98f4 100644 --- a/ibc/agents/ibc_policy.py +++ b/ibc/agents/ibc_policy.py @@ -65,7 +65,7 @@ def _create_variables(specs, training, step_type, network_state): return _create_variables -@tfp.experimental.register_composite +@tfp.experimental.auto_composite_tensor class MappedCategorical(tfp.distributions.Categorical): """Categorical distribution that maps classes to specific values."""