diff --git a/hummingbird/ml/containers/_sklearn_api_containers.py b/hummingbird/ml/containers/_sklearn_api_containers.py index 16dc63a73..6eecca12e 100644 --- a/hummingbird/ml/containers/_sklearn_api_containers.py +++ b/hummingbird/ml/containers/_sklearn_api_containers.py @@ -64,7 +64,7 @@ def _run(self, function, *inputs): splits = [inputs[input_names[idx]] for idx in range(len(input_names))] inputs = [df.to_numpy().reshape(-1, 1) for df in splits] - return function(*inputs) + return function(inputs) class SklearnContainerTransformer(SklearnContainer): diff --git a/hummingbird/ml/containers/sklearn/pytorch_containers.py b/hummingbird/ml/containers/sklearn/pytorch_containers.py index 7745e5681..4b66c0b04 100644 --- a/hummingbird/ml/containers/sklearn/pytorch_containers.py +++ b/hummingbird/ml/containers/sklearn/pytorch_containers.py @@ -186,8 +186,6 @@ def _predict(self, *inputs): return output else: return output.ravel() - elif self._is_anomaly_detection: - return self.model.forward(*inputs)[0].cpu().numpy().ravel() else: return self.model.forward(*inputs)[0].cpu().numpy().ravel() @@ -282,7 +280,7 @@ class TorchScriptSklearnContainerClassification(PyTorchSklearnContainerClassific def predict(self, *inputs): device = get_device(self.model) f = super(TorchScriptSklearnContainerClassification, self)._predict - f_wrapped = lambda x: _torchscript_wrapper(device, f, x, extra_config=self._extra_config) # noqa: E731 + f_wrapped = lambda x: _torchscript_wrapper(device, f, *x, extra_config=self._extra_config) # noqa: E731 return self._run(f_wrapped, *inputs)