14. PROXIMAL GRADIENT DESCENT
Proximal Gradient Descent is an extension of gradient descent for optimizing composite functions that consist of a smooth function and a possibly non-smooth function. It generalizes gradient descent by incorporating a proximal step that accounts for non-smooth regularization terms.
Composite Optimization Problems
We consider optimization problems of the form:
where:
is a smooth function (e.g., differentiable with an L-Lipschitz gradient). may not be differentiable but is convex.
The challenge:
- How do we minimize
when is not smooth? - We solve this using the proximal gradient descent method!
Idea of Proximal Gradient Descent
The standard gradient descent update for minimizing a smooth function
For composite functions
We just added the non-smooth term
Rewriting this, we get the proximal gradient descent update:
Here is the explanation of this transformation: Proximal GD Idea Explained
Proximal Gradient Descent Algorithm
An iteration of proximal gradient descent is defined as:
where
Steps
-
Gradient Descent Step:
Compute(just like in gradient descent) -
Proximal Minimization:
Compute the proximal operator:
This step ensures that
Proximal Gradient Descent as a Generalization of Gradient Descent
Proximal gradient descent recovers basic gradient descent and projected gradient descent as special cases:
-
If
, we recover gradient descent. -
If
(the indicator function of a convex set ), we recover projected gradient descent, where: -
The indicator function
is defined as: -
The proximal mapping simplifies to a projection onto
:
-
Convergence Rates of Proximal Gradient Descent
The convergence of proximal gradient descent follows the same principles as gradient descent, now extended to non-smooth functions
If
then proximal gradient descent satisfies:
This shows that proximal gradient descent converges at a rate of
Summary
- Proximal gradient descent extends gradient descent to composite functions
. - It consists of a gradient step for
and a proximal step for . - It generalizes gradient descent (when
) and projected gradient descent (when is an indicator function). - Convergence is similar to gradient descent for smooth functions.
This method is widely used in sparse learning, compressed sensing, and machine learning, where non-smooth regularization terms (e.g., L1-norm in Lasso regression) play a key role in inducing sparsity.