Paper The following article is Free article

Particle dual averaging: optimization of mean field neural network with global convergence rate analysis*

, and

Published 24 November 2022 © 2022 IOP Publishing Ltd and SISSA Medialab srl
, , Machine Learning 2022 Citation Atsushi Nitanda et al J. Stat. Mech. (2022) 114010 DOI 10.1088/1742-5468/ac98a8

1742-5468/2022/11/114010

Abstract

We propose the particle dual averaging (PDA) method, which generalizes the dual averaging method in convex optimization to the optimization over probability distributions with quantitative runtime guarantee. The algorithm consists of an inner loop and outer loop: the inner loop utilizes the Langevin algorithm to approximately solve for a stationary distribution, which is then optimized in the outer loop. The method can thus be interpreted as an extension of the Langevin algorithm to naturally handle nonlinear functional on the probability space. An important application of the proposed method is the optimization of neural network in the mean field regime, which is theoretically attractive due to the presence of nonlinear feature learning, but quantitative convergence rate can be challenging to obtain. By adapting finite-dimensional convex optimization theory into the space of measures, we analyze PDA in regularized empirical/expected risk minimization, and establish quantitative global convergence in learning two-layer mean field neural networks under more general settings. Our theoretical results are supported by numerical simulations on neural networks with reasonable size.

Export citation and abstract BibTeX RIS

1. Introduction

Gradient-based optimization can achieve vanishing training error on neural networks, despite the apparent non-convex landscape. Among various works that explains the global convergence, one common ingredient is to utilize overparameterization to translate the training dynamics into function spaces, and then exploit the convexity of the loss function with respect to the function. Such endeavors usually consider models in one of the two categories: the mean field regime or the kernel regime.

On one hand, analysis in the kernel (lazy) regime connects gradient descent on wide neural network to kernel regression with respect to the neural tangent kernel [1], which leads to global convergence at linear rate [24]. However, key to the analysis is the linearization of the training dynamics, which requires appropriate scaling of the model such that distance traveled by the parameters vanishes [5]. Such regime thus fails to explain the feature learning of neural networks [6], which is believed to be an important advantage of deep learning; indeed, it has been shown that deep learning can outperform kernel models due to this adaptivity [7, 8].

In contrast, the mean field regime describes the gradient descent dynamics as Wasserstein gradient flow in the probability space [911], which captures the potentially nonlinear evolution of parameters traveling beyond the kernel regime. While the mean field limit is appealing due to the presence of 'feature learning', its characterization is more challenging and quantitative analysis is largely lacking. Recent works established convergence rate in continuous time under modified dynamics [12], strong assumptions on the target function [13], or regularized objective [14], but such result can be fragile in the discrete-time or finite-particle setting—in fact, the discretization error often scales exponentially with the time horizon or dimensionality, which limits the applicability of the theory. Hence, an important research problem that we aim to address is

Can we develop optimization algorithms for neural networks in the mean field regime with more accurate quantitative guarantees the kernel regime enjoys?

We address this question by introducing the particle dual averaging (PDA) method, which globally optimizes an entropic regularized nonlinear functional. For two-layer mean field network which is an important application, we establish polynomial runtime guarantee for the discrete-time algorithm; to our knowledge this is the first quantitative global convergence result under similar settings.

1.1. Contributions

We propose the PDA algorithm, which draws inspiration from the dual averaging (DA) method originally developed for finite-dimensional convex optimization [1517]. We iteratively optimize a probability distribution in the form of a Boltzmann distribution, samples from which can be obtained from the Langevin algorithm (see figure 1). The resulting algorithm has comparable per-iteration cost as gradient descent and can be efficiently implemented.

Figure 1.

Figure 1. 1D visualization of parameter distribution of mean field two-layer neural network (tanh) optimized by PDA. The inner loop uses the Langevin algorithm to solve an approximate stationary distribution ${q}_{\ast }^{(t)}$, which is then optimized in the outer loop towards the true target q*.

Standard image High-resolution image

For optimizing two-layer neural network in the mean-field regime, we establish quantitative global convergence rate of PDA in minimizing an KL-regularized objective: the algorithm requires $\tilde{O}({{\epsilon}}^{-3})$ steps and $\tilde{O}({{\epsilon}}^{-2})$ particles to reach an epsilon-accurate solution, where $\tilde{O}$ hides logarithmic factors. Importantly, our analysis does not couple the learning dynamics with certain continuous time limit, but directly handles the discrete update. This leads to a simpler analysis that covers more general settings. We also derive the generalization bound on the solution obtained by the algorithm. From the viewpoint of the optimization, PDA is an extension of Langevin algorithm to handle entropic-regularized nonlinear functionals on the probability space. Hence we believe our proposed method can also be applied to other distribution optimization problems beyond the training of neural networks.

1.2. Related literature

Mean field limit of two-layer NNs. The key observation for the mean field analysis is that when the number of neurons becomes large, the evolution of parameters is well-described by a nonlinear partial differential equation (PDE), which can be viewed as solving an infinite-dimensional convex problem [18, 19]. Global convergence can be derived by studying the limiting PDE [10, 11, 20, 21], yet quantitative convergence rate generally requires additional assumptions.

Javanmard et al [13] analyzed a particular RBF network and established linear convergence (up to certain error 4 ) for strongly concave target functions. Rotskoff et al [12] provided a sublinear rate in continuous time for a modified gradient flow. In the regularized setting, Chizat [22] obtained local linear convergence under certain non-degeneracy assumption on the objective. Wei et al [23] also proved polynomial rate for a perturbed dynamics under weak 2 regularization.

Our setting is most related to Hu et al [14], who studied the minimization of a nonlinear functional with KL regularization on the probability space, and showed linear convergence (in continuous time) of a particle dynamics named mean field Langevin dynamics when the regularization is sufficiently strong. Chen et al [24] also considered optimizing a KL-regularized objective in the infinite-width and continuous-time limit, and derived NTK-like convergence guarantee under certain parameter scaling. Compared to these prior works, we directly handle the discrete time update in the mean-field regime, and our analysis covers a wider range of regularization parameters and loss functions.

Langevin algorithm. Langevin dynamics can be viewed as optimization in the space of probability measures [25, 26]; this perspective has been explored in [27, 28]. It is known that the continuous-time Langevin diffusion converges exponentially fast to target distributions satisfying certain growth conditions [29, 30]. The discretized Langevin algorithm has a sublinear convergence rate that depends on the numerical scheme [31] and has been studied under various metrics [3234].

The Langevin algorithm can also optimize certain non-convex objectives [3537], in which one finite-dimensional 'particle' can attain approximate global convergence due to concentration of Boltzmann distribution around the true minimizer. However, such result often depends on the spectral gap that grows exponentially in dimensionality, which renders the analysis ineffective for neural net optimization in the high-dimensional parameter space.

Very recently, convergence of Hamiltonian Monte Carlo in learning certain mean field models has been analyzed in Bou-Rabee and Schuh [38], Bou-Rabee and Eberle [39]. Compared to these concurrent results, our formulation covers a more general class of potentials, and in the context of two-layer neural network, we provide optimization guarantees for a wider range of loss functions.

1.3. Notations

Let ${\mathbb{R}}_{+}$ denote the set of non-negative real numbers and ||⋅||2 the Euclidean norm. Given a density function $q:{\mathbb{R}}^{p}\to {\mathbb{R}}_{+}$, we denote the expectation with respect to q(θ)dθ by ${\mathbb{E}}_{q}[\cdot ]$. For a function $f:{\mathbb{R}}^{p}\to \mathbb{R}$, we define ${\mathbb{E}}_{q}[f]=\int f(\theta )q(\theta )\mathrm{d}\theta $ when f is integrable. KL is the Kullback–Leibler divergence: $\mathrm{K}\mathrm{L}(q{\Vert}{q}^{\prime })\stackrel{\mathrm{d}\mathrm{e}\mathrm{f}}{=}\int q(\theta )\mathrm{log}\left(\frac{q(\theta )}{{q}^{\prime }(\theta )}\right)\mathrm{d}\theta $. Let ${\mathcal{P}}_{2}$ denote the set of positive densities q on ${\mathbb{R}}^{p}$ such that the second-order moment ${\mathbb{E}}_{q}[{\Vert}\theta {{\Vert}}_{2}^{2}]< \infty $ and entropy $-\infty < -{\mathbb{E}}_{q}[\mathrm{log}(q)]< +\infty $ are well defined. $\mathcal{N}(0,{I}_{p})$ is the Gaussian distribution on ${\mathbb{R}}^{p}$ with mean 0 and covariance matrix Ip .

2. Problem setting

We consider the problem of risk minimization with neural networks in the mean field regime. For simplicity, we focus on supervised learning. We here formalize the problem setting and models. Let $\mathcal{X}\subset {\mathbb{R}}^{d}$ and $\mathcal{Y}\subset \mathbb{R}$ be the input and output spaces, respectively. For given input data $x\in \mathcal{X}$, we predict a corresponding output $y=h(x)\in \mathcal{Y}$ through a hypothesis function $h:\mathcal{X}\to \mathcal{Y}$.

2.1. Neural network and mean field limit

We adopt a neural network in the mean field regime as a hypothesis function. Let ${\Omega}={\mathbb{R}}^{p}$ be a parameter space and ${h}_{\theta }:\mathcal{X}\to \mathcal{Y}\ (\theta \in {\Omega})$ be a bounded function which will be a component of a neural network. We sometimes denote h(θ, x) = hθ (x). Let q(θ)dθ be a probability distribution on the parameter space Ω and ${\Theta}={\left\{{\theta }_{r}\right\}}_{r=1}^{M}$ be the set of parameters θr sampled from q(θ)dθ. A hypothesis is defined as an ensemble of ${h}_{{\theta }_{r}}$ as follows:

Equation (1)

A typical example in the literature of the above formulation is a two-layer neural network.

Example 1 (two-layer network). Let ${a}_{r}\in \mathbb{R}$ and ${b}_{r}\in {\mathbb{R}}^{d}\ (r\in \left\{1,2,\dots ,M\right\})$ be parameters for output and input layers, respectively. We set θr = (ar , br ) and ${\Theta}={\left\{{\theta }_{r}\right\}}_{r=1}^{M}$. Denote ${h}_{{\theta }_{r}}(x)\stackrel{\mathrm{d}\mathrm{e}\mathrm{f}}{=}{\sigma }_{2}({a}_{r}{\sigma }_{1}({b}_{r}^{\top }x))\ (x\in \mathcal{X})$, where σ1 and σ2 are smooth activation functions. Then the hypothesis hΘ is a two-layer neural network composed of neurons ${h}_{{\theta }_{r}}$: ${h}_{{\Theta}}(x)=\frac{1}{M}{\sum }_{r=1}^{M}\;{\sigma }_{2}({a}_{r}{\sigma }_{1}({b}_{r}^{\top }x))$.

Remark. The purpose of σ2 in the last layer is to ensure the boundedness of output (e.g. see assumption 2 in [10]); this nonlinearity can also be removed if parameters of output layer are fixed. In addition, although we mainly focus on the optimization of two-layer neural network, our proposed method can also be applied to ensemble hΘ of deep neural networks ${h}_{{\theta }_{r}}$.

Suppose the parameters θr follow a probability distribution q(θ)dθ, then hΘ can be viewed as a finite-particle discretization of the following expectation,

Equation (2)

which we refer to as the mean field limit of the neural network hΘ. As previously discussed, when hΘ is overparameterized, optimizing hΘ becomes 'close' to directly optimizing the probability distribution on the parameter space Ω, for which convergence to the optimal solution may be established under appropriate conditions [911]. Hence, the study of optimization of hq with respect to the probability distribution q(θ)dθ may shed light on important properties of overparameterized neural networks.

2.2. Regularized empirical risk minimization

We briefly outline our setting for regularized expected/empirical risk minimization. The prediction error of a hypothesis is measured by the loss function $\ell (z,y)\ (z,y\in \mathcal{Y})$, such as the squared loss (z, y) = 0.5(zy)2 for regression, or the logistic loss (z, y) = log(1 + exp(−yz)) for binary classification. Let $\mathcal{D}$ be a data distribution over $\mathcal{X}\times \mathcal{Y}$. For expected risk minimization, the distribution $\mathcal{D}$ is set to the true data distribution; whereas for empirical risk minimization, we take $\mathcal{D}$ to be the empirical distribution defined by training data ${\left\{({x}_{i},{y}_{i})\right\}}_{i=1}^{n}\ ({x}_{i}\in \mathcal{X},{y}_{i}\in \mathcal{Y})$ independently sampled from the data distribution. We aim to minimize the expected/empirical risk together with a regularization term, which controls the model complexity and also stabilizes the optimization. The regularized objective can be written as follows: for λ1, λ2 > 0,

Equation (3)

where ${R}_{{\lambda }_{1},{\lambda }_{2}}$ is a regularization term composed of the weighted sum of the second-order moment and negative entropy with regularization parameters λ1, λ2:

Equation (4)

Note that this regularization is the KL divergence of q from a Gaussian distribution. In our setting, such regularization ensures that the Gibbs distributions ${q}_{\ast }^{(t)}$ specified in section 3 are well defined.

While our primary focus is the optimization of the objective (3), we can also derive a generalization error bound for the empirical risk minimizer of order of O(n−1/2) for both the regression and binary classification settings, following Chen et al [24]. We defer the details to appendix D.

2.3. The Langevin algorithm

Before presenting our proposed method, we briefly review the Langevin algorithm. For a given smooth potential function $f:{\Omega}\to \mathbb{R}$, the Langevin algorithm performs the following update: given the initial θ(1)q(1)(θ)dθ, step size η > 0, and Gaussian noise ${\zeta }^{(k)}\sim \mathcal{N}(0,{I}_{p})$,

Equation (5)

Under appropriate conditions on f, it is known that θ(t) converges to a stationary distribution proportional to exp(−f(⋅)) in terms of KL divergence at a linear rate (e.g. [40]) up to O(η)-error, where we hide additional factors in the big-O notation.

Alternatively, note that when the normalization constant ∫exp(−f(θ))dθ exists, the Boltzmann distribution in proportion to exp(−f(⋅)) is the solution of the following optimization problem,

Equation (6)

Hence we may interpret the Langevin algorithm as approximately solving an entropic regularized linear functional (i.e. free energy functional) on the probability space. This connection between sampling and optimization (see Dalalyan [41], Wibisono [27], Durmus et al [28]) enables us to employ the Langevin algorithm to obtain (samples from) the closed-form Boltzmann distribution which is the minimizer of (6); for example, many Bayesian inference problems fall into this category.

However, the objective (3) that we aim to optimize is beyond the scope of Langevin algorithm—due to the nonlinearity of loss (z, y) with respect to z, the stationary distribution cannot be described as a closed-form solution of (6). To overcome this limitation, we develop the PDA algorithm which efficiently solves (3) with quantitative runtime guarantees.

3. Proposed method

We now propose the PDA method to approximately solve the problem (3) by optimizing a two-layer neural network in the mean field regime; we also introduce the mean field limit of the proposed method to explain the algorithmic intuition and develop the convergence analysis.

3.1. Particle dual averaging

Our proposed PDA method (algorithm 1) is an optimization algorithm on the space of probability measures. The algorithm consists of an inner loop and outer loop; we run Langevin algorithm in inner loop to approximate a Gibbs distribution, which is optimized in the outer loop so that it converges to the optimal distribution q*. This outer loop update is designed to extend the classical DA scheme [1517] to infinite dimensional optimization problems (described in section 3.2). Below we provide a more detailed explanation.

  • In the outer loop, the last iterate ${\tilde{{\Theta}}}^{(t)}$ of the previous inner loop is given. We compute ${\partial }_{z}\ell ({h}_{{\tilde{{\Theta}}}^{(t)}}({x}_{t}),{y}_{t})$, which is a component of the Gibbs potential 5 , and initialize a set of particles Θ(1) at ${\tilde{{\Theta}}}^{(t)}$. In appendix B we introduce a different 'restarting' scheme for the initialization.
  • In the inner loop, we run the Langevin algorithm (noisy gradient descent) starting from Θ(1), where the gradient at the kth inner step is given by ${\nabla }_{\theta }{\bar{g}}^{(t)}({\theta }_{r}^{(k)})$, which is a sum of weighted average of ${\partial }_{z}\ell ({h}_{{\tilde{{\Theta}}}^{(s)}}({x}_{s}),{y}_{s}){\partial }_{\theta }h({\theta }_{r}^{(k)},{x}_{s})$ and the gradient of 2-regularization (see algorithm 1).

Algorithm 1. PDA.

 Input: data distribution $\mathcal{D}$, initial density q(1), number of outer-iterations T, learning rates ${\left\{{\eta }_{t}\right\}}_{t=1}^{T}$, number of inner-iterations ${\left\{{T}_{t}\right\}}_{t=1}^{T}$
 Randomly draw i.i.d. initial parameters ${\tilde{\theta }}_{r}^{(1)}\sim {q}^{(1)}(\theta )\mathrm{d}\theta \ (r\in \left\{1,2,\dots ,M\right\})$
 ${\tilde{{\Theta}}}^{(1)}{\leftarrow}{\left\{{\tilde{\theta }}_{r}^{(1)}\right\}}_{r=1}^{M}$
 for t = 1 to T do
  Randomly draw data (xt , yt ) from $\mathcal{D}$ ${{\Theta}}^{(1)}={\left\{{\theta }_{r}^{(1)}\right\}}_{r=1}^{M}{\leftarrow}{\tilde{{\Theta}}}^{(t)}$
  for k = 1 to Tt do
   Run inexact noisy gradient descent for r ∈ {1, 2, ..., M} ${\nabla }_{\theta }{\bar{g}}^{(t)}({\theta }_{r}^{(k)}){\leftarrow}\frac{2}{{\lambda }_{2}(t+2)(t+1)}{\sum }_{s=1}^{t}\;s{\partial }_{z}\ell ({h}_{{\tilde{{\Theta}}}^{(s)}}({x}_{s}),{y}_{s}){\partial }_{\theta }h({\theta }_{r}^{(k)},{x}_{s})+\frac{2{\lambda }_{1}t}{{\lambda }_{2}(t+2)}{\theta }_{r}^{(k)}$ ${\theta }_{r}^{(k+1)}{\leftarrow}{\theta }_{r}^{(k)}-{\eta }_{t}{\nabla }_{\theta }{\bar{g}}^{(t)}({\theta }_{r}^{(k)})+\sqrt{2{\eta }_{t}}{\zeta }_{r}^{(k)}$ (i.i.d. Gaussian noise ${\zeta }_{r}^{(k)}\sim \mathcal{N}(0,{I}_{p})$)
  end for
  ${\tilde{{\Theta}}}^{(t+1)}{\leftarrow}{{\Theta}}^{({T}_{t}+1)}={\left\{{\theta }_{r}^{({T}_{t}+1)}\right\}}_{r=1}^{M}$
 end for
 Randomly pick up t ∈ {2, 3, ..., T + 1} following the probability $\mathbb{P}[t]=\frac{2t}{T(T+3)}$ and return ${h}_{{\tilde{{\Theta}}}^{(t)}}$

Figure 1 provides a pictorial illustration of algorithm 1. Note that this procedure is a slight modification of the normal gradient descent algorithm: the first term of ${\nabla }_{\theta }{\bar{g}}^{(t)}$ is similar to the gradient of the loss ${\partial }_{{\theta }_{r}}\ell ({h}_{{{\Theta}}^{(k)}}(x),y)\sim {\partial }_{z}\ell ({h}_{{{\Theta}}^{(k)}}(x),y){\partial }_{\theta }h({\theta }_{r}^{(k)},x)$ where ${{\Theta}}^{(k)}={\left\{{\theta }_{r}^{(k)}\right\}}_{r=1}^{M}$. Indeed, if we set the number of inner-iterations Tt = 1 and replace the direction ${\nabla }_{\theta }{\bar{g}}^{(t)}({\theta }_{r}^{(k)})$ with the gradient of the L2-regularized loss, then PDA exactly reduces to the standard noisy gradient descent algorithm considered in [10]. Algorithm 1 can be extended to the minibatch variant in the obvious manner; for efficient implementation in the empirical risk minimization setting see appendix E.1.

3.2. Mean field view of PDA

In this subsection we discuss the mean field limit of PDA and explain its algorithmic intuition. Note that the inner loop of algorithm 1 is the Langevin algorithm with M particles, which optimizes the potential function given by the weighted sum:

Due to the rapid convergence of Langevin algorithm outlined in subsection 2.3, the particles ${\theta }_{r}^{(k+1)}\ (r\in \left\{1,\dots ,M\right\})$ can be regarded as (approximate) samples from the Boltzmann distribution: $\mathrm{exp}\left(-{\bar{g}}^{(t)}\right)$. Hence, the inner loop of PDA returns an M-particle approximation of some stationary distribution, which is then modified in the outer loop. Importantly, the update on the stationary distribution is designed so that the algorithm converges to the optimal solution of the problem (3).

We now introduce the mean field limit of PDA, i.e. taking the number of particles M and directly optimizing the problem (3) over q. We refer to this mean field limit simply as the DA algorithm. The DA method was originally developed for the convex optimization in finite-dimensional spaces [1517], and here we adapt it to optimization on the probability space. The detail of the DA algorithm is described in algorithm 2.

Algorithm 2. DA.

 Input: data distribution $\mathcal{D}$ and initial density q(1)
 for t = 1 to T do
  Randomly draw a data (xt , yt ) from $\mathcal{D}$ ${g}^{(t)}{\leftarrow}{\partial }_{z}\ell ({h}_{{q}^{(t)}}({x}_{t}),{y}_{t})h(\cdot ,{x}_{t})+{\lambda }_{1}{\Vert}\cdot {{\Vert}}_{2}^{2}$ Obtain an approximation q(t+1) of the density function ${q}_{\ast }^{(t+1)}\propto \mathrm{exp}\left(-\frac{{\sum }_{s=1}^{t}\;2s{g}^{(s)}}{{\lambda }_{2}(t+2)(t+1)}\right)$
 end for
 Randomly pick up t ∈ {2, 3, ..., T + 1} following the probability $\mathbb{P}[t]=\frac{2t}{T(T+3)}$ and return ${h}_{{q}^{(t)}}$

Algorithm 2 iteratively updates the density function ${q}_{\ast }^{(t+1)}\in {\mathcal{P}}_{2}$ which is a solution to the objective:

Equation (7)

where the function ${g}^{(t)}={\partial }_{z}\ell ({h}_{{q}^{(t)}}({x}_{t}),{y}_{t})h(\cdot ,{x}_{t})+{\lambda }_{1}{\Vert}\cdot {{\Vert}}_{2}^{2}$ is the functional derivative of $\ell ({h}_{q}({x}_{{i}_{i}}),{y}_{t})+{\lambda }_{1}{\mathbb{E}}_{q}[{\Vert}\theta {{\Vert}}_{2}^{2}]$ with respect to q at q(t). In other words, the objective (7) is the sum of weighted average of linear approximations of loss function and the entropic regularization in the space of probability distributions. In this sense, the DA method can be seen as an extension of the Langevin algorithm to handle entropic regularized nonlinear functionals on the probability space by iteratively linearizing the objective.

To sum up, we may interpret the DA method as approximating the optimal distribution q* by iteratively optimizing ${q}_{\ast }^{(t)}$, which takes the form of a Boltzmann distribution. In the inner loop of the PDA algorithm, we obtain M (approximate) samples from ${q}_{\ast }^{(t)}$ via the Langevin algorithm. In other words, PDA can be viewed as a finite-particle approximation of DA—indeed, the stationary distributions obtained in PDA converges to ${q}_{\ast }^{(t+1)}$ by taking M. In the following section, we present the convergence rate of the DA method, and also take into account the iteration complexity of the Langevin algorithm; we defer the finite-particle approximation error analysis to appendix C.

4. Convergence analysis

We now provide quantitative global convergence guarantee for our proposed method in discrete time. We first derive the outer loop complexity, assuming approximate optimality of the inner loop iterates, which we then verify in the inner loop analysis. The total complexity is then simply obtained by combining the outer- and inner-loop runtime.

4.1. Outer loop complexity

We first analyze the convergence rate of the DA method (algorithm 2). Our analysis will be made under the following assumptions.

Assumption 1. 

  • (A1)  
    $\mathcal{Y}\subset [-1,1]$. (z, y) is a smooth convex function w.r.t. z and |∂z (z, y)| ⩽ 2 for $y,z\in \mathcal{Y}$.
  • (A2)  
    |h(θ, x)| ⩽ 1 and h(θ, x) is smooth with respect to θ for $x\in \mathcal{X}$.
  • (A3)  
    $\mathrm{K}\mathrm{L}({q}^{(t+1)}{\Vert}{q}_{\ast }^{(t+1)})\leqslant 1/{t}^{2}$.

Remark.  (A2) is satisfied by smooth activation functions such as sigmoid and tanh. Many loss functions including the squared loss and logistic loss satisfy (A1) under the boundedness assumptions $\mathcal{Y}\subset [-1,1]$ and |hθ (x)| ⩽ 1. Note that constants in (A1) and (A2) are defined for simplicity and can be relaxed to any value. (A3) specifies the precision of approximate solutions of sub-problems (7) to guarantee the global convergence of algorithm 2, which we verify in our inner loop analysis.

We first introduce the following quantity for $q\in {\mathcal{P}}_{2}$,

Observe that the expression consists of the negative entropy minus its lower bound for ${q}_{\ast }^{(t)}$ under assumptions (A1) and (A2); in other words $e({q}_{\ast }^{(t)})\geqslant 0$. We have the following convergence rate of DA 6 .

Theorem 1 (convergence of DA). Under assumptions (A1), (A2), and (A3), for arbitrary ${q}_{\ast }\in {\mathcal{P}}_{2}$, iterates of the DA method (algorithm 2) satisfies

where the expectation $\mathbb{E}[\mathcal{L}({q}^{(t)})]$ is taken with respect to the history of examples.

Theorem 1 demonstrates the convergence rate of algorithm 2 to the optimal value of the regularized objective (3) in expectation. Note that $\frac{2}{T(T+3)}{\sum }_{t=2}^{T+1}\;t\mathbb{E}[\mathcal{L}({q}^{(t)})]$ is the expectation of $\mathbb{E}[\mathcal{L}({q}^{(t)})]$ according to the probability $\mathcal{P}[t]=\frac{2t}{T(T+3)}\ (t\in \left\{2,\dots ,T+1\right\})$ as specified in algorithm 2. If we take p, λ1, λ2 as constants and use $\tilde{O}$ to hide the logarithmic terms, we can deduce that after $\tilde{O}({{\epsilon}}^{-1})$ iterations, an epsilon-accurate solution of the optimal distribution: $\mathcal{L}(q)\leqslant {\mathrm{inf}}_{q\in {\mathcal{P}}_{2}}\,\mathcal{L}(q)+{\epsilon}$ is achieved in expectation. Importantly, this convergence rate applies to any choice of regularization parameters, in contrast to the strong regularization required in [14, 42]. On the other hand, due to the exponential dependence on ${\lambda }_{2}^{-1}$, our convergence rate is not informative under weak regularization λ2 → 0. Such dependence follows from the classical LSI perturbation lemma [43], which is likely unavoidable for Langevin-based methods in the most general setting [44], unless additional assumptions are imposed (e.g. a student–teacher setup); we intend to further investigate these conditions in future work.

4.2. Inner loop complexity

In order to derive the total complexity (i.e. taking both the outer loop and inner loop into account) towards a required accuracy, we also need to estimate the iteration complexity of Langevin algorithm. We utilize the following convergence result under the log-Sobolev inequality (definition A):

Theorem 2 [40]. Consider a probability density q(θ) ∝ exp(−f(θ)) satisfying the log-Sobolev inequality with constant α, and assume f is smooth and ∇f is L-Lipschitz, i.e. ||∇θ f(θ) − ∇θ f(θ')||2L||θθ'||2. If we run the Langevin algorithm (5) with learning rate $0< \eta \leqslant \frac{\alpha }{4{L}^{2}}$ and let q(k)(θ)dθ be a probability distribution that θ(k) follows, then we have,

Theorem 2 implies that a δ-accurate solution in KL divergence can be obtained by the Langevin algorithm with $\eta \leqslant \frac{\alpha }{4{L}^{2}}\,\mathrm{min}\left\{1,\frac{\delta }{4p}\right\}$ and $\frac{1}{\alpha \eta }\,\mathrm{log}\,\frac{2\mathrm{K}\mathrm{L}({q}^{(1)}{\Vert}q)}{\delta }$-iterations. Since the optimal solution of a sub-problem in DA (algorithm 2) takes the forms of ${q}_{\ast }^{(t+1)}\propto \mathrm{exp}\left(-\frac{{\sum }_{s=1}^{t}\;2s{g}^{(s)}}{{\lambda }_{2}(t+2)(t+1)}\right)$, we can verify the LSI and determine the constant for ${q}_{\ast }^{(t+1)}(\theta )\mathrm{d}\theta $ based on the LSI perturbation lemma from Holley and Stroock [43] (see lemma B and example B in appendix A.2). Consequently, we can apply theorem 2 to ${q}_{\ast }^{(t+1)}$ for the inner loop complexity when ${\nabla }_{\theta }\,\mathrm{log}\,{q}_{\ast }^{(t+1)}$ is Lipschitz continuous, which motivates us to introduce the following assumption.

Assumption 2. 

  • (A4)  
    θ h(⋅, x) is one-Lipschitz continuous: ||∂θ h(θ, x) − ∂θ h(θ', x)||2 ⩽ ||θθ'||2, $\forall x\in \mathcal{X}$, θ, θ' ∈ Ω.

Remark.  (A4) is parallel to [10, assumption A3], and is satisfied by two-layer neural network in example 1 when the output or input layer is fixed and the input space $\mathcal{X}$ is compact. We remark that this assumption can be relaxed to Hölder continuity of ∂θ h(⋅, x) via the recent result of Erdogdu and Hosseinzadeh [45], which allows us to extend theorem 1 to general Lp -norm regularizer for p > 1. For now we work with (A4) for simplicity of the presentation and proof.

Set δt+1 to be the desired accuracy of an approximate solution q(t+1) specified in (A3): δt+1 = 1/(t + 1)2, we have the following guarantee for the inner loop.

Corollary 1 (inner loop complexity). Under (A1), (A2), and (A4), if we run the Langevin algorithm with step size ${\eta }_{t}=O\left(\frac{{\lambda }_{1}{\lambda }_{2}{\delta }_{t+1}}{p{(1+{\lambda }_{1})}^{2}\mathrm{exp}(8/{\lambda }_{2})}\right)$ on (7), then an approximate solution satisfying $\mathrm{K}\mathrm{L}({q}^{(t+1)}{\Vert}{q}_{\ast }^{(t+1)})\leqslant {\delta }_{t+1}$ can be obtained within $O\left(\frac{{\lambda }_{2}\,\mathrm{exp}(8/{\lambda }_{2})}{{\lambda }_{1}{\eta }_{t}}\,\mathrm{log}\,\frac{2\mathrm{K}\mathrm{L}({q}^{(t)}{\Vert}{q}_{\ast }^{(t+1)})}{{\delta }_{t+1}}\right)$-iterations. Moreover,$\mathrm{K}\mathrm{L}({q}^{(t)}{\Vert}{q}_{\ast }^{(t+1)})\ (t\in \left\{1,2,\dots ,T+1\right\})$ are uniformly bounded with respect to t as long as q(1) is a Gaussian distribution and (A3) is satisfied.

We comment that for the inner loop we utilized the overdamped Langevin algorithm, since it is the most standard and commonly used sampling method for the objective (7). Our analysis can easily incorporate other inner loop updates such as the underdamped Langevin algorithm [46, 47] or the Metropolis-adjusted Langevin algorithm [29, 48], which may improve the iteration complexity.

4.3. Total complexity

Combining theorem 1 and corollary 1, we can now derive the total complexity of our proposed algorithm. For simplicity, we take p, λ1, λ2 as constants and hide logarithmic terms in $\tilde{O}$ and $\tilde{{\Theta}}$. The following corollary establishes a $\tilde{O}({{\epsilon}}^{-3})$ total iteration complexity to obtain an epsilon-accurate solution in expectation because ${T}_{t}=\tilde{{\Theta}}({t}^{2})=\tilde{O}({{\epsilon}}^{-2})$ for tT.

Corollary 2 (total complexity). Let epsilon > 0 be an arbitrary desired accuracy and q(1) be a Gaussian distribution. Under assumptions (A1), (A2), (A3), and (A4), if we run algorithm 2 for $T=\tilde{{\Theta}}({{\epsilon}}^{-1})$ iterations on the outer loop, and the Langevin algorithm with step size ${\eta }_{t}={\Theta}\left(\frac{{\lambda }_{1}{\lambda }_{2}{\delta }_{t+1}}{p{(1+{\lambda }_{1})}^{2}\mathrm{exp}(8/{\lambda }_{2})}\right)$ for ${T}_{t}=\tilde{{\Theta}}({\eta }_{t}^{-1})$ iterations on the inner loop, then an epsilon-accurate solution: $\mathcal{L}(q)\leqslant {\mathrm{inf}}_{q\in {\mathcal{P}}_{2}}\,\mathcal{L}(q)+{\epsilon}$ of the objective (3) is achieved in expectation.

Quantitative convergence guarantee. To translate the above convergence rate result to the finite-particle PDA (algorithm 1), we also characterize the finite-particle discretization error in appendix C. For the particle complexity analysis, we consider two versions of particle update: (i) the warm-start scheme described in algorithm 1, in which Θ(1) is initialized at the last iterate ${\tilde{{\Theta}}}^{(t)}$ of the previous inner loop, and (ii) the resampling scheme, in which Θ(1) is initialized from the initial distribution q(1)(θ)dθ (see appendix B for details). Remarkably, for the resampling scheme, we provide convergence rate guarantee in time- and space-discretized settings that is polynomial in both the iterations and particle size; specifically, the particle complexity of $\tilde{O}({{\epsilon}}^{-2})$, together with the total iteration complexity of $\tilde{O}({{\epsilon}}^{-3})$, suffices to obtain an epsilon-accurate solution to the objective (3) (see appendices B and C for precise statement).

5. Experiments

5.1. Experiment setup

We employ our proposed algorithm in both synthetic student–teacher settings (see figures 2(a) and (b)) and real-world dataset (see figure 2(c)). For the student–teacher setup, the labels are generated as yi = f*(xi ) + ɛi , where f* is the teacher model (target function), and ɛ is zero-mean i.i.d. label noise. For the student model f, we follow Mei et al [10, section 2.1] and parameterize a two-layer neural network with fixed second layer as:

Equation (8)

which we train to minimize the objective (3) using PDA. Note that α = 1 corresponds to the mean field regime (which we are interested in), whereas setting α = 1/2 leads to the kernel (NTK) regime 7 .

Figure 2.

Figure 2. (a) Iteration complexity of PDA: the O(T−1) outer loop rate agrees with theorem 1. (b) Parameter trajectory of PDA: darker color (purple) indicates earlier in training, and vice versa. (c) Odd vs even classification on MNIST; we report the training loss (red) as well as the train and test accuracy (blue and green).

Standard image High-resolution image

Synthetic student–teacher setting. For figures 2(a) and (b) we design synthetic experiments for both regression and classification tasks, where the student model is a two-layer tanh network with M = 500. For regression, we take the target function f* to be a multiple-index model with m neurons: ${f}_{\ast }(x)=\frac{1}{\sqrt{m}}{\sum }_{i=1}^{m}\;{\sigma }_{\ast }(\langle {w}_{i}^{\ast },x\rangle )$, and the input is drawn from a unit Gaussian $\mathcal{N}(0,{I}_{p})$. For binary classification, we consider a simple two-dimensional dataset from sklearn.datasets.make_circles [49], in which the goal is to separate two groups of data on concentric circles (red and blue in figure 2(b)). We include additional experimental results in appendix F.

PDA hyperparameters. We optimize the squared loss for regression and the logistic loss for binary classification. The model is trained by PDA with batch size 50. We scale the number of inner loop steps Tt with t, and the step size ηt with $1/\sqrt{t}$, where t is the outer loop iteration; this heuristic is consistent with the required inner-loop accuracy in theorem 1 and corollary 2.

5.2. Empirical findings

Convergence rate. In figure 2(a) we verify the O(T−1) iteration complexity of the outer loop in theorem 1. We apply PDA to optimize the expected risk (analogous to one-pass SGD) in the regression setting, in which the input dimensionality p = 1 and the target function is a single-index model (m = 1) with tanh activation. We employ the resampled update (i.e. without warm-start; see appendix B) with hyperparameters λ1 = 10−2, λ2 = 10−3. To compute the entropy in the objective (3), we adopt the k-nearest neighbors estimator [50] with k = 10.

Presence of feature learning. In figure 2(b) we visualize the evolution of neural network parameters optimized by PDA in a two-dimensional classification problem. Due to structure of the input data (concentric rings), we expect that for a two-layer neural network to be a good separator, its parameters should also distribute on a circle. Indeed the converged solution of PDA (bright yellow) agrees with this intuition and demonstrates that PDA learns useful features beyond the kernel regime.

Binary classification on MNIST. In figure 2(c) we report the training and test performance of PDA in separating odd vs even digits from the MNIST dataset. We subsample n = 2500 training examples with binary labels, and learn a two-layer tanh network with width M = 2500. We use the resampled update of PDA to optimize the cross entropy loss, with hyperparameters λ1 = 10−2, λ2 = 10−4. Observe that the algorithm achieves good generalization performance (green) and roughly maintains 8 the O(T−1) iteration complexity (red) in optimizing the training objective (3).

6. Conclusion

We proposed the PDA algorithm for optimizing two-layer neural networks in the mean field regime. Leveraging tools from finite-dimensional convex optimization developed in the original DA method, we established quantitative convergence rate of PDA for regularized empirical and expected risk minimization. We also provided particle complexity analysis and generalization bounds for both regression and classification problems. Our theoretical findings are aligned with experimental results on neural network optimization. Looking forward, we plan to investigate specific problem instances in which convergence rate can be obtained under vanishing regularization. It is also important to consider accelerated variants of PDA to further improve the convergence rate in the empirical risk minimization setting. Another interesting direction would be to explore other applications of PDA beyond two-layer neural networks, such as deep models [5154], as well as other optimization problems for entropic regularized nonlinear functional.

Acknowledgments

The authors would like to thank Murat A Erdogdu and for helpful feedback. A N was partially supported by JSPS Kakenhi (22H03650) and JST-PRESTO (JPMJPR1928). D W was partially supported by a Borealis AI Fellowship. T S was partially supported by JSPS KAKENHI (20H00576), Japan Digital Design and JST CREST.

Missing proofs

Appendix A.: Preliminaries

A.1. Entropic regularized linear functional

In this section, we explain the property of the optimal solution of the entropic regularized linear functional. We here define the gradient of the negative entropy ${\mathbb{E}}_{q}[\mathrm{log}(q)]$ with respect to q over the probability space as ${\nabla }_{q}{\mathbb{E}}_{q}[\mathrm{log}(q)]=\mathrm{log}(q)$. Note that this gradient is well defined up to constants as a linear operator on the probability space: q' ↦ ∫(q' − q)(θ)log(q(θ))dθ. The following lemma shows the strong convexity of the negative entropy.

Lemma A. Let q, q' be probability densities such that the entropy and Kullback–Leibler divergence $\mathrm{K}\mathrm{L}({q}^{\prime }{\Vert}q)=\int {q}^{\prime }(\theta )\mathrm{log}\left(\frac{{q}^{\prime }(\theta )}{q(\theta )}\right)\mathrm{d}\theta $ are well defined. Then, we have

The first equality of this lemma can be shown by the direct computation of the entropy, and the second inequality can be obtained by Pinsker's inequality $\frac{1}{2}{\Vert}{q}^{\prime }-q{{\Vert}}_{{L}_{1}(\mathrm{d}\theta )}^{2}\leqslant \mathrm{K}\mathrm{L}({q}^{\prime }{\Vert}q)$.

Recall that ${\mathcal{P}}_{2}$ is the set of positive densities on ${\mathbb{R}}^{p}$ such that the second moment ${\mathbb{E}}_{q}[{\Vert}\theta {{\Vert}}_{2}^{2}]< \infty $ and entropy $-\infty < -{\mathbb{E}}_{q}[\mathrm{log}(q)]< +\infty $ are well defined. We here consider the minimization problem of entropic regularized linear functional on ${\mathcal{P}}_{2}$. Let λ1, λ2 > 0 be positive real numbers and $H:{\mathbb{R}}^{p}\to \mathbb{R}$ be a bounded continuous function.

Equation (9)

Then, we can show $q\propto \mathrm{exp}\left(-\frac{H(\theta )+{\lambda }_{1}{\Vert}\theta {{\Vert}}_{2}^{2}}{{\lambda }_{2}}\right)$ is an optimal solution of the problem (9) as follow. Clearly, $q\in {\mathcal{P}}_{2}$ and the assumption on q in lemma A with ${q}^{\prime }\in {\mathcal{P}}_{2}$ holds. Hence, for $\forall {q}^{\prime }\in {\mathcal{P}}_{2}$,

Equation (10)

For the inequality we used lemma A and for the last equality we used $q\propto \mathrm{exp}\left(-\frac{H(\theta )+{\lambda }_{1}{\Vert}\theta {{\Vert}}_{2}^{2}}{{\lambda }_{2}}\right)$. Therefore, we conclude that q is a minimizer of F on ${\mathcal{P}}_{2}$ and the strong convexity of F holds at q with respect to L1(dθ)-norm. This crucial property is used in the proof of theorem 1.

A.2. Log-Sobolev and Talagrand's inequalities

The log-Sobolev inequality is useful in establishing the convergence rate of Langevin algorithm.

Definition A (log-Sobolev inequality). Let dμ = p(θ)dθ be a probability distribution with a positive smooth density p > 0 on ${\mathbb{R}}^{p}$. We say that μ satisfies the log-Sobolev inequality with constant α > 0 if for any smooth function $f:{\mathbb{R}}^{p}\to \mathbb{R}$,

This inequality is analogous to strong convexity in optimization: let dν = q(θ)dμ be a probability distribution on ${\mathbb{R}}^{p}$ such that q is smooth and positive. Then, if μ satisfies the log-Sobolev inequality with α, it follows that

The above relation is directly obtained by setting $f=\sqrt{q}$ in the definition of log-Sobolev inequality. Note that the right-hand side is nothing else but the squared norm of functional gradient of KL(ν||μ) with respect to a transport map for ν.

It is well-known that strong log-concave densities satisfy the LSI with a dimension-free constant (up to the spectral norm of the covariance).

Example B [55]. Let q ∝ exp(−f) be a probability density, where $f:{\mathbb{R}}^{p}\to \mathbb{R}$ is a smooth function. If there exists c > 0 such that ∇2 fcIp , then q(θ)dθ satisfies log-Sobolev inequality with constant c.

In addition, the LSI is preserved under bounded perturbation, as originally shown in [43]. We also provide a proof for completeness.

Lemma B [43]. Let q(θ)dθ be a probability distribution on ${\mathbb{R}}^{p}$ satisfying the log-Sobolev inequality with a constant α. For a bounded function $B:{\mathbb{R}}^{p}\to \mathbb{R}$, we define a probability distribution qB (θ)dθ as follows:

Then, qB dθ satisfies the log-Sobolev inequality with a constant α/exp(4||B||).

Proof. Taking an expectation ${\mathbb{E}}_{{q}_{B}}$ of the Bregman divergence defined by a convex function x log x, for ∀a > 0,

Since the minimum is attained at $a={\mathbb{E}}_{{q}_{B}}[{f}^{2}(\theta )]$,

where we used the non-negativity of the integrand for the second inequality. □

We next introduce Talagrand's inequality.

Definition B (Talagrand's inequality). We say that a probability distribution q(θ)dθ satisfies Talagrand's inequality with a constant α > 0 if for any probability distribution q'(θ)dθ,

where W2(q', q) denotes the two-Wasserstein distance between q(θ)dθ and q'(θ)dθ.

The next theorem gives a relationship between KL divergence and two-Wasserstein distance.

Theorem A [56]. If a probability distribution q(θ)dθ satisfies the log-Sobolev inequality with constant α > 0, then q(θ)dθ satisfies Talagrand's inequality with the same constant.

Appendix B.: Proof of main results

B.1. Extension of algorithm

In this section, we prove the main theorem that provides the convergence rate of the DA method. We first introduce a slight extension of PDA (algorithm 1) which incorporates two different initializations at each outer loop step. We refer to the two versions as the warm-start and the resampled update, respectively. Note that algorithm 1 in the main text only includes the warm-start update. In appendix C we provide particle complexity analysis for both updates. We remark that the benefit of resampling strategy is the simplicity of estimation of approximation error $\vert {h}_{x}^{(t)}-{h}_{{q}^{(t)}}({x}_{t})\vert $, because ${h}_{x}^{(t)}$ is composed of i.i.d particles and a simple concentration inequality can be applied to estimate this error.

We also extend the mean field limit (algorithm 2) to take into account the inexactness in computing ${h}_{{q}^{(t)}}(t)$. This relaxation is useful in convergence analysis of algorithm 3 with resampling because it allows us to regard this method as an instance of the generalized DA method (algorithm 4) by setting an inexact estimate ${h}_{x}^{(t)}={h}_{{\tilde{{\Theta}}}^{(t)}}({x}_{t})$, instead of the exact value of ${h}_{{q}^{(t)}}(t)$, which is actually used to defined the potential for which Langevin algorithm run in algorithm 3. This means convergence analysis of algorithm 4 (theorem B) immediately provides a convergence guarantee for algorithm 3 if the discretization error $\vert {h}_{x}^{(t)}-{h}_{{q}^{(t)}}({x}_{t})\vert $ can be estimated (as in the resampling scheme).

Algorithm 3. PDA (general version).

 Input: data distribution $\mathcal{D}$, initial density q(1), number of outer-iterations T, learning rates ${\left\{{\eta }_{t}\right\}}_{t=1}^{T}$, number of inner-iterations ${\left\{{T}_{t}\right\}}_{t=1}^{T}$
 Randomly draw i.i.d. initial parameters ${\tilde{\theta }}_{r}^{(1)}\sim {q}^{(1)}(\theta )\mathrm{d}\theta \ (r\in \left\{1,2,\dots ,M\right\})$
 ${\tilde{{\Theta}}}^{(1)}{\leftarrow}{\left\{{\tilde{\theta }}_{r}^{(1)}\right\}}_{r=1}^{M}$
 for t = 1 to T do
  Randomly draw a data (xt , yt ) from $\mathcal{D}$ Either ${{\Theta}}^{(1)}={\left\{{\theta }_{r}^{(1)}\right\}}_{r=1}^{M}{\leftarrow}{\tilde{{\Theta}}}^{(t)}$ (warm-start) Or randomly initialize Θ(1) from q(1)(θ)dθ (resampling)
  for k = 1 to Tt do
   Run an inexact noisy gradient descent for r ∈ {1, 2, ..., M} ${\nabla }_{\theta }{\bar{g}}^{(t)}({\theta }_{r}^{(k)}){\leftarrow}\frac{2}{{\lambda }_{2}(t+2)(t+1)}{\sum }_{s=1}^{t}\;s{\partial }_{z}\ell ({h}_{{\tilde{{\Theta}}}^{(s)}}({x}_{s}),{y}_{s}){\partial }_{\theta }h({\theta }_{r}^{(k)},{x}_{s})+\frac{2{\lambda }_{1}t}{{\lambda }_{2}(t+2)}{\theta }_{r}^{(k)}$ ${\theta }_{r}^{(k+1)}{\leftarrow}{\theta }_{r}^{(k)}-{\eta }_{t}{\nabla }_{\theta }{\bar{g}}^{(t)}({\theta }_{r}^{(k)})+\sqrt{2{\eta }_{t}}{\zeta }_{r}^{(k)}$ (i.i.d. Gaussian noise ${\zeta }_{r}^{(k)}\sim \mathcal{N}(0,{I}_{p})$)
  end for
  ${\tilde{{\Theta}}}^{(t+1)}{\leftarrow}{{\Theta}}^{({T}_{t}+1)}={\left\{{\theta }_{r}^{({T}_{t}+1)}\right\}}_{r=1}^{M}$
 end for
 Randomly pick up t ∈ {2, 3, ..., T + 1} following the probability $\mathbb{P}[t]=\frac{2t}{T(T+3)}$ and return ${h}_{{\tilde{{\Theta}}}^{(t)}}$

Algorithm 4. DA (general version).

 Input: data distribution $\mathcal{D}$ and initial density q(1)
 for t = 1 to T do
  Randomly draw a data (xt , yt ) from $\mathcal{D}$ Compute an approximation ${h}_{x}^{(t)}$ of ${h}_{{q}^{(t)}}({x}_{t})$ ${g}^{(t)}{\leftarrow}{\partial }_{z}\ell ({h}_{x}^{(t)},{y}_{t})h(\cdot ,{x}_{t})+{\lambda }_{1}{\Vert}\cdot {{\Vert}}_{2}^{2}$ Obtain an approximation q(t+1) of the density function ${q}_{\ast }^{(t+1)}\propto \mathrm{exp}\left(-\frac{{\sum }_{s=1}^{t}\;2s{g}^{(s)}}{{\lambda }_{2}(t+2)(t+1)}\right)$
 end for
 Randomly pick up t ∈ {2, 3, ..., T + 1} following the probability $\mathbb{P}[t]=\frac{2t}{T(T+3)}$ and return ${h}_{{q}^{(t)}}$

On the other hands, the convergence analysis of warm-start scheme requires the convergence of mean field limit due to certain technical difficulties, that is, we show the convergence of algorithm 3 with warm-start by coupling the update with its mean field limit (algorithm 2) and taking into account the discretization error which stems from finite-particle approximation.

We now present generalized version of the outer loop convergence rate of DA. We highlight the tolerance factor epsilon in the generalized assumption (A3') in blue.

Assumption A. Let epsilon > 0 be a given accuracy.

  • (A1')  
    $\mathcal{Y}\subset [-1,1]$. (z, y) is a smooth convex function w.r.t. z and |∂z (z, y)| ⩽ 2 for $y,z\in \mathcal{Y}$ and ∂(⋅, y) is one-Lipschitz continuous for $y\in \mathcal{Y}$.
  • (A2')  
    |hθ (x)| ⩽ 1 and h(θ, x) is smooth w.r.t. θ for $x\in \mathcal{X}$.
  • (A3')  
    $\mathrm{K}\mathrm{L}({q}^{(t+1)}{\Vert}{q}_{\ast }^{(t+1)})\leqslant 1/{t}^{2}$ and $\vert {h}_{x}^{(t)}-{h}_{{q}^{(t)}}({x}_{t})\vert \leqslant {\epsilon}$ for t ⩾ 1.

Remark. The new condition of (A3') allows for inexactness of computing ${h}_{{q}^{(t)}}({x}_{t})$. When showing solely the convergence of the algorithm 2 which is the exact mean-field limit, the original assumptions (A1), (A2), and (A3) are sufficient, in other words, we can take epsilon = 0 and Lipschitz continuity of ∂z (⋅, y) in (A1') can be relaxed.

Theorem B (convergence of general DA). Under assumptions (A1'), (A2'), and (A3') with epsilon ⩾ 0, for arbitrary ${q}_{\ast }\in {\mathcal{P}}_{2}$, iterates of the general DA method (algorithm 4) satisfies

where the expectation $\mathbb{E}[\mathcal{L}({q}^{(t)})]$ is taken with respect to the history of examples.

Notation. In the proofs, we use the following notations which are consistent with the description of algorithms 3 and 4:

When considering the resampling scheme, ${h}_{x}^{(t)}$ is set to the approximation ${h}_{{\tilde{{\Theta}}}^{(t)}}({x}_{t})$, whereas when considering the warm-start scheme, ${h}_{x}^{(t)}$ is set to ${h}_{{q}^{(t)}}({x}_{t})$ with the mean field limit M and without tolerance (epsilon = 0).

B.2. Auxiliary lemmas

We introduce several auxiliary results used in the proof of theorem 1 (theorem B) and corollary 1. The following lemma provides a tail bound for chi-squared variables [57].

Lemma C (tail bound for chi-squared variable). Let $\theta \sim \mathcal{N}(0,{\sigma }^{2}{I}_{p})$ be a Gaussian random variable on ${\mathbb{R}}^{p}$. Then, we get for ∀c2,

Based on lemma C, we get the following bound.

Lemma D. Let $\theta \sim \mathcal{N}(0,{\sigma }^{2}{I}_{p})$ be Gaussian random variable on ${\Theta}={\mathbb{R}}^{p}$. Then, we get for ∀R2,

where $Z=\int \mathrm{exp}\left(-\frac{{\Vert}\theta {{\Vert}}_{2}^{2}}{2{\sigma }^{2}}\right)\mathrm{d}\theta $.

Proof. We set $p(\theta )=\mathrm{exp}(-{\Vert}\theta {{\Vert}}_{2}^{2}/2{\sigma }^{2})/Z$. Then,

Proposition A (continuity). Let ${q}_{\ast }(\theta )\propto \mathrm{exp}\left(-H(\theta )-\lambda {\Vert}\theta {{\Vert}}_{2}^{2}\right)\ (\lambda > 0)$ be a density on ${\mathbb{R}}^{p}$ such that ||H||c. Then, for ∀δ > 0 and a density $\forall q\in {\mathcal{P}}_{2}$,

Proof. Let γ be an optimal coupling between qdθ and q*dθ. Using Young's inequality, we have

Equation (11)

The last term can be bounded as follows:

Equation (12)

where the last equality comes from the variance of Gaussian distribution.

From (11) and (12),

From the symmetry of (11), and applying (11) again with (12),

From lemma B and example B, we see q* satisfies the log-Sobolev inequality with a constant 2λ/exp(4c). As a result, q* satisfies Talagrand's inequality with the same constant from theorem A. Hence, by combining the above two inequalities, we have

Therefore, we know that

where we used Pinsker's theorem for the last inequality. This finishes the proof. □

Proposition B (maximum entropy). Let ${q}_{\ast }(\theta )\propto \mathrm{exp}\left(-H(\theta )-\lambda {\Vert}\theta {{\Vert}}_{2}^{2}\right)\ (\lambda > 0)$ on ${\mathbb{R}}^{p}$ be a density such that ||H||c. Then,

Proof. It follows that

where we used (12) and Gaussian integral for the last inequality. □

Proposition C (boundedness of KL-divergence). Let ${q}_{\ast }(\theta )\propto \mathrm{exp}\left(-{H}_{\ast }(\theta )-{\lambda }_{\ast }{\Vert}\theta {{\Vert}}_{2}^{2}\right)({\lambda }_{\ast } > 0)$ be a density on ${\mathbb{R}}^{p}$ such that ||H*||c*, and ${q}_{{\sharp}}(\theta )\propto \mathrm{exp}\left(-{H}_{{\sharp}}(\theta )-{\lambda }_{{\sharp}}{\Vert}\theta {{\Vert}}_{2}^{2}\right)({\lambda }_{{\sharp}} > 0)$ be a density on ${\mathbb{R}}^{p}$ such that ||H ||c . Then, for any density q,

Proof. Applying proposition A with δ = 1,

We next bound the first term in the last equation as follows

where for the first inequality we used a similar inequality as in (12) and for the second inequality we used the Gaussian integral. Hence, we get

Lemma E. Suppose assumption (A1') and (A2') hold. If $\mathrm{K}\mathrm{L}({q}^{(t)}{\Vert}{q}_{\ast }^{(t)})\leqslant \frac{1}{{t}^{2}}$ for t ⩾ 2, then

Proof. Recall the definition of ${g}^{(t)},{\bar{g}}^{(t)}$ and ${q}_{\ast }^{(t)}$ (see notations in subsection 2.1). We set ${\gamma }_{t+1}=\frac{{\sum }_{s=1}^{t}\;s}{{\lambda }_{2}{\sum }_{s=1}^{t+1}\;s}=\frac{t}{{\lambda }_{2}(t+2)}$. Note that for t ⩾ 1,

Equation (13)

Equation (14)

Equation (15)

Therefore, we have for t ⩾ 2 from proposition A with δ = 1/t < 1,

Moreover, we have for t ⩾ 2,

This finishes the proof. □

B.3. Outer loop complexity

Based on the auxiliary results and the convex optimization theory developed in Nesterov [16], Xiao [17], we now prove theorem B which is an extension of theorem 1.

Proof of theorem B. For t ⩾ 1 we define,

From the definition, the density ${q}_{\ast }^{(t+1)}\in {\mathcal{P}}_{2}$ calculated in algorithm 4 maximizes Vt (q). We denote ${V}_{t}^{\ast }=V({q}_{\ast }^{(t+1)})$. Then, for t ⩾ 2, we get

Equation (16)

where for the first inequality we used the optimality of ${q}_{\ast }^{(t)}$ and the strong convexity (10) at ${q}_{\ast }^{(t)}$, and for the final inequality we used lemma E.

We set ${R}_{t}=\left(\frac{3}{2}p+15\right)\frac{{\lambda }_{2}}{{\lambda }_{1}}\,\mathrm{log}(1+t)$ and also ${\gamma }_{t+1}=\frac{{\sum }_{s=1}^{t}\;s}{{\lambda }_{2}{\sum }_{s=1}^{t+1}\;s}=\frac{t}{{\lambda }_{2}(t+2)}$, as done in the proof of lemma E.

From assumptions (A1'), (A2') and ${q}_{\ast }^{(t)}=\mathrm{exp}\left(-\frac{{\sum }_{s=1}^{t-1}\;s{g}^{(s)}}{{\lambda }_{2}{\sum }_{s=1}^{t}\;s}\right)/\int \mathrm{exp}\left(-\frac{{\sum }_{s=1}^{t-1}\;s{g}^{(s)}(\theta )}{{\lambda }_{2}{\sum }_{s=1}^{t}\;s}\right)\mathrm{d}\theta \ (t\geqslant 2)$, we have for t ⩾ 2,

Equation (17)

Using (17) and applying lemma D with ${\sigma }^{2}=\frac{1}{2{\gamma }_{t}{\lambda }_{1}},\enspace \frac{1}{2{\gamma }_{t+1}{\lambda }_{1}}$ and R = Rt , we have for t ⩾ 2,

where for the fifth inequality we used (15) and for the sixth inequality we used 15λ2/λ1Rt .

Applying Young's inequality $ab\leqslant \frac{{a}^{2}}{2\delta }+\frac{\delta {b}^{2}}{2}$ with $a=\left(2+2\left(\frac{3}{2}p+15\right){\lambda }_{2}\,\mathrm{log}(1+t)\right)$, $b={\Vert}{q}_{\ast }^{(t)}-{q}_{\ast }^{(t+1)}{{\Vert}}_{{L}_{1}(\mathrm{d}\theta )}$, and $\delta =\frac{{\lambda }_{2}}{2}(t+1)$, we get

Equation (18)

Combining (16) and (18), we have for t ⩾ 2,

Equation (19)

where we set ${\alpha }_{t}=O\left((1+\mathrm{exp}(8/{\lambda }_{2})){p}^{2}{\lambda }_{2}\,{\mathrm{log}}^{2}(1+t)\right)$.

From proposition B, (14), and (15),

meaning $e({q}_{\ast }^{(t)})\geqslant 0$. Hence,

Summing the inequality (19) over t ∈ {2, ..., T + 1},

Equation (20)

where we used ${\lambda }_{2}t\left\vert e({q}^{(t)})-e({q}_{\ast }^{(t)})\right\vert ={\alpha }_{t}$ (lemma E), 2αt = O(αt ), and $e({q}_{\ast }^{(T+2)})\geqslant 0$.

On the other hand, for $\forall {q}_{\ast }\in {\mathcal{P}}_{2}$,

Equation (21)

Using (A1'), (A2'), and (A3'), we have for any density function q,

Equation (22)

Hence, from (20)–(22), and the convexity of the loss,

Taking the expectation with respect to the history of examples, we have

B.4. Inner loop complexity

We next prove corollary 1 which gives an estimate of inner loop iteration complexity. This result is derived by utilizing the convergence rate of the Langevin algorithm under LSI developed in [40]. We here consider the ideal algorithm 2 (i.e. warm-start and exact mean field limit (epsilon = 0)).

Proof of corollary 1. We verify the assumptions required in theorem 2. We recall that ${q}_{\ast }^{(t+1)}$ takes the form of Boltzmann distribution: for t ⩾ 1,

Note that $\frac{{\lambda }_{1}}{{\lambda }_{2}}\geqslant \frac{{\lambda }_{1}t}{{\lambda }_{2}(t+2)}\geqslant \frac{{\lambda }_{1}}{3{\lambda }_{2}}\ (t\geqslant 1)$ and $\left\vert \frac{1}{{\lambda }_{2}\;{\sum }_{s=1}^{t+1}s}\;{\sum }_{s=1}^{t}\;s{\partial }_{z}\ell ({h}_{x}^{(t)},{y}_{t})h(\cdot ,{x}_{t})\right\vert \leqslant \frac{2t}{{\lambda }_{2}(t+2)}\leqslant \frac{2}{{\lambda }_{2}}$. Therefore, from example B and lemma B, we know that ${q}_{\ast }^{(t+1)}$ satisfies the log-Sobolev inequality with a constant $\frac{2{\lambda }_{1}}{3{\lambda }_{2}\,\mathrm{exp}(8/{\lambda }_{2})}$; in addition, the gradient of $\mathrm{log}({q}_{\ast }^{(t+1)})$ is $\frac{2}{{\lambda }_{2}}(1+{\lambda }_{1})$-Lipschitz continuous. Therefore, from theorem 2 we deduce that Langevin algorithm with learning rate ${\eta }_{t}\leqslant \frac{{\lambda }_{1}{\lambda }_{2}{\delta }_{t+1}}{96p{(1+{\lambda }_{1})}^{2}\mathrm{exp}(8/{\lambda }_{2})}$ yields qt+1 satisfying $\mathrm{K}\mathrm{L}({q}^{(t+1)}{\Vert}{q}_{\ast }^{(t+1)})\leqslant {\delta }_{t+1}$ within $\frac{3{\lambda }_{2}\mathrm{exp}(8/{\lambda }_{2})}{2{\lambda }_{1}{\eta }_{t}}\,\mathrm{log}\,\frac{2\mathrm{K}\mathrm{L}({q}^{(t)}{\Vert}{q}_{\ast }^{(t+1)})}{{\delta }_{t+1}}$-iterations.

We next bound $\mathrm{K}\mathrm{L}({q}^{(t)}{\Vert}{q}_{\ast }^{(t+1)})$. Apply proposition C with q = q(t), ${q}_{\ast }={q}_{\ast }^{(t+1)}$, and ${q}_{{\sharp}}={q}_{\ast }^{(t)}$. Note that in this setting, constants c*, c , λ*, and λ satisfy

Then, we get

Hence, we can conclude $\mathrm{K}\mathrm{L}({q}^{(t)}{\Vert}{q}_{\ast }^{(t+1)})$ are uniformly bounded with respect to t ∈ {1, ..., T} as long as $\mathrm{K}\mathrm{L}({q}^{(t)}{\Vert}{q}_{\ast }^{(t)})\leqslant {\delta }_{t}$ and q(1) is a Gaussian distribution. □

Case of resampling. We note that for resampling scheme, the similar inner loop complexity of $O\left(\frac{{\lambda }_{2}\,\mathrm{exp}(8/{\lambda }_{2})}{{\lambda }_{1}{\eta }_{t}}\,\mathrm{log}\,\frac{2\mathrm{K}\mathrm{L}({q}^{(1)}{\Vert}{q}_{\ast }^{(t+1)})}{{\delta }_{t+1}}\right)$ can be immediately obtained by replacing the initial distribution of Langevin algorithm with q(1)(θ)dθ. Moreover, the uniform boundedness of $\mathrm{K}\mathrm{L}({q}^{(1)}{\Vert}{q}_{\ast }^{(t+1)})$ with respect to t is also guaranteed by applying proposition C with q = q = q(1) and ${q}_{\ast }={q}_{\ast }^{(t+1)}$ as long as q(1)(θ)dθ is a Gaussian distribution.

Additional results and discussions

Appendix C.: Discretization error of finite particles

C.1. Case of resampling

As discussed in subsection 2.1, to establish the finite-particle convergence guarantees of algorithm 3 with resampling up to O(epsilon)-error, we need to show that ${h}_{x}^{(t)}={h}_{{\tilde{{\Theta}}}^{(t)}}({x}_{t})$ satisfies the condition $\vert {h}_{x}^{(t)}-{h}_{{q}^{(t)}}({x}_{t})\vert \leqslant {\epsilon}$ in (A3'). Hence, we are interested in characterizing the discretization error that stems from using finitely many particles.

For the resampling scheme, we can easily derive that the required number of particles is O(epsilon−2 log(T/δ)) with high probability 1 − δ, because i.i.d. particles are obtained by the Langevin algorithm and Hoeffding's inequality is applicable.

Lemma F (Hoeffding's inequality). Let Z, Z1, ..., Zm be i.i.d. random variables taking values in [−a, a] for a > 0. Then, for any ρ > 0, we get

C.2. Case of warm-start

We next consider the warm-start scheme. Note that the convergence of PDA with warm-start is guaranteed by coupling it with its mean-field limit M and applying theorem 1 without tolerance (i.e. epsilon = 0). To analyze the particle complexity, we make an additional assumption regarding the regularity of the loss function and the model.

Assumption B. 

  • (A5)  
    h(⋅, x) is one-Lipschitz continuous 9 for $\forall x\in \mathcal{X}$.

Remark. The above regularity assumption is common in the literature and cover many important problem settings in the optimization of two-layer neural network in the mean field regime. Indeed, (A5) is satisfied for two-layer network in example 1 when the output or input layer is fixed and when the activation function is Lipschitz continuous.

The following proposition shows the convergence of algorithm 1 to algorithm 2 as M.

Proposition D (finite particle approximation). For training examples ${\left\{{x}_{t}\right\}}_{t=1}^{T}$ and any example $\tilde{x}$, define

Under (A1'), (A2), (A4), and (A5), if we run PDA (algorithm 1) on $\tilde{{\Theta}}$ and the corresponding mean field limit DA (algorithm 2) on q, then with high probability limMρT,M = 0. Moreover, if we set ${\eta }_{t}\leqslant \frac{{\lambda }_{2}}{2{\lambda }_{1}}$, ${\lambda }_{1}\geqslant \frac{3}{2}$, and ${T}_{t}\geqslant \frac{3{\lambda }_{2}\,\mathrm{log}\left(4\right)}{(2{\lambda }_{1}-1){\eta }_{t}}$, then with probability at least 1 − δ,

Remark. Proposition D together with corollary 2 imply that under appropriate regularization, a prediction on any point with an epsilon-gap from an epsilon-accurate solution of the regularized objective (4) can be achieved with high probability by running PDA with warm-start (algorithm 1) in poly(epsilon−1) steps using poly(epsilon−1) particles, where we omit dependence on hyperparameters and logarithmic factors. Note that specific choices of hyper-parameters in proposition D are consistent with those in corollary 2. We also remark that under weak regularization (vanishing λ1), our current derivation suggests that the required particle size could be exponential in the time horizon, due to the particle correlation in the warm-start scheme. Finally, we remark that for the empirical risk minimization, the term log(2(T + 1)2/δ) could be changed to log(2n(T + 1)/δ) in the obvious way.

Proof of proposition D. We analyze an error of finite particle approximation for a fixed history of data ${\left\{{x}_{t}\right\}}_{t=1}^{T}$. To algorithm 2 with the corresponding particle dynamics (algorithm 1), we construct an semi PDA update, which is an intermediate of these two algorithms. In particular, the semi PDA method is defined by replacing ${h}_{{\tilde{{\Theta}}}^{(t)}}$ in algorithm 1 with ${h}_{{q}^{(t)}}$ for q(t) in algorithm 2. Let ${\tilde{{\Theta}}}^{\prime (t)}={\left\{{\tilde{\theta }}_{r}^{\prime (t)}\right\}}_{r=1}^{M}$ be parameters obtained in outer loop of the semi PDA. We first estimate the gap between algorithm 2 and the semi PDA.

Note that there is no interaction among ${\tilde{{\Theta}}}^{\prime (t)}$; in other words these are i.i.d. particles sampled from q(t), and we can thus apply Hoeffding's inequality (lemma F) to ${h}_{{\tilde{{\Theta}}}^{\prime (t)}}(\tilde{x})$ and ${h}_{{\tilde{{\Theta}}}^{\prime (t)}}({x}_{s})\ (s\in \left\{1,\dots ,T\right\},t\in \left\{1,\dots ,T+1\right\})$. Hence, for ∀δ > 0, ∀s ∈ {1, ..., T}, and ∀t ∈ {1, ..., T + 1}, with the probability at least 1 − δ

Equation (23)

Equation (24)

We next bound the gap between the semi PDA and algorithm 1 sharing a history of Gaussian noises and initial particles. That is, ${\tilde{\theta }}_{r}^{(1)}={\tilde{\theta }}_{r}^{\prime (1)}$. Let ${{\Theta}}^{(k)}={\left\{{\theta }_{r}^{(k)}\right\}}_{r=1}$ and ${{\Theta}}^{\prime (k)}={\left\{{\theta }_{r}^{\prime (k)}\right\}}_{r=1}$ denote inner iterations of these methods.

  • (a)  
    Here we show the first statement of the proposition. We set ρ1 = 0 and ${\bar{\rho }}_{1}=0$. We define ρt and ${\bar{\rho }}_{t}$ recursively as follows
    Equation (25)
    and ${\bar{\rho }}_{t+1}={\mathrm{max}}_{s\in \left\{1,\dots ,t+1\right\}}{\rho }_{s}$. We show that for any event where (23) and (24) hold, ${\Vert}{\tilde{\theta }}_{r}^{(t)}-{\tilde{\theta }}_{r}^{\prime (t)}{{\Vert}}_{2}\leqslant {\rho }_{t}\ (\forall t\in \left\{1,\dots ,T+1\right\},\ \forall r\in \left\{1,\dots ,M\right\})$ by induction. Suppose ${\Vert}{\tilde{\theta }}_{r}^{(s)}-{\tilde{\theta }}_{r}^{\prime (s)}{{\Vert}}_{2}\leqslant {\rho }_{s}\ (\forall s\in \left\{1,\dots ,t\right\},\ \forall r\in \left\{1,\dots ,M\right\})$ holds. Then, for any x and s ∈ {1, ..., t}
    Equation (26)
    Consider the inner loop at t-the outer step. Then, for an event where (23) holds,
    Expanding this inequality,
    Hence, ${\Vert}{\tilde{\theta }}_{r}^{(t)}-{\tilde{\theta }}_{r}^{\prime (t)}{{\Vert}}_{2}\leqslant {\bar{\rho }}_{T+1}$ for ∀t ∈ {1, ..., T + 1}.Noting that ${\bar{\rho }}_{1}=0$ and
    we see ${\bar{\rho }}_{T+1}\to 0$ as M → +. Then, the proof is finished because for ∀t ∈ {1, ..., T + 1} and ∀s ∈ {1, ..., T} with high probability 1 − δ,
  • (b)  
    We next show the second statement of the proposition. We change the definition (25) of ρt+1 as follows:
    We prove that for any event where (23) and (24) hold, ${\Vert}{\tilde{\theta }}_{r}^{(t)}-{\tilde{\theta }}_{r}^{\prime (t)}{{\Vert}}_{2}\leqslant {\rho }_{t}\ (\forall t\in \left\{1,\dots ,T+1\right\},\ \forall r\in \left\{1,\dots ,M\right\})$ by induction. Suppose ${\Vert}{\tilde{\theta }}_{r}^{(s)}-{\tilde{\theta }}_{r}^{\prime (s)}{{\Vert}}_{2}\leqslant {\rho }_{s}\ (\forall s\in \left\{1,\dots ,t\right\},\ \forall r\in \left\{1,\dots ,M\right\})$ holds. Consider the inner loop at t-step. Note that ${\eta }_{t}\leqslant \frac{{\lambda }_{2}}{2{\lambda }_{1}}$ implies $1-\frac{2{\lambda }_{1}t{\eta }_{t}}{{\lambda }_{2}(t+2)} > 0$. Therefore, by the similar argument as above, we get
    Expanding this inequality,
    where we used $0< 1+\frac{(1-2{\lambda }_{1})t{\eta }_{t}}{{\lambda }_{2}(t+2)}< 1$ and ${\lambda }_{1}\geqslant \frac{3}{2}$.Noting that (1 − x)1/x ⩽ exp(−1) for ∀x ∈ (0, 1], we see that
    where we used ${T}_{t}\geqslant \frac{3{\lambda }_{2}\,\mathrm{log}\left(4\right)}{(2{\lambda }_{1}-1){\eta }_{t}}$. Hence, we know that for t,
    Equation (27)
    This means that ${\Vert}{\tilde{\theta }}_{r}^{(t+1)}-{\tilde{\theta }}_{r}^{\prime (t+1)}{{\Vert}}_{2}\leqslant {\rho }_{t+1}$ and finishes the induction.Next, we show
    Equation (28)
    This inequality obviously holds for t = 1 because ${\bar{\rho }}_{1}=0$. We suppose it is true for tT. Then,
    Hence, the inequality (28) holds for ∀t ∈ {1, ..., T + 1}, yielding
    In summary, it follows that for ∀t ∈ {1, ..., T + 1} and ∀s ∈ {1, ..., T} with high probability 1 − δ,
    where we used (26). This completes the proof. □

Appendix D.: Generalization bounds for empirical risk minimization

In this section, we give generalization bounds for the problem (3) in the context of empirical risk minimization, by using techniques developed by Chen et al [24]. We consider the smoothed hinge loss and squared loss for binary classification and regression problems, respectively.

D.1. Auxiliary results

For a set $\mathcal{F}$ of functions from a space $\mathcal{Z}$ to $\mathbb{R}$ and a set $S={\left\{{z}_{i}\right\}}_{i=1}^{n}\subset \mathcal{Z}$, the empirical Rademacher complexity ${\hat{\mathfrak{R}}}_{S}(\mathcal{F})$ is defined as follows:

where $\sigma ={({\sigma }_{i})}_{i=1}^{n}$ are i.i.d random variables taking −1 or 1 with equal probability.

We introduce the uniform bound using the empirical Rademacher complexity (see Mohri et al [58]).

Lemma G (uniform bound). Let $\mathcal{F}$ be a set of functions from $\mathcal{Z}$ to $[-C,C]\ (C\in \mathbb{R})$ and $\mathcal{D}$ be a distribution over $\mathcal{Z}$. Let $S={\left\{{z}_{i}\right\}}_{i=1}^{n}\subset \mathcal{Z}$ be a set of size n drawn from $\mathcal{D}$. Then, for any δ ∈ (0, 1), with probability at least 1 − δ over the choice of S, we have

The contraction lemma (see Shalev-Shwartz and Ben-David [59]) is useful in estimating the Rademacher complexity.

Lemma H (contraction lemma). Let ${\phi }_{i}:\mathbb{R}\to \mathbb{R}\ (i\in \left\{1,\dots ,n\right\})$ be ρ-Lipschitz functions and $\mathcal{F}$ be a set of functions from $\mathcal{Z}$ to $\mathbb{R}$. Then it follows that for any ${\left\{{z}_{i}\right\}}_{i=1}^{n}\subset \mathcal{Z}$,

Let p0(θ)dθ be a distribution in proportion to $\mathrm{exp}\left(-\frac{{\lambda }_{1}}{{\lambda }_{2}}{\Vert}\theta {{\Vert}}_{2}^{2}\right)\mathrm{d}\theta $. We define a family of mean field neural networks as follows: for R > 0,

The Rademacher complexity of this function class is obtained by Chen et al [24].

Lemma I (Chen et al [24]). Suppose |hθ (x)| ⩽ 1 holds for ∀θ ∈ Ω and $\forall x\in \mathcal{X}$. We have for any constant $R\leqslant \frac{1}{2}$ and set $S\subset \mathcal{X}$ of size n,

D.2. Generalization bound on the binary classification problems

We here give a generalization bound for the binary classification problems. Hence, we suppose $\mathcal{Y}=\left\{-1,1\right\}$ and consider the problem (3) with the smoothed hinge loss defined below

We also define the 0–1 loss as ${\ell }_{01}(z,y)=\mathbb{1}[zy< 0]$.

Theorem C. Let $\mathcal{D}$ be a distribution over $\mathcal{X}\times \mathcal{Y}$. Suppose there exists a true distribution $q{}^{\circ}\in {\mathcal{P}}_{2}$ satisfying hq°(x)y ⩾ 1/2 for $\forall (x,y)\in \mathrm{s}\mathrm{u}\mathrm{p}\mathrm{p}(\mathcal{D})$ and KL(q°||p0) ⩽ 1/2. Let $S={\left\{({x}_{i},{y}_{i})\right\}}_{i=1}^{n}$ be training examples independently sampled from $\mathcal{D}$. Suppose |hθ (x)| ⩽ 1 holds for $\forall (\theta ,x)\in {\Omega}\times \mathcal{X}$. Then, for the minimizer ${q}_{\ast }\in {\mathcal{P}}_{2}$ of the problem (3), it follows that with probability at least 1 − δ over the choice of S,

Proof. We first estimate a radius R to satisfy ${q}_{\ast }\in {\mathcal{F}}_{\mathrm{K}\mathrm{L}}(R)$. Note that the regularization term of objective $\mathcal{L}(q)$ is λ2KL(q||p0) and that (hq°(xi ), yi ) = 0 from the assumption on q° and the definition of the smoothed hinge loss. Since $\mathcal{L}({q}_{\ast })\leqslant \mathcal{L}(q{}^{\circ})$, we get

Equation (29)

Equation (30)

Especially, setting R = KL(q°||p0), we see ${q}_{\ast }\in {\mathcal{F}}_{\mathrm{K}\mathrm{L}}(R)$.

We next define the set of composite functions of loss and mean field neural networks as follows:

Equation (31)

Since (z, y) is four-Lipschitz continuous with respect to z, we can estimate the Rademacher complexity ${\hat{\mathfrak{R}}}_{S}(\mathcal{F}(R))$ by using lemma H with ϕi (⋅) = (⋅, yi ) as follows:

Equation (32)

where we used lemma I for the last inequality.

From the boundedness assumption on hq , we have 0 ⩽ (hq (x), y) ⩽ 5 for $\forall q\in {\mathcal{P}}_{2}$. Applying lemma G with $\mathcal{F}=\mathcal{F}(R)$, we have with probability at least 1 − δ,

where we used 01(z, y) ⩽ (z, y), (30) and (32). □

This theorem results in the following corollary:

Corollary A. Suppose the same assumptions in theorem C hold. Moreover, we set ${\lambda }_{1}=\lambda /\sqrt{n}\ (\lambda > 0)$ and ${\lambda }_{2}=1/\sqrt{n}$. Then, the following bound holds with the probability at least 1 − δ over the choice of training examples,

where ${p}_{0}^{\prime }$ is the Gaussian distribution in proportion to $\mathrm{exp}(-\lambda {\Vert}\cdot {{\Vert}}_{2}^{2})$.

D.3. Generalization bound on the regression problem

We here give a generalization bound for the regression problems. We consider the squared loss (z, y) = 0.5(zy)2 and the bounded label $\mathcal{Y}\subset [-1,1]$.

Theorem D. Let $\mathcal{D}$ be a distribution over $\mathcal{X}\times \mathcal{Y}$. Suppose there exists a true distribution $q{}^{\circ}\in {\mathcal{P}}_{2}$ satisfying y = hq°(x) for $\forall (x,y)\in \mathrm{s}\mathrm{u}\mathrm{p}\mathrm{p}(\mathcal{D})$ and KL(q°||p0) ⩽ 1/2. Let $S={\left\{({x}_{i},{y}_{i})\right\}}_{i=1}^{n}$ be training examples independently sampled from $\mathcal{D}$. Suppose |hθ (x)| ⩽ 1 holds for $\forall (\theta ,x)\in {\Omega}\times \mathcal{X}$. Then, for the minimizer ${q}_{\ast }\in {\mathcal{P}}_{2}$ of the problem (3), it follows that with probability at least 1 − δ over the choice of S,

Proof. The proof is very similar to that of theorem C. Note that (hq°(xi ), yi ) = 0 from the assumption on q° and that inequalities (29) and (30) hold in this case too. Hence, setting R = KL(q°||p0), we see ${q}_{\ast }\in {\mathcal{F}}_{\mathrm{K}\mathrm{L}}(R)$.

Since (z, y) is two-Lipschitz continuous with respect to z ∈ [−1, 1] for any $y\in \mathcal{Y}\subset [-1,1]$, we can estimate the Rademacher complexity ${\hat{\mathfrak{R}}}_{S}(\mathcal{F}(R))$ of $\mathcal{F}(R)$ (defined in (31)) in the same way as theorem C:

Equation (33)

From the boundedness assumption on hq and $\mathcal{Y}$, we have 0 ⩽ (hq (x), y) ⩽ 2 for $\forall q\in {\mathcal{P}}_{2}$. Hence, applying lemma G with $\mathcal{F}=\mathcal{F}(R)$, we have with probability at least 1 − δ,

where we used (30) and (33). □

This theorem results in the following corollary:

Corollary B. Suppose the same assumptions in theorem D hold. Moreover, we set ${\lambda }_{1}=\lambda /\sqrt{n}\ (\lambda > 0)$ and ${\lambda }_{2}=1/\sqrt{n}$. Then, the following bound holds with the probability at least 1 − δ over the choice of training examples,

where ${p}_{0}^{\prime }$ is the Gaussian distribution in proportion to $\mathrm{exp}(-\lambda {\Vert}\cdot {{\Vert}}_{2}^{2})$.

Appendix E.: Additional discussions

E.1. Efficient implementation of PDA

Note that similar to SGD, algorithm 1 only requires gradient queries (and additional Gaussian noise); in particular, a weighted average ${\bar{g}}^{(t)}$ of functions g(t) is updated and its derivative with respect to parameters is calculated. In the case of empirical risk minimization, this procedure can be implemented as follows. We use ${\left\{{w}_{i}\right\}}_{i=1}^{n}$ (initialized as zeros) to store the weighted sums of ${\partial }_{z}\ell ({h}_{{\tilde{{\Theta}}}^{(t)}}({x}_{{i}_{t}}),{y}_{{i}_{t}})$. At step t in the outer loop, ${w}_{{i}_{t}}$ is updated as

The average ${\nabla }_{{\theta }_{r}}{\bar{g}}^{(t)}({{\Theta}}^{(k)})$ can then be computed as

where we use ${\left\{{\theta }_{r}^{(k)}\right\}}_{k=1}^{M}$ to denote parameters Θ(k) at step k of the inner loop. This formulation makes algorithm 1 straightforward to implement.

In addition, the PDA algorithm can also be implemented with mini-batch update, in which a set of data indices It = {it,1, ..., it,b } ⊂ {1, 2, ..., n} is selected per outer loop step instead of one single index it . Due to the reduced variance, mini-batch update can stabilize the algorithm and lead to faster convergence. Our theoretical results in the sequel trivially extends to the mini-batch setting.

E.2. Extension to multi-class classification

We give a natural extension of PDA method to multi-class classification settings. Let $\mathcal{C}$ denote the finite set of all class labels and $\vert \mathcal{C}\vert $ denote its cardinality. For multi-class classification problems, we define a component h(θ, x) of an ensemble as follows. Let ${a}_{r}\in {\mathbb{R}}^{\vert \mathcal{C}\vert }$ and ${b}_{r}\in {\mathbb{R}}^{d}\ (r\in \left\{1,\dots ,M\right\})$ be parameters for output and input layers, respectively, and set θr = (ar , br ) and ${\Theta}={\left\{{\theta }_{r}\right\}}_{r=1}^{M}$. Then, we define ${h}_{{\theta }_{r}}(x)=h(\theta ,x)={\sigma }_{2}({a}_{r}{\sigma }_{1}({b}_{r}^{\top }x))$ which is a neural network with one hidden neuron, 10 and denote

Note that hΘ(x) is a natural two-layer neural network with multiple outputs. Suppose that each parameter θr follows q(θ)dθ. Then the mean field limit can be defined as

Let $\ell (z,y)\ (z={\left\{{z}_{y}\right\}}_{y\in \mathcal{C}}\in {\mathbb{R}}^{\vert \mathcal{C}\vert },y\in \mathcal{C})$ be the loss for multi-class classification problems. A typical choice is the cross-entropy loss with the soft-max activation, that is

In this case, the functional derivative of (hq (x), y) with respect to q is

where we supposed the outputs of hθ and hq are also indexed by $\mathcal{C}$. Hence, the counterpart of g(t) in algorithm 2 in this setting is

Using this function, the DA method for multi-class classification problems can be obtained in the same manner as algorithm 2. Moreover, its discretization can be also immediately derived by replacing the function ${\bar{g}}^{(t)}$ used in algorithm 1 with

In the case of empirical risk minimization, we can adopt an efficient implementation as done in section 5.1. We use ${\left\{{w}_{i,y}\right\}}_{i\in \left\{1,\dots ,n\right\},y\in \mathcal{C}}$ (initialized as zeros) to store the coefficients of hy (⋅, xi ). At step t in the outer loop, ${w}_{{i}_{t},y}\ (y\in \mathcal{C})$ are updated as

Then, ${\nabla }_{{\theta }_{r}}{\bar{g}}^{(t)}({{\Theta}}^{(k)})$ can be computed as

where we use ${\left\{{\theta }_{r}^{(k)}\right\}}_{k=1}^{M}$ to denote parameters Θ(k) at step k of the inner loop.

Finally, we remark that while we here utilize a simple network hθ (x) to recover a normal two-layer neural network, it is also possible to use deep narrow networks or narrow convolutional neural networks as a component hθ (x); in other words hΘ can represent an ensemble of various types of small network. While such extensions are not covered by our current theoretical analysis, they may achieve better practical performance.

E.3. Correspondence with finite-dimensional dual averaging method

We explain the correspondence between the finite-dimensional DA method developed by [1517] and our proposed method (algorithm 2); our goal here is to provide an intuitive understanding of the derivation of algorithm 2 in the context of the classical DA method.

First, we introduce the (regularized) DA method [16, 17] in a more general form for solving the regularized optimization problem on the finite-dimensional space. Let $w\in {\mathbb{R}}^{m}$ be a parameter, l(w, z) be a convex loss in w, where z is a random variable which represents an example, and Ψ(w) is a regularization function. Then, the problem solved by the DA method is given as

Let ${\left\{{w}^{(s)}\right\}}_{s=1}^{t}$ and ${\left\{{f}^{(s)}\right\}}_{s=1}^{t}={\left\{{\partial }_{w}l({w}^{(s)},{z}_{s})\right\}}_{s=1}^{t}$ be histories of iterates and stochastic gradients. The subproblems to produce the next iterate in the DA method is designed by using the strongly convex function d(w) and positive hyperparameters ${\left\{{\alpha }_{s}\right\}}_{s=1}^{\infty }$ and ${\left\{{\beta }_{s}\right\}}_{s=2}^{\infty }$. Specifically, the next iterate w(t+1) is defined as the minimizer of the following problem in which the loss function is linearized and weighted sum of which is taken over the history:

Equation (34)

Next, we consider our problem setting of optimizing the probability distribution and reformulate the subproblem (7) solved in algorithm 2 as follows:

Equation (35)

By comparing (34) and (35), we arrive at the following correspondence: ${\alpha }_{s}={\beta }_{s}=s,\enspace {f}^{(s)}\sim {g}^{(s)},\enspace d(w)={\Psi}(w)\sim {\lambda }_{2}{\mathbb{E}}_{q}[\mathrm{log}(q)]$. We note that in our problem setting the expectation by q can be seen as an inner product with the integrand and ${\lambda }_{2}{\mathbb{E}}_{q}[\mathrm{log}(q)]$ is also set to d(w) because the negative entropy acts as a strongly convex function (lemma A).

Appendix F.: Additional experiments

F.1. Comparison of generalization error

We provide additional experimental results on the generalization performance of PDA. We consider empirical risk minimization for a regression problem (squared loss): the input ${x}_{i}\sim \mathcal{N}(0,{I}_{p})$, and f* is a single index model: f*(x) = sign(⟨w*, x⟩). W set n = 1000, p = 50, M = 200, and implement both noisy gradient descent [10] using full-batch gradient and our proposed algorithm 1 (PDA) using mini-batch update with batch size 50.

Figure 3 we compare the generalization performance of different training methods: noisy GD and PDA in the mean field regime, and also noisy GD in the kernel regime. We fix the 2 and entropy regularization to be the same across all settings: λ1 = 10−2, λ2 = 5 × 10−4. We set the total number of iterations (outer + inner loop steps) in PDA to be the same as GD, and tuned the learning rate for optimal generalization. Observe that

  • Model with the NTK scaling (green) generalizes worse than the mean field models (red and blue). This is consistent with observations in Chizat and Bach [5].
  • For the mean field scaling, PDA (under early stopping) leads to slightly lower test error than noisy GD. We intend to further investigate this difference in the generalization performance. (see appendix D for generalization bounds of the PDA solution)

Figure 3.

Figure 3. Test error of mean field neural networks (α = 1) trained with noisy GD (red) and PDA (blue), and network in the kernel regime (α = 1/2) optimized by GD (green).

Standard image High-resolution image

F.2. PDA beyond 2 regularization

Note that our current formulation (4) considers 2 regularization, which allows us to establish polynomial runtime guarantee for the inner loop via the log-Sobolev inequality. As remarked in section 4, our global convergence analysis can easily be extended to Hölder-smooth gradient via the convergence rate of Langevin algorithm given in Erdogdu and Hosseinzadeh [45]. Although we do not provide details for this extension in the current work (due to the use of Vempala and Wibisono [40]), we empirically demonstrate one of its applications in handling p regularized objectives for p > 1 in the following form,

Equation (36)

Erdogdu and Hosseinzadeh [45] cannot directly cover the non-smooth 1 regularization, but we can still obtain relatively sparse solution by setting p close to 1. Intuitively speaking, when the underlying task exhibits certain low-dimensional or sparse structure, we expect a sparsity-promoting regularization to achieve better generalization performance.

Figure 4(a) demonstrates the advantage of Lp -norm regularization for p < 2 in empirical risk minimization, when the target function exhibits sparse structure. We set n = 1000, p = 50; the teacher is a multiple-index model (m = 2) with binary activation, and parameters of each neuron are one-sparse. We optimize the student model with PDA (warm-start), where we set λ1 = 10−2, λ2 = 10−4, and vary the norm penalty p from 1.01 to 2. Note that smaller p results in favorable generalization due to the induced sparsity. On the other hand, we expect the benefit of sparse regularization to diminish when the target function is not sparse. This intuition is confirmed in figure 4(b), where we control the target sparsity by randomly selecting r parameters to be non-zero, and we define s = r/d to be the sparsity level. Observe that the benefit of sparsity-inducing regularization (smaller p) is more prominent under small s (brighter color), which indicates a sparse target function.

Figure 4.

Figure 4. PDA with general p regularizer (objective (36)). (a) Generalization error vs training time in learning a one-sparse target function. (b) Generalization error vs sparsity of the target function s.

Standard image High-resolution image

F.3. On the role of entropy regularization

Our objective (3) includes an entropy regularization with magnitude λ2. In this section we illustrate the impact of this regularization term. In figure 5(a) we consider a synthetic 1D dataset (n = 15) and plot the output of a two-layer tanh network with 200 neurons trained by SGD and PDA to minimize the squared loss till convergence. We use the same 2 regularization (λ1 = 10−3) for both algorithms, and for PDA we set the entropic term λ2 = 10−4. Observe that SGD with weak regularization (red) almost interpolates the noisy training data, whereas PDA with entropy regularization finds low-complexity solution that is smoother (blue).

Figure 5.

Figure 5. (a) 1D illustration of the impact of entropy regularization in two-layer tanh network: PDA (blue) finds a smoother solution that does not interpolate the training data due to entropy regularization. (b),(c) Test error of two-layer tanh network trained till convergence. PDA (blue) becomes advantageous compared to SGD (red) when labels become noisy, and the NTK model (green, note that the y-axis is on different scale) generalizes considerably worse than the mean field models.

Standard image High-resolution image

We therefore expect entropy regularization to be beneficial when the labels are noisy and the underlying target function (teacher) is 'simple'. We verify this intuition in figure 5(b). We set n = 500, d = 50 and M = 500, and the teacher model is a linear function on the input features. We employ SGD or PDA to optimize the squared error. For both algorithms we use the same 2 regularization λ1 = 10−2, but PDA includes a small entropy term λ2 = 5 × 10−4. We plot the generalization error of the converged model under varying amount of label noise. Note that as the labels becomes more corrupted, PDA (blue) results in lower test error due to the entropy regularization 11 . On the other hand, model under the kernel scaling (green) generalizes poorly compared to the mean field models. Furthermore, figure 5(c) demonstrates that entropy regularization can be beneficial under low noise (or even noiseless) cases as well. We construct the teacher model to be a multiple-index model with binary activation. Note that in this setting PDA achieves lower stationary risk across all noise level, and the advantage amplifies as labels are further corrupted.

F.4. Adaptivity of mean field neural networks

Recall that one motivation to study the mean field regime (instead of the kernel regime) is the presence of feature learning. We illustrate this behavior in a simple student–teacher setup, where the target function is a single-index model with tanh activation. We set n = 500, d = 50, and optimize a two-layer tanh network (M = 1000), either in the mean field regime using PDA, or in the kernel regime using SGD. For both methods we choose λ1 = 10−3, and for PDA we choose λ2 = 10−4.

In figure 6 we plot the evolution of the cosine similarity between the target vector w* and the top-five singular vectors (PC1-5) of the weight matrix during training. In figure 6(a) we observe that the mean field model trained with PDA 'adapts' to the low-dimensional structure of the target function; in particular, the leading singular vector (bright yellow) aligns with the target direction. In contrast, we do not observe such alignment on the network in the kernel regime (figure 6(b)), because the parameters do not travel away from the initialization. This comparison demonstrates the benefit of the mean field parameterization.

Figure 6.

Figure 6. Cosine similarity between the target vector w* and the top-five singular vectors (PC1-5) of the weight matrix during training. The learned parameters 'align' with the target function under the mean field parameterization (a), but not the NTK parameterization (b).

Standard image High-resolution image

Appendix G.: Additional related work

Particle inference algorithms. Bayesian inference is another example distribution optimization, in which the objective is to minimize an entropic regularized linear functional. In addition to the Langevin algorithm, several interacting particle methods have been developed for this purpose, such as particle mirror descent (PMD) [60], Stein variational gradient descent [61], and ensemble Kalman sampler [62], and the corresponding mean field limits have been analyzed in Lu et al [63], Ding and Li [64]. We remark that naive gradient-based method on the probability space often involves computing the probability of particles for the entropy term (e.g. kernel density estimation in PMD), which presents significant difficulty in constructing particle inference algorithms. In contrast, our proposed algorithm avoids this computational challenge due to its algorithmic structure.

Optimization of probability distributions. Parallel to our work, several recent papers also proposed optimization methods over space of probability measures by adapting finite-dimensional convex optimization theory [6567] extend the mirror descent method, Frank–Wolfe method, and (accelerated) Bregman proximal gradient method to the optimization of probability measures, respectively. In addition [68], developed an entropic mirror descent algorithm for generative adversarial networks, and [69] analyzed probability functional descent in the context of variational inference and reinforcement learning.

The kernel regime and beyond. The neural tangent kernel model [1] describes the learning dynamics of neural network under appropriate scaling. Such description builds upon the linearization of the learning dynamics around its initialization, and (quantitative) global convergence guarantees of gradient-based methods for neural networks can be shown for regression problems [24, 70] as well as classification problems [7173].

However, due to the linearization, the NTK model cannot explain the presence of 'feature learning' in neural networks (i.e. parameters are able to travel and adapt to the structure of the learning problem). In fact, various works have shown that deep learning is more powerful than kernel methods in terms of approximation and estimation error [7, 7478], and in certain settings, neural networks optimized with gradient-based methods can outperform the NTK model (or more generally any kernel methods) in terms of generalization error or excess risk [8, 7985].

Footnotes

  • This article is an updated version of: Nitanda A, Wu D and Suzuki T 2021 Particle dual averaging: optimization of mean field neural network with global convergence rate analysis Advances in Neural Information Processing Systems vol 34 ed M Ranzato, A Beygelzimer, Y Dauphin, P S Liang and J Wortman Vaughan pp 19608–21.

  • Note that such error yields sublinear rate with respect to arbitrarily small accuracy epsilon.

  • In algorithm 1, the terms ${\partial }_{z}\ell ({h}_{{\tilde{{\Theta}}}^{(s)}}({x}_{s}),{y}_{s})$ appear in inner loop; but note that these terms only need to be computed in outer loop because they are independent to the inner loop iterates.

  • In appendix B we introduce a more general version of theorem 1 that allows for inexact ${h}_{{q}^{(t)}}(x)$, which simplifies the analysis of finite-particle discretization presented in appendix C.

  • We use the term kernel regime only to indicate the parameter scaling α; this does not necessarily imply that the NTK linearization is an accurate description of the trained model.

  • Note that the estimated training objective (red) slightly deviates from the ideal 1/T-rate; this may be due to inaccuracy in the entropy estimation, or non-convergence of the algorithm (i.e. overestimation of $\mathcal{L}({q}_{\ast })$).

  • WLOG the Lipschitz constant is set to 1, since the same analysis works for any fixed constant.

  • 10 

    Here, ${a}_{r}{\sigma }_{1}({b}_{r}^{\top }x)$ is a scalar ${\sigma }_{1}({b}_{r}^{\top }x)$ times a vector ar .

  • 11 

    Note that entropy regularization is not the only way to reduce overfitting—such capacity control can also be achieved by proper early stopping or other types of explicit regularization.

Please wait… references are loading.
10.1088/1742-5468/ac98a8