{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Computing GGN eigenvalues\n\nIn this example we demonstrate how to use ViViT's\n:py:class:`EigvalshComputation <vivit.EigvalshComputation>` to obtain the GGN's\neigenvalues and verify the result with :py:mod:`torch.autograd`.\n\nFirst, the imports.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from backpack import backpack, extend\nfrom backpack.utils.examples import _autograd_ggn_exact_columns\nfrom torch import cuda, device, isclose, manual_seed, rand, stack\nfrom torch.linalg import eigvalsh\nfrom torch.nn import Linear, MSELoss, ReLU, Sequential\n\nfrom vivit.linalg.eigvalsh import EigvalshComputation\n\n# make deterministic\nmanual_seed(0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Data, model & loss\nFor this demo, we use toy data and a small MLP with sufficiently few\nparameters such that we can store and eigen-decompose the GGN matrix to\nverify correctness. We use mean squared error as loss function.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "N = 4\nD_in = 7\nD_hidden = 5\nD_out = 3\n\nDEVICE = device(\"cuda\" if cuda.is_available() else \"cpu\")\n\nX = rand(N, D_in).to(DEVICE)\ny = rand(N, D_out).to(DEVICE)\n\nmodel = Sequential(\n    Linear(D_in, D_hidden),\n    ReLU(),\n    Linear(D_hidden, D_hidden),\n    ReLU(),\n    Linear(D_hidden, D_out),\n).to(DEVICE)\n\nloss_function = MSELoss(reduction=\"mean\").to(DEVICE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Integrate BackPACK\nNext, :py:func:`extend <backpack.extend>` the model and loss function to be able\nto use BackPACK. Then, we perform a forward pass to compute the loss.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = extend(model)\nloss_function = extend(loss_function)\n\nloss = loss_function(model(X), y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Specify GGN approximation\nBy default, :py:class:`vivit.EigvalshComputation` uses the exact GGN. We only need to\nspecify the GGN's parameters via a ``param_groups`` argument that might be familiar\nto you from :py:mod:`torch.optim`.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "computation = EigvalshComputation()\n\ngroup = {\"params\": [p for p in model.parameters() if p.requires_grad]}\nparam_groups = [group]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Backward pass with BackPACK\nWe can now build the BackPACK extension and extension hook that will compute GGN\neigenvalues, pass them to a :py:class:`with backpack <backpack.backpack>`, and\nperform the backward pass.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "extension = computation.get_extension()\nextension_hook = computation.get_extension_hook(param_groups)\n\nwith backpack(extension, extension_hook=extension_hook):\n    loss.backward()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This will compute the GGN eigenvalues for each parameter group and store them\ninternally in the :py:class:`EigvalshComputation <vivit.EigvalshComputation>`\ninstance. We can use the parameter group to request the eigenvalues.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "evals = computation.get_result(group)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Verify results\nLet's compute the GGN matrix, column by column, using GGN-vector products that\nonly rely on :py:mod:`torch.autograd`. We can then compute its eigenvalues and compare them.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "ggn = stack([col for _, col in _autograd_ggn_exact_columns(X, y, model, loss_function)])\nggn_evals = eigvalsh(ggn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "ViViT eigen-decomposes the GGN Gram matrix, which is smaller than the GGN.\nHence, we compare against the leading eigenvalues from the GGN eigen-decomposition:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "gram_dim = evals.numel()\nggn_evals = ggn_evals[-gram_dim:]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's see if the eigenvalues match.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "for eval_vivit, eval_torch in zip(evals, ggn_evals):\n    close = isclose(eval_vivit, eval_torch, rtol=1e-4, atol=1e-7)\n    print(f\"{eval_vivit:.5e} vs. {eval_torch:.5e}, close: {close}\")\n    if not close:\n        raise ValueError(\"Eigenvalues don't match!\")\n\nprint(\"Eigenvalues match!\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}