forked from modular/modular
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimple-inference.py
More file actions
105 lines (91 loc) · 3.25 KB
/
simple-inference.py
File metadata and controls
105 lines (91 loc) · 3.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
# suppress extraneous logging
import os
import platform
import signal
from argparse import ArgumentParser
import torch
from max import engine
from max.dtype import DType
from transformers import BertTokenizer
os.environ["TRANSFORMERS_VERBOSITY"] = "critical"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
BATCH = 1
SEQLEN = 128
DEFAULT_MODEL_PATH = "../../models/bert-mlm.torchscript"
DESCRIPTION = "BERT model"
HF_MODEL_NAME = "bert-base-uncased"
def execute(model_path, text, input_dict):
session = engine.InferenceSession()
input_spec_list = [
engine.TorchInputSpec(shape=tensor.size(), dtype=DType.int64)
for tensor in input_dict.values()
]
model = session.load(model_path, input_specs=input_spec_list)
tokenizer = BertTokenizer.from_pretrained(HF_MODEL_NAME)
print("Processing input...")
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=SEQLEN,
)
print("Input processed.\n")
masked_index = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(
as_tuple=True
)[1]
outputs = model.execute_legacy(**inputs)["result0"]
logits = torch.from_numpy(outputs[0, masked_index, :])
predicted_token_id = logits.argmax(dim=-1)
predicted_tokens = tokenizer.decode(
[predicted_token_id],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
return predicted_tokens
def main():
# Parse args
parser = ArgumentParser(description=DESCRIPTION)
parser.add_argument(
"--text",
type=str,
metavar="<text>",
required=True,
help="Masked language model.",
)
parser.add_argument(
"--model-path",
type=str,
default=DEFAULT_MODEL_PATH,
help="Directory for the downloaded model.",
)
args = parser.parse_args()
# Improves model compilation speed dramatically on intel CPUs
if "Intel" in platform.processor():
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
signal.signal(signal.SIGINT, signal.SIG_DFL)
torch.set_default_device("cpu")
input_dict = {
"input_ids": torch.zeros((BATCH, SEQLEN), dtype=torch.int64),
"attention_mask": torch.zeros((BATCH, SEQLEN), dtype=torch.int64),
"token_type_ids": torch.zeros((BATCH, SEQLEN), dtype=torch.int64),
}
outputs = execute(args.model_path, args.text, input_dict)
# Get the predictions for the masked token
print(f"input text: {args.text}")
print(f"filled mask: {args.text.replace('[MASK]', outputs)}")
if __name__ == "__main__":
main()