From bd03e64d1cbef95ad8acb9717144f2e40cc61c6f Mon Sep 17 00:00:00 2001 From: Matteo Interlandi Date: Wed, 5 May 2021 09:43:24 -0700 Subject: [PATCH 1/2] play with * --- hummingbird/ml/containers/_sklearn_api_containers.py | 2 +- hummingbird/ml/containers/sklearn/pytorch_containers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hummingbird/ml/containers/_sklearn_api_containers.py b/hummingbird/ml/containers/_sklearn_api_containers.py index 16dc63a73..6eecca12e 100644 --- a/hummingbird/ml/containers/_sklearn_api_containers.py +++ b/hummingbird/ml/containers/_sklearn_api_containers.py @@ -64,7 +64,7 @@ def _run(self, function, *inputs): splits = [inputs[input_names[idx]] for idx in range(len(input_names))] inputs = [df.to_numpy().reshape(-1, 1) for df in splits] - return function(*inputs) + return function(inputs) class SklearnContainerTransformer(SklearnContainer): diff --git a/hummingbird/ml/containers/sklearn/pytorch_containers.py b/hummingbird/ml/containers/sklearn/pytorch_containers.py index 7745e5681..ab440e2b1 100644 --- a/hummingbird/ml/containers/sklearn/pytorch_containers.py +++ b/hummingbird/ml/containers/sklearn/pytorch_containers.py @@ -282,7 +282,7 @@ class TorchScriptSklearnContainerClassification(PyTorchSklearnContainerClassific def predict(self, *inputs): device = get_device(self.model) f = super(TorchScriptSklearnContainerClassification, self)._predict - f_wrapped = lambda x: _torchscript_wrapper(device, f, x, extra_config=self._extra_config) # noqa: E731 + f_wrapped = lambda x: _torchscript_wrapper(device, f, *x, extra_config=self._extra_config) # noqa: E731 return self._run(f_wrapped, *inputs) From cf7c525ee54b46c8480e89861494ff09f5e9fb5d Mon Sep 17 00:00:00 2001 From: Matteo Interlandi Date: Wed, 5 May 2021 09:46:59 -0700 Subject: [PATCH 2/2] remove unecesary code --- hummingbird/ml/containers/sklearn/pytorch_containers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/hummingbird/ml/containers/sklearn/pytorch_containers.py b/hummingbird/ml/containers/sklearn/pytorch_containers.py index ab440e2b1..4b66c0b04 100644 --- a/hummingbird/ml/containers/sklearn/pytorch_containers.py +++ b/hummingbird/ml/containers/sklearn/pytorch_containers.py @@ -186,8 +186,6 @@ def _predict(self, *inputs): return output else: return output.ravel() - elif self._is_anomaly_detection: - return self.model.forward(*inputs)[0].cpu().numpy().ravel() else: return self.model.forward(*inputs)[0].cpu().numpy().ravel()