Note
Click here to download the full example code
Computing empirical NTKs
In this example we will use vivit to compute empirical NTK matrices.
The functorch package allows to do this efficiently. One of its tutorials
states that doing this in stock PyTorch is hard … well, challenge accepted!
Let’s see how vivit and functorch compare.
Given two data sets \(\mathbf{X}_1\), \(\mathbf{X}_2\), and a model \(f_\theta\), the empirical NTK is \(\mathbf{J}_\theta f_\theta(\mathbf{X}_1) [\mathbf{J}_\theta f_\theta(\mathbf{X}_2)]^{\top}\).
vivit can compute the GGN Gram matrix \([\mathbf{J}_\theta
f_\theta(\mathbf{X}) \sqrt{\mathbf{H}}] [\mathbf{J}_\theta f_\theta(\mathbf{X})
\sqrt{\mathbf{H}}]^\top\) on a data set \(\mathbf{X}\) where
\(\sqrt{\mathbf{H}}\) is the matrix square root of the loss Hessian w.r.t.
the model’s prediction.
For MSELoss we have \(\sqrt{\mathbf{H}} = 2 \mathbf{I}\) and therefore
we can compute \([\mathbf{J}_\theta f_\theta (\mathbf{X}) \sqrt{2}
\mathbf{I}] [\mathbf{J}_\theta f_\theta (\mathbf{X}) \sqrt{2} \mathbf{I}]^\top
= \mathbf{J}_\theta f_\theta(\mathbf{X}) [\mathbf{J}_\theta
f_\theta(\mathbf{X})]^{\top}\). If we stack \(\mathbf{X}_1\) and
\(\mathbf{X}_2\) into a data set \(\mathbf{X}\), a submatrix of the Gram matrix
is proportional to the empirical NTK!
Let’s get the imports out of our way.
import time
from backpack import backpack, extend
from functorch import jacrev, jvp, make_functional, vjp, vmap
from torch import allclose, cat, einsum, eye, manual_seed, randn, stack, zeros_like
from torch.nn import Conv2d, Flatten, Linear, MSELoss, ReLU, Sequential
from vivit.extensions.secondorder.vivit import ViViTGGNExact
device = "cpu"
manual_seed(0)
<torch._C.Generator object at 0x7fc4a4be4e70>
Setup
We will use the same CNN as the functorch tutorial and create the data
sets \(\mathbf{X}_1\), \(\mathbf{X}_2\).
def CNN():
"""Same as in the functorch tutorial. Sequential for compatibility with BackPACK."""
return Sequential(
Conv2d(3, 32, (3, 3)),
ReLU(),
Conv2d(32, 32, (3, 3)),
ReLU(),
Conv2d(32, 32, (3, 3)),
Flatten(),
Linear(21632, 10),
)
x_train = randn(20, 3, 32, 32, device=device)
x_test = randn(5, 3, 32, 32, device=device)
net = CNN().to(device)
NTK with functorch
The functorch tutorial provides two different methods to compute the empirical NTK. We just copy them over here.
fnet, params = make_functional(net)
def fnet_single(params, x):
"""From the functorch tutorial."""
return fnet(params, x.unsqueeze(0)).squeeze(0)
def empirical_ntk_functorch(fnet_single, params, x1, x2):
"""From the functorch tutorial."""
# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
jac1 = [j.flatten(2) for j in jac1]
# Compute J(x2)
jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
jac2 = [j.flatten(2) for j in jac2]
# Compute J(x1) @ J(x2).T
result = stack([einsum("Naf,Mbf->NMab", j1, j2) for j1, j2 in zip(jac1, jac2)])
result = result.sum(0)
return result
def empirical_ntk_implicit_functorch(func, params, x1, x2):
"""From the functorch tutorial."""
def get_ntk(x1, x2):
def func_x1(params):
return func(params, x1)
def func_x2(params):
return func(params, x2)
output, vjp_fn = vjp(func_x1, params)
def get_ntk_slice(vec):
# This computes vec @ J(x2).T
# `vec` is some unit vector (a single slice of the Identity matrix)
vjps = vjp_fn(vec)
# This computes J(X1) @ vjps
_, jvps = jvp(func_x2, (params,), vjps)
return jvps
# Here's our identity matrix
basis = eye(output.numel(), dtype=output.dtype, device=output.device).view(
output.numel(), -1
)
return vmap(get_ntk_slice)(basis)
# get_ntk(x1, x2) computes the NTK for a single data point x1, x2
# Since the x1, x2 inputs to empirical_ntk_implicit are batched,
# we actually wish to compute the NTK between every pair of data points
# between {x1} and {x2}. That's what the vmaps here do.
return vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)
Let’s compute an NTK matrix:
ntk_functorch = empirical_ntk_functorch(fnet_single, params, x_train, x_train)
NTK with ViViT
As outlined above, to compute the NTK with vivit, we need to stack the
two data sets, feed them through the network and an MSELoss function, then
compute the GGN Gram matrix during backpropagation. The latter is done by
vivit’s ViViTGGNExact extension, which gives access to the per-layer
Gram matrix. We have to accumulate the Gram matrices over layers. To do that,
we use the following hook:
class AccumulateGramHook:
"""Accumulate the Gram matrix during backpropagation with BackPACK."""
def __init__(self, delete_buffers):
self.gram = None
self.delete_buffers = delete_buffers
def __call__(self, module):
for p in module.parameters():
gram_p = p.vivit_ggn_exact["gram_mat"]()
self.gram = gram_p if self.gram is None else self.gram + gram_p
if self.delete_buffers:
del p.vivit_ggn_exact
The above steps are then implemented by the following function:
def empirical_ntk_vivit(net, x1, x2, delete_buffers=True):
"""Compute the empirical NTK matrix with ViViT."""
N1 = x1.shape[0]
X = cat([x1, x2])
# make BackPACK-ready
net = extend(net)
loss_func = extend(MSELoss(reduction="sum"))
hook = AccumulateGramHook(delete_buffers)
with backpack(ViViTGGNExact(), extension_hook=hook):
output = net(X)
y = zeros_like(output) # anything, won't affect NTK
loss = loss_func(output, y)
loss.backward()
gram_reordered = einsum("cndm->nmcd", hook.gram)
# slice out relevant blocks & fix scaling of MSELoss
return 0.5 * gram_reordered[:N1, N1:]
Check
Let’s check that the vivit and functorch implementations produce the
same NTK matrix:
ntk_functorch = empirical_ntk_functorch(fnet_single, params, x_train, x_test)
ntk_vivit = empirical_ntk_vivit(net, x_train, x_test)
close = allclose(ntk_functorch, ntk_vivit, atol=1e-6)
if close:
print("NTK from functorch and vivit match!")
else:
raise ValueError("NTK from functorch and vivit don't match!")
NTK from functorch and vivit match!
Runtime
Last but not least, let’s compare the three methods in terms of runtime:
t_functorch = time.time()
empirical_ntk_functorch(fnet_single, params, x_train, x_test)
t_functorch = time.time() - t_functorch
t_functorch_implicit = time.time()
empirical_ntk_implicit_functorch(fnet_single, params, x_train, x_test)
t_functorch_implicit = time.time() - t_functorch_implicit
t_vivit = time.time()
empirical_ntk_vivit(net, x_train, x_test)
t_vivit = time.time() - t_vivit
t_min = min(t_functorch, t_functorch_implicit, t_vivit)
print(f"Time [s] functorch: {t_functorch:.4f} (x{t_functorch / t_min:.2f})")
print(f"Time [s] vivit: {t_vivit:.4f} (x{t_vivit/ t_min:.2f})")
print(
f"Time [s] functorch implicit: {t_functorch_implicit:.4f}"
+ f" (x{t_functorch_implicit / t_min:.2f})"
)
Time [s] functorch: 0.4243 (x1.00)
Time [s] vivit: 1.8497 (x4.36)
Time [s] functorch implicit: 2.9742 (x7.01)
We can see that vivit is competitive with functorch.
Total running time of the script: ( 0 minutes 8.634 seconds)