{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "accelerator": "GPU", "colab": { "name": "Counterfactual Regret Minimization (CFR) on Kuhn Poker", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "AYV_dMVDxyc2" }, "source": [ "[![Github](https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social)](https://github.com/labmlai/annotated_deep_learning_paper_implementations)\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/cfr/kuhn/experiment.ipynb) \n", "\n", "## [Counterfactual Regret Minimization (CFR)](https://nn.labml.ai/cfr/index.html) on Kuhn Poker\n", "\n", "This is an experiment learning to play Kuhn Poker with Counterfactual Regret Minimization CFR algorithm." ] }, { "cell_type": "markdown", "metadata": { "id": "AahG_i2y5tY9" }, "source": [ "Install the `labml-nn` package" ] }, { "cell_type": "code", "metadata": { "id": "ZCzmCrAIVg0L" }, "source": [ "%%capture\n", "!pip install labml-nn" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": { "id": "SE2VUQ6L5zxI" }, "source": [ "Imports" ] }, { "cell_type": "code", "metadata": { "id": "0hJXx_g0wS2C" }, "source": [ "from labml import experiment, analytics\n", "from labml_nn.cfr.analytics import plot_infosets\n", "from labml_nn.cfr.kuhn import Configs\n", "from labml_nn.cfr.infoset_saver import InfoSetSaver" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": { "id": "Lpggo0wM6qb-" }, "source": [ "Create an experiment, we only write tracking information to `sqlite` to speed things up.\n", "Since the algorithm iterates fast and we track data on each iteration, writing to\n", "other destinations such as Tensorboard can be relatively time consuming.\n", "SQLite is enough for our analytics." ] }, { "cell_type": "code", "metadata": { "id": "bFcr9k-l4cAg" }, "source": [ "experiment.create(name='kuhn_poker', writers={'sqlite'})" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": { "id": "-OnHLi626tJt" }, "source": [ "Initialize configurations" ] }, { "cell_type": "code", "metadata": { "id": "Piz0c5f44hRo" }, "source": [ "conf = Configs()" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": { "id": "wwMzCqpD6vkL" }, "source": [ "Set experiment configurations and assign a configurations dictionary to override configurations" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 17 }, "id": "e6hmQhTw4nks", "outputId": "e20b5ea3-605b-4bcc-c9de-0da93b6c7f32" }, "source": [ "experiment.configs(conf, {'epochs': 1_000_000})" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": { "id": "KJZRf8527GxL" }, "source": [ "Start the experiment and run the training loop." ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 187 }, "id": "aIAWo7Fw5DR8", "outputId": "18cd2384-d6c0-42a3-feae-5a309563bb33" }, "source": [ "# Start the experiment\n", "with experiment.start():\n", " conf.cfr.iterate()" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "id": "oBXXlP2b7XZO" }, "source": [ "inds = analytics.runs(experiment.get_uuid())" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "id": "RJ0L4XH2Y8g4" }, "source": [ "# dir(inds)" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "id": "htumVaMnY8g4", "colab": { "base_uri": "https://localhost:8080/", "height": 568 }, "outputId": "735a944d-1b96-49e8-97b6-64317ea515b1" }, "source": [ "plot_infosets(inds['average_strategy.*'], width=600, height=500).display()" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "id": "yTDu8JKdY8g4", "colab": { "base_uri": "https://localhost:8080/", "height": 690 }, "outputId": "6047dae2-095e-4323-ee91-f49573ad509c" }, "source": [ "analytics.scatter(inds.average_strategy_Q_b, inds.average_strategy_Kb_b,\n", " width=400, height=400)" ], "outputs": [], "execution_count": null }, { "cell_type": "code", "metadata": { "id": "GnI67bbLY8g5" }, "source": [ "" ], "outputs": [], "execution_count": null } ] }