Quantifying inductive bias with Bayesian priors
The prior over functions, P(f), is the probability that a DNN \(\mathcalN(\Theta )\) expresses f upon random sampling of parameters over a parameter initialization distribution Ppar(Θ):
$$P(f)=\int \mathbb1[\mathcalN(\Theta )==f]P_{\rmpar}(\Theta )d\Theta ,$$
(1)
where \(\mathbb1\) is an indicator function (1 if its argument is true, and 0 otherwise). Explicitly, this term is 1 if the neural network \({{\mathcalN}}(\Theta )\) expresses f with parameters Θ, else 0. It was shown in ref. 24 that, for ReLU activation functions, P(f) for the Boolean system was insensitive to different choices of Ppar(Θ), and that it exhibits an exponential bias of the form \(P(f)\lesssim 2^-a\tildeK(f)+b\) towards simple functions with low descriptional complexity \(\tildeK(f)\), which is a proxy for the true (but uncomputable) Kolmogorov complexity. We will, as in ref. 24, calculate \(\tildeK(f)\) using CLZ, a Lempel-Ziv (LZ) based complexity measure from ref. 25 on the 2n long bitstring that describes the function, taken on an ordered list of inputs. Other complexity measures give similar results24,26, so there is nothing fundamental about this particular choice. To simplify notation, we will use K(f) instead of \(\tildeK(f)\). The exponential drop of P(f) with K(f) in the map from parameters to functions is consistent with an algorithmic information theory (AIT) coding theorem27 inspired simplicity bias bound25 which works for a much wider set of input-output maps. It was argued in ref. 24 that if this inductive bias in the priors matches the simplicity of structured data then it would help explain why DNNs generalize so well. However, the weakness of that work, and related works arguing for such a bias towards simplicity24,26,28,29,30,31,32,33,34,35, is that it is typically not possible to significantly change this inductive bias towards simplicity, making it hard to conclusively show that it is not some other property of the network that instead generates the good performance. Here we exploit a particularity of \(\tanh\) activation functions that enable us to significantly vary the inductive bias of DNNs. In particular, for a Gaussian Ppar(Θ) with standard deviation σw, it was shown36,37 that, as σw increases, there is a transition to a chaotic regime. Moreover, it was recently demonstrated that the simplicity bias in P(f) becomes weaker in the chaotic regime38 (see also Supplementary Note 3). We will exploit this behavior to systematically vary the inductive bias over functions in the prior.
In Fig. 1a, b we depict prior probabilities P(f) for functions f defined on all 128 inputs of a n = 7 Boolean system upon random sampling of parameters of an FCN with 10 layers and hidden width 40 (which is provably fully expressive for this system31), and \(\tanh\) activation functions. The simplicity bias in P(f) becomes weaker as the width σw of the Gaussian Ppar(σw) increases. By contrast, for ReLU activations, the bias in P(f) barely changes with σw (see Fig. S3a). The effect of the decrease in simplicity bias on DNN generalization performance is demonstrated in Fig. 1c for a DNN trained to zero error on a training set S of size m = 64 using advSGD (an SGD variant taken from ref. 24), and tested on the other 64 inputs xi ∈ T. The generalization error (the fraction of incorrect predictions on T) varies as a function of the complexity of the target function. Although all these DNNs exhibit simplicity bias, weaker forms of the bias correspond to significantly worse generalization on the simpler targets (see also Supplementary Note 10). For very complex targets, both networks perform poorly. For reference, we also show an unbiased learner, where functions f are chosen uniformly at random with the proviso that they exactly fit the training set S. Not surprisingly, given the 264 ≈ 2 × 1019 functions that can fit S, the performance of this unbiased learner is no better than random chance.

a Prior P(f) that a Nl-layer FCN with \(\tanh\) activations generates n = 7 Boolean functions f, ranked by probability of individual functions, generated from 108 random samples of parameters Θ over a Gaussian Ppar(Θ) with standard deviations σw = 1…8. Also compared is a ReLU-activated DNN. The dotted blue line denotes a Zipf’s law prior24 \(P(f)=1/((128\ln 2)Rank(f))\). b P(f) versus LZ complexity K for the networks from (a). c generalization error versus K of the target function for an unbiased learner (green), and \(\sigma _w=1,8\,\tanh\) networks trained to zero error with advSGD24 on cross-entropy loss with training set S of size m = 64, for 1000 random initializations. The error is calculated on the remaining ∣T∣ = 64 functions. Error bars are one standard deviation (See Fig. S17 for PAC-Bayes bounds on this data). d, e, f Scatterplots of generalization error versus learned function LZ complexity, from 1000 random initializations for three target functions from subfigure (c). The dashed vertical line denotes the target function complexity. The black cross represents the mode function. The histograms at the top (side) of the plots show the posterior probability upon training as a function of complexity,PSGD(K∣S) (error,PSGD(ϵG∣S)). g The prior probability P(K) to obtain a function of LZ complexity K for uniform random sampling of 108, compared to a theoretical perfect compressor. 90% of the probability mass lies to the right of the vertical dotted lines, and the dash-dot line denotes an extrapolation to low K. h P(K) is relatively uniform on K for the σw = 1 system, while it is highly biased towards complex functions for the σw = 8 networks. The large difference in these priors helps explain the significant variation in DNN performance. i generalization error for the K-learning restriction for the σw = 1, 8 DNNs and for an unbiased learner, all for ∣S∣ = 100. ϵS is the training error and ϵG is the generalization error on the test set. The vertical dashed line is the complexity Kt of the target. Also compared are the standard realizable PAC and marginal-likelihood PAC-Bayes bounds for the unbiased learner. In 104 samples, no solutions were found with K ≲ 70 for the σw = 8 DNN, and with K ≳ 70 for the σw = 1 DNN.
The scatter plots of Fig. 1d–f depict a more fine-grained picture of the behavior of the SGD-trained networks for three different target functions. For each target, 1000 independent initializations of the SGD optimizer, with initial parameters taken from Ppar(σw), are used. The generalization error and complexity of each function found when the DNN first reaches zero training error are plotted. Since there are 264 possible functions that give zero error on the training set S, it is not surprising that the DNN converges to many different functions upon different random initializations. For the σw = 1 network (where P(f) resembles that of ReLU networks) the most common function is typically simpler than the target. By contrast, the less biased network converges on functions that are typically more complex than the target. As the target itself becomes more complex, the relative difference between the two generalization errors decreases, because the strong inductive bias towards simple functions of the first network becomes less useful. No free lunch theorems for supervised learning tell us that when averaged over all target functions, the three learners above will perform equally badly39,40 (see also Supplementary Note 43).
Priors over complexity
To understand why relatively modest changes in the inductive bias towards simplicity lead to such significant differences in generalization performance, we need another important ingredient, namely how the number of functions vary with complexity. Basic counting arguments imply that the number of strings of a fixed length that have complexity K scales exponentially as 2K 27. Therefore, the vast majority of functions picked at random will have high complexity. This exponential growth of the number of functions with complexity can be captured in a more coarse-grained prior, the probability P(K) that the DNN expresses a function of complexity K upon random sampling of parameters over a parameter initialization function Ppar(Θ), which can also be written in terms of functions as \(P(K^\prime )=\sum _f\in \mathcalH_K^\prime P(f)\), the weighted sum over the set \(\mathcalH_K^\prime \) of all functions with complexity \(\tildeK(f)=K^\prime \). In Fig. 1g P(K) is shown for uniform random sampling of functions for 108 samples using the LZ measure, and also for the theoretical ideal compressor with \(P(K)=2^K-K_max-1\) over all 2128 ≈ 3 × 1038 functions (see also Supplementary Note 9). In (h) we display P(K) for functions not sampled at random, but rather from the two networks. There is a dramatic difference between random sampling functions (as in (g)) and between the network with σw = 1, where P(K) is nearly flat. This behavior follows from the interesting fact that the AIT coding theorem-like scaling24,25 of the prior over functions \(P(f) \sim 2^-\tildeK(f)\) counters the 2K growth in the number of functions.
By contrast, even though, relative to the 38 or so orders of magnitude scale on which P(f) varies, the more artefactual σw = 8 system has strong simplicity bias (we estimate that for the simplest functions, P(f) is about 1025 times higher than the mean probability \( < P(f) > =2^-128\approx 3\times 10^-39\)), this is not enough to counter the 2K growth in the number of functions with complexity. Therefore, this DNN is exponentially more likely to throw up complex functions, an effect that SGD is unable to overcome.
More generally, the fact that the number of complex functions grows exponentially with complexity K lies at the heart of the classical explanation of why an insufficiently biased agent suffers from variance: It can too easily find many different functions that all fit the data. The marked differences in the generalization performance between the two networks observed in Fig. 1c–f can be therefore traced to differences in the inductive bias of the networks, as measured by the differences in their priors.
Artificially restricting model capacity
To further illustrate the effect of inductive bias we create a K-learner that only allows functions with complexity ≤KM to be learned and discards all others. As can be seen in Fig. 1i, the learners typically cannot reach zero training error on the training set if KM is less than the target function complexity Kt. For KM ≥ Kt, zero training error can be reached and not surprisingly, the lowest generalization error occurs when KM = Kt. As the upper limit KM is increased, all three learning agents are more likely to make errors in predictions due to variance. The random learner has an error that grows linearly with KM. This behavior can be understood with a classic Probably Approximately Correct (PAC) bound6 where the generalization error (with confidence 0 ≤ (1 − δ) ≤1) scales as \(\epsilon _G\le (\ln | {{\mathcalH}}_\le K_M| -\ln \delta )/m\), where \(| {{{{\mathcalH}}}}_\le K_M| \,K\le K_M\) is the size of the hypothesis class of all functions with K≤KM; the bound scales linearly in KM, as the error does (see Supplementary Note 7 for further discussion including the more sophisticated PAC-Bayes bound41,42). The generalization error for the σw = 1 DNN does not change much with KM for KM > Kt because the strong inductive bias towards simple solutions means access to higher complexity solutions doesn’t significantly change what the DNN converges on.
Finally, we show data for DNNs in the ordered regime with σw ≪ 1, and for other optimizers, loss functions, and activation functions in Figs. S6–S11. These results broadly exhibit the same behavior we describe here.
Calculating the Bayesian posterior and likelihood
To better understand the generalization behavior observed in Fig. 1 we apply Bayes’ rule, P(f∣S) = P(S∣f)P(f)/P(S) to calculate the Bayesian posterior P(f∣S) from the prior P(f), the likelihood P(S∣f), and the marginal likelihood P(S). Since we condition on zero training error, the likelihood takes on a simple form. P(S∣f) = 1if ∀ xi ∈ S, f(xi) = yi, while P(S∣f) = 0 otherwise. For a fixed training set, all the variation in P(f∣S) for f ∈ U(S), the set of all functions compatible with S, comes from the prior P(f) since P(S) is constant. Therefore, in this Bayesian picture, the bias in the prior is translated over to the posterior.
The marginal likelihood also takes a relatively simple form for discrete functions, since P(S) = ∑fP(S∣f)P(f) = ∑f∈U(S)P(f). It is equivalent to the probability that the DNN obtains zero error on the training set S upon random sampling of parameters, and so can be interpreted as a measure of the inductive bias towards the data. The Marginal-likelihood PAC-Bayes bound42 makes a direct link \(P(S)\lesssim e^-m\epsilon _G\) to the generalization error ϵG which captures the intuition that, for a given m, a better inductive bias towards the data (larger P(S)) implies better performance (lower ϵG).
One can also define the posterior probability PSGD(f∣S), that a network trained with SGD (or another optimizer) on training set S, when initialized with Ppar(Θ), converges on function f. For simplicity, we take this probability at the epoch where the system first reaches zero training error. Note that in Fig. 1d–f it is this SGD-based posterior that we plot in the histograms at the top and sides of the plots, with functions grouped either by complexity, which we will call PSGD(K∣S), or by generalization error ϵG, which we will call PSGD(ϵG∣S).
DNNs are typically trained by some form of SGD, and not by randomly sampling over parameters which is much less efficient. However, a recent study43 which carefully compared the two posteriors has shown that to first order, PB(f∣S) ≈ PSGD(f∣S), for many different data sets and DNN architectures. We demonstrate this close similarity in Fig. S15 explicitly for our n = 7 Boolean system. This evidence suggests that Bayesian posteriors calculated by random sampling of parameters, which are much simpler to analyze, can be used to understand the dominant behavior of an SGD-trained DNN, even if, for example, hyperparameter tuning can lead to 2nd-order deviations between the two methods (see also Supplementary Note 1).
To test the predictive power of our Bayesian picture, we first define the function error ϵ(f) as the fraction of incorrect labels f produces on the full set of inputs. Next, we average Bayes’ rule over all training sets S of size m:
$$ S)\rangle _m=P(f)\left\langle \frac f)P(S)\right\rangle _m\approx \fracP(f)\left(1-\epsilon (f)\right)^m\langle P(S)\rangle _m$$
(2)
where the mean likelihood 〈P(S∣f)〉m = (1−ϵ(f))m is the probability of a function f obtaining zero error on a training set of size m. In the second step, we approximate the average of the ratio with the ratio of the averages which should be accurate if P(S) is highly concentrated, as is expected if the training set is not too small.
Equation (2) is hard to calculate, so we coarse-grain it by grouping together functions by their complexity:
$$\langle P(K_m=\sum_C_LZ(f)=K S)\rangle _m\propto P(K)\langle \left(1-\epsilon _G(K)\right)^m\rangle _l,$$
(3)
and in the second step make a decoupling approximation where we average the likelihood term over a small numberlof functions with complexity K with lowest generalization error ϵG(K) since the smallest errors in the sum dominate exponentially since \((1-\epsilon _G)\approx e^-\epsilon _G\) for ∣ϵG∣ ≪ 1. We then multiply by P(K), which takes into account the value of the prior and the multiplicity of functions at that K, and normalize ∑KP(K∣S) = 1. For a given target, we make the ansatz that this decoupling approximation provides an estimate that scales as the true (averaged) posterior.
To test our approximations, we first plot, in Fig. 2a–c, the likelihood term in Equation (3) for three different target functions. To obtain these curves, we considered a large number of functions (including all functions with up to 5 errors w.r.t. the target, with further functions sampled). For each complexity, we average this term over the l = 5 functions with smallest ϵG. Not surprisingly, functions close to the complexity of the target have the smallest error. These graphs help illustrate how the DNN interacts with data. As the training set size m increases, functions that are most likely to be found upon training to zero training error are increasingly concentrated close to the complexity of the target function.

a–c depict the mean likelihood \(\langle (1-\epsilon _G(K))^m\rangle _5\) from Equation (3), averaged over training sets, and over the 5 lowest error functions at each K. This term depends on data and is independent of the DNN architecture. With increasing m it peaks more sharply around the complexity of the target. In (d–f) we compare the posteriors over complexity, \({\langle P_{\rmSGD}(K| S)\rangle }_m\), for SGD (darker blue and red) averaged over training sets of size m, to the prediction of 〈P(K∣S)〉m from Equation (3) (lighter blue and orange), calculated by multiplying the Bayesian likelihood curves in (a–c) by the prior P(K) shown in Fig. 1h. The light (Bayes) and dark (DNN) blue histograms are from the σw = 1 system, and the orange (Bayes) and red (DNN) histograms are from the σw = 8 system which has less bias towards simple functions. The Bayesian decoupling approximation (Equation (3)) captures the dominant trends in the behavior of the SGD-trained networks as a function of data complexity and training set size. Quantitative measures of the similarity between the posteriors can be found in Fig. S22.
To test the decoupling approximation from Eq. (3), we compare in Fig. 2d–f the posterior 〈P(K∣S)〉m, calculated by multiplying the Bayesian likelihood curve from Fig. 2a–c with the two Bayesian priors P(K) from Fig. 1h, i, to the posteriors \( S)\rangle _m\) calculated by advSGD24 over a 1000 different parameter initializations and training sets. It is remarkable to see how well the simple decoupling approximation performs across target functions and training set sizes. In Figs. S13 and S14 we demonstrate the robustness of our approach by showing that using l = 1 or l = 50 functions does not change the predictions much. This success suggests that our simple approach captures the essence of the interaction between the data (measured by the likelihood, which is independent of the learning algorithm), and the DNN architecture (which is measured by the prior and is independent of the data).
We have therefore separated out two of the three parts of the tripartite scheme, which leaves the training algorithm. In the figures above our Bayesian approximation captures the dominant behavior of an SGD-trained network. This correspondence is consistent with the results and arguments of ref. 43. We checked this further in Fig. S15 for a similar set-up using MSE loss, where Bayesian posteriors can be exactly calculated using Gaussian processes (GPs). The direct Bayesian GP calculation closely matches SGD-based results for our much smaller network. Note that, in the spirit of model calculations, as called for in ref. 3,4, we mainly used a much smaller DNN. But their agreement with the GP-based posteriors, calculated for the infinite width limit, shows that at the scale of our Bayesian approach to the 1st-order generalization question we are addressing here, the size of the DNN is not an important factor. The width of a DNN can, of course, be a factor for 2nd order generalization questions.
Beyond the Boolean model: MNIST & CIFAR-10
Can the principles worked out for the Boolean system be observed in larger systems that are closer to the standard practice of DNNs? To this end, we show, in Fig. 3a, b how the generalization error for the popular image datasets MNIST and CIFAR-10 changes as a function of the initial parameter width σw and the number of layers Nl for a standard FCN, trained with SGD on cross-entropy loss with \(\tanh\) activation functions. Larger σw and larger Nl push the system deeper into the chaotic regime36,37 and result in decreasing generalization performance, similar to what we observe for the Boolean system for relatively simple targets. In Fig. 3, we plot the prior over complexity P(K) for a complexity measure called the critical sample ratio (CSR)28, an estimate of the density of decision boundaries that should be appropriate for this problem. Again, increasing σw greatly increases the prior probability that the DNN produces more complex functions upon random sampling of parameters. Thus the decrease in generalization performance is consistent with the inductive bias of the network becoming less simplicity biased, and therefore less well aligned with structure in the data. Indeed, datasets such as MNIST and CIFAR-10 are thought to be relatively simple44,45.

a MNIST generalization error for FCNs on a 1000 image training set versus σw for three depths. b CIFAR10 generalization error for FCNs trained on a 5000 image training set versus σw for three depths. The FCNs, made of multiple hidden layers of width 200, were trained with SGD with batch size 32 and lr=10−3 until 100% accuracy was first achieved on the training set. Error bars are one standard deviation. c Complexity prior P(K), for CSR complexity, for 1000 MNIST images for randomly initialized networks of 10 layers and σw = 1, 2. Probabilities are estimated from a sample of 2 × 104 parameters. d, e, f are scatterplots of generalization error versus the CSR for 1000 networks trained to 100% accuracy on a training set of 1000 MNIST images and tested on 1000 different images. In (d) the training labels are uncorrupted, in (e, f) 25% and 50% of the training labels are corrupted respectively. Note the qualitative similarity to the scatter plots in Fig. 1d–f.
These patterns are further illustrated in Fig. 3d–f where we show scatterplots of generalization error v.s. CSR complexity for three target functions that vary in complexity (here obtained by corrupting labels). The qualitative behavior is similar to that observed for the Boolean system in Fig. 1. The more simplicity-biased networks perform significantly better on the simpler targets, but the difference with the less simplicity-biased network decreases for more complex targets. While we are unable to directly calculate the likelihoods because these systems are too big, we argue that the strong similarities to our simpler model system suggest that the same basic principles of inductive bias are at work here.
link