Skip to content
9 changes: 9 additions & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
channels:
- conda-forge
dependencies:
- conda-forge::python=3.7.6
- conda-forge::biopython=1.76
- conda-forge::matplotlib=3.1.2
- conda-forge::numba=0.47.0
- conda-forge::numpy=1.17.5
- conda-forge::scipy=1.4.1
54 changes: 39 additions & 15 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,19 @@ def parse_args():
parser.add_argument('--tSkip',type=int,default=1)
parser.add_argument('--df',type=int,default=150)
parser.add_argument('--betaParam',type=float,default=1.0)
parser.add_argument('--lik',action='store_true',help='saves likelihood function, used by PALM. **Only compatible when a single selection coefficient is estimated.**')
parser.add_argument('--w',type=float,default=0.01,help='width used to estimate likelihood function (--lik)')
return parser.parse_args()



def load_normal_tables():
import os
dname = os.path.dirname(os.path.abspath(__file__))
# read in global Phi(z) lookups
z_bins = np.genfromtxt('utils/z_bins.txt')
z_logcdf = np.genfromtxt('utils/z_logcdf.txt')
z_logsf = np.genfromtxt('utils/z_logsf.txt')
z_bins = np.genfromtxt(os.path.join(dname, 'utils/z_bins.txt'))
z_logcdf = np.genfromtxt(os.path.join(dname, 'utils/z_logcdf.txt'))
z_logsf = np.genfromtxt(os.path.join(dname, 'utils/z_logsf.txt'))
return z_bins,z_logcdf,z_logsf

def load_times(args):
Expand Down Expand Up @@ -214,7 +219,6 @@ def likelihood_wrapper(theta,timeBins,N,freqs,z_bins,z_logcdf,z_logsf,ancGLs,anc
return np.inf

sel = Sprime[np.digitize(epochs,timeBins,right=False)-1]

tShape = times.shape
if tShape[2] == 0:
t = np.zeros((2,0))
Expand Down Expand Up @@ -308,6 +312,10 @@ def traj_wrapper(theta,timeBins,N,freqs,z_bins,z_logcdf,z_logsf,ancGLs,ancHapGLs
S0 = 0.0 * np.ones(T-1)
opts = {'xatol':1e-4}

if args.lik and T > 2:
print('ERROR: Option --lik incompatible with estimating >1 selection coefficient.')
raise ValueError

if T == 2:
Simplex = np.reshape(np.array([-0.05,0.05]),(2,1))
elif T > 2:
Expand All @@ -323,40 +331,56 @@ def traj_wrapper(theta,timeBins,N,freqs,z_bins,z_logcdf,z_logsf,ancGLs,ancHapGLs
opts['initial_simplex']=Simplex

#for tup in product(*[[-1,1] for i in range(3)]):
logL0 = likelihood_wrapper(S0,timeBins,Ne,freqs,z_bins,z_logcdf,z_logsf,ancientGLs,ancientHapGLs,epochs,noCoals,currFreq,h,sMax,changePts)
logL0 = -likelihood_wrapper(S0,timeBins,Ne,freqs,z_bins,z_logcdf,z_logsf,ancientGLs,ancientHapGLs,epochs,noCoals,currFreq,h,sMax,changePts)

print('Optimizing likelihood surface using Nelder-Mead...')
if times.shape[2] > 1:
print('\t(Importance sampling with M = %d Relate samples)'%(times.shape[2]))
print()
minargs = (timeBins,Ne,freqs,z_bins,z_logcdf,z_logsf,ancientGLs,ancientHapGLs,epochs,noCoals,currFreq,h,sMax,changePts)


res = minimize(likelihood_wrapper,
S0,
args=minargs,
options=opts,
#bounds=bounds,
method='Nelder-Mead')
S0,
args=minargs,
options=opts,
#bounds=bounds,
method='Nelder-Mead')

S = res.x
L = res.fun
#Hinv = np.linalg.inv(res.hess)
#se = np.sqrt(np.diag(Hinv))

if args.lik:
print('Fitting likelihood function...')
print()
Svec = np.linspace(max(S[0]-args.w,-sMax),min(S[0]+args.w,sMax),20)
Lvec = []
for s in Svec:
l = likelihood_wrapper(np.array([s]),timeBins,Ne,freqs,z_bins,z_logcdf,z_logsf,ancientGLs,ancientHapGLs,epochs,noCoals,currFreq,h,sMax,changePts)
Lvec.append(l)
Lvec = np.array(Lvec)
S = np.array([Svec[np.argmax(-Lvec)]])
L = np.max(-Lvec)
p = np.polyfit(Svec,-Lvec,deg=2)
if args.out != None:
np.save(args.out+'.quad_fit.npy',p)


print('#'*10)
print()
print('logLR: %.4f'%(-res.fun+logL0))
print('logLR: %.4f'%(L-logL0))
print()
print('MLE:')
print('========')
print('epoch\tselection')
for s,t,u in zip(S,timeBins[:-1],timeBins[1:]):
print('%d-%d\t%.5f'%(t,u,s))
print('%d-%d\t%.7f'%(t,u,s))


# infer trajectory @ MLE of selection parameter
print(noCoals)

post = traj_wrapper(res.x,timeBins,Ne,freqs,z_bins,z_logcdf,z_logsf,ancientGLs,ancientHapGLs,epochs,noCoals,currFreq,h,sMax,changePts)
post = traj_wrapper(S,timeBins,Ne,freqs,z_bins,z_logcdf,z_logsf,ancientGLs,ancientHapGLs,epochs,noCoals,currFreq,h,sMax,changePts)

if args.out != None:
out(args,epochs,freqs,post)
Expand Down