Protein sequence modelling with Bayesian flow networks

Protein sequence modelling with Bayesian flow networks

Bayesian flow networks for discrete-data

Overview

Let p(x) be a distribution over D discrete variables that we wish to approximate. For some x~p() we have that x =  [x1, …, xD] VD, where V = [v1vK] is a vocabulary with V = K. The BFN approximation to p(x) is constructed through the iterative revealing of information about x through a sequence of noisy observations y(1: i) = {y(1)y(i)}.

A noisy observation at step i, \({{{{\bf{y}}}}}^{(i)}=[{y}_{1}^{(i)},\ldots,{y}_{D}^{(i)}]\in {{\mathbb{R}}}^{D\times K}\) is drawn from the sender distribution as a normally distributed vector for each variable,

$${p}_{s}\left({y}_{j}^{(i)}| {{{\bf{x}}}};{\alpha }_{i}\right)={{{\mathcal{N}}}}\left({\alpha }_{i}(K{{{\bf{e}}}}{({{{\bf{x}}}})}_{j}-1),{\alpha }_{i}K{{{\bf{I}}}}\right),$$

(2)

$${p}_{s}\left({{{{\bf{y}}}}}^{(i)}| {{{\bf{x}}}};{\alpha }_{i}\right)=\mathop{\prod }_{j=1}^{D}{p}_{s}\left({y}_{j}^{(i)}| {{{\bf{x}}}};{\alpha }_{i}\right),$$

(3)

where \({\alpha }_{i}\in {{\mathbb{R}}}^{+}\) is the accuracy at step i, such that a higher αi reveals more information about the true value of x, and \({{{\bf{e}}}}({{{\bf{x}}}})\in {{\mathbb{R}}}^{D\times K}\) is a one-hot encoding of x,

$${{{\bf{e}}}}{({{{\bf{x}}}})}_{i,j}=\left\{\begin{array}{ll}1,&\,{{\mbox{if}}}\,\,{x}_{i}={v}_{j}\\ 0,&\,{\mbox{otherwise}}\,\end{array}\right\}.$$

(4)

By starting from some prior distribution θ(0) over the value of x and taking into account all observations so far, the input distribution is derived, providing the optimal estimation of the likelihood of each variable xj independently,

$${{{{\boldsymbol{\theta }}}}}_{j,k}^{(i)} =p({x}_{j}={v}_{k} \, | \, {{{{\bf{y}}}}}_{j}^{(1:i)})\\ =\frac{{{{{\boldsymbol{\theta }}}}}_{j,k}^{(i-1)}\cdot \exp {y}_{j,k}^{(i)}}{{\sum }_{q=1}^{K}{{{{\boldsymbol{\theta }}}}}_{j,q}^{(i-1)}\cdot \exp {y}_{j,q}^{(i)}}\\ =\frac{{{{{\boldsymbol{\theta }}}}}_{j,k}^{(0)}\cdot \exp {\sum }_{p=1}^{i}{y}_{j,k}^{(p)}}{\mathop{\sum }_{q=1}^{K}{{{{\boldsymbol{\theta }}}}}_{j,q}^{(0)}\cdot \exp \mathop{\sum }_{p=1}^{i}{y}_{j,q}^{(p)}}$$

(5)

As θ(i) is constructed based on an independent assumption about x1xD, it is clearly suboptimal whenever these variables are interdependent. The neural network ϕ is the output distribution, which aims to provide a better estimation by taking into account the interdependency between variables. Specifically,

$$\phi {\left({{{{\boldsymbol{\theta }}}}}^{(i)}\right)}_{j,k}\approx p({x}_{j}={v}_{k} \, | \, {{{{\bf{y}}}}}^{(1:i)})$$

(6)

This approximation implies an equivalent approximation of the distribution of the next noisy observation y(i+1). This approximation, referred to as the receiver distribution, is a mixture of Gaussians for each variable,

$${p}_{r}\left({y}_{j}^{(i+1)}\,|\, \phi \left({{{{\boldsymbol{\theta }}}}}^{(i)}\right);{\alpha }_{i+1}\right)=\mathop{\sum }_{k=1}^{K}\phi {\left({{{{\boldsymbol{\theta }}}}}^{(i)}\right)}_{j,k}\cdot {{{\mathcal{N}}}}\left({\alpha }_{i+1}(K{{{\bf{e}}}}({v}_{k})-1),{\alpha }_{i+1}K{{{\bf{I}}}}\right),$$

(7)

$${p}_{r}\left({{{{\bf{y}}}}}^{(i+1)}\,|\, \phi \left({{{{\boldsymbol{\theta }}}}}^{(i)}\right);{\alpha }_{i+1}\right)=\mathop{\prod }_{j=1}^{D}{p}_{r}\left({y}_{j}^{(i+1)}\,|\, \phi \left({{{{\boldsymbol{\theta }}}}}^{(i)}\right);{\alpha }_{i+1}\right),$$

(8)

The error in ϕ is measured through the KL-divergence between ps and pr at each step. For N steps,

$${L}^{N}={\sum }_{i=1}^{N}{D}_{{{{\rm{KL}}}}}\left({p}_{{{{\rm{s}}}}}(\cdot \,|\, {{{\bf{x}}}};{\alpha }_{i})\parallel {p}_{{{{\rm{r}}}}}(\cdot | \phi \left({{{{\boldsymbol{\theta }}}}}^{(i-1)}\right);{\alpha }_{i})\right).$$

(9)

In practice, it is possible to derive a continuous-time loss L where N → . Under the assumption of a uniform prior θ(0), the loss becomes remarkably simple and easy to approximate through Monte Carlo integration,

$${L}^{\infty }={{\mathbb{E}}}_{t \sim U(0,1)}\left[{{\mathbb{E}}}_{{{{{\bf{y}}}}}^{(t)} \sim {p}_{s}(\cdot \, | \, {{{\bf{x}}}};\beta (t))}\left[\frac{\alpha (t)}{2}{\left| \left| {{{\bf{e}}}}({{{\bf{x}}}})-\phi \left({{{{\boldsymbol{\theta }}}}}^{(t)}\right)\right| \right| }^{2}\right]\right].$$

(10)

This loss function reflects the transition from an N-step process to a continuous-time process, where t moves from 0 to 1. The sequence of accuracies α1αN is replaced with the monotonically increasing accuracy schedule,

$$\beta (t)={\beta }_{1}{t}^{2},$$

(11)

where \({\beta }_{1}\in {{\mathbb{R}}}^{+}\) is the final accuracy at t = 1, and its derivative scales the loss through time,

$$\alpha (t)=\frac{d}{dt}\beta (t)=2{\beta }_{1}t.$$

(12)

The continuous-time loss corresponds to a negative variational lower bound28, and therefore by optimising the parameters of ϕ to minimise it, we arrive at an approximation of the true distribution p(x) from which x was drawn. To train our models, we follow the general procedure outlined in ref. 28. Specifically, the computation of the loss is described in Box 1.

Entropy encoding

In ref. 28, the current time t is presented to the output network alongside the input distribution θ(t). During initial sampling experiments, we discovered that, during the sampling process, the entropy of the input was noticeably higher at a given time t in comparison to that observed during training. We believe that this phenomenon occurs as the input distribution θ(t) contains additional entropy from uncertainty in the output distribution \(\phi ({{{{\boldsymbol{\theta }}}}}^{(t)})\). When time t is presented as an additional input to the network, this mismatch is, essentially, out of distribution for the output network, hampering performance. To resolve this, we replace the conventional Fourier encoding of time t used in28 with a Fourier encoding of the entropy of each variable, appended to its corresponding input distribution before being passed into the network. This effectively makes the model invariant to t, and such an approach mirrors similar techniques used in diffusion models22,67. In general, the use of entropy encoding is a design choice, and alternative representations of progress through the sampling process could suffice.

Sampling

We explored a variety of alternative sampling methods that reduce the overall temperature of sample generation from ProtBFN. In the conventional discrete-data sampling method described in28, restricting ourselves to the logit space, the discrete sample generation process moves from y(t) at time t to y(s) at time s according to the equation,

$${{{{\bf{y}}}}}^{(s)}={{{{\bf{y}}}}}^{(t)}+{\alpha }_{s}\left(K\phi \left({{{{\boldsymbol{\theta }}}}}^{(t)}\right)-1\right)+\sqrt{{\alpha }_{s}K}{{{{\bf{z}}}}}^{(t)},$$

(13)

where \({{{{\boldsymbol{\theta }}}}}^{(t)}={\mathtt{softmax}}({{{{\bf{y}}}}}_{t})\), \({{{{\bf{z}}}}}^{(t)} \sim {{{\mathcal{N}}}}({{{\bf{0}}}},{{{\bf{I}}}})\) is isotropic noise and αs = β(s) − β(t) is the change in accuracy. By the additivity of normally distributed variables and with y(0) = 0, the distribution of y(s) given preceding steps at times i = 0, …t is,

$${{{{\bf{y}}}}}^{(s)} \sim {{{\mathcal{N}}}}\left({\sum}_{i}{\alpha }_{i}\left(K\phi \left({{{{\boldsymbol{\theta }}}}}^{(i)}\right)-1\right),K\beta (s)\right),$$

(14)

or equivalently,

$${{{{\bf{y}}}}}^{(s)}=\left({\sum}_{i}{\alpha }_{i}\left(K\phi \left({{{{\boldsymbol{\theta }}}}}^{(i)}\right)-1\right)\right)+\sqrt{K\beta (s)}{{{\bf{z}}}},$$

(15)

where z ~ N(0I) is a fixed isotropic noise. Sampling according to Equation (15) would yield an ODE similar to that described in29. However, we observed that simply by replacing the summation of previous predictions with the most recent prediction, e.g.,

$${{{{\bf{y}}}}}^{(s)}=\beta (s)\left(K\phi \left({{{{\boldsymbol{\theta }}}}}^{(t)}\right)-1\right)+\sqrt{K\beta (s)}{{{\bf{z}}}},$$

(16)

substantially reduced the perplexity of generated samples. Effectively, this method is equivalent to taking our most recent prediction \(\phi ({{{{\boldsymbol{\theta }}}}}^{(t)})\) and supposing that we had predicted it at every step 0…s. This method is further motivated under the assumption that the most recent prediction from the model is the ‘best guess’ for the centre of the input distribution; by reconstructing the input distribution at that point, we more closely match the flow distribution seen during sampling. We note the similarity of our method to the sampling method proposed in ref. 63, where at each step, the input distribution is reconstructed from the most recent prediction. The main distinction between the two methods is that the sampler proposed in ref. 63 resampled the underlying isotropic noise at each step (SDE-like), whereas our method uses a fixed isotropic noise throughout the process (ODE-like). Our overall sampling method is described in Box 2, which we refer to as Reconstructed ODE (R-ODE). As the choice of sampling method can dramatically change performance, we include an empirical comparison of various sampling methods in Supplementary Note 4.

Inpainting

Consider the task of inpainting some sequence x according to a binary mask m [0, 1]D where mi = 0 indicates that the ith element of x should be used as conditioning information, and mi = 1 indicates that the ith element of x should be inpainted. Inspired by SMCDiff68, we treat conditional generation of the inpainted regions as a sequential Monte Carlo problem which we may solve with particle filtering69,70. Given an output network ϕ and input distribution θ(t) at time t, the KL divergence of the sender and receiver distributions, with accuracy α, for the conditioning subset xm is,

$${D}_{KL}\left({p}_{s}^{{{{\bf{m}}}}}(\cdot \,|\, {{{\bf{x}}}};\alpha )\,| |\, {p}_{r}^{{{{\bf{m}}}}}\left(\cdot | \phi \left({{{{\boldsymbol{\theta }}}}}^{(t)}\right);\alpha \right)\right)={\sum }_{i=1}^{D}{m}_{i}\frac{{\alpha }_{t}K}{2}{\left| \left| {{{\bf{e}}}}{({{{\bf{x}}}})}_{i}-\phi {\left({{{{\boldsymbol{\theta }}}}}^{(t)}\right)}_{i}\right| \right| }^{2},$$

(17)

which we will denote D(xmθ(t)α). Given q particles at time t with input distributions \({{{{\boldsymbol{\theta }}}}}_{1}^{(t)}\ldots {{{{\boldsymbol{\theta }}}}}_{q}^{(t)}\), we resample each particle with probability proportional to \({e}^{D({{{\bf{x}}}},{{{\bf{m}}}},{{{{\boldsymbol{\theta }}}}}_{i}^{(t)},\alpha )}\). Combining SMC with our sampling method, we arrive at the algorithm detailed in Box 3.

Pre-training ProtBFN

Model

ProtBFN is based on the 650-million parameter architecture used in ref. 6, which is a BERT-style encoder-only transformer71 with rotary positional embeddings72. The architecture consists of 33 layers, each consisting of 20-head multi-head self-attention73 in 1280-dimensional space followed by a single-layer MLP with GeLU activation74 and a hidden dimension of 5120. The only substantial difference between ProtBFN’s architecture and that used in ref. 6 is that the initial token embedding is replaced with a linear projection, as our network’s input θ(t) is a distribution over possible token values.

Data

ProtBFN is trained on data obtained from the January 2024 release of UniProtKB31. The data is filtered according to the Protein Existence (PE) property, including only those proteins that are inferred from homology (PE = 3), have evidence at the transcript level (PE = 2) or have evidence at the protein level (PE = 1). By removing those proteins which are known to be hypothetical (PE = 4) or are of unknown existence (PE = 5), ProtBFN is restricted to model the distribution of proteins that are very likely to exist, meaning that greater confidence can be placed in the sequences it generates. Additionally, we found that it was substantially faster to train a model when removing hypothetical proteins, indicating that they introduce substantial amounts of additional entropy, which may imply their generally lower quality. Additionally, ProtBFN is trained only on those sequences with length l < 512, with the final token used to encode an end-of-sequence (EOS) token. After filtering by PE and length, the final training set contains 71 million sequences. Where clusters are used to reweight and debias the data, the clusters are obtained from UniRef5075. Each sequence is represented by its amino acids followed by an end-of-sequence (EOS) token. All other tokens after EOS are PAD tokens, which are treated as normal tokens for the purposes of noisy observations, predictions and loss. As this dataset is a (C)urated and (C)lustered subset of UniProt, we will refer to it for the rest of the text as UniProtCC. As the choice of data is essential to the overall performance of any generative model, we provide a robust ablation of UniProtCC against UniRef50 in Supplementary Note 5.

Training

ProtBFN is first pre-trained for 250,000 training-steps with a batch size of 8192. We observed that a large batch size was necessary to obtain stable gradient estimates. Adam76 is used with β1 = 0.9 and β2 = 0.98 as in ref. 28. The learning rate is initialised to 0 and linearly increased to 10−4 at step 10,000, after which it is held constant. Throughout training, the norm of the gradient is clipped to 500. A copy of the network’s parameters is maintained with an exponential moving average of the weights with decay rate 0.999. During this first phase of training, samples are drawn uniformly at random from all training data.

Next, the model is trained for a further 250,000 steps with clustered data. Specifically, each cluster is constructed by taking all samples within the corresponding UniRef50 cluster, which passes through both the PE and the length filters. During this training phase, each cluster is sampled with probability proportional to the square root of its size, so that ProtBFN is debiased away from those proteins most heavily studied by humans, but not overly focused on very data-sparse 1-member clusters as would happen when uniformly sampling clusters or training only on the UniRef50 cluster centres as done in15. Once a cluster has been sampled, any sequence contained within it is chosen uniformly at random. During this second phase of training, the optimiser is completely reset and again, the learning rate is linearly increased from 0 to 10−5 over the first 10, 000 steps. The main focus of this article is providing evidence BFNs can be used in the protein space. For more in-depth analysis concerning biases present in protein data please refer to refs. 53,77 and 78.

The training curve of ProtBFN is shown in Fig. 5. The loss decreases monotonically over the first 250,000 steps. At the point at which the cluster sample weighting is introduced, the loss increases significantly, reflecting the distribution shift of the underlying training data. The loss continues to monotonically decrease over the next 250,000 steps, although overall loss remains higher than during the initial training stage, indicating that the introduction of the weighted cluster sampling leads to a substantially higher underlying entropy of training data.

Fig. 5: Training curve for the pre-training of ProtBFN.
figure 5

The loss is averaged over every 1000 consecutive steps to produce a smooth plot. The dashed line indicates the step at which the weighting of clusters is changed. Source data are provided as a Source Data file.

Sampling and filtering

To generate the 10,000 samples used for ProtBFN de novo generation results, we use the sampling algorithm described in Box 2 with N = 10,000, e.g., 10,000 sampling steps, and with the weights of the model obtained from the exponential moving average. We found that the lower-temperature sampling method occasionally produced pathological sequences that are highly repetitive. To remedy this, we counted the number of sequential repeats of any given amino acid and considered any amino acid repetitive if it repeats more than 3 times. For any given sequence, the repetitive score was calculated as the total number of repetitive amino acids divided by the sequence length. We discarded any sequence within the top 20th percentile with respect to their repetitive score. Additionally, we found that the sampling method occasionally generated sequences with very high perplexity. We, therefore, discarded any sequence within the top 30th percentile with respect to their perplexity. This process rejected approximately 44% of generated samples.

Coverage score

For a given number of sequences, we can estimate the expected number of unique clusters to which the sequences would be assigned to. This estimate assumes that a given sequence can be assigned to multiple clusters, which is the case when clustering according to 50% sequence identity. If we sample n sequences randomly and independently from a set of clusters with normalised weighting ωi, the expected number of unique clusters sampled can be calculated as

$${\mathbb{E}}\left[{N}_{{{{\rm{hits}}}}}\right]={\sum}_{i}\left(1-{(1-{\omega }_{i})}^{n}\right).$$

(18)

For ProtGPT2 and EvoDiff we use an equal weight for each cluster, since that corresponds to the UniRef50 dataset, while for ProtBFN we use cluster weights inversely proportional to the cluster size, as that corresponds to our training data. For our diversity score, we divide the number of unique clusters hit by the expected number of cluster hits as calculated above. This estimate does over-count the expected number of unique clusters since it treats the clusters independently; however, we compensate for that by normalising the diversity scores by the equivalent metric calculated on a large subsample of our training data.

Fine-tuning AbBFN

Data

We downloaded the heavy-chain unpaired OAS dataset43 on 21st February 2024. We filtered the data with a similar procedure to ref. 79, which is as follows:

  1. 1.

    Filter out the studies: Bonsignori et al.80, Halliley et al.81, Thornqvist et al.82.

  2. 2.

    Filter out studies originating from immature B cells (BType is “Immature-B-Cells” or “Pre-B-Cells”).

  3. 3.

    Filter out studies originating from B cell cancers (Disease is “Light Chain Amyloidosis” or “CLL”).

Then for each remaining study, filter the sequences as follows:

  1. 1.

    Filter if: sequence contains a stop codon; sequence is non-productive; V and J regions are out of frame; Framework region 2 is missing; Framework region 3 is missing; CDR3 is longer than 37 amino acids; J region identity is less than 50%; CDR edge is missing an amino acid; locus does not match chain type.

  2. 2.

    Remove sequences if they only appear once in a study, then make them unique.

  3. 3.

    Filter if the conserved cysteine residue is not present, or misnumbered in the ANARCI numbering.

  4. 4.

    Create the near-full-length sequence IMGT positions 21 through 128 (127 for light chains and for heavy chains from rabbits and camels). Filter if framework 1 region is  <21 amino acids.

  5. 5.

    Remove duplicate near-full-length sequences and tally up the counts, filter out sequences which only appeared once dropped on the grounds of insufficient evidence of genuine biological sequence as opposed to sequencing errors, following79.

  6. 6.

    Filter out sequences which contain any amino acids which are not the standard 20.

  7. 7.

    We use the sequence from the full ANARCI83 numbering, not the near-full-length sequence as in ref. 79.

This corresponds to all of the filters/preprocessing described in ref. 79 but using full variable domain sequence (insofar as it was present in the original sequence) instead of the near-full-length sequence. We applied one additional filter, which removed any sequences that had an ANARCI numbering with an empty region.

To create the SAbDab test set, we used SAbDab data downloaded on 29th February 2024, specifically the summary table, to select only the paired chains, ensuring we exclude single-domain antibodies; most of these are of camelid origin and, therefore, belong to germline genes which were not present in our training set. We also removed single-chain fragment variable (scFv) antibodies. We then parse SEQRES attributes from the original SAbDab PDB files and ran ANARCI83 on these using the IMGT numbering scheme to obtain the final variable domain sequences as ANARCI retains only the subset of the sequence that could be numbered, thereby removing additions such as purification tags.

To ensure that the training data is dis-similar to the testing data, we split the separated OAS and SAbDab data into heavy and light chains. We then uniformly selected and set aside 20,000 heavy-chain sequences from the OAS data at random and added all heavy chains from SAbDab to construct the test dataset. Next, we used MMSeqs235 to query this combined dataset against the remainder of the OAS dataset with default sensitivity of 5.7, minimum sequence identity of 0.95, coverage of 0.8 and coverage mode 0. We removed any hits from the training data. Due to the presence of heavily engineered antibodies, the test data originating from SAbDab is more out-of-distribution when compared to the test data originating from OAS, and filtering out the similar sequences still retained 99% of the data, whereas filtering out sequences that are similar to the 20,000 uniformly-sampled test set sequences from OAS retained 79% of the OAS data. Filtering out these similar sequences produced 195M training examples, compared to 248M before sequence similarity filtering.

Training

AbBFN is fine-tuned from the ProtBFN model on the filtered OAS data. For computational efficiency, the maximum sequence length is reduced to 256. It is trained for 100,000 steps with a batch size of 8192. Adam76 is used with β1  = 0.9 and β2 = 0.98 as in ref. 28. The learning rate is initialised to 0 and linearly increased to 10−5 at step 10, 000, after which it is held constant. Throughout the training, the norm of the gradient is clipped to 500. A copy of the network’s parameters is maintained with an exponential moving average of the weights with a decay rate of 0.999. The training curve of AbBFN is shown in Fig. 6.

Fig. 6: Training curve for the fine-tuning of AbBFN.
figure 6

The loss is averaged over every 1000 consecutive steps to produce a smooth plot. Source data are provided as a Source Data file.

Fine-tuning AbBFN+

To assess if AbBFN is able to learn the distribution of heavy-chain sequences in SAbDab, we further fine-tune the model on each of the 9 train folds from the 10-fold cross-validation splits used by ref. 44, using the remaining test fold for evaluation. The model is fine-tuned for 1000 training-steps with a batch size of 512. Adam76 is used with β1 = 0.9 and β2 = 0.98, and the learning rate is linearly increased from 0 to 10−5 over the 1000 training-steps. Throughout the training, the norm of the gradient is clipped to 500. A copy of the network’s parameters is maintained with an exponential moving average of the weights with a decay rate of 0.995. The exponential moving average parameters are used for inpainting results.

Sampling AbBFN

To generate the 10,000 samples used for AbBFN de novo generation results, we use the sampling algorithm described in Box 2 with N = 10,000, e.g., 10,000 sampling steps, and with the weights of the model obtained from the exponential moving average. We did not find that it was necessary to filter samples by repetitiveness or perplexity, possibly due to the increased simplicity of the OAS domain in comparison to UniProtCC by virtue of the common immunoglobulin fold.

Inpainting

To inpaint OAS and SAbDab test sequences, we use the algorithm described in Box 3. We use p = 1024 particles for each sequence and N = 100 sampling steps. For both AbBFN and AbBFN+, we use the weights of the model obtained from the exponential moving average during training.

Computational resources

ProtBFN was trained on 128 TPU v4 chips for approximately 2 weeks. AbBFN was further trained for approximately 3 days on 64 TPU v4 chips.

Reporting summary

Further information on research design is available in the Nature Portfolio Reporting Summary linked to this article.

link