Computing GGN eigenpairs

In this example we demonstrate how to use ViViT’s EighComputation to obtain the leading GGN eigenpairs (eigenvalues and associated eigenvectors). We verify the result with torch.autograd.

First, the imports.

from typing import List

from backpack import backpack, extend
from backpack.utils.examples import _autograd_ggn_exact_columns
from torch import Tensor, allclose, cat, cuda, device, manual_seed, rand, stack
from torch.nn import Linear, MSELoss, ReLU, Sequential

from vivit.linalg.eigh import EighComputation

# make deterministic
<torch._C.Generator object at 0x7fc4a4be4e70>

Data, model & loss

For this demo, we use toy data and a small MLP with sufficiently few parameters such that we can store the GGN matrix to verify the eigen-properties of our results (yes, one could use matrix-free GGN-vector products instead, but by expanding the GGN matrix we will familiarize ourselves more with the format of the results). We use mean squared error as loss function.

N = 4
D_in = 7
D_hidden = 5
D_out = 3

DEVICE = device("cuda" if cuda.is_available() else "cpu")

X = rand(N, D_in).to(DEVICE)
y = rand(N, D_out).to(DEVICE)

model = Sequential(
    Linear(D_in, D_hidden),
    Linear(D_hidden, D_hidden),
    Linear(D_hidden, D_out),

loss_function = MSELoss(reduction="mean").to(DEVICE)

Integrate BackPACK

Next, extend the model and loss function to be able to use BackPACK. Then, we perform a forward pass to compute the loss.

Specify GGN approximation and eigenpair filter

By default, vivit.EighComputation uses the exact GGN. Furthermore, we need to specify the GGN’s parameters via a param_groups argument that might be familiar to you from torch.optim. It also contains a filter function that selects the eigenvalues whose eigenvectors should be computed (computing all eigenvectors is infeasible for big architectures).

computation = EighComputation()

def select_top_k(evals: Tensor, k=4) -> List[int]:
    """Select the top-k eigenvalues for the eigenvector computation.

        evals: Eigenvalues, sorted in ascending order.
        k: Number of leading eigenvalues. Defaults to ``4``.

        Indices of top-k eigenvalues.
    return [evals.numel() - k + idx for idx in range(k)]

group = {
    "params": [p for p in model.parameters() if p.requires_grad],
    "criterion": select_top_k,
param_groups = [group]

Backward pass with BackPACK

We can now build the BackPACK extension and extension hook that will compute eigenpairs, pass them to a with backpack, and perform the backward pass.

extension = computation.get_extension()
extension_hook = computation.get_extension_hook(param_groups)

with backpack(extension, extension_hook=extension_hook):

This will compute the GGN eigenpairs for each parameter group and store them internally in the EighComputation instance. We can use the parameter group to request the eigenpairs.

evals, evecs = computation.get_result(group)

The eigenvectors have a similar format than the parameters. The leading axis allows to access eigenvectors for an eigenvalue.

print("Parameter shape    |  Eigenvector shape")
for p, v in zip(group["params"], evecs):
    print(f"{str(p.shape):<19}|  {v.shape}")
Parameter shape    |  Eigenvector shape
torch.Size([5, 7]) |  torch.Size([4, 5, 7])
torch.Size([5])    |  torch.Size([4, 5])
torch.Size([5, 5]) |  torch.Size([4, 5, 5])
torch.Size([5])    |  torch.Size([4, 5])
torch.Size([3, 5]) |  torch.Size([4, 3, 5])
torch.Size([3])    |  torch.Size([4, 3])

In the following, we will flatten and concatenate them among parameters, such that evecs_flat[k,:] is the GGN eigenvector with eigenvalue evals[k]:

evecs_flat = cat([e.flatten(start_dim=1) for e in evecs], dim=1)

Verify results

To verify the above, let’s compute the GGN matrix, column by column, using GGN-vector products that only rely on torch.autograd.

ggn = stack([col for _, col in _autograd_ggn_exact_columns(X, y, model, loss_function)])

We can then check that application of the GGN to an eigenvector rescales the latter by its eigenvalue.

for e, v in zip(evals, evecs_flat):
    ggn_v = ggn @ v
    close = allclose(e * v, ggn_v, rtol=1e-4, atol=1e-7)

    print(f"Eigenvalue {e:.5e}, Eigenvector properties: {close}")
    if not close:
        raise ValueError("Eigenvector properties failed!")

print("Eigenvector properties confirmed!")
Eigenvalue 9.14433e-02, Eigenvector properties: True
Eigenvalue 6.81850e-01, Eigenvector properties: True
Eigenvalue 7.62203e-01, Eigenvector properties: True
Eigenvalue 1.05583e+00, Eigenvector properties: True
Eigenvector properties confirmed!

Total running time of the script: ( 0 minutes 0.142 seconds)

Gallery generated by Sphinx-Gallery