diff --git a/notebooks/00_method/C2H4_mol.png b/notebooks/00_method/C2H4_mol.png new file mode 100644 index 0000000..6ea31d4 Binary files /dev/null and b/notebooks/00_method/C2H4_mol.png differ diff --git a/notebooks/00_method/H.png b/notebooks/00_method/H.png new file mode 100644 index 0000000..545319d Binary files /dev/null and b/notebooks/00_method/H.png differ diff --git a/notebooks/00_method/Li.png b/notebooks/00_method/Li.png new file mode 100644 index 0000000..12ca1c4 Binary files /dev/null and b/notebooks/00_method/Li.png differ diff --git a/notebooks/00_method/NH3_mol.png b/notebooks/00_method/NH3_mol.png new file mode 100644 index 0000000..02fcd8d Binary files /dev/null and b/notebooks/00_method/NH3_mol.png differ diff --git a/notebooks/00_method/atoms.png b/notebooks/00_method/atoms.png new file mode 100644 index 0000000..0bae1f2 Binary files /dev/null and b/notebooks/00_method/atoms.png differ diff --git a/notebooks/00_method/bbmep_1106_mols/mol_00.png b/notebooks/00_method/bbmep_1106_mols/mol_00.png new file mode 100644 index 0000000..a8b87c2 Binary files /dev/null and b/notebooks/00_method/bbmep_1106_mols/mol_00.png differ diff --git a/notebooks/00_method/bbmep_1106_mols/mol_04.png b/notebooks/00_method/bbmep_1106_mols/mol_04.png new file mode 100644 index 0000000..364aeec Binary files /dev/null and b/notebooks/00_method/bbmep_1106_mols/mol_04.png differ diff --git a/notebooks/00_method/bbmep_1106_mols/mol_09.png b/notebooks/00_method/bbmep_1106_mols/mol_09.png new file mode 100644 index 0000000..c569ef8 Binary files /dev/null and b/notebooks/00_method/bbmep_1106_mols/mol_09.png differ diff --git a/notebooks/00_method/bbmep_1106_mols/mol_14.png b/notebooks/00_method/bbmep_1106_mols/mol_14.png new file mode 100644 index 0000000..21dfa5b Binary files /dev/null and b/notebooks/00_method/bbmep_1106_mols/mol_14.png differ diff --git a/notebooks/00_method/bbmep_1106_mols/mol_19.png b/notebooks/00_method/bbmep_1106_mols/mol_19.png new file mode 100644 index 0000000..32ca3a8 Binary files /dev/null and b/notebooks/00_method/bbmep_1106_mols/mol_19.png differ diff --git a/notebooks/00_method/bending.png b/notebooks/00_method/bending.png new file mode 100644 index 0000000..0e06518 Binary files /dev/null and b/notebooks/00_method/bending.png differ diff --git a/notebooks/00_method/breaking.png b/notebooks/00_method/breaking.png new file mode 100644 index 0000000..a86beac Binary files /dev/null and b/notebooks/00_method/breaking.png differ diff --git a/notebooks/00_method/create_figures.ipynb b/notebooks/00_method/create_figures.ipynb new file mode 100644 index 0000000..ed8447c --- /dev/null +++ b/notebooks/00_method/create_figures.ipynb @@ -0,0 +1,1180 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0dcbc5ce", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from scipy.interpolate import CubicSpline\n", + "\n", + "import matplotlib as mpl\n", + "import matplotlib.font_manager as fm\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.image as mpimg\n", + "import matplotlib.gridspec as gridspec\n", + "import matplotlib.patches as patches\n", + "\n", + "from matplotlib.offsetbox import OffsetImage, AnnotationBbox" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9b83cf8", + "metadata": {}, + "outputs": [], + "source": [ + "# On Linux: Manually download segoeui font and load and register here\n", + "\n", + "from pathlib import Path\n", + "\n", + "font_path = f'{Path.home()}/.local/share/fonts/SEGOEUI.TTF'\n", + "segoe_prop = fm.FontProperties(fname=font_path)\n", + "\n", + "mpl.rcParams['font.family'] = segoe_prop.get_name()\n", + "fm.fontManager.addfont(font_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b03dd7fb", + "metadata": {}, + "outputs": [], + "source": [ + "# Matplotlib settings\n", + "\n", + "mpl.rcParams[\"text.usetex\"] = False\n", + "mpl.rcParams['mathtext.fontset'] = 'stix'\n", + "mpl.rcParams['mathtext.cal'] = 'stix'\n", + "mpl.rcParams['mathtext.rm'] = 'stix'\n", + "mpl.rcParams['mathtext.it'] = 'stix:italic'\n", + "mpl.rcParams['mathtext.bf'] = 'stix:bold'\n", + "mpl.rcParams[\"font.size\"] = 10\n", + "\n", + "%config InlineBackend.print_figure_kwargs={'bbox_inches': None}\n", + "\n", + "model_color = \"#dff2f4\"\n", + "array_color = \"#e7dff6\"\n", + "dataset_color = \"#fff1ee\"\n", + "result_color = \"#ffd566\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c417f1bd", + "metadata": {}, + "outputs": [], + "source": [ + "TEXTWIDTH = 7.08\n", + "\n", + "fig = plt.figure(figsize=(TEXTWIDTH, 0.5 * TEXTWIDTH))\n", + "spec = gridspec.GridSpec(ncols=2, nrows=1, figure=fig)\n", + "ax1 = fig.add_subplot(spec[0, 0])\n", + "ax2 = fig.add_subplot(spec[0, 1])\n", + "\n", + "lw = 0.5\n", + "lw_bold = 2 * lw\n", + "lw_box = lw\n", + "\n", + "\n", + "def annotate(\n", + " ax,\n", + " text,\n", + " xy,\n", + " xytext=None,\n", + " ha='center',\n", + " va='center',\n", + " arrowprops={},\n", + " arrow=False,\n", + " **kwargs,\n", + "):\n", + " ap = {\n", + " 'color': 'k',\n", + " 'arrowstyle': '-|>, head_width=0.15, head_length=0.3',\n", + " 'linewidth': lw,\n", + " }\n", + " arrowprops = {**ap, **arrowprops} if (arrow or arrowprops) else None\n", + " ax.annotate(\n", + " text,\n", + " xy,\n", + " xytext,\n", + " xycoords='figure fraction',\n", + " textcoords='figure fraction',\n", + " ha=ha,\n", + " va=va,\n", + " arrowprops=arrowprops,\n", + " **kwargs,\n", + " )\n", + "\n", + "\n", + "ax1.axis('off')\n", + "ax2.axis('off')\n", + "\n", + "# SCAFFOLDING\n", + "for l in [\n", + " ((0.0015, -0.1), (0.0015, 1.1)),\n", + " ((0.5, -0.1), (0.5, 1.1)),\n", + " ((0.26, 0.44), (0.502, 0.44)),\n", + " ((0.05, 0.44), (0.23, 0.44)),\n", + " ((-0.1, 0.723), (0.35, 0.723)),\n", + " ((0.6, 0.66), (1.1, 0.66)),\n", + " ((0.498, 0.39), (0.66, 0.39)),\n", + "]:\n", + " annotate(ax1, None, *l, arrowprops={'color': 'lightgrey', 'arrowstyle': '-'})\n", + "\n", + "\n", + "# SUBFIGURE PRETRAIN\n", + "annotate(ax1, 'Chemical Pretraining', (0.25, 0.95), fontsize=10)\n", + "\n", + "# DATASET CONSTRUCTION\n", + "\n", + "arrimg = mpimg.imread(f\"atoms.png\")\n", + "imagebox = OffsetImage(arrimg, zoom=0.09)\n", + "ab = AnnotationBbox(\n", + " imagebox,\n", + " (0.1, 0.77),\n", + " frameon=False,\n", + " xycoords='figure fraction',\n", + " box_alignment=(0.5, 0.0),\n", + ")\n", + "ax1.add_artist(ab)\n", + "\n", + "\n", + "arrimg = mpimg.imread(\"NH3_mol.png\")\n", + "imagebox = OffsetImage(arrimg, zoom=0.07)\n", + "ab = AnnotationBbox(imagebox, (0.25, 0.8), frameon=False, xycoords='figure fraction')\n", + "ax1.add_artist(ab)\n", + "\n", + "for i, action in enumerate(['Bending', 'Stretching', 'Breaking']):\n", + " arrimg = mpimg.imread(f\"{action.lower()}.png\")\n", + " imagebox = OffsetImage(arrimg, zoom=0.045)\n", + " ab = AnnotationBbox(\n", + " imagebox,\n", + " (0.35 + 0.055 * i, 0.78),\n", + " frameon=False,\n", + " xycoords='figure fraction',\n", + " box_alignment=(0.5, 0.0),\n", + " )\n", + " ax1.add_artist(ab)\n", + " annotate(ax1, action, (0.35 + 0.055 * i, 0.755), fontsize=6)\n", + "\n", + "annotate(ax1, None, (0.2, 0.8), (0.17, 0.8), arrow=True)\n", + "annotate(ax1, None, (0.32, 0.8), (0.29, 0.8), arrow=True)\n", + "\n", + "annotate(ax1, 'Light Atom Species', (0.1, 0.89), fontsize=8)\n", + "annotate(ax1, 'Assemble Molecules', (0.25, 0.89), fontsize=8)\n", + "annotate(ax1, 'Distort Geometries', (0.4, 0.89), fontsize=8)\n", + "\n", + "annotate(ax1, 'H', (0.056, 0.812), fontsize=6)\n", + "annotate(ax1, 'Li', (0.085, 0.812), fontsize=6)\n", + "annotate(ax1, 'B', (0.1135, 0.812), fontsize=6)\n", + "annotate(ax1, 'C', (0.1437, 0.812), fontsize=6)\n", + "annotate(ax1, 'N', (0.07, 0.75), fontsize=6)\n", + "annotate(ax1, 'O', (0.1, 0.75), fontsize=6)\n", + "annotate(ax1, 'F', (0.129, 0.75), fontsize=6)\n", + "\n", + "# LAC DATASET\n", + "\n", + "(x, y) = (0, 0)\n", + "annotate(ax1, None, (0.39 + x, 0.69 + y), (0.39, 0.74), arrow=True)\n", + "annotate(ax1, 'LAC Dataset', (0.31 + x, 0.69 + y), fontsize=8)\n", + "empty_box = patches.FancyBboxPatch(\n", + " (0.58 + x / 2, 0.465 + y),\n", + " 0.48,\n", + " 0.2,\n", + " facecolor=dataset_color,\n", + " edgecolor='k',\n", + " linewidth=lw_box,\n", + " zorder=0,\n", + " clip_on=False,\n", + " boxstyle=\"round,pad=0.0,rounding_size=0.02\",\n", + ")\n", + "ax1.add_patch(empty_box)\n", + "annotate(ax1, '• 88 distinct molecules', (0.275 + x, 0.64 + y), ha=\"left\", fontsize=6)\n", + "annotate(ax1, '• 22350 structures', (0.275 + x, 0.61 + y), ha=\"left\", fontsize=6)\n", + "annotate(ax1, '• up to 22 electrons ', (0.275 + x, 0.58 + y), ha=\"left\", fontsize=6)\n", + "annotate(\n", + " ax1, '• organic & inorganic chemistry', (0.275 + x, 0.55 + y), ha=\"left\", fontsize=6\n", + ")\n", + "annotate(\n", + " ax1, '• non-equilibrium & reactions', (0.275 + x, 0.52 + y), ha=\"left\", fontsize=6\n", + ")\n", + "annotate(\n", + " ax1, '• 45% multi-reference character', (0.275 + x, 0.49 + y), ha=\"left\", fontsize=6\n", + ")\n", + "\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.225 + x, 0.57 + y),\n", + " (0.255 + x, 0.57 + y),\n", + " arrowprops={\"arrowstyle\": \"-\"},\n", + ")\n", + "\n", + "# AUGMENTATION\n", + "\n", + "(x, y) = (-0.11, -0.01)\n", + "\n", + "annotate(ax1, 'Augmentation', (0.24 + x, 0.68 + y), fontsize=8)\n", + "\n", + "# Rotation\n", + "arrimg = mpimg.imread(\"C2H4_mol.png\")\n", + "imagebox = OffsetImage(arrimg, zoom=0.05)\n", + "ab = AnnotationBbox(\n", + " imagebox, (0.205 + x, 0.58 + y), frameon=False, xycoords='figure fraction'\n", + ")\n", + "ax1.add_artist(ab)\n", + "\n", + "for a in [\n", + " ((0.21 + x, 0.52 + y), (0.17 + x, 0.61 + y)),\n", + " ((0.20 + x, 0.64 + y), (0.24 + x, 0.55 + y)),\n", + "]:\n", + " annotate(\n", + " ax1,\n", + " None,\n", + " *a,\n", + " arrowprops={\n", + " 'connectionstyle': \"angle,angleA=90,angleB=0,rad=18\",\n", + " \"arrowstyle\": \"-|>, head_width=0.08, head_length=0.2\",\n", + " },\n", + " )\n", + "\n", + "annotate(ax1, 'Rotation', (0.205 + x, 0.49 + y), fontsize=7)\n", + "\n", + "# Fuzz\n", + "ab = AnnotationBbox(\n", + " imagebox, (0.28 + x, 0.58 + y), frameon=False, xycoords='figure fraction'\n", + ")\n", + "ax1.add_artist(ab)\n", + "for a in [\n", + " ((0.27 + x, 0.64 + y), (0.288 + x, 0.61 + y)),\n", + " ((0.265 + x, 0.62 + y), (0.256 + x, 0.58 + y)),\n", + " ((0.25 + x, 0.56 + y), (0.273 + x, 0.545 + y)),\n", + " ((0.32 + x, 0.57 + y), (0.2976 + x, 0.5845 + y)),\n", + " ((0.3 + x, 0.62 + y), (0.287 + x, 0.587 + y)),\n", + " ((0.295 + x, 0.555 + y), (0.268 + x, 0.576 + y)),\n", + "]:\n", + " annotate(\n", + " ax1,\n", + " None,\n", + " *a,\n", + " arrowprops={\"arrowstyle\": \"-|>, head_width=0.05, head_length=0.1\"},\n", + " )\n", + "\n", + "annotate(ax1, 'Fuzz', (0.28 + x, 0.49 + y), fontsize=7)\n", + "\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.03, 0.425),\n", + " (0.16 + x, 0.58 + y),\n", + " arrowprops={'connectionstyle': \"angle,angleA=0,angleB=90,rad=8\"},\n", + ")\n", + "\n", + "# PRETRAINING\n", + "x, y = (0, -0.25)\n", + "\n", + "# Arrow to stacked electron positions r\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.055 + x, 0.51 + y),\n", + " (0.03 + x, 0.55 + y),\n", + " arrowprops={'connectionstyle': \"angle,angleA=90,angleB=0,rad=8\"},\n", + ")\n", + "\n", + "# Arrow from r to Orbformer\n", + "annotate(ax1, None, (0.13 + x, 0.51 + y), (0.105 + x, 0.51 + y), arrow=True)\n", + "\n", + "# Arrow from molecular configuration M to Orbformer\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.13 + x, 0.455 + y),\n", + " (0.03 + x, 0.625 + y),\n", + " arrowprops={'connectionstyle': \"angle,angleA=90,angleB=0,rad=8\"},\n", + ")\n", + "\n", + "# Plot M and r boxes\n", + "empty_box = patches.FancyBboxPatch(\n", + " (0.035 + x / 2, 0.62 + y),\n", + " 0.06,\n", + " 0.05,\n", + " facecolor=array_color,\n", + " edgecolor='k',\n", + " linewidth=lw_box,\n", + " clip_on=False,\n", + " boxstyle=\"round,pad=0.0,rounding_size=0.01\",\n", + ")\n", + "ax1.add_patch(empty_box)\n", + "annotate(ax1, r'$\\mathbf{M}$', (0.03 + x, 0.64 + y), fontsize=9)\n", + "annotate(\n", + " ax1,\n", + " r\"$\\{\\mathbf{x}\\}_i$\",\n", + " (0.08 + x, 0.508 + y),\n", + " bbox=dict(\n", + " boxstyle=\"round,pad=0.3\", linewidth=lw_box, edgecolor=\"k\", facecolor=array_color\n", + " ),\n", + " fontsize=9,\n", + " color=\"black\",\n", + ")\n", + "\n", + "# Orbfrormer box\n", + "annotate(\n", + " ax1,\n", + " \"Orbformer\\nWavefunction\",\n", + " (0.2 + x, 0.48 + y),\n", + " bbox=dict(\n", + " boxstyle=\"round,pad=0.8\", linewidth=lw_box, edgecolor=\"k\", facecolor=model_color\n", + " ),\n", + " fontsize=9,\n", + " color=\"black\",\n", + ")\n", + "\n", + "# MCMC arrows and text\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.3 + x, 0.56 + y),\n", + " (0.27 + x, 0.475 + y),\n", + " arrowprops={'connectionstyle': \"angle,angleA=0,angleB=90,rad=8\"},\n", + ")\n", + "ax1.annotate(\n", + " r\"$\\rho_\\mathbf{M}(\\mathbf{x}) = \\left|\\Psi(\\mathbf{x}\\mid\\mathbf{M})\\right|^2$\",\n", + " xycoords='figure fraction',\n", + " textcoords='figure fraction',\n", + " xy=(0.19 + x, 0.58 + y),\n", + " fontsize=9,\n", + " color=\"black\",\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.08 + x, 0.54 + y),\n", + " (0.18 + x, 0.59 + y),\n", + " arrowprops={'connectionstyle': \"angle,angleA=0,angleB=90,rad=8\"},\n", + ")\n", + "ax1.annotate(\n", + " r\"ULA/MALA\",\n", + " xycoords='figure fraction',\n", + " textcoords='figure fraction',\n", + " xy=(0.1 + x, 0.605 + y),\n", + " fontsize=7,\n", + " color=\"black\",\n", + ")\n", + "\n", + "# SGD arrows and text\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.3 + x, 0.39 + y),\n", + " (0.27 + x, 0.475 + y),\n", + " arrowprops={'connectionstyle': \"angle,angleA=0,angleB=90,rad=8\"},\n", + ")\n", + "annotate(\n", + " ax1,\n", + " r\"$\\mathcal{E}_{\\mathbf{M}}\"\n", + " r\" \\approx{\\sum}_i~\\,\\frac{H_\\mathbf{M}\\Psi(\\mathbf{x}_i\\mid\\mathbf{M})}{\\Psi(\\mathbf{x}_i\\mid\\mathbf{M})}$\",\n", + " (0.3 + x, 0.349 + y),\n", + " fontsize=10,\n", + " color=\"black\",\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.2 + x, 0.41 + y),\n", + " (0.22 + x, 0.35 + y),\n", + " arrowprops={'connectionstyle': \"angle,angleA=0,angleB=90,rad=8\"},\n", + ")\n", + "annotate(\n", + " ax1,\n", + " r\"$\\nabla_\\boldsymbol{\\theta}~\\mathcal{E}_\\mathbf{M} \\rightarrow\"\n", + " r\" \\boldsymbol{\\theta}'$\",\n", + " xy=(0.14 + x, 0.37 + y),\n", + " fontsize=9,\n", + " color=\"black\",\n", + ")\n", + "\n", + "# Final parameters\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (x + 0.205, y + 0.285),\n", + " (x + 0.14, y + 0.34),\n", + " fontsize=7,\n", + " arrowprops={'connectionstyle': \"angle,angleA=90,angleB=0,rad=8\"},\n", + ")\n", + "annotate(\n", + " ax1, r\"$\\boldsymbol{\\theta}^*$\", (x + 0.22, y + 0.285), fontsize=10, color=\"black\"\n", + ")\n", + "annotate(\n", + " ax1,\n", + " \", Foundation Model Parameters\",\n", + " (x + 0.34, y + 0.285),\n", + " fontsize=8,\n", + " color=\"black\",\n", + ")\n", + "\n", + "# Training progress subfigure\n", + "axp = ax1.inset_axes((0.772 + x, 0.44 + y, 0.3, 0.13))\n", + "x_plot = np.linspace(0, 1, 50)\n", + "y_plot = (\n", + " 20 - 15 * x_plot + np.random.normal(0, 1, 50) + np.random.normal(0, 5, 5).repeat(10)\n", + ")\n", + "axp.plot(x_plot, y_plot, c=result_color)\n", + "for i in range(1, 6):\n", + " axp.axvline(i * 0.2, color='k', lw=lw)\n", + "axp.set_xlabel('Training Iteration', fontsize=7)\n", + "axp.set_ylabel('Variance', fontsize=7)\n", + "axp.set_xticks([], [])\n", + "axp.set_yticks([])\n", + "axp.set_zorder(0)\n", + "axp.spines['top'].set_visible(False)\n", + "axp.spines['right'].set_visible(False)\n", + "\n", + "# Arrows from training back to M\n", + "for a in [\n", + " ((0.3 + x, 0.64 + y), (0.245 + x, 0.5)),\n", + " ((0.293 + x, 0.64 + y), (0.4811 + x, 0.562 + y)),\n", + " ((0.427 + x, 0.64 + y), (0.4563 + x, 0.562 + y)),\n", + " ((0.402 + x, 0.64 + y), (0.4315 + x, 0.562 + y)),\n", + " ((0.377 + x, 0.64 + y), (0.4067 + x, 0.562 + y)),\n", + " ((0.351 + x, 0.64 + y), (0.3819 + x, 0.562 + y)),\n", + "]:\n", + " annotate(\n", + " ax1,\n", + " None,\n", + " *a,\n", + " arrowprops={\n", + " 'connectionstyle': \"angle,angleA=90,angleB=0,rad=8\",\n", + " 'arrowstyle': '-',\n", + " },\n", + " )\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.215, 0.57),\n", + " (0.245 + x, 0.485),\n", + " arrowprops={'connectionstyle': \"angle,angleA=90,angleB=0,rad=8\"},\n", + ")\n", + "annotate(ax1, r\"draw molecule\", (0.34 + x, 0.658 + y), fontsize=7, color=\"black\")\n", + "\n", + "# SUBFIGURE FINETUNE\n", + "\n", + "annotate(ax2, 'Transferable Finetuning', (0.75, 0.95), fontsize=10)\n", + "annotate(ax2, 'Targeted Chemical Process', (0.62, 0.89), fontsize=8)\n", + "\n", + "# Plot molecules of chemical process\n", + "for i, j in enumerate([0, 4, 9, 14, 19]):\n", + " arrimg = mpimg.imread(f\"bbmep_1106_mols/mol_{(str(0) + str(j))[-2:]}.png\")\n", + " imagebox = OffsetImage(arrimg, zoom=0.06)\n", + " ab = AnnotationBbox(\n", + " imagebox, (0.58 + i * 0.08, 0.8), frameon=False, xycoords='figure fraction'\n", + " )\n", + " ax2.add_artist(ab)\n", + " annotate(\n", + " ax2,\n", + " None,\n", + " (0.55, 0.7009 - i * 0.00047),\n", + " (0.58 + i * 0.08, 0.74),\n", + " arrowprops={\n", + " 'arrowstyle': '-',\n", + " 'connectionstyle': \"angle,angleA=90,angleB=0,rad=8\",\n", + " },\n", + " )\n", + "\n", + "# Arrow to M block\n", + "annotate(\n", + " ax2,\n", + " None,\n", + " (0.53, 0.63),\n", + " (0.558, 0.7),\n", + " arrowprops={\n", + " 'linewidth': lw_bold,\n", + " 'connectionstyle': \"angle,angleA=0,angleB=-90,rad=8\",\n", + " },\n", + ")\n", + "\n", + "\n", + "# FINETUNE\n", + "x, y = (0.5, -0.01)\n", + "\n", + "# Arrow to stacked electron positions r\n", + "annotate(\n", + " ax2,\n", + " None,\n", + " (0.08 + x, 0.51 + y),\n", + " (0.03 + x, 0.55 + y),\n", + " arrowprops={\n", + " 'linewidth': lw_bold,\n", + " 'connectionstyle': \"angle,angleA=90,angleB=0,rad=8\",\n", + " },\n", + ")\n", + "# Arrow from r to Orbformer\n", + "annotate(\n", + " ax2,\n", + " None,\n", + " (0.185 + x, 0.51 + y),\n", + " (0.14 + x, 0.51 + y),\n", + " arrowprops={'linewidth': lw_bold},\n", + ")\n", + "\n", + "# Arrow from molecular configuration M to Orbformer\n", + "annotate(\n", + " ax2,\n", + " None,\n", + " (0.185 + x, 0.44 + y),\n", + " (0.03 + x, 0.575 + y),\n", + " arrowprops={\n", + " 'linewidth': lw_bold,\n", + " 'connectionstyle': \"angle,angleA=90,angleB=0,rad=8\",\n", + " },\n", + ")\n", + "\n", + "# Plot M and r boxes\n", + "for i in range(5):\n", + " empty_box = patches.FancyBboxPatch(\n", + " (-0.05 - i * 0.005, 0.56 + i * 0.005),\n", + " 0.06,\n", + " 0.05,\n", + " facecolor=array_color,\n", + " edgecolor='k',\n", + " linewidth=lw_box,\n", + " clip_on=False,\n", + " boxstyle=\"round,pad=0.0,rounding_size=0.01\",\n", + " )\n", + " ax2.add_patch(empty_box)\n", + " annotate(\n", + " ax2,\n", + " r\"$\\{\\mathbf{x}\\}_i$\",\n", + " (0.115 + x - i * 0.0025, 0.495 + y + i * 0.005),\n", + " bbox=dict(\n", + " boxstyle=\"round,pad=0.3\",\n", + " linewidth=lw_box,\n", + " edgecolor=\"k\",\n", + " facecolor=array_color,\n", + " ),\n", + " fontsize=9,\n", + " color=\"black\",\n", + " )\n", + "\n", + "annotate(ax2, r'$\\mathbf{M}$', (0.0275 + x, 0.61 + y), fontsize=9)\n", + "\n", + "# Orbfrormer box\n", + "annotate(\n", + " ax2,\n", + " r\"Pretrained\" + \"\\nOrbformer\\nWavefunction\",\n", + " (0.255 + x, 0.48 + y),\n", + " bbox=dict(\n", + " boxstyle=\"round,pad=0.8\", linewidth=lw_box, edgecolor=\"k\", facecolor=model_color\n", + " ),\n", + " fontsize=9,\n", + " color=\"black\",\n", + ")\n", + "\n", + "# MCMC arrows and text\n", + "annotate(\n", + " ax2,\n", + " None,\n", + " (0.4 + x, 0.58 + y),\n", + " (0.33 + x, 0.48 + y),\n", + " arrowprops={\n", + " 'linewidth': lw_bold,\n", + " 'connectionstyle': \"angle,angleA=0,angleB=90,rad=8\",\n", + " },\n", + ")\n", + "annotate(\n", + " ax2,\n", + " r\"$\\rho_\\mathbf{M}(\\mathbf{x}) = \\left|\\Psi(\\mathbf{x}\\mid\\mathbf{M})\\right|^2$\",\n", + " (0.41 + x, 0.615 + y),\n", + " fontsize=9,\n", + " color=\"black\",\n", + ")\n", + "annotate(\n", + " ax2,\n", + " None,\n", + " (0.107 + x, 0.55 + y),\n", + " (0.33 + x, 0.615 + y),\n", + " arrowprops={\n", + " 'linewidth': lw_bold,\n", + " 'connectionstyle': \"angle,angleA=0,angleB=90,rad=8\",\n", + " },\n", + ")\n", + "annotate(ax2, r\"MALA\", (0.22 + x, 0.635 + y), fontsize=7, color=\"black\")\n", + "\n", + "# SGD arrows and text\n", + "annotate(\n", + " ax2,\n", + " None,\n", + " (0.4 + x, 0.32 + y),\n", + " (0.33 + x, 0.48 + y),\n", + " arrowprops={\n", + " 'linewidth': lw_bold,\n", + " 'connectionstyle': \"angle,angleA=0,angleB=90,rad=8\",\n", + " },\n", + ")\n", + "annotate(\n", + " ax2,\n", + " r\"$\\mathcal{E}_{\\mathbf{M}}\"\n", + " r\" \\approx{\\sum}_{i}~\\,\\frac{H_\\mathbf{M}\\Psi(\\mathbf{x}_i\\mid\\mathbf{M})}{\\Psi(\\mathbf{x}_i\\mid\\mathbf{M})}$\",\n", + " (0.4 + x, 0.27 + y),\n", + " fontsize=10,\n", + " color=\"black\",\n", + ")\n", + "annotate(\n", + " ax2,\n", + " None,\n", + " (0.25 + x, 0.39 + y),\n", + " (0.32 + x, 0.3 - 0.02 + y),\n", + " arrowprops={\n", + " 'linewidth': lw_bold,\n", + " 'connectionstyle': \"angle,angleA=0,angleB=90,rad=8\",\n", + " },\n", + ")\n", + "annotate(\n", + " ax2,\n", + " r\"$\\nabla_\\boldsymbol{\\theta}~{\\sum}_{\\mathbf{M}}~\\,\\mathcal{E}_\\mathbf{M}\"\n", + " r\" \\rightarrow\\boldsymbol{\\theta}'$\",\n", + " (0.18 + x, 0.33 + y),\n", + " fontsize=9,\n", + " color=\"black\",\n", + ")\n", + "\n", + "# Arrows to PES datapoints\n", + "annotate(\n", + " ax2,\n", + " None,\n", + " (0.06 + x, 0.08),\n", + " (0.32 + x, 0.29 - 0.02 + y),\n", + " arrowprops={'connectionstyle': \"angle,angleA=0,angleB=-90,rad=20\"},\n", + ")\n", + "annotate(\n", + " ax2,\n", + " None,\n", + " (0.125 + x, 0.08),\n", + " (0.32 + x, 0.286 - 0.02 + y),\n", + " arrowprops={'connectionstyle': \"angle,angleA=0,angleB=-90,rad=16\"},\n", + ")\n", + "annotate(\n", + " ax2,\n", + " None,\n", + " (0.207 + x, 0.08),\n", + " (0.32 + x, 0.282 - 0.02 + y),\n", + " arrowprops={'connectionstyle': \"angle,angleA=0,angleB=-90,rad=12\"},\n", + ")\n", + "annotate(\n", + " ax2,\n", + " None,\n", + " (0.288 + x, 0.14),\n", + " (0.32 + x, 0.278 - 0.02 + y),\n", + " arrowprops={'connectionstyle': \"angle,angleA=0,angleB=-90,rad=8\"},\n", + ")\n", + "annotate(\n", + " ax2,\n", + " None,\n", + " (0.37 + x, 0.17),\n", + " (0.37 + x, 0.255 - 0.02 + y),\n", + " arrowprops={'connectionstyle': \"angle,angleA=0,angleB=-90,rad=0\"},\n", + ")\n", + "\n", + "axp = ax2.inset_axes((-0.002, 0.07 + y, 0.75, 0.16))\n", + "x = np.arange(0, 20, 1)\n", + "y = np.array([\n", + " -323.76290,\n", + " -323.77085,\n", + " -323.76719,\n", + " -323.76160,\n", + " -323.76803,\n", + " -323.77344,\n", + " -323.75211,\n", + " -323.75864,\n", + " -323.76477,\n", + " -323.76331,\n", + " -323.75992,\n", + " -323.76476,\n", + " -323.75962,\n", + " -323.76217,\n", + " -323.66576,\n", + " -323.60473,\n", + " -323.52371,\n", + " -323.60698,\n", + " -323.60705,\n", + " -323.61342,\n", + "])\n", + "y_err = np.array([\n", + " 22e-5,\n", + " 26e-5,\n", + " 29e-5,\n", + " 28e-5,\n", + " 33e-5,\n", + " 10e-5,\n", + " 29e-5,\n", + " 3e-5,\n", + " 25e-5,\n", + " 23e-5,\n", + " 11e-5,\n", + " 15e-5,\n", + " 11e-5,\n", + " 5e-5,\n", + " 19e-5,\n", + " 8e-5,\n", + " 5e-5,\n", + " 9e-5,\n", + " 35e-5,\n", + " 9e-5,\n", + "])\n", + "idxs = [0, 4, 9, 14, 19]\n", + "xnew = np.linspace(0, 19, num=1000)\n", + "ynew = np.interp(xnew, x, y)\n", + "ynew = CubicSpline(x, y)(xnew)\n", + "axp.plot(xnew, ynew, ls=':', c='grey', lw=1)\n", + "axp.errorbar(x[idxs], y[idxs], y_err[idxs], marker='.', ms=7, ls='', c=result_color)\n", + "axp.set_xlabel('Reaction Coordinate', fontsize=7)\n", + "axp.set_ylabel('Energy', fontsize=7)\n", + "axp.set_xticks([])\n", + "axp.set_yticks([])\n", + "axp.spines['top'].set_visible(False)\n", + "axp.spines['right'].set_visible(False)\n", + "axp.set_zorder(0)\n", + "\n", + "fig.subplots_adjust(left=0, right=1, bottom=0, top=1)\n", + "fig.savefig('method.pdf', dpi=600)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d73ccab5", + "metadata": {}, + "outputs": [], + "source": [ + "nuc_color = \"#dff2f4\"\n", + "elec_color = 'lightyellow'\n", + "array_color = 'lightgrey'\n", + "\n", + "fig = plt.figure(figsize=(0.5 * TEXTWIDTH, 0.5 * TEXTWIDTH))\n", + "spec = gridspec.GridSpec(ncols=1, nrows=1, figure=fig)\n", + "ax1 = fig.add_subplot(spec[0, 0])\n", + "\n", + "ax1.axis('off')\n", + "ax1.annotate(\n", + " 'Orbformer Wavefunction',\n", + " (0.25 * 2, 0.95),\n", + " xycoords='figure fraction',\n", + " ha=\"center\",\n", + " va=\"center\",\n", + " fontsize=10,\n", + ")\n", + "\n", + "elecs = np.array([[0.46, 0.91], [0.39, 0.73], [0.52, 0.86], [0.33, 0.96]])\n", + "ax1.scatter(*elecs[:2].T, c='r', s=5)\n", + "ax1.scatter(*elecs[2:].T, c='b', s=5)\n", + "ax1.set_ylim(0, 1)\n", + "ax1.set_xlim(0, 1)\n", + "\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.6, 0.8),\n", + " (0.495, 0.85),\n", + " arrowprops={\n", + " 'connectionstyle': \"angle,angleA=0,angleB=90,rad=5\",\n", + " 'arrowstyle': '-',\n", + " 'color': 'lightgray',\n", + " },\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.6, 0.76),\n", + " (0.485, 0.72),\n", + " arrowprops={\n", + " 'connectionstyle': \"angle,angleA=0,angleB=90,rad=5\",\n", + " 'arrowstyle': '-',\n", + " 'color': 'lightgray',\n", + " },\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.8, 0.78),\n", + " (0.6, 0.82),\n", + " arrowprops={\n", + " 'connectionstyle': \"angle,angleA=90,angleB=0,rad=5\",\n", + " 'arrowstyle': '-',\n", + " 'color': 'lightgray',\n", + " },\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.8, 0.77),\n", + " (0.6, 0.73),\n", + " arrowprops={\n", + " 'connectionstyle': \"angle,angleA=90,angleB=0,rad=5\",\n", + " 'arrowstyle': '-',\n", + " 'color': 'lightgray',\n", + " },\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.19, 0.79),\n", + " (0.3, 0.82),\n", + " arrowprops={\n", + " 'connectionstyle': \"angle,angleA=90,angleB=0,rad=5\",\n", + " 'arrowstyle': '-',\n", + " 'color': 'lightgray',\n", + " },\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.3, 0.8),\n", + " (0.37, 0.85),\n", + " arrowprops={\n", + " 'connectionstyle': \"angle,angleA=0,angleB=90,rad=5\",\n", + " 'arrowstyle': '-',\n", + " 'color': 'lightgray',\n", + " },\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.19, 0.78),\n", + " (0.35, 0.805),\n", + " arrowprops={\n", + " 'connectionstyle': \"angle,angleA=90,angleB=0,rad=5\",\n", + " 'arrowstyle': '-',\n", + " 'color': 'lightgray',\n", + " },\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.35, 0.79),\n", + " (0.47, 0.81),\n", + " arrowprops={\n", + " 'connectionstyle': \"angle,angleA=0,angleB=90,rad=5\",\n", + " 'arrowstyle': '-',\n", + " 'color': 'lightgray',\n", + " },\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.19, 0.77),\n", + " (0.52, 0.77),\n", + " arrowprops={'arrowstyle': '-', 'color': 'lightgray'},\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.19, 0.76),\n", + " (0.32, 0.7),\n", + " arrowprops={\n", + " 'connectionstyle': \"angle,angleA=90,angleB=0,rad=5\",\n", + " 'arrowstyle': '-',\n", + " 'color': 'lightgray',\n", + " },\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.32, 0.73),\n", + " (0.42, 0.67),\n", + " arrowprops={\n", + " 'connectionstyle': \"angle,angleA=0,angleB=90,rad=5\",\n", + " 'arrowstyle': '-',\n", + " 'color': 'lightgray',\n", + " },\n", + ")\n", + "\n", + "# ________________________________________________\n", + "\n", + "annotate(ax1, None, (0.13, 0.13), (0.13, 0.74), arrow=True)\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.23, 0.68),\n", + " (0.17, 0.74),\n", + " arrowprops={'connectionstyle': \"angle,angleA=90,angleB=0,rad=5\", 'arrowstyle': '-'},\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.25, 0.635),\n", + " (0.21, 0.68),\n", + " arrowprops={'connectionstyle': \"angle,angleA=0,angleB=90,rad=5\"},\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.86, 0.68),\n", + " (0.79, 0.74),\n", + " arrowprops={'connectionstyle': \"angle,angleA=90,angleB=0,rad=5\", 'arrowstyle': '-'},\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.88, 0.635),\n", + " (0.84, 0.68),\n", + " arrowprops={'connectionstyle': \"angle,angleA=0,angleB=90,rad=5\"},\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.73, 0.68),\n", + " (0.77, 0.74),\n", + " arrowprops={'connectionstyle': \"angle,angleA=90,angleB=0,rad=5\", 'arrowstyle': '-'},\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.7, 0.635),\n", + " (0.75, 0.68),\n", + " arrowprops={'connectionstyle': \"angle,angleA=0,angleB=90,rad=5\"},\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.645, 0.7),\n", + " (0.75, 0.74),\n", + " arrowprops={'connectionstyle': \"angle,angleA=90,angleB=0,rad=5\", 'arrowstyle': '-'},\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.61, 0.635),\n", + " (0.66, 0.7),\n", + " arrowprops={'connectionstyle': \"angle,angleA=0,angleB=90,rad=5\", 'arrowstyle': '-'},\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.328, 0.6),\n", + " (0.61, 0.65),\n", + " arrowprops={'connectionstyle': \"angle,angleA=90,angleB=0,rad=5\"},\n", + ")\n", + "annotate(ax1, None, (0.328, 0.56), (0.65, 0.56), arrow=True)\n", + "annotate(ax1, None, (0.81, 0.58), (0.75, 0.58), arrow=True)\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.324, 0.49),\n", + " (0.25, 0.54),\n", + " arrowprops={'connectionstyle': \"angle,angleA=90,angleB=0,rad=5\", 'arrowstyle': '-'},\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.352, 0.49),\n", + " (0.86, 0.54),\n", + " arrowprops={'connectionstyle': \"angle,angleA=90,angleB=0,rad=5\", 'arrowstyle': '-'},\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.634, 0.45),\n", + " (0.15, 0.74),\n", + " arrowprops={'connectionstyle': \"angle,angleA=90,angleB=0,rad=5\", 'arrowstyle': '-'},\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.662, 0.45),\n", + " (0.9, 0.54),\n", + " arrowprops={'connectionstyle': \"angle,angleA=90,angleB=0,rad=5\", 'arrowstyle': '-'},\n", + ")\n", + "annotate(ax1, None, (0.338, 0.405), (0.338, 0.473), arrow=True)\n", + "annotate(ax1, None, (0.649, 0.379), (0.649, 0.435), arrow=True)\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.34, 0.3),\n", + " (0.482, 0.26),\n", + " arrowprops={'connectionstyle': \"angle,angleA=0,angleB=90,rad=5\", 'arrowstyle': '-'},\n", + ")\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.65, 0.3),\n", + " (0.512, 0.26),\n", + " arrowprops={'connectionstyle': \"angle,angleA=0,angleB=90,rad=5\", 'arrowstyle': '-'},\n", + ")\n", + "annotate(ax1, None, (0.497, 0.19), (0.497, 0.242), arrow=True)\n", + "annotate(ax1, None, (0.632, 0.15), (0.578, 0.15), arrowprops={'arrowstyle': '-'})\n", + "annotate(\n", + " ax1,\n", + " None,\n", + " (0.663, 0.15),\n", + " (0.719, 0.097),\n", + " arrowprops={'connectionstyle': \"angle,angleA=90,angleB=0,rad=5\", 'arrowstyle': '-'},\n", + ")\n", + "annotate(ax1, None, (0.703, 0.085), (0.23, 0.085), arrowprops={'arrowstyle': '-'})\n", + "annotate(ax1, None, (0.81, 0.085), (0.732, 0.085), arrow=True)\n", + "\n", + "circ1 = patches.Circle((0.275, 0.49), 0.02, lw=lw, ec='k', fill=False)\n", + "circ2 = patches.Circle((0.675, 0.44), 0.02, lw=lw, ec='k', fill=False)\n", + "circ3 = patches.Circle((0.675, 0.05), 0.02, lw=lw, ec='k', fill=False)\n", + "circ4 = patches.Circle((0.765, -0.035), 0.02, lw=lw, ec='k', fill=False)\n", + "circ5 = patches.Circle((0.48, 0.19), 0.02, lw=lw, ec='k', fill=False)\n", + "ax1.add_patch(circ1)\n", + "ax1.add_patch(circ2)\n", + "ax1.add_patch(circ3)\n", + "ax1.add_patch(circ4)\n", + "ax1.add_patch(circ5)\n", + "circ4.set_clip_box(None)\n", + "\n", + "annotate(ax1, r\"$\\times$\", (0.34, 0.484), fontsize=8, color=\"black\")\n", + "annotate(ax1, r\"$\\times$\", (0.65, 0.446), fontsize=8, color=\"black\")\n", + "annotate(ax1, r\"$\\bullet$\", (0.4975, 0.253), fontsize=8, color=\"black\")\n", + "annotate(ax1, r\"$\\sum$\", (0.649, 0.147), fontsize=4.5, color=\"black\")\n", + "annotate(ax1, r\"$\\bullet$\", (0.7185, 0.08), fontsize=8, color=\"black\")\n", + "\n", + "arrimg = mpimg.imread(\"H.png\")\n", + "imagebox = OffsetImage(arrimg, zoom=0.2)\n", + "ab = AnnotationBbox(imagebox, (0.46, 0.72), frameon=False, xycoords='figure fraction')\n", + "ax1.add_artist(ab)\n", + "\n", + "arrimg = mpimg.imread(\"Li.png\")\n", + "imagebox = OffsetImage(arrimg, zoom=0.2)\n", + "ab = AnnotationBbox(imagebox, (0.46, 0.85), frameon=False, xycoords='figure fraction')\n", + "ax1.add_artist(ab)\n", + "\n", + "annotate(\n", + " ax1,\n", + " \"\\n \",\n", + " (0.15, 0.775),\n", + " bbox=dict(\n", + " boxstyle=\"round,pad=0.8\", linewidth=lw, edgecolor=\"k\", facecolor=elec_color\n", + " ),\n", + " fontsize=7,\n", + " color=\"black\",\n", + ")\n", + "annotate(ax1, r\"$[E,4]$\", (0.15, 0.76), fontsize=6, color=\"black\")\n", + "annotate(ax1, r\"$\\mathbf{x}$\", (0.15, 0.79), fontsize=9, color=\"black\")\n", + "annotate(\n", + " ax1,\n", + " \"\\n \",\n", + " (0.77, 0.775),\n", + " bbox=dict(\n", + " boxstyle=\"round,pad=0.8\", linewidth=lw, edgecolor=\"k\", facecolor=nuc_color\n", + " ),\n", + " fontsize=7,\n", + " color=\"black\",\n", + ")\n", + "annotate(ax1, r\"$[N,4]$\", (0.77, 0.76), fontsize=6, color=\"black\")\n", + "annotate(ax1, r\"$\\mathbf{M}$\", (0.77, 0.79), fontsize=9, color=\"black\")\n", + "\n", + "annotate(\n", + " ax1,\n", + " \"Electron\\nTransformer\\n\" + r\"$[E,F_\\text{rep}]$\",\n", + " (0.25, 0.58),\n", + " bbox=dict(\n", + " boxstyle=\"round,pad=0.8\", linewidth=lw, edgecolor=\"k\", facecolor=elec_color\n", + " ),\n", + " fontsize=6,\n", + " color=\"black\",\n", + ")\n", + "annotate(\n", + " ax1,\n", + " \"Nuclei\\nMPNN\\n\" + r\"$[N,F_\\text{nuc}]$\",\n", + " (0.7, 0.58),\n", + " bbox=dict(\n", + " boxstyle=\"round,pad=0.8\", edgecolor=\"k\", linewidth=lw, facecolor=nuc_color\n", + " ),\n", + " fontsize=6,\n", + " color=\"black\",\n", + ")\n", + "annotate(\n", + " ax1,\n", + " \"Orbital\\nGenerator\\n\" + r\"$[O,N,F_\\text{orb}]$\",\n", + " (0.88, 0.58),\n", + " bbox=dict(\n", + " boxstyle=\"round,pad=0.8\", edgecolor=\"k\", linewidth=lw, facecolor=nuc_color\n", + " ),\n", + " fontsize=6,\n", + " color=\"black\",\n", + ")\n", + "annotate(\n", + " ax1,\n", + " \"Jastrow factor\\n\" + r\"$[1]$\",\n", + " (0.15, 0.09),\n", + " bbox=dict(\n", + " boxstyle=\"round,pad=0.8\", edgecolor=\"k\", linewidth=lw, facecolor=array_color\n", + " ),\n", + " fontsize=6,\n", + " color=\"black\",\n", + ")\n", + "annotate(\n", + " ax1,\n", + " \"Envelopes\\n\" + r\"$[E,O,D]$\",\n", + " (0.65, 0.34),\n", + " bbox=dict(\n", + " boxstyle=\"round,pad=0.8\", edgecolor=\"k\", linewidth=lw, facecolor=dataset_color\n", + " ),\n", + " fontsize=6,\n", + " color=\"black\",\n", + ")\n", + "annotate(\n", + " ax1,\n", + " \"Projected electron\\n representations\\n\" + r\"$[E,O,D]$\",\n", + " (0.35, 0.355),\n", + " bbox=dict(\n", + " boxstyle=\"round,pad=0.8\", edgecolor=\"k\", linewidth=lw, facecolor=dataset_color\n", + " ),\n", + " fontsize=6,\n", + " color=\"black\",\n", + ")\n", + "annotate(\n", + " ax1,\n", + " \"Determinant\\n\" + r\"$[D]$\",\n", + " (0.5, 0.15),\n", + " bbox=dict(\n", + " boxstyle=\"round,pad=0.8\", edgecolor=\"k\", linewidth=lw, facecolor=array_color\n", + " ),\n", + " fontsize=6,\n", + " color=\"black\",\n", + ")\n", + "annotate(\n", + " ax1,\n", + " r\"$\\Psi$\",\n", + " (0.85, 0.085),\n", + " bbox=dict(\n", + " boxstyle=\"round,pad=0.5\", edgecolor=\"k\", linewidth=lw, facecolor=array_color\n", + " ),\n", + " fontsize=11,\n", + " color=\"black\",\n", + ")\n", + "\n", + "fig.savefig('architecture.pdf', dpi=600)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "oneqmc", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/00_method/stretching.png b/notebooks/00_method/stretching.png new file mode 100644 index 0000000..58fef72 Binary files /dev/null and b/notebooks/00_method/stretching.png differ