Computing directional derivatives along GGN eigenvectors

In this example we demonstrate how to use ViViT’s DirectionalDerivativesComputation to obtain the 1ˢᵗ- and 2ⁿᵈ-order directional derivatives along the leading GGN eigenvectors. We verify the result with torch.autograd.

First, the imports.

from typing import List

from backpack import backpack, extend
from backpack.utils.examples import _autograd_ggn_exact_columns
from torch import Tensor, cuda, device, einsum, isclose, manual_seed, rand, stack, zeros
from torch.autograd import grad
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.nn.utils.convert_parameters import parameters_to_vector

from vivit.optim.directional_derivatives import DirectionalDerivativesComputation

# make deterministic
<torch._C.Generator object at 0x7fc4a4be4e70>

Data, model & loss

For this demo, we use toy data and a small MLP with sufficiently few parameters such that we can store the GGN matrix to verify our results (yes, one could use matrix-free GGN-vector products instead). We use mean squared error as loss function.

N = 4
D_in = 7
D_hidden = 5
D_out = 3

DEVICE = device("cuda" if cuda.is_available() else "cpu")

X = rand(N, D_in).to(DEVICE)
y = rand(N, D_out).to(DEVICE)

model = Sequential(
    Linear(D_in, D_hidden),
    Linear(D_hidden, D_hidden),
    Linear(D_hidden, D_out),

loss_function = MSELoss(reduction="mean").to(DEVICE)

Integrate BackPACK

Next, extend the model and loss function to be able to use BackPACK. Then, we perform a forward pass to compute the loss.

Specify GGN approximation and directions

By default, vivit.DirectionalDerivativesComputation uses the exact GGN. Furthermore, we need to specify the GGN’s parameters via a param_groups argument that might be familiar to you from torch.optim. It also contains a filter function that selects the eigenvalues whose eigenvectors will be used as directions to evaluate directional derivatives.

computation = DirectionalDerivativesComputation()

def select_top_k(evals: Tensor, k=4) -> List[int]:
    """Select the top-k eigenvalues as directions to evaluate derivatives.

        evals: Eigenvalues, sorted in ascending order.
        k: Number of leading eigenvalues. Defaults to ``4``.

        Indices of top-k eigenvalues.
    return [evals.numel() - k + idx for idx in range(k)]

group = {
    "params": [p for p in model.parameters() if p.requires_grad],
    "criterion": select_top_k,
param_groups = [group]

Backward pass with BackPACK

We can now build the BackPACK extensions and extension hook that will compute directional derivatives, pass them to a with backpack, and perform the backward pass.

extensions = computation.get_extensions()
extension_hook = computation.get_extension_hook(param_groups)

with backpack(*extensions, extension_hook=extension_hook):

This will compute the directional derivatives for each parameter group and store them internally in the DirectionalDerivativesComputation instance. We can use the parameter group to request them.

gammas_vivit, lambdas_vivit = computation.get_result(group)

Verify results

To verify the above, let’s first compute the per-sample gradients and GGNs using torch.autograd.

batch_grad = []
batch_ggn = []

for n in range(N):
    x_n, y_n = X[[n]], y[[n]]

    grad_n = grad(
        loss_function(model(x_n), y_n),
        [p for p in model.parameters() if p.requires_grad],

    ggn_n = stack(
        [col for _, col in _autograd_ggn_exact_columns(x_n, y_n, model, loss_function)]

We also need the GGN eigenvectors as directions

ggn = stack([col for _, col in _autograd_ggn_exact_columns(X, y, model, loss_function)])
evals, evecs = ggn.symeig(eigenvectors=True)
keep = select_top_k(evals)
evals, evecs = evals[keep], evecs[:, keep]

We are now ready to compute and compare the target quantities.

First, compute and compare the first-order directional derivatives. Note that since the GGN eigenvectors used as directions are not unique but can point in the opposite direction. The directional gradient can thus be of different sign and we only compare the absolute value.

K = evals.numel()
gammas_torch = zeros(N, K, device=evals.device, dtype=evals.dtype)

for n in range(N):
    grad_n = batch_grad[n]
    for k in range(K):
        e_k = evecs[:, k]

        gammas_torch[n, k] = einsum("i,i", grad_n, e_k)

for gamma_vivit, gamma_torch in zip(gammas_vivit.flatten(), gammas_torch.flatten()):
    close = isclose(abs(gamma_vivit), abs(gamma_torch), rtol=1e-4, atol=1e-7)
    print(f"{gamma_vivit:.5e} vs. {gamma_torch:.5e}, close: {close}")
    if not close:
        raise ValueError("1ˢᵗ-order directional derivatives don't match!")

print("1ˢᵗ-order directional derivatives match!")
1.91209e-02 vs. -1.91209e-02, close: True
-3.92638e-01 vs. -3.92638e-01, close: True
-4.43494e-01 vs. -4.43494e-01, close: True
1.70077e-01 vs. 1.70076e-01, close: True
-1.64624e-02 vs. 1.64625e-02, close: True
-8.22119e-01 vs. -8.22119e-01, close: True
-2.33368e-01 vs. -2.33368e-01, close: True
3.45461e-01 vs. 3.45461e-01, close: True
-3.89742e-03 vs. 3.89743e-03, close: True
-3.74231e-01 vs. -3.74231e-01, close: True
-1.96793e-01 vs. -1.96793e-01, close: True
-8.43964e-02 vs. -8.43964e-02, close: True
5.49901e-02 vs. -5.49899e-02, close: True
-9.46718e-01 vs. -9.46719e-01, close: True
-7.72283e-01 vs. -7.72283e-01, close: True
3.40165e-01 vs. 3.40165e-01, close: True
1ˢᵗ-order directional derivatives match!

Next, compute and compare the second-order directional derivatives.

lambdas_torch = zeros(N, K, device=evals.device, dtype=evals.dtype)

for n in range(N):
    ggn_n = batch_ggn[n]
    for k in range(K):
        e_k = evecs[:, k]

        lambdas_torch[n, k] = einsum("i,ij,j", e_k, ggn_n, e_k)

for lambda_vivit, lambda_torch in zip(lambdas_vivit.flatten(), lambdas_torch.flatten()):
    close = isclose(lambda_vivit, lambda_torch, rtol=1e-4, atol=1e-7)
    print(f"{lambda_vivit:.5e} vs. {lambda_torch:.5e}, close: {close}")
    if not close:
        raise ValueError("2ⁿᵈ-order directional derivatives don't match!")

print("2ⁿᵈ-order directional derivatives match!")
2.20225e-01 vs. 2.20224e-01, close: True
6.65268e-01 vs. 6.65268e-01, close: True
7.21850e-01 vs. 7.21850e-01, close: True
7.13450e-01 vs. 7.13450e-01, close: True
3.93340e-03 vs. 3.93345e-03, close: True
6.90360e-01 vs. 6.90360e-01, close: True
8.00855e-01 vs. 8.00855e-01, close: True
1.36017e+00 vs. 1.36017e+00, close: True
1.37852e-01 vs. 1.37852e-01, close: True
6.78889e-01 vs. 6.78889e-01, close: True
6.86441e-01 vs. 6.86441e-01, close: True
9.45868e-01 vs. 9.45868e-01, close: True
3.76307e-03 vs. 3.76304e-03, close: True
6.92882e-01 vs. 6.92883e-01, close: True
8.39668e-01 vs. 8.39668e-01, close: True
1.20384e+00 vs. 1.20384e+00, close: True
2ⁿᵈ-order directional derivatives match!

Last, we check that the sample means of second-order derivatives coincide with the eigenvalues.

for eval_vivit, eval_torch in zip(lambdas_vivit.mean(0), evals):
    close = isclose(eval_vivit, eval_torch, rtol=1e-4, atol=1e-7)
    print(f"{eval_vivit:.5e} vs. {eval_torch:.5e}, close: {close}")
    if not close:
        print("Averaged 2ⁿᵈ-order directional derivatives don't match eigenvalues!")

print("Averaged 2ⁿᵈ-order directional derivatives match eigenvalues!")
9.14433e-02 vs. 9.14434e-02, close: True
6.81850e-01 vs. 6.81850e-01, close: True
7.62204e-01 vs. 7.62203e-01, close: True
1.05583e+00 vs. 1.05583e+00, close: True
Averaged 2ⁿᵈ-order directional derivatives match eigenvalues!

Total running time of the script: ( 0 minutes 0.671 seconds)

Gallery generated by Sphinx-Gallery