Proximal GD Idea Explained


Related to:

Classical Gradient Step for Minimizing g(x)

If we were minimizing only g(x), the standard gradient descent step would be derived from a first-order approximation:

xt+1=argminyRn(g(xt)+g(xt)T(yxt)+12γyxt2).

Here:

Adding the h(y) Term

When we introduce h(y), which is possibly non-smooth, we modify the update by keeping the gradient step for g(x) the same but adding h(y) explicitly:

xt+1=argminyRn(g(xt)+g(xt)T(yxt)+12γyxt2+h(y)).

Now, the function consists of:

  1. A linear approximation of g(y).
  2. A quadratic proximity term.
  3. The non-smooth function h(y)**.

Completing the Square

Since g(xt) is independent of y, we ignore it in the minimization problem. The key step is recognizing that the first two terms can be rewritten using the squared norm:

g(xt)T(yxt)+12γyxt2

This term is the first-order approximation of g(x) plus a regularization, and it can be rewritten in a more compact form:

12γy(xtγg(xt))2

Understanding the Completing-the-Square Step

We need to rewrite the quadratic expression:

g(xt)T(yxt)+12γyxt2

into a squared norm form plus a constant term.


Step-by-Step Breakdown

1. Expand the squared norm

The Euclidean norm squared is given by:

yxt2=(yxt)T(yxt)

So, we rewrite the term:

12γyxt2=12γ(yxt)T(yxt)

2. Introduce a shift using γg(xt)

We want to introduce a shifted term in the form y(xtγg(xt)), so we add and subtract γg(xt) cleverly.

Observe that:

g(xt)T(yxt)=1γγg(xt)T(yxt)

This suggests rewriting the term as:

g(xt)T(yxt)=1γ(yxt)T(γg(xt))

3. Expand the squared norm

Consider the squared term:

y(xtγg(xt))2

Expanding it:

(y(xtγg(xt)))T(y(xtγg(xt)))

Breaking it down:

(yxt+γg(xt))T(yxt+γg(xt))

Expanding using the identity (a+b)T(a+b)=aTa+2aTb+bTb:

yxt2+2γg(xt)T(yxt)+γ2g(xt)2

Dividing everything by 2γ:

12γyxt2+g(xt)T(yxt)+γ2g(xt)2

Thus:

12γyxt2+g(xt)T(yxt)+γ2g(xt)2=12γy(xtγg(xt))212γyxt2+g(xt)T(yxt)=12γy(xtγg(xt))2γ2g(xt)2

4. Rearrange the expression

Comparing with the original form:

g(xt)T(yxt)+12γyxt2

We get:

12γy(xtγg(xt))2γ2g(xt)2

This is the final transformed expression!

BUT! The final formula looks like this:

xt+1=argminyRn(12γy(xtγg(xt))2+h(y))

This is beacuse the term γ2g(xt)2 is a constant with respect to y and can be ignored in the minimization problem :)