diff --git a/app.py b/app.py index ed5de36..4eb793f 100644 --- a/app.py +++ b/app.py @@ -81,7 +81,7 @@ def get_nearest_neighbors(self, smile, ms): id_top = np.argmax(np.array(scores)) - return id_top + return id_top, max(scores) def extract_model_info(directory): models_info = [] @@ -193,6 +193,24 @@ def predict(): all_predictions = {} all_ads = {} model_info_list = [] + data_dict = { + "Model": {}, + "Input": {'SMILES': {}, 'Structure': {}}, + "Nearest Neighbor": {'Structure': {}, 'Similarity': {}, 'Source': {}, 'Experimental value': {}}, + "Output": {'Predicted pChEMBL Value': {}, 'Within Applicability Domain': {}, 'Download QPRF': {}} + } + tooltip_dict = {"Model": 'Unique identifier of the model that made the prediction', + "Structure": '2D depiction of molecule', + "SMILES": 'Line representation of molecule', + "Nearest Neighbor": 'Molecule from the model training set that is the most similar to input molecule', + "Similarity": 'Tanimoto similarity score based on same chemical descriptor as used for model', + "Source": 'Document(s) containing experimental data for nearest neighbor', + "Experimental value": 'Average of all reported experimental values', + "Predicted pChEMBL Value": 'Model prediction for input molecule. pChEMBL is defined as -log(response). More information available in QMRF & QPRF', + "Within Applicability Domain": 'AD is based on descriptors of training set. An input molecule is within AD if the distance to the training set is lower than a set threshold. More information available in QMRF & QPRF', + "Download QPRF": 'Automatically filled in report about the prediction' + } + for model_name in model_names: logging.debug(f"Processing model: {model_name}") model_path = os.path.join(MODELS_DIR, model_name, f"{model_name}_meta.json") @@ -244,19 +262,9 @@ def predict(): table_data.append(row) # Update headers - table_data_extensive = [] headers = ['Structure', 'SMILES'] tooltips = ['2D depiction of input molecule', 'Line representation of input molecule'] - headers_extensive = ['Model', 'Structure', 'SMILES', 'Nearest Neighbor', 'Source', 'Predicted pChEMBL Value', 'Within Applicability Domain'] - tooltips_extensive = [ - 'Unique identifier of the model that made the prediction', - '2D depiction of input molecule', - 'Line representation of input molecule', - '2D depiction of nearest neighbor of input molecule in model training set. More information available in QPRF', - 'Document(s) containing experimental data for nearest neighbor', - 'Model prediction for input molecule. pChEMBL is defined as -log(response). More information available in QMRF & QPRF', - 'AD is based on descriptors of training set. An input molecule is within AD if the distance to the training set is lower than a set threshold. More information available in QMRF & QPRF', - ] + searcher = SimilaritySearcher() for model_name in model_names: accession = model_name.split("_")[0] @@ -288,7 +296,7 @@ def predict(): for i, smile in enumerate(smiles_list): image_data = smiles_to_image(smile) - id_top = searcher.get_nearest_neighbors(smile, ms) + id_top, score = searcher.get_nearest_neighbors(smile, ms) nearest_neighbor = {} nn_smiles = train_df.iloc[id_top]['SMILES'] nearest_neighbor["smiles"] = nn_smiles @@ -298,20 +306,33 @@ def predict(): nearest_neighbor["predicted_value"] = model.predictMols([nn_smiles])[0][0] nearest_neighbor["similarity"] = f"Nearest neighbor was found using {searcher.scorer.__name__} based on {searcher.descgen.__class__.__name__}" image_data_nn = smiles_to_image(nn_smiles) - if getattr(model, 'applicabilityDomain', None): - row = [model_name] + [image_data] + [smile] + [image_data_nn] + [nn_smiles] + [doi_nn] + [all_predictions[model_name][i]] + [all_ads[model_name][i]] - else: - row = [model_name] + [image_data] + [smile] + [image_data_nn] + [nn_smiles] + [doi_nn] + [all_predictions[model_name][i]] - - table_data_extensive.append(row) render_qprf(smile, model, predictions[i], ad[i], nearest_neighbor) - + data_dict["Model"].setdefault("value", []).append(model_name) + data_dict["Input"]["Structure"].setdefault("image", []).append(image_data) + data_dict["Input"]["SMILES"].setdefault("value", []).append(smile) + data_dict["Nearest Neighbor"]["Structure"].setdefault("image", []).append(image_data_nn) + data_dict["Nearest Neighbor"]["Similarity"].setdefault("value", []).append(f"{score:.2f}") + data_dict["Nearest Neighbor"]["Source"].setdefault("value", []).append(doi_nn) + data_dict["Nearest Neighbor"]["Experimental value"].setdefault("value", []).append(nearest_neighbor["value"]) + data_dict["Output"]["Predicted pChEMBL Value"].setdefault("value", []).append(all_predictions[model_name][i]) + data_dict["Output"]["Within Applicability Domain"].setdefault("value", []).append(all_ads[model_name][i]) + data_dict["Output"]["Download QPRF"].setdefault("url", []).append("hi") error_message = None if invalid_smiles: error_message = f"Invalid SMILES, could not be processed: {', '.join(invalid_smiles)}" # Mention invalid SMILES in error message - return render_template('index.html', models=available_models, headers=headers, tooltips=tooltips, data=table_data, headers_extensive=headers_extensive, tooltips_extensive = tooltips_extensive, data_extensive=table_data_extensive, smiles_input=smiles_input, model_names=model_names, file_name=file_name, error=error_message) + return render_template('index.html', + models=available_models, + data_dict=data_dict, + headers=headers, + tooltips=tooltips, + data=table_data, + tooltips_extensive = tooltip_dict, + smiles_input=smiles_input, + model_names=model_names, + file_name=file_name, + error=error_message) except Exception: logging.exception("An error occurred while processing the request.") return render_template('index.html', models=available_models, error="An error occurred while processing the request.") diff --git a/static/styles.css b/static/styles.css index 6a76835..17656b0 100644 --- a/static/styles.css +++ b/static/styles.css @@ -38,6 +38,16 @@ header { } +.prediction-container { + max-width: 1200px; + margin: 20px auto; + padding: 20px; + background-color: white; + border-radius: 10px; + box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); + overflow: hidden; + } + .table-container { max-height: 400px; overflow-y: auto; diff --git a/templates/index.html b/templates/index.html index 621f33d..354a33a 100644 --- a/templates/index.html +++ b/templates/index.html @@ -18,12 +18,12 @@