From 1ce03995ba3591a9e4885881149cb4b02948c81a Mon Sep 17 00:00:00 2001 From: Radha Mastandrea Date: Wed, 10 Nov 2021 08:44:50 -0800 Subject: [PATCH 1/2] col split handling for small jets --- scripts/modules/jet_augs.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/scripts/modules/jet_augs.py b/scripts/modules/jet_augs.py index 6c6d3a5..ebbc4b9 100755 --- a/scripts/modules/jet_augs.py +++ b/scripts/modules/jet_augs.py @@ -10,7 +10,7 @@ import torch.nn.functional as F -def translate_jets( batch, width=1.0 ): +def translate_jets( batch, width=0.5 ): ''' Input: batch of jets, shape (batchsize, 3, n_constit) dim 1 ordering: (pT, eta, phi) @@ -30,12 +30,17 @@ def translate_jets( batch, width=1.0 ): return shifted_batch -def rotate_jets( batch ): +def rotate_jets( batch, num_jets = 1 ): ''' Input: batch of jets, shape (batchsize, 3, n_constit) dim 1 ordering: (pT, eta, phi) Output: batch of jets rotated independently in eta-phi, same shape as input ''' + """ + print(batch.shape) + split_batch = np.split(batch, num_jets, axis = 2) + print([x.shape for x in split_batch]) + """ rot_angle = np.random.rand(batch.shape[0])*2*np.pi c = np.cos(rot_angle) s = np.sin(rot_angle) @@ -54,14 +59,14 @@ def normalise_pts( batch ): batch_norm[:,0,:] = np.nan_to_num(batch_norm[:,0,:]/np.sum(batch_norm[:,0,:], axis=1)[:, np.newaxis], posinf = 0.0, neginf = 0.0 ) return batch_norm -def rescale_pts( batch ): +def rescale_pts( batch, pt_rescale_denom ): ''' Input: batch of jets, shape (batchsize, 3, n_constit) dim 1 ordering: (pT, eta, phi) - Output: batch of pT-rescaled jets, each constituent pT is rescaled by 600, same shape as input + Output: batch of pT-rescaled jets, each constituent pT is rescaled by pt_rescale_denom, same shape as input ''' batch_rscl = batch.copy() - batch_rscl[:,0,:] = np.nan_to_num(batch_rscl[:,0,:]/600, posinf = 0.0, neginf = 0.0 ) + batch_rscl[:,0,:] = np.nan_to_num(batch_rscl[:,0,:]/pt_rescale_denom, posinf = 0.0, neginf = 0.0 ) return batch_rscl def crop_jets( batch, nc ): @@ -96,13 +101,17 @@ def collinear_fill_jets( batch ): nc = batch.shape[2] nzs = np.array( [ np.where( batch[:,0,:][i]>0.0)[0].shape[0] for i in range(len(batch)) ] ) for k in range(len(batch)): - nzs1 = np.max( [ nzs[k], int(nc/2) ] ) - zs1 = int(nc-nzs1) - els = np.random.choice( np.linspace(0,nzs1-1,nzs1), size=zs1, replace=False ) - rs = np.random.uniform( size=zs1 ) - for j in range(zs1): + zs1 = int(nc-nzs[k]) + nfill = np.min( [ zs1, nzs[k] ] ) + els = np.random.choice( np.linspace(0,nzs[k]-1,nzs[k]), size=nfill, replace=False ) + rs = np.random.uniform( size=nfill ) + for j in range(nfill): batchb[k,0,int(els[j])] = rs[j]*batch[k,0,int(els[j])] batchb[k,0,int(nzs[k]+j)] = (1-rs[j])*batch[k,0,int(els[j])] batchb[k,1,int(nzs[k]+j)] = batch[k,1,int(els[j])] batchb[k,2,int(nzs[k]+j)] = batch[k,2,int(els[j])] + return batchb + + + From c7f69ce783fc6c8bb49171f9757395f1978178c5 Mon Sep 17 00:00:00 2001 From: Radha Date: Mon, 20 Dec 2021 09:34:07 -0800 Subject: [PATCH 2/2] removed my extra changes to jet_augs --- scripts/modules/jet_augs.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/scripts/modules/jet_augs.py b/scripts/modules/jet_augs.py index ebbc4b9..2b92bad 100755 --- a/scripts/modules/jet_augs.py +++ b/scripts/modules/jet_augs.py @@ -30,17 +30,12 @@ def translate_jets( batch, width=0.5 ): return shifted_batch -def rotate_jets( batch, num_jets = 1 ): +def rotate_jets( batch ): ''' Input: batch of jets, shape (batchsize, 3, n_constit) dim 1 ordering: (pT, eta, phi) Output: batch of jets rotated independently in eta-phi, same shape as input ''' - """ - print(batch.shape) - split_batch = np.split(batch, num_jets, axis = 2) - print([x.shape for x in split_batch]) - """ rot_angle = np.random.rand(batch.shape[0])*2*np.pi c = np.cos(rot_angle) s = np.sin(rot_angle) @@ -59,14 +54,14 @@ def normalise_pts( batch ): batch_norm[:,0,:] = np.nan_to_num(batch_norm[:,0,:]/np.sum(batch_norm[:,0,:], axis=1)[:, np.newaxis], posinf = 0.0, neginf = 0.0 ) return batch_norm -def rescale_pts( batch, pt_rescale_denom ): +def rescale_pts( batch): ''' Input: batch of jets, shape (batchsize, 3, n_constit) dim 1 ordering: (pT, eta, phi) - Output: batch of pT-rescaled jets, each constituent pT is rescaled by pt_rescale_denom, same shape as input + Output: batch of pT-rescaled jets, each constituent pT is rescaled by 600, same shape as input ''' batch_rscl = batch.copy() - batch_rscl[:,0,:] = np.nan_to_num(batch_rscl[:,0,:]/pt_rescale_denom, posinf = 0.0, neginf = 0.0 ) + batch_rscl[:,0,:] = np.nan_to_num(batch_rscl[:,0,:]/600., posinf = 0.0, neginf = 0.0 ) return batch_rscl def crop_jets( batch, nc ):