One stereotype of “AI for Science” is to fit transformers on scientific datasets. But that’s not the only way AI could benefit science.
In Sea AI Lab we aim to
Rebuild density functional theory software with modern tools created in the AI community.
I provide a few reasons why the team at Sea AI Lab is currently not working in this direction.
At the core of density functional theory is an optimization problem that finds the lowest energy of a system. It is solvable via gradient descent like machine learning models. We enforce the orthogonality constraint in DFT via reparameterization, and implemented energy minimization as gradient descent.
Direct minimization also brings new opportunities. The SCF optimization solves an eigenvalue problem which only works for linear coefficients, while direct optimization enable us to use more complex ansatz (described below). Stochastic gradient is also a very attractive point, where we can replace the integral/summation in DFT with randomized version to save overall computation.
Beyond the groundstate energy calculation, various properties of materials are related to the derivatives of the energy. For example, forces are first order derivatives with respect to the atom coordinates. Phonon spectrums are computed from second order derivatives etc. Without AD, the formula of these properties need to be derived by hand and implemented separately. In machine learning, we need to program our own backpropagation algorithm not long ago, but today it is totally reformed by AD tools. We expect AD tools would help transform scientific computing in a similar way.
To make it more interesting, we extend AD to support automatic differentiation of functionals and operators. We found it useful for aligning our code with the math derivation. Physicists often communicate in succinct math languages, for example, below is the math for calculating material bandstructure.
However, the above is not directly implementable. To implement it, one needs to hand derive the $\frac{\delta E[\rho]}{\delta \rho}$, replace the $\psi$ and $\rho$ with linear combinations of basis, and convert the entire calculation into coefficient space. With the introduction of automatic functional derivative, $\frac{\delta E[\rho]}{\delta \rho}$ is directly implementable as a linearization of the energy functional. We show a code snippet for calculating bandstructure
def potential_and_xc_energy(rho):
return (
e.hartree(rho) +
e.external(rho, crystal) +
jax_xc.energy.lda_x(rho)
)
Veff = jax.grad(potential_and_xc_energy)(rho0)
_, Eeff = jax.linearize(potential_and_xc_energy, (rho0,))
def energy_under_veff(psi):
return e.kinetic(psi) + Eeff(psi_to_rho(psi))
def energy_in_param(param, k):
psi = o.partial(wave_ansatz, args=(param, k), argnums=(0, 1))
return energy_under_veff(psi)
bands = {}
for k in k_vectors:
bands[k] = jnp.eigh(jax.hessian(energy_in_param)(param, k))[0]
In DFT, electron wave functions are approximated as linear combinations of basis functions which are analytically simple functions. When the target function is complex, we increase the number of basis, e.g. we increase the cutoff energy in solid state calculations. One key insight from deep learning is that the increasing the depth is exponentially more efficient than increasing the width.
To introduce deep models as wave functions, we need to consider the following constraints.
We managed to introduce deep components into the wave functions while satisfying all the above (Section 4 in DF4T). We are not the first to explore this idea, previous work dates back to 1993. However, powerful normalizing flows only come to existence in the past few years. It is therefore worthwhile to re-explore these ideas with the tools developed in the AI era.
Thanks to the separation of frontend and backend, deep learning frameworks abstract away the details of hardware. Code written in JAX runs on all kinds of hardware without any customization. Today it is even possible to scale out on a cluster transparently via the parallel jit api. In our own DFT code, we commit to the idea of separating the main logic from the optimization details. We keep the frontend as simple as math, we hide the optimization and implementation details using the powerful compiler stack in JAX. One such example is the use of FFT in solidstate code.
The wave functions are represented as linear summation over planewave basis.
$$ \psi(r)=\sum_G c_{G} \exp{iGr} $$
Although we can evaluate this function explicitly at coordinates to obtain its value, it turns out using fast fourier transformation (FFT) would be much more efficient if we would like to evaluate this function on a grid. Explicit evaluation would take $O(N^2)$ while FFT takes $O(N\log N)$. However, this optimization would break the functional syntax that we’re promoting.
For example, the exchange correlation functional is defined as $$ E_{\text{xc}}[\rho]=\int\epsilon_{\text{xc}}[\rho](r)\rho(r)dr $$
The corresponding python code
Exc = o.integrate(jax_xc.gga_x_pbe(rho) * rho)
There are a few difficulties,
rho
.jax_xc.gga_x_pbe
takes rho
as input evaluates jax.grad(rho)
internally.rho
has to be a function differentiation w.r.t r
, while FFT only depends on the linear coefficients.def gga_x_pbe(rho):
def epsilon_xc(r):
return some_other_function(rho(r), jax.grad(rho(r)))
return epsilon_xc
r
into the epsilon_xc
function, it will result in pointwise evaluation of the rho
, leading to a $O(N^2)$ complexity. Thanks to the internal design of JAX, we can register customized jvp rule and batch rule for wave functions that are composed of linearly combined planewaves. Via batch rule, we can trigger the FFT acceleration when a grid of input is received, while at the same time we write custom jvp rules to guarantee that the gradient is still computed correctly.