Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__pycache__/
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
<img src="https://github.com/senguptaumd/SfSNet/blob/gh-pages/resources/Teaser1.png" width="500px" >

### Overview
- (0) Test script: `test_SfSNet.m`
- (0) Test script: `test_SfSNet.m`, `test_SfSNet.py`
- (1) Test images along with mask: Images_mask
- (2) Test images without mask: Images

Run 'test_SfSNet' on Matlab to run SfSNet on the supplied test images.
Run 'test_SfSNet' on Matlab or 'test_SfSNet.py' in Python to run SfSNet on the supplied test images.

### Dependencies ###
This code requires a working installation of [Caffe](http://caffe.berkeleyvision.org/) and Matlab interface for Caffe. For guidelines and help with installation of Caffe, consult the [installation guide](http://caffe.berkeleyvision.org/) and [Caffe users group](https://groups.google.com/forum/#!forum/caffe-users).
This code requires a working installation of [Caffe](http://caffe.berkeleyvision.org/) and Matlab interface for Caffe or Python interface for Caffe. For guidelines and help with installation of Caffe, consult the [installation guide](http://caffe.berkeleyvision.org/) and [Caffe users group](https://groups.google.com/forum/#!forum/caffe-users).

Please set the variable `PATH_TO_CAFFE_MATLAB`, in line 3 of `test_SfSNet.m` as `$PATH_TO_CAFFE/matlab` (path to matlab folder for the caffe installation)

Expand Down
105 changes: 105 additions & 0 deletions spherical_harmonics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import numpy as np

def normalize(arr, axis=0):
"""
Normalize an array along a certain axis
"""
return arr/(np.expand_dims(np.linalg.norm(arr,axis=axis),axis=axis)+1e-10)

def get_sphere_normals(image_size=512, sphere_radius=224):
"""
Return the normals for a 2D image of a sphere
"""
center_x, center_y = image_size//2, image_size//2
sphere_normals = np.zeros((image_size,image_size,3))
for i in range(image_size):
for j in range(image_size):
x = (i-center_x)/sphere_radius
y = (j-center_y)/sphere_radius
discriminant = 1-x*x-y*y
if discriminant > 0:
z = np.sqrt(discriminant)
sphere_normals[i, j, 0] = x
sphere_normals[i, j, 1] = y
sphere_normals[i, j, 2] = z
else:
sphere_normals[i, j, :] = 0
return sphere_normals

def get_sphere_mask(image_size=512, sphere_radius=224):
"""
Return the mask for a 2D image of a sphere
"""
center_x, center_y = image_size//2, image_size//2
sphere_mask = np.zeros((image_size,image_size,3))
for i in range(image_size):
for j in range(image_size):
x = (i-center_x)/sphere_radius
y = (j-center_y)/sphere_radius
discriminant = 1-x*x-y*y
if discriminant > 0:
sphere_mask[i, j, :] = 1
else:
sphere_mask[i, j, :] = 0
return sphere_mask

def compute_shading(lights, normals=None, mode='sfsnet'):
"""
Adapted from SfSNet/SfSNet_train/python/Shading_Layer.py

Assumes lights is N x 27 and normals is N x H x W x 3
"""
if normals is None:
# render SH on sphere
normals = get_sphere_normals()
normals = np.tile(normals, (lights.shape[0],1,1,1))

normals = normalize(normals, axis=3)
normals = normals.transpose((0,3,1,2))

shading = np.zeros(normals.shape)

sz = normals.shape

att = np.pi*np.array([1, 2.0/3, 0.25])

c1 = att[0]*(1.0/np.sqrt(4*np.pi)) # 1 * 0.282095 = 0.282095
c2 = att[1]*(np.sqrt(3.0/(4*np.pi))) # 2/3 * 0.488602 = 0.325735
c3 = att[2]*0.5*(np.sqrt(5.0/(4*np.pi))) # 1/4 * 0.315392 = 0.078848
c4 = att[2]*(3.0*(np.sqrt(5.0/(12*np.pi)))) # 1/4 * 1.092548 = 0.273137
c5 = att[2]*(3.0*(np.sqrt(5.0/(48*np.pi)))) # 1/4 * 0.546274 = 0.136568 = c4/2.0

for i in range(0, sz[0]):
nx = normals[i,0,...]
ny = normals[i,1,...]
nz = normals[i,2,...]

if mode == 'sfsnet':
# SH representation used by SfSNet
H1 = c1*np.ones((sz[2],sz[3]))
H2 = c2*nz
H3 = c2*nx
H4 = c2*ny
H5 = c3*(2*nz*nz - nx*nx -ny*ny)
H6 = c4*nx*nz
H7 = c4*ny*nz
H8 = c5*(nx*nx - ny*ny)
H9 = c4*nx*ny
else:
# SH representation used by LDAN, DirectX SH, Google SH, etc.
H1 = c1*np.ones((sz[2],sz[3]))
H2 = -c2*ny
H3 = c2*nz
H4 = -c2*nx
H5 = c4*nx*ny
H6 = -c4*ny*nz
H7 = c3*(2*nz*nz - nx*nx -ny*ny) # equivalent to c3*(3*nz*nz - 1)
H8 = -c4*nx*nz
H9 = c5*(nx*nx - ny*ny)

for j in range(0,3) :
L=lights[i,j*9:(j+1)*9]
shading[i,j,:,:]=L[0]*H1+L[1]*H2+L[2]*H3+L[3]*H4+L[4]*H5+L[5]*H6+L[6]*H7+L[7]*H8+L[8]*H9

shading = shading.transpose((0,2,3,1))
return shading
113 changes: 113 additions & 0 deletions test_SfSNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
plt.ion()
import cv2
import caffe
import spherical_harmonics

# caffe.set_mode_cpu();
caffe.set_device(0);
caffe.set_mode_gpu();

model = 'SfSNet_deploy.prototxt'
weights = 'SfSNet.caffemodel.h5'
net = caffe.Net(model, weights, caffe.TEST)

# Choose Dataset
dat_idx = input('Please enter 1 for images with masks and 0 for images without mask: ')
dat_idx = int(dat_idx)
if dat_idx == 1:
# Images and masks are provided
list_im = os.listdir('Images_mask/');
list_im = list(filter(lambda f: f.endswith('_face.png'), list_im))
list_im.sort()
elif dat_idx == 0:
# No mask provided (Need to use your own mask).
list_im = os.listdir('Images/');
list_im = list(filter(lambda f: f.endswith('.png'), list_im))
list_im.sort()
else:
print('Wrong Option!');
sys.exit(1)

fig, axs = plt.subplots(2, 3, figsize=(24,16))
for i in range(2):
for j in range(3):
axs[i][j].axis('off')
axs[0][0].set_title('Image')
axs[0][1].set_title('Normal')
axs[0][2].set_title('Albedo')
axs[1][1].set_title('Shading')
axs[1][2].set_title('Recon')

M = 128; # size of input for SfSNet
for i in range(len(list_im)):
if dat_idx == 1:
im = cv2.imread(os.path.join('Images_mask/', list_im[i]));
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

mask = cv2.imread(os.path.join('Images_mask/', list_im[i].replace('face', 'mask')));
mask = cv2.resize(mask, (M, M));
mask = mask/255;
else:
im = cv2.imread(os.path.join('Images/', list_im[i]));
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

# Prepare images
im = cv2.resize(im, (M, M))
im = im.astype(np.float32)/255
im_data = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
im_data = im_data.reshape((1,) + im_data.shape)
im_data = im_data.transpose((0,3,1,2))

# Pass images
net.blobs['data'].data[...] = im_data
output = net.forward();
al_out = output['Aconv0'].transpose((0,2,3,1))[0]
n_out = output['Nconv0'].transpose((0,2,3,1))[0]
light_out = output['fc_light'][0]

# light_out is a 27 dimensional vector. 9 dimension for each channel of
# RGB. For every 9 dimensional, 1st dimension is ambient illumination
# (0th order), next 3 dimension is directional (1st order), next 5
# dimension is 2nd order approximation. You can simply use 27
# dimensional feature vector as lighting representation.

# Transform
n_out2 = cv2.cvtColor(n_out, cv2.COLOR_BGR2RGB)
n_out2 = 2*n_out2-1; # [-1 1]
n_out2 = n_out2/(np.expand_dims(np.linalg.norm(n_out2,axis=2),axis=2)+1e-10)

al_out2 = cv2.cvtColor(al_out, cv2.COLOR_BGR2RGB)

# Note: n_out2, al_out2, light_out is the actual output

# Create reconstruction and shading image
light_out_data = light_out.reshape((1,) + light_out.shape)
n_out2_data = n_out2.reshape((1,) + n_out2.shape)
Ishd = spherical_harmonics.compute_shading(light_out_data, n_out2_data)[0]
Irec = Ishd*al_out2

# Visualize light_out on a sphere (not included in test_SfSNet.m)
sphere_normals = spherical_harmonics.get_sphere_normals()
sphere_normals_data = sphere_normals.reshape((1,) + sphere_normals.shape)
Ilight = spherical_harmonics.compute_shading(light_out_data, sphere_normals_data)[0]

if dat_idx == 1:
axs[0][0].imshow(mask*im)
axs[0][1].imshow(mask*((1+n_out2)/2).clip(0,1))
axs[0][2].imshow(mask*al_out2.clip(0,1))
axs[1][0].imshow(Ilight.clip(0,1))
axs[1][1].imshow(mask*200/255*Ishd.clip(0,1))
axs[1][2].imshow(mask*Irec.clip(0,1))
else:
axs[0][0].imshow(im)
axs[0][1].imshow(((1+n_out2)/2).clip(0,1))
axs[0][2].imshow(al_out2.clip(0,1))
axs[1][0].imshow(Ilight.clip(0,1))
axs[1][1].imshow(200/255*Ishd.clip(0,1))
axs[1][2].imshow(Irec.clip(0,1))

input('Press Enter to Continue')