{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Computing empirical NTKs\n\nIn this example we will use ``vivit`` to compute empirical NTK matrices.\n\nThe ``functorch`` package allows to do this efficiently. `One of its tutorials\n<https://pytorch.org/functorch/stable/notebooks/neural_tangent_kernels.html>`_\nstates that doing this in stock PyTorch is hard ... well, challenge accepted!\nLet's see how ``vivit`` and ``functorch`` compare.\n\nGiven two data sets $\\mathbf{X}_1$, $\\mathbf{X}_2$, and a model\n$f_\\theta$, the empirical NTK is $\\mathbf{J}_\\theta\nf_\\theta(\\mathbf{X}_1) [\\mathbf{J}_\\theta f_\\theta(\\mathbf{X}_2)]^{\\top}$.\n\n``vivit`` can compute the GGN Gram matrix $[\\mathbf{J}_\\theta\nf_\\theta(\\mathbf{X}) \\sqrt{\\mathbf{H}}] [\\mathbf{J}_\\theta f_\\theta(\\mathbf{X})\n\\sqrt{\\mathbf{H}}]^\\top$ on a data set $\\mathbf{X}$ where\n$\\sqrt{\\mathbf{H}}$ is the matrix square root of the loss Hessian w.r.t.\nthe model's prediction.\n\nFor ``MSELoss`` we have $\\sqrt{\\mathbf{H}} = 2 \\mathbf{I}$ and therefore\nwe can compute $[\\mathbf{J}_\\theta f_\\theta (\\mathbf{X}) \\sqrt{2}\n\\mathbf{I}] [\\mathbf{J}_\\theta f_\\theta (\\mathbf{X}) \\sqrt{2} \\mathbf{I}]^\\top\n= \\mathbf{J}_\\theta f_\\theta(\\mathbf{X}) [\\mathbf{J}_\\theta\nf_\\theta(\\mathbf{X})]^{\\top}$. If we stack $\\mathbf{X}_1$ and\n$\\mathbf{X}_2$ into a data set $\\mathbf{X}$, a submatrix of the Gram matrix\nis proportional to the empirical NTK!\n\nLet's get the imports out of our way.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import time\n\nfrom backpack import backpack, extend\nfrom functorch import jacrev, jvp, make_functional, vjp, vmap\nfrom torch import allclose, cat, einsum, eye, manual_seed, randn, stack, zeros_like\nfrom torch.nn import Conv2d, Flatten, Linear, MSELoss, ReLU, Sequential\n\nfrom vivit.extensions.secondorder.vivit import ViViTGGNExact\n\ndevice = \"cpu\"\nmanual_seed(0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Setup\n\nWe will use the same CNN as the ``functorch`` tutorial and create the data\nsets $\\mathbf{X}_1$, $\\mathbf{X}_2$.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def CNN():\n    \"\"\"Same as in the functorch tutorial. Sequential for compatibility with BackPACK.\"\"\"\n    return Sequential(\n        Conv2d(3, 32, (3, 3)),\n        ReLU(),\n        Conv2d(32, 32, (3, 3)),\n        ReLU(),\n        Conv2d(32, 32, (3, 3)),\n        Flatten(),\n        Linear(21632, 10),\n    )\n\n\nx_train = randn(20, 3, 32, 32, device=device)\nx_test = randn(5, 3, 32, 32, device=device)\n\nnet = CNN().to(device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## NTK with functorch\n\nThe functorch tutorial provides two different methods to compute the\nempirical NTK. We just copy them over here.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fnet, params = make_functional(net)\n\n\ndef fnet_single(params, x):\n    \"\"\"From the functorch tutorial.\"\"\"\n    return fnet(params, x.unsqueeze(0)).squeeze(0)\n\n\ndef empirical_ntk_functorch(fnet_single, params, x1, x2):\n    \"\"\"From the functorch tutorial.\"\"\"\n    # Compute J(x1)\n    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n    jac1 = [j.flatten(2) for j in jac1]\n\n    # Compute J(x2)\n    jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)\n    jac2 = [j.flatten(2) for j in jac2]\n\n    # Compute J(x1) @ J(x2).T\n    result = stack([einsum(\"Naf,Mbf->NMab\", j1, j2) for j1, j2 in zip(jac1, jac2)])\n    result = result.sum(0)\n    return result\n\n\ndef empirical_ntk_implicit_functorch(func, params, x1, x2):\n    \"\"\"From the functorch tutorial.\"\"\"\n\n    def get_ntk(x1, x2):\n        def func_x1(params):\n            return func(params, x1)\n\n        def func_x2(params):\n            return func(params, x2)\n\n        output, vjp_fn = vjp(func_x1, params)\n\n        def get_ntk_slice(vec):\n            # This computes vec @ J(x2).T\n            # `vec` is some unit vector (a single slice of the Identity matrix)\n            vjps = vjp_fn(vec)\n            # This computes J(X1) @ vjps\n            _, jvps = jvp(func_x2, (params,), vjps)\n            return jvps\n\n        # Here's our identity matrix\n        basis = eye(output.numel(), dtype=output.dtype, device=output.device).view(\n            output.numel(), -1\n        )\n        return vmap(get_ntk_slice)(basis)\n\n    # get_ntk(x1, x2) computes the NTK for a single data point x1, x2\n    # Since the x1, x2 inputs to empirical_ntk_implicit are batched,\n    # we actually wish to compute the NTK between every pair of data points\n    # between {x1} and {x2}. That's what the vmaps here do.\n    return vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's compute an NTK matrix:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "ntk_functorch = empirical_ntk_functorch(fnet_single, params, x_train, x_train)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## NTK with ViViT\n\nAs outlined above, to compute the NTK with ``vivit``, we need to stack the\ntwo data sets, feed them through the network and an ``MSELoss`` function, then\ncompute the GGN Gram matrix during backpropagation. The latter is done by\n``vivit``'s ``ViViTGGNExact`` extension, which gives access to the per-layer\nGram matrix. We have to accumulate the Gram matrices over layers. To do that,\nwe use the following hook:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class AccumulateGramHook:\n    \"\"\"Accumulate the Gram matrix during backpropagation with BackPACK.\"\"\"\n\n    def __init__(self, delete_buffers):\n        self.gram = None\n        self.delete_buffers = delete_buffers\n\n    def __call__(self, module):\n        for p in module.parameters():\n            gram_p = p.vivit_ggn_exact[\"gram_mat\"]()\n            self.gram = gram_p if self.gram is None else self.gram + gram_p\n\n            if self.delete_buffers:\n                del p.vivit_ggn_exact"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The above steps are then implemented by the following function:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def empirical_ntk_vivit(net, x1, x2, delete_buffers=True):\n    \"\"\"Compute the empirical NTK matrix with ViViT.\"\"\"\n    N1 = x1.shape[0]\n    X = cat([x1, x2])\n\n    # make BackPACK-ready\n    net = extend(net)\n    loss_func = extend(MSELoss(reduction=\"sum\"))\n    hook = AccumulateGramHook(delete_buffers)\n\n    with backpack(ViViTGGNExact(), extension_hook=hook):\n        output = net(X)\n        y = zeros_like(output)  # anything, won't affect NTK\n        loss = loss_func(output, y)\n        loss.backward()\n\n    gram_reordered = einsum(\"cndm->nmcd\", hook.gram)\n\n    # slice out relevant blocks & fix scaling of MSELoss\n    return 0.5 * gram_reordered[:N1, N1:]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Check\n\nLet's check that the ``vivit`` and ``functorch`` implementations produce the\nsame NTK matrix:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "ntk_functorch = empirical_ntk_functorch(fnet_single, params, x_train, x_test)\nntk_vivit = empirical_ntk_vivit(net, x_train, x_test)\n\nclose = allclose(ntk_functorch, ntk_vivit, atol=1e-6)\nif close:\n    print(\"NTK from functorch and vivit match!\")\nelse:\n    raise ValueError(\"NTK from functorch and vivit don't match!\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Runtime\n\nLast but not least, let's compare the three methods in terms of runtime:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "t_functorch = time.time()\nempirical_ntk_functorch(fnet_single, params, x_train, x_test)\nt_functorch = time.time() - t_functorch\n\nt_functorch_implicit = time.time()\nempirical_ntk_implicit_functorch(fnet_single, params, x_train, x_test)\nt_functorch_implicit = time.time() - t_functorch_implicit\n\nt_vivit = time.time()\nempirical_ntk_vivit(net, x_train, x_test)\nt_vivit = time.time() - t_vivit\n\nt_min = min(t_functorch, t_functorch_implicit, t_vivit)\n\nprint(f\"Time [s] functorch:          {t_functorch:.4f} (x{t_functorch / t_min:.2f})\")\nprint(f\"Time [s] vivit:              {t_vivit:.4f} (x{t_vivit/ t_min:.2f})\")\nprint(\n    f\"Time [s] functorch implicit: {t_functorch_implicit:.4f}\"\n    + f\" (x{t_functorch_implicit / t_min:.2f})\"\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can see that ``vivit`` is competitive with ``functorch``.\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}