diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c18dd8d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__/ diff --git a/README.md b/README.md index 873a5c8..50929ec 100644 --- a/README.md +++ b/README.md @@ -5,14 +5,14 @@ ### Overview - - (0) Test script: `test_SfSNet.m` + - (0) Test script: `test_SfSNet.m`, `test_SfSNet.py` - (1) Test images along with mask: Images_mask - (2) Test images without mask: Images -Run 'test_SfSNet' on Matlab to run SfSNet on the supplied test images. +Run 'test_SfSNet' on Matlab or 'test_SfSNet.py' in Python to run SfSNet on the supplied test images. ### Dependencies ### -This code requires a working installation of [Caffe](http://caffe.berkeleyvision.org/) and Matlab interface for Caffe. For guidelines and help with installation of Caffe, consult the [installation guide](http://caffe.berkeleyvision.org/) and [Caffe users group](https://groups.google.com/forum/#!forum/caffe-users). +This code requires a working installation of [Caffe](http://caffe.berkeleyvision.org/) and Matlab interface for Caffe or Python interface for Caffe. For guidelines and help with installation of Caffe, consult the [installation guide](http://caffe.berkeleyvision.org/) and [Caffe users group](https://groups.google.com/forum/#!forum/caffe-users). Please set the variable `PATH_TO_CAFFE_MATLAB`, in line 3 of `test_SfSNet.m` as `$PATH_TO_CAFFE/matlab` (path to matlab folder for the caffe installation) diff --git a/spherical_harmonics.py b/spherical_harmonics.py new file mode 100644 index 0000000..232dc1b --- /dev/null +++ b/spherical_harmonics.py @@ -0,0 +1,105 @@ +import numpy as np + +def normalize(arr, axis=0): + """ + Normalize an array along a certain axis + """ + return arr/(np.expand_dims(np.linalg.norm(arr,axis=axis),axis=axis)+1e-10) + +def get_sphere_normals(image_size=512, sphere_radius=224): + """ + Return the normals for a 2D image of a sphere + """ + center_x, center_y = image_size//2, image_size//2 + sphere_normals = np.zeros((image_size,image_size,3)) + for i in range(image_size): + for j in range(image_size): + x = (i-center_x)/sphere_radius + y = (j-center_y)/sphere_radius + discriminant = 1-x*x-y*y + if discriminant > 0: + z = np.sqrt(discriminant) + sphere_normals[i, j, 0] = x + sphere_normals[i, j, 1] = y + sphere_normals[i, j, 2] = z + else: + sphere_normals[i, j, :] = 0 + return sphere_normals + +def get_sphere_mask(image_size=512, sphere_radius=224): + """ + Return the mask for a 2D image of a sphere + """ + center_x, center_y = image_size//2, image_size//2 + sphere_mask = np.zeros((image_size,image_size,3)) + for i in range(image_size): + for j in range(image_size): + x = (i-center_x)/sphere_radius + y = (j-center_y)/sphere_radius + discriminant = 1-x*x-y*y + if discriminant > 0: + sphere_mask[i, j, :] = 1 + else: + sphere_mask[i, j, :] = 0 + return sphere_mask + +def compute_shading(lights, normals=None, mode='sfsnet'): + """ + Adapted from SfSNet/SfSNet_train/python/Shading_Layer.py + + Assumes lights is N x 27 and normals is N x H x W x 3 + """ + if normals is None: + # render SH on sphere + normals = get_sphere_normals() + normals = np.tile(normals, (lights.shape[0],1,1,1)) + + normals = normalize(normals, axis=3) + normals = normals.transpose((0,3,1,2)) + + shading = np.zeros(normals.shape) + + sz = normals.shape + + att = np.pi*np.array([1, 2.0/3, 0.25]) + + c1 = att[0]*(1.0/np.sqrt(4*np.pi)) # 1 * 0.282095 = 0.282095 + c2 = att[1]*(np.sqrt(3.0/(4*np.pi))) # 2/3 * 0.488602 = 0.325735 + c3 = att[2]*0.5*(np.sqrt(5.0/(4*np.pi))) # 1/4 * 0.315392 = 0.078848 + c4 = att[2]*(3.0*(np.sqrt(5.0/(12*np.pi)))) # 1/4 * 1.092548 = 0.273137 + c5 = att[2]*(3.0*(np.sqrt(5.0/(48*np.pi)))) # 1/4 * 0.546274 = 0.136568 = c4/2.0 + + for i in range(0, sz[0]): + nx = normals[i,0,...] + ny = normals[i,1,...] + nz = normals[i,2,...] + + if mode == 'sfsnet': + # SH representation used by SfSNet + H1 = c1*np.ones((sz[2],sz[3])) + H2 = c2*nz + H3 = c2*nx + H4 = c2*ny + H5 = c3*(2*nz*nz - nx*nx -ny*ny) + H6 = c4*nx*nz + H7 = c4*ny*nz + H8 = c5*(nx*nx - ny*ny) + H9 = c4*nx*ny + else: + # SH representation used by LDAN, DirectX SH, Google SH, etc. + H1 = c1*np.ones((sz[2],sz[3])) + H2 = -c2*ny + H3 = c2*nz + H4 = -c2*nx + H5 = c4*nx*ny + H6 = -c4*ny*nz + H7 = c3*(2*nz*nz - nx*nx -ny*ny) # equivalent to c3*(3*nz*nz - 1) + H8 = -c4*nx*nz + H9 = c5*(nx*nx - ny*ny) + + for j in range(0,3) : + L=lights[i,j*9:(j+1)*9] + shading[i,j,:,:]=L[0]*H1+L[1]*H2+L[2]*H3+L[3]*H4+L[4]*H5+L[5]*H6+L[6]*H7+L[7]*H8+L[8]*H9 + + shading = shading.transpose((0,2,3,1)) + return shading diff --git a/test_SfSNet.py b/test_SfSNet.py new file mode 100644 index 0000000..3c35253 --- /dev/null +++ b/test_SfSNet.py @@ -0,0 +1,113 @@ +import os +import sys +import numpy as np +import matplotlib.pyplot as plt +plt.ion() +import cv2 +import caffe +import spherical_harmonics + +# caffe.set_mode_cpu(); +caffe.set_device(0); +caffe.set_mode_gpu(); + +model = 'SfSNet_deploy.prototxt' +weights = 'SfSNet.caffemodel.h5' +net = caffe.Net(model, weights, caffe.TEST) + +# Choose Dataset +dat_idx = input('Please enter 1 for images with masks and 0 for images without mask: ') +dat_idx = int(dat_idx) +if dat_idx == 1: + # Images and masks are provided + list_im = os.listdir('Images_mask/'); + list_im = list(filter(lambda f: f.endswith('_face.png'), list_im)) + list_im.sort() +elif dat_idx == 0: + # No mask provided (Need to use your own mask). + list_im = os.listdir('Images/'); + list_im = list(filter(lambda f: f.endswith('.png'), list_im)) + list_im.sort() +else: + print('Wrong Option!'); + sys.exit(1) + +fig, axs = plt.subplots(2, 3, figsize=(24,16)) +for i in range(2): + for j in range(3): + axs[i][j].axis('off') +axs[0][0].set_title('Image') +axs[0][1].set_title('Normal') +axs[0][2].set_title('Albedo') +axs[1][1].set_title('Shading') +axs[1][2].set_title('Recon') + +M = 128; # size of input for SfSNet +for i in range(len(list_im)): + if dat_idx == 1: + im = cv2.imread(os.path.join('Images_mask/', list_im[i])); + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + + mask = cv2.imread(os.path.join('Images_mask/', list_im[i].replace('face', 'mask'))); + mask = cv2.resize(mask, (M, M)); + mask = mask/255; + else: + im = cv2.imread(os.path.join('Images/', list_im[i])); + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + + # Prepare images + im = cv2.resize(im, (M, M)) + im = im.astype(np.float32)/255 + im_data = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) + im_data = im_data.reshape((1,) + im_data.shape) + im_data = im_data.transpose((0,3,1,2)) + + # Pass images + net.blobs['data'].data[...] = im_data + output = net.forward(); + al_out = output['Aconv0'].transpose((0,2,3,1))[0] + n_out = output['Nconv0'].transpose((0,2,3,1))[0] + light_out = output['fc_light'][0] + + # light_out is a 27 dimensional vector. 9 dimension for each channel of + # RGB. For every 9 dimensional, 1st dimension is ambient illumination + # (0th order), next 3 dimension is directional (1st order), next 5 + # dimension is 2nd order approximation. You can simply use 27 + # dimensional feature vector as lighting representation. + + # Transform + n_out2 = cv2.cvtColor(n_out, cv2.COLOR_BGR2RGB) + n_out2 = 2*n_out2-1; # [-1 1] + n_out2 = n_out2/(np.expand_dims(np.linalg.norm(n_out2,axis=2),axis=2)+1e-10) + + al_out2 = cv2.cvtColor(al_out, cv2.COLOR_BGR2RGB) + + # Note: n_out2, al_out2, light_out is the actual output + + # Create reconstruction and shading image + light_out_data = light_out.reshape((1,) + light_out.shape) + n_out2_data = n_out2.reshape((1,) + n_out2.shape) + Ishd = spherical_harmonics.compute_shading(light_out_data, n_out2_data)[0] + Irec = Ishd*al_out2 + + # Visualize light_out on a sphere (not included in test_SfSNet.m) + sphere_normals = spherical_harmonics.get_sphere_normals() + sphere_normals_data = sphere_normals.reshape((1,) + sphere_normals.shape) + Ilight = spherical_harmonics.compute_shading(light_out_data, sphere_normals_data)[0] + + if dat_idx == 1: + axs[0][0].imshow(mask*im) + axs[0][1].imshow(mask*((1+n_out2)/2).clip(0,1)) + axs[0][2].imshow(mask*al_out2.clip(0,1)) + axs[1][0].imshow(Ilight.clip(0,1)) + axs[1][1].imshow(mask*200/255*Ishd.clip(0,1)) + axs[1][2].imshow(mask*Irec.clip(0,1)) + else: + axs[0][0].imshow(im) + axs[0][1].imshow(((1+n_out2)/2).clip(0,1)) + axs[0][2].imshow(al_out2.clip(0,1)) + axs[1][0].imshow(Ilight.clip(0,1)) + axs[1][1].imshow(200/255*Ishd.clip(0,1)) + axs[1][2].imshow(Irec.clip(0,1)) + + input('Press Enter to Continue')