diff --git a/scripts/modules/jet_augs.py b/scripts/modules/jet_augs.py index 6c6d3a5..2b92bad 100755 --- a/scripts/modules/jet_augs.py +++ b/scripts/modules/jet_augs.py @@ -10,7 +10,7 @@ import torch.nn.functional as F -def translate_jets( batch, width=1.0 ): +def translate_jets( batch, width=0.5 ): ''' Input: batch of jets, shape (batchsize, 3, n_constit) dim 1 ordering: (pT, eta, phi) @@ -54,14 +54,14 @@ def normalise_pts( batch ): batch_norm[:,0,:] = np.nan_to_num(batch_norm[:,0,:]/np.sum(batch_norm[:,0,:], axis=1)[:, np.newaxis], posinf = 0.0, neginf = 0.0 ) return batch_norm -def rescale_pts( batch ): +def rescale_pts( batch): ''' Input: batch of jets, shape (batchsize, 3, n_constit) dim 1 ordering: (pT, eta, phi) Output: batch of pT-rescaled jets, each constituent pT is rescaled by 600, same shape as input ''' batch_rscl = batch.copy() - batch_rscl[:,0,:] = np.nan_to_num(batch_rscl[:,0,:]/600, posinf = 0.0, neginf = 0.0 ) + batch_rscl[:,0,:] = np.nan_to_num(batch_rscl[:,0,:]/600., posinf = 0.0, neginf = 0.0 ) return batch_rscl def crop_jets( batch, nc ): @@ -96,13 +96,17 @@ def collinear_fill_jets( batch ): nc = batch.shape[2] nzs = np.array( [ np.where( batch[:,0,:][i]>0.0)[0].shape[0] for i in range(len(batch)) ] ) for k in range(len(batch)): - nzs1 = np.max( [ nzs[k], int(nc/2) ] ) - zs1 = int(nc-nzs1) - els = np.random.choice( np.linspace(0,nzs1-1,nzs1), size=zs1, replace=False ) - rs = np.random.uniform( size=zs1 ) - for j in range(zs1): + zs1 = int(nc-nzs[k]) + nfill = np.min( [ zs1, nzs[k] ] ) + els = np.random.choice( np.linspace(0,nzs[k]-1,nzs[k]), size=nfill, replace=False ) + rs = np.random.uniform( size=nfill ) + for j in range(nfill): batchb[k,0,int(els[j])] = rs[j]*batch[k,0,int(els[j])] batchb[k,0,int(nzs[k]+j)] = (1-rs[j])*batch[k,0,int(els[j])] batchb[k,1,int(nzs[k]+j)] = batch[k,1,int(els[j])] batchb[k,2,int(nzs[k]+j)] = batch[k,2,int(els[j])] + return batchb + + +