Skip to content

Commit a6e1c7f

Browse files
committed
Update
[ghstack-poisoned]
2 parents 3c7552e + 384b592 commit a6e1c7f

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

test/test_collector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def forward(self, observation):
162162
output = self.linear(observation)
163163
if self.multiple_outputs:
164164
return output, output.sum(), output.min(), output.max()
165-
return self.linear(observation)
165+
return output
166166

167167

168168
class UnwrappablePolicy(nn.Module):

torchrl/collectors/collectors.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,9 @@ def update_policy_weights_(
486486

487487
strategy = WeightStrategy(extract_as="tensordict")
488488
weights = strategy.extract_weights(self._original_policy)
489+
# Cast weights to the policy device before applying
490+
if self.policy_device is not None:
491+
weights = weights.to(self.policy_device)
489492
strategy.apply_weights(self.policy, weights)
490493
# Otherwise, no action needed - policy is local and changes are immediately visible
491494

0 commit comments

Comments
 (0)