Jump to content

Neural tangent kernel

From Wikipedia, the free encyclopedia
This is an old revision of this page, as edited by Dkarkada (talk | contribs) at 07:38, 3 May 2023 (added animation and summary of main results). The present address (URL) is a permanent link to this revision, which may differ significantly from the current revision.

In the study of artificial neural networks (ANNs), the neural tangent kernel (NTK) is a kernel that describes the evolution of deep artificial neural networks during their training by gradient descent. It allows ANNs to be studied using theoretical tools from kernel methods.

In general, a kernel is a positive-semidefinite symmetric function of two inputs which represents some notion of similarity between the two inputs. The NTK is a specific kernel derived from a given neural network; in general, when the neural network parameters change during training, the NTK evolves as well. However, in the limit of large layer width the NTK becomes constant, revealing a duality between training the wide neural network and kernel methods: gradient descent in the infinite-width limit is fully equivalent to kernel gradient descent with the NTK. As a result, using gradient descent to minimize least-square loss for neural networks yields the same mean estimator as ridgeless kernel regression with the NTK. This duality enables simple closed form equations describing the training dynamics, generalization, and predictions of wide neural networks.

The NTK was introduced in 2018 by Arthur Jacot, Franck Gabriel and Clément Hongler,[1] who used it to study the convergence and generalization properties of fully connected neural networks. Later works[2][3] extended the NTK results to other neural network architectures.

Main result (informal)

Let denote the scalar function computed by a given neural network with parameters on input . Then the neural tangent kernel is defined[1] asSince it is written as a dot product between mapped inputs (with the gradient of the neural network function serving as the feature map), we are guaranteed that the NTK is symmetric and positive semi-definite. The NTK is thus a valid kernel function.

Now consider a large ensemble of fully-connected neural networks. All the neural networks are identical in architecture; the only difference between them is the random initialization of parameters. (Equivalently, one can consider a random neural network whose statistics are determined by the random parameter initialization.) Each parameter is chosen i.i.d. according to some mean-zero distribution. This random initialization of induces a distribution over which is captured by our ensemble.

At initialization, an ensemble of wide neural networks is a zero-mean Gaussian process; during training (gradient descent on mean-square error), the ensemble evolves according to the neural tangent kernel. The converged ensemble is a Gaussian process whose mean is the ridgeless kernel regression estimator and whose variance vanishes on the training points. Here, the neural network is a scalar function trained on inputs drawn from the unit circle.

The number of neurons in each layer is called the layer’s width. Consider taking the width of every hidden layer to infinity and training each neural network with gradient descent (with a suitably small learning rate). In this infinite-width limit, several nice properties emerge:

  • At initialization (before training), the neural network ensemble is a zero-mean Gaussian process (GP).[4] This means that distribution of functions is the maximum-entropy distribution with mean and covariance , where the GP covariance can be computed from the network architecture. In other words, the distribution of neural network functions at initialization has no structure other than its first and second moments (mean and covariance). This follows from the central limit theorem.
  • Each neural network is linear in its parameters.[5] This means that the network’s dependence on its parameters can be captured by its first-order Taylor expansion: . (In general, the neural network is still nonlinear with respect to the inputs.)
  • The NTK is deterministic.[1][5] In other words, the NTK of each neural network in the ensemble is identical.
  • The NTK does not change during training.[1][5]
  • Each parameter changes negligibly throughout training. Although individual parameters move by a vanishingly small amount, they collectively conspire to provide a finite change in the final output of the network, as is necessary for training.[5]
  • If the loss function is mean-squared error, the training dynamics are equivalent to kernel gradient descent using the NTK as the kernel.[1] In particular, the final ensemble distribution of functions is still a Gaussian process, but with a new mean and covariance.[5] The ensemble mean converges to the same estimator yielded by kernel regression with the NTK as kernel and zero regularization. The ensemble covariance is expressible in terms of the NTK and the GP covariance. It can be shown that the ensemble variance vanishes at the training points (in other words, all neural networks in the ensemble interpolate the training data).

Definition

Scalar output case

An ANN with scalar output consists of a family of functions parametrized by a vector of parameters .

The NTK is a kernel defined byIn the language of kernel methods, the NTK is the kernel associated with the feature map .

Vector output case

An ANN with vector output of size consists in a family of functions parametrized by a vector of parameters .

In this case, the NTK is a matrix-valued kernel, with values in the space of matrices, defined by

Derivation

When optimizing the parameters of an ANN to minimize an empirical loss through gradient descent, the NTK governs the dynamics of the ANN output function throughout the training.

Scalar output case

For a dataset with scalar labels and a loss function , the associated empirical loss, defined on functions , is given byWhen the ANN is trained to fit the dataset (i.e. minimize ) via continuous-time gradient descent, the parameters evolve through the ordinary differential equation:

During training the ANN output function follows an evolution differential equation given in terms of the NTK:

This equation shows how the NTK drives the dynamics of in the space of functions during training.

Vector output case

For a dataset with vector labels and a loss function , the corresponding empirical loss on functions is defined byThe training of through continuous-time gradient descent yields the following evolution in function space driven by the NTK:

Interpretation

The NTK represents the influence of the loss gradient with respect to example on the evolution of ANN output through a gradient descent step: in the scalar case, this readsIn particular, each data point influences the evolution of the output for each throughout the training, in a way that is captured by the NTK .

Large-width limit

Recent theoretical and empirical work in deep learning has shown the performance of ANNs to strictly improve as their layer widths grow larger.[6][7] For various ANN architectures, the NTK yields precise insight into the training in this large-width regime.[1][8][9][5][10][11]

Wide fully-connected ANNs have a deterministic NTK, which remains constant throughout training

Consider an ANN with fully-connected layers of widths , so that , where is the composition of an affine transformation with the pointwise application of a nonlinearity , where parametrizes the maps . The parameters are initialized randomly, in an independent, identically distributed way.

As the widths grow, the NTK's scale is affected by the exact parametrization of the 's and by the parameter initialization. This motivates the so-called NTK parametrization . This parametrization ensures that if the parameters are initialized as standard normal variables, the NTK has a finite nontrivial limit. In the large-width limit, the NTK converges to a deterministic (non-random) limit , which stays constant in time.

The NTK is explicitly given by , where is determined by the set of recursive equations:

where denotes the kernel defined in terms of the Gaussian expectation:

In this formula the kernels are the ANN's so-called activation kernels.[12][13][4]

Wide fully connected networks are linear in their parameters throughout training

The NTK describes the evolution of neural networks under gradient descent in function space. Dual to this perspective is an understanding of how neural networks evolve in parameter space, since the NTK is defined in terms of the gradient of the ANN's outputs with respect to its parameters. In the infinite width limit, the connection between these two perspectives becomes especially interesting. The NTK remaining constant throughout training at large widths co-occurs with the ANN being well described throughout training by its first order Taylor expansion around its parameters at initialization:[5]

Other architectures

The NTK can be studied for various ANN architectures,[10] in particular convolutional neural networks (CNNs),[14] recurrent neural networks (RNNs) and transformers.[15] In such settings, the large-width limit corresponds to letting the number of parameters grow, while keeping the number of layers fixed: for CNNs, this involves letting the number of channels grow.

Applications

Convergence to a global minimum

For a convex loss functional with a global minimum, if the NTK remains positive-definite during training, the loss of the ANN converges to that minimum as . This positive-definiteness property has been shown in a number of cases, yielding the first proofs that large-width ANNs converge to global minima during training.[1] [8] [16] [17] [18]

Kernel methods

The NTK gives a rigorous connection between the inference performed by infinite-width ANNs and that performed by kernel methods: when the loss function is the least-squares loss, the inference performed by an ANN is in expectation equal to the kernel ridge regression (with zero ridge) with respect to the NTK . This suggests that the performance of large ANNs in the NTK parametrization can be replicated by kernel methods for suitably chosen kernels.[1][10]

Software libraries

Neural Tangents is a free and open-source Python library used for computing and doing inference with the infinite width NTK and neural network Gaussian process (NNGP) corresponding to various common ANN architectures.[19]

References

  1. ^ a b c d e f g h Jacot, Arthur; Gabriel, Franck; Hongler, Clement (2018), Bengio, S.; Wallach, H.; Larochelle, H.; Grauman, K. (eds.), "Neural Tangent Kernel: Convergence and Generalization in Neural Networks" (PDF), Advances in Neural Information Processing Systems 31, Curran Associates, Inc., pp. 8571–8580, arXiv:1806.07572, Bibcode:2018arXiv180607572J, retrieved 2019-11-27
  2. ^ Arora, Sanjeev; Du, Simon S.; Hu, Wei; Li, Zhiyuan; Salakhutdinov, Ruslan; Wang, Ruosong (2019-11-04). "On Exact Computation with an Infinitely Wide Neural Net". arXiv:1904.11955 [cs, stat].
  3. ^ Yang, Greg (2020-11-29). "Tensor Programs II: Neural Tangent Kernel for Any Architecture". arXiv:2006.14548 [cond-mat, stat].
  4. ^ a b Lee, Jaehoon; Bahri, Yasaman; Novak, Roman; Schoenholz, Samuel S.; Pennington, Jeffrey; Sohl-Dickstein, Jascha (2018-02-15). "Deep Neural Networks as Gaussian Processes". {{cite journal}}: Cite journal requires |journal= (help)
  5. ^ a b c d e f g Lee, Jaehoon; Xiao, Lechao; Schoenholz, Samuel S.; Bahri, Yasaman; Novak, Roman; Sohl-Dickstein, Jascha; Pennington, Jeffrey (2020). "Wide neural networks of any depth evolve as linear models under gradient descent". Journal of Statistical Mechanics: Theory and Experiment. 2020 (12): 124002. arXiv:1902.06720. Bibcode:2020JSMTE2020l4002L. doi:10.1088/1742-5468/abc62b. S2CID 62841516.
  6. ^ Novak, Roman; Bahri, Yasaman; Abolafia, Daniel A.; Pennington, Jeffrey; Sohl-Dickstein, Jascha (2018-02-15). "Sensitivity and Generalization in Neural Networks: an Empirical Study". arXiv:1802.08760. Bibcode:2018arXiv180208760N. {{cite journal}}: Cite journal requires |journal= (help)
  7. ^ Canziani, Alfredo; Paszke, Adam; Culurciello, Eugenio (2016-11-04). "An Analysis of Deep Neural Network Models for Practical Applications". arXiv:1605.07678. Bibcode:2016arXiv160507678C. {{cite journal}}: Cite journal requires |journal= (help)
  8. ^ a b Allen-Zhu, Zeyuan; Li, Yuanzhi; Song, Zhao (2018). "A convergence theory for deep learning via overparameterization". International Conference on Machine Learning. arXiv:1811.03962.
  9. ^ Du, Simon; Lee, Jason; Li, Haochuan; Wang, Liwei; Zhai, Xiyu (2019-05-24). "Gradient Descent Finds Global Minima of Deep Neural Networks". International Conference on Machine Learning: 1675–1685. arXiv:1811.03804.
  10. ^ a b c Arora, Sanjeev; Du, Simon S; Hu, Wei; Li, Zhiyuan; Salakhutdinov, Russ R; Wang, Ruosong (2019), "On Exact Computation with an Infinitely Wide Neural Net", NeurIPS: 8139–8148, arXiv:1904.11955
  11. ^ Huang, Jiaoyang; Yau, Horng-Tzer (2019-09-17). "Dynamics of Deep Neural Networks and Neural Tangent Hierarchy". arXiv:1909.08156 [cs.LG].
  12. ^ Cho, Youngmin; Saul, Lawrence K. (2009), Bengio, Y.; Schuurmans, D.; Lafferty, J. D.; Williams, C. K. I. (eds.), "Kernel Methods for Deep Learning" (PDF), Advances in Neural Information Processing Systems 22, Curran Associates, Inc., pp. 342–350, retrieved 2019-11-27
  13. ^ Daniely, Amit; Frostig, Roy; Singer, Yoram (2016), Lee, D. D.; Sugiyama, M.; Luxburg, U. V.; Guyon, I. (eds.), "Toward Deeper Understanding of Neural Networks: The Power of Initialization and a Dual View on Expressivity" (PDF), Advances in Neural Information Processing Systems 29, Curran Associates, Inc., pp. 2253–2261, arXiv:1602.05897, Bibcode:2016arXiv160205897D, retrieved 2019-11-27
  14. ^ Yang, Greg (2019-02-13). "Scaling Limits of Wide Neural Networks with Weight Sharing: Gaussian Process Behavior, Gradient Independence, and Neural Tangent Kernel Derivation". arXiv:1902.04760 [cs.NE].
  15. ^ Hron, Jiri; Bahri, Yasaman; Sohl-Dickstein, Jascha; Novak, Roman (2020-06-18). "Infinite attention: NNGP and NTK for deep attention networks". International Conference on Machine Learning. 2020. arXiv:2006.10540. Bibcode:2020arXiv200610540H.
  16. ^ Du, Simon S; Zhai, Xiyu; Poczos, Barnabas; Aarti, Singh (2019). "Gradient descent provably optimizes over-parameterized neural networks". International Conference on Learning Representations. arXiv:1810.02054.
  17. ^ Zou, Difan; Cao, Yuan; Zhou, Dongruo; Gu, Quanquan (2020). "Gradient descent optimizes over-parameterized deep ReLU networks". Machine learning. 109: 467--492.
  18. ^ Allen-Zhu, Zeyuan; Li, Yuanzhi; Song, Zhao (2018-10-29). "On the convergence rate of training recurrent neural networks". NeurIPS. arXiv:1810.12065.
  19. ^ Novak, Roman; Xiao, Lechao; Hron, Jiri; Lee, Jaehoon; Alemi, Alexander A.; Sohl-Dickstein, Jascha; Schoenholz, Samuel S. (2019-12-05), "Neural Tangents: Fast and Easy Infinite Neural Networks in Python", International Conference on Learning Representations (ICLR), vol. 2020, arXiv:1912.02803, Bibcode:2019arXiv191202803N