Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions bec_code_review_report.tex
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
\documentclass[12pt]{article}
\usepackage[utf8]{inputenc}
\usepackage{geometry}
\geometry{a4paper, margin=1in}
\usepackage{hyperref}
\usepackage{minted}
\usepackage{enumitem}

\title{Code Review Report: Born Effective Charges (BEC) Implementation}
\author{Reviewer}
\date{\today}

\begin{document}

\maketitle

\section{Introduction}
This report provides a comprehensive review of the new functionality implemented in the \texttt{feature/add-bec-support-14624376731387974145} branch to allow SevenNet to predict Born Effective Charges (BEC). The review evaluates the rigorousness of the approach, identifies potential bottlenecks, highlights dangerous approximations, and discusses general aspects evaluated in these types of equivariant neural network models.

The implementation broadly follows the ``Data $\to$ Model $\to$ Loss $\to$ Trainer'' architectural pipeline required for adding a new property in SevenNet, tracking dataset statistics correctly in \texttt{sevenn/train/graph\_dataset.py}. However, several critical issues related to loss scaling, model architecture configuration, and memory management need to be addressed before merging.

\section{Identified Issues and Mitigations}

\subsection{1. Loss Scaling and Dimensionality (Dangerous Approximation)}
\textbf{Context:} In \texttt{sevenn/train/loss.py}, the \texttt{BECLoss} class calculates the loss by first flattening both the predicted and reference 9-component tensors before passing them to the loss criterion (e.g., PyTorch's \texttt{nn.MSELoss} with \texttt{reduction='mean'}).
\begin{minted}{python}
pred = torch.reshape(pred, (-1,))
ref = torch.reshape(ref_irreps, (-1,))
\end{minted}

\textbf{Explanation of the Error:} By flattening the 9-component tensor (shape $[N_{atoms}, 9]$ to $[N_{atoms} \times 9]$) before applying \texttt{reduction='mean'}, the calculated loss is averaged over all $N_{atoms} \times 9$ components. This effectively scales the base loss value downward by a factor of 9 compared to a per-atom scalar loss (where the sum of squared errors would only be divided by $N_{atoms}$).

\textbf{Impact / Importance:} This is a dangerous approximation because it artificially reduces the gradient magnitude for the BEC property relative to other properties (like Energy or Force) by an order of magnitude. This disproportionate weighting will lead to suboptimal multi-task training dynamics, causing the model to underfit the BEC predictions unless the user manually compensates by increasing the \texttt{BEC\_WEIGHT} by a factor of 9.

\textbf{Mitigation:} To ensure consistent gradient magnitudes across tensor bases and maintain the Mean Squared Error (L2 norm) consistency between Cartesian and Irreps representations, the loss calculation should average over atoms, not components.
Instead of flattening, maintain the shape $[N_{atoms}, 9]$ and compute the loss such that it averages over the first dimension ($N_{atoms}$) but sums (or manages appropriately) over the 9 components. Alternatively, if flattening is kept, explicitly multiply the resulting loss by 9 to restore the per-atom scaling.

\subsection{2. Model Architecture Configuration (\texttt{lmax} Override)}
\textbf{Context:} In \texttt{sevenn/model\_build.py}, the model constructs the final interaction layer. For predicting non-scalar tensors like BEC (which requires the irreps \texttt{1x0e+1x1e+1x2e}), the final layer must support higher-order spherical harmonics.
\begin{minted}{python}
if t == num_convolution_layer - 1:
if config.get(KEY.IS_TRAIN_BEC, False):
# We need at least L=1 and L=2 for vectors and tensors.
# `lmax_node` is already resolved at the top of the function.
pass
else:
lmax_node = 0
parity_mode = 'even'
\end{minted}

\textbf{Explanation of the Error:} The code relies on the global \texttt{lmax\_node} configured by the user. If a user sets the default \texttt{lmax=0} in their configuration, the \texttt{pass} statement does nothing, and the final layer will be built with \texttt{lmax=0} and \texttt{parity\_mode='even'}.

\textbf{Impact / Importance:} This is a critical failure point. Predicting the BEC tensor requires the final layer to output irreps of at least $L=1$ and $L=2$ with full parity. If \texttt{lmax=0}, the model fundamentally lacks the capacity (equivariant feature propagation) to predict the non-scalar BEC tensor, leading to shape mismatches or zeroed outputs during the \texttt{predict\_bec} linear projection.

\textbf{Mitigation:} Forcefully override the parameters to ensure mathematical correctness.
\begin{minted}{python}
if config.get(KEY.IS_TRAIN_BEC, False):
lmax_node = max(lmax_node, 2)
parity_mode = 'full'
\end{minted}

\subsection{3. Memory Management: \texttt{CartesianTensor} Instantiation (Bottleneck Prevention)}
\textbf{Context:} The conversion between Irreps and Cartesian representations relies on \texttt{e3nn.io.CartesianTensor}. In SevenNet, dynamically instantiating \texttt{CartesianTensor} and calling its \texttt{.from\_cartesian()} or \texttt{.to\_cartesian()} methods without caching the \texttt{ReducedTensorProducts} (RTP) causes severe O(N) memory leaks and time inflation due to dynamic \texttt{fx.GraphModule} compilation under the hood.

\textbf{Review Findings:} The current implementation in \texttt{BECLoss}, \texttt{ErrorMetric}, and \texttt{SevenNetCalculator} correctly caches both \texttt{\_ct} (the \texttt{CartesianTensor}) and \texttt{\_rtp} (the \texttt{ReducedTensorProducts}) at the object instance level.
\begin{minted}{python}
def _get_cartesian_tensor(self):
if getattr(self, '_ct', None) is None:
self._ct = CartesianTensor('ij')
self._rtp = self._ct.reduced_tensor_products()
return self._ct, self._rtp
\end{minted}
\textbf{Conclusion:} This approach is rigorous and correctly avoids multi-GPU device pollution (by avoiding class-level caching) while preventing the known memory leak bottleneck. No changes are required here, but the implementation is commended for its correctness.

\subsection{4. Error Recording and Terminal Logging}
\textbf{Context:} The default error metric for \texttt{vdim: 9} (used for BEC) in \texttt{sevenn/error\_recorder.py} is \texttt{RMSError}, which computes the L2 vector norm per atom.

\textbf{Review Findings:} This is mathematically correct for minimizing the overall vector distance. However, it should be documented or noted that if users expect the component-wise average error (which matches external numpy calculations commonly used in the literature), they must explicitly specify \texttt{ComponentRMSE} or \texttt{MAE} in their training configuration instead of the default \texttt{RMSE}.

\section{Summary and Recommendations}
The implementation is mostly sound and integrates well with the SevenNet framework. The caching of \texttt{CartesianTensor} correctly avoids a known memory bottleneck. However, the PR should not be merged until the following changes are made:
\begin{enumerate}[label=\arabic*.]
\item \textbf{Fix the Loss Scaling in \texttt{BECLoss}:} Adjust the tensor reshaping or multiply the loss by 9 to prevent the base loss value from being artificially scaled down, which disrupts multi-task learning dynamics.
\item \textbf{Enforce \texttt{lmax} and \texttt{parity\_mode} in \texttt{model\_build.py}:} Replace the \texttt{pass} statement with explicit overrides (\texttt{lmax\_node = max(lmax\_node, 2)} and \texttt{parity\_mode = 'full'}) to guarantee the model can propagate the necessary equivariant features for tensor prediction regardless of the base configuration.
\end{enumerate}

Addressing these two dangerous approximations will ensure the rigorousness of the model and robust training performance.

\end{document}
Loading