Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04

ENV DEBIAN_FRONTEND=noninteractive

RUN apt-get update && apt-get install -y \
python3.8 \
python3-pip \
ffmpeg \
libsm6 \
libxext6 \
git \
wget \
&& rm -rf /var/lib/apt/lists/*

WORKDIR /app

# Copy requirements
COPY requirements.txt /app/

# Install PyTorch
RUN pip3 install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118

# Install other requirements
RUN pip3 install -r requirements.txt

# Copy application
COPY . /app/

# Expose Gradio port
EXPOSE 7860

CMD ["python3", "gradio_app.py", "--server_name", "0.0.0.0", "--port", "7860"]
8 changes: 6 additions & 2 deletions download_checkpoints.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
pip install gdown

gdown --id 1rvWuM12cyvNvBQNCLmG4Fr2L1rpjQBF0

mv float.pth checkpoints/
hf download r-f/wav2vec-english-speech-emotion-recognition \
--local-dir ./checkpoints/wav2vec-english-speech-emotion-recognition \
--include "*"
hf download facebook/wav2vec2-base-960h \
--local-dir ./checkpoints/wav2vec2-base-960h \
--include "*"
10 changes: 7 additions & 3 deletions generate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Inference Stage 2
"""

import os
import os, torch, random, cv2, torchvision, subprocess, librosa, datetime, tempfile, face_alignment
import numpy as np
import albumentations as A
Expand All @@ -25,7 +25,10 @@ def __init__(self, opt):
self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False)

# wav2vec2 audio preprocessor
self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(opt.wav2vec_model_path, local_files_only=True)
if os.path.exists(opt.wav2vec_model_path):
self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(opt.wav2vec_model_path, local_files_only=True)
else:
self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(opt.wav2vec_model_path)

# image transform
self.transform = A.Compose([
Expand Down Expand Up @@ -173,9 +176,10 @@ def initialize(self, parser):
parser.add_argument('--res_video_path',
default=None, type=str, help='res video path')
parser.add_argument('--ckpt_path',
default="/home/nvadmin/workspace/taek/float-pytorch/checkpoints/float.pth", type=str, help='checkpoint path')
default="./checkpoints/float.pth", type=str, help='checkpoint path')
parser.add_argument('--res_dir',
default="./results", type=str, help='result dir')

return parser


Expand Down
Loading