mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-08-26 16:50:39 +08:00
Denoising Diffusion Probabilistic Models (#98)
This commit is contained in:
1198
docs/diffusion/ddpm/evaluate.html
Normal file
1198
docs/diffusion/ddpm/evaluate.html
Normal file
File diff suppressed because it is too large
Load Diff
945
docs/diffusion/ddpm/experiment.html
Normal file
945
docs/diffusion/ddpm/experiment.html
Normal file
@ -0,0 +1,945 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="Training code for Denoising Diffusion Probabilistic Model."/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Denoising Diffusion Probabilistic Models (DDPM) training"/>
|
||||
<meta name="twitter:description" content="Training code for Denoising Diffusion Probabilistic Model."/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/diffusion/ddpm/experiment.html"/>
|
||||
<meta property="og:title" content="Denoising Diffusion Probabilistic Models (DDPM) training"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Denoising Diffusion Probabilistic Models (DDPM) training"/>
|
||||
<meta property="og:description" content="Training code for Denoising Diffusion Probabilistic Model."/>
|
||||
|
||||
<title>Denoising Diffusion Probabilistic Models (DDPM) training</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/diffusion/ddpm/experiment.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">diffusion</a>
|
||||
<a class="parent" href="index.html">ddpm</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/diffusion/ddpm/experiment.py">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1><a href="index.html">Denoising Diffusion Probabilistic Models (DDPM)</a> training</h1>
|
||||
<p>This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this
|
||||
<a href="https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3">discussion on fast.ai</a>.
|
||||
Save the images inside <a href="#dataset_path"><code>data/celebA</code> folder</a>.</p>
|
||||
<p>The paper had used a exponential moving average of the model with a decay of $0.9999$. We have skipped this for
|
||||
simplicity.</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">18</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span>
|
||||
<span class="lineno">19</span>
|
||||
<span class="lineno">20</span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="lineno">21</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
|
||||
<span class="lineno">22</span><span class="kn">import</span> <span class="nn">torchvision</span>
|
||||
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
|
||||
<span class="lineno">24</span>
|
||||
<span class="lineno">25</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">lab</span><span class="p">,</span> <span class="n">tracker</span><span class="p">,</span> <span class="n">experiment</span><span class="p">,</span> <span class="n">monit</span>
|
||||
<span class="lineno">26</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">BaseConfigs</span><span class="p">,</span> <span class="n">option</span>
|
||||
<span class="lineno">27</span><span class="kn">from</span> <span class="nn">labml_helpers.device</span> <span class="kn">import</span> <span class="n">DeviceConfigs</span>
|
||||
<span class="lineno">28</span><span class="kn">from</span> <span class="nn">labml_nn.diffusion.ddpm</span> <span class="kn">import</span> <span class="n">DenoiseDiffusion</span>
|
||||
<span class="lineno">29</span><span class="kn">from</span> <span class="nn">labml_nn.diffusion.ddpm.unet</span> <span class="kn">import</span> <span class="n">UNet</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<h2>Configurations</h2>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">32</span><span class="k">class</span> <span class="nc">Configs</span><span class="p">(</span><span class="n">BaseConfigs</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
<p>Device to train the model on.
|
||||
<a href="https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs"><code>DeviceConfigs</code></a>
|
||||
picks up an available CUDA device or defaults to CPU.</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">39</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">DeviceConfigs</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
<p>U-Net model for $\color{cyan}{\epsilon_\theta}(x_t, t)$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">42</span> <span class="n">eps_model</span><span class="p">:</span> <span class="n">UNet</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<p><a href="index.html">DDPM algorithm</a></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">44</span> <span class="n">diffusion</span><span class="p">:</span> <span class="n">DenoiseDiffusion</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
<p>Number of channels in the image. $3$ for RGB.</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">47</span> <span class="n">image_channels</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">3</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
<p>Image size</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">49</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
<p>Number of channels in the initial feature map</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">51</span> <span class="n">n_channels</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<p>The list of channel numbers at each resolution.
|
||||
The number of channels is <code>channel_multipliers[i] * n_channels</code></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">54</span> <span class="n">channel_multipliers</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
<p>The list of booleans that indicate whether to use attention at each resolution</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">56</span> <span class="n">is_attention</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">True</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<p>Number of time steps $T$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">59</span> <span class="n">n_steps</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1_000</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
<p>Batch size</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">61</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">64</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-12'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
<p>Number of samples to generate</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">63</span> <span class="n">n_samples</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">16</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-13'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
<p>Learning rate</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">learning_rate</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">2e-5</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-14'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<p>Number of training epochs</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">68</span> <span class="n">epochs</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1_000</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-15'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
<p>Dataset</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">71</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-16'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
<p>Dataloader</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">73</span> <span class="n">data_loader</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-17'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
<p>Adam optimizer</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">76</span> <span class="n">optimizer</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">78</span> <span class="k">def</span> <span class="nf">init</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-19'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-19'>#</a>
|
||||
</div>
|
||||
<p>Create $\color{cyan}{\epsilon_\theta}(x_t, t)$ model</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">80</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps_model</span> <span class="o">=</span> <span class="n">UNet</span><span class="p">(</span>
|
||||
<span class="lineno">81</span> <span class="n">image_channels</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">image_channels</span><span class="p">,</span>
|
||||
<span class="lineno">82</span> <span class="n">n_channels</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">n_channels</span><span class="p">,</span>
|
||||
<span class="lineno">83</span> <span class="n">ch_mults</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">channel_multipliers</span><span class="p">,</span>
|
||||
<span class="lineno">84</span> <span class="n">is_attn</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">is_attention</span><span class="p">,</span>
|
||||
<span class="lineno">85</span> <span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-20'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-20'>#</a>
|
||||
</div>
|
||||
<p>Create <a href="index.html">DDPM class</a></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">88</span> <span class="bp">self</span><span class="o">.</span><span class="n">diffusion</span> <span class="o">=</span> <span class="n">DenoiseDiffusion</span><span class="p">(</span>
|
||||
<span class="lineno">89</span> <span class="n">eps_model</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">eps_model</span><span class="p">,</span>
|
||||
<span class="lineno">90</span> <span class="n">n_steps</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">n_steps</span><span class="p">,</span>
|
||||
<span class="lineno">91</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span>
|
||||
<span class="lineno">92</span> <span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-21'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-21'>#</a>
|
||||
</div>
|
||||
<p>Create dataloader</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">95</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">pin_memory</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-22'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-22'>#</a>
|
||||
</div>
|
||||
<p>Create optimizer</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">97</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">eps_model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">learning_rate</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-23'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-23'>#</a>
|
||||
</div>
|
||||
<p>Image logging</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">100</span> <span class="n">tracker</span><span class="o">.</span><span class="n">set_image</span><span class="p">(</span><span class="s2">"sample"</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-24'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-24'>#</a>
|
||||
</div>
|
||||
<h3>Sample images</h3>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">102</span> <span class="k">def</span> <span class="nf">sample</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-25'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-25'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">106</span> <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-26'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-26'>#</a>
|
||||
</div>
|
||||
<p>$x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">108</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">n_samples</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_channels</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size</span><span class="p">],</span>
|
||||
<span class="lineno">109</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-27'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-27'>#</a>
|
||||
</div>
|
||||
<p>Remove noise for $T$ steps</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">112</span> <span class="k">for</span> <span class="n">t_</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">'Sample'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_steps</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-28'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-28'>#</a>
|
||||
</div>
|
||||
<p>$t$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">114</span> <span class="n">t</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_steps</span> <span class="o">-</span> <span class="n">t_</span> <span class="o">-</span> <span class="mi">1</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-29'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-29'>#</a>
|
||||
</div>
|
||||
<p>Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">116</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">diffusion</span><span class="o">.</span><span class="n">p_sample</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">new_full</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">n_samples</span><span class="p">,),</span> <span class="n">t</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">))</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-30'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-30'>#</a>
|
||||
</div>
|
||||
<p>Log samples</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">119</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s1">'sample'</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-31'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-31'>#</a>
|
||||
</div>
|
||||
<h3>Train</h3>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">121</span> <span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-32'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-32'>#</a>
|
||||
</div>
|
||||
<p>Iterate through the dataset</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">127</span> <span class="k">for</span> <span class="n">data</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">iterate</span><span class="p">(</span><span class="s1">'Train'</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_loader</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-33'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-33'>#</a>
|
||||
</div>
|
||||
<p>Increment global step</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">129</span> <span class="n">tracker</span><span class="o">.</span><span class="n">add_global_step</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-34'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-34'>#</a>
|
||||
</div>
|
||||
<p>Move data to device</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">131</span> <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-35'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-35'>#</a>
|
||||
</div>
|
||||
<p>Make the gradients zero</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">134</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-36'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-36'>#</a>
|
||||
</div>
|
||||
<p>Calculate loss</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">136</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">diffusion</span><span class="o">.</span><span class="n">loss</span><span class="p">(</span><span class="n">data</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-37'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-37'>#</a>
|
||||
</div>
|
||||
<p>Compute gradients</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">138</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-38'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-38'>#</a>
|
||||
</div>
|
||||
<p>Take an optimization step</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">140</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-39'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-39'>#</a>
|
||||
</div>
|
||||
<p>Track the loss</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">142</span> <span class="n">tracker</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s1">'loss'</span><span class="p">,</span> <span class="n">loss</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-40'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-40'>#</a>
|
||||
</div>
|
||||
<h3>Training loop</h3>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">144</span> <span class="k">def</span> <span class="nf">run</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-41'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-41'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">148</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">monit</span><span class="o">.</span><span class="n">loop</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">epochs</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-42'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-42'>#</a>
|
||||
</div>
|
||||
<p>Train the model</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">150</span> <span class="bp">self</span><span class="o">.</span><span class="n">train</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-43'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-43'>#</a>
|
||||
</div>
|
||||
<p>Sample some images</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">152</span> <span class="bp">self</span><span class="o">.</span><span class="n">sample</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-44'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-44'>#</a>
|
||||
</div>
|
||||
<p>New line in the console</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">154</span> <span class="n">tracker</span><span class="o">.</span><span class="n">new_line</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-45'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-45'>#</a>
|
||||
</div>
|
||||
<p>Save the model</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">156</span> <span class="n">experiment</span><span class="o">.</span><span class="n">save_checkpoint</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-46'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-46'>#</a>
|
||||
</div>
|
||||
<h3>CelebA HQ dataset</h3>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">159</span><span class="k">class</span> <span class="nc">CelebADataset</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-47'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-47'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">164</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
|
||||
<span class="lineno">165</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-48'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-48'>#</a>
|
||||
</div>
|
||||
<p>CelebA images folder</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">168</span> <span class="n">folder</span> <span class="o">=</span> <span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()</span> <span class="o">/</span> <span class="s1">'celebA'</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-49'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-49'>#</a>
|
||||
</div>
|
||||
<p>List of files</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">170</span> <span class="bp">self</span><span class="o">.</span><span class="n">_files</span> <span class="o">=</span> <span class="p">[</span><span class="n">p</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">folder</span><span class="o">.</span><span class="n">glob</span><span class="p">(</span><span class="sa">f</span><span class="s1">'**/*.jpg'</span><span class="p">)]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-50'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-50'>#</a>
|
||||
</div>
|
||||
<p>Transformations to resize the image and convert to tensor</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">173</span> <span class="bp">self</span><span class="o">.</span><span class="n">_transform</span> <span class="o">=</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
|
||||
<span class="lineno">174</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">(</span><span class="n">image_size</span><span class="p">),</span>
|
||||
<span class="lineno">175</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
|
||||
<span class="lineno">176</span> <span class="p">])</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-51'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-51'>#</a>
|
||||
</div>
|
||||
<p>Size of the dataset</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">178</span> <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-52'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-52'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">182</span> <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_files</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-53'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-53'>#</a>
|
||||
</div>
|
||||
<p>Get an image</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">184</span> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-54'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-54'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">188</span> <span class="n">img</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_files</span><span class="p">[</span><span class="n">index</span><span class="p">])</span>
|
||||
<span class="lineno">189</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_transform</span><span class="p">(</span><span class="n">img</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-55'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-55'>#</a>
|
||||
</div>
|
||||
<p>Create CelebA dataset</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">192</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">dataset</span><span class="p">,</span> <span class="s1">'CelebA'</span><span class="p">)</span>
|
||||
<span class="lineno">193</span><span class="k">def</span> <span class="nf">celeb_dataset</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-56'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-56'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">197</span> <span class="k">return</span> <span class="n">CelebADataset</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">image_size</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-57'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-57'>#</a>
|
||||
</div>
|
||||
<h3>MNIST dataset</h3>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">200</span><span class="k">class</span> <span class="nc">MNISTDataset</span><span class="p">(</span><span class="n">torchvision</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-58'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-58'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">205</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">image_size</span><span class="p">):</span>
|
||||
<span class="lineno">206</span> <span class="n">transform</span> <span class="o">=</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
|
||||
<span class="lineno">207</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">Resize</span><span class="p">(</span><span class="n">image_size</span><span class="p">),</span>
|
||||
<span class="lineno">208</span> <span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
|
||||
<span class="lineno">209</span> <span class="p">])</span>
|
||||
<span class="lineno">210</span>
|
||||
<span class="lineno">211</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">lab</span><span class="o">.</span><span class="n">get_data_path</span><span class="p">()),</span> <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">transform</span><span class="o">=</span><span class="n">transform</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-59'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-59'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">213</span> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">item</span><span class="p">):</span>
|
||||
<span class="lineno">214</span> <span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__getitem__</span><span class="p">(</span><span class="n">item</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-60'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-60'>#</a>
|
||||
</div>
|
||||
<p>Create MNIST dataset</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">217</span><span class="nd">@option</span><span class="p">(</span><span class="n">Configs</span><span class="o">.</span><span class="n">dataset</span><span class="p">,</span> <span class="s1">'MNIST'</span><span class="p">)</span>
|
||||
<span class="lineno">218</span><span class="k">def</span> <span class="nf">mnist_dataset</span><span class="p">(</span><span class="n">c</span><span class="p">:</span> <span class="n">Configs</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-61'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-61'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">222</span> <span class="k">return</span> <span class="n">MNISTDataset</span><span class="p">(</span><span class="n">c</span><span class="o">.</span><span class="n">image_size</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-62'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-62'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">225</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-63'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-63'>#</a>
|
||||
</div>
|
||||
<p>Create experiment</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">227</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'diffuse'</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-64'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-64'>#</a>
|
||||
</div>
|
||||
<p>Create configurations</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">230</span> <span class="n">configs</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-65'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-65'>#</a>
|
||||
</div>
|
||||
<p>Set configurations. You can override the defaults by passing the values in the dictionary.</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">233</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">configs</span><span class="p">,</span> <span class="p">{</span>
|
||||
<span class="lineno">234</span> <span class="p">})</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-66'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-66'>#</a>
|
||||
</div>
|
||||
<p>Initialize</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">237</span> <span class="n">configs</span><span class="o">.</span><span class="n">init</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-67'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-67'>#</a>
|
||||
</div>
|
||||
<p>Set models for saving and loading</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">240</span> <span class="n">experiment</span><span class="o">.</span><span class="n">add_pytorch_models</span><span class="p">({</span><span class="s1">'eps_model'</span><span class="p">:</span> <span class="n">configs</span><span class="o">.</span><span class="n">eps_model</span><span class="p">})</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-68'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-68'>#</a>
|
||||
</div>
|
||||
<p>Start and run the training loop</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">243</span> <span class="k">with</span> <span class="n">experiment</span><span class="o">.</span><span class="n">start</span><span class="p">():</span>
|
||||
<span class="lineno">244</span> <span class="n">configs</span><span class="o">.</span><span class="n">run</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-69'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-69'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">248</span><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">'__main__'</span><span class="p">:</span>
|
||||
<span class="lineno">249</span> <span class="n">main</span><span class="p">()</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
</script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
console.log(images);
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
666
docs/diffusion/ddpm/index.html
Normal file
666
docs/diffusion/ddpm/index.html
Normal file
@ -0,0 +1,666 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="PyTorch implementation and tutorial of the paper Denoising Diffusion Probabilistic Models (DDPM)."/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Denoising Diffusion Probabilistic Models (DDPM)"/>
|
||||
<meta name="twitter:description" content="PyTorch implementation and tutorial of the paper Denoising Diffusion Probabilistic Models (DDPM)."/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/diffusion/ddpm/index.html"/>
|
||||
<meta property="og:title" content="Denoising Diffusion Probabilistic Models (DDPM)"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Denoising Diffusion Probabilistic Models (DDPM)"/>
|
||||
<meta property="og:description" content="PyTorch implementation and tutorial of the paper Denoising Diffusion Probabilistic Models (DDPM)."/>
|
||||
|
||||
<title>Denoising Diffusion Probabilistic Models (DDPM)</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/diffusion/ddpm/index.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">diffusion</a>
|
||||
<a class="parent" href="index.html">ddpm</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/diffusion/ddpm/__init__.py">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>Denoising Diffusion Probabilistic Models (DDPM)</h1>
|
||||
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation/tutorial of the paper
|
||||
<a href="https://papers.labml.ai/paper/2006.11239">Denoising Diffusion Probabilistic Models</a>.</p>
|
||||
<p>In simple terms, we get an image from data and add noise step by step.
|
||||
Then We train a model to predict that noise at each step and use the model to
|
||||
generate images.</p>
|
||||
<p>The following definitions and derivations show how this works.
|
||||
For details please refer to <a href="https://papers.labml.ai/paper/2006.11239">the paper</a>.</p>
|
||||
<h2>Forward Process</h2>
|
||||
<p>The forward process adds noise to the data $x_0 \sim q(x_0)$, for $T$ timesteps.</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
q(x_t | x_{t-1}) = \mathcal{N}\big(x_t; \sqrt{1- \beta_t} x_{t-1}, \beta_t \mathbf{I}\big) \\
|
||||
q(x_{1:T} | x_0) = \prod_{t = 1}^{T} q(x_t | x_{t-1})
|
||||
\end{align}</script>
|
||||
</p>
|
||||
<p>where $\beta_1, \dots, \beta_T$ is the variance schedule.</p>
|
||||
<p>We can sample $x_t$ at any timestep $t$ with,</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
|
||||
\end{align}</script>
|
||||
</p>
|
||||
<p>where $\alpha_t = 1 - \beta_t$ and $\bar\alpha_t = \prod_{s=1}^t \alpha_s$</p>
|
||||
<h2>Reverse Process</h2>
|
||||
<p>The reverse process removes noise starting at $p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
|
||||
for $T$ time steps.</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
\color{cyan}{p_\theta}(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1};
|
||||
\color{cyan}{\mu_\theta}x_t, t), \color{cyan}{\Sigma_\theta}(x_t, t)\big) \\
|
||||
\color{cyan}{p_\theta}(x_{0:T}) &= \color{cyan}{p_\theta}(x_T) \prod_{t = 1}^{T} \color{cyan}{p_\theta}(x_{t-1} | x_t) \\
|
||||
\color{cyan}{p_\theta}(x_0) &= \int \color{cyan}{p_\theta}(x_{0:T}) dx_{1:T}
|
||||
\end{align}</script>
|
||||
</p>
|
||||
<p>$\color{cyan}\theta$ are the parameters we train.</p>
|
||||
<h2>Loss</h2>
|
||||
<p>We optimize the ELBO (from Jenson’s inequality) on the negative log likelihood.</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
\mathbb{E}[-\log \color{cyan}{p_\theta}(x_0)]
|
||||
&\le \mathbb{E}_q [ -\log \frac{\color{cyan}{p_\theta}(x_{0:T})}{q(x_{1:T}|x_0)} ] \\
|
||||
&=L
|
||||
\end{align}</script>
|
||||
</p>
|
||||
<p>The loss can be rewritten as follows.</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
L
|
||||
&= \mathbb{E}_q [ -\log \frac{\color{cyan}{p_\theta}(x_{0:T})}{q(x_{1:T}|x_0)} ] \\
|
||||
&= \mathbb{E}_q [ -\log p(x_T) - \sum_{t=1}^T \log \frac{\color{cyan}{p_\theta}(x_{t-1}|x_t)}{q(x_t|x_{t-1})} ] \\
|
||||
&= \mathbb{E}_q [
|
||||
-\log \frac{p(x_T)}{q(x_T|x_0)}
|
||||
-\sum_{t=2}^T \log \frac{\color{cyan}{p_\theta}(x_{t-1}|x_t)}{q(x_{t-1}|x_t,x_0)}
|
||||
-\log \color{cyan}{p_\theta}(x_0|x_1)] \\
|
||||
&= \mathbb{E}_q [
|
||||
D_{KL}(q(x_T|x_0) \Vert p(x_T))
|
||||
+\sum_{t=2}^T D_{KL}(q(x_{t-1}|x_t,x_0) \Vert \color{cyan}{p_\theta}(x_{t-1}|x_t))
|
||||
-\log \color{cyan}{p_\theta}(x_0|x_1)]
|
||||
\end{align}</script>
|
||||
</p>
|
||||
<p>$D_{KL}(q(x_T|x_0) \Vert p(x_T))$ is constant since we keep $\beta_1, \dots, \beta_T$ constant.</p>
|
||||
<h3>Computing $L_{t-1} = D_{KL}(q(x_{t-1}|x_t,x_0) \Vert \color{cyan}{p_\theta}(x_{t-1}|x_t))$</h3>
|
||||
<p>The forward process posterior conditioned by $x_0$ is,</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
q(x_{t-1}|x_t, x_0) &= \mathcal{N} \Big(x_{t-1}; \tilde\mu_t(x_t, x_0), \tilde\beta_t \mathbf{I} \Big) \\
|
||||
\tilde\mu_t(x_t, x_0) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
|
||||
+ \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\
|
||||
\tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{a}
|
||||
\end{align}</script>
|
||||
</p>
|
||||
<p>The paper sets $\color{cyan}{\Sigma_\theta}(x_t, t) = \sigma_t^2 \mathbf{I}$ where $\sigma_t^2$ is set to constants
|
||||
$\beta_t$ or $\tilde\beta_t$.</p>
|
||||
<p>Then,
|
||||
<script type="math/tex; mode=display">\color{cyan}{p_\theta}(x_{t-1} | x_t) = \mathcal{N}\big(x_{t-1}; \color{cyan}{\mu_\theta}(x_t, t), \sigma_t^2 \mathbf{I} \big)</script>
|
||||
</p>
|
||||
<p>For given noise $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ using $q(x_t|x_0)$</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
x_t(x_0, \epsilon) &= \sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon \\
|
||||
x_0 &= \frac{1}{\sqrt{\bar\alpha_t}} \Big(x_t(x_0, \epsilon) - \sqrt{1-\bar\alpha_t}\epsilon\Big)
|
||||
\end{align}</script>
|
||||
</p>
|
||||
<p>This gives,</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
L_{t-1}
|
||||
&= D_{KL}(q(x_{t-1}|x_t,x_0) \Vert \color{cyan}{p_\theta}(x_{t-1}|x_t)) \\
|
||||
&= \mathbb{E}_q \Bigg[ \frac{1}{2\sigma_t^2}
|
||||
\Big \Vert \tilde\mu(x_t, x_0) - \color{cyan}{\mu_\theta}(x_t, t) \Big \Vert^2 \Bigg] \\
|
||||
&= \mathbb{E}_{x_0, \epsilon} \Bigg[ \frac{1}{2\sigma_t^2}
|
||||
\bigg\Vert \frac{1}{\sqrt{\alpha_t}} \Big(
|
||||
x_t(x_0, \epsilon) - \frac{\beta_t}{\sqrt{1 - \bar\alpha_t}} \epsilon
|
||||
\Big) - \color{cyan}{\mu_\theta}(x_t(x_0, \epsilon), t) \bigg\Vert^2 \Bigg] \\
|
||||
\end{align}</script>
|
||||
</p>
|
||||
<p>Re-parameterizing with a model to predict noise</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
\color{cyan}{\mu_\theta}(x_t, t) &= \tilde\mu \bigg(x_t,
|
||||
\frac{1}{\sqrt{\bar\alpha_t}} \Big(x_t -
|
||||
\sqrt{1-\bar\alpha_t}\color{cyan}{\epsilon_\theta}(x_t, t) \Big) \bigg) \\
|
||||
&= \frac{1}{\sqrt{\alpha_t}} \Big(x_t -
|
||||
\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\color{cyan}{\epsilon_\theta}(x_t, t) \Big)
|
||||
\end{align}</script>
|
||||
</p>
|
||||
<p>where $\epsilon_theta$ is a learned function that predicts $\epsilon$ given $(x_t, t)$.</p>
|
||||
<p>This gives,</p>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
L_{t-1}
|
||||
&= \mathbb{E}_{x_0, \epsilon} \Bigg[ \frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1 - \bar\alpha_t)}
|
||||
\Big\Vert
|
||||
\epsilon - \color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
|
||||
\Big\Vert^2 \Bigg]
|
||||
\end{align}</script>
|
||||
</p>
|
||||
<p>That is, we are training to predict the noise.</p>
|
||||
<h3>Simplified loss</h3>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">L_simple(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert
|
||||
\epsilon - \color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
|
||||
\bigg\Vert^2 \Bigg]</script>
|
||||
</p>
|
||||
<p>This minimizes $-\log \color{cyan}{p_\theta}(x_0|x_1)$ when $t=1$ and $L_{t-1}$ for $t\gt1$ discarding the
|
||||
weighting in $L_{t-1}$. Discarding the weights $\frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1 - \bar\alpha_t)}$
|
||||
increase the weight given to higher $t$ (which have higher noise levels), therefore increasing the sample quality.</p>
|
||||
<p>This file implements the loss calculation and a basic sampling method that we use to generate images during
|
||||
training.</p>
|
||||
<p>Here is the <a href="unet.html">UNet model</a> that gives $\color{cyan}{\epsilon_\theta}(x_t, t)$ and
|
||||
<a href="experiment.html">training code</a>.
|
||||
<a href="evaluate.html">This file</a> can generate samples and interpolations from a trained model.</p>
|
||||
<p><a href="https://app.labml.ai/run/a44333ea251411ec8007d1a1762ed686"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">162</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Optional</span>
|
||||
<span class="lineno">163</span>
|
||||
<span class="lineno">164</span><span class="kn">import</span> <span class="nn">torch</span>
|
||||
<span class="lineno">165</span><span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="nn">F</span>
|
||||
<span class="lineno">166</span><span class="kn">import</span> <span class="nn">torch.utils.data</span>
|
||||
<span class="lineno">167</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
|
||||
<span class="lineno">168</span>
|
||||
<span class="lineno">169</span><span class="kn">from</span> <span class="nn">labml_nn.diffusion.ddpm.utils</span> <span class="kn">import</span> <span class="n">gather</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<h2>Denoise Diffusion</h2>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">172</span><span class="k">class</span> <span class="nc">DenoiseDiffusion</span><span class="p">:</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
<ul>
|
||||
<li><code>eps_model</code> is $\color{cyan}{\epsilon_\theta}(x_t, t)$ model</li>
|
||||
<li><code>n_steps</code> is $t$</li>
|
||||
<li><code>device</code> is the device to place constants on</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">177</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">eps_model</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">n_steps</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">device</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-3'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-3'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">183</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
|
||||
<span class="lineno">184</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps_model</span> <span class="o">=</span> <span class="n">eps_model</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-4'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-4'>#</a>
|
||||
</div>
|
||||
<p>Create $\beta_1, \dots, \beta_T$ linearly increasing variance schedule</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">187</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mf">0.0001</span><span class="p">,</span> <span class="mf">0.02</span><span class="p">,</span> <span class="n">n_steps</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-5'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-5'>#</a>
|
||||
</div>
|
||||
<p>$\alpha_t = 1 - \beta_t$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">190</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-6'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-6'>#</a>
|
||||
</div>
|
||||
<p>$\bar\alpha_t = \prod_{s=1}^t \alpha_s$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">192</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha_bar</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cumprod</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">alpha</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-7'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-7'>#</a>
|
||||
</div>
|
||||
<p>$T$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">194</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_steps</span> <span class="o">=</span> <span class="n">n_steps</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-8'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-8'>#</a>
|
||||
</div>
|
||||
<p>$\sigma^2 = \beta$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">196</span> <span class="bp">self</span><span class="o">.</span><span class="n">sigma2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-9'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-9'>#</a>
|
||||
</div>
|
||||
<h4>Get $q(x_t|x_0)$ distribution</h4>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
|
||||
\end{align}</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">198</span> <span class="k">def</span> <span class="nf">q_xt_x0</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x0</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">Tuple</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]:</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-10'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-10'>#</a>
|
||||
</div>
|
||||
<p><a href="utils.html">gather</a> $\alpha_t$ and compute $\sqrt{\bar\alpha_t} x_0$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">208</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">alpha_bar</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span> <span class="o">**</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">x0</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-11'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-11'>#</a>
|
||||
</div>
|
||||
<p>$(1-\bar\alpha_t) \mathbf{I}$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">210</span> <span class="n">var</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">gather</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">alpha_bar</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-12'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-12'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">212</span> <span class="k">return</span> <span class="n">mean</span><span class="p">,</span> <span class="n">var</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-13'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-13'>#</a>
|
||||
</div>
|
||||
<h4>Sample from $q(x_t|x_0)$</h4>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
|
||||
\end{align}</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">214</span> <span class="k">def</span> <span class="nf">q_sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x0</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">eps</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-14'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-14'>#</a>
|
||||
</div>
|
||||
<p>$\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">224</span> <span class="k">if</span> <span class="n">eps</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="lineno">225</span> <span class="n">eps</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">x0</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-15'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-15'>#</a>
|
||||
</div>
|
||||
<p>get $q(x_t|x_0)$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">228</span> <span class="n">mean</span><span class="p">,</span> <span class="n">var</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">q_xt_x0</span><span class="p">(</span><span class="n">x0</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-16'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-16'>#</a>
|
||||
</div>
|
||||
<p>Sample from $q(x_t|x_0)$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">230</span> <span class="k">return</span> <span class="n">mean</span> <span class="o">+</span> <span class="p">(</span><span class="n">var</span> <span class="o">**</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">*</span> <span class="n">eps</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-17'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-17'>#</a>
|
||||
</div>
|
||||
<h4>Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$</h4>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\begin{align}
|
||||
\color{cyan}{p_\theta}(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1};
|
||||
\color{cyan}{\mu_\theta}(x_t, t), \sigma_t^2 \mathbf{I} \big) \\
|
||||
\color{cyan}{\mu_\theta}(x_t, t)
|
||||
&= \frac{1}{\sqrt{\alpha_t}} \Big(x_t -
|
||||
\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\color{cyan}{\epsilon_\theta}(x_t, t) \Big)
|
||||
\end{align}</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">232</span> <span class="k">def</span> <span class="nf">p_sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">xt</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-18'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-18'>#</a>
|
||||
</div>
|
||||
<p>$\color{cyan}{\epsilon_\theta}(x_t, t)$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">246</span> <span class="n">eps_theta</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps_model</span><span class="p">(</span><span class="n">xt</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-19'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-19'>#</a>
|
||||
</div>
|
||||
<p><a href="utils.html">gather</a> $\bar\alpha_t$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">248</span> <span class="n">alpha_bar</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">alpha_bar</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-20'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-20'>#</a>
|
||||
</div>
|
||||
<p>$\alpha_t$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">250</span> <span class="n">alpha</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">alpha</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-21'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-21'>#</a>
|
||||
</div>
|
||||
<p>$\frac{\beta}{\sqrt{1-\bar\alpha_t}}$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">252</span> <span class="n">eps_coef</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">alpha</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">alpha_bar</span><span class="p">)</span> <span class="o">**</span> <span class="o">.</span><span class="mi">5</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-22'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-22'>#</a>
|
||||
</div>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">\frac{1}{\sqrt{\alpha_t}} \Big(x_t -
|
||||
\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\color{cyan}{\epsilon_\theta}(x_t, t) \Big)</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">255</span> <span class="n">mean</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="p">(</span><span class="n">alpha</span> <span class="o">**</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">xt</span> <span class="o">-</span> <span class="n">eps_coef</span> <span class="o">*</span> <span class="n">eps_theta</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-23'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-23'>#</a>
|
||||
</div>
|
||||
<p>$\sigma^2$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">257</span> <span class="n">var</span> <span class="o">=</span> <span class="n">gather</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sigma2</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-24'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-24'>#</a>
|
||||
</div>
|
||||
<p>$\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">260</span> <span class="n">eps</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">xt</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">xt</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-25'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-25'>#</a>
|
||||
</div>
|
||||
<p>Sample</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">262</span> <span class="k">return</span> <span class="n">mean</span> <span class="o">+</span> <span class="p">(</span><span class="n">var</span> <span class="o">**</span> <span class="o">.</span><span class="mi">5</span><span class="p">)</span> <span class="o">*</span> <span class="n">eps</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-26'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-26'>#</a>
|
||||
</div>
|
||||
<h4>Simplified Loss</h4>
|
||||
<p>
|
||||
<script type="math/tex; mode=display">L_simple(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert
|
||||
\epsilon - \color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
|
||||
\bigg\Vert^2 \Bigg]</script>
|
||||
</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">264</span> <span class="k">def</span> <span class="nf">loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x0</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">noise</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-27'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-27'>#</a>
|
||||
</div>
|
||||
<p>Get batch size</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">273</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">x0</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-28'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-28'>#</a>
|
||||
</div>
|
||||
<p>Get random $t$ for each sample in the batch</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">275</span> <span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_steps</span><span class="p">,</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,),</span> <span class="n">device</span><span class="o">=</span><span class="n">x0</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">long</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-29'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-29'>#</a>
|
||||
</div>
|
||||
<p>$\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">278</span> <span class="k">if</span> <span class="n">noise</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
||||
<span class="lineno">279</span> <span class="n">noise</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">x0</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-30'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-30'>#</a>
|
||||
</div>
|
||||
<p>Sample $x_t$ for $q(x_t|x_0)$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">282</span> <span class="n">xt</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">q_sample</span><span class="p">(</span><span class="n">x0</span><span class="p">,</span> <span class="n">t</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="n">noise</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-31'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-31'>#</a>
|
||||
</div>
|
||||
<p>Get $\color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)$</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">284</span> <span class="n">eps_theta</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps_model</span><span class="p">(</span><span class="n">xt</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-32'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-32'>#</a>
|
||||
</div>
|
||||
<p>MSE loss</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">287</span> <span class="k">return</span> <span class="n">F</span><span class="o">.</span><span class="n">mse_loss</span><span class="p">(</span><span class="n">noise</span><span class="p">,</span> <span class="n">eps_theta</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
</script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
console.log(images);
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
150
docs/diffusion/ddpm/readme.html
Normal file
150
docs/diffusion/ddpm/readme.html
Normal file
@ -0,0 +1,150 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content=""/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Denoising Diffusion Probabilistic Models (DDPM)"/>
|
||||
<meta name="twitter:description" content=""/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/diffusion/ddpm/readme.html"/>
|
||||
<meta property="og:title" content="Denoising Diffusion Probabilistic Models (DDPM)"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Denoising Diffusion Probabilistic Models (DDPM)"/>
|
||||
<meta property="og:description" content=""/>
|
||||
|
||||
<title>Denoising Diffusion Probabilistic Models (DDPM)</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/diffusion/ddpm/readme.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">diffusion</a>
|
||||
<a class="parent" href="index.html">ddpm</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/diffusion/ddpm/readme.md">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1><a href="https://nn.labml.ai/diffusion/ddpm/index.html">Denoising Diffusion Probabilistic Models (DDPM)</a></h1>
|
||||
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation/tutorial of the paper
|
||||
<a href="https://papers.labml.ai/paper/2006.11239">Denoising Diffusion Probabilistic Models</a>.</p>
|
||||
<p>In simple terms, we get an image from data and add noise step by step.
|
||||
Then We train a model to predict that noise at each step and use the model to
|
||||
generate images.</p>
|
||||
<p>Here is the <a href="https://nn.labml.ai/diffusion/ddpm/unet.html">UNet model</a> that predicts the noise and
|
||||
<a href="https://nn.labml.ai/diffusion/ddpm/experiment.html">training code</a>.
|
||||
<a href="https://nn.labml.ai/diffusion/ddpm/evaluate.html">This file</a> can generate samples and interpolations
|
||||
from a trained model.</p>
|
||||
<p><a href="https://app.labml.ai/run/a44333ea251411ec8007d1a1762ed686"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
</script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
console.log(images);
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
1336
docs/diffusion/ddpm/unet.html
Normal file
1336
docs/diffusion/ddpm/unet.html
Normal file
File diff suppressed because it is too large
Load Diff
BIN
docs/diffusion/ddpm/unet.png
Normal file
BIN
docs/diffusion/ddpm/unet.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 130 KiB |
163
docs/diffusion/ddpm/utils.html
Normal file
163
docs/diffusion/ddpm/utils.html
Normal file
@ -0,0 +1,163 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="Utility functions for DDPM experiment"/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Utility functions for DDPM experiment"/>
|
||||
<meta name="twitter:description" content="Utility functions for DDPM experiment"/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/diffusion/ddpm/utils.html"/>
|
||||
<meta property="og:title" content="Utility functions for DDPM experiment"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Utility functions for DDPM experiment"/>
|
||||
<meta property="og:description" content="Utility functions for DDPM experiment"/>
|
||||
|
||||
<title>Utility functions for DDPM experiment</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/diffusion/ddpm/utils.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">diffusion</a>
|
||||
<a class="parent" href="index.html">ddpm</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/diffusion/ddpm/utils.py">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>Utility functions for <a href="index.html">DDPM</a> experiemnt</h1>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">10</span><span></span><span class="kn">import</span> <span class="nn">torch.utils.data</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-1'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-1'>#</a>
|
||||
</div>
|
||||
<p>Gather consts for $t$ and reshape to feature map shape</p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">13</span><span class="k">def</span> <span class="nf">gather</span><span class="p">(</span><span class="n">consts</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-2'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-2'>#</a>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre><span class="lineno">15</span> <span class="n">c</span> <span class="o">=</span> <span class="n">consts</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
|
||||
<span class="lineno">16</span> <span class="k">return</span> <span class="n">c</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
</script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
console.log(images);
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
142
docs/diffusion/index.html
Normal file
142
docs/diffusion/index.html
Normal file
@ -0,0 +1,142 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content="A set of PyTorch implementations/tutorials of diffusion models."/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Diffusion models"/>
|
||||
<meta name="twitter:description" content="A set of PyTorch implementations/tutorials of diffusion models."/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/diffusion/index.html"/>
|
||||
<meta property="og:title" content="Diffusion models"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Diffusion models"/>
|
||||
<meta property="og:description" content="A set of PyTorch implementations/tutorials of diffusion models."/>
|
||||
|
||||
<title>Diffusion models</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/diffusion/index.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="index.html">diffusion</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/diffusion/__init__.py">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs doc-strings'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1>Diffusion models</h1>
|
||||
<ul>
|
||||
<li><a href="ddpm/index.html">Denoising Diffusion Probabilistic Models (DDPM)</a></li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class='code'>
|
||||
<div class="highlight"><pre></pre></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
</script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
console.log(images);
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
@ -112,6 +112,10 @@ implementations.</p>
|
||||
<li><a href="gan/wasserstein/gradient_penalty/index.html">Wasserstein GAN with Gradient Penalty</a></li>
|
||||
<li><a href="gan/stylegan/index.html">StyleGAN 2</a></li>
|
||||
</ul>
|
||||
<h4>✨ <a href="diffusion/index.html">Diffusion models</a></h4>
|
||||
<ul>
|
||||
<li><a href="diffusion/ddpm/index.html">Denoising Diffusion Probabilistic Models (DDPM)</a></li>
|
||||
</ul>
|
||||
<h4>✨ <a href="sketch_rnn/index.html">Sketch RNN</a></h4>
|
||||
<h4>✨ Graph Neural Networks</h4>
|
||||
<ul>
|
||||
|
@ -42,6 +42,12 @@
|
||||
"1503.02531": [
|
||||
"https://nn.labml.ai/distillation/index.html"
|
||||
],
|
||||
"1505.04597": [
|
||||
"https://nn.labml.ai/diffusion/ddpm/unet.html"
|
||||
],
|
||||
"2006.11239": [
|
||||
"https://nn.labml.ai/diffusion/ddpm/index.html"
|
||||
],
|
||||
"2010.07468": [
|
||||
"https://nn.labml.ai/optimizers/ada_belief.html"
|
||||
],
|
||||
|
147
docs/rl/dqn/readme.html
Normal file
147
docs/rl/dqn/readme.html
Normal file
@ -0,0 +1,147 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="content-type" content="text/html;charset=utf-8"/>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||||
<meta name="description" content=""/>
|
||||
|
||||
<meta name="twitter:card" content="summary"/>
|
||||
<meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta name="twitter:title" content="Deep Q Networks (DQN)"/>
|
||||
<meta name="twitter:description" content=""/>
|
||||
<meta name="twitter:site" content="@labmlai"/>
|
||||
<meta name="twitter:creator" content="@labmlai"/>
|
||||
|
||||
<meta property="og:url" content="https://nn.labml.ai/rl/dqn/readme.html"/>
|
||||
<meta property="og:title" content="Deep Q Networks (DQN)"/>
|
||||
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"/>
|
||||
<meta property="og:site_name" content="LabML Neural Networks"/>
|
||||
<meta property="og:type" content="object"/>
|
||||
<meta property="og:title" content="Deep Q Networks (DQN)"/>
|
||||
<meta property="og:description" content=""/>
|
||||
|
||||
<title>Deep Q Networks (DQN)</title>
|
||||
<link rel="shortcut icon" href="/icon.png"/>
|
||||
<link rel="stylesheet" href="../../pylit.css">
|
||||
<link rel="canonical" href="https://nn.labml.ai/rl/dqn/readme.html"/>
|
||||
<!-- Global site tag (gtag.js) - Google Analytics -->
|
||||
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
|
||||
<script>
|
||||
window.dataLayer = window.dataLayer || [];
|
||||
|
||||
function gtag() {
|
||||
dataLayer.push(arguments);
|
||||
}
|
||||
|
||||
gtag('js', new Date());
|
||||
|
||||
gtag('config', 'G-4V3HC8HBLH');
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div id='container'>
|
||||
<div id="background"></div>
|
||||
<div class='section'>
|
||||
<div class='docs'>
|
||||
<p>
|
||||
<a class="parent" href="/">home</a>
|
||||
<a class="parent" href="../index.html">rl</a>
|
||||
<a class="parent" href="index.html">dqn</a>
|
||||
</p>
|
||||
<p>
|
||||
|
||||
<a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/rl/dqn/readme.md">
|
||||
<img alt="Github"
|
||||
src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
<a href="https://twitter.com/labmlai"
|
||||
rel="nofollow">
|
||||
<img alt="Twitter"
|
||||
src="https://img.shields.io/twitter/follow/labmlai?style=social"
|
||||
style="max-width:100%;"/></a>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class='section' id='section-0'>
|
||||
<div class='docs'>
|
||||
<div class='section-link'>
|
||||
<a href='#section-0'>#</a>
|
||||
</div>
|
||||
<h1><a href="https://nn.labml.ai/rl/dqn/index.html">Deep Q Networks (DQN)</a></h1>
|
||||
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of paper
|
||||
<a href="https://papers.labml.ai/paper/1312.5602">Playing Atari with Deep Reinforcement Learning</a>
|
||||
along with <a href="https://nn.labml.ai/rl/dqn/model.html">Dueling Network</a>, <a href="https://nn.labml.ai/rl/dqn/replay_buffer.html">Prioritized Replay</a>
|
||||
and Double Q Network.</p>
|
||||
<p>Here is the <a href="https://nn.labml.ai/rl/dqn/experiment.html">experiment</a> and <a href="https://nn.labml.ai/rl/dqn/model.html">model</a> implementation.</p>
|
||||
<p><a href="https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/rl/dqn/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
|
||||
<a href="https://app.labml.ai/run/fe1ad986237511ec86e8b763a2d3f710"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
|
||||
</div>
|
||||
<div class='code'>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
<div class='footer'>
|
||||
<a href="https://papers.labml.ai">Trending Research Papers</a>
|
||||
<a href="https://labml.ai">labml.ai</a>
|
||||
</div>
|
||||
</div>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML">
|
||||
</script>
|
||||
<!-- MathJax configuration -->
|
||||
<script type="text/x-mathjax-config">
|
||||
MathJax.Hub.Config({
|
||||
tex2jax: {
|
||||
inlineMath: [ ['$','$'] ],
|
||||
displayMath: [ ['$$','$$'] ],
|
||||
processEscapes: true,
|
||||
processEnvironments: true
|
||||
},
|
||||
// Center justify equations in code and markdown cells. Elsewhere
|
||||
// we use CSS to left justify single line equations in code cells.
|
||||
displayAlign: 'center',
|
||||
"HTML-CSS": { fonts: ["TeX"] }
|
||||
});
|
||||
|
||||
</script>
|
||||
<script>
|
||||
function handleImages() {
|
||||
var images = document.querySelectorAll('p>img')
|
||||
|
||||
console.log(images);
|
||||
for (var i = 0; i < images.length; ++i) {
|
||||
handleImage(images[i])
|
||||
}
|
||||
}
|
||||
|
||||
function handleImage(img) {
|
||||
img.parentElement.style.textAlign = 'center'
|
||||
|
||||
var modal = document.createElement('div')
|
||||
modal.id = 'modal'
|
||||
|
||||
var modalContent = document.createElement('div')
|
||||
modal.appendChild(modalContent)
|
||||
|
||||
var modalImage = document.createElement('img')
|
||||
modalContent.appendChild(modalImage)
|
||||
|
||||
var span = document.createElement('span')
|
||||
span.classList.add('close')
|
||||
span.textContent = 'x'
|
||||
modal.appendChild(span)
|
||||
|
||||
img.onclick = function () {
|
||||
console.log('clicked')
|
||||
document.body.appendChild(modal)
|
||||
modalImage.src = img.src
|
||||
}
|
||||
|
||||
span.onclick = function () {
|
||||
document.body.removeChild(modal)
|
||||
}
|
||||
}
|
||||
|
||||
handleImages()
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
@ -328,6 +328,48 @@
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/diffusion/index.html</loc>
|
||||
<lastmod>2021-10-06T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/diffusion/ddpm/unet.html</loc>
|
||||
<lastmod>2021-10-08T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/diffusion/ddpm/index.html</loc>
|
||||
<lastmod>2021-10-08T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/diffusion/ddpm/experiment.html</loc>
|
||||
<lastmod>2021-10-08T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/diffusion/ddpm/utils.html</loc>
|
||||
<lastmod>2021-10-08T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/diffusion/ddpm/evaluate.html</loc>
|
||||
<lastmod>2021-10-08T16:30:00+00:00</lastmod>
|
||||
<priority>1.00</priority>
|
||||
</url>
|
||||
|
||||
|
||||
<url>
|
||||
<loc>https://nn.labml.ai/optimizers/adam_warmup.html</loc>
|
||||
<lastmod>2021-01-13T16:30:00+00:00</lastmod>
|
||||
|
@ -52,6 +52,10 @@ implementations.
|
||||
* [Wasserstein GAN with Gradient Penalty](gan/wasserstein/gradient_penalty/index.html)
|
||||
* [StyleGAN 2](gan/stylegan/index.html)
|
||||
|
||||
#### ✨ [Diffusion models](diffusion/index.html)
|
||||
|
||||
* [Denoising Diffusion Probabilistic Models (DDPM)](diffusion/ddpm/index.html)
|
||||
|
||||
#### ✨ [Sketch RNN](sketch_rnn/index.html)
|
||||
|
||||
#### ✨ Graph Neural Networks
|
||||
|
11
labml_nn/diffusion/__init__.py
Normal file
11
labml_nn/diffusion/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
"""
|
||||
---
|
||||
title: Diffusion models
|
||||
summary: >
|
||||
A set of PyTorch implementations/tutorials of diffusion models.
|
||||
---
|
||||
|
||||
# Diffusion models
|
||||
|
||||
* [Denoising Diffusion Probabilistic Models (DDPM)](ddpm/index.html)
|
||||
"""
|
287
labml_nn/diffusion/ddpm/__init__.py
Normal file
287
labml_nn/diffusion/ddpm/__init__.py
Normal file
@ -0,0 +1,287 @@
|
||||
"""
|
||||
---
|
||||
title: Denoising Diffusion Probabilistic Models (DDPM)
|
||||
summary: >
|
||||
PyTorch implementation and tutorial of the paper
|
||||
Denoising Diffusion Probabilistic Models (DDPM).
|
||||
---
|
||||
|
||||
# Denoising Diffusion Probabilistic Models (DDPM)
|
||||
|
||||
This is a [PyTorch](https://pytorch.org) implementation/tutorial of the paper
|
||||
[Denoising Diffusion Probabilistic Models](https://papers.labml.ai/paper/2006.11239).
|
||||
|
||||
In simple terms, we get an image from data and add noise step by step.
|
||||
Then We train a model to predict that noise at each step and use the model to
|
||||
generate images.
|
||||
|
||||
The following definitions and derivations show how this works.
|
||||
For details please refer to [the paper](https://papers.labml.ai/paper/2006.11239).
|
||||
|
||||
## Forward Process
|
||||
|
||||
The forward process adds noise to the data $x_0 \sim q(x_0)$, for $T$ timesteps.
|
||||
|
||||
\begin{align}
|
||||
q(x_t | x_{t-1}) = \mathcal{N}\big(x_t; \sqrt{1- \beta_t} x_{t-1}, \beta_t \mathbf{I}\big) \\
|
||||
q(x_{1:T} | x_0) = \prod_{t = 1}^{T} q(x_t | x_{t-1})
|
||||
\end{align}
|
||||
|
||||
where $\beta_1, \dots, \beta_T$ is the variance schedule.
|
||||
|
||||
We can sample $x_t$ at any timestep $t$ with,
|
||||
|
||||
\begin{align}
|
||||
q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
|
||||
\end{align}
|
||||
|
||||
where $\alpha_t = 1 - \beta_t$ and $\bar\alpha_t = \prod_{s=1}^t \alpha_s$
|
||||
|
||||
## Reverse Process
|
||||
|
||||
The reverse process removes noise starting at $p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
|
||||
for $T$ time steps.
|
||||
|
||||
\begin{align}
|
||||
\color{cyan}{p_\theta}(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1};
|
||||
\color{cyan}{\mu_\theta}x_t, t), \color{cyan}{\Sigma_\theta}(x_t, t)\big) \\
|
||||
\color{cyan}{p_\theta}(x_{0:T}) &= \color{cyan}{p_\theta}(x_T) \prod_{t = 1}^{T} \color{cyan}{p_\theta}(x_{t-1} | x_t) \\
|
||||
\color{cyan}{p_\theta}(x_0) &= \int \color{cyan}{p_\theta}(x_{0:T}) dx_{1:T}
|
||||
\end{align}
|
||||
|
||||
$\color{cyan}\theta$ are the parameters we train.
|
||||
|
||||
## Loss
|
||||
|
||||
We optimize the ELBO (from Jenson's inequality) on the negative log likelihood.
|
||||
|
||||
\begin{align}
|
||||
\mathbb{E}[-\log \color{cyan}{p_\theta}(x_0)]
|
||||
&\le \mathbb{E}_q [ -\log \frac{\color{cyan}{p_\theta}(x_{0:T})}{q(x_{1:T}|x_0)} ] \\
|
||||
&=L
|
||||
\end{align}
|
||||
|
||||
The loss can be rewritten as follows.
|
||||
|
||||
\begin{align}
|
||||
L
|
||||
&= \mathbb{E}_q [ -\log \frac{\color{cyan}{p_\theta}(x_{0:T})}{q(x_{1:T}|x_0)} ] \\
|
||||
&= \mathbb{E}_q [ -\log p(x_T) - \sum_{t=1}^T \log \frac{\color{cyan}{p_\theta}(x_{t-1}|x_t)}{q(x_t|x_{t-1})} ] \\
|
||||
&= \mathbb{E}_q [
|
||||
-\log \frac{p(x_T)}{q(x_T|x_0)}
|
||||
-\sum_{t=2}^T \log \frac{\color{cyan}{p_\theta}(x_{t-1}|x_t)}{q(x_{t-1}|x_t,x_0)}
|
||||
-\log \color{cyan}{p_\theta}(x_0|x_1)] \\
|
||||
&= \mathbb{E}_q [
|
||||
D_{KL}(q(x_T|x_0) \Vert p(x_T))
|
||||
+\sum_{t=2}^T D_{KL}(q(x_{t-1}|x_t,x_0) \Vert \color{cyan}{p_\theta}(x_{t-1}|x_t))
|
||||
-\log \color{cyan}{p_\theta}(x_0|x_1)]
|
||||
\end{align}
|
||||
|
||||
$D_{KL}(q(x_T|x_0) \Vert p(x_T))$ is constant since we keep $\beta_1, \dots, \beta_T$ constant.
|
||||
|
||||
### Computing $L_{t-1} = D_{KL}(q(x_{t-1}|x_t,x_0) \Vert \color{cyan}{p_\theta}(x_{t-1}|x_t))$
|
||||
|
||||
The forward process posterior conditioned by $x_0$ is,
|
||||
|
||||
\begin{align}
|
||||
q(x_{t-1}|x_t, x_0) &= \mathcal{N} \Big(x_{t-1}; \tilde\mu_t(x_t, x_0), \tilde\beta_t \mathbf{I} \Big) \\
|
||||
\tilde\mu_t(x_t, x_0) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
|
||||
+ \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\
|
||||
\tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{a}
|
||||
\end{align}
|
||||
|
||||
The paper sets $\color{cyan}{\Sigma_\theta}(x_t, t) = \sigma_t^2 \mathbf{I}$ where $\sigma_t^2$ is set to constants
|
||||
$\beta_t$ or $\tilde\beta_t$.
|
||||
|
||||
Then,
|
||||
$$\color{cyan}{p_\theta}(x_{t-1} | x_t) = \mathcal{N}\big(x_{t-1}; \color{cyan}{\mu_\theta}(x_t, t), \sigma_t^2 \mathbf{I} \big)$$
|
||||
|
||||
For given noise $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ using $q(x_t|x_0)$
|
||||
|
||||
\begin{align}
|
||||
x_t(x_0, \epsilon) &= \sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon \\
|
||||
x_0 &= \frac{1}{\sqrt{\bar\alpha_t}} \Big(x_t(x_0, \epsilon) - \sqrt{1-\bar\alpha_t}\epsilon\Big)
|
||||
\end{align}
|
||||
|
||||
This gives,
|
||||
|
||||
\begin{align}
|
||||
L_{t-1}
|
||||
&= D_{KL}(q(x_{t-1}|x_t,x_0) \Vert \color{cyan}{p_\theta}(x_{t-1}|x_t)) \\
|
||||
&= \mathbb{E}_q \Bigg[ \frac{1}{2\sigma_t^2}
|
||||
\Big \Vert \tilde\mu(x_t, x_0) - \color{cyan}{\mu_\theta}(x_t, t) \Big \Vert^2 \Bigg] \\
|
||||
&= \mathbb{E}_{x_0, \epsilon} \Bigg[ \frac{1}{2\sigma_t^2}
|
||||
\bigg\Vert \frac{1}{\sqrt{\alpha_t}} \Big(
|
||||
x_t(x_0, \epsilon) - \frac{\beta_t}{\sqrt{1 - \bar\alpha_t}} \epsilon
|
||||
\Big) - \color{cyan}{\mu_\theta}(x_t(x_0, \epsilon), t) \bigg\Vert^2 \Bigg] \\
|
||||
\end{align}
|
||||
|
||||
Re-parameterizing with a model to predict noise
|
||||
|
||||
\begin{align}
|
||||
\color{cyan}{\mu_\theta}(x_t, t) &= \tilde\mu \bigg(x_t,
|
||||
\frac{1}{\sqrt{\bar\alpha_t}} \Big(x_t -
|
||||
\sqrt{1-\bar\alpha_t}\color{cyan}{\epsilon_\theta}(x_t, t) \Big) \bigg) \\
|
||||
&= \frac{1}{\sqrt{\alpha_t}} \Big(x_t -
|
||||
\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\color{cyan}{\epsilon_\theta}(x_t, t) \Big)
|
||||
\end{align}
|
||||
|
||||
where $\epsilon_theta$ is a learned function that predicts $\epsilon$ given $(x_t, t)$.
|
||||
|
||||
This gives,
|
||||
|
||||
\begin{align}
|
||||
L_{t-1}
|
||||
&= \mathbb{E}_{x_0, \epsilon} \Bigg[ \frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1 - \bar\alpha_t)}
|
||||
\Big\Vert
|
||||
\epsilon - \color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
|
||||
\Big\Vert^2 \Bigg]
|
||||
\end{align}
|
||||
|
||||
That is, we are training to predict the noise.
|
||||
|
||||
### Simplified loss
|
||||
|
||||
$$L_simple(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert
|
||||
\epsilon - \color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
|
||||
\bigg\Vert^2 \Bigg]$$
|
||||
|
||||
This minimizes $-\log \color{cyan}{p_\theta}(x_0|x_1)$ when $t=1$ and $L_{t-1}$ for $t\gt1$ discarding the
|
||||
weighting in $L_{t-1}$. Discarding the weights $\frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1 - \bar\alpha_t)}$
|
||||
increase the weight given to higher $t$ (which have higher noise levels), therefore increasing the sample quality.
|
||||
|
||||
This file implements the loss calculation and a basic sampling method that we use to generate images during
|
||||
training.
|
||||
|
||||
Here is the [UNet model](unet.html) that gives $\color{cyan}{\epsilon_\theta}(x_t, t)$ and
|
||||
[training code](experiment.html).
|
||||
[This file](evaluate.html) can generate samples and interpolations from a trained model.
|
||||
|
||||
[](https://app.labml.ai/run/a44333ea251411ec8007d1a1762ed686)
|
||||
"""
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
from torch import nn
|
||||
|
||||
from labml_nn.diffusion.ddpm.utils import gather
|
||||
|
||||
|
||||
class DenoiseDiffusion:
|
||||
"""
|
||||
## Denoise Diffusion
|
||||
"""
|
||||
|
||||
def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
|
||||
"""
|
||||
* `eps_model` is $\color{cyan}{\epsilon_\theta}(x_t, t)$ model
|
||||
* `n_steps` is $t$
|
||||
* `device` is the device to place constants on
|
||||
"""
|
||||
super().__init__()
|
||||
self.eps_model = eps_model
|
||||
|
||||
# Create $\beta_1, \dots, \beta_T$ linearly increasing variance schedule
|
||||
self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
|
||||
|
||||
# $\alpha_t = 1 - \beta_t$
|
||||
self.alpha = 1. - self.beta
|
||||
# $\bar\alpha_t = \prod_{s=1}^t \alpha_s$
|
||||
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
|
||||
# $T$
|
||||
self.n_steps = n_steps
|
||||
# $\sigma^2 = \beta$
|
||||
self.sigma2 = self.beta
|
||||
|
||||
def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
#### Get $q(x_t|x_0)$ distribution
|
||||
|
||||
\begin{align}
|
||||
q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
|
||||
\end{align}
|
||||
"""
|
||||
|
||||
# [gather](utils.html) $\alpha_t$ and compute $\sqrt{\bar\alpha_t} x_0$
|
||||
mean = gather(self.alpha_bar, t) ** 0.5 * x0
|
||||
# $(1-\bar\alpha_t) \mathbf{I}$
|
||||
var = 1 - gather(self.alpha_bar, t)
|
||||
#
|
||||
return mean, var
|
||||
|
||||
def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
#### Sample from $q(x_t|x_0)$
|
||||
|
||||
\begin{align}
|
||||
q(x_t|x_0) &= \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)
|
||||
\end{align}
|
||||
"""
|
||||
|
||||
# $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
|
||||
if eps is None:
|
||||
eps = torch.randn_like(x0)
|
||||
|
||||
# get $q(x_t|x_0)$
|
||||
mean, var = self.q_xt_x0(x0, t)
|
||||
# Sample from $q(x_t|x_0)$
|
||||
return mean + (var ** 0.5) * eps
|
||||
|
||||
def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
|
||||
"""
|
||||
#### Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$
|
||||
|
||||
\begin{align}
|
||||
\color{cyan}{p_\theta}(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1};
|
||||
\color{cyan}{\mu_\theta}(x_t, t), \sigma_t^2 \mathbf{I} \big) \\
|
||||
\color{cyan}{\mu_\theta}(x_t, t)
|
||||
&= \frac{1}{\sqrt{\alpha_t}} \Big(x_t -
|
||||
\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\color{cyan}{\epsilon_\theta}(x_t, t) \Big)
|
||||
\end{align}
|
||||
"""
|
||||
|
||||
# $\color{cyan}{\epsilon_\theta}(x_t, t)$
|
||||
eps_theta = self.eps_model(xt, t)
|
||||
# [gather](utils.html) $\bar\alpha_t$
|
||||
alpha_bar = gather(self.alpha_bar, t)
|
||||
# $\alpha_t$
|
||||
alpha = gather(self.alpha, t)
|
||||
# $\frac{\beta}{\sqrt{1-\bar\alpha_t}}$
|
||||
eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
|
||||
# $$\frac{1}{\sqrt{\alpha_t}} \Big(x_t -
|
||||
# \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\color{cyan}{\epsilon_\theta}(x_t, t) \Big)$$
|
||||
mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
|
||||
# $\sigma^2$
|
||||
var = gather(self.sigma2, t)
|
||||
|
||||
# $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
|
||||
eps = torch.randn(xt.shape, device=xt.device)
|
||||
# Sample
|
||||
return mean + (var ** .5) * eps
|
||||
|
||||
def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
#### Simplified Loss
|
||||
|
||||
$$L_simple(\theta) = \mathbb{E}_{t,x_0, \epsilon} \Bigg[ \bigg\Vert
|
||||
\epsilon - \color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)
|
||||
\bigg\Vert^2 \Bigg]$$
|
||||
"""
|
||||
# Get batch size
|
||||
batch_size = x0.shape[0]
|
||||
# Get random $t$ for each sample in the batch
|
||||
t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
|
||||
|
||||
# $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
|
||||
# Sample $x_t$ for $q(x_t|x_0)$
|
||||
xt = self.q_sample(x0, t, eps=noise)
|
||||
# Get $\color{cyan}{\epsilon_\theta}(\sqrt{\bar\alpha_t} x_0 + \sqrt{1-\bar\alpha_t}\epsilon, t)$
|
||||
eps_theta = self.eps_model(xt, t)
|
||||
|
||||
# MSE loss
|
||||
return F.mse_loss(noise, eps_theta)
|
327
labml_nn/diffusion/ddpm/evaluate.py
Normal file
327
labml_nn/diffusion/ddpm/evaluate.py
Normal file
@ -0,0 +1,327 @@
|
||||
"""
|
||||
---
|
||||
title: Denoising Diffusion Probabilistic Models (DDPM) evaluation/sampling
|
||||
summary: >
|
||||
Code to generate samples from a trained
|
||||
Denoising Diffusion Probabilistic Model.
|
||||
---
|
||||
|
||||
# [Denoising Diffusion Probabilistic Models (DDPM)](index.html) evaluation/sampling
|
||||
|
||||
This is the code to generate images and create interpolations between given images.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from torchvision.transforms.functional import to_pil_image, resize
|
||||
|
||||
from labml import experiment, monit
|
||||
from labml_nn.diffusion.ddpm import DenoiseDiffusion, gather
|
||||
from labml_nn.diffusion.ddpm.experiment import Configs
|
||||
|
||||
|
||||
class Sampler:
|
||||
"""
|
||||
## Sampler class
|
||||
"""
|
||||
|
||||
def __init__(self, diffusion: DenoiseDiffusion, image_channels: int, image_size: int, device: torch.device):
|
||||
"""
|
||||
* `diffusion` is the `DenoiseDiffusion` instance
|
||||
* `image_channels` is the number of channels in the image
|
||||
* `image_size` is the image size
|
||||
* `device` is the device of the model
|
||||
"""
|
||||
self.device = device
|
||||
self.image_size = image_size
|
||||
self.image_channels = image_channels
|
||||
self.diffusion = diffusion
|
||||
|
||||
# $T$
|
||||
self.n_steps = diffusion.n_steps
|
||||
# $\color{cyan}{\epsilon_\theta}(x_t, t)$
|
||||
self.eps_model = diffusion.eps_model
|
||||
# $\beta_t$
|
||||
self.beta = diffusion.beta
|
||||
# $\alpha_t$
|
||||
self.alpha = diffusion.alpha
|
||||
# $\bar\alpha_t$
|
||||
self.alpha_bar = diffusion.alpha_bar
|
||||
# $\bar\alpha_{t-1}$
|
||||
alpha_bar_tm1 = torch.cat([self.alpha_bar.new_ones((1,)), self.alpha_bar[:-1]])
|
||||
|
||||
# To calculate
|
||||
# \begin{align}
|
||||
# q(x_{t-1}|x_t, x_0) &= \mathcal{N} \Big(x_{t-1}; \tilde\mu_t(x_t, x_0), \tilde\beta_t \mathbf{I} \Big) \\
|
||||
# \tilde\mu_t(x_t, x_0) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0
|
||||
# + \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\
|
||||
# \tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{a}
|
||||
# \end{align}
|
||||
|
||||
# $\tilde\beta_t$
|
||||
self.beta_tilde = self.beta * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)
|
||||
# $$\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}$$
|
||||
self.mu_tilde_coef1 = self.beta * (alpha_bar_tm1 ** 0.5) / (1 - self.alpha_bar)
|
||||
# $$\frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1}}{1-\bar\alpha_t}$$
|
||||
self.mu_tilde_coef2 = (self.alpha ** 0.5) * (1 - alpha_bar_tm1) / (1 - self.alpha_bar)
|
||||
# $\sigma^2 = \beta$
|
||||
self.sigma2 = self.beta
|
||||
|
||||
def show_image(self, img, title=""):
|
||||
"""Helper function to display an image"""
|
||||
img = img.clip(0, 1)
|
||||
img = img.cpu().numpy()
|
||||
plt.imshow(img.transpose(1, 2, 0))
|
||||
plt.title(title)
|
||||
plt.show()
|
||||
|
||||
def make_video(self, frames, path="video.mp4"):
|
||||
"""Helper function to create a video"""
|
||||
import imageio
|
||||
# 20 second video
|
||||
writer = imageio.get_writer(path, fps=len(frames) // 20)
|
||||
# Add each image
|
||||
for f in frames:
|
||||
f = f.clip(0, 1)
|
||||
f = to_pil_image(resize(f, [368, 368]))
|
||||
writer.append_data(np.array(f))
|
||||
#
|
||||
writer.close()
|
||||
|
||||
def sample_animation(self, n_frames: int = 1000, create_video: bool = True):
|
||||
"""
|
||||
#### Sample an image step-by-step using $\color{cyan}{p_\theta}(x_{t-1}|x_t)$
|
||||
|
||||
We sample an image step-by-step using $\color{cyan}{p_\theta}(x_{t-1}|x_t)$ and at each step
|
||||
show the estimate
|
||||
$$x_0 \approx \hat{x}_0 = \frac{1}{\sqrt{\bar\alpha}}
|
||||
\Big( x_t - \sqrt{1 - \bar\alpha_t} \color{cyan}{\epsilon_\theta}(x_t, t) \Big)$$
|
||||
"""
|
||||
|
||||
# $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
|
||||
xt = torch.randn([1, self.image_channels, self.image_size, self.image_size], device=self.device)
|
||||
|
||||
# Interval to log $\hat{x}_0$
|
||||
interval = self.n_steps // n_frames
|
||||
# Frames for video
|
||||
frames = []
|
||||
# Sample $T$ steps
|
||||
for t_inv in monit.iterate('Denoise', self.n_steps):
|
||||
# $t$
|
||||
t_ = self.n_steps - t_inv - 1
|
||||
# $t$ in a tensor
|
||||
t = xt.new_full((1,), t_, dtype=torch.long)
|
||||
# $\color{cyan}{\epsilon_\theta}(x_t, t)$
|
||||
eps_theta = self.eps_model(xt, t)
|
||||
if t_ % interval == 0:
|
||||
# Get $\hat{x}_0$ and add to frames
|
||||
x0 = self.p_x0(xt, t, eps_theta)
|
||||
frames.append(x0[0])
|
||||
if not create_video:
|
||||
self.show_image(x0[0], f"{t_}")
|
||||
# Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$
|
||||
xt = self.p_sample(xt, t, eps_theta)
|
||||
|
||||
# Make video
|
||||
if create_video:
|
||||
self.make_video(frames)
|
||||
|
||||
def interpolate(self, x1: torch.Tensor, x2: torch.Tensor, lambda_: float, t_: int = 100):
|
||||
"""
|
||||
#### Interpolate two images $x_0$ and $x'_0$
|
||||
|
||||
We get $x_t \sim q(x_t|x_0)$ and $x'_t \sim q(x'_t|x_0)$.
|
||||
|
||||
Then interpolate to
|
||||
$$\bar{x}_t = (1 - \lambda)x_t + \lambda x'_0$$
|
||||
|
||||
Then get
|
||||
$$\bar{x}_0 \sim \color{cyan}{p_\theta}(x_0|\bar{x}_t)$$
|
||||
|
||||
* `x1` is $x_0$
|
||||
* `x2` is $x'_0$
|
||||
* `lambda_` is $\lambda$
|
||||
* `t_` is $t$
|
||||
"""
|
||||
|
||||
# Number of samples
|
||||
n_samples = x1.shape[0]
|
||||
# $t$ tensor
|
||||
t = torch.full((n_samples,), t_, device=self.device)
|
||||
# $$\bar{x}_t = (1 - \lambda)x_t + \lambda x'_0$$
|
||||
xt = (1 - lambda_) * self.diffusion.q_sample(x1, t) + lambda_ * self.diffusion.q_sample(x2, t)
|
||||
|
||||
# $$\bar{x}_0 \sim \color{cyan}{p_\theta}(x_0|\bar{x}_t)$$
|
||||
return self._sample_x0(xt, t_)
|
||||
|
||||
def interpolate_animate(self, x1: torch.Tensor, x2: torch.Tensor, n_frames: int = 100, t_: int = 100,
|
||||
create_video=True):
|
||||
"""
|
||||
#### Interpolate two images $x_0$ and $x'_0$ and make a video
|
||||
|
||||
* `x1` is $x_0$
|
||||
* `x2` is $x'_0$
|
||||
* `n_frames` is the number of frames for the image
|
||||
* `t_` is $t$
|
||||
* `create_video` specifies whether to make a video or to show each frame
|
||||
"""
|
||||
|
||||
# Show original images
|
||||
self.show_image(x1, "x1")
|
||||
self.show_image(x2, "x2")
|
||||
# Add batch dimension
|
||||
x1 = x1[None, :, :, :]
|
||||
x2 = x2[None, :, :, :]
|
||||
# $t$ tensor
|
||||
t = torch.full((1,), t_, device=self.device)
|
||||
# $x_t \sim q(x_t|x_0)$
|
||||
x1t = self.diffusion.q_sample(x1, t)
|
||||
# $x'_t \sim q(x'_t|x_0)$
|
||||
x2t = self.diffusion.q_sample(x2, t)
|
||||
|
||||
frames = []
|
||||
# Get frames with different $\lambda$
|
||||
for i in monit.iterate('Interpolate', n_frames + 1, is_children_silent=True):
|
||||
# $\lambda$
|
||||
lambda_ = i / n_frames
|
||||
# $$\bar{x}_t = (1 - \lambda)x_t + \lambda x'_0$$
|
||||
xt = (1 - lambda_) * x1t + lambda_ * x2t
|
||||
# $$\bar{x}_0 \sim \color{cyan}{p_\theta}(x_0|\bar{x}_t)$$
|
||||
x0 = self._sample_x0(xt, t_)
|
||||
# Add to frames
|
||||
frames.append(x0[0])
|
||||
# Show frame
|
||||
if not create_video:
|
||||
self.show_image(x0[0], f"{lambda_ :.2f}")
|
||||
|
||||
# Make video
|
||||
if create_video:
|
||||
self.make_video(frames)
|
||||
|
||||
def _sample_x0(self, xt: torch.Tensor, n_steps: int):
|
||||
"""
|
||||
#### Sample an image using $\color{cyan}{p_\theta}(x_{t-1}|x_t)$
|
||||
|
||||
* `xt` is $x_t$
|
||||
* `n_steps` is $t$
|
||||
"""
|
||||
|
||||
# Number of sampels
|
||||
n_samples = xt.shape[0]
|
||||
# Iterate until $t$ steps
|
||||
for t_ in monit.iterate('Denoise', n_steps):
|
||||
t = n_steps - t_ - 1
|
||||
# Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$
|
||||
xt = self.diffusion.p_sample(xt, xt.new_full((n_samples,), t, dtype=torch.long))
|
||||
|
||||
# Return $x_0$
|
||||
return xt
|
||||
|
||||
def sample(self, n_samples: int = 16):
|
||||
"""
|
||||
#### Generate images
|
||||
"""
|
||||
# $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
|
||||
xt = torch.randn([n_samples, self.image_channels, self.image_size, self.image_size], device=self.device)
|
||||
|
||||
# $$x_0 \sim \color{cyan}{p_\theta}(x_0|x_t)$$
|
||||
x0 = self._sample_x0(xt, self.n_steps)
|
||||
|
||||
# Show images
|
||||
for i in range(n_samples):
|
||||
self.show_image(x0[i])
|
||||
|
||||
def p_sample(self, xt: torch.Tensor, t: torch.Tensor, eps_theta: torch.Tensor):
|
||||
"""
|
||||
#### Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$
|
||||
|
||||
\begin{align}
|
||||
\color{cyan}{p_\theta}(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1};
|
||||
\color{cyan}{\mu_\theta}(x_t, t), \sigma_t^2 \mathbf{I} \big) \\
|
||||
\color{cyan}{\mu_\theta}(x_t, t)
|
||||
&= \frac{1}{\sqrt{\alpha_t}} \Big(x_t -
|
||||
\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\color{cyan}{\epsilon_\theta}(x_t, t) \Big)
|
||||
\end{align}
|
||||
"""
|
||||
# [gather](utils.html) $\bar\alpha_t$
|
||||
alpha_bar = gather(self.alpha_bar, t)
|
||||
# $\alpha_t$
|
||||
alpha = gather(self.alpha, t)
|
||||
# $\frac{\beta}{\sqrt{1-\bar\alpha_t}}$
|
||||
eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
|
||||
# $$\frac{1}{\sqrt{\alpha_t}} \Big(x_t -
|
||||
# \frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\color{cyan}{\epsilon_\theta}(x_t, t) \Big)$$
|
||||
mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
|
||||
# $\sigma^2$
|
||||
var = gather(self.sigma2, t)
|
||||
|
||||
# $\epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$
|
||||
eps = torch.randn(xt.shape, device=xt.device)
|
||||
# Sample
|
||||
return mean + (var ** .5) * eps
|
||||
|
||||
def p_x0(self, xt: torch.Tensor, t: torch.Tensor, eps: torch.Tensor):
|
||||
"""
|
||||
#### Estimate $x_0$
|
||||
|
||||
$$x_0 \approx \hat{x}_0 = \frac{1}{\sqrt{\bar\alpha}}
|
||||
\Big( x_t - \sqrt{1 - \bar\alpha_t} \color{cyan}{\epsilon_\theta}(x_t, t) \Big)$$
|
||||
"""
|
||||
# [gather](utils.html) $\bar\alpha_t$
|
||||
alpha_bar = gather(self.alpha_bar, t)
|
||||
|
||||
# $$x_0 \approx \hat{x}_0 = \frac{1}{\sqrt{\bar\alpha}}
|
||||
# \Big( x_t - \sqrt{1 - \bar\alpha_t} \color{cyan}{\epsilon_\theta}(x_t, t) \Big)$$
|
||||
return (xt - (1 - alpha_bar) ** 0.5 * eps) / (alpha_bar ** 0.5)
|
||||
|
||||
|
||||
def main():
|
||||
"""Generate samples"""
|
||||
|
||||
# Training experiment run UUID
|
||||
run_uuid = "a44333ea251411ec8007d1a1762ed686"
|
||||
|
||||
# Start an evaluation
|
||||
experiment.evaluate()
|
||||
|
||||
# Create configs
|
||||
configs = Configs()
|
||||
# Load custom configuration of the training run
|
||||
configs_dict = experiment.load_configs(run_uuid)
|
||||
# Set configurations
|
||||
experiment.configs(configs, configs_dict)
|
||||
|
||||
# Initialize
|
||||
configs.init()
|
||||
|
||||
# Set PyTorch modules for saving and loading
|
||||
experiment.add_pytorch_models({'eps_model': configs.eps_model})
|
||||
|
||||
# Load training experiment
|
||||
experiment.load(run_uuid)
|
||||
|
||||
# Create sampler
|
||||
sampler = Sampler(diffusion=configs.diffusion,
|
||||
image_channels=configs.image_channels,
|
||||
image_size=configs.image_size,
|
||||
device=configs.device)
|
||||
|
||||
# Start evaluation
|
||||
with experiment.start():
|
||||
# No gradients
|
||||
with torch.no_grad():
|
||||
# Sample an image with an denoising animation
|
||||
sampler.sample_animation()
|
||||
|
||||
if False:
|
||||
# Get some images fro data
|
||||
data = next(iter(configs.data_loader)).to(configs.device)
|
||||
|
||||
# Create an interpolation animation
|
||||
sampler.interpolate_animate(data[0], data[1])
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
main()
|
249
labml_nn/diffusion/ddpm/experiment.py
Normal file
249
labml_nn/diffusion/ddpm/experiment.py
Normal file
@ -0,0 +1,249 @@
|
||||
"""
|
||||
---
|
||||
title: Denoising Diffusion Probabilistic Models (DDPM) training
|
||||
summary: >
|
||||
Training code for
|
||||
Denoising Diffusion Probabilistic Model.
|
||||
---
|
||||
|
||||
# [Denoising Diffusion Probabilistic Models (DDPM)](index.html) training
|
||||
|
||||
This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this
|
||||
[discussion on fast.ai](https://forums.fast.ai/t/download-celeba-hq-dataset/45873/3).
|
||||
Save the images inside [`data/celebA` folder](#dataset_path).
|
||||
|
||||
The paper had used a exponential moving average of the model with a decay of $0.9999$. We have skipped this for
|
||||
simplicity.
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
|
||||
from labml import lab, tracker, experiment, monit
|
||||
from labml.configs import BaseConfigs, option
|
||||
from labml_helpers.device import DeviceConfigs
|
||||
from labml_nn.diffusion.ddpm import DenoiseDiffusion
|
||||
from labml_nn.diffusion.ddpm.unet import UNet
|
||||
|
||||
|
||||
class Configs(BaseConfigs):
|
||||
"""
|
||||
## Configurations
|
||||
"""
|
||||
# Device to train the model on.
|
||||
# [`DeviceConfigs`](https://docs.labml.ai/api/helpers.html#labml_helpers.device.DeviceConfigs)
|
||||
# picks up an available CUDA device or defaults to CPU.
|
||||
device: torch.device = DeviceConfigs()
|
||||
|
||||
# U-Net model for $\color{cyan}{\epsilon_\theta}(x_t, t)$
|
||||
eps_model: UNet
|
||||
# [DDPM algorithm](index.html)
|
||||
diffusion: DenoiseDiffusion
|
||||
|
||||
# Number of channels in the image. $3$ for RGB.
|
||||
image_channels: int = 3
|
||||
# Image size
|
||||
image_size: int = 32
|
||||
# Number of channels in the initial feature map
|
||||
n_channels: int = 64
|
||||
# The list of channel numbers at each resolution.
|
||||
# The number of channels is `channel_multipliers[i] * n_channels`
|
||||
channel_multipliers: List[int] = [1, 2, 2, 4]
|
||||
# The list of booleans that indicate whether to use attention at each resolution
|
||||
is_attention: List[int] = [False, False, False, True]
|
||||
|
||||
# Number of time steps $T$
|
||||
n_steps: int = 1_000
|
||||
# Batch size
|
||||
batch_size: int = 64
|
||||
# Number of samples to generate
|
||||
n_samples: int = 16
|
||||
# Learning rate
|
||||
learning_rate: float = 2e-5
|
||||
|
||||
# Number of training epochs
|
||||
epochs: int = 1_000
|
||||
|
||||
# Dataset
|
||||
dataset: torch.utils.data.Dataset
|
||||
# Dataloader
|
||||
data_loader: torch.utils.data.DataLoader
|
||||
|
||||
# Adam optimizer
|
||||
optimizer: torch.optim.Adam
|
||||
|
||||
def init(self):
|
||||
# Create $\color{cyan}{\epsilon_\theta}(x_t, t)$ model
|
||||
self.eps_model = UNet(
|
||||
image_channels=self.image_channels,
|
||||
n_channels=self.n_channels,
|
||||
ch_mults=self.channel_multipliers,
|
||||
is_attn=self.is_attention,
|
||||
).to(self.device)
|
||||
|
||||
# Create [DDPM class](index.html)
|
||||
self.diffusion = DenoiseDiffusion(
|
||||
eps_model=self.eps_model,
|
||||
n_steps=self.n_steps,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Create dataloader
|
||||
self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle=True, pin_memory=True)
|
||||
# Create optimizer
|
||||
self.optimizer = torch.optim.Adam(self.eps_model.parameters(), lr=self.learning_rate)
|
||||
|
||||
# Image logging
|
||||
tracker.set_image("sample", True)
|
||||
|
||||
def sample(self):
|
||||
"""
|
||||
### Sample images
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# $x_T \sim p(x_T) = \mathcal{N}(x_T; \mathbf{0}, \mathbf{I})$
|
||||
x = torch.randn([self.n_samples, self.image_channels, self.image_size, self.image_size],
|
||||
device=self.device)
|
||||
|
||||
# Remove noise for $T$ steps
|
||||
for t_ in monit.iterate('Sample', self.n_steps):
|
||||
# $t$
|
||||
t = self.n_steps - t_ - 1
|
||||
# Sample from $\color{cyan}{p_\theta}(x_{t-1}|x_t)$
|
||||
x = self.diffusion.p_sample(x, x.new_full((self.n_samples,), t, dtype=torch.long))
|
||||
|
||||
# Log samples
|
||||
tracker.save('sample', x)
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
### Train
|
||||
"""
|
||||
|
||||
# Iterate through the dataset
|
||||
for data in monit.iterate('Train', self.data_loader):
|
||||
# Increment global step
|
||||
tracker.add_global_step()
|
||||
# Move data to device
|
||||
data = data.to(self.device)
|
||||
|
||||
# Make the gradients zero
|
||||
self.optimizer.zero_grad()
|
||||
# Calculate loss
|
||||
loss = self.diffusion.loss(data)
|
||||
# Compute gradients
|
||||
loss.backward()
|
||||
# Take an optimization step
|
||||
self.optimizer.step()
|
||||
# Track the loss
|
||||
tracker.save('loss', loss)
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
### Training loop
|
||||
"""
|
||||
for _ in monit.loop(self.epochs):
|
||||
# Train the model
|
||||
self.train()
|
||||
# Sample some images
|
||||
self.sample()
|
||||
# New line in the console
|
||||
tracker.new_line()
|
||||
# Save the model
|
||||
experiment.save_checkpoint()
|
||||
|
||||
|
||||
class CelebADataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
### CelebA HQ dataset
|
||||
"""
|
||||
|
||||
def __init__(self, image_size: int):
|
||||
super().__init__()
|
||||
|
||||
# CelebA images folder
|
||||
folder = lab.get_data_path() / 'celebA'
|
||||
# List of files
|
||||
self._files = [p for p in folder.glob(f'**/*.jpg')]
|
||||
|
||||
# Transformations to resize the image and convert to tensor
|
||||
self._transform = torchvision.transforms.Compose([
|
||||
torchvision.transforms.Resize(image_size),
|
||||
torchvision.transforms.ToTensor(),
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Size of the dataset
|
||||
"""
|
||||
return len(self._files)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
"""
|
||||
Get an image
|
||||
"""
|
||||
img = Image.open(self._files[index])
|
||||
return self._transform(img)
|
||||
|
||||
|
||||
@option(Configs.dataset, 'CelebA')
|
||||
def celeb_dataset(c: Configs):
|
||||
"""
|
||||
Create CelebA dataset
|
||||
"""
|
||||
return CelebADataset(c.image_size)
|
||||
|
||||
|
||||
class MNISTDataset(torchvision.datasets.MNIST):
|
||||
"""
|
||||
### MNIST dataset
|
||||
"""
|
||||
|
||||
def __init__(self, image_size):
|
||||
transform = torchvision.transforms.Compose([
|
||||
torchvision.transforms.Resize(image_size),
|
||||
torchvision.transforms.ToTensor(),
|
||||
])
|
||||
|
||||
super().__init__(str(lab.get_data_path()), train=True, download=True, transform=transform)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return super().__getitem__(item)[0]
|
||||
|
||||
|
||||
@option(Configs.dataset, 'MNIST')
|
||||
def mnist_dataset(c: Configs):
|
||||
"""
|
||||
Create MNIST dataset
|
||||
"""
|
||||
return MNISTDataset(c.image_size)
|
||||
|
||||
|
||||
def main():
|
||||
# Create experiment
|
||||
experiment.create(name='diffuse')
|
||||
|
||||
# Create configurations
|
||||
configs = Configs()
|
||||
|
||||
# Set configurations. You can override the defaults by passing the values in the dictionary.
|
||||
experiment.configs(configs, {
|
||||
})
|
||||
|
||||
# Initialize
|
||||
configs.init()
|
||||
|
||||
# Set models for saving and loading
|
||||
experiment.add_pytorch_models({'eps_model': configs.eps_model})
|
||||
|
||||
# Start and run the training loop
|
||||
with experiment.start():
|
||||
configs.run()
|
||||
|
||||
|
||||
#
|
||||
if __name__ == '__main__':
|
||||
main()
|
15
labml_nn/diffusion/ddpm/readme.md
Normal file
15
labml_nn/diffusion/ddpm/readme.md
Normal file
@ -0,0 +1,15 @@
|
||||
# [Denoising Diffusion Probabilistic Models (DDPM)](https://nn.labml.ai/diffusion/ddpm/index.html)
|
||||
|
||||
This is a [PyTorch](https://pytorch.org) implementation/tutorial of the paper
|
||||
[Denoising Diffusion Probabilistic Models](https://papers.labml.ai/paper/2006.11239).
|
||||
|
||||
In simple terms, we get an image from data and add noise step by step.
|
||||
Then We train a model to predict that noise at each step and use the model to
|
||||
generate images.
|
||||
|
||||
Here is the [UNet model](https://nn.labml.ai/diffusion/ddpm/unet.html) that predicts the noise and
|
||||
[training code](https://nn.labml.ai/diffusion/ddpm/experiment.html).
|
||||
[This file](https://nn.labml.ai/diffusion/ddpm/evaluate.html) can generate samples and interpolations
|
||||
from a trained model.
|
||||
|
||||
[](https://app.labml.ai/run/a44333ea251411ec8007d1a1762ed686)
|
410
labml_nn/diffusion/ddpm/unet.py
Normal file
410
labml_nn/diffusion/ddpm/unet.py
Normal file
@ -0,0 +1,410 @@
|
||||
"""
|
||||
---
|
||||
title: U-Net model for Denoising Diffusion Probabilistic Models (DDPM)
|
||||
summary: >
|
||||
UNet model for Denoising Diffusion Probabilistic Models (DDPM)
|
||||
---
|
||||
|
||||
# U-Net model for [Denoising Diffusion Probabilistic Models (DDPM)](index.html)
|
||||
|
||||
This is a [U-Net](https://papers.labml.ai/paper/1505.04597) based model to predict noise
|
||||
$\color{cyan}{\epsilon_\theta}(x_t, t)$.
|
||||
|
||||
U-Net is a gets it's name from the U shape in the model diagram.
|
||||
It processes a given image by progressively lowering (halving) the feature map resolution and then
|
||||
increasing the resolution.
|
||||
There are pass-through connection at each resolution.
|
||||
|
||||

|
||||
|
||||
This implementation contains a bunch of modifications to original U-Net (residual blocks, multi-head attention)
|
||||
and also adds time-step embeddings $t$.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple, Union, List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from labml_helpers.module import Module
|
||||
|
||||
|
||||
class Swish(Module):
|
||||
"""
|
||||
### Swish actiavation function
|
||||
|
||||
$$x \cdot \sigma(x)$$
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class TimeEmbedding(nn.Module):
|
||||
"""
|
||||
### Embeddings for $t$
|
||||
"""
|
||||
|
||||
def __init__(self, n_channels: int):
|
||||
"""
|
||||
* `n_channels` is the number of dimensions in the embedding
|
||||
"""
|
||||
super().__init__()
|
||||
self.n_channels = n_channels
|
||||
# First linear layer
|
||||
self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
|
||||
# Activation
|
||||
self.act = Swish()
|
||||
# Second linear layer
|
||||
self.lin2 = nn.Linear(self.n_channels, self.n_channels)
|
||||
|
||||
def forward(self, t: torch.Tensor):
|
||||
# Create sinusoidal position embeddings
|
||||
# [same as those from the transformer](../../transformers/positional_encoding.html)
|
||||
# \begin{align}
|
||||
# PE^{(1)}_{t,i} &= sin\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg) \\
|
||||
# PE^{(2)}_{t,i} &= cos\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg)
|
||||
# \end{align}
|
||||
# where $d$ is `half_dim`
|
||||
half_dim = self.n_channels // 8
|
||||
emb = math.log(10_000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
|
||||
emb = t[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=1)
|
||||
|
||||
# Transform with the MLP
|
||||
emb = self.act(self.lin1(emb))
|
||||
emb = self.lin2(emb)
|
||||
|
||||
#
|
||||
return emb
|
||||
|
||||
|
||||
class ResidualBlock(Module):
|
||||
"""
|
||||
### Residual block
|
||||
|
||||
A residual block has two convolution layers with group normalization.
|
||||
Each resolution is processed with two residual blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32):
|
||||
"""
|
||||
* `in_channels` is the number of input channels
|
||||
* `out_channels` is the number of input channels
|
||||
* `time_channels` is the number channels in the time step ($t$) embeddings
|
||||
* `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
|
||||
"""
|
||||
super().__init__()
|
||||
# Group normalization and the first convolution layer
|
||||
self.norm1 = nn.GroupNorm(n_groups, in_channels)
|
||||
self.act1 = Swish()
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
|
||||
|
||||
# Group normalization and the second convolution layer
|
||||
self.norm2 = nn.GroupNorm(n_groups, out_channels)
|
||||
self.act2 = Swish()
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
|
||||
|
||||
# If the number of input channels is not equal to the number of output channels we have to
|
||||
# project the shortcut connection
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
|
||||
else:
|
||||
self.shortcut = nn.Identity()
|
||||
|
||||
# Linear layer for time embeddings
|
||||
self.time_emb = nn.Linear(time_channels, out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor):
|
||||
"""
|
||||
* `x` has shape `[batch_size, in_channels, height, width]`
|
||||
* `t` has shape `[batch_size, time_channels]`
|
||||
"""
|
||||
# First convolution layer
|
||||
h = self.conv1(self.act1(self.norm1(x)))
|
||||
# Add time embeddings
|
||||
h += self.time_emb(t)[:, :, None, None]
|
||||
# Second convolution layer
|
||||
h = self.conv2(self.act2(self.norm2(h)))
|
||||
|
||||
# Add the shortcut connection and return
|
||||
return h + self.shortcut(x)
|
||||
|
||||
|
||||
class AttentionBlock(Module):
|
||||
"""
|
||||
### Attention block
|
||||
|
||||
This is similar to [transformer multi-head attention](../../transformers/mha.html).
|
||||
"""
|
||||
|
||||
def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):
|
||||
"""
|
||||
* `n_channels` is the number of channels in the input
|
||||
* `n_heads` is the number of heads in multi-head attention
|
||||
* `d_k` is the number of dimensions in each head
|
||||
* `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Default `d_k`
|
||||
if d_k is None:
|
||||
d_k = n_channels
|
||||
# Normalization layer
|
||||
self.norm = nn.GroupNorm(n_groups, n_channels)
|
||||
# Projections for query, key and values
|
||||
self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
|
||||
# Linear layer for final transformation
|
||||
self.output = nn.Linear(n_heads * d_k, n_channels)
|
||||
# Scale for dot-product attention
|
||||
self.scale = d_k ** -0.5
|
||||
#
|
||||
self.n_heads = n_heads
|
||||
self.d_k = d_k
|
||||
|
||||
def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
* `x` has shape `[batch_size, in_channels, height, width]`
|
||||
* `t` has shape `[batch_size, time_channels]`
|
||||
"""
|
||||
# `t` is not used, but it's kept in the arguments because for the attention layer function signature
|
||||
# to match with `ResidualBlock`.
|
||||
_ = t
|
||||
# Get shape
|
||||
batch_size, n_channels, height, width = x.shape
|
||||
# Change `x` to shape `[batch_size, seq, n_channels]`
|
||||
x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
|
||||
# Get query, key, and values (concatenated) and shape it to `[batch_size, seq, n_heads, 3 * d_k]`
|
||||
qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
|
||||
# Split query, key, and values. Each of them will have shape `[batch_size, seq, n_heads, d_k]`
|
||||
q, k, v = torch.chunk(qkv, 3, dim=-1)
|
||||
# Calculate scaled dot-product $\frac{Q K^\top}{\sqrt{d_k}}$
|
||||
attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
|
||||
# Softmax along the sequence dimension $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
|
||||
attn = attn.softmax(dim=1)
|
||||
# Multiply by values
|
||||
res = torch.einsum('bijh,bjhd->bihd', attn, v)
|
||||
# Reshape to `[batch_size, seq, n_heads * d_k]`
|
||||
res = res.view(batch_size, -1, self.n_heads * self.d_k)
|
||||
# Transform to `[batch_size, seq, n_channels]`
|
||||
res = self.output(res)
|
||||
|
||||
# Add skip connection
|
||||
res += x
|
||||
|
||||
# Change to shape `[batch_size, in_channels, height, width]`
|
||||
res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)
|
||||
|
||||
#
|
||||
return res
|
||||
|
||||
|
||||
class DownBlock(Module):
|
||||
"""
|
||||
### Down block
|
||||
|
||||
This combines `ResidualBlock` and `AttentionBlock`. These are used in the first half of U-Net at each resolution.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
|
||||
super().__init__()
|
||||
self.res = ResidualBlock(in_channels, out_channels, time_channels)
|
||||
if has_attn:
|
||||
self.attn = AttentionBlock(out_channels)
|
||||
else:
|
||||
self.attn = nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor):
|
||||
x = self.res(x, t)
|
||||
x = self.attn(x)
|
||||
return x
|
||||
|
||||
|
||||
class UpBlock(Module):
|
||||
"""
|
||||
### Up block
|
||||
|
||||
This combines `ResidualBlock` and `AttentionBlock`. These are used in the second half of U-Net at each resolution.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool):
|
||||
super().__init__()
|
||||
# The input has `in_channels + out_channels` because we concatenate the output of the same resolution
|
||||
# from the first half of the U-Net
|
||||
self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
|
||||
if has_attn:
|
||||
self.attn = AttentionBlock(out_channels)
|
||||
else:
|
||||
self.attn = nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor):
|
||||
x = self.res(x, t)
|
||||
x = self.attn(x)
|
||||
return x
|
||||
|
||||
|
||||
class MiddleBlock(Module):
|
||||
"""
|
||||
### Middle block
|
||||
|
||||
It combines a `ResidualBlock`, `AttentionBlock`, followed by another `ResidualBlock`.
|
||||
This block is applied at the lowest resolution of the U-Net.
|
||||
"""
|
||||
|
||||
def __init__(self, n_channels: int, time_channels: int):
|
||||
super().__init__()
|
||||
self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
|
||||
self.attn = AttentionBlock(n_channels)
|
||||
self.res2 = ResidualBlock(n_channels, n_channels, time_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor):
|
||||
x = self.res1(x, t)
|
||||
x = self.attn(x)
|
||||
x = self.res2(x, t)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
### Scale up the feature map by $2 \times$
|
||||
"""
|
||||
|
||||
def __init__(self, n_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor):
|
||||
# `t` is not used, but it's kept in the arguments because for the attention layer function signature
|
||||
# to match with `ResidualBlock`.
|
||||
_ = t
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
### Scale down the feature map by $\frac{1}{2} \times$
|
||||
"""
|
||||
|
||||
def __init__(self, n_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor):
|
||||
# `t` is not used, but it's kept in the arguments because for the attention layer function signature
|
||||
# to match with `ResidualBlock`.
|
||||
_ = t
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class UNet(Module):
|
||||
"""
|
||||
## U-Net
|
||||
"""
|
||||
|
||||
def __init__(self, image_channels: int = 3, n_channels: int = 64,
|
||||
ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
|
||||
is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
|
||||
n_blocks: int = 2):
|
||||
"""
|
||||
* `image_channels` is the number of channels in the image. $3$ for RGB.
|
||||
* `n_channels` is number of channels in the initial feature map that we transform the image into
|
||||
* `ch_mults` is the list of channel numbers at each resolution. The number of channels is `ch_mults[i] * n_channels`
|
||||
* `is_attn` is a list of booleans that indicate whether to use attention at each resolution
|
||||
* `n_blocks` is the number of `UpDownBlocks` at each resolution
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Number of resolutions
|
||||
n_resolutions = len(ch_mults)
|
||||
|
||||
# Project image into feature map
|
||||
self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))
|
||||
|
||||
# Time embedding layer. Time embedding has `n_channels * 4` channels
|
||||
self.time_emb = TimeEmbedding(n_channels * 4)
|
||||
|
||||
# #### First half of U-Net - decreasing resolution
|
||||
down = []
|
||||
# Number of channels
|
||||
out_channels = in_channels = n_channels
|
||||
# For each resolution
|
||||
for i in range(n_resolutions):
|
||||
# Number of output channels at this resolution
|
||||
out_channels = in_channels * ch_mults[i]
|
||||
# Add `n_blocks`
|
||||
for _ in range(n_blocks):
|
||||
down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
|
||||
in_channels = out_channels
|
||||
# Down sample at all resolutions except the last
|
||||
if i < n_resolutions - 1:
|
||||
down.append(Downsample(in_channels))
|
||||
|
||||
# Combine the set of modules
|
||||
self.down = nn.ModuleList(down)
|
||||
|
||||
# Middle block
|
||||
self.middle = MiddleBlock(out_channels, n_channels * 4, )
|
||||
|
||||
# #### Second half of U-Net - increasing resolution
|
||||
up = []
|
||||
# Number of channels
|
||||
in_channels = out_channels
|
||||
# For each resolution
|
||||
for i in reversed(range(n_resolutions)):
|
||||
# `n_blocks` at the same resolution
|
||||
out_channels = in_channels
|
||||
for _ in range(n_blocks):
|
||||
up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
|
||||
# Final block to reduce the number of channels
|
||||
out_channels = in_channels // ch_mults[i]
|
||||
up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
|
||||
in_channels = out_channels
|
||||
# Up sample at all resolutions except last
|
||||
if i > 0:
|
||||
up.append(Upsample(in_channels))
|
||||
|
||||
# Combine the set of modules
|
||||
self.up = nn.ModuleList(up)
|
||||
|
||||
# Final normalization and convolution layer
|
||||
self.norm = nn.GroupNorm(8, n_channels)
|
||||
self.act = Swish()
|
||||
self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor):
|
||||
"""
|
||||
* `x` has shape `[batch_size, in_channels, height, width]`
|
||||
* `t` has shape `[batch_size]`
|
||||
"""
|
||||
|
||||
# Get time-step embeddings
|
||||
t = self.time_emb(t)
|
||||
|
||||
# Get image projection
|
||||
x = self.image_proj(x)
|
||||
|
||||
# `h` will store outputs at each resolution for skip connection
|
||||
h = [x]
|
||||
# First half of U-Net
|
||||
for m in self.down:
|
||||
x = m(x, t)
|
||||
h.append(x)
|
||||
|
||||
# Middle (bottom)
|
||||
x = self.middle(x, t)
|
||||
|
||||
# Second half of U-Net
|
||||
for m in self.up:
|
||||
if isinstance(m, Upsample):
|
||||
x = m(x, t)
|
||||
else:
|
||||
# Get the skip connection from first half of U-Net and concatenate
|
||||
s = h.pop()
|
||||
x = torch.cat((x, s), dim=1)
|
||||
#
|
||||
x = m(x, t)
|
||||
|
||||
# Final normalization and convolution
|
||||
return self.final(self.act(self.norm(x)))
|
16
labml_nn/diffusion/ddpm/utils.py
Normal file
16
labml_nn/diffusion/ddpm/utils.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""
|
||||
---
|
||||
title: Utility functions for DDPM experiment
|
||||
summary: >
|
||||
Utility functions for DDPM experiment
|
||||
---
|
||||
|
||||
# Utility functions for [DDPM](index.html) experiemnt
|
||||
"""
|
||||
import torch.utils.data
|
||||
|
||||
|
||||
def gather(consts: torch.Tensor, t: torch.Tensor):
|
||||
"""Gather consts for $t$ and reshape to feature map shape"""
|
||||
c = consts.gather(-1, t)
|
||||
return c.reshape(-1, 1, 1, 1)
|
@ -57,6 +57,11 @@ implementations almost weekly.
|
||||
* [Wasserstein GAN with Gradient Penalty](https://nn.labml.ai/gan/wasserstein/gradient_penalty/index.html)
|
||||
* [StyleGAN 2](https://nn.labml.ai/gan/stylegan/index.html)
|
||||
|
||||
#### ✨ [Diffusion models](https://nn.labml.ai/diffusion/index.html)
|
||||
|
||||
* [Denoising Diffusion Probabilistic Models (DDPM)](https://nn.labml.ai/diffusion/ddpm/index.html)
|
||||
|
||||
|
||||
#### ✨ [Sketch RNN](https://nn.labml.ai/sketch_rnn/index.html)
|
||||
|
||||
#### ✨ Graph Neural Networks
|
||||
|
Reference in New Issue
Block a user