sophia speed up

This commit is contained in:
Varuna Jayasiri
2023-07-15 08:30:41 +05:30
parent 594f89c8cc
commit b43fb807a8
4 changed files with 57 additions and 43 deletions

File diff suppressed because one or more lines are too long

View File

@ -484,7 +484,7 @@
<url>
<loc>https://nn.labml.ai/index.html</loc>
<lastmod>2023-06-30T16:30:00+00:00</lastmod>
<lastmod>2023-07-14T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>
@ -715,7 +715,7 @@
<url>
<loc>https://nn.labml.ai/optimizers/index.html</loc>
<lastmod>2022-06-03T16:30:00+00:00</lastmod>
<lastmod>2023-07-14T16:30:00+00:00</lastmod>
<priority>1.00</priority>
</url>

View File

@ -156,8 +156,7 @@ class Sophia(GenericAdaptiveOptimizer):
We do the following parameter update,
\begin{align}
\theta_{t + 1} &\leftarrow \theta_t - \eta \cdot \operatorname{clip} \bigg(\frac{m_t}{h_t + \epsilon}, \rho \bigg) \\
\theta_{t + 1} &\leftarrow \theta_t - \eta \rho \cdot \operatorname{clip} \bigg(\frac{m_t}{\rho h_t + \epsilon}, 1 \bigg)
\theta_{t + 1} &\leftarrow \theta_t - \eta \cdot \operatorname{clip} \bigg(\frac{m_t}{h_t + \epsilon}, \rho \bigg)
\end{align}
"""
@ -182,8 +181,11 @@ class Sophia(GenericAdaptiveOptimizer):
# Get maximum learning rate $\eta \rho$
lr = group['lr']
# $$\operatorname{clip} \bigg(\frac{m_t}{\rho h_t + \epsilon}, 1 \bigg)$$
ratio = (m / (rho * hessian + group['eps'])).clamp(-1, 1)
# $\eta$
eta = lr / rho
# $$\theta_{t + 1} \leftarrow \theta_t - \eta \rho \cdot \operatorname{clip} \bigg(\frac{m_t}{\rho h_t + \epsilon}, 1 \bigg)$$
param.data.add_(ratio, alpha=-lr)
# $$\operatorname{clip} \bigg(\frac{m_t}{h_t + \epsilon}, \rho \bigg)$$
ratio = (m / (hessian + group['eps'])).clamp(-rho, rho)
# $$\theta_{t + 1} \leftarrow \theta_t - \eta \cdot \operatorname{clip} \bigg(\frac{m_t}{h_t + \epsilon}, \rho \bigg)$$
param.data.add_(ratio, alpha=-eta)

View File

@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
setuptools.setup(
name='labml-nn',
version='0.4.134',
version='0.4.135',
author="Varuna Jayasiri, Nipun Wijerathne",
author_email="vpjayasiri@gmail.com, hnipun@gmail.com",
description="🧑‍🏫 Implementations/tutorials of deep learning papers with side-by-side notes 📝; including transformers (original, xl, switch, feedback, vit), optimizers (adam, radam, adabelief), gans(dcgan, cyclegan, stylegan2), 🎮 reinforcement learning (ppo, dqn), capsnet, distillation, diffusion, etc. 🧠",