-
Notifications
You must be signed in to change notification settings - Fork 11
Open
Description
What are the expected shapes of the inputs to the forward function of SpearmanLoss? Does it accept batch data?
I think the SpearmanLoss's forward method may have to be modified to support batch input. E.g.
`
def forward(self, mem_pred, mem_gt, pr=False):
rank_gt = get_rank(mem_gt, -1)
if len(mem_pred.shape) == 1:
rank_pred = self.sorter(mem_pred.unsqueeze(
0)).view(-1)
else:
rank_pred = self.sorter(mem_pred)
return self.criterion_mse(rank_pred, rank_gt) + self.lbd * self.criterionl1(mem_pred, mem_gt)
`
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels