diff --git a/examples/mcmc/warmup_mcmc.py b/examples/mcmc/warmup_mcmc.py new file mode 100644 index 0000000..65c1997 --- /dev/null +++ b/examples/mcmc/warmup_mcmc.py @@ -0,0 +1,75 @@ +from manim import * + +import scipy.stats +from manim_ml.diffusion.mcmc import MCMCAxes +import matplotlib.pyplot as plt +import numpy as np + +plt.style.use('dark_background') + +# Make the specific scene +config.pixel_height = 720 +config.pixel_width = 720 +config.frame_height = 7.0 +config.frame_width = 7.0 + +class MCMCWarmupScene(Scene): + + def construct(self): + # Define the Gaussian Mixture likelihood + def gaussian_mm_logpdf(x): + """Gaussian Mixture Model Log PDF""" + # Choose two arbitrary Gaussians + # Big Gaussian + big_gaussian_pdf = scipy.stats.multivariate_normal( + mean=[-0.5, -0.5], + cov=[1.0, 1.0] + ).pdf(x) + # Little Gaussian + little_gaussian_pdf = scipy.stats.multivariate_normal( + mean=[2.3, 1.9], + cov=[0.3, 0.3] + ).pdf(x) + # Sum their likelihoods and take the log + logpdf = np.log(big_gaussian_pdf + little_gaussian_pdf) + + return logpdf + + # Generate a bunch of true samples + true_samples = [] + # Generate samples for little gaussian + little_gaussian_samples = np.random.multivariate_normal( + mean=[2.3, 1.9], + cov=[[0.3, 0.0], [0.0, 0.3]], + size=(10000) + ) + big_gaussian_samples = np.random.multivariate_normal( + mean=[-0.5, -0.5], + cov=[[1.0, 0.0], [0.0, 1.0]], + size=(10000) + ) + true_samples = np.concatenate((little_gaussian_samples, big_gaussian_samples)) + # Make the MCMC axes + axes = MCMCAxes( + x_range=[-5, 5], + y_range=[-5, 5], + x_length=7.0, + y_length=7.0 + ) + axes.move_to(ORIGIN) + self.play( + Create(axes) + ) + # Make the chain sampling animation + chain_sampling_animation = axes.visualize_metropolis_hastings_chain_sampling( + log_prob_fn=gaussian_mm_logpdf, + true_samples=true_samples, + sampling_kwargs={ + "iterations": 2000, + "warm_up": 50, + "initial_location": np.array([-3.5, 3.5]), + "sampling_seed": 4 + }, + ) + self.play(chain_sampling_animation) + self.wait(3) diff --git a/manim_ml/diffusion/mcmc.py b/manim_ml/diffusion/mcmc.py index 929874e..10be816 100644 --- a/manim_ml/diffusion/mcmc.py +++ b/manim_ml/diffusion/mcmc.py @@ -2,16 +2,20 @@ Tool for animating Markov Chain Monte Carlo simulations in 2D. """ from manim import * +import matplotlib +import matplotlib.pyplot as plt +from manim_ml.utils.mobjects.plotting import convert_matplotlib_figure_to_image_mobject import numpy as np import scipy import scipy.stats from tqdm import tqdm +import seaborn as sns from manim_ml.utils.mobjects.probability import GaussianDistribution ######################## MCMC Algorithms ######################### -def gaussian_proposal(x, sigma=1.0): +def gaussian_proposal(x, sigma=0.3): """ Gaussian proposal distribution. @@ -94,6 +98,7 @@ def metropolis_hastings_sampler( iterations=25, warm_up=0, ndim=2, + sampling_seed=1 ): """Samples using a Metropolis-Hastings sampler. @@ -119,7 +124,7 @@ def metropolis_hastings_sampler( candidate_samples: np.ndarray numpy array of the candidate samples for each time step """ - assert warm_up == 0, "Warmup not implemented yet" + np.random.seed(sampling_seed) # initialize chain, acceptance rate and lnprob chain = np.zeros((iterations, ndim)) proposals = np.zeros((iterations, ndim)) @@ -156,6 +161,43 @@ def metropolis_hastings_sampler( #################### MCMC Visualization Tools ###################### +def make_dist_image_mobject_from_samples(samples, ylim, xlim): + # Make the plot + matplotlib.use('Agg') + plt.figure(figsize=(10,10), dpi=100) + print(np.shape(samples[:, 0])) + displot = sns.displot( + x=samples[:, 0], + y=samples[:, 1], + cmap="Reds", + kind="kde", + norm=matplotlib.colors.LogNorm() + ) + plt.ylim(ylim[0], ylim[1]) + plt.xlim(xlim[0], xlim[1]) + plt.axis('off') + fig = displot.fig + image_mobject = convert_matplotlib_figure_to_image_mobject(fig) + + return image_mobject + +class Uncreate(Create): + def __init__( + self, + mobject, + reverse_rate_function: bool = True, + introducer: bool = True, + remover: bool = True, + **kwargs, + ) -> None: + super().__init__( + mobject, + reverse_rate_function=reverse_rate_function, + introducer=introducer, + remover=remover, + **kwargs, + ) + class MCMCAxes(Group): """Container object for visualizing MCMC on a 2D axis""" @@ -166,7 +208,7 @@ class MCMCAxes(Group): accept_line_color=GREEN, reject_line_color=RED, line_color=BLUE, - line_stroke_width=3, + line_stroke_width=2, x_range=[-3, 3], y_range=[-3, 3], x_length=5, @@ -180,6 +222,10 @@ class MCMCAxes(Group): self.line_color = line_color self.line_stroke_width = line_stroke_width # Make the axes + self.x_length = x_length + self.y_length = y_length + self.x_range = x_range + self.y_range = y_range self.axes = Axes( x_range=x_range, y_range=y_range, @@ -290,6 +336,7 @@ class MCMCAxes(Group): log_prob_fn=MultidimensionalGaussianPosterior(), prop_fn=gaussian_proposal, show_dots=False, + true_samples=None, sampling_kwargs={}, ): """ @@ -318,12 +365,14 @@ class MCMCAxes(Group): """ # Compute the chain samples using a Metropolis Hastings Sampler mcmc_samples, warm_up_samples, candidate_samples = metropolis_hastings_sampler( - log_prob_fn=log_prob_fn, prop_fn=prop_fn, **sampling_kwargs + log_prob_fn=log_prob_fn, + prop_fn=prop_fn, + **sampling_kwargs ) # print(f"MCMC samples: {mcmc_samples}") # print(f"Candidate samples: {candidate_samples}") # Make the animation for visualizing the chain - animations = [] + transition_animations = [] # Place the initial point current_point = mcmc_samples[0] current_point = Dot( @@ -332,10 +381,11 @@ class MCMCAxes(Group): radius=self.dot_radius, ) create_initial_point = Create(current_point) - animations.append(create_initial_point) + transition_animations.append(create_initial_point) # Show the initial point's proposal distribution # NOTE: visualize the warm up and the iterations lines = [] + warmup_points = [] num_iterations = len(mcmc_samples) + len(warm_up_samples) for iteration in tqdm(range(1, num_iterations)): next_sample = mcmc_samples[iteration] @@ -362,14 +412,50 @@ class MCMCAxes(Group): transition_animation, line = self.make_transition_animation( current_point, next_point, candidate_point ) + # Save assets lines.append(line) - animations.append(transition_animation) + if iteration < len(warm_up_samples): + warmup_points.append(candidate_point) + + # Add the transition animation + transition_animations.append(transition_animation) # Setup for next iteration current_point = next_point - # Make the final animation group - animation_group = AnimationGroup( - *animations, + # Overall MCMC animation + # 1. Fade in the distribution + image_mobject = make_dist_image_mobject_from_samples( + true_samples, + xlim=(self.x_range[0], self.x_range[1]), + ylim=(self.y_range[0], self.y_range[1]) + ) + image_mobject.scale_to_fit_height( + self.y_length + ) + image_mobject.move_to(self.axes) + fade_in_distribution = FadeIn( + image_mobject, + run_time=0.5 + ) + # 2. Start sampling the chain + chain_sampling_animation = AnimationGroup( + *transition_animations, + lag_ratio=1.0, + run_time=5.0 + ) + # 3. Convert the chain to points, excluding the warmup + lines = VGroup(*lines) + warm_up_points = VGroup(*warmup_points) + fade_out_lines_and_warmup = AnimationGroup( + Uncreate(lines), + Uncreate(warm_up_points), + lag_ratio=0.0 + ) + # Make the final animation + animation_group = Succession( + fade_in_distribution, + chain_sampling_animation, + fade_out_lines_and_warmup, lag_ratio=1.0 ) - return animation_group, VGroup(*lines) + return animation_group diff --git a/manim_ml/utils/mobjects/plotting.py b/manim_ml/utils/mobjects/plotting.py index 296091b..26eb2f3 100644 --- a/manim_ml/utils/mobjects/plotting.py +++ b/manim_ml/utils/mobjects/plotting.py @@ -14,7 +14,7 @@ def convert_matplotlib_figure_to_image_mobject(fig, dpi=200): matplotlib figure """ fig.tight_layout(pad=0) - plt.axis('off') + # plt.axis('off') fig.canvas.draw() # Save data into a buffer image_buffer = io.BytesIO() diff --git a/tests/test_mcmc.py b/tests/test_mcmc.py index fad0ac3..602d537 100644 --- a/tests/test_mcmc.py +++ b/tests/test_mcmc.py @@ -15,8 +15,8 @@ plt.style.use('dark_background') # Make the specific scene config.pixel_height = 1200 config.pixel_width = 1200 -config.frame_height = 10.0 -config.frame_width = 10.0 +config.frame_height = 7.0 +config.frame_width = 7.0 def test_metropolis_hastings_sampler(iterations=100): samples, _, candidates = metropolis_hastings_sampler(iterations=iterations)