Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions scripts/modules/jet_augs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn.functional as F


def translate_jets( batch, width=1.0 ):
def translate_jets( batch, width=0.5 ):
'''
Input: batch of jets, shape (batchsize, 3, n_constit)
dim 1 ordering: (pT, eta, phi)
Expand Down Expand Up @@ -54,14 +54,14 @@ def normalise_pts( batch ):
batch_norm[:,0,:] = np.nan_to_num(batch_norm[:,0,:]/np.sum(batch_norm[:,0,:], axis=1)[:, np.newaxis], posinf = 0.0, neginf = 0.0 )
return batch_norm

def rescale_pts( batch ):
def rescale_pts( batch):
'''
Input: batch of jets, shape (batchsize, 3, n_constit)
dim 1 ordering: (pT, eta, phi)
Output: batch of pT-rescaled jets, each constituent pT is rescaled by 600, same shape as input
'''
batch_rscl = batch.copy()
batch_rscl[:,0,:] = np.nan_to_num(batch_rscl[:,0,:]/600, posinf = 0.0, neginf = 0.0 )
batch_rscl[:,0,:] = np.nan_to_num(batch_rscl[:,0,:]/600., posinf = 0.0, neginf = 0.0 )
return batch_rscl

def crop_jets( batch, nc ):
Expand Down Expand Up @@ -96,13 +96,17 @@ def collinear_fill_jets( batch ):
nc = batch.shape[2]
nzs = np.array( [ np.where( batch[:,0,:][i]>0.0)[0].shape[0] for i in range(len(batch)) ] )
for k in range(len(batch)):
nzs1 = np.max( [ nzs[k], int(nc/2) ] )
zs1 = int(nc-nzs1)
els = np.random.choice( np.linspace(0,nzs1-1,nzs1), size=zs1, replace=False )
rs = np.random.uniform( size=zs1 )
for j in range(zs1):
zs1 = int(nc-nzs[k])
nfill = np.min( [ zs1, nzs[k] ] )
els = np.random.choice( np.linspace(0,nzs[k]-1,nzs[k]), size=nfill, replace=False )
rs = np.random.uniform( size=nfill )
for j in range(nfill):
batchb[k,0,int(els[j])] = rs[j]*batch[k,0,int(els[j])]
batchb[k,0,int(nzs[k]+j)] = (1-rs[j])*batch[k,0,int(els[j])]
batchb[k,1,int(nzs[k]+j)] = batch[k,1,int(els[j])]
batchb[k,2,int(nzs[k]+j)] = batch[k,2,int(els[j])]

return batchb