Skip to content

Commit 8df0d99

Browse files
committed
Merge remote-tracking branch 'origin/main'
2 parents 776b1a4 + e952056 commit 8df0d99

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

scripts/infernce.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import click
2+
import torch
3+
from pytorchvideo.data.encoded_video import EncodedVideo
4+
from pytorchvideo.transforms import ApplyTransformToKey, UniformTemporalSubsample, ShortSideScale
5+
from torchvision.transforms import Compose, Lambda
6+
from torchvision.transforms._transforms_video import NormalizeVideo, CenterCropVideo
7+
8+
from tubevit.model import TubeViTLightningModule
9+
10+
11+
@click.command()
12+
@click.argument("video-path")
13+
@click.option("-m", "--model-path", type=click.Path(exists=True), required=True, help="path to model weight.")
14+
@click.option("--label-path", type=click.Path(exists=True), required=True, help="path to classInd.txt.")
15+
@click.option("-f", "--frames-per-clip", type=int, default=32, help="frame per clip.")
16+
@click.option("-v", "--video-size", type=click.Tuple([int, int]), default=(224, 224), help="frame per clip.")
17+
def main(
18+
video_path,
19+
model_path,
20+
label_path,
21+
frames_per_clip,
22+
video_size,
23+
):
24+
with open(label_path, "r") as f:
25+
labels = f.read().splitlines()
26+
labels = list(map(lambda x: x.split(" ")[-1], labels))
27+
28+
# Compose video data transforms
29+
transform = ApplyTransformToKey(
30+
key="video",
31+
transform=Compose(
32+
[
33+
UniformTemporalSubsample(frames_per_clip),
34+
Lambda(lambda x: x / 255.0),
35+
NormalizeVideo(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
36+
ShortSideScale(
37+
size=video_size[0]
38+
),
39+
CenterCropVideo(crop_size=video_size)
40+
]
41+
),
42+
)
43+
44+
# Load video
45+
video = EncodedVideo.from_path(video_path)
46+
# Get clip
47+
clip_start_sec = 0.0 # secs
48+
clip_duration = 2.0 # secs
49+
duration = video.duration
50+
video_data = []
51+
for i in range(10):
52+
if clip_start_sec + clip_duration * (i + 1) <= duration:
53+
data = video.get_clip(start_sec=clip_start_sec + clip_duration * i,
54+
end_sec=clip_start_sec + clip_duration * (i + 1))
55+
data = transform(data)
56+
video_data.append(data['video'])
57+
58+
video_data = torch.stack(video_data)
59+
model = TubeViTLightningModule.load_from_checkpoint(model_path)
60+
prediction = model.predict_step(batch=(video_data, None), batch_idx=0)
61+
print(video_data.shape)
62+
print('Predict:', labels[torch.argmax(torch.sum(prediction['y_prob'], dim=0)).to('cpu').item()])
63+
64+
65+
if __name__ == "__main__":
66+
main()

tubevit/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def _calc_conv_shape(self, kernel_size, stride, offset) -> np.ndarray:
216216
kernel_size = np.array(kernel_size)
217217
stride = np.array(stride)
218218
offset = np.array(offset)
219-
output = np.ceil((self.video_shape[[1, 2, 3]] - offset - kernel_size + 1) / stride).astype(int)
219+
output = np.floor(((self.video_shape[[1, 2, 3]] - offset - kernel_size) / stride) + 1).astype(int)
220220
return output
221221

222222
def _generate_position_embedding(self) -> torch.nn.Parameter:

0 commit comments

Comments
 (0)