Note
Click here to download the full example code
Computing GGN eigenvalues
In this example we demonstrate how to use ViViT’s
EigvalshComputation
to obtain the GGN’s
eigenvalues and verify the result with torch.autograd
.
First, the imports.
from backpack import backpack, extend
from backpack.utils.examples import _autograd_ggn_exact_columns
from torch import cuda, device, isclose, manual_seed, rand, stack
from torch.linalg import eigvalsh
from torch.nn import Linear, MSELoss, ReLU, Sequential
from vivit.linalg.eigvalsh import EigvalshComputation
# make deterministic
manual_seed(0)
<torch._C.Generator object at 0x7fc4a4be4e70>
Data, model & loss
For this demo, we use toy data and a small MLP with sufficiently few parameters such that we can store and eigen-decompose the GGN matrix to verify correctness. We use mean squared error as loss function.
N = 4
D_in = 7
D_hidden = 5
D_out = 3
DEVICE = device("cuda" if cuda.is_available() else "cpu")
X = rand(N, D_in).to(DEVICE)
y = rand(N, D_out).to(DEVICE)
model = Sequential(
Linear(D_in, D_hidden),
ReLU(),
Linear(D_hidden, D_hidden),
ReLU(),
Linear(D_hidden, D_out),
).to(DEVICE)
loss_function = MSELoss(reduction="mean").to(DEVICE)
Integrate BackPACK
Next, extend
the model and loss function to be able
to use BackPACK. Then, we perform a forward pass to compute the loss.
model = extend(model)
loss_function = extend(loss_function)
loss = loss_function(model(X), y)
Specify GGN approximation
By default, vivit.EigvalshComputation
uses the exact GGN. We only need to
specify the GGN’s parameters via a param_groups
argument that might be familiar
to you from torch.optim
.
computation = EigvalshComputation()
group = {"params": [p for p in model.parameters() if p.requires_grad]}
param_groups = [group]
Backward pass with BackPACK
We can now build the BackPACK extension and extension hook that will compute GGN
eigenvalues, pass them to a with backpack
, and
perform the backward pass.
extension = computation.get_extension()
extension_hook = computation.get_extension_hook(param_groups)
with backpack(extension, extension_hook=extension_hook):
loss.backward()
/home/docs/checkouts/readthedocs.org/user_builds/vivit/envs/latest/lib/python3.7/site-packages/vivit/linalg/eigvalsh.py:221: UserWarning: torch.symeig is deprecated in favor of torch.linalg.eigh and will be removed in a future PyTorch release.
The default behavior has changed from using the upper triangular portion of the matrix by default to using the lower triangular portion.
L, _ = torch.symeig(A, upper=upper)
should be replaced with
L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')
and
L, V = torch.symeig(A, eigenvectors=True)
should be replaced with
L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L') (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2794.)
gram_evals, _ = gram_mat.symeig(eigenvectors=False)
This will compute the GGN eigenvalues for each parameter group and store them
internally in the EigvalshComputation
instance. We can use the parameter group to request the eigenvalues.
evals = computation.get_result(group)
Verify results
Let’s compute the GGN matrix, column by column, using GGN-vector products that
only rely on torch.autograd
. We can then compute its eigenvalues and compare them.
ViViT eigen-decomposes the GGN Gram matrix, which is smaller than the GGN. Hence, we compare against the leading eigenvalues from the GGN eigen-decomposition:
Let’s see if the eigenvalues match.
for eval_vivit, eval_torch in zip(evals, ggn_evals):
close = isclose(eval_vivit, eval_torch, rtol=1e-4, atol=1e-7)
print(f"{eval_vivit:.5e} vs. {eval_torch:.5e}, close: {close}")
if not close:
raise ValueError("Eigenvalues don't match!")
print("Eigenvalues match!")
1.46938e-08 vs. 3.14615e-08, close: True
1.33847e-03 vs. 1.33849e-03, close: True
3.20110e-03 vs. 3.20108e-03, close: True
5.41562e-03 vs. 5.41563e-03, close: True
8.00165e-03 vs. 8.00164e-03, close: True
1.30415e-02 vs. 1.30415e-02, close: True
3.93476e-02 vs. 3.93476e-02, close: True
5.08563e-02 vs. 5.08563e-02, close: True
9.14433e-02 vs. 9.14433e-02, close: True
6.81850e-01 vs. 6.81850e-01, close: True
7.62203e-01 vs. 7.62203e-01, close: True
1.05583e+00 vs. 1.05583e+00, close: True
Eigenvalues match!
Total running time of the script: ( 0 minutes 1.358 seconds)