Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ jobs:
wget https://zenodo.org/records/4735647/files/resnet50_v1.onnx > /dev/null 2>&1
IGNORE_DATASET_LIMITS=1 python3 computer_vision/classification/resnet_50_v1/run.py -m resnet50_v1.onnx -p fp32 -f ort

wget https://s3.amazonaws.com/onnx-model-zoo/vgg/vgg16/vgg16.tar.gz > /dev/null 2>&1
tar -xf vgg16.tar.gz > /dev/null
IGNORE_DATASET_LIMITS=1 python3 computer_vision/classification/vgg_16/run.py -m vgg16/vgg16.onnx -p fp32 -f ort
# wget https://s3.amazonaws.com/onnx-model-zoo/vgg/vgg16/vgg16.tar.gz > /dev/null 2>&1
# tar -xf vgg16.tar.gz > /dev/null
# IGNORE_DATASET_LIMITS=1 python3 computer_vision/classification/vgg_16/run.py -m vgg16/vgg16.onnx -p fp32 -f ort

test_arm64:
runs-on: self-hosted
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def parse_args():
parser.add_argument("--squad_path",
type=str,
help="path to directory with ImageNet validation images")
parser.add_argument("--fixed_input_size", type=int,
help='size of the input')
parser.add_argument("--disable_jit_freeze", action='store_true',
help="if true model will be run not in jit freeze mode")
return parser.parse_args()
Expand Down Expand Up @@ -93,7 +95,7 @@ def run_tf_fp16(model_path, batch_size, num_runs, timeout, squad_path, **kwargs)
return run_tf_fp(model_path, batch_size, num_runs, timeout, squad_path)


def run_pytorch_fp(model_path, batch_size, num_runs, timeout, squad_path, disable_jit_freeze=False):
def run_pytorch_fp(model_path, batch_size, num_runs, timeout, squad_path, fixed_input_size, disable_jit_freeze=False):
from utils.benchmark import run_model
from utils.nlp.squad import Squad_v1_1
from transformers import AutoTokenizer, BertConfig, BertForQuestionAnswering
Expand All @@ -117,7 +119,11 @@ def run_single_pass(pytorch_runner, squad):
padding=True, truncation=True, model_max_length=512)

def tokenize(question, text):
return tokenizer(question, text, padding=True, truncation=True, return_tensors="pt")
if fixed_input_size is not None:
return tokenizer(question, text, padding="max_length", truncation=True,
max_length=fixed_input_size, return_tensors="pt")
else:
return tokenizer(question, text, padding=True, truncation=True, return_tensors="pt")

def detokenize(answer):
return tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(answer))
Expand Down Expand Up @@ -199,8 +205,9 @@ def detokenize(answer):
return run_model(run_single_pass, runner, dataset, batch_size, num_runs, timeout)


def run_pytorch_fp32(model_path, batch_size, num_runs, timeout, squad_path, disable_jit_freeze, **kwargs):
return run_pytorch_fp(model_path, batch_size, num_runs, timeout, squad_path, disable_jit_freeze)
def run_pytorch_fp32(model_path, batch_size, num_runs, timeout, squad_path, fixed_input_size, disable_jit_freeze,
**kwargs):
return run_pytorch_fp(model_path, batch_size, num_runs, timeout, squad_path, fixed_input_size, disable_jit_freeze)


def main():
Expand Down
3 changes: 2 additions & 1 deletion tests/test_pytorch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def wrapper(**kwargs):

exact_match_ref, f1_ref = 0.750, 0.817
acc = run_process(wrapper, {"model_path": self.model_path, "squad_path": self.dataset_path,
"batch_size": 1, "num_runs": 24, "timeout": None, "disable_jit_freeze": False})
"batch_size": 1, "num_runs": 24, "timeout": None,
"fixed_input_size": None, "disable_jit_freeze": False})
self.assertTrue(acc["exact_match"] / exact_match_ref > 0.95)
self.assertTrue(acc["f1"] / f1_ref > 0.95)

Expand Down
Loading