Abstract
A longstanding debate surrounds the related hypotheses that low-curvature minima generalize better, and that stochastic gradient descent (SGD) discourages curvature. We offer a more complete and nuanced view in support of both hypotheses. First, we show that curvature harms test performance through two new mechanisms, the shift-curvature and bias-curvature, in addition to a known parameter-covariance mechanism. The shift refers to the difference between train and test local minima, and the bias and covariance are those of the parameter distribution. These three curvature-mediated contributions to test performance are reparametrization-invariant even though curvature itself is not. Although the shift is unknown at training time, the shift-curvature as well as the other mechanisms can still be mitigated by minimizing overall curvature. Second, we derive a new, explicit SGD steady-state distribution showing that SGD optimizes an effective potential related to but different from train loss, and that SGD noise mediates a trade-off between low-loss versus low-curvature regions of this effective potential. Third, combining our test performance analysis with the SGD steady state shows that for small SGD noise, the shift-curvature is the dominant of the three mechanisms. Our experiments demonstrate the significant impact of shift-curvature on test loss, and further explore the relationship between SGD noise and curvature.
Export citation and abstract BibTeX RIS
Original content from this work may be used under the terms of the Creative Commons Attribution 4.0 license. Any further distribution of this work must maintain attribution to the author(s) and the title of the work, journal citation and DOI.
1. Introduction
Understanding generalization remains one of the key questions in machine learning. In typical machine learning applications, we train a model on a training dataset hoping that the model will generalize well to a test dataset, which is a proxy for unseen data. However, our theoretical understanding of how datasets, models, and learning algorithms combine to determine test performance is still incomplete, particularly in the case of overparametrized models that defy classical expectations about overfitting (Belkin 2021). A decades-old debate focuses on the related hypotheses that low curvature of the loss function results in better test performance, and that stochastic gradient descent (SGD) favors lower-curvature local minima. There is strong evidence for these hypotheses. First, many studies have shown that increasing SGD noise (by increasing the ratio of learning rate to batch size) can improve test performance (Goyal et al 2017, Hoffer et al 2017, You et al 2017, Golmant et al 2018, McCandlish et al 2018, Shallue et al 2018, He et al 2019a, Smith et al 2020). Our experiment in figure 8 also confirms this. Second, Keskar et al (2016), Li et al (2017a), Jiang et al (2019) offer empirical evidence that lower-curvature local minima tend to generalize better, an idea often attributed to Hochreiter and Schmidhuber (1997). Third, Jastrzkebski et al (2017) make direct theoretical connections between test performance, curvature, and SGD noise (albeit under strong assumptions). They show that under the SGD steady-state distribution, the expected test loss near a given local minimum depends on the trace of curvature times parameter covariance, assuming a constant Hessian equal to SGD noise. They also show that the probability of SGD landing near a particular local minimum is inversely related to its curvature, with noise mediating a depth/curvature tradeoff, assuming constant, isotropic SGD noise and valid Laplace approximation. On the other hand, there are several arguments against the idea that curvature predicts test performance. Dinh et al (2017) point out that model reparametrization can arbitrarily change curvature without changing test performance. Further, a number of recent works argue that Jastrzkebski et al (2017) and others present an oversimplified view of curvature, since loss landscapes often (e.g. when there are many more model parameters than training datapoints) have many directions of near-zero curvature rather than locally strictly-convex valleys; in particular, the SGD noise matrix and Hessian of the loss are different, and both are highly non-isotropic with many near-zero eigenvalues (Draxler et al 2018, Sagun et al 2018, Li et al 2020). Our work aims to help reconcile these points of view by removing unrealistic assumptions such as constant, isotropic Hessian and loss gradient covariance, identifying new and possibly more significant mechanisms by which curvature affects test performance, and showing that these curvature-dependent contributions to test performance are reparametrization invariant.
At a high level, we find that the intuitive connections between test performance, curvature, and SGD noise still hold, but in a more complete and nuanced way. We show that the test loss averaged over an arbitrary parameter distribution depends on the shift-curvature (in addition to the bias-curvature and the previously-identified covariance-curvature). The shift-curvature is equal to the curvature of the loss in the direction of the shift between nearby train and test local minima, as shown in figure 1; such a shift may be caused by dataset sampling or possibly by test and train datasets coming from different distributions. However, the shift is unknown at training time, so it is beneficial to reduce curvature in all directions. Second, we show that the SGD steady-state distribution favors train local minima with low curvature in all directions, although the curvature is not that of the train loss, but rather of a related effective potential which we obtain by refining existing results about the steady state of SGD. Third, we find that the shift-curvature may be the most significant curvature term when SGD noise is small. Our empirical results show that shift-curvature indeed has a significant impact on test loss.
In more detail, we refine existing results for the steady-state distribution of SGD as follows. Starting from a continuous-time diffusion approximation of SGD due to Li et al (2017b) (for which we also offer an alternative, constructive derivation which may be helpful for analyzing other algorithms), we derive a steady-state distribution showing that SGD minimizes an explicit effective potential that is related to but different from the training loss. Our expression is more explicit than the result in Chaudhari and Soatto (2018), and more general than that in Jastrzkebski et al (2017) and Seung et al (1992) since we remove their assumptions of constant and isotropic SGD noise.
This paper is organized as follows. Sections 2 and 3 present the setup and main results for the simpler underparametrized case (more data samples than model parameters). Section 4 discusses the SGD diffusion approximation and its steady state in more detail, describing how the results in section 3 are obtained. Section 5 extends the results to the more complex overparametrized case. The appendices contain derivations and additional experiments.
2. Setup
In this section and the next, we present our setup and main results in the underparametrized case (section 5 extends to the overparametrized case). We assume we are given a loss function that depends on a single data sample and model parameters , where ι and p are positive integers. We are also given a training set (tr is short for train) consisting of points sampled i.i.d. from an underlying train distribution, and a test set sampled i.i.d. from an underlying test distribution. Often the train and test distributions are identical, but sometimes they differ from one another (for example, when the data distribution evolves over time and test is sampled at a later time than train). But even when train and test are both sampled from the same underlying distribution, the train and test sets still differ due to sampling (because each data set consists of a finite number of samples). We define the train and test losses for a given θ as the expectation over the train and test sets, respectively:
In general, and are similar but not identical functions, each with multiple local minima. We make three initial assumptions about the local minima: (i) the local minima of both train and test are strict, (ii) each local minimum of the train loss has a nearby local minimum of the test loss, and (iii) the train local minima are countable. We expect (i) to hold for underparametrized models (provided there are at least p linearly-independent per-sample gradients) but not for overparametrized models (which are handled in section 5, where we drop assumption (i)). Assumption (ii) is validated by our experiments, and can be satisfied, for example, in the absence of distribution shift by having large enough datasets. Both (ii) and (iii) could actually be relaxed to apply only to train local minima that receive significant weight in the SGD steady-state distribution, as (4) will make clear, but we will assume they hold for all train minima to simplify the exposition. With these assumptions, we denote the (strict) local minima of train by and the corresponding closest local minima of test by , and generally reserve k as an index that runs through train local minima (). We can then define the shift between corresponding local minima of train and test as:
The shift is small as a consequence of assumption (ii), but nonzero in general because the train and test losses are slightly different (due to dataset sampling or possibly distribution shift). He et al (2019b) experiments and our own demonstrate a nonzero train/test shift for several models and datasets. (We generalize definition (2) to apply to non-strict minima in section 5.) Table 1 summarizes the main notation we use for easy reference.
Table 1. Notation reference guide.
Symbol | Definition | Notes |
---|---|---|
Loss function | ||
Train/test loss at θ | ||
Local minima of test/train loss | ||
Test/train loss at local minima | ||
Curvature at test/train local minima | ||
Shift between train/test minima | ||
Approx. parameter distribution | ||
Bias of ρ at k | ||
SGD diffusion matrix |
We further assume that a stochastic learning algorithm such as SGD processes the train dataset to produce model parameters with a distribution that is, different runs of SGD on the same training set yield i.i.d. samples from The resulting models are evaluated on a test dataset to determine the test performance. One of our main goals is to understand the expected test loss over the distribution of model parameters the algorithm produces (while keeping datasets fixed).
3. Main results
First, consider a solution obtained by a single run of a training algorithm like SGD, i.e. is a sample from . To predict the performance of the trained model on unseen data, we would evaluate the test loss at . We expect to be near a train local minimum because it was found by an algorithm that attempts to minimize the train loss. We can therefore write where sk is the train/test shift introduced above, and denotes the bias of relative to the train local minimum, with both sk and assumed small. We then Taylor expand about , noting that :
where we introduce the more compact notation for the test loss at a local minimum, and for the Hessian (or curvature) of the test loss at one its local minima 1 . Although it comes from a simple Taylor-expansion, equation (3) identifies important quantities impacting test loss that have so far been inadequately discussed in the literature. The equation says that the test loss evaluated near a train local minimum is equal to the test loss at the test local minimum plus a term that depends on an interaction between the curvature, the shift between the train and test local minima, and the bias between and the train local minimum. The shift-curvature term, , plays an especially important role in the rest of this paper; the basic intuition is that, given train and test losses with local minima slightly shifted relative to each other, the test loss evaluated near a train local minimum depends on the local curvature in the direction of the local-minimum shift (Keskar et al 2016 seem to appeal to similar intuition, but do not make it concrete). Since the shift cannot be known at training time, it is beneficial to minimize curvature in all directions, which, as we will see shortly, is what SGD does.
In order to obtain an explicit approximation for the expected test loss over a distribution of parameters, we model the parameter distribution as a countable mixture of Gaussians with one mixture component for each train loss minimum, as proposed by Jastrzkebski et al (2017) for the case of SGD. That is, we write , where , , the component means µk are close to the train local minima, denote the component biases, and denote the component covariances 2 . Averaging the approximate test loss in (3) over such a parameter distribution results in the approximate test performance:
here denotes the trace of a matrix. Equation (4) is one of our main results, and makes explicit how curvature determines test performance: through the covariance of model parameters as well as a quadratic function of the shifts sk and biases bk . Buntine (1991), Jastrzkebski et al (2017) identify the same contribution of curvature to test performance (under the assumption that C = D, which we remove), but the other curvature-dependent terms in (4) are new as far as we know. In particular, the new shift-curvature term (discussed above) appears as , showing that test loss improves when ρ places more weight on minima with lower shift-curvature. The role of shift-curvature is consistent with the empirical results of Verpoort et al (2020), who show that curvature (as measured by ) has a major impact on test performance for small sample sizes, which correspond to larger train/test shift.
We can also show that (4) is reparametrization-invariant. Although multiple authors, including Keskar et al (2016), Jiang et al (2019), and Hochreiter and Schmidhuber (1997), have suggested that generalization quality can be tied directly to curvature, Dinh et al (2017) argue against this by observing that model reparametrization can arbitrarily change the curvature of the local minima. Although the curvature alone is indeed not invariant to reparametrization, we show in appendix
Our second main result is the approximate steady-state distribution of SGD, derived in section 4:
We call the effective potential, is the training set gradient covariance, is the gradient of the training loss, and Z is the partition function of this Boltzmann-like distribution. The temperature T, which is equal to the ratio of learning rate and minibatch size (so typically ), captures the strength of SGD noise (as discussed further in section 4). The effective potential is related to the training loss , but in general not the same: only if is constant and isotropic 3 . Result (5) requires the assumption that the integrand in the line integral is a gradient, which is equivalent to having zero curl, as well as for to be invertible, which is only true in underparametrized case 4 . Equation (5) is more general than the results in Seung et al (1992), Jastrzkebski et al (2017), which assume that the noise is constant and isotropic (i.e. a multiple of the identity). It is different and more explicit than results in Chaudhari and Soatto (2018) for nonconstant, anisotropic noise, enabled by the zero-curl assumption that we rely on to obtain a solution with zero probability current (J in equation (10)). Chaudhari and Soatto (2018) say that the steady state cannot have zero probability current for anisotropic and/or non-constant noise, but we offer a counterexample in appendix B.1. They also stop short of explicitly solving for the effective potential, presumably because they are rightly concerned about inverting ; however, our analysis of the nullspace of shows that this is only an issue in the overparametrized case, in which case a modified SGD still allows us to find a solution (see section 5).
Approximating ρ as a mixture of Gaussians shows that ρ assigns to basin k the weight:
This shows how SGD noise mediates a trade-off between depth and curvature of the effective potential, generalizing the result given in Jastrzkebski et al (2017) for the special case of constant, isotropic covariance . When T = 0 the weights wk indicate that places all its probability mass on the deepest minimum of the effective potential, but as T increases, begins to transfer weight onto other minima. It then favors minima with lower curvature of the effective potential due to the inverse-square-root dependence on the determinant of the curvature of the effective potential. In particular, if two local minima have the same depth, SGD will prefer the one with lower curvature. Compared to Jastrzkebski et al (2017)'s result, ours highlights the complexity arising from non-isotropic noise, in particular that the depth and curvature are those of the effective potential rather than those of the training loss.
Third, to approximate the expected SGD test performance, we specialize (4) to the Gaussian mixture approximation of the SGD steady-state ρ in (5) to obtain:
The terms in (4) involving bias bk and covariance are both O(T), as detailed in section 4. Equation (7) shows that for SGD with small T and non-negligible shifts sk , our newly-identified shift-curvature term may be the most significant way that curvature affects the test loss. In summary, we find that SGD temperature mediates a trade-off between depth and curvature of the effective potential, so that as temperature increases, greater weight wk is placed on lower-curvature basins. These lower-curvature basins tend to have smaller values of shift-curvature, , which can reduce the expected test loss overall 5 . We illustrate this with a conceptual example in the next section.
Our experiments in figures 1–3 demonstrate the significant contribution of shift-curvature to test performance, support the inverse dependence of curvature on SGD temperature, and offer additional empirical insights about the test local minima, curvature, and shift. For figures 1–3, we trained VGG16 (Simonyan and Zisserman 2014) and Resnet10 (He et al
2016) networks, both without batch normalization, on CIFAR10 (Krizhevsky et al
2009). We also performed the same experiment and obtained qualitatively similar results for Resnet20 with batch normalization on CIFAR10, and Resnet18 with batch normalization a subset of ImageNet (Deng et al
2009) (figures 5–7 in appendix G.1). The experiment is designed as follows. We repeated the following procedure over a range of temperatures for the initial SGD run, and over several seeds for each temperature. For each temperature and seed, we first trained our model using SGD with the specified temperature to get close to local minima of train (note that SGD typically finds different train local minima for different temperatures). Next, we initialized from the SGD solution and kept training at extremely low temperature on the train set to get even closer to the train local minimum. Finally, we again initialized from the SGD solution and trained at extremely low temperature, but this time on the test set, to get as close as possible to the test local minimum. Having found the train and test local minima, we linearly interpolated between them in order to study the train and test accuracy and loss along the line connecting their local minima (as shown in figure 1 for two different initial SGD temperatures; note the higher curvature at lower SGD temperature). Finally, we estimated the curvature of the loss functions along this line by reflecting them about their local minima in the direction from toward , and fitting a quadratic centered at the local minimum (as shown in figure 1, right pane). (We focus on the line connecting the local minima because all quantities in (3) with lie along this line. Letting denote the line, with and , the curvature of the fitted quadratic corresponds to , as shown in (31).) Although we could compute the full Hessian matrix, in these experiments we are mainly interested in the curvature along the line which can be computed much more cheaply via the 1D projection; further, because of the shapes of the losses we find that the symmetrized quadratic fit better captures the relevant overall curvature of the basin, as discussed in appendix
Download figure:
Standard image High-resolution imageDownload figure:
Standard image High-resolution imageDownload figure:
Standard image High-resolution imageDownload figure:
Standard image High-resolution imageDownload figure:
Standard image High-resolution imageDownload figure:
Standard image High-resolution image3.1. A simple two-basin example
To illustrate the impact of shift and curvature on test loss and the temperature-controlled depth/curvature trade-off, we construct a simplified synthetic optimization problem where and the train and test losses each have two local minima (the data is implicit because we directly define the train and test losses as functions of θ). We construct the train and test losses to be identical except for a constant shift (i.e. ). We let the noise be a constant as well, so the bias is zero. Thus (5) becomes , the basin weights become and (7) becomes:
since and . In figure 4, the losses each have two local minima with different curvature (hence different shift-curvature, since the shift is constant). On the left, the minimum at (let us assign it k = 0) is deeper but narrower than the minimum at θ = 1 (k = 1), so that and . At T = 0, ρ is a Dirac delta at and . As T increases, w1 increases, and we get a mixture between the two basins. As w1 increases, the depth term worsens ( increases), while the shift-curvature terms improves ( decreases). For larger temperatures, the variance and/or bias terms in the loss become large enough to worsen test performance again, so there is a non-zero and finite optimal temperature that minimizes test performance. On the right, we have a similar setup, except that the deeper basin is also wider. In this case, we see no improvement in test loss with increasing temperature, since there is no trade-off to explore between depth and shift-curvature. Appendix G.3 includes similar one-basin and three-basin cases. In the one-basin case, increasing the temperature can only worsen the expected test loss, since there is no trade-off to explore, and increasing the SGD variance can only hurt both the train and test losses on average. The three basin case is qualitatively similar to the two basin case, but also shows the preference for lower curvature when the two basins are equally deep.
4. Stochastic gradient descent
In this section we derive equations (5)–(7). Recall that SGD attempts to minimize the training loss through a discrete-time stochastic process with state where t indexes the number of SGD updates:
where λ is the learning rate, B is the minibatch size, and the xi
are i.i.d. samples from the train set that comprise the minibatch. Letting , we note that has mean and covariance — we refer to the latter as the SGD noise, and think of the temperature (by analogy with statistical mechanics) as a good summary of noise strength. Typical applications use , and so that , since running SGD with too high a learning rate or temperature results in unstable dynamics (especially early in a run) that often fail to converge, i.e. see Goyal et al (2017). So we generally assume that and . As shown by Li et al (2017b) and our complementary derivation in appendix
is called the probability current. Redefining training time through the change of coordinates has the only effect of removing λ from (10), i.e. yielding and leaving T as the only equation parameter. (Explicitly, rescaling the temporal axis of SGD runs so that , where are the learning rates and are the number of SGD updates for two different runs, should produce similar results as long as the diffusion approximation holds; this can be a useful rule of thumb for adjusting the number of SGD updates as the learning rate is changed.) The fact that T is the only remaining parameter after rescaling time justifies the notion that T is a good summary of SGD noise strength, and means that different SGD runs with the same value of T but different learning rates and minibatch sizes still produce approximately the same steady-state parameter distributions. The dependence of the SGD distribution on T is in agreement with theory given in Jastrzkebski et al (2017), and the empirical results of Goyal et al (2017), He et al (2019a). We also confirm the T dependence in figure 8 (appendix G.2), which shows similar test performance for different minibatch sizes and learning rates as long as T is kept fixed. Finally, Goyal et al (2017) notes that the simple dependence of SGD on T breaks down when the learning rate is sufficiently large even when T is small; a situation where we expect the diffusion approximation to be inaccurate.
Download figure:
Standard image High-resolution imageWith the change of variables just described, is equivalent to the Langevin (stochastic differential) equation:
that appears in Li et al (2017b) and Chaudhari and Soatto (2018) as an approximation of SGD. Fokker–Planck and Langevin equations describe the same underlying continuous-time diffusion process; i.e. see Gardiner (2009). Li et al (2017b) showed that (10) is an order-1 weak approximation of SGD. In appendix A.1 we offer a complementary informal constructive derivation that shows how to approximate any discrete-time Markov process by a continuous-time diffusion process through the truncation of the infinite Kramers–Moyal expansion, which may be useful for analyzing other algorithms as well. Our approach essentially relies on approximating each SGD step as a sum of i.i.d. infinitesimal increments, and validates a conjecture in Bazant (2005) that the moments of the continuous-time diffusion approximation are proportional to the cumulants of the discrete-time process 6 . For the rest of this paper, we assume that (10) is a good approximation to SGD, and study the resulting stationary distribution. Next, we assume that the process described by (10) is ergodic, and seek its unique steady-state solution, i.e. the distribution such that . We can follow Gardiner (2009) and seek a solution where . Appendix B.1 shows that some algebra starting from , assuming that the integrand of the line integral in (5) is a gradient, yields the steady-state distribution in (5). This approach needs an invertible which is only possible in the underparametrized case. Section 5 describes what happens in the overparametrized case.
To approximate the expected SGD test loss (7), we approximate the SGD steady-state distribution as a mixture of Gaussians (similar to Jastrzkebski et al 2017). Since (5) has the form , it is reasonable to expect that that when ρ is multimodal with peaks at the local minima of the effective potential . So we make a Laplace approximation, i.e. we approximate by a quadratic function in a neighborhood of any local minimum to obtain a local Gaussian approximation. Combining the local approximation at every local minimum yields a weighted Gaussian mixture approximation of , with weights given by (6), and bias and variance both O(T). Details are in appendix B.2. (Note that the bias is linear in T but nonzero, confirming a result in Chaudhari and Soatto (2018) that the critical points of υ differ linearly-in-T from those of .) Substituting these results into the approximate test loss (4) and keeping only leading terms in T results in the approximate SGD test performance in (7). The validity of the Gaussian mixture approximation of (5) depends on the local quadratic approximation of υ, which we explore in our experiments with quadratic fits to the train loss, to which υ is related.
5. More parameters than datapoints
We now consider the overparametrized situation ( and/or , recalling that p denotes the number of model parameters and the number train/test model parameters). As discussed in appendix
In appendix D.4, we generalize (3) to the overparametrized case. We first add -regularization to the train loss: which makes the train local minima strict (so that the bias in (3) makes sense); however, the test loss is unregularized so its local minima are non-strict. We find that (3) still holds if we generalize the definition of the shift to:
where is the projector onto the subspace spanned by the per-test-sample gradients.
In order to generalize equation (5), we modify SGD to include isotropic noise in addition to the -regularization of the train loss, as follows:
where are scalars, and where wt is zero-mean Gaussian noise in with covariance equal to the identity, that is independent of everything else. The isotropic Gaussian noise ensures that is invertible, while -regularization controls the part of the drift in but not in , as discussed in appendix D.3. The diffusion process approximation of the modified SGD is then exactly like the one for the underparametrized case after making the substitutions in place of , and in place of With these substitutions, the steady-state distribution still has the form in equation (5), and its Gaussian mixture approximation still has the form in equation (7).
6. Other related work
There are many works aimed at understanding generalization in the context of SGD, and the interplay between SGD noise, curvature, and test performance. In addition to the most relevant works already discussed: Ahn et al (2012), Mandt et al (2017), Smith and Le (2017), He et al (2019a)take a Bayesian perspective, He et al (2019b) study asymmetrical valleys, Belkin et al (2019) focus on model capacity, Martin and Mahoney (2018) apply random matrix theory, Corneanu et al (2020) propose persistent topology measures, Russo and Zou (2016), Xu and Raginsky (2017) provide information-theoretic bounds, Smith et al (2021) perform error analysis relative to gradient flow, Lee et al (2017), Khan et al (2019)connect to Gaussian processes, Wu et al (2019) analyze multiplicative noise, and Zhu et al (2018) study 'escaping efficiency'. He et al (2019b) show that SGD is biased toward the flatter side of asymmetrical loss valleys, and such a bias can improve generalization (offering additional insight into the bias term in (3) which we identify but do not focus on). Although Mandt et al (2017), Smith and Le (2017), He et al (2019a) make important connections between generalization and curvature, none provide an explicit expression for test loss in terms of curvature. Also, their analyses are focused around a single local minimum, rather than comparing different local minima as we do in this paper. Mandt et al (2017) show that SGD can perform approximate Bayesian inference by studying a diffusion approximation of SGD and its steady-state, but assuming a single, quadratic minimum, and constant isotropic noise; assumptions we relax in this paper. He et al (2019a) derive a PAC-Bayes bound on the generalization gap at a given local minimum, which shows that wider minima have a lower upper-bound on risk; in contrast, we make a direct approximation showing that wider minima have lower test loss. They also show that the bound on the risk can increase as log(1/T) for a particular local minimum, provided the model is large relative to the T and curvature; this is quite different from our analysis, which involves analyzing the role of T in modulating the SGD steady-state distribution's preference for different local minima (and with no assumptions on model size). Smith and Le (2017) connect Bayesian evidence to curvature at a single local minimum, and empirically suggest evidence as proxy for test loss (by contrast we connect test loss directly to curvature). They also show that the steady state of a Langevin equation with constant isotropic noise is proportional to evidence, and by analogy surmise that SGD noise drives the steady-state toward wider minima; we agree with their intuition, but we allow for non-constant non-isotropic noise, and use a Gaussian mixture approximation to directly show the preference for wider minima. Finally, there is a long history of using statistical mechanics techniques to study machine learning, as we do in this work. For example, Seung et al (1992) prove among many other results that the generalization gap is positive; however, like many works both before and after, they assume constant and isotropic noise. Our work motivates updating these earlier studies to the algorithms currently used in practice, e.g. by using the parameter distribution in (5).
7. Conclusion and discussion
This paper contributes to longstanding debates about whether curvature can predict generalization, and how and whether SGD minimizes curvature. First, we show that curvature harms test performance through three mechanisms: the shift-curvature, bias-curvature, and the covariance-curvature; we believe the first two to be new. We address the concern of Dinh et al (2017) by showing that all three are reparametrization-invariant although curvature is not. Our main focus is the shift-curvature, or curvature along the line connecting train and test local minima; although the shift is unknown at training time, the shift-curvature still shows that any directions with high curvature can potentially harm test performance if they happen to align with the shift. Second, we derive a new and explicit SGD steady-state distribution showing that SGD optimizes an effective potential related to but different from train loss, and show that SGD noise mediates a trade-off between the depth and curvature of this effective potential. Our steady-state solution removes assumptions in earlier works of constant, isotropic gradient covariance, and treats the overparametrized case with care. Third, we combine our test performance analysis with the SGD steady-state to show that for small SGD noise, shift-curvature may be the dominant mechanism. Our experiments validate our approximations of test loss, show that that shift-curvature is indeed a major contributor to test performance, and show that SGD with higher noise chooses local minima with lower curvature and lower shift-curvature.
However, many avenues for future work still remain. First of all, we have made substantial, but not complete, progress in the overparametrized case. One of the main challenges is directions of near-zero curvature (a concern raised by Draxler et al
2018, Sagun et al
2018 and others). On the positive side, overparametrization poses no problem for our test performance approximation (3) or shift-curvature concept (since zero-curvature directions simply have no effect on either local test performance or shift-curvature). Furthermore, our SGD results eliminate constant, isotropic assumptions that appear in earlier works, clarify the challenges in the overparametrized case (such as the rank-deficient gradient covariance with varying nullspace), and still obtain the steady-state solution (5) of a slightly modified problem. However, the Gaussian mixture approximation (6), which highlights SGD's low-curvature preference, is less clear in the presence of near-zero-curvature directions. Therefore, further work is needed to clarify how near-zero-curvature directions may affect the unmodified SGD steady-state as well as its preference for low-curvature regions. (A possibly-useful intuition is that, locally, SGD should force all near-zero-curvature directions to zero, making them essentially irrelevant; however, there are difficulties in making this precise, as discussed in appendix
Acknowledgments
We would like to thank Luca Zappella and Santiago Akle for inspiring us to study this problem; Vimal Thilak for assistance with compute infrastructure, Josh Susskind, Moises Goldszmidt, and Omid Saremi for helpful discussion and suggestions on the draft; Rudolph van der Merwe and John Giannandrea for their support; and our PhD advisors Wing H Wong and George C Verghese for first introducing us to the techniques behind this work.
Data availability statement
The data generated and/or analyzed during the current study are not publicly available for legal/ethical reasons but are available from the corresponding author on reasonable request.
Appendix A.: Continuous-time approximation of SGD
A.1. Continuous-time approximation of a discrete-time continuous-space Markov process
Our goal in this section is to construct a continuous-time Markov process that approximates a discrete-time Markov process. Later we will specialize to SGD, but for now we consider any discrete-time stochastic process with state (where j indexes time), that evolves according to:
where is an arbitrary function of θj . The SGD update is a special case when Δ is given by (9). (Note the different notation from the main text: j rather than t is the discrete time index here, so that we can use t to represent continuous time.) We wish to construct a continuous-time Markov process , with probability distribution , that approximates θj in the following sense. We assume the discrete-time updates of θj occur every τ > 0 continuous-time units of , where τ is arbitrary. We want such that if then
We will construct via its Kramers–Moyal expansion, which is a differential equation that describes the evolution of the probability distribution of . Truncating the Kramers–Moyal expansion to second order yields a Fokker–Planck approximation of the discrete-time stochastic process (which we will later specialize to the case of SGD). The Kramers–Moyal equation depends on the moments of the infinitesimal increments of the continuous process (i.e. , for infinitesimally small), so our basic strategy is to construct these moments in such a way that the continuous-time update over time interval τ matches the discrete-time update (we will find that defining the continuous moments proportional to the discrete cumulants achieves this).
We state the final results here, and justify them in the next sections. First we need to introduce some notation. We adopt multi-index notation to streamline the exposition, so for example if then , where . Dropping the time index j from the discrete-time update Δ for now, we denote the moments and cumulants of Δ (which play an important role) by and , respectively, with (using multi-index notation for ). Next, we define the increment of the continuous-time process for any two times as the random variable (Note that the Markov property implies that the increment only depends on the time difference when the value of is known.)
With this notation, we can present our main results. In order to specify the process via the Kramers–Moyal expansion, we actually only need to define the probability distribution of when is infinitesimally small. We will show that we can (approximately) match the continuous-time update over time interval τ to the discrete-time update by defining the moments of as:
with (where denotes p-dimensional non-negative integers). That is, the moments of the small-time increments of are directly proportional to the cumulants of the discrete time process θj it seeks to approximate. (We will find that resulting process will approximate the discrete one well when does not change much in value within each τ time increment, so that the small time increments during that time period are approximately independent and identically distributed.)
Then we will show that (setting τ = 1 for simplicity now) the following Kramers–Moyal (KM) expansion describes the evolution of the distribution of :
where for an arbitrary function (as is standard in multi-index notation), and denotes the p-dimensional positive integers. The KM expansion can be truncated to second order to obtain a Fokker–Planck approximation of the discrete-time stochastic process.
A.1.1. Moments of the continuous time process in terms of cumulants of the discrete time process
Our goal in this section is to justify equation (15). Let for some positive integer K, and define for and Recall that is the small time increment of our continuous process between and so that is the total change in our process over τ time units when breaking the temporal interval into K equal-sized increments. We want to determine the moments of so that matches the corresponding moment of the discrete SGD step . Assuming the increments are approximately i.i.d. within each time interval of length we can approximate as a sum of i.i.d. random variables. With simplified notation, we model the problem as follows. Consider a sum where Xi are i.i.d. random variables (each in ). Suppose that we know the 'desired' limiting random variable along with its cumulants. We want to find the moments of the i.i.d. random variables X so that . We will use here the following definitions. Let and denote the moments and cumulants of an arbitrary random variable Let be the moment generating function of and the cumulant generating function, so that and . We will also use the identity:
To restate the problem more precisely in this notation, we want to find such that . Using the i.i.d. assumption and identity 17, we obtain:
so and since we want , we equate to find that so:
In order to achieve the desired limit, we then need the moments of the i.i.d. variables X to be equal to times the cumulants of the desired limiting distribution S. To explicitly connect this result back to the continuous approximation of SGD, we associate , , , and . Then translates to equation (15), as we desired. Note that the i.i.d. assumption on the increments is an approximation that introduces error into (15). In the approximation, we assume that for , , . If τ is small and varies slowly, then this condition will approximately hold.
A.1.2. The continuous Kramers–Moyal expansion
Our goal in this section is to obtain the Kramers–Moyal expansion (16). The approach is essentially to Taylor-expand the Chapman–Kolmogorov equation, take the limit as dt → 0, and use result (15) for the moments of the infinitesimal increments. The Markov assumption for implies the continuous-time Chapman Kolmogorov equation:
where is the transition probability function, and Substituting , and using our definition of , we obtain:
Next we Taylor-expand the left-hand-side of (19) in Δ (recalling that in multi-index notation, the infinite Taylor expansion of is ):
Finally we take the limit as dt → 0 and use (15):
The result above is equation (16).
A.1.3. The continuous process approximates the discrete process
Now that we have defined via its Kramers–Moyal expansion (16), we need to check that it approximates SGD in the sense of equation (14) (we need to show that ). We seek to understand the approximation error in terms of the difference between and The error arises because relies on the approximation that small-time increments of are i.i.d.. We assume that and then study starting from the Chapman–Kolmogorov for (18):
The two integrals above can be Taylor-expanded in the same way as in our derivation of the Kramers–Moyal expansion to yield:
where we define the error in the γ-th moment to be So, i.e.
Recall that the errors are due to the i.i.d. assumption on the infinitesimal increments in the continuous approximation; that is, the assumption that the distribution of is approximately constant in an interval of length τ. Therefore, if varies slowly relatively to the timescale τ, the approximation will be close.
A.2. Continuous-time diffusion approximation of SGD
We now choose Δ to be as in equation (9) to study SGD. Our goal here is to approximate SGD with a continuous-time process. Because cumulants of independent random variables are additive, and letting be the γth cumulant of (i.e. the gradient evaluated at a single sample of x), we find that the cumulant for the minibatch of size B is:
So the KM expansion (16) of the continuous-time SGD approximation then becomes:
Clearly, then, as T gets small, fewer terms in the expansion matter. The first two cumulants for SGD are the mean and variance of the gradients of U over the training distribution, respectively:
When T is small enough, or when the cumulants are small for (i.e. when is Gaussian, cumulants higher than 2 are zero), we can approximate the KM expansion of (23) by the Fokker–Planck (FP) equation that retains only the first two terms in the expansion. Switching to the notation of the main text:
(so is now the empirical covariance matrix of the gradients in the training, or diffusion matrix, and is the train loss), we obtain:
Rewriting using the divergence operator , we obtain the Fokker–Planck SGD equation stated in the main text (10).
A.2.1. When is the SGD diffusion approximation accurate?
The FP approximation of SGD follows from two approximations. First, we need the increments within a time interval of τ (corresponding to a single SGD update) to be approximately i.i.d. so that the KM expansion (16) is an accurate approximation of the discrete time SGD process. This condition holds when the product of the expected change in θ during a single update and of the derivative w.r.t. θ of the density of any small increment is small 7 . The second condition required for the FP equation to hold is that terms of order in the KM expansion must be small (relative to the terms), so that the truncation that yields the FP equation is appropriate. The latter is satisfied when T is small, and/or the third and higher cumulants of the gradients (i.e. for ) are small. Finally, our steady-state analysis in the next section relies on the distribution having actually reached steady-state, and the number of SGD steps required to reach steady-state scales inversely with temperature, so in practice with very small temperature the steady-state could be difficult to attain. Therefore, practically speaking, we expect the SGD diffusion approximation to start to break down for large T (i.e. large ratio of LR to batch size), or for fixed small T at extreme (very small or very large) learning rates or batch sizes. A large learning rate makes the expected change in θ in a single update large, potentially violating the i.i.d. assumption of the small time increments, while a large batch size at small fixed temperature also implies a large learning rate, and has the same effect. We similarly expect the mean change in an SGD update to be large at the beginning of an SGD run, and our FP approximation to not be valid during some initial transient period. Lastly, the i.i.d. assumption can be violated when the derivative of the mean and covariance of a single SGD update with respect to θ is large.
Appendix B.: Steady state of the SGD diffusion approximation
B.1. Steady state distribution
Our goal in this section is to show that (5) is the steady-state solution of our SGD diffusion approximation (10), and clarify the assumptions under which this is true. Our solution of the Fokker–Planck equation is identical to one presented in Gardiner (2009) section 6.2. We consider the underparametrized case (section 5 extends to the overparametrized case), so that is invertible and the SGD drift term spans even without regularization. We assume the process is ergodic, and therefore has a single steady-state distribution To find it, we first set (10) to zero and obtain the steady-state condition . A distribution where J is constant, if it can be found, satisfies the above. We also now require that and as , so we attempt to find a solution that satisfies everywhere, and where Some algebra then yields the steady-state solution:
This solution relies on a line integral to define that is path independent, and therefore needs the following assumption: Assumption 1.
remark is a gradient; that is the curl of vanishes.
In other words, since is defined as , the first equation above can only be satisfied if is a gradient; a necessary and sufficient condition for this is the vanishing of the curl or so-called potential conditions where denotes the ith entry of . When these conditions hold, the Hessian of is symmetric, as it should be. So we assume that and are such that the assumption above holds.
B.1.1. Examples where assumption 1 holds
Assumption 1 clearly holds in the case where is constant and isotropic (i.e. for positive constant c), since then , which is the gradient of . But it can also hold for nonconstant, anisotropic as well, as the following examples confirm.
- (a)Isotropic but non-constant noise: consider so that . Then is a gradient, and we have , and This argument works more generally: let be any smooth scalar function, and let where (the first example is a special case with ). Then the chain rule implies that Similarly, so and
- (b)Anisotropic but constant noise: Consider and suppose that is quadratic, so that is constant and . Then which shows that is a gradient. In this case, .
- (c)Non-constant and anisotropic noise: Let and suppose that , where for an arbitrary smooth scalar function Note that both Ui and di are functions only of θi . Then . Also, . So . Defining , we get . This example results in independent dynamics for each parameter, which is unlikely to hold for realistic models, but it does show that assumption 1 holds for more models with just those with constant isotropic noise. (As an additional example, note that in example 2 where , we obtain the same if we remove quadratic /constant assumption but take the limit as T → 0 to drop the term.)
Further work is necessary to characterize all the situations for which SGD satisfies assumption 1. Chaudhari and Soatto (2018) state in effect that J ≠ 0 whenever is nonconstant and/or anisotropic, so the examples above are counterexamples since they all satisfy J = 0 (a consequence of assumption 1).
B.2. Gaussian mixture approximation
We further approximate the steady-state distribution of SGD by a mixture of Gaussians, i.e. as a distribution of the form So we need to find the means µk , covariances , and weights wk . In our derivation, we assume the underparametrized case, where for example the gradient covariance and Hessian are full rank. However, the arguments are similar in the overparametrized case using the modified SGD (13). From (5), the first two derivatives of υ are:
If µ is any local extremum of υ, then:
where the first term of drops out since is positive definite, so ). Now, let be a local minimum of , and µk a nearby local minimum of υ. To find the approximate bias , we can expand the equation above:
since So,
For the covariance, we can make a Laplace approximation to centered at µk :
Therefore, is also O(T) (and inversely proportional to the curvature of the effective potential). The Laplace approximation also allows us to find the weights. We use the eigendecomposition of to approximate the integral of over the basin :
This is equation (6) in the main text. It tells us that the basin weights depend on the depth and width of the basin in terms of the effective potential.
Appendix C.: Reparametrization invariance
C.1. Expected test loss under any parameter distribution is reparametrization-invariant
Let denote any loss function (for example, the test loss, or the train loss with or without regularization), and let be any distribution over the model parameters. Consider a reparametrization of θ (with ), where is invertible. For an arbitrary function , we define the reparametrized version by . With this notation, we want to show that is reparametrization-invariant.
Let denote the p.d.f. of , and note that (using a general formula for invertible functions of random variables) . Then:
Therefore is reparametrization-invariant.
In particular, this means that is reparametrization-invariant, where is the test loss and ρ is the SGD steady-state distribution.
C.2. Taylor approximation is reparametrization-invariant
Using equations (3) and (4), we obtain the following approximation for the expected test loss:
Our goal in this section is to show that this expression is approximately reparametrization-invariant for any distribution such that bk and sk are both small (which holds for SGD because we are already assuming that sk is small, and the SGD steady-state distribution satisfies ). As before, let be an invertible reparametrization. Defining , we need to show that the terms appearing on the right-hand-side of (26) are (approximately) reparametrization-invariant.
We can show that the weights wk are reparametrization-invariant using a similar argument to :
Now we want to show that the remaining terms in left hand side of (26) are approximately reparametrization-invariant. First, since the test loss is a scalar function, it is equal to its reparametrization, i.e.:
So is invariant. To help with the other terms, note that the first two derivatives of any scalar function are:
where , , , , . We use these to find:
Since we have seen that each of the terms appearing in (26) (i.e. wk , , etc) are all reparametrization-invariant, we conclude the whole expression for the expected test loss is reparametrization-invariant.
Appendix D.: The overparametrized case
D.1. Some intuition
One of the important properties of the overparametrized case (more model parameters than training samples) is that the loss landscape necessarily has zero- or near-zero-curvature directions, since in the absence of -regularization there can be no more curved directions than there are training samples, and -regularization adds a small amount of curvature in all directions. A question, then, is how zero-curvature directions affect the conclusions in this paper. As a rough intuition, we argue that zero-curvature directions may have little impact on test performance, and be deterministically forced to zero by SGD at steady state in the presence of any regularization, making them essentially irrelevant in the context of our analysis. Consider a simplified case where the set of nonzero-curvature directions is constant (rather than depending on θ, as it does in general) and the same for train and test. Suppose that the train loss has a small amount of -regularization (as is typical in practice) but the test loss does not. In this case, the train local minima are strict (single points) but the the test local minima are subspaces that include all the zero-curvature directions. The Taylor approximation (3) is identical for any point in the test local minimum space spaces; that is, only the nonzero-curvature directions impact test performance or its shift-curvature approximation. Furthermore, SGD converges to a solution where the parameters in the zero-curvature directions are all equal to zero (as there is no diffusion in these directions and the drift term is , where α > 0 is the parameter). Therefore, the zero-curvature directions are essentially irrelevant in the context of generalization and SGD.
Of course, the situation becomes more complex when we allow the nonzero-curvature directions to vary with θ and differ for train and test. Since (3) is a local approximation, the argument above still holds, and we find that the local zero-curvature directions do not affect the test performance near a train local minimum. However, (5) is more complicated, because it depends on the global landscape where the set of nonzero-curvature directions can vary (also, the gradient variance only has rank N − 1, causing technical difficulties), which is why we require the problem modifications discussed in section 5. In addition to -regularization on the train loss, we also need to add a small amount ( of noise to the diffusion matrix to make it invertible, which changes the problem so that SGD does not force the zero-curvature directions to zero as in the discussion above, but instead converges to an infinite-support distribution, with the (local) zero-curvature directions represented locally by Gaussians with mean zero and variance .
D.2. Nonzero-curvature subspaces
Our goal is to derive the results in section 5, but we first take a detour to describe in detail why and how overparametrization changes the problem. In the underparametrized case (more data samples N than model parameters p), the per-sample gradients typically span all of , while in overparametrized case (p > N), the per-sample gradients span a subspace of dimension no larger than N. To work out the many consequences of this, we introduce some more compact notation. We consider a typical loss function The test loss and the train loss without -regularization in (1) both have this form. Let be a shorthand for the N per-sample gradients, and Thus is the nonzero-curvature subspace at θ. Let denote the complement subspace orthogonal to . To simplify exposition, we assume that all per-sample gradients are linearly independent, so that the dimension of is . Thus, if , then and is empty; but if p > N, then has dimension N and has dimension p − N. Because the gradients are generally functions of the spaces and can change with
As shown at the end of this section, the average gradient is in , the range of the gradient covariance is contained in but has one less dimension, and the range of the Hessian is equal to provided is constant in a neighborhood around θ. In particular, the gradient covariance and Hessian are both rank-deficient (hence non-invertible). All of these facts apply directly to the test loss, and to the part of the train loss excluding -regularization. Adding -regularization to the train loss changes the results as follows. The average gradient is , so that is no longer constrained to lie in and can be anywhere in . The gradient covariance is unchanged since αθ is not random; in particular, it is still rank-deficient. The Hessian becomes full-rank, because regularization adds α to every eigenvalue of the Hessian. These facts have important implications for the Taylor approximation (3) and the SGD steady-state (5) in the overparametrized case. First, recalling that a local minimum is strict when the Hessian is full-rank, we see that test local minima are non-strict, while train local minima are strict only if we add -regularization. Second, is non-invertible, invalidating (5), but we can resolve this by adding isotropic Gaussian noise to the SGD updates to make full rank, together with -regularization to control the part of the drift term in .
D.2.1. Average gradient, covariance, and Hessian are in
The average gradient is which clearly lies in the span of the per-sample gradients: The gradient covariance is We will show shortly that the range of is also contained in , but it is not equal to it, since has one more dimension. This means that has a nonempty nullspace containing and an extra direction, hence is not invertible. To see that , note that (like any p × p sample covariance constructed from N data points) has rank (since it can be written as a sum of N rank-1 matrices, , but only has dimension N − 1 when p > N). Therefore, when p > N, the nullspace of has dimension . It is also clear that the nullspace of contains (to see this, consider any vector , i.e. such that for all ; it then follows from the definition of that ). Therefore, the nullspace of contains all of (which has dimension p − N), plus one extra direction. Equivalently, the range of is contained in , but has one extra direction. Finally, the range of the Hessian is equal to if is constant in some open neighborhood of θ. To see this, note that for any , we have that , since for all θ in a neighborhood of by assumption. So the nullspace of contains . Also, assuming that the per-sample Hessians in the dataset are linearly independent, the range of the Hessian has dimension so it is equal to .
D.3. Overparametrized SGD steady-state
In the overparametrized case, we use the modified SGD (13) which yields the diffusion approximation with the same FP description (10) but with modified probability current Recall that the diffusion matrix in the diffusion process approximation is the covariance of , which is now and is invertible when β ≠ 0. But the added isotropic noise creates a new problem: since lies in we now have diffusion noise but no drift in , creating the possibility of a distribution of parameters that spreads out forever as . To prevent this, we need drift in that keeps parameter values small, and α > 0 accomplishes this for us. So, with , the same approach given in appendix B.1 works to obtain the steady-state distribution, resulting in the same form after the appropriate redefinitions of the diffusion matrix (adding to it) and drift term (adding αθ to it). With β = 0, it is unclear how find the steady-state distribution: for example, one might try replacing the inverse of -projected-onto- with its pseudoinverse, and including a delta function in ρ to enforce conditions implied by the ODE in the one-dimensional nullspace of -projected-onto-, but this choice turns out to yield a nonzero divergence of the probability current , hence it is not a steady-state solution.
D.4. Overparametrized Taylor approximation of test loss
If the local minima of train and test are strict (i.e. single points), then the Taylor expansion approximation of the test loss (3) goes through unchanged. As we saw in appendix D.2, the local minima of train are always strict since we assume that the train loss includes nonzero -regularization (which makes the Hessian full-rank). However, the local minima of test are non-strict in the overparametrized case since the Hessian is rank-deficient (unless we also add -regularization to the test loss, which is not usually done in practice). However, we can still find a natural analog of the Taylor approximation in the overparametrized case. The test local minima are no longer strict. We assume that is constant in a neighborhood of , where . From appendix D, we know that , where , and has dimension N < p. Therefore (as shown at the end of this section) each test local minimum is a set:
The local minima of the train loss are strict due to the -regularization. (The train gradients are , which can only be zero when the element in is equal to zero). Recall that in (3), we seek to approximate the test loss at a point found by a single trajectory of SGD. We know from the previous section that must have its projection on equal to zero, so it is also unique. Taylor-expanding the test loss about any in the test-loss-local-minimum subspace, we obtain:
where and .
Since the range of is orthogonal to , the term only depends on the part of s(w) in , hence the Taylor approximation of is the same for any choice of w. We would get the same result for by defining for any w, but a natural choice (that eliminates the irrelevant-to-the-test-loss part of s in ) is (This choice of s(w) corresponds to , which is the the solution to .) The overall result is:
We see that the Taylor approximation of the test loss at a train local minimum has the same form as in equation (3), but the definition of the shift takes a more general form: it is now the projection of the difference between the test and train local minima onto (subspace for the test loss). This definition coincides with the underparametrized definition of sk when is all of .
D.4.1. Test local minima are given by (28)
Let be an orthonormal matrix such that is a basis for and is a basis for . So, i.e. the projector onto is Define Then
In particular, we have Therefore the unregularized train loss must be constant w.r.t. , so we can write (by picking ):
Appendix E.: Higher order curvature
The full Taylor-expansion of a function (for example, the train or test loss) about a local minimum θk is given by:
here we are using multi-index notation, so that
Note that the derivatives of order 3 and higher are tensors, and represents a tensor contraction.
In one dimension, a critical point θk of a function is a local minimum if:
This follows from Taylor's theorem. For example, if , then we can look at the next two derivatives:if then θk is a saddle point; otherwise, it is a local minimum if and a local maximum if . If the 3rd and 4th derivatives are zero, we need to look at the 5th and 6th, and so on. In higher dimensions (), we can generalize to for all x (analogous to definition of positive definite Hessian). But the condition for a local minimum becomes more complex since, for example, the second derivative could be positive semidefinite (with both positive and zero eigenvalues); in positive-eigenvalue directions we can be confident of a local minimum, but in the nullspace we must consider higher derivatives. However, we also know that a critical point θk is a local minimum of U if θk is a local minimum of the 1D function obtained by evaluating U along any line through θk . This characterization allows us to fall back to the simpler 1D characterization (30) along any line, and suffices for many of our purposes.
In particular, in the Taylor expansion (3), taking to simplify the argument, we wish to approximate by expanding about , so we only care about the values of along the line between the train and test local minima. Defining:
(note that this Taylor expansion looks slightly unusual because it is centered at , the test local minimum, and evaluated at 0, the train local minimum, so that .) Since is a local minimum of , we know that has a local minimum at , so condition (30) must hold there. This means that the order J of the minimum nonzero derivative along the line must be even, and we must have:
This offers a way to characterize the curvature in terms of a higher-even-order derivative , even if the Hessian is zero along the line between the local minima. Higher curvature along this line corresponds to larger values of . We can also partially generalize the SGD steady-state results to the higher-order curvature case. However, we currently only know how to evaluate the necessary integrals in the special case where the local curvature of υ depends only on a single even derivative (for example, only a positive-definite 4th derivative), and where the even derivative tensor is diagonalizable (which is not the case in general for tensors of order 3 and higher). That is, we must assume that at its local minima υ can be approximated as:
where J is even, and is diagonalizable. Then,
To find the basin weights, we need to evaluate the integral:
To make more progress, we need to evaluate integrals of the form (here j is even but otherwise arbitrary). This is where we need the assumption that is diagonalizable, which always holds for symmetric tensors of order two with real entries (i.e. matrices), but does not hold in general for symmetric tensors of higher orders (Comon 1994). Specifically, we assume that can be written as:
where if r < n we add n − r additional columns to Q to complete an orthonormal basis for . (Note that the assumption that the qi 's are orthogonal implies that since more than n vectors in cannot be linearly independent, and also implies that for all i because is positive semidefinite (Qi 2005). If we take only the first r nonzero eigenvalues we can assume for all . In this case we have an analog of the determinant given by Qi (2005): Then:
To see that agrees with the J = 2 case, note that
Appendix F.: Experiment notes
F.1. Shift-curvature symmetrized quadratic fits
Due to the shapes of the losses we find that the symmetrized quadratic fit more accurately captures the relevant overall curvature of the basin. First, the loss curves tend to have a small 'flattened' region right around the local minima, despite looking roughly quadratic overall (we suspect the 'flattening' is due to sigmoid saturation), so that the curvature evaluated right at the local minima tends to underestimate the overall curvature of the basin, which is better captured by fitting a quadratic. Second, we find that the train and test losses have an asymmetrical shape (He et al 2019b) along the line between the local minima, hence we reflect about the test local minimum in the direction toward the train local minimum since this is the direction that affects the test loss at the test local minimum; similarly, we reflect the train loss about its local minimum in the same direction (i.e. away from the test local minimum) since we are interested in train curvature as a proxy for test curvature so we should study the same direction of each. We believe the observed asymmetry of the losses is due to the experimental design: to find the test minimum, we run gradient descent starting from close to the train minimum, which means we go down the steepest direction of the test basin away from the train local minimum until we reach a test local minimum; however if the test basin has a flat region at the bottom, continuing in the same direction may have zero or very small test loss until we reach the opposite wall of the basin (however, this direction has little impact on the test loss at the train local minimum).
Appendix G.: Additional experiments
G.1. More shift-curvature experiments
G.2. Temperature and shift experiments
Figure 8 shows experiments on CIFAR10 (Krizhevsky et al 2009) to demonstrate the phenomenon of generalization improving with temperature, as well as to understand the shift between and training and test sets for CIFAR10. We used the VGG9 network Simonyan and Zisserman (2014) and the training procedure of Li et al (2017a) (but with no momentum). For the temperature experiment, we first trained a network using a large batch size (4096) and decreasing learning rate schedule (starting from 0.1, scaled by 0.1 at 0.5, 0.75 of total epochs, ending with LR 0.001) until the training converged. Then, we continued training from that initialization with a variety of batch sizes and learning rates (all LRs held constant in second stage), for 1000 epochs (regardless of batch size). For each second-stage run, we took the median of the last 100 steps (to eliminate spiky outliers) as the 'final loss'. We repeated this two-stage experiment 10 times (with different initializations), and plotted the mean and standard deviation over all trials of the final train and test loss as a function of temperature (learning rate divided by batch size) (from 2.44 × 10−8 to 2.44 × 10−5). (At end of first stage, final temperature = 2.44 × 10−7; mean final train loss = 0.005 (var 4 × 10−07) and test loss = 0.555 (var 1 × 10−4).) Our results show that both the train loss and test loss depend primarily on the temperature, and test loss decreases with increasing temperature while train loss remains roughly constant. For small batches (50 and 100), train loss was higher at the beginning of the second stage for all LRs; train loss dropped significantly by the end of training with large LRs but remained high with small LRs. We believe this was due to mismatched batchnorm statistics and parameters (parameters learned for batch size 4096 but applied to batch size 50 or 100); with larger LRs the network was able to retrain the batchnorm parameters, but smaller LRs did not allow sufficient progress.
The shift experiments aimed to clarify whether the train/test shift in CIFAR10 is more likely due to distribution shift or to finite sampling. The training procedure was the same as in the initial stage of the temperature experiment. To study distribution shift, we ran a shuffling experiment, where we merged the train and test sets, reshuffled them to create a single dataset, and then randomly split the data into training and test at each trial (25 total). If there were a distribution shift between the default train and test split in the CIFAR data, this process would remove it, and we would expect the test loss to decrease. However, our experiment shows that there is no significant difference in the test loss for different shufflings of the train and test sets, suggesting that distribution shift is not present. To better understand the role of finite sampling, we ran a sample-size experiment, where for each trial we chose a subsample of the full dataset of a given size, split it proportionally into train and test, and trained on the training set. If finite sampling were causing the shift between the training and test distributions, we would expect smaller sample sizes to exacerbate the difference (since sampled distributions generally become closer to the underlying distribution as the sample size grows). Our experiment shows that the test loss increases as the sample size decreases, consistent with finite sampling as the source of distribution shift.
G.3. One and three basin experiments
Figure 9 shows 1-basin and 3-basin experiments analogous to the 2-basin experiment in figure 4. In the 1-basin example, both train and test error increase with increasing temperature. The 3-basin example shows qualitatively similar behavior to the 2-basin case and confirms the SGD steady-state distribution's preference for wider minima when the depths are similar.
Download figure:
Standard image High-resolution imageG.4. Experiment details and hyperparameters
G.4.1. Figures 1–3 and 5–7
- Resnet10 (without batch norm): Shon (2021) haiku implementation of Resnet10 (He et al 2016). Modified by reducing blocks per group from 3 to 2, and removing batch norm.
- Resnet20 (with batch norm): Shon (2021)'s haiku implementation of Resnet20.
- Resnet18 (with batch norm) for ImageNet: haiku implementation (Hennigan et al 2020).
- VGG16: Yoshida (2021) haiku implementation of VGG16 (Simonyan and Zisserman 2014). Adapted for CIFAR10 by reducing size of dense layers to 512 as proposed by Liu and Deng (2015) and removing adaptive pooling.
- Data: CIFAR10 (Krizhevsky et al 2009), ImageNet (Deng et al 2009)
- Training: adapted from Shon (2021)'s JAX sample code.
CIFAR10 experiment hyperparameters: | |
Initial SGD | Find Train/Test Local Minima |
BATCH_SIZE: 100 | TRAIN_SAMPLE_COUNT: 50000 |
TRAIN_SAMPLE_COUNT: 50000 | BATCH_SIZE: 10000(res) 5000(vgg) |
TRAIN_SHUFFLE: True | LEARNING_RATE: 1e-4 |
TRAIN_DROP_LAST: True | MAX_EPOCH: 10000 |
MAX_EPOCH: 400 | L2_REG: 1e-5 |
MAX_EPOCH_REF_TEMP*: 0.00005 | L2_AT_EVAL: True |
LEARNING_RATE: various | AUGMENT_TRAIN: False |
L2_REG: 1e-5 | TRAIN_SHUFFLE: False |
L2_AT_EVAL: False | TRAIN_DROP_LAST: False |
AUGMENT_TRAIN: False | |
ImageNet experiment hyperparameters: | |
Initial SGD | Find Train/Test Local Minima |
BATCH_SIZE: 100 | TRAIN_SAMPLE_COUNT: 127149 |
TRAIN_SAMPLE_COUNT: 127149 | BATCH_SIZE: 500 |
TRAIN_SHUFFLE: True | LEARNING_RATE: 1e-4 |
TRAIN_DROP_LAST: True | MAX_EPOCH: 300 |
MAX_EPOCH: 500 | L2_REG: 1e-5 |
MAX_EPOCH_REF_TEMP*: 1e-5 | L2_AT_EVAL: True |
LEARNING_RATE: various | AUGMENT_TRAIN: False |
L2_REG: 1e-5 | TRAIN_SHUFFLE: False |
L2_AT_EVAL: False | TRAIN_DROP_LAST: False |
AUGMENT_TRAIN: False | |
*We scale the number of epochs (E) as E * T = MAX_EPOCHS * MAX_EPOCH_REF_TEMP. |
G.4.2. Figures 4 and 9
The objective plots show the synthetic train and test objective (loss) functions and . For the steady-state distribution plots, we plug the train loss , the gradient variance , and the current temperature into (5), and evaluate as a function of θ. For the probability of each basin plots, we integrate the steady-state distribution over the basin of each minimum (i.e. between the two local maxima adjacent to the minimum, or the range endpoint—noting that ρ → 0 at the range endpoints); the probabilities are plotted as a function of temperature. Finally, the train vs. test loss plots show the train and test losses as a function of temperature, i.e. train loss is computed as (where depends on the temperature per (5).
Param | Figure 4 Left | Figure 4 Right | Figure 9 Left | Figure 9 Right |
seed | 0 | 0 | 0 | 0 |
minima | [-1 1] | [-1 1] | [0] | [-1 0 1] |
weights | [0.021 0.1] | [0.019 0.1] | [0.1] | [0.1 0.051 0.3] |
sigmas | [0.1 0.5] | [0.1 0.5] | [0.1] | [0.1 0.05 0.3] |
c | 0.001 | 0.001 | 0.001 | 0.001 |
stddev_shift | 0.1 | 0.1 | 0.03 | 0.1 |
lscale | 1.0 | 1.0 | 1.0 | 1.0 |
wscale | 0.0 | 0.0 | 0.0 | 0.0 |
G.4.3. Figure 8
We used a slightly modified version of the code accompanying Li et al (2017a) for training a VGG9 network on CIFAR10. We changed the learning rate schedule to drop at 0.5, 0.75, and 0.9 of total epochs (rather than at fixed epochs), and also added support for reshuffling the train and test sets, and using a subsample of the full dataset.
Stage 1 hyperparameters | Stage 2 hyperparameters: | Figure 8 right |
rand_seed: range(0, 10) | resume_epoch: 1200 | rand_seed: -1 |
batch_size: 4096 | rand_seed: -1 | batch_size: 3000 |
lr: 0.1 | batch_size: various | lr: 0.1 |
lr_decay: 0.1 | lr: various | lr_decay: 0.1 |
epochs: 300 | lr_decay: 1.0 | epochs: 2000 |
weight_decay: 0.0005 | epochs: 2200 | weight_decay: 0.0005 |
momentum: 0.9 | weight_decay: 0.0005 | momentum: 0 |
momentum: 0 | shuffle_seed: 0:25) | |
sample_size: 10000:50000 |
Footnotes
- 1
In section 5, we find that (3) still holds in the overparametrized case after generalizing the definition of sk to apply to non-strict minima. We further note that for asymmetric losses, only the curvature in the direction from the test toward the train local minimum matters, as explored in our experiments. Finally, in appendix E, we show how this result and others can be partially generalized to the case where the quadratic term is zero, but curvature can still be characterized in terms of a higher-order even derivative.
- 2
Generally, such a model is reasonable when ρ has the form for some constant c and function υ with local minima near those of the train loss, so that ρ is multimodal with peaks near the train local minima, enabling a Laplace approximation about each local minimum of υ, which yields a Gaussian mixture with small biases. We will see shortly that the SGD steady state has this form.
- 3
When for some scalar function , then . If in addition is a constant denoted by then
- 4
- 5
This is in contrast to the effective potential υ, since some algebra shows that , so υ can only worsen with increasing T. Interpreting as the energy of the system makes the heat capacity (Niroomand et al 2022).
- 6
Informally, in our derivation we need the probability distribution of to vary slowly with θ, and making λ sufficiently small guarantees this, to ensure that infinitesimal increments comprising an SGD step are approximately i.i.d.. We also need to truncate the Kramers–Moyal expansion to second order and obtain (10). The same conditions also appear in Li et al (2017b): follows from the bound of the accuracy of the order-1 weak approximation, which is a constant times the learning rate with batch size assumed constant. That the probability of the updates varies slowly with θ appears formally there as a Lipschitz and growth condition on the loss.
- 7
To see this, consider the joint probability of two infinitesimal increments within the same time period of length i.e. where and assume at the beginning of the dt0 interval. We can write the latter probability as:
is a random variable describing the increment since the start of the τ time interval until the beginning of the kth infinitesimal increment Letting µw denote the mean of we can Taylor expand to first order around W = 0 to find that:
plus higher order terms. When the second term above is small compared to the first, we can finally write that and since , and these increments only depend on these magnitudes and on θ at the beginning of each increment (both equal to θ0 in our expansion above), then the i.i.d. assumption is satisfied. Conversely, when µw and/or are large we expect the increments not to be i.i.d. and the KM expansion to not approximate the discrete process well.