A couple of months ago we (Kevin Scaman, Francis Bach, Yin Tat Lee, Laurent Massoulie and myself) uploaded a new paper on distributed convex optimization. We came up with a pretty clean picture for the optimal oracle complexity of this setting, which I will describe below. I should note that there are hundreds of papers on this topic, but the point of the post is to show our simple cute calculations and not to survey the immense literature on distributed optimization, see the paper itself for a number of pointers to other recent works.
Distributed optimization setting
Let be an undirected graph on vertices () and with diameter . We will think of the nodes in as the computing units. To each vertex there is an associated convex function . For machine learning applications one can think that each computing unit has access to a “private” dataset, and represents the fit of the model corresponding to on this dataset (say measured on least squares loss, or logistic loss for example). The goal will be to find in a distributed way the optimal “consensus” point:
The distributed processing protocol is as follows: asynchronously/in parallel, each node can (i) compute a (local) gradient in time , and (ii) communicate a vector in to its neighbors in in time . We denote by the local model (essentially its guess for ) of node at time . We aim to characterize the smallest time such that one can guarantee that all nodes satisfy where .
We focus on the case where is -smooth and -strongly convex ( is the condition number), which is arguably the most challenging case since one expects linear convergence (i.e., the scaling of in should be ) which a priori makes the interaction of optimization error and communication error potentially delicate (one key finding is that in fact it is not delicate!). Also, having in mind applications outside of large-scale machine learning (such as “federated” learning), we will make no assumptions about the functions at different vertices relate to each other.
A trivial answer
Recall that Nesterov’s accelerated gradient descent solves the serial problem in time . Trivially one can distribute a step of Nesterov’s accelerated gradient descent in time (simply designate a master node at the beginning, and everybody sends its local gradient to the master node in time ). Thus we arrive at the upper bound using a trivial (centralized) algorithm. We now show (slightly informally, see the paper for proper definitions) that this in fact optimal!
First let us recall the lower bound proof in the serial case (see for example Theorem 3.15 here). The idea is to introduce the function where is the Laplacian of the path graph on , or in other words
First it is easy to see that this function is indeed -smooth and -strongly convex. The key point is that, for any algorithm starting at and such that each iteration stays in the linear span of the previously computed gradients (a very natural assumption) then
In words one can say that each gradient calculation “discovers” a new edge of the path graph involved in the definition of . Concluding the serial proof is then just a matter of brute force calculations.
Now let us move to the distributed setting, and consider two vertices and that realize the diameter of . The idea goes as follows: let (respectively ) be the Laplacian of even edges of the path graph on (respectively the odd edges), that is
Now define , , and for any . The key observation is that node does not “know” about the even edges until it receives a message from and vice versa. Thus it fairly easy to show that in this case one has:
which effectively amounts to a slowdown by a factor compared to the serial case and proves the lower bound .
Not so fast!
One can say that the algorithm proposed above defeats a bit the purpose of the distributed setting. Indeed the centralized communication protocol it relies on is not robust to various real-life issues such as machine failures, time-varying graphs, edges with different latency, etc. An elegant and practical solution is to restrict communication to be gossip-like. That is local computations have now to be communicated via matrix multiplication with a walk matrix which we define as satisfying the following three conditions: (i) , (ii) , and (iii) . Let us briefly discuss these conditions: (i) simply means that if represents real values stored at the vertices, then can be calculated with a distributed communication protocol; (ii) says that if there is consensus (that is all vertices hold the same value) then no communication occurs with this matrix multiplication; and (iii) will turn out to be natural in a just a moment for our algorithm based on duality. A prominent example of a walk matrix would be the (normalized) Laplacian of
We denote by the inverse condition number of on (that is the ratio of the smallest non-zero eigvenvalue of to its largest eigenvalue), also known as the spectral gap of when is the Laplacian. Notice that naturally controls the number of gossip steps to reach consensus, in the sense that gossip steps corresponds to gradient descent steps on , which will converge in steps. Doing an “accelerated gossip” (also known as Chebyshev gossiping) one could thus hope to essentially replace the diameter by . Notice that this is hopeful thinking because in the centralized model steps gets you to an exact consensus, while in the gossip model one only reaches an -approximate consensus and errors might compound. In fact with a bit of graph theory one can immediately see that simply replacing by is too good to be true: there are graphs (namely expanders) where is of order while is of order of a constant, and thus an upper bound of the form (say) would violate our previous lower bound by a factor .
To save the day we will make extra assumptions, namely that each local function has condition number and that in addition to computing local gradient the vertices can also compute local gradients of the Fenchel dual functions . The latter assumption can be removed at the expense of extra logarithmic factors but we will ignore this point here (see the paper for some hints as well as further discussion on this point). For the former assumption we note that the lower bound proof given above completely breaks under this assumption. However one can save the construction for some specific graphs (finding the correct generalization to arbitrary graphs is one of our open problems). For example imagine a line graph, and cluster the vertices into three groups, the first third, the middle, and the last third. Then one could distribute the even part of the Laplacian on in the first group, and the odd part on the last group, as well as distribute the Euclidean norm evenly among all vertices. This construction verifies that each vertex function has condition number and furthermore the rest of the argument still goes through. Interestingly in this case one also has and thus this proves that for the line graph one has for gossip algorithms. We will now show a matching upper bound (which holds for arbitrary graphs).
For (which we think of as a set of column vectors, one for each vertex ), denote for the column and let . We are interested in minimizing under the constraint that all columns are equal, which can be written as . By definition of the Fenchel dual and a simple change of variable one has:
Next observe that gradient ascent on can be written as
and with the notation this is simply . Crucially exactly corresponds to gossiping the local conjugate gradients (which are also the local models) . In other words we only have to understand the condition number of the function . The beauty of all of this is that this condition number is precisely (i.e. it naturally combines the condition number of the vertex functions with the “condition number” of the graph). Thus by accelerating gradient ascent we arrive at a time complexity of (recall that a gossip step takes time ). We call the corresponding algorithm SSDA (Single-Step Dual Accelerated). One can improve it slightly in the case of low communication cost by doing multiple rounds of communication between two gradient computations (essentially replacing by ). We call the corresponding algorithm MSDA (Multi-Step Dual Accelerated) and its attains the optimal (in the worst case over graphs) complexity of .