AI for Science

Feb. 22, 2024

One stereotype of “AI for Science” is to fit transformers on scientific datasets. But that’s not the only way AI could benefit science.

In Sea AI Lab we aim to

Rebuild density functional theory software with modern tools created in the AI community.

Why not blackbox fitting?

I provide a few reasons why the team at Sea AI Lab is currently not working in this direction.

What do we do?

Direct Optimization

At the core of density functional theory is an optimization problem that finds the lowest energy of a system. It is solvable via gradient descent like machine learning models. We enforce the orthogonality constraint in DFT via reparameterization, and implemented energy minimization as gradient descent.

Direct minimization also brings new opportunities. The SCF optimization solves an eigenvalue problem which only works for linear coefficients, while direct optimization enable us to use more complex ansatz (described below). Stochastic gradient is also a very attractive point, where we can replace the integral/summation in DFT with randomized version to save overall computation.

Automatic differentiation (AD)

Beyond the groundstate energy calculation, various properties of materials are related to the derivatives of the energy. For example, forces are first order derivatives with respect to the atom coordinates. Phonon spectrums are computed from second order derivatives etc. Without AD, the formula of these properties need to be derived by hand and implemented separately. In machine learning, we need to program our own backpropagation algorithm not long ago, but today it is totally reformed by AD tools. We expect AD tools would help transform scientific computing in a similar way.

Automatic Functional differentiation (AutoFD)

To make it more interesting, we extend AD to support automatic differentiation of functionals and operators. We found it useful for aligning our code with the math derivation. Physicists often communicate in succinct math languages, for example, below is the math for calculating material bandstructure.

  1. We first solve the groundstate energy $\rho_0$.
  2. We linearize the potential and exchange correlation energy at $\rho_0$ and call it the effective potential.
    $$V_{\text{eff}}=\left.\frac{\delta (E_{\text{pot}}[\rho] + E_{\text{xc}}[\rho])}{\delta \rho}\right\vert_{\rho_0}$$
  3. With the effective potential, we then solve the eigenvalue problem
    $$\left(-\frac{1}{2}\nabla^2 + \hat{V}_\text{eff}\right) \psi = \epsilon \psi$$

However, the above is not directly implementable. To implement it, one needs to hand derive the $\frac{\delta E[\rho]}{\delta \rho}$, replace the $\psi$ and $\rho$ with linear combinations of basis, and convert the entire calculation into coefficient space. With the introduction of automatic functional derivative, $\frac{\delta E[\rho]}{\delta \rho}$ is directly implementable as a linearization of the energy functional. We show a code snippet for calculating bandstructure

  1. We first define the potential and xc energy functional
    def potential_and_xc_energy(rho):
      return (
        e.hartree(rho) +
        e.external(rho, crystal) +
        jax_xc.energy.lda_x(rho)
      )
    
  2. The $V_\text{eff}$ is the first order derivative of the functional, we can compute it via
    Veff = jax.grad(potential_and_xc_energy)(rho0)
    
    Since subsequently we use $V_\text{eff}$ in inner products, we could instead directly construct $E_\text{eff}: \rho \mapsto \langle V_\text{eff}|\rho\rangle$
    _, Eeff = jax.linearize(potential_and_xc_energy, (rho0,))
    
  3. We construct the energy under effective potential as $\langle\psi|-\frac{1}{2}\nabla^2+\hat{V}_\text{eff}|\psi\rangle$. We will diagonalize its hessian.
    def energy_under_veff(psi):
      return e.kinetic(psi) + Eeff(psi_to_rho(psi))
    
  4. Numerical computation needs to happen in the parameter space.
    def energy_in_param(param, k):
      psi = o.partial(wave_ansatz, args=(param, k), argnums=(0, 1))
      return energy_under_veff(psi)
    
    bands = {}
    for k in k_vectors:
      bands[k] = jnp.eigh(jax.hessian(energy_in_param)(param, k))[0]
    

More powerful ansatz

In DFT, electron wave functions are approximated as linear combinations of basis functions which are analytically simple functions. When the target function is complex, we increase the number of basis, e.g. we increase the cutoff energy in solid state calculations. One key insight from deep learning is that the increasing the depth is exponentially more efficient than increasing the width.

To introduce deep models as wave functions, we need to consider the following constraints.

We managed to introduce deep components into the wave functions while satisfying all the above (Section 4 in DF4T). We are not the first to explore this idea, previous work dates back to 1993. However, powerful normalizing flows only come to existence in the past few years. It is therefore worthwhile to re-explore these ideas with the tools developed in the AI era.

Acceleration and scaling

Thanks to the separation of frontend and backend, deep learning frameworks abstract away the details of hardware. Code written in JAX runs on all kinds of hardware without any customization. Today it is even possible to scale out on a cluster transparently via the parallel jit api. In our own DFT code, we commit to the idea of separating the main logic from the optimization details. We keep the frontend as simple as math, we hide the optimization and implementation details using the powerful compiler stack in JAX. One such example is the use of FFT in solidstate code.

The wave functions are represented as linear summation over planewave basis.

$$ \psi(r)=\sum_G c_{G} \exp{iGr} $$

Although we can evaluate this function explicitly at coordinates to obtain its value, it turns out using fast fourier transformation (FFT) would be much more efficient if we would like to evaluate this function on a grid. Explicit evaluation would take $O(N^2)$ while FFT takes $O(N\log N)$. However, this optimization would break the functional syntax that we’re promoting.

For example, the exchange correlation functional is defined as $$ E_{\text{xc}}[\rho]=\int\epsilon_{\text{xc}}[\rho](r)\rho(r)dr $$

The corresponding python code

Exc = o.integrate(jax_xc.gga_x_pbe(rho) * rho)

There are a few difficulties,