Loss & Regularization Functionals#

What is a Functional?#

A functional is a special type of operator in mathematics and computational science. Unlike typical operators which map from one vector space to another, a functional maps vectors to real numbers. Think of it as a way to assign a numerical “score” to any given input. This is particularly useful in optimization problems, where the goal is to find the input that minimizes (or maximizes) this score.

In the realm of computational imaging, functionals often serve as objective criteria to optimize. These objectives usually consist of two or more parts: a loss term and a regularization term. In Bayesian contexts, the objective functional often signifies the negative log-posterior density, which needs to be either sampled or summarized (e.g. by its moments).

Functionals Hierarchy in Pyxu#

Pyxu provides a versatile and robust class hierarchy to define various types of functionals. This structure aims to offer maximal flexibility for both elementary and complex computational imaging tasks. Below is an exploration of these different classes, complete with (very simplified) example implementations.

Important Note: The classes described below are designed as abstract base classes. This means they serve as templates or “blueprints” for creating specific functionals that suit your project needs. You’re not supposed to instantiate them directly. Instead, you have two main options:

  1. Subclassing: Extend these classes to create customized functionals tailored to your problem.

  2. Generic Constructor Routine: Utilize the from_source🔗 function to define new functionals from their core methods, like apply()🔗, grad()🔗, or prox()🔗.

Additionally, don’t forget to explore our comprehensive Reference API. It features pre-implemented versions of many commonly used functionals, serving as both a shortcut for common tasks and a useful learning resource.

Func: The Foundation Stone 🧱#

The Func🔗 class is the base class in the functional hierarchy. Its core method is apply()🔗, which computes the value of the functional at a given input.

Here’s a simplified example, implementing the squared \(L_2\) norm:

import numpy as np
from pyxu.abc import Func

class SquaredL2(Func):
    def __init__(self, dim_shape):
        super().__init__(dim_shape=dim_shape, codim_shape=1)

    def apply(self, arr):
        # Calculate over the last dimensions to support batching dimensions.
        # This ensures the squared L2 norm is applied in parallel across batches.
        axis = tuple(range(-len(self.dim_shape), 0))
        return (arr**2).sum(axis=axis)[..., np.newaxis]

This base functional can be adapted into a loss functional using the argshift()🔗 method:

l2_loss = SquaredL2(data.shape).argshift(-data)  # computes ||<arr> - data||^{2}

In this specific example, argshift() shifts the functional using a given data array, transforming it into a loss that measures the squared Euclidean distance between the input and the data.

Special Cases: argshift() doesn’t always produce a simple shift of the original functional. For instance, when applied to an entropy functional, it transforms into relative entropy.

DiffFunc: Differentiable Functionals 🎯#

The DiffFunc🔗 class extends Func for functionals that have a well-defined gradient. It introduces an additional method, grad()🔗, for gradient computation.

from pyxu.abc import DiffFunc

class SquaredL2(DiffFunc):
    def __init__(self, dim_shape):
        super().__init__(dim_shape=dim_shape, codim_shape=1)

    def apply(self, arr):
        axis = tuple(range(-len(self.dim_shape), 0))
        return (arr**2).sum(axis=axis)[..., np.newaxis]

    def grad(self, arr):
        return 2 * arr

🌈Tip: With Pyxu, you can use from_torch🔗 or from_jax🔗 to automatically compute the gradient if you have a PyTorch or JAX implementation of your functional.

ProxFunc: Proximable Functionals 🛡️#

For functionals with a simple proximal operator, you’ll find ProxFunc🔗 extremely useful. It offers the prox()🔗 method, which evaluates the proximal operator of the functional. Here’s an example using the \(L_1\) norm:

from pyxu.abc import ProxFunc

class L1Norm(ProxFunc):
    def __init__(self, dim_shape):
        super().__init__(dim_shape=dim_shape, codim_shape=1)

    def apply(self, arr):
        axis = tuple(range(-len(self.dim_shape), 0))
        return np.abs(arr).sum(axis=axis, keepdims=True)

    def prox(self, arr, tau):
        return np.sign(arr) * np.clip(np.abs(arr) - tau, 0, None)

Moreau Envelope for Smoothing#

You can also smooth out a proximable functional using the moreau_envelope()🔗 method. For example, you can smooth the L1 norm to create the Huber loss function as follows:

huber = L1Norm(dim_shape).moreau_envelope(mu=0.1)

Demystifying the Proximal Operator 🎭#

The proximal operator is a powerful tool, especially in the context of nonsmooth optimization. It allows you to iteratively update a given estimate such that the functional is minimized. It’s like asking the algorithm to fine-tune a guess toward the actual optimal solution.

For example, consider the indicator function of a convex set. The proximal operator in this context is simply a projection onto that set. Essentially, it pulls any “off-domain” points back into the permissible set, ensuring they comply with the constraints of your problem.

Mathematically, the proximal operator of a function \(f: \mathbb{R}^n \rightarrow \mathbb{R} \cup \{ +\infty \}\) is defined as:

\[\text{prox}_{\tau f}(x) = \arg \min_{u \in \mathbb{R}^n} \left( f(u) + \frac{1}{2\tau} \| u - x \|^{2} \right).\]

Here, \(\tau > 0\) is a parameter and \(\| \cdot \|\) is the Euclidean norm. The term \(\frac{1}{2\tau} \| u - x \|^{2}\) is a regularization term that pulls the solution \(u\) closer to \(x\). The proximal operator \(\text{prox}_{\tau f}(x)\) returns the point \(u\) that minimizes this expression, essentially finding a compromise between minimizing \(f(u)\) and staying close to the original point \(x\).

This mathematical tool is invaluable for optimization problems, especially those involving nonsmooth or complex functionals. It provides a way to make “smart” steps toward the minimum, even when you can’t directly calculate the gradient for all points.

Specialized Classes: ProxDiffFunc, LinFunc, QuadraticFunc 🎨#

These classes are for functionals with several properties, like being both proximable and differentiable (ProxDiffFunc🔗), linear (LinFunc🔗), or quadratic (QuadraticFunc🔗). Quadratic functionals are especially important in primal-dual methods for faster convergence, so do use them when you can!

Note: when dealing with ProxDiffFunc instances, it can be hard to decide whether the gradient or proximal operator should be used for optimization purposes. The general rule-of-thumb is to use the gradient as much as possible, as the latter requires more regularity on the objective functional, which can be leveraged by solvers for faster convergence.

Implicit Functionals: The Undercover Agents 🕵️‍♀️#

In some cases, you may not know the functional itself, but you might know its proximal operator or gradient (e.g. for plug-and-play or score-based priors). Pyxu lets you define such “implicit functionals” as follows:

from pyxu.abc import ProxFunc
from scipy.ndimage import median_filter

class MedianFilterPrior(ProxFunc):
    def __init__(self, dim_shape):
        super().__init__(dim_shape=dim_shape, codim_shape=1)

    def apply(self, arr):
        return NotImplementedError  # apply method not provided

    def prox(self, arr, tau):
        return median_filter(arr, size=5)

Crafting Custom Loss Functionals through Composition with Forward Operators#

Inverse problems often involve unknown variables that are related to observable data through a forward operator. In simple terms, a forward operator is like a “real-world filter” that transforms your unknown variable, and what you actually observe is the transformed version. In computational imaging, for example, this could represent the blurring of an image. Pyxu makes it easy to handle these intricacies by allowing you to integrate these forward operators directly into your loss functionals.

For example, let’s say you have a blurred image, represented by the variable \(b\). The blurring occurred through a known process—described by a forward operator \(A\)—applied to the original image \(x\). Mathematically, this is:

\[b = A x\]

The challenge here is to recover \(x\) given \(b\) and \(A\). But why does this matter? Because in real-world applications, you often don’t observe \(x\) directly. What you have is \(b\), and you have to work your way backward to find \(x\).

Practical Example: Deblurring through Least-Squares 🌠#

A common strategy to solve this problem is to minimize the least-squares difference between \(b\) and \(A x\). In Pyxu, you can create a composite loss functional for this exact problem like so:

from pyxu.operator import SquaredL2Norm

loss = SquaredL2Norm(dim_shape=b.shape).argshift(-b) * A

What did we do here? We took the squared \(L_2\) norm, which measures the distance between two vectors, and used the .argshift() method to turn it into a loss functional tailored to our blurred image \(b\). The * A part then integrates the forward operator \(A\) into this loss, making sure we’re comparing apples to apples—or in this case, blurred images to blurred images.

The Benefit: Automatic Propagation 🚀#

The beauty of this approach is that it streamlines your optimization. The gradient calculations and other relevant methods get automatically updated to incorporate \(A\), making your life a lot easier!

The Takeaway: Flexibility and Power 🌈#

This feature adds another layer of flexibility and power to Pyxu, allowing you to tackle a wide range of problems involving different types of transformations and physical processes with ease.