Made mcmc example. Added ability to view matplotlib plots.

This commit is contained in:
Alec Helbling
2023-02-02 21:59:31 -05:00
parent 9698907cbf
commit 134be057fb
22 changed files with 322 additions and 81 deletions

View File

@ -1,6 +0,0 @@
from manim_ml.flow.flow import *
class TestScene(Scene):
def construct(self):
self.add(Rectangle())

View File

@ -4,30 +4,99 @@ from manim_ml.diffusion.mcmc import (
MultidimensionalGaussianPosterior,
metropolis_hastings_sampler,
)
from manim_ml.utils.mobjects.plotting import convert_matplotlib_figure_to_image_mobject
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
plt.style.use('dark_background')
# Make the specific scene
config.pixel_height = 1200
config.pixel_width = 1200
config.frame_height = 12.0
config.frame_width = 12.0
config.frame_height = 10.0
config.frame_width = 10.0
def test_metropolis_hastings_sampler(iterations=100):
samples, _, candidates = metropolis_hastings_sampler(iterations=iterations)
assert samples.shape == (iterations, 2)
def plot_hexbin_gaussian_on_image_mobject(
sample_func,
xlim=(-4, 4),
ylim=(-4, 4)
):
# Fixing random state for reproducibility
np.random.seed(19680801)
n = 100_000
samples = []
for i in range(n):
samples.append(sample_func())
samples = np.array(samples)
x = samples[:, 0]
y = samples[:, 1]
fig, ax0 = plt.subplots(1, figsize=(5, 5))
hb = ax0.hexbin(x, y, gridsize=50, cmap='gist_heat')
ax0.set(xlim=xlim, ylim=ylim)
return convert_matplotlib_figure_to_image_mobject(fig)
class MCMCTest(Scene):
def construct(self):
axes = MCMCAxes()
self.play(Create(axes))
gaussian_posterior = MultidimensionalGaussianPosterior(
mu=np.array([0.0, 0.0]), var=np.array([4.0, 2.0])
def construct(
self,
mu=np.array([0.0, 0.0]),
var=np.array([[1.0, 1.0]])
):
def gaussian_sample_func():
vals = np.random.multivariate_normal(
mu,
np.eye(2) * var,
1
)[0]
return vals
image_mobject = plot_hexbin_gaussian_on_image_mobject(
gaussian_sample_func
)
show_gaussian_animation = axes.show_ground_truth_gaussian(gaussian_posterior)
self.play(show_gaussian_animation)
chain_sampling_animation = axes.visualize_metropolis_hastings_chain_sampling(
log_prob_fn=gaussian_posterior, sampling_kwargs={"iterations": 1000}
self.add(image_mobject)
self.play(FadeOut(image_mobject))
axes = MCMCAxes(
x_range=[-4, 4],
y_range=[-4, 4],
)
self.play(
Create(axes)
)
self.play(chain_sampling_animation)
gaussian_posterior = MultidimensionalGaussianPosterior(
mu=np.array([0.0, 0.0]),
var=np.array([1.0, 1.0])
)
chain_sampling_animation, lines = axes.visualize_metropolis_hastings_chain_sampling(
log_prob_fn=gaussian_posterior,
sampling_kwargs={"iterations": 500},
)
self.play(
chain_sampling_animation,
run_time=3.5
)
self.play(
FadeOut(lines)
)
self.wait(1)
self.play(
FadeIn(image_mobject)
)

71
tests/test_plotting.py Normal file
View File

@ -0,0 +1,71 @@
from manim import *
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
plt.style.use('dark_background')
from manim_ml.utils.mobjects.plotting import convert_matplotlib_figure_to_image_mobject
from manim_ml.utils.testing.frames_comparison import frames_comparison
__module_test__ = "plotting"
@frames_comparison
def test_matplotlib_to_image_mobject(scene):
# libraries & dataset
df = sns.load_dataset('iris')
# Custom the color, add shade and bandwidth
matplotlib.use('Agg')
plt.figure(figsize=(10,10), dpi=100)
displot = sns.displot(
x=df.sepal_width,
y=df.sepal_length,
cmap="Reds",
kind="kde",
)
plt.axis('off')
fig = displot.fig
image_mobject = convert_matplotlib_figure_to_image_mobject(fig)
# Display the image mobject
scene.add(image_mobject)
class TestMatplotlibToImageMobject(Scene):
def construct(self):
# Make a matplotlib plot
# libraries & dataset
df = sns.load_dataset('iris')
# Custom the color, add shade and bandwidth
matplotlib.use('Agg')
plt.figure(figsize=(10,10), dpi=100)
displot = sns.displot(
x=df.sepal_width,
y=df.sepal_length,
cmap="Reds",
kind="kde",
)
plt.axis('off')
fig = displot.fig
image_mobject = convert_matplotlib_figure_to_image_mobject(fig)
# Display the image mobject
self.add(image_mobject)
class HexabinScene(Scene):
def construct(self):
# Fixing random state for reproducibility
np.random.seed(19680801)
n = 100_000
x = np.random.standard_normal(n)
y = x + 1.0 * np.random.standard_normal(n)
xlim = -4, 4
ylim = -4, 4
fig, ax0 = plt.subplots(1, figsize=(5, 5))
hb = ax0.hexbin(x, y, gridsize=50, cmap='inferno')
ax0.set(xlim=xlim, ylim=ylim)
self.add(convert_matplotlib_figure_to_image_mobject(fig))