|
| 1 | +import click |
| 2 | +import torch |
| 3 | +from pytorchvideo.data.encoded_video import EncodedVideo |
| 4 | +from pytorchvideo.transforms import ApplyTransformToKey, UniformTemporalSubsample, ShortSideScale |
| 5 | +from torchvision.transforms import Compose, Lambda |
| 6 | +from torchvision.transforms._transforms_video import NormalizeVideo, CenterCropVideo |
| 7 | + |
| 8 | +from tubevit.model import TubeViTLightningModule |
| 9 | + |
| 10 | + |
| 11 | +@click.command() |
| 12 | +@click.argument("video-path") |
| 13 | +@click.option("-m", "--model-path", type=click.Path(exists=True), required=True, help="path to model weight.") |
| 14 | +@click.option("--label-path", type=click.Path(exists=True), required=True, help="path to classInd.txt.") |
| 15 | +@click.option("-f", "--frames-per-clip", type=int, default=32, help="frame per clip.") |
| 16 | +@click.option("-v", "--video-size", type=click.Tuple([int, int]), default=(224, 224), help="frame per clip.") |
| 17 | +def main( |
| 18 | + video_path, |
| 19 | + model_path, |
| 20 | + label_path, |
| 21 | + frames_per_clip, |
| 22 | + video_size, |
| 23 | +): |
| 24 | + with open(label_path, "r") as f: |
| 25 | + labels = f.read().splitlines() |
| 26 | + labels = list(map(lambda x: x.split(" ")[-1], labels)) |
| 27 | + |
| 28 | + # Compose video data transforms |
| 29 | + transform = ApplyTransformToKey( |
| 30 | + key="video", |
| 31 | + transform=Compose( |
| 32 | + [ |
| 33 | + UniformTemporalSubsample(frames_per_clip), |
| 34 | + Lambda(lambda x: x / 255.0), |
| 35 | + NormalizeVideo(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| 36 | + ShortSideScale( |
| 37 | + size=video_size[0] |
| 38 | + ), |
| 39 | + CenterCropVideo(crop_size=video_size) |
| 40 | + ] |
| 41 | + ), |
| 42 | + ) |
| 43 | + |
| 44 | + # Load video |
| 45 | + video = EncodedVideo.from_path(video_path) |
| 46 | + # Get clip |
| 47 | + clip_start_sec = 0.0 # secs |
| 48 | + clip_duration = 2.0 # secs |
| 49 | + duration = video.duration |
| 50 | + video_data = [] |
| 51 | + for i in range(10): |
| 52 | + if clip_start_sec + clip_duration * (i + 1) <= duration: |
| 53 | + data = video.get_clip(start_sec=clip_start_sec + clip_duration * i, |
| 54 | + end_sec=clip_start_sec + clip_duration * (i + 1)) |
| 55 | + data = transform(data) |
| 56 | + video_data.append(data['video']) |
| 57 | + |
| 58 | + video_data = torch.stack(video_data) |
| 59 | + model = TubeViTLightningModule.load_from_checkpoint(model_path) |
| 60 | + prediction = model.predict_step(batch=(video_data, None), batch_idx=0) |
| 61 | + print(video_data.shape) |
| 62 | + print('Predict:', labels[torch.argmax(torch.sum(prediction['y_prob'], dim=0)).to('cpu').item()]) |
| 63 | + |
| 64 | + |
| 65 | +if __name__ == "__main__": |
| 66 | + main() |
0 commit comments