mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-14 09:31:42 +08:00
251 lines
5.4 KiB
Plaintext
251 lines
5.4 KiB
Plaintext
{
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0,
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"name": "Counterfactual Regret Minimization (CFR) on Kuhn Poker",
|
|
"provenance": [],
|
|
"collapsed_sections": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.5"
|
|
}
|
|
},
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "AYV_dMVDxyc2"
|
|
},
|
|
"source": [
|
|
"[](https://github.com/labmlai/annotated_deep_learning_paper_implementations)\n",
|
|
"[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/cfr/kuhn/experiment.ipynb) \n",
|
|
"\n",
|
|
"## [Counterfactual Regret Minimization (CFR)](https://nn.labml.ai/cfr/index.html) on Kuhn Poker\n",
|
|
"\n",
|
|
"This is an experiment learning to play Kuhn Poker with Counterfactual Regret Minimization CFR algorithm."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "AahG_i2y5tY9"
|
|
},
|
|
"source": [
|
|
"Install the `labml-nn` package"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "ZCzmCrAIVg0L"
|
|
},
|
|
"source": [
|
|
"%%capture\n",
|
|
"!pip install labml-nn"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "SE2VUQ6L5zxI"
|
|
},
|
|
"source": [
|
|
"Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "0hJXx_g0wS2C"
|
|
},
|
|
"source": [
|
|
"from labml import experiment, analytics\n",
|
|
"from labml_nn.cfr.analytics import plot_infosets\n",
|
|
"from labml_nn.cfr.kuhn import Configs\n",
|
|
"from labml_nn.cfr.infoset_saver import InfoSetSaver"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "Lpggo0wM6qb-"
|
|
},
|
|
"source": [
|
|
"Create an experiment, we only write tracking information to `sqlite` to speed things up.\n",
|
|
"Since the algorithm iterates fast and we track data on each iteration, writing to\n",
|
|
"other destinations such as Tensorboard can be relatively time consuming.\n",
|
|
"SQLite is enough for our analytics."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "bFcr9k-l4cAg"
|
|
},
|
|
"source": [
|
|
"experiment.create(name='kuhn_poker', writers={'sqlite'})"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "-OnHLi626tJt"
|
|
},
|
|
"source": [
|
|
"Initialize configurations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "Piz0c5f44hRo"
|
|
},
|
|
"source": [
|
|
"conf = Configs()"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "wwMzCqpD6vkL"
|
|
},
|
|
"source": [
|
|
"Set experiment configurations and assign a configurations dictionary to override configurations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 17
|
|
},
|
|
"id": "e6hmQhTw4nks",
|
|
"outputId": "e20b5ea3-605b-4bcc-c9de-0da93b6c7f32"
|
|
},
|
|
"source": [
|
|
"experiment.configs(conf, {'epochs': 1_000_000})"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "KJZRf8527GxL"
|
|
},
|
|
"source": [
|
|
"Start the experiment and run the training loop."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 187
|
|
},
|
|
"id": "aIAWo7Fw5DR8",
|
|
"outputId": "18cd2384-d6c0-42a3-feae-5a309563bb33"
|
|
},
|
|
"source": [
|
|
"# Start the experiment\n",
|
|
"with experiment.start():\n",
|
|
" conf.cfr.iterate()"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "oBXXlP2b7XZO"
|
|
},
|
|
"source": [
|
|
"inds = analytics.runs(experiment.get_uuid())"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "RJ0L4XH2Y8g4"
|
|
},
|
|
"source": [
|
|
"# dir(inds)"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "htumVaMnY8g4",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 568
|
|
},
|
|
"outputId": "735a944d-1b96-49e8-97b6-64317ea515b1"
|
|
},
|
|
"source": [
|
|
"plot_infosets(inds['average_strategy.*'], width=600, height=500).display()"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "yTDu8JKdY8g4",
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 690
|
|
},
|
|
"outputId": "6047dae2-095e-4323-ee91-f49573ad509c"
|
|
},
|
|
"source": [
|
|
"analytics.scatter(inds.average_strategy_Q_b, inds.average_strategy_Kb_b,\n",
|
|
" width=400, height=400)"
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"metadata": {
|
|
"id": "GnI67bbLY8g5"
|
|
},
|
|
"source": [
|
|
""
|
|
],
|
|
"outputs": [],
|
|
"execution_count": null
|
|
}
|
|
]
|
|
}
|