Files
ManimML/tests/test_mcmc.py

32 lines
1.1 KiB
Python

from manim import *
from manim_ml.diffusion.mcmc import MCMCAxes, MultidimensionalGaussianPosterior, metropolis_hastings_sampler
# Make the specific scene
config.pixel_height = 1200
config.pixel_width = 1200
config.frame_height = 12.0
config.frame_width = 12.0
def test_metropolis_hastings_sampler(iterations=100):
samples, _, candidates = metropolis_hastings_sampler(iterations=iterations)
assert samples.shape == (iterations, 2)
class MCMCTest(Scene):
def construct(self):
axes = MCMCAxes()
self.play(Create(axes))
gaussian_posterior = MultidimensionalGaussianPosterior(
mu=np.array([0.0, 0.0]),
var=np.array([4.0, 2.0])
)
show_gaussian_animation = axes.show_ground_truth_gaussian(
gaussian_posterior
)
self.play(show_gaussian_animation)
chain_sampling_animation = axes.visualize_metropolis_hastings_chain_sampling(
log_prob_fn=gaussian_posterior,
sampling_kwargs={"iterations": 1000}
)
self.play(chain_sampling_animation)