Architectural highlights of AlphaFold3

DeepMind and Isomophic Labs recently published the methods behind AlphaFold3, the sequel to the famous AlphaFold2. The involvement of Isomorphic Labs signifies a shift that Alphabet is getting serious about drug design. To this end, AlphaFold3 provides a substantial improvement in the field of complex prediction, a major piece in the computational drug design pipeline.

The goal of this post is to provide a middle ground between popular science articles and the AlphaFold3 SI. The material for this post is largely drawn from our mini reading group on AlphaFold3. I won’t go into detail on the results, although they are very impressive!

Broadly, AlphaFold3 is structured similarly to AlphaFold2. It is still split into a main trunk which processes MSA, template, and sequence information to generate an initial structural hypothesis, and then feeds these into a structure module to generate the final structure. As in AlphaFold2, the pair representation in the main trunk is processed using attention with triangular updates to ensure a geometrically consistent hypothesis. So what are the big differences that make AlphaFold3 another giant leap?

A conceptual shift

Whereas AlphaFold2 was targeting the protein folding problem, AlphaFold3 is primarily designed for modelling protein-ligand complexes, that is, the structure of proteins bound to other things. Ligands in AlphaFold3 can be defined very broadly: other protein chains, nucleic acids, or even arbitrary small molecules. Given that most drugs are just molecules which bind to certain proteins, being able to reliably predict how molecules bind to proteins is of significant interest.

AlphaFold3 models complexes – proteins, nucleic acids, and small molecules

Modelling complexes

In order to accommodate arbitrary molecular complexes, AlphaFold3 introduces a more general tokenization scheme as well as a few architectural changes. The change in tokenization enables it to describe other types of molecules, as well as dynamically change the scale of how molecules are modelled: principally, amino acids can be modelled using backbone coordinates or entire side chains, depending on the context.

The architectural changes are partially to model larger structures – since complexes are bigger than single chains, especially when you are modelling all-atom representations. The biggest innovation seems to be the new structure module which uses a diffusion Transformer.

Overview of AlphaFold3’s architecture. The main trunk generates a structural hypothesis through the Single and Pair representations which are passed to the Diffusion module to generate a complete structure.

The main trunk

The main trunk is broadly similar to that of AlphaFold3. It still takes as input a multiple sequence alignment (MSA), matching structural templates, and the tokenized sequences. The main difference is that it is slightly more efficient. In particular, the MSA representation has been simplified and no longer includes column attention. This means that each sequence from the MSA can be processed independently which is simpler and more efficient.

Multiple multiple sequence alignments

Because AlphaFold3 is built to solve different kinds of problems, it is capable of processing different kinds of MSAs. For protein-protein interactions AF3 uses the UniProt MSA and then augments it with protein results from sequence search. The UniProt MSA is very important because it contains a canonical sequence of each chain per species. This means that repeated mutations across species and across chain pairs can be used to inform the binding site prediction, similar to the MSA usage in AlphaFold2 and previous approaches. The remaining sequences come from individual chains and so cannot provide information about inter-chain contacts, but can still be used to guide intra-chain structure prediction. Because AlphaFold3 also explicitly models nucleic acid binding, they also provide functionality for parsing RNA and DNA MSAs, analogously to protein MSAs.

Pairformer

AlphaFold3 combines the processing of the MSA and structural templates with the generation of the pair representation. This entire module is now called the PairFormer. As in AlphaFold2, the PairFormer uses triangular attention updates to ensure its structural hypothesis at each step at least satisfies the triangle inequality. The output of the PairFormer is an NxN dimensional Pair representation which abstractly represents plausible connections in the structure as well as an N dimensional Single representation which represents each token in the sequence. These are passed to the new diffusion-based structure module to be resolved to a predicted structure.

The structure module

Probably the most exciting architectural development in AlphaFold3 is the completely redesigned Structure Module. The old “Invariant Point Attention” is gone, in favour of a diffusion model wrapped around a Transformer. In some ways, this is a simplification: positions are now only represented as coordinates in R3 and the model, conspicuously, no longer explicitly guarantees SE(3) equivariance.

On Transformers

Since their introduction in 2017, Transformers have undergone a slew of architectural changes, to the point where the definition of a Transformer has become murky. Broadly, it can be thought of as a series of blocks which are composed of attention and processing layers. In the original Transformer, these were standard dot product attention and simple feedforward layers. The AlphaFold3 structure model is a slightly different flavour. The key motivation behind most of these changes is simply that we want to be able to condition the output of the Transformer on the output of the preceding module, the PairFormer.

Attention with pair bias

In many cases, the PairFormer has probably learned a fairly specific hypothesis about the distance between tokens in the final structure. So how is this hypothesis communicated to the Structure module? In AlphaFold3, the pair representation for tokens i and j is linearly projected to a scalar which is then added as a bias term to the attention matrix. Intuitively, this allows each attention head to attend to tokens which the PairFormer identified as having relevant information. The extreme of this would be if the PairFormer has a complete strucutral hypothesis, in which case the structure module transformer could completely defer to the pair representation and have the position of each token slowly converge to a weighted average of the positions of the tokens it should be close to.

Transition blocks

The Transition block is almost identical to a standard feedforward layer except for two changes: LayerNorm is replaced with “Adaptive Normalization” (AdaNorm) and ReLU is replaced with SwiGLU. AdaNorm is a slightly strange way of conditioning the Transition block on the PairFormer’s token representation. The justification given for ReLU -> SwiGLU is simply that SwiGLU is used in modern Transformers. Perhaps tellingly, the paper which introduced SwiGLU concludes “We offer no explanation as to why these architectures seem to work; we attribute their success, as all else, to divine benevolence”. The effectiveness of SwiGLU likely comes from the fact that it is smooth and highly non-linear.

Wrapped in diffusion

Armed with these variants of attention and transitions, the AlphaFold3 structure module is relatively straightforward. Take the rescaled coordinates for each of the tokens, the sequence representation, and the pair representation, and pass them into a series of alternating PairAttentionWithBias and Transition blocks. At the end, each token outputs its predicted new coordinates. The diffusion module then takes these and generates a step in that direction based on the current temperature. The steps are added with noise to simulate denoising, and the process is repeated n times, after which the model should have converged on a reasonable structure for the complex.

What about SE(3)?

In AlphaFold2, the structure module takes great care to ensure that SE(3) (relative distance) is preserved. AlphaFold3 scraps this, instead simply linearly embedding the coordinates and passing them into the token embeddings to be processed like any other token feature. For a model to have learned any physics, and there are some indications that AF3 has, a model must be able to measure distance which means it must have learned some SE(3). To encourage the model to learn this, all structures are recentred and then randomly rotated and translated at every diffusion step. In this way the model must learn to use positional information which has been subject to arbitrary SE(3) transformations. Just because a model hasn’t been designed to be SE(3) equivariant, doesn’t mean the model hasn’t learned to be SE(3) equivariant.

Looking forward

I think the biggest takeaway from AlphaFold3 is that attention is still all you need. Transformers are essentially stable, memory-efficient, fully connected GNNs. If Transformers can learn some of the benefits of GNNs such as distance encodings, then they can be an effective architecture, even for structural data. That said, AlphaFold3 still generates an explicit edge representation via the PairFormer which may be important because, for instance, the pair representations are initialized with distance encodings. Regardless, AlphaFold3 is an exciting step towards modelling complexes rather than single chains. Tools like this will be essential for designing drugs faster and cheaper.

Author