This repository was archived by the owner on May 6, 2022. It is now read-only.
File tree Expand file tree Collapse file tree 3 files changed +39
-0
lines changed
Expand file tree Collapse file tree 3 files changed +39
-0
lines changed Original file line number Diff line number Diff line change 1919 pip install --upgrade pip pip-tools
2020 pip install -r requirements.txt
2121 pip install tensorflow
22+ pip install torch
2223 python -m spacy download en_core_web_sm
2324 - save_cache :
2425 key : build-{{checksum "requirements.txt"}}
Original file line number Diff line number Diff line change 1+ """Pipeline compatible abstraction for Pytorch jit models."""
2+
3+ import numpy as np
4+ import torch
5+
6+
7+ class PyTorchModel :
8+ """Pytorch JIT Model."""
9+
10+ def __init__ (self , model_path : str , device : str = "cpu" ) -> None :
11+ self .model = torch .jit .load (model_path , map_location = device )
12+ self .model .eval ()
13+ self .device = device
14+
15+ def __call__ (self , inputs : np .ndarray ) -> np .ndarray :
16+ with torch .no_grad ():
17+ inputs = torch .from_numpy (inputs ).to (self .device )
18+ out = self .model (inputs )
19+
20+ return out
Original file line number Diff line number Diff line change 1+ """
2+ Tests for Pytorch model base class
3+ """
4+ from unittest import mock
5+
6+ import numpy as np
7+
8+ from spokestack .models .pytorch import PyTorchModel
9+
10+
11+ @mock .patch ("spokestack.models.pytorch.torch" )
12+ def test_inputs (* args ):
13+ sample = np .random .rand (1 , 128 , 80 ).astype (np .float32 )
14+ model = PyTorchModel (model_path = "torch_model" )
15+
16+ output = model (sample )
17+
18+ assert output
You can’t perform that action at this time.
0 commit comments