Computing GGN eigenvalues

In this example we demonstrate how to use ViViT’s EigvalshComputation to obtain the GGN’s eigenvalues and verify the result with torch.autograd.

First, the imports.

from backpack import backpack, extend
from backpack.utils.examples import _autograd_ggn_exact_columns
from torch import cuda, device, isclose, manual_seed, rand, stack
from torch.linalg import eigvalsh
from torch.nn import Linear, MSELoss, ReLU, Sequential

from vivit.linalg.eigvalsh import EigvalshComputation

# make deterministic
manual_seed(0)
<torch._C.Generator object at 0x7fc4a4be4e70>

Data, model & loss

For this demo, we use toy data and a small MLP with sufficiently few parameters such that we can store and eigen-decompose the GGN matrix to verify correctness. We use mean squared error as loss function.

N = 4
D_in = 7
D_hidden = 5
D_out = 3

DEVICE = device("cuda" if cuda.is_available() else "cpu")

X = rand(N, D_in).to(DEVICE)
y = rand(N, D_out).to(DEVICE)

model = Sequential(
    Linear(D_in, D_hidden),
    ReLU(),
    Linear(D_hidden, D_hidden),
    ReLU(),
    Linear(D_hidden, D_out),
).to(DEVICE)

loss_function = MSELoss(reduction="mean").to(DEVICE)

Integrate BackPACK

Next, extend the model and loss function to be able to use BackPACK. Then, we perform a forward pass to compute the loss.

Specify GGN approximation

By default, vivit.EigvalshComputation uses the exact GGN. We only need to specify the GGN’s parameters via a param_groups argument that might be familiar to you from torch.optim.

computation = EigvalshComputation()

group = {"params": [p for p in model.parameters() if p.requires_grad]}
param_groups = [group]

Backward pass with BackPACK

We can now build the BackPACK extension and extension hook that will compute GGN eigenvalues, pass them to a with backpack, and perform the backward pass.

extension = computation.get_extension()
extension_hook = computation.get_extension_hook(param_groups)

with backpack(extension, extension_hook=extension_hook):
    loss.backward()
/home/docs/checkouts/readthedocs.org/user_builds/vivit/envs/latest/lib/python3.7/site-packages/vivit/linalg/eigvalsh.py:221: UserWarning: torch.symeig is deprecated in favor of torch.linalg.eigh and will be removed in a future PyTorch release.
The default behavior has changed from using the upper triangular portion of the matrix by default to using the lower triangular portion.
L, _ = torch.symeig(A, upper=upper)
should be replaced with
L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')
and
L, V = torch.symeig(A, eigenvectors=True)
should be replaced with
L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L') (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2794.)
  gram_evals, _ = gram_mat.symeig(eigenvectors=False)

This will compute the GGN eigenvalues for each parameter group and store them internally in the EigvalshComputation instance. We can use the parameter group to request the eigenvalues.

evals = computation.get_result(group)

Verify results

Let’s compute the GGN matrix, column by column, using GGN-vector products that only rely on torch.autograd. We can then compute its eigenvalues and compare them.

ggn = stack([col for _, col in _autograd_ggn_exact_columns(X, y, model, loss_function)])
ggn_evals = eigvalsh(ggn)

ViViT eigen-decomposes the GGN Gram matrix, which is smaller than the GGN. Hence, we compare against the leading eigenvalues from the GGN eigen-decomposition:

gram_dim = evals.numel()
ggn_evals = ggn_evals[-gram_dim:]

Let’s see if the eigenvalues match.

for eval_vivit, eval_torch in zip(evals, ggn_evals):
    close = isclose(eval_vivit, eval_torch, rtol=1e-4, atol=1e-7)
    print(f"{eval_vivit:.5e} vs. {eval_torch:.5e}, close: {close}")
    if not close:
        raise ValueError("Eigenvalues don't match!")

print("Eigenvalues match!")
1.46938e-08 vs. 3.14615e-08, close: True
1.33847e-03 vs. 1.33849e-03, close: True
3.20110e-03 vs. 3.20108e-03, close: True
5.41562e-03 vs. 5.41563e-03, close: True
8.00165e-03 vs. 8.00164e-03, close: True
1.30415e-02 vs. 1.30415e-02, close: True
3.93476e-02 vs. 3.93476e-02, close: True
5.08563e-02 vs. 5.08563e-02, close: True
9.14433e-02 vs. 9.14433e-02, close: True
6.81850e-01 vs. 6.81850e-01, close: True
7.62203e-01 vs. 7.62203e-01, close: True
1.05583e+00 vs. 1.05583e+00, close: True
Eigenvalues match!

Total running time of the script: ( 0 minutes 1.358 seconds)

Gallery generated by Sphinx-Gallery