Files
Varuna Jayasiri c4d2e8cd22 docs
2025-07-31 08:48:07 +05:30

315 lines
75 KiB
HTML
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en">
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A simple PyTorch implementation/tutorial of Generative Adversarial Networks (GAN) loss functions."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Generative Adversarial Networks (GAN)"/>
<meta name="twitter:description" content="A simple PyTorch implementation/tutorial of Generative Adversarial Networks (GAN) loss functions."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/gan/original/index.html"/>
<meta property="og:title" content="Generative Adversarial Networks (GAN)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="Generative Adversarial Networks (GAN)"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Generative Adversarial Networks (GAN)"/>
<meta property="og:description" content="A simple PyTorch implementation/tutorial of Generative Adversarial Networks (GAN) loss functions."/>
<title>Generative Adversarial Networks (GAN)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css?v=1">
<link rel="canonical" href="https://nn.labml.ai/gan/original/index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">gan</a>
<a class="parent" href="index.html">original</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/gan/original/__init__.py" target="_blank">
View code on Github</a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Generative Adversarial Networks (GAN)</h1>
<p>This is an implementation of <a href="https://arxiv.org/abs/1406.2661">Generative Adversarial Networks</a>.</p>
<p>The generator, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord coloredeq eqv" style=""><span class="mord mathnormal" style="">G</span></span><span class="mopen">(</span><span class="mord coloredeq eqo" style=""><span class="mord" style=""><span class="mord vbox" style=""><span class="thinbox" style=""><span class="rlap" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span><span class="fix"></span></span></span></span><span class="mspace" style="margin-right:0.050187499999999996em"></span><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span><span class="mpunct">;</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">g</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span> generates samples that match the distribution of data, while the discriminator, <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.036108em;vertical-align:-0.286108em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="margin-right:0.02778em">D</span></span><span class="mopen">(</span><span class="mord coloredeq eqn" style=""><span class="mord" style=""><span class="mord vbox" style=""><span class="thinbox" style=""><span class="rlap" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord" style=""><span class="mord mathnormal" style="">x</span></span></span><span class="fix"></span></span></span></span><span class="mspace" style="margin-right:0.050187499999999996em"></span><span class="mord mathnormal" style="">x</span></span></span><span class="mpunct">;</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:-0.02778em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">g</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span> gives the probability that <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord" style=""><span class="mord vbox" style=""><span class="thinbox" style=""><span class="rlap" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord" style=""><span class="mord mathnormal" style="">x</span></span></span><span class="fix"></span></span></span></span><span class="mspace" style="margin-right:0.050187499999999996em"></span><span class="mord mathnormal" style="">x</span></span></span></span></span></span></span> came from data rather than <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqv" style=""><span class="mord mathnormal" style="">G</span></span></span></span></span></span>.</p>
<p>We train <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="margin-right:0.02778em">D</span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord coloredeq eqv" style=""><span class="mord mathnormal" style="">G</span></span></span></span></span></span> simultaneously on a two-player min-max game with value function <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="mopen">(</span><span class="mord coloredeq eqv" style=""><span class="mord mathnormal" style="">G</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="margin-right:0.02778em">D</span></span><span class="mclose">)</span></span></span></span></span>.</p>
<p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.494331em;vertical-align:-0.7443310000000001em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.66786em;"><span style="top:-2.355669em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqv" style=""><span class="mord mathnormal mtight" style="">G</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop">min</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.744331em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.43055999999999983em;"><span style="top:-2.355669em;margin-left:0em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq equ" style=""><span class="mord mathnormal mtight" style="margin-right:0.02778em">D</span></span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span><span class="mop">max</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.7443310000000001em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.22222em;">V</span><span class="mopen">(</span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="margin-right:0.02778em">D</span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq eqv" style=""><span class="mord mathnormal" style="">G</span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1.2051999999999998em;vertical-align:-0.3551999999999999em;"></span><span class="mop"><span class="mop mathbb" style="position:relative;top:0.094445em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.34480000000000005em;"><span style="top:-2.5198em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqn" style=""><span class="mord mtight" style=""><span class="mord vbox mtight" style=""><span class="thinbox mtight" style=""><span class="rlap mtight" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span></span></span><span class="fix"></span></span></span></span><span class="mspace mtight" style="margin-right:0.07169642857142856em"></span><span class="mord mathnormal mtight" style="">x</span></span></span><span class="mrel mtight"></span><span class="mord mtight coloredeq eqi" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqk" style=""><span class="mord mathnormal mtight" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">d</span><span class="mord mathnormal mtight" style="">a</span><span class="mord mathnormal mtight" style="">t</span><span class="mord mathnormal mtight" style="">a</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span><span class="mopen mtight" style="">(</span><span class="mord mtight" style=""><span class="mord mtight coloredeq eqn" style=""><span class="mord vbox mtight" style=""><span class="thinbox mtight" style=""><span class="rlap mtight" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">x</span></span></span><span class="fix"></span></span></span></span><span class="mspace mtight" style="margin-right:0.07169642857142856em"></span><span class="mord mathnormal mtight" style="">x</span></span></span><span class="mclose mtight" style="">)</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.3551999999999999em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="delimsizing size1">[</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="margin-right:0.02778em">D</span></span><span class="mopen">(</span><span class="mord coloredeq eqn" style=""><span class="mord" style=""><span class="mord vbox" style=""><span class="thinbox" style=""><span class="rlap" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord" style=""><span class="mord mathnormal" style="">x</span></span></span><span class="fix"></span></span></span></span><span class="mspace" style="margin-right:0.050187499999999996em"></span><span class="mord mathnormal" style="">x</span></span></span><span class="mclose">)</span><span class="mord"><span class="delimsizing size1">]</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.2051999999999998em;vertical-align:-0.3551999999999999em;"></span><span class="mop"><span class="mop mathbb" style="position:relative;top:0.094445em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.34480000000000005em;"><span style="top:-2.5198em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqo" style=""><span class="mord mtight" style=""><span class="mord vbox mtight" style=""><span class="thinbox mtight" style=""><span class="rlap mtight" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span></span></span><span class="fix"></span></span></span></span><span class="mspace mtight" style="margin-right:0.07169642857142856em"></span><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span></span></span><span class="mrel mtight"></span><span class="mord mtight coloredeq eqg" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16454285714285719em;"><span style="top:-2.357em;margin-left:0em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqo" style=""><span class="mord vbox mtight" style=""><span class="thinbox mtight" style=""><span class="rlap mtight" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span></span></span><span class="fix"></span></span></span></span><span class="mspace mtight" style="margin-right:0.10037499999999999em"></span><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.143em;"><span></span></span></span></span></span></span><span class="mopen mtight" style="">(</span><span class="mord mtight" style=""><span class="mord mtight coloredeq eqo" style=""><span class="mord vbox mtight" style=""><span class="thinbox mtight" style=""><span class="rlap mtight" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span></span></span><span class="fix"></span></span></span></span><span class="mspace mtight" style="margin-right:0.07169642857142856em"></span><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span></span></span><span class="mclose mtight" style="">)</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.3551999999999999em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="delimsizing size1">[</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mopen">(</span><span class="mord coloredeq eqt" style=""><span class="mord" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.20001em;vertical-align:-0.35001em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="margin-right:0.02778em">D</span></span><span class="mopen">(</span><span class="mord coloredeq eqv" style=""><span class="mord mathnormal" style="">G</span></span><span class="mopen">(</span><span class="mord coloredeq eqo" style=""><span class="mord" style=""><span class="mord vbox" style=""><span class="thinbox" style=""><span class="rlap" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span><span class="fix"></span></span></span></span><span class="mspace" style="margin-right:0.050187499999999996em"></span><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span><span class="mclose">))</span><span class="mord"><span class="delimsizing size1">]</span></span></span></span></span></span></span></p>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqi" style=""><span class="mord" style=""><span class="mord coloredeq eqk" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">d</span><span class="mord mathnormal mtight" style="">a</span><span class="mord mathnormal mtight" style="">t</span><span class="mord mathnormal mtight" style="">a</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord coloredeq eqn" style=""><span class="mord vbox" style=""><span class="thinbox" style=""><span class="rlap" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord" style=""><span class="mord mathnormal" style="">x</span></span></span><span class="fix"></span></span></span></span><span class="mspace" style="margin-right:0.050187499999999996em"></span><span class="mord mathnormal" style="">x</span></span></span><span class="mclose" style="">)</span></span></span></span></span></span> is the probability distribution over data, whilst <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqg" style=""><span class="mord" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mtight coloredeq eqo" style=""><span class="mord vbox mtight" style=""><span class="thinbox mtight" style=""><span class="rlap mtight" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span></span></span><span class="fix"></span></span></span></span><span class="mspace mtight" style="margin-right:0.07169642857142856em"></span><span class="mord mathnormal mtight" style="margin-right:0.04398em">z</span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mopen" style="">(</span><span class="mord" style=""><span class="mord coloredeq eqo" style=""><span class="mord vbox" style=""><span class="thinbox" style=""><span class="rlap" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span><span class="fix"></span></span></span></span><span class="mspace" style="margin-right:0.050187499999999996em"></span><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span><span class="mclose" style="">)</span></span></span></span></span></span> probability distribution of <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqo" style=""><span class="mord" style=""><span class="mord vbox" style=""><span class="thinbox" style=""><span class="rlap" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span><span class="fix"></span></span></span></span><span class="mspace" style="margin-right:0.050187499999999996em"></span><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span></span></span></span></span>, which is set to gaussian noise.</p>
<p>This file defines the loss functions. <a href="experiment.html">Here</a> is an MNIST example with two multilayer perceptron for the generator and discriminator.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">34</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">35</span><span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>
<span class="lineno">36</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
<span class="lineno">37</span><span class="kn">import</span> <span class="nn">torch.utils.data</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Discriminator Loss</h2>
<p>Discriminator should <strong>ascend</strong> on the gradient,</p>
<p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:3.027669em;vertical-align:-1.277669em;"></span><span class="mord"><span class="mord"></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999985em;"><span style="top:-2.55em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3448em;"><span style="top:-2.3487714285714287em;margin-left:-0.02778em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight">d</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15122857142857138em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.25586em;"><span></span></span></span></span></span></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.32144em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqw" style=""><span class="mord mathnormal" style="">m</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqt" style=""><span class="mord" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.6513970000000002em;"><span style="top:-1.872331em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mrel mtight">=</span><span class="mord mtight coloredeq eqt" style=""><span class="mord mtight" style="">1</span></span></span></span></span><span style="top:-3.050005em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op"></span></span></span><span style="top:-4.3000050000000005em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqw" style=""><span class="mord mathnormal mtight" style="">m</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.277669em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="delimsizing size4">[</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="margin-right:0.02778em">D</span></span><span class="mord"><span class="delimsizing size2">(</span></span><span class="mord"><span class="mord coloredeq eqn" style=""><span class="mord" style=""><span class="mord vbox" style=""><span class="thinbox" style=""><span class="rlap" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord" style=""><span class="mord mathnormal" style="">x</span></span></span><span class="fix"></span></span></span></span><span class="mspace" style="margin-right:0.050187499999999996em"></span><span class="mord mathnormal" style="">x</span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.938em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqq" style=""><span class="mopen mtight" style="">(</span><span class="mord mathnormal mtight" style="">i</span><span class="mclose mtight" style="">)</span></span></span></span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size2">)</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.80002em;vertical-align:-0.65002em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="delimsizing size2">(</span></span><span class="mord coloredeq eqt" style=""><span class="mord" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:3.0000299999999998em;vertical-align:-1.25003em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="margin-right:0.02778em">D</span></span><span class="mord"><span class="delimsizing size2">(</span></span><span class="mord coloredeq eqv" style=""><span class="mord mathnormal" style="">G</span></span><span class="mord"><span class="delimsizing size2">(</span></span><span class="mord"><span class="mord coloredeq eqo" style=""><span class="mord" style=""><span class="mord vbox" style=""><span class="thinbox" style=""><span class="rlap" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span><span class="fix"></span></span></span></span><span class="mspace" style="margin-right:0.050187499999999996em"></span><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.938em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqq" style=""><span class="mopen mtight" style="">(</span><span class="mord mathnormal mtight" style="">i</span><span class="mclose mtight" style="">)</span></span></span></span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size2">)</span></span><span class="mord"><span class="delimsizing size2">)</span></span><span class="mord"><span class="delimsizing size2">)</span></span><span class="mord"><span class="delimsizing size4">]</span></span></span></span></span></span></span></p>
<p><span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqw" style=""><span class="mord mathnormal" style="">m</span></span></span></span></span></span> is the mini-batch size and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqq" style=""><span class="mopen" style="">(</span><span class="mord mathnormal" style="">i</span><span class="mclose" style="">)</span></span></span></span></span></span> is used to index samples in the mini-batch. <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord" style=""><span class="mord vbox" style=""><span class="thinbox" style=""><span class="rlap" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord" style=""><span class="mord mathnormal" style="">x</span></span></span><span class="fix"></span></span></span></span><span class="mspace" style="margin-right:0.050187499999999996em"></span><span class="mord mathnormal" style="">x</span></span></span></span></span></span></span> are samples from <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqk" style=""><span class="mord" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">d</span><span class="mord mathnormal mtight" style="">a</span><span class="mord mathnormal mtight" style="">t</span><span class="mord mathnormal mtight" style="">a</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqo" style=""><span class="mord" style=""><span class="mord vbox" style=""><span class="thinbox" style=""><span class="rlap" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span><span class="fix"></span></span></span></span><span class="mspace" style="margin-right:0.050187499999999996em"></span><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span></span></span></span></span> are samples from <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right:0.04398em;">z</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span>.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">40</span><span class="k">class</span> <span class="nc">DiscriminatorLogitsLoss</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">55</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">smoothing</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.2</span><span class="p">):</span>
<span class="lineno">56</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>We use PyTorch Binary Cross Entropy Loss, which is <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.80002em;vertical-align:-0.65002em;"></span><span class="mord"></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop op-symbol small-op" style="position:relative;top:-0.0000050000000000050004em;"></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="delimsizing size2">[</span></span><span class="mord coloredeq eqx" style=""><span class="mord mathnormal" style="margin-right:0.03588em">y</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mopen">(</span><span class="mord coloredeq eqm" style=""><span class="mord accent" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqx" style="margin-right:0.03588em">y</span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.19444em;"><span class="mord" style="">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.19444em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mopen">(</span><span class="mord coloredeq eqt" style=""><span class="mord" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord coloredeq eqx" style=""><span class="mord mathnormal" style="margin-right:0.03588em">y</span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mopen">(</span><span class="mord coloredeq eqt" style=""><span class="mord" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:1.80002em;vertical-align:-0.65002em;"></span><span class="mord coloredeq eqm" style=""><span class="mord accent" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqx" style="margin-right:0.03588em">y</span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.19444em;"><span class="mord" style="">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.19444em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mord"><span class="delimsizing size2">]</span></span></span></span></span></span>, where <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqx" style=""><span class="mord mathnormal" style="margin-right:0.03588em">y</span></span></span></span></span></span> are the labels and <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqm" style=""><span class="mord accent" style=""><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.69444em;"><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="mord" style=""><span class="mord mathnormal coloredeq eqx" style="margin-right:0.03588em">y</span></span></span><span style="top:-3em;"><span class="pstrut" style="height:3em;"></span><span class="accent-body" style="left:-0.19444em;"><span class="mord" style="">^</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.19444em;"><span></span></span></span></span></span></span></span></span></span></span> are the predictions. <em>Note the negative sign</em>. We use labels equal to <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqt" style=""><span class="mord" style="">1</span></span></span></span></span></span> for <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord" style=""><span class="mord vbox" style=""><span class="thinbox" style=""><span class="rlap" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord" style=""><span class="mord mathnormal" style="">x</span></span></span><span class="fix"></span></span></span></span><span class="mspace" style="margin-right:0.050187499999999996em"></span><span class="mord mathnormal" style="">x</span></span></span></span></span></span></span> from <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqk" style=""><span class="mord" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight" style="">d</span><span class="mord mathnormal mtight" style="">a</span><span class="mord mathnormal mtight" style="">t</span><span class="mord mathnormal mtight" style="">a</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span></span></span></span></span></span> and labels equal to <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">0</span></span></span></span></span> for <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord" style=""><span class="mord vbox" style=""><span class="thinbox" style=""><span class="rlap" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord" style=""><span class="mord mathnormal" style="">x</span></span></span><span class="fix"></span></span></span></span><span class="mspace" style="margin-right:0.050187499999999996em"></span><span class="mord mathnormal" style="">x</span></span></span></span></span></span></span> from <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.32833099999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqv" style="">G</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord" style="">.</span></span></span></span></span></span> Then descending on the sum of these is the same as ascending on the above gradient.</p>
<p><code class="highlight"><span></span><span class="n">BCEWithLogitsLoss</span></code>
combines softmax and binary cross entropy loss. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">67</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_true</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">()</span>
<span class="lineno">68</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_false</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>We use label smoothing because it seems to work better in some cases </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span> <span class="bp">self</span><span class="o">.</span><span class="n">smoothing</span> <span class="o">=</span> <span class="n">smoothing</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Labels are registered as buffered and persistence is set to <code class="highlight"><span></span><span class="kc">False</span></code>
. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">74</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;labels_true&#39;</span><span class="p">,</span> <span class="n">_create_labels</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">smoothing</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">),</span> <span class="kc">False</span><span class="p">)</span>
<span class="lineno">75</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;labels_false&#39;</span><span class="p">,</span> <span class="n">_create_labels</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">smoothing</span><span class="p">),</span> <span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p> <code class="highlight"><span></span><span class="n">logits_true</span></code>
are logits from <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.138em;vertical-align:-0.25em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="margin-right:0.02778em">D</span></span><span class="mopen">(</span><span class="mord"><span class="mord coloredeq eqn" style=""><span class="mord" style=""><span class="mord vbox" style=""><span class="thinbox" style=""><span class="rlap" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord" style=""><span class="mord mathnormal" style="">x</span></span></span><span class="fix"></span></span></span></span><span class="mspace" style="margin-right:0.050187499999999996em"></span><span class="mord mathnormal" style="">x</span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8879999999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqq" style=""><span class="mopen mtight" style="">(</span><span class="mord mathnormal mtight" style="">i</span><span class="mclose mtight" style="">)</span></span></span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span> and <code class="highlight"><span></span><span class="n">logits_false</span></code>
are logits from <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:1.138em;vertical-align:-0.25em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="margin-right:0.02778em">D</span></span><span class="mopen">(</span><span class="mord coloredeq eqv" style=""><span class="mord mathnormal" style="">G</span></span><span class="mopen">(</span><span class="mord"><span class="mord coloredeq eqo" style=""><span class="mord" style=""><span class="mord vbox" style=""><span class="thinbox" style=""><span class="rlap" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span><span class="fix"></span></span></span></span><span class="mspace" style="margin-right:0.050187499999999996em"></span><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.8879999999999999em;"><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqq" style=""><span class="mopen mtight" style="">(</span><span class="mord mathnormal mtight" style="">i</span><span class="mclose mtight" style="">)</span></span></span></span></span></span></span></span></span></span><span class="mclose">))</span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">77</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits_true</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">logits_false</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">82</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">logits_true</span><span class="p">)</span> <span class="o">&gt;</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">labels_true</span><span class="p">):</span>
<span class="lineno">83</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s2">&quot;labels_true&quot;</span><span class="p">,</span>
<span class="lineno">84</span> <span class="n">_create_labels</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">logits_true</span><span class="p">),</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">smoothing</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">logits_true</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="kc">False</span><span class="p">)</span>
<span class="lineno">85</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">logits_false</span><span class="p">)</span> <span class="o">&gt;</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">labels_false</span><span class="p">):</span>
<span class="lineno">86</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s2">&quot;labels_false&quot;</span><span class="p">,</span>
<span class="lineno">87</span> <span class="n">_create_labels</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">logits_false</span><span class="p">),</span> <span class="mf">0.0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">smoothing</span><span class="p">,</span> <span class="n">logits_false</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="kc">False</span><span class="p">)</span>
<span class="lineno">88</span>
<span class="lineno">89</span> <span class="k">return</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loss_true</span><span class="p">(</span><span class="n">logits_true</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">labels_true</span><span class="p">[:</span><span class="nb">len</span><span class="p">(</span><span class="n">logits_true</span><span class="p">)]),</span>
<span class="lineno">90</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_false</span><span class="p">(</span><span class="n">logits_false</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">labels_false</span><span class="p">[:</span><span class="nb">len</span><span class="p">(</span><span class="n">logits_false</span><span class="p">)]))</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<h2>Generator Loss</h2>
<p>Generator should <strong>descend</strong> on the gradient,</p>
<p><span ><span class="katex-display"><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:3.027669em;vertical-align:-1.277669em;"></span><span class="mord"><span class="mord"></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.33610799999999996em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.02778em;">θ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.16454285714285716em;"><span style="top:-2.357em;margin-left:-0.02778em;margin-right:0.07142857142857144em;"><span class="pstrut" style="height:2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">g</span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.2818857142857143em;"><span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.34731999999999996em;"><span></span></span></span></span></span></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.32144em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqw" style=""><span class="mord mathnormal" style="">m</span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord coloredeq eqt" style=""><span class="mord" style="">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.6513970000000002em;"><span style="top:-1.872331em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mrel mtight">=</span><span class="mord mtight coloredeq eqt" style=""><span class="mord mtight" style="">1</span></span></span></span></span><span style="top:-3.050005em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op"></span></span></span><span style="top:-4.3000050000000005em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight coloredeq eqw" style=""><span class="mord mathnormal mtight" style="">m</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:1.277669em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="delimsizing size4">[</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="delimsizing size2">(</span></span><span class="mord coloredeq eqt" style=""><span class="mord" style="">1</span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin"></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:3.0000299999999998em;vertical-align:-1.25003em;"></span><span class="mord coloredeq equ" style=""><span class="mord mathnormal" style="margin-right:0.02778em">D</span></span><span class="mord"><span class="delimsizing size2">(</span></span><span class="mord coloredeq eqv" style=""><span class="mord mathnormal" style="">G</span></span><span class="mord"><span class="delimsizing size2">(</span></span><span class="mord"><span class="mord coloredeq eqo" style=""><span class="mord" style=""><span class="mord vbox" style=""><span class="thinbox" style=""><span class="rlap" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord" style=""><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span><span class="fix"></span></span></span></span><span class="mspace" style="margin-right:0.050187499999999996em"></span><span class="mord mathnormal" style="margin-right:0.04398em">z</span></span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height:0.938em;"><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight coloredeq eqq" style=""><span class="mopen mtight" style="">(</span><span class="mord mathnormal mtight" style="">i</span><span class="mclose mtight" style="">)</span></span></span></span></span></span></span></span></span></span><span class="mord"><span class="delimsizing size2">)</span></span><span class="mord"><span class="delimsizing size2">)</span></span><span class="mord"><span class="delimsizing size2">)</span></span><span class="mord"><span class="delimsizing size4">]</span></span></span></span></span></span></span></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">93</span><span class="k">class</span> <span class="nc">GeneratorLogitsLoss</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">104</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">smoothing</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.2</span><span class="p">):</span>
<span class="lineno">105</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">106</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_true</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BCEWithLogitsLoss</span><span class="p">()</span>
<span class="lineno">107</span> <span class="bp">self</span><span class="o">.</span><span class="n">smoothing</span> <span class="o">=</span> <span class="n">smoothing</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>We use labels equal to <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord coloredeq eqt" style=""><span class="mord" style="">1</span></span></span></span></span></span> for <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="mord coloredeq eqn" style=""><span class="mord" style=""><span class="mord vbox" style=""><span class="thinbox" style=""><span class="rlap" style=""><span class="strut" style="height:0.43056em;vertical-align:0em;"></span><span class="inner"><span class="mord" style=""><span class="mord mathnormal" style="">x</span></span></span><span class="fix"></span></span></span></span><span class="mspace" style="margin-right:0.050187499999999996em"></span><span class="mord mathnormal" style="">x</span></span></span></span></span></span></span> from <span ><span class="katex"><span aria-hidden="true" class="katex-html"><span class="base"><span class="strut" style="height:0.625em;vertical-align:-0.19444em;"></span><span class="mord coloredeq eqp" style=""><span class="mord" style=""><span class="mord mathnormal" style="">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.32833099999999993em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight" style=""><span class="mord mtight" style=""><span class="mord mtight" style=""><span class="mord mathnormal mtight coloredeq eqv" style="">G</span></span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mord" style="">.</span></span></span></span></span></span> Then descending on this loss is the same as descending on the above gradient. </p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">111</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s1">&#39;fake_labels&#39;</span><span class="p">,</span> <span class="n">_create_labels</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">smoothing</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">),</span> <span class="kc">False</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">113</span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logits</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span>
<span class="lineno">114</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">logits</span><span class="p">)</span> <span class="o">&gt;</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">fake_labels</span><span class="p">):</span>
<span class="lineno">115</span> <span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s2">&quot;fake_labels&quot;</span><span class="p">,</span>
<span class="lineno">116</span> <span class="n">_create_labels</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">logits</span><span class="p">),</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">smoothing</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">logits</span><span class="o">.</span><span class="n">device</span><span class="p">),</span> <span class="kc">False</span><span class="p">)</span>
<span class="lineno">117</span>
<span class="lineno">118</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_true</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">fake_labels</span><span class="p">[:</span><span class="nb">len</span><span class="p">(</span><span class="n">logits</span><span class="p">)])</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p> Create smoothed labels</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">121</span><span class="k">def</span> <span class="nf">_create_labels</span><span class="p">(</span><span class="n">n</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">r1</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">r2</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">125</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">uniform_</span><span class="p">(</span><span class="n">r1</span><span class="p">,</span> <span class="n">r2</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>