Skip to content
This repository was archived by the owner on May 6, 2022. It is now read-only.

Commit 0491348

Browse files
will-ricespace-pope
authored andcommitted
Add Pytorch model abstraction
1 parent 754a884 commit 0491348

File tree

3 files changed

+39
-0
lines changed

3 files changed

+39
-0
lines changed

.circleci/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ jobs:
1919
pip install --upgrade pip pip-tools
2020
pip install -r requirements.txt
2121
pip install tensorflow
22+
pip install torch
2223
python -m spacy download en_core_web_sm
2324
- save_cache:
2425
key: build-{{checksum "requirements.txt"}}

spokestack/models/pytorch.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""Pipeline compatible abstraction for Pytorch jit models."""
2+
3+
import numpy as np
4+
import torch
5+
6+
7+
class PyTorchModel:
8+
"""Pytorch JIT Model."""
9+
10+
def __init__(self, model_path: str, device: str = "cpu") -> None:
11+
self.model = torch.jit.load(model_path, map_location=device)
12+
self.model.eval()
13+
self.device = device
14+
15+
def __call__(self, inputs: np.ndarray) -> np.ndarray:
16+
with torch.no_grad():
17+
inputs = torch.from_numpy(inputs).to(self.device)
18+
out = self.model(inputs)
19+
20+
return out

tests/models/test_pytorch.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""
2+
Tests for Pytorch model base class
3+
"""
4+
from unittest import mock
5+
6+
import numpy as np
7+
8+
from spokestack.models.pytorch import PyTorchModel
9+
10+
11+
@mock.patch("spokestack.models.pytorch.torch")
12+
def test_inputs(*args):
13+
sample = np.random.rand(1, 128, 80).astype(np.float32)
14+
model = PyTorchModel(model_path="torch_model")
15+
16+
output = model(sample)
17+
18+
assert output

0 commit comments

Comments
 (0)