Computing empirical NTKs

In this example we will use vivit to compute empirical NTK matrices.

The functorch package allows to do this efficiently. One of its tutorials states that doing this in stock PyTorch is hard … well, challenge accepted! Let’s see how vivit and functorch compare.

Given two data sets \(\mathbf{X}_1\), \(\mathbf{X}_2\), and a model \(f_\theta\), the empirical NTK is \(\mathbf{J}_\theta f_\theta(\mathbf{X}_1) [\mathbf{J}_\theta f_\theta(\mathbf{X}_2)]^{\top}\).

vivit can compute the GGN Gram matrix \([\mathbf{J}_\theta f_\theta(\mathbf{X}) \sqrt{\mathbf{H}}] [\mathbf{J}_\theta f_\theta(\mathbf{X}) \sqrt{\mathbf{H}}]^\top\) on a data set \(\mathbf{X}\) where \(\sqrt{\mathbf{H}}\) is the matrix square root of the loss Hessian w.r.t. the model’s prediction.

For MSELoss we have \(\sqrt{\mathbf{H}} = 2 \mathbf{I}\) and therefore we can compute \([\mathbf{J}_\theta f_\theta (\mathbf{X}) \sqrt{2} \mathbf{I}] [\mathbf{J}_\theta f_\theta (\mathbf{X}) \sqrt{2} \mathbf{I}]^\top = \mathbf{J}_\theta f_\theta(\mathbf{X}) [\mathbf{J}_\theta f_\theta(\mathbf{X})]^{\top}\). If we stack \(\mathbf{X}_1\) and \(\mathbf{X}_2\) into a data set \(\mathbf{X}\), a submatrix of the Gram matrix is proportional to the empirical NTK!

Let’s get the imports out of our way.

import time

from backpack import backpack, extend
from functorch import jacrev, jvp, make_functional, vjp, vmap
from torch import allclose, cat, einsum, eye, manual_seed, randn, stack, zeros_like
from torch.nn import Conv2d, Flatten, Linear, MSELoss, ReLU, Sequential

from vivit.extensions.secondorder.vivit import ViViTGGNExact

device = "cpu"
manual_seed(0)

Out:

<torch._C.Generator object at 0x7fc15a5db370>

Setup

We will use the same CNN as the functorch tutorial and create the data sets \(\mathbf{X}_1\), \(\mathbf{X}_2\).

def CNN():
    """Same as in the functorch tutorial. Sequential for compatibility with BackPACK."""
    return Sequential(
        Conv2d(3, 32, (3, 3)),
        ReLU(),
        Conv2d(32, 32, (3, 3)),
        ReLU(),
        Conv2d(32, 32, (3, 3)),
        Flatten(),
        Linear(21632, 10),
    )


x_train = randn(20, 3, 32, 32, device=device)
x_test = randn(5, 3, 32, 32, device=device)

net = CNN().to(device)

NTK with functorch

The functorch tutorial provides two different methods to compute the empirical NTK. We just copy them over here.

fnet, params = make_functional(net)


def fnet_single(params, x):
    """From the functorch tutorial."""
    return fnet(params, x.unsqueeze(0)).squeeze(0)


def empirical_ntk_functorch(fnet_single, params, x1, x2):
    """From the functorch tutorial."""
    # Compute J(x1)
    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
    jac1 = [j.flatten(2) for j in jac1]

    # Compute J(x2)
    jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
    jac2 = [j.flatten(2) for j in jac2]

    # Compute J(x1) @ J(x2).T
    result = stack([einsum("Naf,Mbf->NMab", j1, j2) for j1, j2 in zip(jac1, jac2)])
    result = result.sum(0)
    return result


def empirical_ntk_implicit_functorch(func, params, x1, x2):
    """From the functorch tutorial."""

    def get_ntk(x1, x2):
        def func_x1(params):
            return func(params, x1)

        def func_x2(params):
            return func(params, x2)

        output, vjp_fn = vjp(func_x1, params)

        def get_ntk_slice(vec):
            # This computes vec @ J(x2).T
            # `vec` is some unit vector (a single slice of the Identity matrix)
            vjps = vjp_fn(vec)
            # This computes J(X1) @ vjps
            _, jvps = jvp(func_x2, (params,), vjps)
            return jvps

        # Here's our identity matrix
        basis = eye(output.numel(), dtype=output.dtype, device=output.device).view(
            output.numel(), -1
        )
        return vmap(get_ntk_slice)(basis)

    # get_ntk(x1, x2) computes the NTK for a single data point x1, x2
    # Since the x1, x2 inputs to empirical_ntk_implicit are batched,
    # we actually wish to compute the NTK between every pair of data points
    # between {x1} and {x2}. That's what the vmaps here do.
    return vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)

Let’s compute an NTK matrix:

ntk_functorch = empirical_ntk_functorch(fnet_single, params, x_train, x_train)

NTK with ViViT

As outlined above, to compute the NTK with vivit, we need to stack the two data sets, feed them through the network and an MSELoss function, then compute the GGN Gram matrix during backpropagation. The latter is done by vivit’s ViViTGGNExact extension, which gives access to the per-layer Gram matrix. We have to accumulate the Gram matrices over layers. To do that, we use the following hook:

class AccumulateGramHook:
    """Accumulate the Gram matrix during backpropagation with BackPACK."""

    def __init__(self, delete_buffers):
        self.gram = None
        self.delete_buffers = delete_buffers

    def __call__(self, module):
        for p in module.parameters():
            gram_p = p.vivit_ggn_exact["gram_mat"]()
            self.gram = gram_p if self.gram is None else self.gram + gram_p

            if self.delete_buffers:
                del p.vivit_ggn_exact

The above steps are then implemented by the following function:

def empirical_ntk_vivit(net, x1, x2, delete_buffers=True):
    """Compute the empirical NTK matrix with ViViT."""
    N1 = x1.shape[0]
    X = cat([x1, x2])

    # make BackPACK-ready
    net = extend(net)
    loss_func = extend(MSELoss(reduction="sum"))
    hook = AccumulateGramHook(delete_buffers)

    with backpack(ViViTGGNExact(), extension_hook=hook):
        output = net(X)
        y = zeros_like(output)  # anything, won't affect NTK
        loss = loss_func(output, y)
        loss.backward()

    gram_reordered = einsum("cndm->nmcd", hook.gram)

    # slice out relevant blocks & fix scaling of MSELoss
    return 0.5 * gram_reordered[:N1, N1:]

Check

Let’s check that the vivit and functorch implementations produce the same NTK matrix:

ntk_functorch = empirical_ntk_functorch(fnet_single, params, x_train, x_test)
ntk_vivit = empirical_ntk_vivit(net, x_train, x_test)

close = allclose(ntk_functorch, ntk_vivit, atol=1e-6)
if close:
    print("NTK from functorch and vivit match!")
else:
    raise ValueError("NTK from functorch and vivit don't match!")

Out:

NTK from functorch and vivit match!

Runtime

Last but not least, let’s compare the three methods in terms of runtime:

t_functorch = time.time()
empirical_ntk_functorch(fnet_single, params, x_train, x_test)
t_functorch = time.time() - t_functorch

t_functorch_implicit = time.time()
empirical_ntk_implicit_functorch(fnet_single, params, x_train, x_test)
t_functorch_implicit = time.time() - t_functorch_implicit

t_vivit = time.time()
empirical_ntk_vivit(net, x_train, x_test)
t_vivit = time.time() - t_vivit

t_min = min(t_functorch, t_functorch_implicit, t_vivit)

print(f"Time [s] functorch:          {t_functorch:.4f} (x{t_functorch / t_min:.2f})")
print(f"Time [s] vivit:              {t_vivit:.4f} (x{t_vivit/ t_min:.2f})")
print(
    f"Time [s] functorch implicit: {t_functorch_implicit:.4f}"
    + f" (x{t_functorch_implicit / t_min:.2f})"
)

Out:

Time [s] functorch:          0.4023 (x1.00)
Time [s] vivit:              1.7386 (x4.32)
Time [s] functorch implicit: 2.4185 (x6.01)

We can see that vivit is competitive with functorch.

Total running time of the script: ( 0 minutes 7.769 seconds)

Gallery generated by Sphinx-Gallery