{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Computing GGN eigenpairs\n\nIn this example we demonstrate how to use ViViT's\n:py:class:`EighComputation <vivit.EighComputation>` to obtain the leading GGN\neigenpairs (eigenvalues and associated eigenvectors). We verify the result with\n:py:mod:`torch.autograd`.\n\nFirst, the imports.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from typing import List\n\nfrom backpack import backpack, extend\nfrom backpack.utils.examples import _autograd_ggn_exact_columns\nfrom torch import Tensor, allclose, cat, cuda, device, manual_seed, rand, stack\nfrom torch.nn import Linear, MSELoss, ReLU, Sequential\n\nfrom vivit.linalg.eigh import EighComputation\n\n# make deterministic\nmanual_seed(0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Data, model & loss\nFor this demo, we use toy data and a small MLP with sufficiently few\nparameters such that we can store the GGN matrix to verify the eigen-properties\nof our results (yes, one could use matrix-free GGN-vector products instead, but by\nexpanding the GGN matrix we will familiarize ourselves more with the format of\nthe results). We use mean squared error as loss function.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "N = 4\nD_in = 7\nD_hidden = 5\nD_out = 3\n\nDEVICE = device(\"cuda\" if cuda.is_available() else \"cpu\")\n\nX = rand(N, D_in).to(DEVICE)\ny = rand(N, D_out).to(DEVICE)\n\nmodel = Sequential(\n    Linear(D_in, D_hidden),\n    ReLU(),\n    Linear(D_hidden, D_hidden),\n    ReLU(),\n    Linear(D_hidden, D_out),\n).to(DEVICE)\n\nloss_function = MSELoss(reduction=\"mean\").to(DEVICE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Integrate BackPACK\nNext, :py:func:`extend <backpack.extend>` the model and loss function to be able\nto use BackPACK. Then, we perform a forward pass to compute the loss.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = extend(model)\nloss_function = extend(loss_function)\n\nloss = loss_function(model(X), y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Specify GGN approximation and eigenpair filter\nBy default, :py:class:`vivit.EighComputation` uses the exact GGN. Furthermore, we need\nto specify the GGN's parameters via a ``param_groups`` argument that might be familiar\nto you from :py:mod:`torch.optim`. It also contains a filter function that selects the\neigenvalues whose eigenvectors should be computed (computing all eigenvectors is\ninfeasible for big architectures).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "computation = EighComputation()\n\n\ndef select_top_k(evals: Tensor, k=4) -> List[int]:\n    \"\"\"Select the top-k eigenvalues for the eigenvector computation.\n\n    Args:\n        evals: Eigenvalues, sorted in ascending order.\n        k: Number of leading eigenvalues. Defaults to ``4``.\n\n    Returns:\n        Indices of top-k eigenvalues.\n    \"\"\"\n    return [evals.numel() - k + idx for idx in range(k)]\n\n\ngroup = {\n    \"params\": [p for p in model.parameters() if p.requires_grad],\n    \"criterion\": select_top_k,\n}\nparam_groups = [group]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Backward pass with BackPACK\nWe can now build the BackPACK extension and extension hook that will compute\neigenpairs, pass them to a :py:class:`with backpack <backpack.backpack>`, and\nperform the backward pass.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "extension = computation.get_extension()\nextension_hook = computation.get_extension_hook(param_groups)\n\nwith backpack(extension, extension_hook=extension_hook):\n    loss.backward()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This will compute the GGN eigenpairs for each parameter group and store them\ninternally in the :py:class:`EighComputation <vivit.EighComputation>` instance.\nWe can use the parameter group to request the eigenpairs.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "evals, evecs = computation.get_result(group)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The eigenvectors have a similar format than the parameters. The leading axis\nallows to access eigenvectors for an eigenvalue.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(\"Parameter shape    |  Eigenvector shape\")\nfor p, v in zip(group[\"params\"], evecs):\n    print(f\"{str(p.shape):<19}|  {v.shape}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In the following, we will flatten and concatenate them among parameters, such that\n``evecs_flat[k,:]`` is the GGN eigenvector with eigenvalue ``evals[k]``:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "evecs_flat = cat([e.flatten(start_dim=1) for e in evecs], dim=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Verify results\nTo verify the above, let's compute the GGN matrix, column by column, using GGN-vector\nproducts that only rely on :py:mod:`torch.autograd`.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "ggn = stack([col for _, col in _autograd_ggn_exact_columns(X, y, model, loss_function)])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can then check that application of the GGN to an eigenvector rescales the latter by\nits eigenvalue.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "for e, v in zip(evals, evecs_flat):\n    ggn_v = ggn @ v\n    close = allclose(e * v, ggn_v, rtol=1e-4, atol=1e-7)\n\n    print(f\"Eigenvalue {e:.5e}, Eigenvector properties: {close}\")\n    if not close:\n        raise ValueError(\"Eigenvector properties failed!\")\n\nprint(\"Eigenvector properties confirmed!\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}