diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000..73bd326 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,9 @@ +channels: + - conda-forge +dependencies: + - conda-forge::python=3.7.6 + - conda-forge::biopython=1.76 + - conda-forge::matplotlib=3.1.2 + - conda-forge::numba=0.47.0 + - conda-forge::numpy=1.17.5 + - conda-forge::scipy=1.4.1 \ No newline at end of file diff --git a/inference.py b/inference.py index cd9009a..e1d60b6 100644 --- a/inference.py +++ b/inference.py @@ -90,14 +90,19 @@ def parse_args(): parser.add_argument('--tSkip',type=int,default=1) parser.add_argument('--df',type=int,default=150) parser.add_argument('--betaParam',type=float,default=1.0) + parser.add_argument('--lik',action='store_true',help='saves likelihood function, used by PALM. **Only compatible when a single selection coefficient is estimated.**') + parser.add_argument('--w',type=float,default=0.01,help='width used to estimate likelihood function (--lik)') return parser.parse_args() + def load_normal_tables(): + import os + dname = os.path.dirname(os.path.abspath(__file__)) # read in global Phi(z) lookups - z_bins = np.genfromtxt('utils/z_bins.txt') - z_logcdf = np.genfromtxt('utils/z_logcdf.txt') - z_logsf = np.genfromtxt('utils/z_logsf.txt') + z_bins = np.genfromtxt(os.path.join(dname, 'utils/z_bins.txt')) + z_logcdf = np.genfromtxt(os.path.join(dname, 'utils/z_logcdf.txt')) + z_logsf = np.genfromtxt(os.path.join(dname, 'utils/z_logsf.txt')) return z_bins,z_logcdf,z_logsf def load_times(args): @@ -214,7 +219,6 @@ def likelihood_wrapper(theta,timeBins,N,freqs,z_bins,z_logcdf,z_logsf,ancGLs,anc return np.inf sel = Sprime[np.digitize(epochs,timeBins,right=False)-1] - tShape = times.shape if tShape[2] == 0: t = np.zeros((2,0)) @@ -308,6 +312,10 @@ def traj_wrapper(theta,timeBins,N,freqs,z_bins,z_logcdf,z_logsf,ancGLs,ancHapGLs S0 = 0.0 * np.ones(T-1) opts = {'xatol':1e-4} + if args.lik and T > 2: + print('ERROR: Option --lik incompatible with estimating >1 selection coefficient.') + raise ValueError + if T == 2: Simplex = np.reshape(np.array([-0.05,0.05]),(2,1)) elif T > 2: @@ -323,40 +331,56 @@ def traj_wrapper(theta,timeBins,N,freqs,z_bins,z_logcdf,z_logsf,ancGLs,ancHapGLs opts['initial_simplex']=Simplex #for tup in product(*[[-1,1] for i in range(3)]): - logL0 = likelihood_wrapper(S0,timeBins,Ne,freqs,z_bins,z_logcdf,z_logsf,ancientGLs,ancientHapGLs,epochs,noCoals,currFreq,h,sMax,changePts) + logL0 = -likelihood_wrapper(S0,timeBins,Ne,freqs,z_bins,z_logcdf,z_logsf,ancientGLs,ancientHapGLs,epochs,noCoals,currFreq,h,sMax,changePts) print('Optimizing likelihood surface using Nelder-Mead...') if times.shape[2] > 1: print('\t(Importance sampling with M = %d Relate samples)'%(times.shape[2])) print() minargs = (timeBins,Ne,freqs,z_bins,z_logcdf,z_logsf,ancientGLs,ancientHapGLs,epochs,noCoals,currFreq,h,sMax,changePts) + + res = minimize(likelihood_wrapper, - S0, - args=minargs, - options=opts, - #bounds=bounds, - method='Nelder-Mead') + S0, + args=minargs, + options=opts, + #bounds=bounds, + method='Nelder-Mead') S = res.x L = res.fun - #Hinv = np.linalg.inv(res.hess) - #se = np.sqrt(np.diag(Hinv)) + + if args.lik: + print('Fitting likelihood function...') + print() + Svec = np.linspace(max(S[0]-args.w,-sMax),min(S[0]+args.w,sMax),20) + Lvec = [] + for s in Svec: + l = likelihood_wrapper(np.array([s]),timeBins,Ne,freqs,z_bins,z_logcdf,z_logsf,ancientGLs,ancientHapGLs,epochs,noCoals,currFreq,h,sMax,changePts) + Lvec.append(l) + Lvec = np.array(Lvec) + S = np.array([Svec[np.argmax(-Lvec)]]) + L = np.max(-Lvec) + p = np.polyfit(Svec,-Lvec,deg=2) + if args.out != None: + np.save(args.out+'.quad_fit.npy',p) + print('#'*10) print() - print('logLR: %.4f'%(-res.fun+logL0)) + print('logLR: %.4f'%(L-logL0)) print() print('MLE:') print('========') print('epoch\tselection') for s,t,u in zip(S,timeBins[:-1],timeBins[1:]): - print('%d-%d\t%.5f'%(t,u,s)) + print('%d-%d\t%.7f'%(t,u,s)) # infer trajectory @ MLE of selection parameter print(noCoals) - post = traj_wrapper(res.x,timeBins,Ne,freqs,z_bins,z_logcdf,z_logsf,ancientGLs,ancientHapGLs,epochs,noCoals,currFreq,h,sMax,changePts) + post = traj_wrapper(S,timeBins,Ne,freqs,z_bins,z_logcdf,z_logsf,ancientGLs,ancientHapGLs,epochs,noCoals,currFreq,h,sMax,changePts) if args.out != None: out(args,epochs,freqs,post)