{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Computing directional derivatives along GGN eigenvectors\n\nIn this example we demonstrate how to use ViViT's\n:py:class:`DirectionalDerivativesComputation <vivit.DirectionalDerivativesComputation>`\nto obtain the 1\u02e2\u1d57- and 2\u207f\u1d48-order directional derivatives along the leading GGN\neigenvectors. We verify the result with :py:mod:`torch.autograd`.\n\nFirst, the imports.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from typing import List\n\nfrom backpack import backpack, extend\nfrom backpack.utils.examples import _autograd_ggn_exact_columns\nfrom torch import Tensor, cuda, device, einsum, isclose, manual_seed, rand, stack, zeros\nfrom torch.autograd import grad\nfrom torch.nn import Linear, MSELoss, ReLU, Sequential\nfrom torch.nn.utils.convert_parameters import parameters_to_vector\n\nfrom vivit.optim.directional_derivatives import DirectionalDerivativesComputation\n\n# make deterministic\nmanual_seed(0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Data, model & loss\nFor this demo, we use toy data and a small MLP with sufficiently few\nparameters such that we can store the GGN matrix to verify our results\n(yes, one could use matrix-free GGN-vector products instead).\nWe use mean squared error as loss function.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "N = 4\nD_in = 7\nD_hidden = 5\nD_out = 3\n\nDEVICE = device(\"cuda\" if cuda.is_available() else \"cpu\")\n\nX = rand(N, D_in).to(DEVICE)\ny = rand(N, D_out).to(DEVICE)\n\nmodel = Sequential(\n    Linear(D_in, D_hidden),\n    ReLU(),\n    Linear(D_hidden, D_hidden),\n    ReLU(),\n    Linear(D_hidden, D_out),\n).to(DEVICE)\n\nloss_function = MSELoss(reduction=\"mean\").to(DEVICE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Integrate BackPACK\nNext, :py:func:`extend <backpack.extend>` the model and loss function to be able\nto use BackPACK. Then, we perform a forward pass to compute the loss.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = extend(model)\nloss_function = extend(loss_function)\n\nloss = loss_function(model(X), y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Specify GGN approximation and directions\nBy default, :py:class:`vivit.DirectionalDerivativesComputation` uses the exact GGN.\nFurthermore, we need to specify the GGN's parameters via a ``param_groups`` argument\nthat might be familiar to you from :py:mod:`torch.optim`. It also contains a filter\nfunction that selects the eigenvalues whose eigenvectors will be used as directions\nto evaluate directional derivatives.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "computation = DirectionalDerivativesComputation()\n\n\ndef select_top_k(evals: Tensor, k=4) -> List[int]:\n    \"\"\"Select the top-k eigenvalues as directions to evaluate derivatives.\n\n    Args:\n        evals: Eigenvalues, sorted in ascending order.\n        k: Number of leading eigenvalues. Defaults to ``4``.\n\n    Returns:\n        Indices of top-k eigenvalues.\n    \"\"\"\n    return [evals.numel() - k + idx for idx in range(k)]\n\n\ngroup = {\n    \"params\": [p for p in model.parameters() if p.requires_grad],\n    \"criterion\": select_top_k,\n}\nparam_groups = [group]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Backward pass with BackPACK\nWe can now build the BackPACK extensions and extension hook that will compute\ndirectional derivatives, pass them to a :py:class:`with backpack <backpack.backpack>`,\nand perform the backward pass.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "extensions = computation.get_extensions()\nextension_hook = computation.get_extension_hook(param_groups)\n\nwith backpack(*extensions, extension_hook=extension_hook):\n    loss.backward()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This will compute the directional derivatives for each\nparameter group and store them internally in the\n:py:class:`DirectionalDerivativesComputation<vivit.DirectionalDerivativesComputation>`\ninstance. We can use the parameter group to request them.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "gammas_vivit, lambdas_vivit = computation.get_result(group)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Verify results\nTo verify the above, let's first compute the per-sample gradients and GGNs using\n:py:mod:`torch.autograd`.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "batch_grad = []\nbatch_ggn = []\n\nfor n in range(N):\n    x_n, y_n = X[[n]], y[[n]]\n\n    grad_n = grad(\n        loss_function(model(x_n), y_n),\n        [p for p in model.parameters() if p.requires_grad],\n    )\n    batch_grad.append(parameters_to_vector(grad_n))\n\n    ggn_n = stack(\n        [col for _, col in _autograd_ggn_exact_columns(x_n, y_n, model, loss_function)]\n    )\n    batch_ggn.append(ggn_n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We also need the GGN eigenvectors as directions\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "ggn = stack([col for _, col in _autograd_ggn_exact_columns(X, y, model, loss_function)])\nevals, evecs = ggn.symeig(eigenvectors=True)\nkeep = select_top_k(evals)\nevals, evecs = evals[keep], evecs[:, keep]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We are now ready to compute and compare the target quantities.\n\nFirst, compute and compare the first-order directional derivatives. Note that since\nthe GGN eigenvectors used as directions are not unique but can point in the opposite\ndirection. The directional gradient can thus be of different sign and we only compare\nthe absolute value.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "K = evals.numel()\ngammas_torch = zeros(N, K, device=evals.device, dtype=evals.dtype)\n\nfor n in range(N):\n    grad_n = batch_grad[n]\n    for k in range(K):\n        e_k = evecs[:, k]\n\n        gammas_torch[n, k] = einsum(\"i,i\", grad_n, e_k)\n\nfor gamma_vivit, gamma_torch in zip(gammas_vivit.flatten(), gammas_torch.flatten()):\n    close = isclose(abs(gamma_vivit), abs(gamma_torch), rtol=1e-4, atol=1e-7)\n    print(f\"{gamma_vivit:.5e} vs. {gamma_torch:.5e}, close: {close}\")\n    if not close:\n        raise ValueError(\"1\u02e2\u1d57-order directional derivatives don't match!\")\n\nprint(\"1\u02e2\u1d57-order directional derivatives match!\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Next, compute and compare the second-order directional derivatives.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "lambdas_torch = zeros(N, K, device=evals.device, dtype=evals.dtype)\n\nfor n in range(N):\n    ggn_n = batch_ggn[n]\n    for k in range(K):\n        e_k = evecs[:, k]\n\n        lambdas_torch[n, k] = einsum(\"i,ij,j\", e_k, ggn_n, e_k)\n\nfor lambda_vivit, lambda_torch in zip(lambdas_vivit.flatten(), lambdas_torch.flatten()):\n    close = isclose(lambda_vivit, lambda_torch, rtol=1e-4, atol=1e-7)\n    print(f\"{lambda_vivit:.5e} vs. {lambda_torch:.5e}, close: {close}\")\n    if not close:\n        raise ValueError(\"2\u207f\u1d48-order directional derivatives don't match!\")\n\nprint(\"2\u207f\u1d48-order directional derivatives match!\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Last, we check that the sample means of second-order derivatives coincide with\nthe eigenvalues.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "for eval_vivit, eval_torch in zip(lambdas_vivit.mean(0), evals):\n    close = isclose(eval_vivit, eval_torch, rtol=1e-4, atol=1e-7)\n    print(f\"{eval_vivit:.5e} vs. {eval_torch:.5e}, close: {close}\")\n    if not close:\n        print(\"Averaged 2\u207f\u1d48-order directional derivatives don't match eigenvalues!\")\n\nprint(\"Averaged 2\u207f\u1d48-order directional derivatives match eigenvalues!\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}