diff --git a/app.py b/app.py
index ed5de36..4eb793f 100644
--- a/app.py
+++ b/app.py
@@ -81,7 +81,7 @@ def get_nearest_neighbors(self, smile, ms):
id_top = np.argmax(np.array(scores))
- return id_top
+ return id_top, max(scores)
def extract_model_info(directory):
models_info = []
@@ -193,6 +193,24 @@ def predict():
all_predictions = {}
all_ads = {}
model_info_list = []
+ data_dict = {
+ "Model": {},
+ "Input": {'SMILES': {}, 'Structure': {}},
+ "Nearest Neighbor": {'Structure': {}, 'Similarity': {}, 'Source': {}, 'Experimental value': {}},
+ "Output": {'Predicted pChEMBL Value': {}, 'Within Applicability Domain': {}, 'Download QPRF': {}}
+ }
+ tooltip_dict = {"Model": 'Unique identifier of the model that made the prediction',
+ "Structure": '2D depiction of molecule',
+ "SMILES": 'Line representation of molecule',
+ "Nearest Neighbor": 'Molecule from the model training set that is the most similar to input molecule',
+ "Similarity": 'Tanimoto similarity score based on same chemical descriptor as used for model',
+ "Source": 'Document(s) containing experimental data for nearest neighbor',
+ "Experimental value": 'Average of all reported experimental values',
+ "Predicted pChEMBL Value": 'Model prediction for input molecule. pChEMBL is defined as -log(response). More information available in QMRF & QPRF',
+ "Within Applicability Domain": 'AD is based on descriptors of training set. An input molecule is within AD if the distance to the training set is lower than a set threshold. More information available in QMRF & QPRF',
+ "Download QPRF": 'Automatically filled in report about the prediction'
+ }
+
for model_name in model_names:
logging.debug(f"Processing model: {model_name}")
model_path = os.path.join(MODELS_DIR, model_name, f"{model_name}_meta.json")
@@ -244,19 +262,9 @@ def predict():
table_data.append(row)
# Update headers
- table_data_extensive = []
headers = ['Structure', 'SMILES']
tooltips = ['2D depiction of input molecule', 'Line representation of input molecule']
- headers_extensive = ['Model', 'Structure', 'SMILES', 'Nearest Neighbor', 'Source', 'Predicted pChEMBL Value', 'Within Applicability Domain']
- tooltips_extensive = [
- 'Unique identifier of the model that made the prediction',
- '2D depiction of input molecule',
- 'Line representation of input molecule',
- '2D depiction of nearest neighbor of input molecule in model training set. More information available in QPRF',
- 'Document(s) containing experimental data for nearest neighbor',
- 'Model prediction for input molecule. pChEMBL is defined as -log(response). More information available in QMRF & QPRF',
- 'AD is based on descriptors of training set. An input molecule is within AD if the distance to the training set is lower than a set threshold. More information available in QMRF & QPRF',
- ]
+
searcher = SimilaritySearcher()
for model_name in model_names:
accession = model_name.split("_")[0]
@@ -288,7 +296,7 @@ def predict():
for i, smile in enumerate(smiles_list):
image_data = smiles_to_image(smile)
- id_top = searcher.get_nearest_neighbors(smile, ms)
+ id_top, score = searcher.get_nearest_neighbors(smile, ms)
nearest_neighbor = {}
nn_smiles = train_df.iloc[id_top]['SMILES']
nearest_neighbor["smiles"] = nn_smiles
@@ -298,20 +306,33 @@ def predict():
nearest_neighbor["predicted_value"] = model.predictMols([nn_smiles])[0][0]
nearest_neighbor["similarity"] = f"Nearest neighbor was found using {searcher.scorer.__name__} based on {searcher.descgen.__class__.__name__}"
image_data_nn = smiles_to_image(nn_smiles)
- if getattr(model, 'applicabilityDomain', None):
- row = [model_name] + [image_data] + [smile] + [image_data_nn] + [nn_smiles] + [doi_nn] + [all_predictions[model_name][i]] + [all_ads[model_name][i]]
- else:
- row = [model_name] + [image_data] + [smile] + [image_data_nn] + [nn_smiles] + [doi_nn] + [all_predictions[model_name][i]]
-
- table_data_extensive.append(row)
render_qprf(smile, model, predictions[i], ad[i], nearest_neighbor)
-
+ data_dict["Model"].setdefault("value", []).append(model_name)
+ data_dict["Input"]["Structure"].setdefault("image", []).append(image_data)
+ data_dict["Input"]["SMILES"].setdefault("value", []).append(smile)
+ data_dict["Nearest Neighbor"]["Structure"].setdefault("image", []).append(image_data_nn)
+ data_dict["Nearest Neighbor"]["Similarity"].setdefault("value", []).append(f"{score:.2f}")
+ data_dict["Nearest Neighbor"]["Source"].setdefault("value", []).append(doi_nn)
+ data_dict["Nearest Neighbor"]["Experimental value"].setdefault("value", []).append(nearest_neighbor["value"])
+ data_dict["Output"]["Predicted pChEMBL Value"].setdefault("value", []).append(all_predictions[model_name][i])
+ data_dict["Output"]["Within Applicability Domain"].setdefault("value", []).append(all_ads[model_name][i])
+ data_dict["Output"]["Download QPRF"].setdefault("url", []).append("hi")
error_message = None
if invalid_smiles:
error_message = f"Invalid SMILES, could not be processed: {', '.join(invalid_smiles)}" # Mention invalid SMILES in error message
- return render_template('index.html', models=available_models, headers=headers, tooltips=tooltips, data=table_data, headers_extensive=headers_extensive, tooltips_extensive = tooltips_extensive, data_extensive=table_data_extensive, smiles_input=smiles_input, model_names=model_names, file_name=file_name, error=error_message)
+ return render_template('index.html',
+ models=available_models,
+ data_dict=data_dict,
+ headers=headers,
+ tooltips=tooltips,
+ data=table_data,
+ tooltips_extensive = tooltip_dict,
+ smiles_input=smiles_input,
+ model_names=model_names,
+ file_name=file_name,
+ error=error_message)
except Exception:
logging.exception("An error occurred while processing the request.")
return render_template('index.html', models=available_models, error="An error occurred while processing the request.")
diff --git a/static/styles.css b/static/styles.css
index 6a76835..17656b0 100644
--- a/static/styles.css
+++ b/static/styles.css
@@ -38,6 +38,16 @@ header {
}
+.prediction-container {
+ max-width: 1200px;
+ margin: 20px auto;
+ padding: 20px;
+ background-color: white;
+ border-radius: 10px;
+ box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
+ overflow: hidden;
+ }
+
.table-container {
max-height: 400px;
overflow-y: auto;
diff --git a/templates/index.html b/templates/index.html
index 621f33d..354a33a 100644
--- a/templates/index.html
+++ b/templates/index.html
@@ -18,12 +18,12 @@