GemNet: Universal Directional Graph Neural Networks for Molecules#
1. Context: Spherical Harmonics#
Spherical harmonics \(Y_l^m(\theta, \phi)\) appear as solutions to the angular part of Laplace’s equation in spherical coordinates.
l = degree (non-negative integer, like angular momentum quantum number)
m = order (integer related to projection/orientation)
2. Motivation#
Regular GNNs are only as powerful as the 1-Weisfeiler Lehman test of isomorphism and thus cannot distinguish between certain molecules. Moreover, they require a large number of training samples to achieve good accuracy.
Resolve the questionable expressiveness of GNNs by proving sufficient conditions for universality in the case of invariance to translations and rotations and equivariance to permutations; and then extending this result to rotationally equivariant predictions.
3. Preliminaries#
GNNs for molecules typically incorporate directional information in one of two ways: Via SO(3) representations or by using directions in real space
Mathematical terminology
Assume a point cloud with n points (atoms), each is associated with a position in coordinates \(\mathbf{X}\in \mathbb{R}^{3\times n}\) and a set of rotationally invariant features (e.g. atom types) defined as \(\mathbf{H}_{\text{in}}\in\mathbb{R}^{h\times n}\)
We define model classes by sets of functions \(\mathcal{F}\), we need to prove that the set \(\mathcal{F}\) is equal to the full set of functions \(\mathcal{G}'\) which are invariant to the group of translations \(\mathbb{T}^3\) , rotations \(\text{SO(3)}\) and equivariant to the group of permutations. \(\rightarrow \mathbb{T}^3 \or \text{SO(3)}\or S_n\), We denote a vector’s norm by \(x = ||x||_2\), with direction towards \(\hat{x}=\mathbf{x}/x\), and the relative position between point \(a\) and \(b\) is \(x_{ba}=x_b-x_a\)
Tensor field network
We define set \(\mathcal{F}^\text{TFN}_\text{pool}\) as all rotationally equivariant linear functions on the \(\text{SO}(3)\) group.
The embedding functions \(\mathcal{F}^\text{TFN}_\text{feat}(D)=\{\pi_2 \circ f^{(2D)} \circ \cdots \circ f^{(1)}|f^{(i)}\in\mathcal{F}_\text{prod}^\text{TFN}\}\), each \(f^{(i)}\) is one of the tensor product layers in the TFN and there is spherical harmonics and Clebsch–Gordan coefficients come in to mix spatial and feature information.\(\pi_2(\mathbf{X},\mathbf{H})=\mathbf{H}\) , implies this function does nothing to geometry, just pass along features.
So, \(\mathcal{F}_\text{prod}^\text{TFN}=\{f|f(\mathbf{X},\mathbf{H})=(\mathbf{X},\mathbf{\tilde{H}}^\text{TFN}(\mathbf{X},\mathbf{H})) \}\), It’s the set of feature update functions. The intermediate representations are \(\mathbf{H}\in W^n_\text{feat}\), \(W_\text{feat}\) is a representation of \(\text{SO}(3)\) indexed by the degree l and the order m.
\(C^{(l_o, m_o)}_{(l_f, m_f), (l_i, m_i)}\) is Clebsch-Gordan coefficients represents the tensor product of two input \(\text{SO}(3)\) representations (the filter and input representations) into a sum of output representations.
\(\sum_{b\in\mathcal{N}_a}\) is looping over all neighbor atoms b of atom a.
\(F^{(l_f)}_{\text{TFN}, m_f}(x_b - x_a)\) a is rotationally equivariant filter applied to the vector from atom a to b.
\(H^{(l_i)}_{b m_i}\) Neighbor \(b\)’s feature of type \((l_i, m_i)\).
[!NOTE]
\(F^{(l)}_{\text{TFN}, m}(x) = R^{(l)}(\|x\|) \; Y_{l m}(\hat{x})\) is a rotationally equivariant filter, with a (learned) radial part R, which is any polynomial of degree ≤D, and the real spherical
harmonics \(Y_{lm}\) with degree land order m.
\(R^{(l)}\) = learned radial function of the bond length \(\|x\|\)
\(Y_{l m}\) = spherical harmonic capturing the angular dependence
Spherical networks
Instead of intermediate \(\text{SO}(3)\) representations we now switch to spherical representations, which are functions on the sphere \(H : S^2 →R\). We define the set of functions \(\mathcal{F}^\text{sphere}_{ K(D),D}\) analogously to \(\mathcal{F}^\text{TFN}_{ K(D),D}\). However, for \(\mathcal{F}^\text{sphere}_\text{feat}(D)\) we use
[!NOTE]
\(F^{\text{sphere}}(x,\hat r) = \sum_{l,m} R^{(l)}(x)\;\Re\!\big[\,Y^{(l)}_m(\widehat{x})^{*}\;Y^{(l)}_m(\hat r)\,\big]\)
\(R^{(l)}(x)\): learned radial function of the distance \(\|x\|\) (one per degree \(l\)).
\(Y^{(l)}_m\): (complex) spherical harmonics (the angular basis on the sphere).
\(\widehat{x}=x/\|x\|:\) the direction from \(a\) to \(b\).
\(\Re[\cdot]\): take the real part → yields a real-valued filter.
The set of pooling functions for invariant predictions is:
\(\theta_\text{pool}\) is a learnable parameter