Denoising Diffusion Probabilistic Models (#98)

This commit is contained in:
Varuna Jayasiri
2021-10-08 21:33:04 +05:30
committed by GitHub
parent e309638fea
commit e2be5ddf35
21 changed files with 6123 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,945 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="Training code for Denoising Diffusion Probabilistic Model."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Denoising Diffusion Probabilistic Models (DDPM) training"/>
<meta name="twitter:description" content="Training code for Denoising Diffusion Probabilistic Model."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/diffusion/ddpm/experiment.html"/>
<meta property="og:title" content="Denoising Diffusion Probabilistic Models (DDPM) training"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Denoising Diffusion Probabilistic Models (DDPM) training"/>
<meta property="og:description" content="Training code for Denoising Diffusion Probabilistic Model."/>
<title>Denoising Diffusion Probabilistic Models (DDPM) training</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/diffusion/ddpm/experiment.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">diffusion</a>
<a class="parent" href="index.html">ddpm</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/diffusion/ddpm/experiment.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="index.html">Denoising Diffusion Probabilistic Models (DDPM)</a> training</h1>
<p>This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this
<a href="https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3">discussion on fast.ai</a>.
Save the images inside <a href="#dataset_path"><code>data/celebA</code> folder</a>.</p>
<p>The paper had used a exponential moving average of the model with a decay of $0.9999$. We have skipped this for
simplicity.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">18</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span>
<span class="lineno">19</span>
<span class="lineno">20</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">21</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
<span class="lineno">22</span><span class="kn">import</span> <span class="nn">torchvision</span>
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
<span class="lineno">24</span>
<span class="lineno">25</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">lab</span><span class="p">,</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">monit</span>
<span class="lineno">26</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">BaseConfigs</span><span class="p">,</span> <span class="n">option</span>
<span class="lineno">27</span><span class="kn">from</span> <span class="nn">labml_helpers.device</span> <span class="kn">import</span> <span class="n">DeviceConfigs</span>
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">labml_nn.diffusion.ddpm</span> <span class="kn">import</span> <span class="n">DenoiseDiffusion</span>
<span class="lineno">29</span><span class="kn">from</span> <span class="nn">labml_nn.diffusion.ddpm.unet</span> <span class="kn">import</span> <span class="n">UNet</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Configurations</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">32</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">BaseConfigs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<p>Device to train the model on.
<a href="https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs"><code>DeviceConfigs</code></a>
picks up an available CUDA device or defaults to CPU.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">39</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">DeviceConfigs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
<p>U-Net model for $\color{cyan}{\epsilon_\theta}(x_t, t)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">42</span> <span class="n">eps_model</span><span class="p">:</span> <span class="n">UNet</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p><a href="index.html">DDPM algorithm</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">44</span> <span class="n">diffusion</span><span class="p">:</span> <span class="n">DenoiseDiffusion</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>Number of channels in the image. $3$ for RGB.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">47</span> <span class="n">image_channels</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">3</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>Image size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>Number of channels in the initial feature map</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">51</span> <span class="n">n_channels</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>The list of channel numbers at each resolution.
The number of channels is <code>channel_multipliers[i] * n_channels</code></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">54</span> <span class="n">channel_multipliers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<p>The list of booleans that indicate whether to use attention at each resolution</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">56</span> <span class="n">is_attention</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">True</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p>Number of time steps $T$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">59</span> <span class="n">n_steps</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1_000</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>Batch size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">61</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
<p>Number of samples to generate</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">63</span> <span class="n">n_samples</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">16</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<p>Learning rate</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">learning_rate</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">2e-5</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>Number of training epochs</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">epochs</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1_000</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>Dataset</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">71</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Dataloader</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">data_loader</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<p>Adam optimizer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">76</span> <span class="n">optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">78</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p>Create $\color{cyan}{\epsilon_\theta}(x_t, t)$ model</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">80</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps_model</span> <span class="o">=</span> <span class="n">UNet</span><span class="p">(</span>
<span class="lineno">81</span> <span class="n">image_channels</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">image_channels</span><span class="p">,</span>
<span class="lineno">82</span> <span class="n">n_channels</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span><span class="p">,</span>
<span class="lineno">83</span> <span class="n">ch_mults</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">channel_multipliers</span><span class="p">,</span>
<span class="lineno">84</span> <span class="n">is_attn</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">is_attention</span><span class="p">,</span>
<span class="lineno">85</span> <span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>Create <a href="index.html">DDPM class</a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">88</span> <span class="bp">self</span><span class="o">.</span><span class="n">diffusion</span> <span class="o">=</span> <span class="n">DenoiseDiffusion</span><span class="p">(</span>
<span class="lineno">89</span> <span class="n">eps_model</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">eps_model</span><span class="p">,</span>
<span class="lineno">90</span> <span class="n">n_steps</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">n_steps</span><span class="p">,</span>
<span class="lineno">91</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
<span class="lineno">92</span> <span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>Create dataloader</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">95</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>Create optimizer</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">97</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">eps_model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>Image logging</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">100</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_image</span><span class="p">(</span><span class="s2">&quot;sample&quot;</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<h3>Sample images</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">102</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">106</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<p>$x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">108</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">n_samples</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_channels</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size</span><span class="p">],</span>
<span class="lineno">109</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>Remove noise for $T$ steps</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">112</span> <span class="k">for</span> <span class="n">t_</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">&#39;Sample&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_steps</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>$t$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">114</span> <span class="n">t</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_steps</span> <span class="o">-</span> <span class="n">t_</span> <span class="o">-</span> <span class="mi">1</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">116</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">diffusion</span><span class="o">.</span><span class="n">p_sample</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">new_full</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_samples</span><span class="p">,),</span> <span class="n">t</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">))</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Log samples</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">119</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s1">&#39;sample&#39;</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<h3>Train</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">121</span> <span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>Iterate through the dataset</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">127</span> <span class="k">for</span> <span class="n">data</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">&#39;Train&#39;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_loader</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-33'>
<div class='docs'>
<div class='section-link'>
<a href='#section-33'>#</a>
</div>
<p>Increment global step</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">129</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-34'>
<div class='docs'>
<div class='section-link'>
<a href='#section-34'>#</a>
</div>
<p>Move data to device</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">131</span> <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-35'>
<div class='docs'>
<div class='section-link'>
<a href='#section-35'>#</a>
</div>
<p>Make the gradients zero</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">134</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-36'>
<div class='docs'>
<div class='section-link'>
<a href='#section-36'>#</a>
</div>
<p>Calculate loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">diffusion</span><span class="o">.</span><span class="n">loss</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-37'>
<div class='docs'>
<div class='section-link'>
<a href='#section-37'>#</a>
</div>
<p>Compute gradients</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-38'>
<div class='docs'>
<div class='section-link'>
<a href='#section-38'>#</a>
</div>
<p>Take an optimization step</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">140</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-39'>
<div class='docs'>
<div class='section-link'>
<a href='#section-39'>#</a>
</div>
<p>Track the loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">142</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s1">&#39;loss&#39;</span><span class="p">,</span> <span class="n">loss</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-40'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-40'>#</a>
</div>
<h3>Training loop</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">144</span> <span class="k">def</span> <span class="nf">run</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-41'>
<div class='docs'>
<div class='section-link'>
<a href='#section-41'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">148</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">loop</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">epochs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-42'>
<div class='docs'>
<div class='section-link'>
<a href='#section-42'>#</a>
</div>
<p>Train the model</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">150</span> <span class="bp">self</span><span class="o">.</span><span class="n">train</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-43'>
<div class='docs'>
<div class='section-link'>
<a href='#section-43'>#</a>
</div>
<p>Sample some images</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">152</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-44'>
<div class='docs'>
<div class='section-link'>
<a href='#section-44'>#</a>
</div>
<p>New line in the console</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">154</span> <span class="n">tracker</span><span class="o">.</span><span class="n">new_line</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-45'>
<div class='docs'>
<div class='section-link'>
<a href='#section-45'>#</a>
</div>
<p>Save the model</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">experiment</span><span class="o">.</span><span class="n">save_checkpoint</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-46'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-46'>#</a>
</div>
<h3>CelebA HQ dataset</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">159</span><span class="k">class</span> <span class="nc">CelebADataset</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-47'>
<div class='docs'>
<div class='section-link'>
<a href='#section-47'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">164</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
<span class="lineno">165</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-48'>
<div class='docs'>
<div class='section-link'>
<a href='#section-48'>#</a>
</div>
<p>CelebA images folder</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">168</span> <span class="n">folder</span> <span class="o">=</span> <span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">&#39;celebA&#39;</span></pre></div>
</div>
</div>
<div class='section' id='section-49'>
<div class='docs'>
<div class='section-link'>
<a href='#section-49'>#</a>
</div>
<p>List of files</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">170</span> <span class="bp">self</span><span class="o">.</span><span class="n">_files</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">folder</span><span class="o">.</span><span class="n">glob</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;**/*.jpg&#39;</span><span class="p">)]</span></pre></div>
</div>
</div>
<div class='section' id='section-50'>
<div class='docs'>
<div class='section-link'>
<a href='#section-50'>#</a>
</div>
<p>Transformations to resize the image and convert to tensor</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">173</span> <span class="bp">self</span><span class="o">.</span><span class="n">_transform</span> <span class="o">=</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
<span class="lineno">174</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">(</span><span class="n">image_size</span><span class="p">),</span>
<span class="lineno">175</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
<span class="lineno">176</span> <span class="p">])</span></pre></div>
</div>
</div>
<div class='section' id='section-51'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-51'>#</a>
</div>
<p>Size of the dataset</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">178</span> <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-52'>
<div class='docs'>
<div class='section-link'>
<a href='#section-52'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">182</span> <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_files</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-53'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-53'>#</a>
</div>
<p>Get an image</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">184</span> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-54'>
<div class='docs'>
<div class='section-link'>
<a href='#section-54'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">188</span> <span class="n">img</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_files</span><span class="p">[</span><span class="n">index</span><span class="p">])</span>
<span class="lineno">189</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_transform</span><span class="p">(</span><span class="n">img</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-55'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-55'>#</a>
</div>
<p>Create CelebA dataset</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">192</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">dataset</span><span class="p">,</span> <span class="s1">&#39;CelebA&#39;</span><span class="p">)</span>
<span class="lineno">193</span><span class="k">def</span> <span class="nf">celeb_dataset</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-56'>
<div class='docs'>
<div class='section-link'>
<a href='#section-56'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">197</span> <span class="k">return</span> <span class="n">CelebADataset</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">image_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-57'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-57'>#</a>
</div>
<h3>MNIST dataset</h3>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">200</span><span class="k">class</span> <span class="nc">MNISTDataset</span><span class="p">(</span><span class="n">torchvision</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-58'>
<div class='docs'>
<div class='section-link'>
<a href='#section-58'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">205</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">image_size</span><span class="p">):</span>
<span class="lineno">206</span> <span class="n">transform</span> <span class="o">=</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
<span class="lineno">207</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">(</span><span class="n">image_size</span><span class="p">),</span>
<span class="lineno">208</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
<span class="lineno">209</span> <span class="p">])</span>
<span class="lineno">210</span>
<span class="lineno">211</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()),</span> <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transform</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-59'>
<div class='docs'>
<div class='section-link'>
<a href='#section-59'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">213</span> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">item</span><span class="p">):</span>
<span class="lineno">214</span> <span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__getitem__</span><span class="p">(</span><span class="n">item</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-60'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-60'>#</a>
</div>
<p>Create MNIST dataset</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">217</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">dataset</span><span class="p">,</span> <span class="s1">&#39;MNIST&#39;</span><span class="p">)</span>
<span class="lineno">218</span><span class="k">def</span> <span class="nf">mnist_dataset</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-61'>
<div class='docs'>
<div class='section-link'>
<a href='#section-61'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">222</span> <span class="k">return</span> <span class="n">MNISTDataset</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">image_size</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-62'>
<div class='docs'>
<div class='section-link'>
<a href='#section-62'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">225</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
</div>
</div>
<div class='section' id='section-63'>
<div class='docs'>
<div class='section-link'>
<a href='#section-63'>#</a>
</div>
<p>Create experiment</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">227</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;diffuse&#39;</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-64'>
<div class='docs'>
<div class='section-link'>
<a href='#section-64'>#</a>
</div>
<p>Create configurations</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">230</span> <span class="n">configs</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-65'>
<div class='docs'>
<div class='section-link'>
<a href='#section-65'>#</a>
</div>
<p>Set configurations. You can override the defaults by passing the values in the dictionary.</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">233</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">configs</span><span class="p">,</span> <span class="p">{</span>
<span class="lineno">234</span> <span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-66'>
<div class='docs'>
<div class='section-link'>
<a href='#section-66'>#</a>
</div>
<p>Initialize</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">237</span> <span class="n">configs</span><span class="o">.</span><span class="n">init</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-67'>
<div class='docs'>
<div class='section-link'>
<a href='#section-67'>#</a>
</div>
<p>Set models for saving and loading</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">240</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">&#39;eps_model&#39;</span><span class="p">:</span> <span class="n">configs</span><span class="o">.</span><span class="n">eps_model</span><span class="p">})</span></pre></div>
</div>
</div>
<div class='section' id='section-68'>
<div class='docs'>
<div class='section-link'>
<a href='#section-68'>#</a>
</div>
<p>Start and run the training loop</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">243</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
<span class="lineno">244</span> <span class="n">configs</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='section' id='section-69'>
<div class='docs'>
<div class='section-link'>
<a href='#section-69'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">248</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
<span class="lineno">249</span> <span class="n">main</span><span class="p">()</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

View File

@ -0,0 +1,666 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="PyTorch implementation and tutorial of the paper Denoising Diffusion Probabilistic Models (DDPM)."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Denoising Diffusion Probabilistic Models (DDPM)"/>
<meta name="twitter:description" content="PyTorch implementation and tutorial of the paper Denoising Diffusion Probabilistic Models (DDPM)."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/diffusion/ddpm/index.html"/>
<meta property="og:title" content="Denoising Diffusion Probabilistic Models (DDPM)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Denoising Diffusion Probabilistic Models (DDPM)"/>
<meta property="og:description" content="PyTorch implementation and tutorial of the paper Denoising Diffusion Probabilistic Models (DDPM)."/>
<title>Denoising Diffusion Probabilistic Models (DDPM)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/diffusion/ddpm/index.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">diffusion</a>
<a class="parent" href="index.html">ddpm</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/diffusion/ddpm/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Denoising Diffusion Probabilistic Models (DDPM)</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation/tutorial of the paper
<a href="https://papers.labml.ai/paper/2006.11239">Denoising Diffusion Probabilistic Models</a>.</p>
<p>In simple terms, we get an image from data and add noise step by step.
Then We train a model to predict that noise at each step and use the model to
generate images.</p>
<p>The following definitions and derivations show how this works.
For details please refer to <a href="https://papers.labml.ai/paper/2006.11239">the paper</a>.</p>
<h2>Forward Process</h2>
<p>The forward process adds noise to the data $x_0 \sim q(x_0)$, for $T$ timesteps.</p>
<p>
<script type="math/tex; mode=display">\begin{align}
q(x_t | x_{t-1}) = \mathcal{N}\big(x_t; \sqrt{1- \beta_t} x_{t-1}, \beta_t \mathbf{I}\big) \\
q(x_{1:T} | x_0) = \prod_{t = 1}^{T} q(x_t | x_{t-1})
\end{align}</script>
</p>
<p>where $\beta_1, \dots, \beta_T$ is the variance schedule.</p>
<p>We can sample $x_t$ at any timestep $t$ with,</p>
<p>
<script type="math/tex; mode=display">\begin{align}
q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
\end{align}</script>
</p>
<p>where $\alpha_t = 1 - \beta_t$ and $\bar\alpha_t = \prod_{s=1}^t \alpha_s$</p>
<h2>Reverse Process</h2>
<p>The reverse process removes noise starting at $p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
for $T$ time steps.</p>
<p>
<script type="math/tex; mode=display">\begin{align}
\color{cyan}{p_\theta}(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1};
\color{cyan}{\mu_\theta}x_t, t), \color{cyan}{\Sigma_\theta}(x_t, t)\big) \\
\color{cyan}{p_\theta}(x_{0:T}) &= \color{cyan}{p_\theta}(x_T) \prod_{t = 1}^{T} \color{cyan}{p_\theta}(x_{t-1} | x_t) \\
\color{cyan}{p_\theta}(x_0) &= \int \color{cyan}{p_\theta}(x_{0:T}) dx_{1:T}
\end{align}</script>
</p>
<p>$\color{cyan}\theta$ are the parameters we train.</p>
<h2>Loss</h2>
<p>We optimize the ELBO (from Jenson&rsquo;s inequality) on the negative log likelihood.</p>
<p>
<script type="math/tex; mode=display">\begin{align}
\mathbb{E}[-\log \color{cyan}{p_\theta}(x_0)]
&\le \mathbb{E}_q [ -\log \frac{\color{cyan}{p_\theta}(x_{0:T})}{q(x_{1:T}|x_0)} ] \\
&=L
\end{align}</script>
</p>
<p>The loss can be rewritten as follows.</p>
<p>
<script type="math/tex; mode=display">\begin{align}
L
&= \mathbb{E}_q [ -\log \frac{\color{cyan}{p_\theta}(x_{0:T})}{q(x_{1:T}|x_0)} ] \\
&= \mathbb{E}_q [ -\log p(x_T) - \sum_{t=1}^T \log \frac{\color{cyan}{p_\theta}(x_{t-1}|x_t)}{q(x_t|x_{t-1})} ] \\
&= \mathbb{E}_q [
-\log \frac{p(x_T)}{q(x_T|x_0)}
-\sum_{t=2}^T \log \frac{\color{cyan}{p_\theta}(x_{t-1}|x_t)}{q(x_{t-1}|x_t,x_0)}
-\log \color{cyan}{p_\theta}(x_0|x_1)] \\
&= \mathbb{E}_q [
D_{KL}(q(x_T|x_0) \Vert p(x_T))
+\sum_{t=2}^T D_{KL}(q(x_{t-1}|x_t,x_0) \Vert \color{cyan}{p_\theta}(x_{t-1}|x_t))
-\log \color{cyan}{p_\theta}(x_0|x_1)]
\end{align}</script>
</p>
<p>$D_{KL}(q(x_T|x_0) \Vert p(x_T))$ is constant since we keep $\beta_1, \dots, \beta_T$ constant.</p>
<h3>Computing $L_{t-1} = D_{KL}(q(x_{t-1}|x_t,x_0) \Vert \color{cyan}{p_\theta}(x_{t-1}|x_t))$</h3>
<p>The forward process posterior conditioned by $x_0$ is,</p>
<p>
<script type="math/tex; mode=display">\begin{align}
q(x_{t-1}|x_t, x_0) &= \mathcal{N} \Big(x_{t-1}; \tilde\mu_t(x_t, x_0), \tilde\beta_t \mathbf{I} \Big) \\
\tilde\mu_t(x_t, x_0) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
+ \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\
\tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{a}
\end{align}</script>
</p>
<p>The paper sets $\color{cyan}{\Sigma_\theta}(x_t, t) = \sigma_t^2 \mathbf{I}$ where $\sigma_t^2$ is set to constants
$\beta_t$ or $\tilde\beta_t$.</p>
<p>Then,
<script type="math/tex; mode=display">\color{cyan}{p_\theta}(x_{t-1} | x_t) = \mathcal{N}\big(x_{t-1}; \color{cyan}{\mu_\theta}(x_t, t), \sigma_t^2 \mathbf{I} \big)</script>
</p>
<p>For given noise $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ using $q(x_t|x_0)$</p>
<p>
<script type="math/tex; mode=display">\begin{align}
x_t(x_0, \epsilon) &= \sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon \\
x_0 &= \frac{1}{\sqrt{\bar\alpha_t}} \Big(x_t(x_0, \epsilon) - \sqrt{1-\bar\alpha_t}\epsilon\Big)
\end{align}</script>
</p>
<p>This gives,</p>
<p>
<script type="math/tex; mode=display">\begin{align}
L_{t-1}
&= D_{KL}(q(x_{t-1}|x_t,x_0) \Vert \color{cyan}{p_\theta}(x_{t-1}|x_t)) \\
&= \mathbb{E}_q \Bigg[ \frac{1}{2\sigma_t^2}
\Big \Vert \tilde\mu(x_t, x_0) - \color{cyan}{\mu_\theta}(x_t, t) \Big \Vert^2 \Bigg] \\
&= \mathbb{E}_{x_0, \epsilon} \Bigg[ \frac{1}{2\sigma_t^2}
\bigg\Vert \frac{1}{\sqrt{\alpha_t}} \Big(
x_t(x_0, \epsilon) - \frac{\beta_t}{\sqrt{1 - \bar\alpha_t}} \epsilon
\Big) - \color{cyan}{\mu_\theta}(x_t(x_0, \epsilon), t) \bigg\Vert^2 \Bigg] \\
\end{align}</script>
</p>
<p>Re-parameterizing with a model to predict noise</p>
<p>
<script type="math/tex; mode=display">\begin{align}
\color{cyan}{\mu_\theta}(x_t, t) &= \tilde\mu \bigg(x_t,
\frac{1}{\sqrt{\bar\alpha_t}} \Big(x_t -
\sqrt{1-\bar\alpha_t}\color{cyan}{\epsilon_\theta}(x_t, t) \Big) \bigg) \\
&= \frac{1}{\sqrt{\alpha_t}} \Big(x_t -
\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\color{cyan}{\epsilon_\theta}(x_t, t) \Big)
\end{align}</script>
</p>
<p>where $\epsilon_theta$ is a learned function that predicts $\epsilon$ given $(x_t, t)$.</p>
<p>This gives,</p>
<p>
<script type="math/tex; mode=display">\begin{align}
L_{t-1}
&= \mathbb{E}_{x_0, \epsilon} \Bigg[ \frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1 - \bar\alpha_t)}
\Big\Vert
\epsilon - \color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
\Big\Vert^2 \Bigg]
\end{align}</script>
</p>
<p>That is, we are training to predict the noise.</p>
<h3>Simplified loss</h3>
<p>
<script type="math/tex; mode=display">L_simple(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert
\epsilon - \color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
\bigg\Vert^2 \Bigg]</script>
</p>
<p>This minimizes $-\log \color{cyan}{p_\theta}(x_0|x_1)$ when $t=1$ and $L_{t-1}$ for $t\gt1$ discarding the
weighting in $L_{t-1}$. Discarding the weights $\frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1 - \bar\alpha_t)}$
increase the weight given to higher $t$ (which have higher noise levels), therefore increasing the sample quality.</p>
<p>This file implements the loss calculation and a basic sampling method that we use to generate images during
training.</p>
<p>Here is the <a href="unet.html">UNet model</a> that gives $\color{cyan}{\epsilon_\theta}(x_t, t)$ and
<a href="experiment.html">training code</a>.
<a href="evaluate.html">This file</a> can generate samples and interpolations from a trained model.</p>
<p><a href="https://app.labml.ai/run/a44333ea251411ec8007d1a1762ed686"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">162</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Optional</span>
<span class="lineno">163</span>
<span class="lineno">164</span><span class="kn">import</span> <span class="nn">torch</span>
<span class="lineno">165</span><span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
<span class="lineno">166</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
<span class="lineno">167</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="lineno">168</span>
<span class="lineno">169</span><span class="kn">from</span> <span class="nn">labml_nn.diffusion.ddpm.utils</span> <span class="kn">import</span> <span class="n">gather</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<h2>Denoise Diffusion</h2>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">172</span><span class="k">class</span> <span class="nc">DenoiseDiffusion</span><span class="p">:</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
<ul>
<li><code>eps_model</code> is $\color{cyan}{\epsilon_\theta}(x_t, t)$ model</li>
<li><code>n_steps</code> is $t$</li>
<li><code>device</code> is the device to place constants on</li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">177</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">eps_model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">n_steps</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-3'>
<div class='docs'>
<div class='section-link'>
<a href='#section-3'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">183</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="lineno">184</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps_model</span> <span class="o">=</span> <span class="n">eps_model</span></pre></div>
</div>
</div>
<div class='section' id='section-4'>
<div class='docs'>
<div class='section-link'>
<a href='#section-4'>#</a>
</div>
<p>Create $\beta_1, \dots, \beta_T$ linearly increasing variance schedule</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">187</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mf">0.0001</span><span class="p">,</span> <span class="mf">0.02</span><span class="p">,</span> <span class="n">n_steps</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-5'>
<div class='docs'>
<div class='section-link'>
<a href='#section-5'>#</a>
</div>
<p>$\alpha_t = 1 - \beta_t$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">190</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta</span></pre></div>
</div>
</div>
<div class='section' id='section-6'>
<div class='docs'>
<div class='section-link'>
<a href='#section-6'>#</a>
</div>
<p>$\bar\alpha_t = \prod_{s=1}^t \alpha_s$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">192</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha_bar</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumprod</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">alpha</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-7'>
<div class='docs'>
<div class='section-link'>
<a href='#section-7'>#</a>
</div>
<p>$T$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">194</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_steps</span> <span class="o">=</span> <span class="n">n_steps</span></pre></div>
</div>
</div>
<div class='section' id='section-8'>
<div class='docs'>
<div class='section-link'>
<a href='#section-8'>#</a>
</div>
<p>$\sigma^2 = \beta$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">196</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigma2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta</span></pre></div>
</div>
</div>
<div class='section' id='section-9'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-9'>#</a>
</div>
<h4>Get $q(x_t|x_0)$ distribution</h4>
<p>
<script type="math/tex; mode=display">\begin{align}
q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
\end{align}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">198</span> <span class="k">def</span> <span class="nf">q_xt_x0</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x0</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]:</span></pre></div>
</div>
</div>
<div class='section' id='section-10'>
<div class='docs'>
<div class='section-link'>
<a href='#section-10'>#</a>
</div>
<p><a href="utils.html">gather</a> $\alpha_t$ and compute $\sqrt{\bar\alpha_t} x_0$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">208</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">alpha_bar</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">x0</span></pre></div>
</div>
</div>
<div class='section' id='section-11'>
<div class='docs'>
<div class='section-link'>
<a href='#section-11'>#</a>
</div>
<p>$(1-\bar\alpha_t) \mathbf{I}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">210</span> <span class="n">var</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">gather</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">alpha_bar</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-12'>
<div class='docs'>
<div class='section-link'>
<a href='#section-12'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">212</span> <span class="k">return</span> <span class="n">mean</span><span class="p">,</span> <span class="n">var</span></pre></div>
</div>
</div>
<div class='section' id='section-13'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-13'>#</a>
</div>
<h4>Sample from $q(x_t|x_0)$</h4>
<p>
<script type="math/tex; mode=display">\begin{align}
q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
\end{align}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">214</span> <span class="k">def</span> <span class="nf">q_sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x0</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">eps</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-14'>
<div class='docs'>
<div class='section-link'>
<a href='#section-14'>#</a>
</div>
<p>$\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">224</span> <span class="k">if</span> <span class="n">eps</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">225</span> <span class="n">eps</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">x0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-15'>
<div class='docs'>
<div class='section-link'>
<a href='#section-15'>#</a>
</div>
<p>get $q(x_t|x_0)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">228</span> <span class="n">mean</span><span class="p">,</span> <span class="n">var</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">q_xt_x0</span><span class="p">(</span><span class="n">x0</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-16'>
<div class='docs'>
<div class='section-link'>
<a href='#section-16'>#</a>
</div>
<p>Sample from $q(x_t|x_0)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">230</span> <span class="k">return</span> <span class="n">mean</span> <span class="o">+</span> <span class="p">(</span><span class="n">var</span> <span class="o">**</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">*</span> <span class="n">eps</span></pre></div>
</div>
</div>
<div class='section' id='section-17'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-17'>#</a>
</div>
<h4>Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$</h4>
<p>
<script type="math/tex; mode=display">\begin{align}
\color{cyan}{p_\theta}(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1};
\color{cyan}{\mu_\theta}(x_t, t), \sigma_t^2 \mathbf{I} \big) \\
\color{cyan}{\mu_\theta}(x_t, t)
&= \frac{1}{\sqrt{\alpha_t}} \Big(x_t -
\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\color{cyan}{\epsilon_\theta}(x_t, t) \Big)
\end{align}</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">232</span> <span class="k">def</span> <span class="nf">p_sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">xt</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-18'>
<div class='docs'>
<div class='section-link'>
<a href='#section-18'>#</a>
</div>
<p>$\color{cyan}{\epsilon_\theta}(x_t, t)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">246</span> <span class="n">eps_theta</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps_model</span><span class="p">(</span><span class="n">xt</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-19'>
<div class='docs'>
<div class='section-link'>
<a href='#section-19'>#</a>
</div>
<p><a href="utils.html">gather</a> $\bar\alpha_t$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">248</span> <span class="n">alpha_bar</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">alpha_bar</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-20'>
<div class='docs'>
<div class='section-link'>
<a href='#section-20'>#</a>
</div>
<p>$\alpha_t$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">250</span> <span class="n">alpha</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">alpha</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-21'>
<div class='docs'>
<div class='section-link'>
<a href='#section-21'>#</a>
</div>
<p>$\frac{\beta}{\sqrt{1-\bar\alpha_t}}$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">252</span> <span class="n">eps_coef</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">alpha</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">alpha_bar</span><span class="p">)</span> <span class="o">**</span> <span class="o">.</span><span class="mi">5</span></pre></div>
</div>
</div>
<div class='section' id='section-22'>
<div class='docs'>
<div class='section-link'>
<a href='#section-22'>#</a>
</div>
<p>
<script type="math/tex; mode=display">\frac{1}{\sqrt{\alpha_t}} \Big(x_t -
\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\color{cyan}{\epsilon_\theta}(x_t, t) \Big)</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">255</span> <span class="n">mean</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="n">alpha</span> <span class="o">**</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">xt</span> <span class="o">-</span> <span class="n">eps_coef</span> <span class="o">*</span> <span class="n">eps_theta</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-23'>
<div class='docs'>
<div class='section-link'>
<a href='#section-23'>#</a>
</div>
<p>$\sigma^2$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">257</span> <span class="n">var</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sigma2</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-24'>
<div class='docs'>
<div class='section-link'>
<a href='#section-24'>#</a>
</div>
<p>$\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">260</span> <span class="n">eps</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">xt</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">xt</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-25'>
<div class='docs'>
<div class='section-link'>
<a href='#section-25'>#</a>
</div>
<p>Sample</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">262</span> <span class="k">return</span> <span class="n">mean</span> <span class="o">+</span> <span class="p">(</span><span class="n">var</span> <span class="o">**</span> <span class="o">.</span><span class="mi">5</span><span class="p">)</span> <span class="o">*</span> <span class="n">eps</span></pre></div>
</div>
</div>
<div class='section' id='section-26'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-26'>#</a>
</div>
<h4>Simplified Loss</h4>
<p>
<script type="math/tex; mode=display">L_simple(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert
\epsilon - \color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
\bigg\Vert^2 \Bigg]</script>
</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">264</span> <span class="k">def</span> <span class="nf">loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x0</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">noise</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-27'>
<div class='docs'>
<div class='section-link'>
<a href='#section-27'>#</a>
</div>
<p>Get batch size</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">273</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x0</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
</div>
</div>
<div class='section' id='section-28'>
<div class='docs'>
<div class='section-link'>
<a href='#section-28'>#</a>
</div>
<p>Get random $t$ for each sample in the batch</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">275</span> <span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_steps</span><span class="p">,</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,),</span> <span class="n">device</span><span class="o">=</span><span class="n">x0</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-29'>
<div class='docs'>
<div class='section-link'>
<a href='#section-29'>#</a>
</div>
<p>$\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">278</span> <span class="k">if</span> <span class="n">noise</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="lineno">279</span> <span class="n">noise</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">x0</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-30'>
<div class='docs'>
<div class='section-link'>
<a href='#section-30'>#</a>
</div>
<p>Sample $x_t$ for $q(x_t|x_0)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">282</span> <span class="n">xt</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">q_sample</span><span class="p">(</span><span class="n">x0</span><span class="p">,</span> <span class="n">t</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">noise</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-31'>
<div class='docs'>
<div class='section-link'>
<a href='#section-31'>#</a>
</div>
<p>Get $\color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)$</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">284</span> <span class="n">eps_theta</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps_model</span><span class="p">(</span><span class="n">xt</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='section' id='section-32'>
<div class='docs'>
<div class='section-link'>
<a href='#section-32'>#</a>
</div>
<p>MSE loss</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">287</span> <span class="k">return</span> <span class="n">F</span><span class="o">.</span><span class="n">mse_loss</span><span class="p">(</span><span class="n">noise</span><span class="p">,</span> <span class="n">eps_theta</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

View File

@ -0,0 +1,150 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Denoising Diffusion Probabilistic Models (DDPM)"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/diffusion/ddpm/readme.html"/>
<meta property="og:title" content="Denoising Diffusion Probabilistic Models (DDPM)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Denoising Diffusion Probabilistic Models (DDPM)"/>
<meta property="og:description" content=""/>
<title>Denoising Diffusion Probabilistic Models (DDPM)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/diffusion/ddpm/readme.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">diffusion</a>
<a class="parent" href="index.html">ddpm</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/diffusion/ddpm/readme.md">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="https://nn.labml.ai/diffusion/ddpm/index.html">Denoising Diffusion Probabilistic Models (DDPM)</a></h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation/tutorial of the paper
<a href="https://papers.labml.ai/paper/2006.11239">Denoising Diffusion Probabilistic Models</a>.</p>
<p>In simple terms, we get an image from data and add noise step by step.
Then We train a model to predict that noise at each step and use the model to
generate images.</p>
<p>Here is the <a href="https://nn.labml.ai/diffusion/ddpm/unet.html">UNet model</a> that predicts the noise and
<a href="https://nn.labml.ai/diffusion/ddpm/experiment.html">training code</a>.
<a href="https://nn.labml.ai/diffusion/ddpm/evaluate.html">This file</a> can generate samples and interpolations
from a trained model.</p>
<p><a href="https://app.labml.ai/run/a44333ea251411ec8007d1a1762ed686"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

View File

@ -0,0 +1,163 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="Utility functions for DDPM experiment"/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Utility functions for DDPM experiment"/>
<meta name="twitter:description" content="Utility functions for DDPM experiment"/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/diffusion/ddpm/utils.html"/>
<meta property="og:title" content="Utility functions for DDPM experiment"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Utility functions for DDPM experiment"/>
<meta property="og:description" content="Utility functions for DDPM experiment"/>
<title>Utility functions for DDPM experiment</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/diffusion/ddpm/utils.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">diffusion</a>
<a class="parent" href="index.html">ddpm</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/diffusion/ddpm/utils.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Utility functions for <a href="index.html">DDPM</a> experiemnt</h1>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">10</span><span></span><span class="kn">import</span> <span class="nn">torch.utils.data</span></pre></div>
</div>
</div>
<div class='section' id='section-1'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-1'>#</a>
</div>
<p>Gather consts for $t$ and reshape to feature map shape</p>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">13</span><span class="k">def</span> <span class="nf">gather</span><span class="p">(</span><span class="n">consts</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
</div>
</div>
<div class='section' id='section-2'>
<div class='docs'>
<div class='section-link'>
<a href='#section-2'>#</a>
</div>
</div>
<div class='code'>
<div class="highlight"><pre><span class="lineno">15</span> <span class="n">c</span> <span class="o">=</span> <span class="n">consts</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="lineno">16</span> <span class="k">return</span> <span class="n">c</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

142
docs/diffusion/index.html Normal file
View File

@ -0,0 +1,142 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content="A set of PyTorch implementations/tutorials of diffusion models."/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Diffusion models"/>
<meta name="twitter:description" content="A set of PyTorch implementations/tutorials of diffusion models."/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/diffusion/index.html"/>
<meta property="og:title" content="Diffusion models"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Diffusion models"/>
<meta property="og:description" content="A set of PyTorch implementations/tutorials of diffusion models."/>
<title>Diffusion models</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/diffusion/index.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="index.html">diffusion</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/diffusion/__init__.py">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs doc-strings'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1>Diffusion models</h1>
<ul>
<li><a href="ddpm/index.html">Denoising Diffusion Probabilistic Models (DDPM)</a></li>
</ul>
</div>
<div class='code'>
<div class="highlight"><pre></pre></div>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

View File

@ -112,6 +112,10 @@ implementations.</p>
<li><a href="gan/wasserstein/gradient_penalty/index.html">Wasserstein GAN with Gradient Penalty</a></li> <li><a href="gan/wasserstein/gradient_penalty/index.html">Wasserstein GAN with Gradient Penalty</a></li>
<li><a href="gan/stylegan/index.html">StyleGAN 2</a></li> <li><a href="gan/stylegan/index.html">StyleGAN 2</a></li>
</ul> </ul>
<h4><a href="diffusion/index.html">Diffusion models</a></h4>
<ul>
<li><a href="diffusion/ddpm/index.html">Denoising Diffusion Probabilistic Models (DDPM)</a></li>
</ul>
<h4><a href="sketch_rnn/index.html">Sketch RNN</a></h4> <h4><a href="sketch_rnn/index.html">Sketch RNN</a></h4>
<h4>✨ Graph Neural Networks</h4> <h4>✨ Graph Neural Networks</h4>
<ul> <ul>

View File

@ -42,6 +42,12 @@
"1503.02531": [ "1503.02531": [
"https://nn.labml.ai/distillation/index.html" "https://nn.labml.ai/distillation/index.html"
], ],
"1505.04597": [
"https://nn.labml.ai/diffusion/ddpm/unet.html"
],
"2006.11239": [
"https://nn.labml.ai/diffusion/ddpm/index.html"
],
"2010.07468": [ "2010.07468": [
"https://nn.labml.ai/optimizers/ada_belief.html" "https://nn.labml.ai/optimizers/ada_belief.html"
], ],

147
docs/rl/dqn/readme.html Normal file
View File

@ -0,0 +1,147 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<meta name="description" content=""/>
<meta name="twitter:card" content="summary"/>
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta name="twitter:title" content="Deep Q Networks (DQN)"/>
<meta name="twitter:description" content=""/>
<meta name="twitter:site" content="@labmlai"/>
<meta name="twitter:creator" content="@labmlai"/>
<meta property="og:url" content="https://nn.labml.ai/rl/dqn/readme.html"/>
<meta property="og:title" content="Deep Q Networks (DQN)"/>
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
<meta property="og:site_name" content="LabML Neural Networks"/>
<meta property="og:type" content="object"/>
<meta property="og:title" content="Deep Q Networks (DQN)"/>
<meta property="og:description" content=""/>
<title>Deep Q Networks (DQN)</title>
<link rel="shortcut icon" href="/icon.png"/>
<link rel="stylesheet" href="../../pylit.css">
<link rel="canonical" href="https://nn.labml.ai/rl/dqn/readme.html"/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {
dataLayer.push(arguments);
}
gtag('js', new Date());
gtag('config', 'G-4V3HC8HBLH');
</script>
</head>
<body>
<div id='container'>
<div id="background"></div>
<div class='section'>
<div class='docs'>
<p>
<a class="parent" href="/">home</a>
<a class="parent" href="../index.html">rl</a>
<a class="parent" href="index.html">dqn</a>
</p>
<p>
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/rl/dqn/readme.md">
<img alt="Github"
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
style="max-width:100%;"/></a>
<a href="https://twitter.com/labmlai"
rel="nofollow">
<img alt="Twitter"
src="https://img.shields.io/twitter/follow/labmlai?style=social"
style="max-width:100%;"/></a>
</p>
</div>
</div>
<div class='section' id='section-0'>
<div class='docs'>
<div class='section-link'>
<a href='#section-0'>#</a>
</div>
<h1><a href="https://nn.labml.ai/rl/dqn/index.html">Deep Q Networks (DQN)</a></h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of paper
<a href="https://papers.labml.ai/paper/1312.5602">Playing Atari with Deep Reinforcement Learning</a>
along with <a href="https://nn.labml.ai/rl/dqn/model.html">Dueling Network</a>, <a href="https://nn.labml.ai/rl/dqn/replay_buffer.html">Prioritized Replay</a>
and Double Q Network.</p>
<p>Here is the <a href="https://nn.labml.ai/rl/dqn/experiment.html">experiment</a> and <a href="https://nn.labml.ai/rl/dqn/model.html">model</a> implementation.</p>
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/rl/dqn/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
<a href="https://app.labml.ai/run/fe1ad986237511ec86e8b763a2d3f710"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
</div>
<div class='code'>
</div>
</div>
<div class='footer'>
<a href="https://papers.labml.ai">Trending Research Papers</a>
<a href="https://labml.ai">labml.ai</a>
</div>
</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
</script>
<!-- MathJax configuration -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true,
processEnvironments: true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign: 'center',
"HTML-CSS": { fonts: ["TeX"] }
});
</script>
<script>
function handleImages() {
var images = document.querySelectorAll('p>img')
console.log(images);
for (var i = 0; i < images.length; ++i) {
handleImage(images[i])
}
}
function handleImage(img) {
img.parentElement.style.textAlign = 'center'
var modal = document.createElement('div')
modal.id = 'modal'
var modalContent = document.createElement('div')
modal.appendChild(modalContent)
var modalImage = document.createElement('img')
modalContent.appendChild(modalImage)
var span = document.createElement('span')
span.classList.add('close')
span.textContent = 'x'
modal.appendChild(span)
img.onclick = function () {
console.log('clicked')
document.body.appendChild(modal)
modalImage.src = img.src
}
span.onclick = function () {
document.body.removeChild(modal)
}
}
handleImages()
</script>
</body>
</html>

View File

@ -328,6 +328,48 @@
</url> </url>
<url>
<loc>https://nn.labml.ai/diffusion/index.html</loc>
<lastmod>2021-10-06T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/diffusion/ddpm/unet.html</loc>
<lastmod>2021-10-08T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/diffusion/ddpm/index.html</loc>
<lastmod>2021-10-08T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/diffusion/ddpm/experiment.html</loc>
<lastmod>2021-10-08T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/diffusion/ddpm/utils.html</loc>
<lastmod>2021-10-08T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url>
<loc>https://nn.labml.ai/diffusion/ddpm/evaluate.html</loc>
<lastmod>2021-10-08T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
<url> <url>
<loc>https://nn.labml.ai/optimizers/adam_warmup.html</loc> <loc>https://nn.labml.ai/optimizers/adam_warmup.html</loc>
<lastmod>2021-01-13T16:30:00+00:00</lastmod> <lastmod>2021-01-13T16:30:00+00:00</lastmod>

View File

@ -52,6 +52,10 @@ implementations.
* [Wasserstein GAN with Gradient Penalty](gan/wasserstein/gradient_penalty/index.html) * [Wasserstein GAN with Gradient Penalty](gan/wasserstein/gradient_penalty/index.html)
* [StyleGAN 2](gan/stylegan/index.html) * [StyleGAN 2](gan/stylegan/index.html)
#### ✨ [Diffusion models](diffusion/index.html)
* [Denoising Diffusion Probabilistic Models (DDPM)](diffusion/ddpm/index.html)
#### ✨ [Sketch RNN](sketch_rnn/index.html) #### ✨ [Sketch RNN](sketch_rnn/index.html)
#### ✨ Graph Neural Networks #### ✨ Graph Neural Networks

View File

@ -0,0 +1,11 @@
"""
---
title: Diffusion models
summary: >
A set of PyTorch implementations/tutorials of diffusion models.
---
# Diffusion models
* [Denoising Diffusion Probabilistic Models (DDPM)](ddpm/index.html)
"""

View File

@ -0,0 +1,287 @@
"""
---
title: Denoising Diffusion Probabilistic Models (DDPM)
summary: >
PyTorch implementation and tutorial of the paper
Denoising Diffusion Probabilistic Models (DDPM).
---
# Denoising Diffusion Probabilistic Models (DDPM)
This is a [PyTorch](https://pytorch.org) implementation/tutorial of the paper
[Denoising Diffusion Probabilistic Models](https://papers.labml.ai/paper/2006.11239).
In simple terms, we get an image from data and add noise step by step.
Then We train a model to predict that noise at each step and use the model to
generate images.
The following definitions and derivations show how this works.
For details please refer to [the paper](https://papers.labml.ai/paper/2006.11239).
## Forward Process
The forward process adds noise to the data $x_0 \sim q(x_0)$, for $T$ timesteps.
\begin{align}
q(x_t | x_{t-1}) = \mathcal{N}\big(x_t; \sqrt{1- \beta_t} x_{t-1}, \beta_t \mathbf{I}\big) \\
q(x_{1:T} | x_0) = \prod_{t = 1}^{T} q(x_t | x_{t-1})
\end{align}
where $\beta_1, \dots, \beta_T$ is the variance schedule.
We can sample $x_t$ at any timestep $t$ with,
\begin{align}
q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
\end{align}
where $\alpha_t = 1 - \beta_t$ and $\bar\alpha_t = \prod_{s=1}^t \alpha_s$
## Reverse Process
The reverse process removes noise starting at $p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
for $T$ time steps.
\begin{align}
\color{cyan}{p_\theta}(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1};
\color{cyan}{\mu_\theta}x_t, t), \color{cyan}{\Sigma_\theta}(x_t, t)\big) \\
\color{cyan}{p_\theta}(x_{0:T}) &= \color{cyan}{p_\theta}(x_T) \prod_{t = 1}^{T} \color{cyan}{p_\theta}(x_{t-1} | x_t) \\
\color{cyan}{p_\theta}(x_0) &= \int \color{cyan}{p_\theta}(x_{0:T}) dx_{1:T}
\end{align}
$\color{cyan}\theta$ are the parameters we train.
## Loss
We optimize the ELBO (from Jenson's inequality) on the negative log likelihood.
\begin{align}
\mathbb{E}[-\log \color{cyan}{p_\theta}(x_0)]
&\le \mathbb{E}_q [ -\log \frac{\color{cyan}{p_\theta}(x_{0:T})}{q(x_{1:T}|x_0)} ] \\
&=L
\end{align}
The loss can be rewritten as follows.
\begin{align}
L
&= \mathbb{E}_q [ -\log \frac{\color{cyan}{p_\theta}(x_{0:T})}{q(x_{1:T}|x_0)} ] \\
&= \mathbb{E}_q [ -\log p(x_T) - \sum_{t=1}^T \log \frac{\color{cyan}{p_\theta}(x_{t-1}|x_t)}{q(x_t|x_{t-1})} ] \\
&= \mathbb{E}_q [
-\log \frac{p(x_T)}{q(x_T|x_0)}
-\sum_{t=2}^T \log \frac{\color{cyan}{p_\theta}(x_{t-1}|x_t)}{q(x_{t-1}|x_t,x_0)}
-\log \color{cyan}{p_\theta}(x_0|x_1)] \\
&= \mathbb{E}_q [
D_{KL}(q(x_T|x_0) \Vert p(x_T))
+\sum_{t=2}^T D_{KL}(q(x_{t-1}|x_t,x_0) \Vert \color{cyan}{p_\theta}(x_{t-1}|x_t))
-\log \color{cyan}{p_\theta}(x_0|x_1)]
\end{align}
$D_{KL}(q(x_T|x_0) \Vert p(x_T))$ is constant since we keep $\beta_1, \dots, \beta_T$ constant.
### Computing $L_{t-1} = D_{KL}(q(x_{t-1}|x_t,x_0) \Vert \color{cyan}{p_\theta}(x_{t-1}|x_t))$
The forward process posterior conditioned by $x_0$ is,
\begin{align}
q(x_{t-1}|x_t, x_0) &= \mathcal{N} \Big(x_{t-1}; \tilde\mu_t(x_t, x_0), \tilde\beta_t \mathbf{I} \Big) \\
\tilde\mu_t(x_t, x_0) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
+ \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\
\tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{a}
\end{align}
The paper sets $\color{cyan}{\Sigma_\theta}(x_t, t) = \sigma_t^2 \mathbf{I}$ where $\sigma_t^2$ is set to constants
$\beta_t$ or $\tilde\beta_t$.
Then,
$$\color{cyan}{p_\theta}(x_{t-1} | x_t) = \mathcal{N}\big(x_{t-1}; \color{cyan}{\mu_\theta}(x_t, t), \sigma_t^2 \mathbf{I} \big)$$
For given noise $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ using $q(x_t|x_0)$
\begin{align}
x_t(x_0, \epsilon) &= \sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon \\
x_0 &= \frac{1}{\sqrt{\bar\alpha_t}} \Big(x_t(x_0, \epsilon) - \sqrt{1-\bar\alpha_t}\epsilon\Big)
\end{align}
This gives,
\begin{align}
L_{t-1}
&= D_{KL}(q(x_{t-1}|x_t,x_0) \Vert \color{cyan}{p_\theta}(x_{t-1}|x_t)) \\
&= \mathbb{E}_q \Bigg[ \frac{1}{2\sigma_t^2}
\Big \Vert \tilde\mu(x_t, x_0) - \color{cyan}{\mu_\theta}(x_t, t) \Big \Vert^2 \Bigg] \\
&= \mathbb{E}_{x_0, \epsilon} \Bigg[ \frac{1}{2\sigma_t^2}
\bigg\Vert \frac{1}{\sqrt{\alpha_t}} \Big(
x_t(x_0, \epsilon) - \frac{\beta_t}{\sqrt{1 - \bar\alpha_t}} \epsilon
\Big) - \color{cyan}{\mu_\theta}(x_t(x_0, \epsilon), t) \bigg\Vert^2 \Bigg] \\
\end{align}
Re-parameterizing with a model to predict noise
\begin{align}
\color{cyan}{\mu_\theta}(x_t, t) &= \tilde\mu \bigg(x_t,
\frac{1}{\sqrt{\bar\alpha_t}} \Big(x_t -
\sqrt{1-\bar\alpha_t}\color{cyan}{\epsilon_\theta}(x_t, t) \Big) \bigg) \\
&= \frac{1}{\sqrt{\alpha_t}} \Big(x_t -
\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\color{cyan}{\epsilon_\theta}(x_t, t) \Big)
\end{align}
where $\epsilon_theta$ is a learned function that predicts $\epsilon$ given $(x_t, t)$.
This gives,
\begin{align}
L_{t-1}
&= \mathbb{E}_{x_0, \epsilon} \Bigg[ \frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1 - \bar\alpha_t)}
\Big\Vert
\epsilon - \color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
\Big\Vert^2 \Bigg]
\end{align}
That is, we are training to predict the noise.
### Simplified loss
$$L_simple(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert
\epsilon - \color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
\bigg\Vert^2 \Bigg]$$
This minimizes $-\log \color{cyan}{p_\theta}(x_0|x_1)$ when $t=1$ and $L_{t-1}$ for $t\gt1$ discarding the
weighting in $L_{t-1}$. Discarding the weights $\frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1 - \bar\alpha_t)}$
increase the weight given to higher $t$ (which have higher noise levels), therefore increasing the sample quality.
This file implements the loss calculation and a basic sampling method that we use to generate images during
training.
Here is the [UNet model](unet.html) that gives $\color{cyan}{\epsilon_\theta}(x_t, t)$ and
[training code](experiment.html).
[This file](evaluate.html) can generate samples and interpolations from a trained model.
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/a44333ea251411ec8007d1a1762ed686)
"""
from typing import Tuple, Optional
import torch
import torch.nn.functional as F
import torch.utils.data
from torch import nn
from labml_nn.diffusion.ddpm.utils import gather
class DenoiseDiffusion:
"""
## Denoise Diffusion
"""
def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
"""
* `eps_model` is $\color{cyan}{\epsilon_\theta}(x_t, t)$ model
* `n_steps` is $t$
* `device` is the device to place constants on
"""
super().__init__()
self.eps_model = eps_model
# Create $\beta_1, \dots, \beta_T$ linearly increasing variance schedule
self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
# $\alpha_t = 1 - \beta_t$
self.alpha = 1. - self.beta
# $\bar\alpha_t = \prod_{s=1}^t \alpha_s$
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
# $T$
self.n_steps = n_steps
# $\sigma^2 = \beta$
self.sigma2 = self.beta
def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
#### Get $q(x_t|x_0)$ distribution
\begin{align}
q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
\end{align}
"""
# [gather](utils.html) $\alpha_t$ and compute $\sqrt{\bar\alpha_t} x_0$
mean = gather(self.alpha_bar, t) ** 0.5 * x0
# $(1-\bar\alpha_t) \mathbf{I}$
var = 1 - gather(self.alpha_bar, t)
#
return mean, var
def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
"""
#### Sample from $q(x_t|x_0)$
\begin{align}
q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
\end{align}
"""
# $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
if eps is None:
eps = torch.randn_like(x0)
# get $q(x_t|x_0)$
mean, var = self.q_xt_x0(x0, t)
# Sample from $q(x_t|x_0)$
return mean + (var ** 0.5) * eps
def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
"""
#### Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$
\begin{align}
\color{cyan}{p_\theta}(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1};
\color{cyan}{\mu_\theta}(x_t, t), \sigma_t^2 \mathbf{I} \big) \\
\color{cyan}{\mu_\theta}(x_t, t)
&= \frac{1}{\sqrt{\alpha_t}} \Big(x_t -
\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\color{cyan}{\epsilon_\theta}(x_t, t) \Big)
\end{align}
"""
# $\color{cyan}{\epsilon_\theta}(x_t, t)$
eps_theta = self.eps_model(xt, t)
# [gather](utils.html) $\bar\alpha_t$
alpha_bar = gather(self.alpha_bar, t)
# $\alpha_t$
alpha = gather(self.alpha, t)
# $\frac{\beta}{\sqrt{1-\bar\alpha_t}}$
eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
# $$\frac{1}{\sqrt{\alpha_t}} \Big(x_t -
# \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\color{cyan}{\epsilon_\theta}(x_t, t) \Big)$$
mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
# $\sigma^2$
var = gather(self.sigma2, t)
# $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
eps = torch.randn(xt.shape, device=xt.device)
# Sample
return mean + (var ** .5) * eps
def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
"""
#### Simplified Loss
$$L_simple(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert
\epsilon - \color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
\bigg\Vert^2 \Bigg]$$
"""
# Get batch size
batch_size = x0.shape[0]
# Get random $t$ for each sample in the batch
t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
# $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
if noise is None:
noise = torch.randn_like(x0)
# Sample $x_t$ for $q(x_t|x_0)$
xt = self.q_sample(x0, t, eps=noise)
# Get $\color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)$
eps_theta = self.eps_model(xt, t)
# MSE loss
return F.mse_loss(noise, eps_theta)

View File

@ -0,0 +1,327 @@
"""
---
title: Denoising Diffusion Probabilistic Models (DDPM) evaluation/sampling
summary: >
Code to generate samples from a trained
Denoising Diffusion Probabilistic Model.
---
# [Denoising Diffusion Probabilistic Models (DDPM)](index.html) evaluation/sampling
This is the code to generate images and create interpolations between given images.
"""
import numpy as np
import torch
from matplotlib import pyplot as plt
from torchvision.transforms.functional import to_pil_image, resize
from labml import experiment, monit
from labml_nn.diffusion.ddpm import DenoiseDiffusion, gather
from labml_nn.diffusion.ddpm.experiment import Configs
class Sampler:
"""
## Sampler class
"""
def __init__(self, diffusion: DenoiseDiffusion, image_channels: int, image_size: int, device: torch.device):
"""
* `diffusion` is the `DenoiseDiffusion` instance
* `image_channels` is the number of channels in the image
* `image_size` is the image size
* `device` is the device of the model
"""
self.device = device
self.image_size = image_size
self.image_channels = image_channels
self.diffusion = diffusion
# $T$
self.n_steps = diffusion.n_steps
# $\color{cyan}{\epsilon_\theta}(x_t, t)$
self.eps_model = diffusion.eps_model
# $\beta_t$
self.beta = diffusion.beta
# $\alpha_t$
self.alpha = diffusion.alpha
# $\bar\alpha_t$
self.alpha_bar = diffusion.alpha_bar
# $\bar\alpha_{t-1}$
alpha_bar_tm1 = torch.cat([self.alpha_bar.new_ones((1,)), self.alpha_bar[:-1]])
# To calculate
# \begin{align}
# q(x_{t-1}|x_t, x_0) &= \mathcal{N} \Big(x_{t-1}; \tilde\mu_t(x_t, x_0), \tilde\beta_t \mathbf{I} \Big) \\
# \tilde\mu_t(x_t, x_0) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
# + \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\
# \tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{a}
# \end{align}
# $\tilde\beta_t$
self.beta_tilde = self.beta * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)
# $$\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}$$
self.mu_tilde_coef1 = self.beta * (alpha_bar_tm1 ** 0.5) / (1 - self.alpha_bar)
# $$\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1}}{1-\bar\alpha_t}$$
self.mu_tilde_coef2 = (self.alpha ** 0.5) * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)
# $\sigma^2 = \beta$
self.sigma2 = self.beta
def show_image(self, img, title=""):
"""Helper function to display an image"""
img = img.clip(0, 1)
img = img.cpu().numpy()
plt.imshow(img.transpose(1, 2, 0))
plt.title(title)
plt.show()
def make_video(self, frames, path="video.mp4"):
"""Helper function to create a video"""
import imageio
# 20 second video
writer = imageio.get_writer(path, fps=len(frames) // 20)
# Add each image
for f in frames:
f = f.clip(0, 1)
f = to_pil_image(resize(f, [368, 368]))
writer.append_data(np.array(f))
#
writer.close()
def sample_animation(self, n_frames: int = 1000, create_video: bool = True):
"""
#### Sample an image step-by-step using $\color{cyan}{p_\theta}(x_{t-1}|x_t)$
We sample an image step-by-step using $\color{cyan}{p_\theta}(x_{t-1}|x_t)$ and at each step
show the estimate
$$x_0 \approx \hat{x}_0 = \frac{1}{\sqrt{\bar\alpha}}
\Big( x_t - \sqrt{1 - \bar\alpha_t} \color{cyan}{\epsilon_\theta}(x_t, t) \Big)$$
"""
# $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
xt = torch.randn([1, self.image_channels, self.image_size, self.image_size], device=self.device)
# Interval to log $\hat{x}_0$
interval = self.n_steps // n_frames
# Frames for video
frames = []
# Sample $T$ steps
for t_inv in monit.iterate('Denoise', self.n_steps):
# $t$
t_ = self.n_steps - t_inv - 1
# $t$ in a tensor
t = xt.new_full((1,), t_, dtype=torch.long)
# $\color{cyan}{\epsilon_\theta}(x_t, t)$
eps_theta = self.eps_model(xt, t)
if t_ % interval == 0:
# Get $\hat{x}_0$ and add to frames
x0 = self.p_x0(xt, t, eps_theta)
frames.append(x0[0])
if not create_video:
self.show_image(x0[0], f"{t_}")
# Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$
xt = self.p_sample(xt, t, eps_theta)
# Make video
if create_video:
self.make_video(frames)
def interpolate(self, x1: torch.Tensor, x2: torch.Tensor, lambda_: float, t_: int = 100):
"""
#### Interpolate two images $x_0$ and $x'_0$
We get $x_t \sim q(x_t|x_0)$ and $x'_t \sim q(x'_t|x_0)$.
Then interpolate to
$$\bar{x}_t = (1 - \lambda)x_t + \lambda x'_0$$
Then get
$$\bar{x}_0 \sim \color{cyan}{p_\theta}(x_0|\bar{x}_t)$$
* `x1` is $x_0$
* `x2` is $x'_0$
* `lambda_` is $\lambda$
* `t_` is $t$
"""
# Number of samples
n_samples = x1.shape[0]
# $t$ tensor
t = torch.full((n_samples,), t_, device=self.device)
# $$\bar{x}_t = (1 - \lambda)x_t + \lambda x'_0$$
xt = (1 - lambda_) * self.diffusion.q_sample(x1, t) + lambda_ * self.diffusion.q_sample(x2, t)
# $$\bar{x}_0 \sim \color{cyan}{p_\theta}(x_0|\bar{x}_t)$$
return self._sample_x0(xt, t_)
def interpolate_animate(self, x1: torch.Tensor, x2: torch.Tensor, n_frames: int = 100, t_: int = 100,
create_video=True):
"""
#### Interpolate two images $x_0$ and $x'_0$ and make a video
* `x1` is $x_0$
* `x2` is $x'_0$
* `n_frames` is the number of frames for the image
* `t_` is $t$
* `create_video` specifies whether to make a video or to show each frame
"""
# Show original images
self.show_image(x1, "x1")
self.show_image(x2, "x2")
# Add batch dimension
x1 = x1[None, :, :, :]
x2 = x2[None, :, :, :]
# $t$ tensor
t = torch.full((1,), t_, device=self.device)
# $x_t \sim q(x_t|x_0)$
x1t = self.diffusion.q_sample(x1, t)
# $x'_t \sim q(x'_t|x_0)$
x2t = self.diffusion.q_sample(x2, t)
frames = []
# Get frames with different $\lambda$
for i in monit.iterate('Interpolate', n_frames + 1, is_children_silent=True):
# $\lambda$
lambda_ = i / n_frames
# $$\bar{x}_t = (1 - \lambda)x_t + \lambda x'_0$$
xt = (1 - lambda_) * x1t + lambda_ * x2t
# $$\bar{x}_0 \sim \color{cyan}{p_\theta}(x_0|\bar{x}_t)$$
x0 = self._sample_x0(xt, t_)
# Add to frames
frames.append(x0[0])
# Show frame
if not create_video:
self.show_image(x0[0], f"{lambda_ :.2f}")
# Make video
if create_video:
self.make_video(frames)
def _sample_x0(self, xt: torch.Tensor, n_steps: int):
"""
#### Sample an image using $\color{cyan}{p_\theta}(x_{t-1}|x_t)$
* `xt` is $x_t$
* `n_steps` is $t$
"""
# Number of sampels
n_samples = xt.shape[0]
# Iterate until $t$ steps
for t_ in monit.iterate('Denoise', n_steps):
t = n_steps - t_ - 1
# Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$
xt = self.diffusion.p_sample(xt, xt.new_full((n_samples,), t, dtype=torch.long))
# Return $x_0$
return xt
def sample(self, n_samples: int = 16):
"""
#### Generate images
"""
# $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
xt = torch.randn([n_samples, self.image_channels, self.image_size, self.image_size], device=self.device)
# $$x_0 \sim \color{cyan}{p_\theta}(x_0|x_t)$$
x0 = self._sample_x0(xt, self.n_steps)
# Show images
for i in range(n_samples):
self.show_image(x0[i])
def p_sample(self, xt: torch.Tensor, t: torch.Tensor, eps_theta: torch.Tensor):
"""
#### Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$
\begin{align}
\color{cyan}{p_\theta}(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1};
\color{cyan}{\mu_\theta}(x_t, t), \sigma_t^2 \mathbf{I} \big) \\
\color{cyan}{\mu_\theta}(x_t, t)
&= \frac{1}{\sqrt{\alpha_t}} \Big(x_t -
\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\color{cyan}{\epsilon_\theta}(x_t, t) \Big)
\end{align}
"""
# [gather](utils.html) $\bar\alpha_t$
alpha_bar = gather(self.alpha_bar, t)
# $\alpha_t$
alpha = gather(self.alpha, t)
# $\frac{\beta}{\sqrt{1-\bar\alpha_t}}$
eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
# $$\frac{1}{\sqrt{\alpha_t}} \Big(x_t -
# \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\color{cyan}{\epsilon_\theta}(x_t, t) \Big)$$
mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
# $\sigma^2$
var = gather(self.sigma2, t)
# $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
eps = torch.randn(xt.shape, device=xt.device)
# Sample
return mean + (var ** .5) * eps
def p_x0(self, xt: torch.Tensor, t: torch.Tensor, eps: torch.Tensor):
"""
#### Estimate $x_0$
$$x_0 \approx \hat{x}_0 = \frac{1}{\sqrt{\bar\alpha}}
\Big( x_t - \sqrt{1 - \bar\alpha_t} \color{cyan}{\epsilon_\theta}(x_t, t) \Big)$$
"""
# [gather](utils.html) $\bar\alpha_t$
alpha_bar = gather(self.alpha_bar, t)
# $$x_0 \approx \hat{x}_0 = \frac{1}{\sqrt{\bar\alpha}}
# \Big( x_t - \sqrt{1 - \bar\alpha_t} \color{cyan}{\epsilon_\theta}(x_t, t) \Big)$$
return (xt - (1 - alpha_bar) ** 0.5 * eps) / (alpha_bar ** 0.5)
def main():
"""Generate samples"""
# Training experiment run UUID
run_uuid = "a44333ea251411ec8007d1a1762ed686"
# Start an evaluation
experiment.evaluate()
# Create configs
configs = Configs()
# Load custom configuration of the training run
configs_dict = experiment.load_configs(run_uuid)
# Set configurations
experiment.configs(configs, configs_dict)
# Initialize
configs.init()
# Set PyTorch modules for saving and loading
experiment.add_pytorch_models({'eps_model': configs.eps_model})
# Load training experiment
experiment.load(run_uuid)
# Create sampler
sampler = Sampler(diffusion=configs.diffusion,
image_channels=configs.image_channels,
image_size=configs.image_size,
device=configs.device)
# Start evaluation
with experiment.start():
# No gradients
with torch.no_grad():
# Sample an image with an denoising animation
sampler.sample_animation()
if False:
# Get some images fro data
data = next(iter(configs.data_loader)).to(configs.device)
# Create an interpolation animation
sampler.interpolate_animate(data[0], data[1])
#
if __name__ == '__main__':
main()

View File

@ -0,0 +1,249 @@
"""
---
title: Denoising Diffusion Probabilistic Models (DDPM) training
summary: >
Training code for
Denoising Diffusion Probabilistic Model.
---
# [Denoising Diffusion Probabilistic Models (DDPM)](index.html) training
This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this
[discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3).
Save the images inside [`data/celebA` folder](#dataset_path).
The paper had used a exponential moving average of the model with a decay of $0.9999$. We have skipped this for
simplicity.
"""
from typing import List
import torch
import torch.utils.data
import torchvision
from PIL import Image
from labml import lab, tracker, experiment, monit
from labml.configs import BaseConfigs, option
from labml_helpers.device import DeviceConfigs
from labml_nn.diffusion.ddpm import DenoiseDiffusion
from labml_nn.diffusion.ddpm.unet import UNet
class Configs(BaseConfigs):
"""
## Configurations
"""
# Device to train the model on.
# [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs)
# picks up an available CUDA device or defaults to CPU.
device: torch.device = DeviceConfigs()
# U-Net model for $\color{cyan}{\epsilon_\theta}(x_t, t)$
eps_model: UNet
# [DDPM algorithm](index.html)
diffusion: DenoiseDiffusion
# Number of channels in the image. $3$ for RGB.
image_channels: int = 3
# Image size
image_size: int = 32
# Number of channels in the initial feature map
n_channels: int = 64
# The list of channel numbers at each resolution.
# The number of channels is `channel_multipliers[i] * n_channels`
channel_multipliers: List[int] = [1, 2, 2, 4]
# The list of booleans that indicate whether to use attention at each resolution
is_attention: List[int] = [False, False, False, True]
# Number of time steps $T$
n_steps: int = 1_000
# Batch size
batch_size: int = 64
# Number of samples to generate
n_samples: int = 16
# Learning rate
learning_rate: float = 2e-5
# Number of training epochs
epochs: int = 1_000
# Dataset
dataset: torch.utils.data.Dataset
# Dataloader
data_loader: torch.utils.data.DataLoader
# Adam optimizer
optimizer: torch.optim.Adam
def init(self):
# Create $\color{cyan}{\epsilon_\theta}(x_t, t)$ model
self.eps_model = UNet(
image_channels=self.image_channels,
n_channels=self.n_channels,
ch_mults=self.channel_multipliers,
is_attn=self.is_attention,
).to(self.device)
# Create [DDPM class](index.html)
self.diffusion = DenoiseDiffusion(
eps_model=self.eps_model,
n_steps=self.n_steps,
device=self.device,
)
# Create dataloader
self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
# Create optimizer
self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
# Image logging
tracker.set_image("sample", True)
def sample(self):
"""
### Sample images
"""
with torch.no_grad():
# $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
device=self.device)
# Remove noise for $T$ steps
for t_ in monit.iterate('Sample', self.n_steps):
# $t$
t = self.n_steps - t_ - 1
# Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$
x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
# Log samples
tracker.save('sample', x)
def train(self):
"""
### Train
"""
# Iterate through the dataset
for data in monit.iterate('Train', self.data_loader):
# Increment global step
tracker.add_global_step()
# Move data to device
data = data.to(self.device)
# Make the gradients zero
self.optimizer.zero_grad()
# Calculate loss
loss = self.diffusion.loss(data)
# Compute gradients
loss.backward()
# Take an optimization step
self.optimizer.step()
# Track the loss
tracker.save('loss', loss)
def run(self):
"""
### Training loop
"""
for _ in monit.loop(self.epochs):
# Train the model
self.train()
# Sample some images
self.sample()
# New line in the console
tracker.new_line()
# Save the model
experiment.save_checkpoint()
class CelebADataset(torch.utils.data.Dataset):
"""
### CelebA HQ dataset
"""
def __init__(self, image_size: int):
super().__init__()
# CelebA images folder
folder = lab.get_data_path() / 'celebA'
# List of files
self._files = [p for p in folder.glob(f'**/*.jpg')]
# Transformations to resize the image and convert to tensor
self._transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(image_size),
torchvision.transforms.ToTensor(),
])
def __len__(self):
"""
Size of the dataset
"""
return len(self._files)
def __getitem__(self, index: int):
"""
Get an image
"""
img = Image.open(self._files[index])
return self._transform(img)
@option(Configs.dataset, 'CelebA')
def celeb_dataset(c: Configs):
"""
Create CelebA dataset
"""
return CelebADataset(c.image_size)
class MNISTDataset(torchvision.datasets.MNIST):
"""
### MNIST dataset
"""
def __init__(self, image_size):
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(image_size),
torchvision.transforms.ToTensor(),
])
super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
def __getitem__(self, item):
return super().__getitem__(item)[0]
@option(Configs.dataset, 'MNIST')
def mnist_dataset(c: Configs):
"""
Create MNIST dataset
"""
return MNISTDataset(c.image_size)
def main():
# Create experiment
experiment.create(name='diffuse')
# Create configurations
configs = Configs()
# Set configurations. You can override the defaults by passing the values in the dictionary.
experiment.configs(configs, {
})
# Initialize
configs.init()
# Set models for saving and loading
experiment.add_pytorch_models({'eps_model': configs.eps_model})
# Start and run the training loop
with experiment.start():
configs.run()
#
if __name__ == '__main__':
main()

View File

@ -0,0 +1,15 @@
# [Denoising Diffusion Probabilistic Models (DDPM)](https://nn.labml.ai/diffusion/ddpm/index.html)
This is a [PyTorch](https://pytorch.org) implementation/tutorial of the paper
[Denoising Diffusion Probabilistic Models](https://papers.labml.ai/paper/2006.11239).
In simple terms, we get an image from data and add noise step by step.
Then We train a model to predict that noise at each step and use the model to
generate images.
Here is the [UNet model](https://nn.labml.ai/diffusion/ddpm/unet.html) that predicts the noise and
[training code](https://nn.labml.ai/diffusion/ddpm/experiment.html).
[This file](https://nn.labml.ai/diffusion/ddpm/evaluate.html) can generate samples and interpolations
from a trained model.
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/a44333ea251411ec8007d1a1762ed686)

View File

@ -0,0 +1,410 @@
"""
---
title: U-Net model for Denoising Diffusion Probabilistic Models (DDPM)
summary: >
UNet model for Denoising Diffusion Probabilistic Models (DDPM)
---
# U-Net model for [Denoising Diffusion Probabilistic Models (DDPM)](index.html)
This is a [U-Net](https://papers.labml.ai/paper/1505.04597) based model to predict noise
$\color{cyan}{\epsilon_\theta}(x_t, t)$.
U-Net is a gets it's name from the U shape in the model diagram.
It processes a given image by progressively lowering (halving) the feature map resolution and then
increasing the resolution.
There are pass-through connection at each resolution.
![U-Net diagram from paper](unet.png)
This implementation contains a bunch of modifications to original U-Net (residual blocks, multi-head attention)
and also adds time-step embeddings $t$.
"""
import math
from typing import Optional, Tuple, Union, List
import torch
from torch import nn
from labml_helpers.module import Module
class Swish(Module):
"""
### Swish actiavation function
$$x \cdot \sigma(x)$$
"""
def forward(self, x):
return x * torch.sigmoid(x)
class TimeEmbedding(nn.Module):
"""
### Embeddings for $t$
"""
def __init__(self, n_channels: int):
"""
* `n_channels` is the number of dimensions in the embedding
"""
super().__init__()
self.n_channels = n_channels
# First linear layer
self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
# Activation
self.act = Swish()
# Second linear layer
self.lin2 = nn.Linear(self.n_channels, self.n_channels)
def forward(self, t: torch.Tensor):
# Create sinusoidal position embeddings
# [same as those from the transformer](../../transformers/positional_encoding.html)
# \begin{align}
# PE^{(1)}_{t,i} &= sin\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg) \\
# PE^{(2)}_{t,i} &= cos\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg)
# \end{align}
# where $d$ is `half_dim`
half_dim = self.n_channels // 8
emb = math.log(10_000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=1)
# Transform with the MLP
emb = self.act(self.lin1(emb))
emb = self.lin2(emb)
#
return emb
class ResidualBlock(Module):
"""
### Residual block
A residual block has two convolution layers with group normalization.
Each resolution is processed with two residual blocks.
"""
def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32):
"""
* `in_channels` is the number of input channels
* `out_channels` is the number of input channels
* `time_channels` is the number channels in the time step ($t$) embeddings
* `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
"""
super().__init__()
# Group normalization and the first convolution layer
self.norm1 = nn.GroupNorm(n_groups, in_channels)
self.act1 = Swish()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
# Group normalization and the second convolution layer
self.norm2 = nn.GroupNorm(n_groups, out_channels)
self.act2 = Swish()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
# If the number of input channels is not equal to the number of output channels we have to
# project the shortcut connection
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
else:
self.shortcut = nn.Identity()
# Linear layer for time embeddings
self.time_emb = nn.Linear(time_channels, out_channels)
def forward(self, x: torch.Tensor, t: torch.Tensor):
"""
* `x` has shape `[batch_size, in_channels, height, width]`
* `t` has shape `[batch_size, time_channels]`
"""
# First convolution layer
h = self.conv1(self.act1(self.norm1(x)))
# Add time embeddings
h += self.time_emb(t)[:, :, None, None]
# Second convolution layer
h = self.conv2(self.act2(self.norm2(h)))
# Add the shortcut connection and return
return h + self.shortcut(x)
class AttentionBlock(Module):
"""
### Attention block
This is similar to [transformer multi-head attention](../../transformers/mha.html).
"""
def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):
"""
* `n_channels` is the number of channels in the input
* `n_heads` is the number of heads in multi-head attention
* `d_k` is the number of dimensions in each head
* `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
"""
super().__init__()
# Default `d_k`
if d_k is None:
d_k = n_channels
# Normalization layer
self.norm = nn.GroupNorm(n_groups, n_channels)
# Projections for query, key and values
self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
# Linear layer for final transformation
self.output = nn.Linear(n_heads * d_k, n_channels)
# Scale for dot-product attention
self.scale = d_k ** -0.5
#
self.n_heads = n_heads
self.d_k = d_k
def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):
"""
* `x` has shape `[batch_size, in_channels, height, width]`
* `t` has shape `[batch_size, time_channels]`
"""
# `t` is not used, but it's kept in the arguments because for the attention layer function signature
# to match with `ResidualBlock`.
_ = t
# Get shape
batch_size, n_channels, height, width = x.shape
# Change `x` to shape `[batch_size, seq, n_channels]`
x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
# Get query, key, and values (concatenated) and shape it to `[batch_size, seq, n_heads, 3 * d_k]`
qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
# Split query, key, and values. Each of them will have shape `[batch_size, seq, n_heads, d_k]`
q, k, v = torch.chunk(qkv, 3, dim=-1)
# Calculate scaled dot-product $\frac{Q K^\top}{\sqrt{d_k}}$
attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
# Softmax along the sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
attn = attn.softmax(dim=1)
# Multiply by values
res = torch.einsum('bijh,bjhd->bihd', attn, v)
# Reshape to `[batch_size, seq, n_heads * d_k]`
res = res.view(batch_size, -1, self.n_heads * self.d_k)
# Transform to `[batch_size, seq, n_channels]`
res = self.output(res)
# Add skip connection
res += x
# Change to shape `[batch_size, in_channels, height, width]`
res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)
#
return res
class DownBlock(Module):
"""
### Down block
This combines `ResidualBlock` and `AttentionBlock`. These are used in the first half of U-Net at each resolution.
"""
def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
super().__init__()
self.res = ResidualBlock(in_channels, out_channels, time_channels)
if has_attn:
self.attn = AttentionBlock(out_channels)
else:
self.attn = nn.Identity()
def forward(self, x: torch.Tensor, t: torch.Tensor):
x = self.res(x, t)
x = self.attn(x)
return x
class UpBlock(Module):
"""
### Up block
This combines `ResidualBlock` and `AttentionBlock`. These are used in the second half of U-Net at each resolution.
"""
def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
super().__init__()
# The input has `in_channels + out_channels` because we concatenate the output of the same resolution
# from the first half of the U-Net
self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
if has_attn:
self.attn = AttentionBlock(out_channels)
else:
self.attn = nn.Identity()
def forward(self, x: torch.Tensor, t: torch.Tensor):
x = self.res(x, t)
x = self.attn(x)
return x
class MiddleBlock(Module):
"""
### Middle block
It combines a `ResidualBlock`, `AttentionBlock`, followed by another `ResidualBlock`.
This block is applied at the lowest resolution of the U-Net.
"""
def __init__(self, n_channels: int, time_channels: int):
super().__init__()
self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
self.attn = AttentionBlock(n_channels)
self.res2 = ResidualBlock(n_channels, n_channels, time_channels)
def forward(self, x: torch.Tensor, t: torch.Tensor):
x = self.res1(x, t)
x = self.attn(x)
x = self.res2(x, t)
return x
class Upsample(nn.Module):
"""
### Scale up the feature map by $2 \times$
"""
def __init__(self, n_channels):
super().__init__()
self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))
def forward(self, x: torch.Tensor, t: torch.Tensor):
# `t` is not used, but it's kept in the arguments because for the attention layer function signature
# to match with `ResidualBlock`.
_ = t
return self.conv(x)
class Downsample(nn.Module):
"""
### Scale down the feature map by $\frac{1}{2} \times$
"""
def __init__(self, n_channels):
super().__init__()
self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))
def forward(self, x: torch.Tensor, t: torch.Tensor):
# `t` is not used, but it's kept in the arguments because for the attention layer function signature
# to match with `ResidualBlock`.
_ = t
return self.conv(x)
class UNet(Module):
"""
## U-Net
"""
def __init__(self, image_channels: int = 3, n_channels: int = 64,
ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
n_blocks: int = 2):
"""
* `image_channels` is the number of channels in the image. $3$ for RGB.
* `n_channels` is number of channels in the initial feature map that we transform the image into
* `ch_mults` is the list of channel numbers at each resolution. The number of channels is `ch_mults[i] * n_channels`
* `is_attn` is a list of booleans that indicate whether to use attention at each resolution
* `n_blocks` is the number of `UpDownBlocks` at each resolution
"""
super().__init__()
# Number of resolutions
n_resolutions = len(ch_mults)
# Project image into feature map
self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))
# Time embedding layer. Time embedding has `n_channels * 4` channels
self.time_emb = TimeEmbedding(n_channels * 4)
# #### First half of U-Net - decreasing resolution
down = []
# Number of channels
out_channels = in_channels = n_channels
# For each resolution
for i in range(n_resolutions):
# Number of output channels at this resolution
out_channels = in_channels * ch_mults[i]
# Add `n_blocks`
for _ in range(n_blocks):
down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
in_channels = out_channels
# Down sample at all resolutions except the last
if i < n_resolutions - 1:
down.append(Downsample(in_channels))
# Combine the set of modules
self.down = nn.ModuleList(down)
# Middle block
self.middle = MiddleBlock(out_channels, n_channels * 4, )
# #### Second half of U-Net - increasing resolution
up = []
# Number of channels
in_channels = out_channels
# For each resolution
for i in reversed(range(n_resolutions)):
# `n_blocks` at the same resolution
out_channels = in_channels
for _ in range(n_blocks):
up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
# Final block to reduce the number of channels
out_channels = in_channels // ch_mults[i]
up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
in_channels = out_channels
# Up sample at all resolutions except last
if i > 0:
up.append(Upsample(in_channels))
# Combine the set of modules
self.up = nn.ModuleList(up)
# Final normalization and convolution layer
self.norm = nn.GroupNorm(8, n_channels)
self.act = Swish()
self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))
def forward(self, x: torch.Tensor, t: torch.Tensor):
"""
* `x` has shape `[batch_size, in_channels, height, width]`
* `t` has shape `[batch_size]`
"""
# Get time-step embeddings
t = self.time_emb(t)
# Get image projection
x = self.image_proj(x)
# `h` will store outputs at each resolution for skip connection
h = [x]
# First half of U-Net
for m in self.down:
x = m(x, t)
h.append(x)
# Middle (bottom)
x = self.middle(x, t)
# Second half of U-Net
for m in self.up:
if isinstance(m, Upsample):
x = m(x, t)
else:
# Get the skip connection from first half of U-Net and concatenate
s = h.pop()
x = torch.cat((x, s), dim=1)
#
x = m(x, t)
# Final normalization and convolution
return self.final(self.act(self.norm(x)))

View File

@ -0,0 +1,16 @@
"""
---
title: Utility functions for DDPM experiment
summary: >
Utility functions for DDPM experiment
---
# Utility functions for [DDPM](index.html) experiemnt
"""
import torch.utils.data
def gather(consts: torch.Tensor, t: torch.Tensor):
"""Gather consts for $t$ and reshape to feature map shape"""
c = consts.gather(-1, t)
return c.reshape(-1, 1, 1, 1)

View File

@ -57,6 +57,11 @@ implementations almost weekly.
* [Wasserstein GAN with Gradient Penalty](https://nn.labml.ai/gan/wasserstein/gradient_penalty/index.html) * [Wasserstein GAN with Gradient Penalty](https://nn.labml.ai/gan/wasserstein/gradient_penalty/index.html)
* [StyleGAN 2](https://nn.labml.ai/gan/stylegan/index.html) * [StyleGAN 2](https://nn.labml.ai/gan/stylegan/index.html)
#### ✨ [Diffusion models](https://nn.labml.ai/diffusion/index.html)
* [Denoising Diffusion Probabilistic Models (DDPM)](https://nn.labml.ai/diffusion/ddpm/index.html)
#### ✨ [Sketch RNN](https://nn.labml.ai/sketch_rnn/index.html) #### ✨ [Sketch RNN](https://nn.labml.ai/sketch_rnn/index.html)
#### ✨ Graph Neural Networks #### ✨ Graph Neural Networks