This week we had the pleasure to host Tengyu Ma from Princeton University who told us about the recent progress he has made with co-authors to understand various linearized versions of neural networks. I will describe here two such results, one for Residual Neural Networks and one for Recurrent Neural Networks.
Some properties to look for in non-convex optimization
We will say that a function admits first order optimality (respectively second order optimality) if all critical points (respectively all local minima) of are global minima (of course first order optimality implies second order optimality for smooth functions). In particular with first order optimality one has that gradient descent converges to the global minimum, and with second order optimality this is also true provided that one avoids saddle points. To obtain rates of convergence it can be useful to make more quantitative statements. For example we say that is -Polyak if
Clearly -Polyak implies first order optimality, but more importantly it also implies linear convergence rate for gradient descent on . A variant of this condition is -weak-quasi-convexity:
in which case gradient descent converges at the slow non-smooth rate (and in this case it is also robust to noise, i.e. one can write a stochastic gradient descent version). The proofs of these statements just mimic the usual convex proofs. For more on these conditions see for instance this paper.
Linearized Residual Networks
Recall that a neural network is just a map where are linear maps (i.e. they are the matrices parametrizing the neural network) and is some non-linear map (the most popular one, ReLu, is the just the coordinate-wise positive part). Alternatively you can think of a neural network as a sequence of hidden states where and . In 2015 a team of researcher at MSR Asia introduced the concept of a residual neural network where the hidden states are now updated as before for even but for odd we set . Apparently this trick allowed them to train much deeper networks, though it is not clear why this would help from a theoretical point of view (the intuition is that at least when the network is initialized with all matrices being it still does something non-trivial, namely it computes the identity).
In their most recent paper Moritz Hardt and Tengyu Ma try to explain why adding this “identity connection” could be a good idea from a geometric point of view. They consider an (extremely) simplified model where there is no non-linearity, i.e. is the identity map. A neural network is then just a product of matrices. In particular the landscape we are looking at for least-squares with such a model is of the form:
which is of course a non-convex function (just think of the function and observe that on the segment it gives the non-convex function ). However it actually satisfies the second-order optimality condition:
Proposition [Kawaguchi 2016]
Assume that has a full rank covariance matrix and that for some deterministic matrix . Then all local minima of are global minima.
I won’t give the proof of this result as it requires to take the second derivative of which is a bit annoying (I will give below the proof of the first derivative). Now in this linearized setting the residual network version (where the identity connection is added at every layer) corresponds simply to a reparametrization around the identity, in other words we consider now the following function:
Proposition [Hardt and Ma 2016]
Assume that has a full rank covariance matrix and that for some deterministic matrix . Then has first order optimality on the set .
Thus adding the identity connection makes the objective function better behave around the starting points with all-zeros matrices (in the sense that gradient descent doesn’t have to worry about avoiding saddle points). The proof is just a few lines of standard calculations to take derivatives of functions with matrix-valued inputs.
Proof: One has with and ,
so with and ,
which exactly means that the derivative of with respect to is equal to . On the set under consideration one has that and are invertible (and so is by assumption), and thus if this derivative is equal to it muts be that and thus (which is the global minimum).
Linearized recurrent neural networks
The simplest version of a recurrent neural network is as follows. It is a mapping of the form (we are thinking of doing sequence to sequence prediction). In these networks the hidden state is updated as (with ) and the output is . I will now describe a paper by Hardt, Ma and Recht (see also this blog post) that tries to understand the geometry of least-squares for this problem in the linearized version where . That is we are looking at the function:
where is obtained from via some unknown recurrent neural network with parameters . First observe that by induction one can easily see that and . In particular, assuming that is an i.i.d. isotropic sequence one obtains
In particular we see that the effect of is decoupled from the other variables and that is appears as a convex function, thus we will just ignore it. Next we make the natural assumption that the spectral radius of is less than (for otherwise the influence of the initial input is growing over time which doesn’t seem natural) and thus up to some small error term (for large ) one can consider the idealized risk:
The next idea is a cute one which makes the above expression more tractable. Consider the series and its Fourier transform:
By Parseval’s theorem the idealized risk is equal to the distance between and (i.e. ). We will now show that under appropriate further assumptions, for any that is weakly-quasi-convex in (in particular this shows that the idealized risk is weakly-quasi-convex). The big assumption that Hardt, Ma and Recht make is that the system is a “single-input single-output” model, that is both and are scalar. In this case it turns out that control theory shows that there is a “canonical controlable form” where , and has zeros everywhere except on the upper diagonal where it has ones and on the last row where it has (I don’t know the proof of this result, if some reader has a pointer for a simple proof please share in the comments!). Note that with this form the system is simple to interpret as one has and . Now with just a few lines of algebra:
Thus we are just asking to check the weak-quasi-convexity of
Weak-quasi-convexity is preserved by linear functions, so we just need to understand the map
which is weak-quasi-convex provided that has a positive inner product with . In particular we just proved the following:
Theorem [Hardt, Ma, Recht 2016]
Let and assume there is some cone of angle less than such that . Then the idealized risk is -weakly-quasi-convex on the set of such that .
(In the paper they specifically pick the cone where the imaginary part is larger than the real part.) This theorem naturally suggests that by overparametrizing the network (i.e. adding dimensions to and ) one could have a nicer landscape (indeed in this case the above condition can be easier to check), see the paper for more details!