Skip to content

Commit b4e3212

Browse files
committed
Fix json result processor
1 parent e5fe140 commit b4e3212

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

chebai/result/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _generate_predictions(self, data_path, raw=False, **kwargs):
5353
for x in self.dataset._load_dict(data_path)
5454
]
5555
else:
56-
data_tuples = torch.load(data_path)
56+
data_tuples = [(x.get("raw_features", x["ident"]), x["ident"], x) for x in torch.load(data_path)]
5757

5858
for raw_features, ident, row in tqdm.tqdm(data_tuples):
5959
raw_labels = row.get("labels")

chebai/result/prediction_json.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ def close(self):
1616
json.dump(self.data, fout)
1717
del self.data
1818

19-
def process_prediction(self, proc_id, features, labels, pred, ident):
19+
def process_prediction(self, proc_id, raw_features, labels, preds, ident, **kwargs):
2020
self.data.append(
2121
dict(
2222
ident=ident,
2323
labels=labels if labels is not None else None,
24-
prediction=pred.tolist(),
24+
prediction=preds.tolist(),
2525
)
2626
)

0 commit comments

Comments
 (0)