mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-17 18:55:54 +08:00
Made mcmc example. Added ability to view matplotlib plots.
This commit is contained in:
BIN
tests/control_data/plotting/matplotlib_to_image_mobject.npz
Normal file
BIN
tests/control_data/plotting/matplotlib_to_image_mobject.npz
Normal file
Binary file not shown.
@ -1,6 +0,0 @@
|
||||
from manim_ml.flow.flow import *
|
||||
|
||||
|
||||
class TestScene(Scene):
|
||||
def construct(self):
|
||||
self.add(Rectangle())
|
@ -4,30 +4,99 @@ from manim_ml.diffusion.mcmc import (
|
||||
MultidimensionalGaussianPosterior,
|
||||
metropolis_hastings_sampler,
|
||||
)
|
||||
from manim_ml.utils.mobjects.plotting import convert_matplotlib_figure_to_image_mobject
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import matplotlib
|
||||
plt.style.use('dark_background')
|
||||
|
||||
# Make the specific scene
|
||||
config.pixel_height = 1200
|
||||
config.pixel_width = 1200
|
||||
config.frame_height = 12.0
|
||||
config.frame_width = 12.0
|
||||
|
||||
config.frame_height = 10.0
|
||||
config.frame_width = 10.0
|
||||
|
||||
def test_metropolis_hastings_sampler(iterations=100):
|
||||
samples, _, candidates = metropolis_hastings_sampler(iterations=iterations)
|
||||
assert samples.shape == (iterations, 2)
|
||||
|
||||
def plot_hexbin_gaussian_on_image_mobject(
|
||||
sample_func,
|
||||
xlim=(-4, 4),
|
||||
ylim=(-4, 4)
|
||||
):
|
||||
# Fixing random state for reproducibility
|
||||
np.random.seed(19680801)
|
||||
n = 100_000
|
||||
samples = []
|
||||
for i in range(n):
|
||||
samples.append(sample_func())
|
||||
samples = np.array(samples)
|
||||
|
||||
x = samples[:, 0]
|
||||
y = samples[:, 1]
|
||||
|
||||
fig, ax0 = plt.subplots(1, figsize=(5, 5))
|
||||
|
||||
hb = ax0.hexbin(x, y, gridsize=50, cmap='gist_heat')
|
||||
|
||||
ax0.set(xlim=xlim, ylim=ylim)
|
||||
|
||||
return convert_matplotlib_figure_to_image_mobject(fig)
|
||||
|
||||
class MCMCTest(Scene):
|
||||
def construct(self):
|
||||
axes = MCMCAxes()
|
||||
self.play(Create(axes))
|
||||
gaussian_posterior = MultidimensionalGaussianPosterior(
|
||||
mu=np.array([0.0, 0.0]), var=np.array([4.0, 2.0])
|
||||
|
||||
def construct(
|
||||
self,
|
||||
mu=np.array([0.0, 0.0]),
|
||||
var=np.array([[1.0, 1.0]])
|
||||
):
|
||||
|
||||
def gaussian_sample_func():
|
||||
vals = np.random.multivariate_normal(
|
||||
mu,
|
||||
np.eye(2) * var,
|
||||
1
|
||||
)[0]
|
||||
|
||||
return vals
|
||||
|
||||
image_mobject = plot_hexbin_gaussian_on_image_mobject(
|
||||
gaussian_sample_func
|
||||
)
|
||||
show_gaussian_animation = axes.show_ground_truth_gaussian(gaussian_posterior)
|
||||
self.play(show_gaussian_animation)
|
||||
chain_sampling_animation = axes.visualize_metropolis_hastings_chain_sampling(
|
||||
log_prob_fn=gaussian_posterior, sampling_kwargs={"iterations": 1000}
|
||||
self.add(image_mobject)
|
||||
self.play(FadeOut(image_mobject))
|
||||
|
||||
axes = MCMCAxes(
|
||||
x_range=[-4, 4],
|
||||
y_range=[-4, 4],
|
||||
)
|
||||
self.play(
|
||||
Create(axes)
|
||||
)
|
||||
|
||||
self.play(chain_sampling_animation)
|
||||
gaussian_posterior = MultidimensionalGaussianPosterior(
|
||||
mu=np.array([0.0, 0.0]),
|
||||
var=np.array([1.0, 1.0])
|
||||
)
|
||||
|
||||
chain_sampling_animation, lines = axes.visualize_metropolis_hastings_chain_sampling(
|
||||
log_prob_fn=gaussian_posterior,
|
||||
sampling_kwargs={"iterations": 500},
|
||||
)
|
||||
|
||||
self.play(
|
||||
chain_sampling_animation,
|
||||
run_time=3.5
|
||||
)
|
||||
self.play(
|
||||
FadeOut(lines)
|
||||
)
|
||||
self.wait(1)
|
||||
self.play(
|
||||
FadeIn(image_mobject)
|
||||
)
|
||||
|
||||
|
||||
|
71
tests/test_plotting.py
Normal file
71
tests/test_plotting.py
Normal file
@ -0,0 +1,71 @@
|
||||
|
||||
from manim import *
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import matplotlib
|
||||
plt.style.use('dark_background')
|
||||
|
||||
from manim_ml.utils.mobjects.plotting import convert_matplotlib_figure_to_image_mobject
|
||||
from manim_ml.utils.testing.frames_comparison import frames_comparison
|
||||
|
||||
__module_test__ = "plotting"
|
||||
|
||||
@frames_comparison
|
||||
def test_matplotlib_to_image_mobject(scene):
|
||||
# libraries & dataset
|
||||
df = sns.load_dataset('iris')
|
||||
# Custom the color, add shade and bandwidth
|
||||
matplotlib.use('Agg')
|
||||
plt.figure(figsize=(10,10), dpi=100)
|
||||
displot = sns.displot(
|
||||
x=df.sepal_width,
|
||||
y=df.sepal_length,
|
||||
cmap="Reds",
|
||||
kind="kde",
|
||||
)
|
||||
plt.axis('off')
|
||||
fig = displot.fig
|
||||
image_mobject = convert_matplotlib_figure_to_image_mobject(fig)
|
||||
# Display the image mobject
|
||||
scene.add(image_mobject)
|
||||
|
||||
class TestMatplotlibToImageMobject(Scene):
|
||||
|
||||
def construct(self):
|
||||
# Make a matplotlib plot
|
||||
# libraries & dataset
|
||||
df = sns.load_dataset('iris')
|
||||
# Custom the color, add shade and bandwidth
|
||||
matplotlib.use('Agg')
|
||||
plt.figure(figsize=(10,10), dpi=100)
|
||||
displot = sns.displot(
|
||||
x=df.sepal_width,
|
||||
y=df.sepal_length,
|
||||
cmap="Reds",
|
||||
kind="kde",
|
||||
)
|
||||
plt.axis('off')
|
||||
fig = displot.fig
|
||||
image_mobject = convert_matplotlib_figure_to_image_mobject(fig)
|
||||
# Display the image mobject
|
||||
self.add(image_mobject)
|
||||
|
||||
|
||||
class HexabinScene(Scene):
|
||||
|
||||
def construct(self):
|
||||
# Fixing random state for reproducibility
|
||||
np.random.seed(19680801)
|
||||
n = 100_000
|
||||
x = np.random.standard_normal(n)
|
||||
y = x + 1.0 * np.random.standard_normal(n)
|
||||
xlim = -4, 4
|
||||
ylim = -4, 4
|
||||
|
||||
fig, ax0 = plt.subplots(1, figsize=(5, 5))
|
||||
|
||||
hb = ax0.hexbin(x, y, gridsize=50, cmap='inferno')
|
||||
ax0.set(xlim=xlim, ylim=ylim)
|
||||
|
||||
self.add(convert_matplotlib_figure_to_image_mobject(fig))
|
Reference in New Issue
Block a user