{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Computing directionally damped Newton steps\n\nIn this example we demonstrate how to use ViViT's\n:py:class:`DirectionalDampedNewtonComputation\n<vivit.DirectionalDampedNewtonComputation>` to compute directionally damped\nNewton steps with the GGN. We verify the result with :py:mod:`torch.autograd`.\n\nFirst, the imports.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from typing import List\n\nfrom backpack import backpack, extend\nfrom backpack.utils.examples import _autograd_ggn_exact_columns\nfrom torch import (\n    Tensor,\n    allclose,\n    cuda,\n    device,\n    einsum,\n    manual_seed,\n    ones_like,\n    rand,\n    stack,\n    zeros_like,\n)\nfrom torch.autograd import grad\nfrom torch.nn import Linear, MSELoss, ReLU, Sequential\nfrom torch.nn.utils.convert_parameters import parameters_to_vector\n\nfrom vivit.optim.directional_damped_newton import DirectionalDampedNewtonComputation\n\n# make deterministic\nmanual_seed(0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Data, model & loss\n\nFor this demo, we use toy data and a small MLP with sufficiently few\nparameters such that we can store the GGN matrix to verify our results. We\nuse mean squared error as loss function.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "N = 4\nD_in = 7\nD_hidden = 5\nD_out = 3\n\nDEVICE = device(\"cuda\" if cuda.is_available() else \"cpu\")\n\nX = rand(N, D_in).to(DEVICE)\ny = rand(N, D_out).to(DEVICE)\n\nmodel = Sequential(\n    Linear(D_in, D_hidden),\n    ReLU(),\n    Linear(D_hidden, D_hidden),\n    ReLU(),\n    Linear(D_hidden, D_out),\n).to(DEVICE)\n\nloss_function = MSELoss(reduction=\"mean\").to(DEVICE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Integrate BackPACK\n\nNext, :py:func:`extend <backpack.extend>` the model and loss function to be\nable to use BackPACK. Then, we perform a forward pass to compute the loss.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = extend(model)\nloss_function = extend(loss_function)\n\nloss = loss_function(model(X), y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Specify GGN approximation and directions\n\nBy default, :py:class:`vivit.DirectionalDampedNewtonComputation` uses the\nexact GGN. Furthermore, we need to specify the GGN's parameters via a\n``param_groups`` argument that might be familiar to you from\n:py:mod:`torch.optim`. It also contains a filter function that selects the\neigenvalues whose eigenvectors will be used as directions for the Newton\nstep.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "computation = DirectionalDampedNewtonComputation()\n\n\ndef select_top_k(evals: Tensor, k=4) -> List[int]:\n    \"\"\"Select the top-k eigenvalues as directions to evaluate derivatives.\n\n    Args:\n        evals: Eigenvalues, sorted in ascending order.\n        k: Number of leading eigenvalues. Defaults to ``4``.\n\n    Returns:\n        Indices of top-k eigenvalues.\n    \"\"\"\n    return [evals.numel() - k + idx for idx in range(k)]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Specify directional damping\n\nWe also need a damping function that provides the damping value for each\ndirection. This function receives the GGNs eigenvalues, Gram matrix\neigenvectors, as well as first- and second-order directional derivatives. It\nreturns a one-dimensional tensor that contains the damping values for all\ndirections.\n\nThis seems overly complicated. But this approach allows for incorporating\ninformation about gradient and curvature noise into the damping value.\n\nFor simplicity, we will use a constant damping of 1 for all\ndirections.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "DAMPING = 1.0\n\n\ndef constant_damping(\n    evals: Tensor, evecs: Tensor, gammas: Tensor, lambdas: Tensor\n) -> Tensor:\n    \"\"\"Constant damping along all directions.\n\n    Args:\n        evals: GGN eigenvalues. Shape ``[K]``.\n        evecs: GGN Gram matrix eigenvectors. Shape ``[NC, K]``.\n        gammas: Directional gradients. Shape ``[N, K]``.\n        lambdas: Directional curvatures. Shape ``[N, K]``.\n\n    Returns:\n        Directional dampings. Shape ``[K]``.\n    \"\"\"\n    return DAMPING * ones_like(evals)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's put everything together and set up the parameter groups.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "group = {\n    \"params\": [p for p in model.parameters() if p.requires_grad],\n    \"criterion\": select_top_k,\n    \"damping\": constant_damping,\n}\nparam_groups = [group]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Backward pass with BackPACK\n\nWe can now build the BackPACK extensions and extension hook that will compute\nthe damped Newton step, pass them to a :py:class:`with backpack\n<backpack.backpack>`, and perform the backward pass.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "extensions = computation.get_extensions()\nextension_hook = computation.get_extension_hook(param_groups)\n\nwith backpack(*extensions, extension_hook=extension_hook):\n    loss.backward()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This will compute the damped Newton step for each parameter group and store\nit internally in the :py:class:`vivit.DirectionalDampedNewtonComputation`\ninstance. We can use the parameter group to request it.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "newton_step = computation.get_result(group)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "It has the same format as the ``group['params']`` entry:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "for param, newton in zip(group[\"params\"], newton_step):\n    print(f\"Parameter shape:   {param.shape}\\nNewton step shape: {newton.shape}\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We will flatten and concatenate the Newton step over parameters to simplify\nthe comparison with :py:mod:`torch.autograd`.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "newton_step_flat = parameters_to_vector(newton_step)\nprint(newton_step_flat.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Verify results\n\nLet's compute the damped Newton step with :py:mod:`torch.autograd` and verify\nit leads to the same result.\n\nWe need the gradient and the GGN.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "gradient = grad(\n    loss_function(model(X), y), [p for p in model.parameters() if p.requires_grad]\n)\ngradient = parameters_to_vector(gradient)\n\nggn = stack([col for _, col in _autograd_ggn_exact_columns(X, y, model, loss_function)])\n\nprint(gradient.shape, ggn.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Next, eigen-decompose the GGN and filter the relevant eigenpairs:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "evals, evecs = ggn.symeig(eigenvectors=True)\nkeep = select_top_k(evals)\nevals, evecs = evals[keep], evecs[:, keep]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This is sufficient to form the damped Newton step\n\n\\begin{align}s = \\sum_{k=1}^K \\frac{-\\gamma_k}{\\lambda_k + \\delta} e_k\\end{align}\n\nwith constant damping $\\delta = 1$.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "newton_step_torch = zeros_like(gradient)\n\nK = evals.numel()\n\nfor k in range(K):\n    evec = evecs[:, k]\n    gamm = einsum(\"i,i\", gradient, evec)\n    lamb = evals[k]\n\n    newton = (-gamm / (lamb + DAMPING)) * evec\n    newton_step_torch += newton\n\nprint(newton_step_torch.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Both damped Newton steps should be identical.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "close = allclose(newton_step_flat, newton_step_torch, rtol=1e-5, atol=1e-7)\nif not close:\n    raise ValueError(\"Directionally damped Newton steps don't match!\")\n\nprint(\"Directionally damped Newton steps match!\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}