Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.pyc
selection
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,10 @@ SONIA takes as input TCR CDR3 amino acid sequences, with or without per sequence

An example script is provided that reads in selected and pre-selected sequences from supplied text files and infer selection factors on any amino acid / position / CDR3 length combinations and V/J identity, saving the inferred model to a file. Then the model is loaded into the EvaluateModel to generate sequences before and after selection, and calculate probabilities and energies for the generated sequences.

Free use of SONIA is granted under the terms of the GNU General Public License version 3 (GPLv3).
## Installation

The provided `environment.yml` file describes the SONIA dependencies.

---

Free use of SONIA is granted under the terms of the GNU General Public License version 3 (GPLv3).
11 changes: 11 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
name: sonia
channels:
- anaconda
dependencies:
- python=2.7
- matplotlib
- numpy
- pip
- pip:
- olga
- tensorflow==1.15
26 changes: 13 additions & 13 deletions sonia.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class Sonia(object):
'humanTRB' (default), 'humanIGH', and 'mouseTRB'.
l2_reg : float or None
L2 regularization. If None (default) then no regularization.

Methods
----------
seq_feature_proj(feature, seq)
Expand Down Expand Up @@ -121,12 +121,12 @@ def __init__(self, features = [], data_seqs = [], gen_seqs = [], load_model = No
self.update_model(add_data_seqs = data_seqs, add_gen_seqs = gen_seqs)
self.update_model_structure(initialize=True)
self.L1_converge_history = []

if seed is not None:
np.random.seed(seed = seed)

self.amino_acids = 'ACDEFGHIKLMNPQRSTVWY'

def seq_feature_proj(self, feature, seq):
"""Checks if a sequence matches all subfeatures of the feature list

Expand Down Expand Up @@ -211,12 +211,12 @@ def compute_seq_energy(self, seq = None, seq_features = None):

def compute_energy(self,seqs_features):
"""Computes the energy of a list of sequences according to the model.

Parameters
----------
seqs_features : list
list of encoded sequences into sonia features.

Returns
-------
E : float
Expand All @@ -236,7 +236,7 @@ def _encode_data(self,seq_features):
for i in range(len(data_enc)): data_enc[i][data[i]] = 1
return data_enc


def compute_marginals(self, features = None, seq_model_features = None, seqs = None, use_flat_distribution = False, output_dict = False):
"""Computes the marginals of each feature over sequences.
Computes marginals either with a flat distribution over the sequences
Expand Down Expand Up @@ -430,7 +430,7 @@ def update_model(self, add_data_seqs = [], add_gen_seqs = [], add_features = [],
self.features = np.append(self.features, add_features)
self.update_model_structure(initialize=True)
self.feature_dict = {tuple(f): i for i, f in enumerate(self.features)}

if (len(add_data_seqs + add_features + remove_features) > 0 or auto_update_seq_features) and len(self.features)>0:
self.data_seq_features = [self.find_seq_features(seq) for seq in self.data_seqs]

Expand Down Expand Up @@ -535,7 +535,7 @@ def plot_model_learning(self, save_name = None):

plt.legend(frameon = False, loc = 2)
plt.title('L1 Distance convergence', fontsize = 15)

fig.add_subplot(132)

plt.loglog(self.data_marginals, self.gen_marginals, 'r.', alpha = 0.2, markersize=1)
Expand All @@ -551,7 +551,7 @@ def plot_model_learning(self, save_name = None):
plt.ylabel('Marginals over generated sequences', fontsize = 13)
plt.legend(loc = 2, fontsize = 10)
plt.title('Marginal Scatter', fontsize = 15)

fig.add_subplot(133)
plt.title('Likelihood', fontsize = 15)
plt.plot(self.learning_history.history['likelihood'],label='train',c='k')
Expand Down Expand Up @@ -654,7 +654,7 @@ def load_model(self, load_dir, load_seqs = True):
split_line = line.split('\t')
self.data_seqs.append(split_line[0].split(';'))
self.data_seq_features.append([self.feature_dict[tuple(f.split(','))] for f in split_line[2].split(';') if tuple(f.split(',')) in self.feature_dict])
else:
else:
print 'Cannot find data_seqs.tsv -- no data seqs loaded.'

if os.path.isfile(os.path.join(load_dir, 'gen_seqs.tsv')) and load_seqs:
Expand Down Expand Up @@ -698,7 +698,7 @@ def likelihood(y_true, y_pred):
return gen-data

class computeL1(keras.callbacks.Callback):

def __init__(self, sonia):
self.data_marginals = sonia.data_marginals
self.sonia=sonia
Expand All @@ -713,7 +713,7 @@ def on_train_begin(self, logs={}):

def return_model_marginals(self):
marginals = np.zeros(self.len_features)
Qs = np.exp(-self.model.predict(self.encoded_data)[:, 0])
Qs = np.exp(-self.model.predict(self.encoded_data)[:, 0])
for i in range(len(self.gen_enc)):
marginals[self.gen_enc[i]] += Qs[i]
return marginals / np.sum(Qs)
Expand All @@ -722,4 +722,4 @@ def on_epoch_end(self, epoch, logs={}):
curr_loss = logs.get('loss')
curr_loss_val = logs.get('val_loss')
self.L1_history.append(np.sum(np.abs(self.return_model_marginals() - self.data_marginals)))
print "epoch = ", epoch, " loss = ", np.around(curr_loss, decimals=4) , " val_loss = ", np.around(curr_loss_val, decimals=4), " L1 dist: ", np.around(self.L1_history[-1], decimals=4)
print "epoch = ", epoch, " loss = ", np.around(curr_loss, decimals=4) , " val_loss = ", np.around(curr_loss_val, decimals=4), " L1 dist: ", np.around(self.L1_history[-1], decimals=4)