Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@ This repository contains the demo for the audio-to-video synchronisation network
Please cite the paper below if you make use of the software.

## Dependencies

```
pip install -r requirements.txt
conda env create -f environment.yml
```

In addition, `ffmpeg` is required.

## Getting Started

Download the pretrained model:
```
sh download_model.sh
```

## Demo

Expand All @@ -21,16 +27,17 @@ SyncNet demo:
python demo_syncnet.py --videofile data/example.avi --tmp_dir /path/to/temp/directory
```

Check that this script returns:
Check that this script returns approximately the following values (minor differences are expected depending on your platform and package versions):
```
AV offset: 3
Min dist: 5.353
Confidence: 10.021
```

Full pipeline:
## Full Pipeline

Run the three stages — face detection and tracking, sync offset estimation, and visualisation:
```
sh download_model.sh
python run_pipeline.py --videofile /path/to/video.mp4 --reference name_of_video --data_dir /path/to/output
python run_syncnet.py --videofile /path/to/video.mp4 --reference name_of_video --data_dir /path/to/output
python run_visualise.py --videofile /path/to/video.mp4 --reference name_of_video --data_dir /path/to/output
Expand All @@ -39,7 +46,6 @@ python run_visualise.py --videofile /path/to/video.mp4 --reference name_of_video
Outputs:
```
$DATA_DIR/pycrop/$REFERENCE/*.avi - cropped face tracks
$DATA_DIR/pywork/$REFERENCE/offsets.txt - audio-video offset values
$DATA_DIR/pyavi/$REFERENCE/video_out.avi - output video (as shown below)
```
<p align="center">
Expand Down
89 changes: 49 additions & 40 deletions SyncNetInstance.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/usr/bin/python
#!/usr/bin/env python3
#-*- coding: utf-8 -*-
# Video 25 FPS, Audio 16000HZ

import torch
import numpy
import time, pdb, argparse, subprocess, os, math, glob
import time, pdb, argparse, subprocess, os, math, glob, logging
import cv2
import python_speech_features

Expand All @@ -13,11 +13,13 @@
from SyncNetModel import *
from shutil import rmtree

logger = logging.getLogger(__name__)


# ==================== Get OFFSET ====================

def calc_pdist(feat1, feat2, vshift=10):

win_size = vshift*2+1

feat2p = torch.nn.functional.pad(feat2,(0,0,vshift,vshift))
Expand All @@ -34,14 +36,16 @@ def calc_pdist(feat1, feat2, vshift=10):

class SyncNetInstance(torch.nn.Module):

def __init__(self, dropout = 0, num_layers_in_fc_layers = 1024):
super(SyncNetInstance, self).__init__();
def __init__(self, dropout = 0, num_layers_in_fc_layers = 1024, device=None):
super().__init__()

self.__S__ = S(num_layers_in_fc_layers = num_layers_in_fc_layers).cuda();
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
logger.info('Using device: %s', self.device)
self.__S__ = S(num_layers_in_fc_layers = num_layers_in_fc_layers).to(self.device)

def evaluate(self, opt, videofile):

self.__S__.eval();
self.__S__.eval()

# ========== ==========
# Convert files
Expand All @@ -52,18 +56,21 @@ def evaluate(self, opt, videofile):

os.makedirs(os.path.join(opt.tmp_dir,opt.reference))

command = ("ffmpeg -y -i %s -threads 1 -f image2 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'%06d.jpg')))
output = subprocess.call(command, shell=True, stdout=None)
command = ["ffmpeg", "-y", "-i", videofile, "-threads", "1", "-f", "image2",
os.path.join(opt.tmp_dir, opt.reference, '%06d.jpg')]
subprocess.run(command, check=True)

command = ["ffmpeg", "-y", "-i", videofile, "-async", "1", "-ac", "1", "-vn",
"-acodec", "pcm_s16le", "-ar", "16000",
os.path.join(opt.tmp_dir, opt.reference, 'audio.wav')]
subprocess.run(command, check=True)

command = ("ffmpeg -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'audio.wav')))
output = subprocess.call(command, shell=True, stdout=None)

# ========== ==========
# Load video
# Load video
# ========== ==========

images = []

flist = glob.glob(os.path.join(opt.tmp_dir,opt.reference,'*.jpg'))
flist.sort()

Expand All @@ -74,7 +81,7 @@ def evaluate(self, opt, videofile):
im = numpy.expand_dims(im,axis=0)
im = numpy.transpose(im,(0,3,4,1,2))

imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
imtv = torch.from_numpy(im.astype(float)).float()

# ========== ==========
# Load audio
Expand All @@ -85,17 +92,17 @@ def evaluate(self, opt, videofile):
mfcc = numpy.stack([numpy.array(i) for i in mfcc])

cc = numpy.expand_dims(numpy.expand_dims(mfcc,axis=0),axis=0)
cct = torch.autograd.Variable(torch.from_numpy(cc.astype(float)).float())
cct = torch.from_numpy(cc.astype(float)).float()

# ========== ==========
# Check audio and video input length
# ========== ==========

if (float(len(audio))/16000) != (float(len(images))/25) :
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Comparing floating-point values for exact equality is generally unreliable due to precision issues. It is safer to use math.isclose with an appropriate tolerance, especially when comparing durations derived from different sources (audio samples vs. video frames).

Suggested change
if (float(len(audio))/16000) != (float(len(images))/25) :
if not math.isclose(len(audio) / 16000.0, len(images) / 25.0, abs_tol=1e-4):

print("WARNING: Audio (%.4fs) and video (%.4fs) lengths are different."%(float(len(audio))/16000,float(len(images))/25))
logger.warning("Audio (%.4fs) and video (%.4fs) lengths are different.",float(len(audio))/16000,float(len(images))/25)

min_length = min(len(images),math.floor(len(audio)/640))

# ========== ==========
# Generate video and audio feats
# ========== ==========
Expand All @@ -106,15 +113,15 @@ def evaluate(self, opt, videofile):

tS = time.time()
for i in range(0,lastframe,opt.batch_size):

im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
im_in = torch.cat(im_batch,0)
im_out = self.__S__.forward_lip(im_in.cuda());
im_out = self.__S__.forward_lip(im_in.to(self.device))
im_feat.append(im_out.data.cpu())

cc_batch = [ cct[:,:,:,vframe*4:vframe*4+20] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
cc_in = torch.cat(cc_batch,0)
cc_out = self.__S__.forward_aud(cc_in.cuda())
cc_out = self.__S__.forward_aud(cc_in.to(self.device))
cc_feat.append(cc_out.data.cpu())

im_feat = torch.cat(im_feat,0)
Expand All @@ -123,8 +130,8 @@ def evaluate(self, opt, videofile):
# ========== ==========
# Compute offset
# ========== ==========
print('Compute time %.3f sec.' % (time.time()-tS))

logger.info('Compute time %.3f sec.', time.time()-tS)

dists = calc_pdist(im_feat,cc_feat,vshift=opt.vshift)
mdist = torch.mean(torch.stack(dists,1),1)
Expand All @@ -138,25 +145,27 @@ def evaluate(self, opt, videofile):
# fdist = numpy.pad(fdist, (3,3), 'constant', constant_values=15)
fconf = torch.median(mdist).numpy() - fdist
fconfm = signal.medfilt(fconf,kernel_size=9)

numpy.set_printoptions(formatter={'float': '{: 0.3f}'.format})
print('Framewise conf: ')
print(fconfm)
print('AV offset: \t%d \nMin dist: \t%.3f\nConfidence: \t%.3f' % (offset,minval,conf))
logger.info('Framewise conf: ')
logger.info(fconfm)
logger.info('AV offset: \t%d', offset.item())
logger.info('Min dist: \t%.3f', minval.item())
logger.info('Confidence: \t%.3f', conf.item())

dists_npy = numpy.array([ dist.numpy() for dist in dists ])
return offset.numpy(), conf.numpy(), dists_npy

def extract_feature(self, opt, videofile):

self.__S__.eval();
self.__S__.eval()

# ========== ==========
# Load video
# Load video
# ========== ==========
cap = cv2.VideoCapture(videofile)

frame_num = 1;
frame_num = 1
images = []
while frame_num:
frame_num += 1
Expand All @@ -170,8 +179,8 @@ def extract_feature(self, opt, videofile):
im = numpy.expand_dims(im,axis=0)
im = numpy.transpose(im,(0,3,4,1,2))

imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
imtv = torch.from_numpy(im.astype(float)).float()

# ========== ==========
# Generate video feats
# ========== ==========
Expand All @@ -181,28 +190,28 @@ def extract_feature(self, opt, videofile):

tS = time.time()
for i in range(0,lastframe,opt.batch_size):

im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ]
im_in = torch.cat(im_batch,0)
im_out = self.__S__.forward_lipfeat(im_in.cuda());
im_out = self.__S__.forward_lipfeat(im_in.to(self.device))
im_feat.append(im_out.data.cpu())

im_feat = torch.cat(im_feat,0)

# ========== ==========
# Compute offset
# ========== ==========
print('Compute time %.3f sec.' % (time.time()-tS))

logger.info('Compute time %.3f sec.', time.time()-tS)

return im_feat


def loadParameters(self, path):
loaded_state = torch.load(path, map_location=lambda storage, loc: storage);
loaded_state = torch.load(path, map_location=lambda storage, loc: storage, weights_only=True)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Setting weights_only=True may break loading for existing pretrained models if they were saved as full model objects (pickles) rather than state dictionaries. Given that SyncNetModel.save defaults to saving the full model, it is likely that many .model files in this ecosystem are full pickles. If you want to ensure backward compatibility with such files, set weights_only=False.

Suggested change
loaded_state = torch.load(path, map_location=lambda storage, loc: storage, weights_only=True)
loaded_state = torch.load(path, map_location=lambda storage, loc: storage, weights_only=False)


self_state = self.__S__.state_dict();
self_state = self.__S__.state_dict()

for name, param in loaded_state.items():

self_state[name].copy_(param);
self_state[name].copy_(param)
51 changes: 21 additions & 30 deletions SyncNetModel.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,16 @@
#!/usr/bin/python
#!/usr/bin/env python3
#-*- coding: utf-8 -*-

import torch
import torch.nn as nn

def save(model, filename):
with open(filename, "wb") as f:
torch.save(model, f);
print("%s saved."%filename);

def load(filename):
net = torch.load(filename)
return net;

class S(nn.Module):
def __init__(self, num_layers_in_fc_layers = 1024):
super(S, self).__init__();
super().__init__()

self.__nFeatures__ = 24;
self.__nChs__ = 32;
self.__midChs__ = 32;
self.__nFeatures__ = 24
self.__nChs__ = 32
self.__midChs__ = 32

self.netcnnaud = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1)),
Expand All @@ -44,25 +35,25 @@ def __init__(self, num_layers_in_fc_layers = 1024):
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(3,3), stride=(2,2)),

nn.Conv2d(256, 512, kernel_size=(5,4), padding=(0,0)),
nn.BatchNorm2d(512),
nn.ReLU(),
);
)

self.netfcaud = nn.Sequential(
nn.Linear(512, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, num_layers_in_fc_layers),
);
)

self.netfclip = nn.Sequential(
nn.Linear(512, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, num_layers_in_fc_layers),
);
)

self.netcnnlip = nn.Sequential(
nn.Conv3d(3, 96, kernel_size=(5,7,7), stride=(1,2,2), padding=0),
Expand Down Expand Up @@ -91,27 +82,27 @@ def __init__(self, num_layers_in_fc_layers = 1024):
nn.Conv3d(256, 512, kernel_size=(1,6,6), padding=0),
nn.BatchNorm3d(512),
nn.ReLU(inplace=True),
);
)

def forward_aud(self, x):

mid = self.netcnnaud(x); # N x ch x 24 x M
mid = mid.view((mid.size()[0], -1)); # N x (ch x 24)
out = self.netfcaud(mid);
mid = self.netcnnaud(x) # N x ch x 24 x M
mid = mid.view((mid.size(0), -1)) # N x (ch x 24)
out = self.netfcaud(mid)

return out;
return out

def forward_lip(self, x):

mid = self.netcnnlip(x);
mid = mid.view((mid.size()[0], -1)); # N x (ch x 24)
out = self.netfclip(mid);
mid = self.netcnnlip(x)
mid = mid.view((mid.size(0), -1)) # N x (ch x 24)
out = self.netfclip(mid)

return out;
return out

def forward_lipfeat(self, x):

mid = self.netcnnlip(x);
out = mid.view((mid.size()[0], -1)); # N x (ch x 24)
mid = self.netcnnlip(x)
out = mid.view((mid.size(0), -1)) # N x (ch x 24)

return out;
return out
Loading