mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-17 10:45:54 +08:00
Added changes to the MCMC sampling code. Added an MCMC example.
This commit is contained in:
75
examples/mcmc/warmup_mcmc.py
Normal file
75
examples/mcmc/warmup_mcmc.py
Normal file
@ -0,0 +1,75 @@
|
||||
from manim import *
|
||||
|
||||
import scipy.stats
|
||||
from manim_ml.diffusion.mcmc import MCMCAxes
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
plt.style.use('dark_background')
|
||||
|
||||
# Make the specific scene
|
||||
config.pixel_height = 720
|
||||
config.pixel_width = 720
|
||||
config.frame_height = 7.0
|
||||
config.frame_width = 7.0
|
||||
|
||||
class MCMCWarmupScene(Scene):
|
||||
|
||||
def construct(self):
|
||||
# Define the Gaussian Mixture likelihood
|
||||
def gaussian_mm_logpdf(x):
|
||||
"""Gaussian Mixture Model Log PDF"""
|
||||
# Choose two arbitrary Gaussians
|
||||
# Big Gaussian
|
||||
big_gaussian_pdf = scipy.stats.multivariate_normal(
|
||||
mean=[-0.5, -0.5],
|
||||
cov=[1.0, 1.0]
|
||||
).pdf(x)
|
||||
# Little Gaussian
|
||||
little_gaussian_pdf = scipy.stats.multivariate_normal(
|
||||
mean=[2.3, 1.9],
|
||||
cov=[0.3, 0.3]
|
||||
).pdf(x)
|
||||
# Sum their likelihoods and take the log
|
||||
logpdf = np.log(big_gaussian_pdf + little_gaussian_pdf)
|
||||
|
||||
return logpdf
|
||||
|
||||
# Generate a bunch of true samples
|
||||
true_samples = []
|
||||
# Generate samples for little gaussian
|
||||
little_gaussian_samples = np.random.multivariate_normal(
|
||||
mean=[2.3, 1.9],
|
||||
cov=[[0.3, 0.0], [0.0, 0.3]],
|
||||
size=(10000)
|
||||
)
|
||||
big_gaussian_samples = np.random.multivariate_normal(
|
||||
mean=[-0.5, -0.5],
|
||||
cov=[[1.0, 0.0], [0.0, 1.0]],
|
||||
size=(10000)
|
||||
)
|
||||
true_samples = np.concatenate((little_gaussian_samples, big_gaussian_samples))
|
||||
# Make the MCMC axes
|
||||
axes = MCMCAxes(
|
||||
x_range=[-5, 5],
|
||||
y_range=[-5, 5],
|
||||
x_length=7.0,
|
||||
y_length=7.0
|
||||
)
|
||||
axes.move_to(ORIGIN)
|
||||
self.play(
|
||||
Create(axes)
|
||||
)
|
||||
# Make the chain sampling animation
|
||||
chain_sampling_animation = axes.visualize_metropolis_hastings_chain_sampling(
|
||||
log_prob_fn=gaussian_mm_logpdf,
|
||||
true_samples=true_samples,
|
||||
sampling_kwargs={
|
||||
"iterations": 2000,
|
||||
"warm_up": 50,
|
||||
"initial_location": np.array([-3.5, 3.5]),
|
||||
"sampling_seed": 4
|
||||
},
|
||||
)
|
||||
self.play(chain_sampling_animation)
|
||||
self.wait(3)
|
@ -2,16 +2,20 @@
|
||||
Tool for animating Markov Chain Monte Carlo simulations in 2D.
|
||||
"""
|
||||
from manim import *
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
from manim_ml.utils.mobjects.plotting import convert_matplotlib_figure_to_image_mobject
|
||||
import numpy as np
|
||||
import scipy
|
||||
import scipy.stats
|
||||
from tqdm import tqdm
|
||||
import seaborn as sns
|
||||
|
||||
from manim_ml.utils.mobjects.probability import GaussianDistribution
|
||||
|
||||
######################## MCMC Algorithms #########################
|
||||
|
||||
def gaussian_proposal(x, sigma=1.0):
|
||||
def gaussian_proposal(x, sigma=0.3):
|
||||
"""
|
||||
Gaussian proposal distribution.
|
||||
|
||||
@ -94,6 +98,7 @@ def metropolis_hastings_sampler(
|
||||
iterations=25,
|
||||
warm_up=0,
|
||||
ndim=2,
|
||||
sampling_seed=1
|
||||
):
|
||||
"""Samples using a Metropolis-Hastings sampler.
|
||||
|
||||
@ -119,7 +124,7 @@ def metropolis_hastings_sampler(
|
||||
candidate_samples: np.ndarray
|
||||
numpy array of the candidate samples for each time step
|
||||
"""
|
||||
assert warm_up == 0, "Warmup not implemented yet"
|
||||
np.random.seed(sampling_seed)
|
||||
# initialize chain, acceptance rate and lnprob
|
||||
chain = np.zeros((iterations, ndim))
|
||||
proposals = np.zeros((iterations, ndim))
|
||||
@ -156,6 +161,43 @@ def metropolis_hastings_sampler(
|
||||
|
||||
#################### MCMC Visualization Tools ######################
|
||||
|
||||
def make_dist_image_mobject_from_samples(samples, ylim, xlim):
|
||||
# Make the plot
|
||||
matplotlib.use('Agg')
|
||||
plt.figure(figsize=(10,10), dpi=100)
|
||||
print(np.shape(samples[:, 0]))
|
||||
displot = sns.displot(
|
||||
x=samples[:, 0],
|
||||
y=samples[:, 1],
|
||||
cmap="Reds",
|
||||
kind="kde",
|
||||
norm=matplotlib.colors.LogNorm()
|
||||
)
|
||||
plt.ylim(ylim[0], ylim[1])
|
||||
plt.xlim(xlim[0], xlim[1])
|
||||
plt.axis('off')
|
||||
fig = displot.fig
|
||||
image_mobject = convert_matplotlib_figure_to_image_mobject(fig)
|
||||
|
||||
return image_mobject
|
||||
|
||||
class Uncreate(Create):
|
||||
def __init__(
|
||||
self,
|
||||
mobject,
|
||||
reverse_rate_function: bool = True,
|
||||
introducer: bool = True,
|
||||
remover: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
mobject,
|
||||
reverse_rate_function=reverse_rate_function,
|
||||
introducer=introducer,
|
||||
remover=remover,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
class MCMCAxes(Group):
|
||||
"""Container object for visualizing MCMC on a 2D axis"""
|
||||
|
||||
@ -166,7 +208,7 @@ class MCMCAxes(Group):
|
||||
accept_line_color=GREEN,
|
||||
reject_line_color=RED,
|
||||
line_color=BLUE,
|
||||
line_stroke_width=3,
|
||||
line_stroke_width=2,
|
||||
x_range=[-3, 3],
|
||||
y_range=[-3, 3],
|
||||
x_length=5,
|
||||
@ -180,6 +222,10 @@ class MCMCAxes(Group):
|
||||
self.line_color = line_color
|
||||
self.line_stroke_width = line_stroke_width
|
||||
# Make the axes
|
||||
self.x_length = x_length
|
||||
self.y_length = y_length
|
||||
self.x_range = x_range
|
||||
self.y_range = y_range
|
||||
self.axes = Axes(
|
||||
x_range=x_range,
|
||||
y_range=y_range,
|
||||
@ -290,6 +336,7 @@ class MCMCAxes(Group):
|
||||
log_prob_fn=MultidimensionalGaussianPosterior(),
|
||||
prop_fn=gaussian_proposal,
|
||||
show_dots=False,
|
||||
true_samples=None,
|
||||
sampling_kwargs={},
|
||||
):
|
||||
"""
|
||||
@ -318,12 +365,14 @@ class MCMCAxes(Group):
|
||||
"""
|
||||
# Compute the chain samples using a Metropolis Hastings Sampler
|
||||
mcmc_samples, warm_up_samples, candidate_samples = metropolis_hastings_sampler(
|
||||
log_prob_fn=log_prob_fn, prop_fn=prop_fn, **sampling_kwargs
|
||||
log_prob_fn=log_prob_fn,
|
||||
prop_fn=prop_fn,
|
||||
**sampling_kwargs
|
||||
)
|
||||
# print(f"MCMC samples: {mcmc_samples}")
|
||||
# print(f"Candidate samples: {candidate_samples}")
|
||||
# Make the animation for visualizing the chain
|
||||
animations = []
|
||||
transition_animations = []
|
||||
# Place the initial point
|
||||
current_point = mcmc_samples[0]
|
||||
current_point = Dot(
|
||||
@ -332,10 +381,11 @@ class MCMCAxes(Group):
|
||||
radius=self.dot_radius,
|
||||
)
|
||||
create_initial_point = Create(current_point)
|
||||
animations.append(create_initial_point)
|
||||
transition_animations.append(create_initial_point)
|
||||
# Show the initial point's proposal distribution
|
||||
# NOTE: visualize the warm up and the iterations
|
||||
lines = []
|
||||
warmup_points = []
|
||||
num_iterations = len(mcmc_samples) + len(warm_up_samples)
|
||||
for iteration in tqdm(range(1, num_iterations)):
|
||||
next_sample = mcmc_samples[iteration]
|
||||
@ -362,14 +412,50 @@ class MCMCAxes(Group):
|
||||
transition_animation, line = self.make_transition_animation(
|
||||
current_point, next_point, candidate_point
|
||||
)
|
||||
# Save assets
|
||||
lines.append(line)
|
||||
animations.append(transition_animation)
|
||||
if iteration < len(warm_up_samples):
|
||||
warmup_points.append(candidate_point)
|
||||
|
||||
# Add the transition animation
|
||||
transition_animations.append(transition_animation)
|
||||
# Setup for next iteration
|
||||
current_point = next_point
|
||||
# Make the final animation group
|
||||
animation_group = AnimationGroup(
|
||||
*animations,
|
||||
# Overall MCMC animation
|
||||
# 1. Fade in the distribution
|
||||
image_mobject = make_dist_image_mobject_from_samples(
|
||||
true_samples,
|
||||
xlim=(self.x_range[0], self.x_range[1]),
|
||||
ylim=(self.y_range[0], self.y_range[1])
|
||||
)
|
||||
image_mobject.scale_to_fit_height(
|
||||
self.y_length
|
||||
)
|
||||
image_mobject.move_to(self.axes)
|
||||
fade_in_distribution = FadeIn(
|
||||
image_mobject,
|
||||
run_time=0.5
|
||||
)
|
||||
# 2. Start sampling the chain
|
||||
chain_sampling_animation = AnimationGroup(
|
||||
*transition_animations,
|
||||
lag_ratio=1.0,
|
||||
run_time=5.0
|
||||
)
|
||||
# 3. Convert the chain to points, excluding the warmup
|
||||
lines = VGroup(*lines)
|
||||
warm_up_points = VGroup(*warmup_points)
|
||||
fade_out_lines_and_warmup = AnimationGroup(
|
||||
Uncreate(lines),
|
||||
Uncreate(warm_up_points),
|
||||
lag_ratio=0.0
|
||||
)
|
||||
# Make the final animation
|
||||
animation_group = Succession(
|
||||
fade_in_distribution,
|
||||
chain_sampling_animation,
|
||||
fade_out_lines_and_warmup,
|
||||
lag_ratio=1.0
|
||||
)
|
||||
|
||||
return animation_group, VGroup(*lines)
|
||||
return animation_group
|
||||
|
@ -14,7 +14,7 @@ def convert_matplotlib_figure_to_image_mobject(fig, dpi=200):
|
||||
matplotlib figure
|
||||
"""
|
||||
fig.tight_layout(pad=0)
|
||||
plt.axis('off')
|
||||
# plt.axis('off')
|
||||
fig.canvas.draw()
|
||||
# Save data into a buffer
|
||||
image_buffer = io.BytesIO()
|
||||
|
@ -15,8 +15,8 @@ plt.style.use('dark_background')
|
||||
# Make the specific scene
|
||||
config.pixel_height = 1200
|
||||
config.pixel_width = 1200
|
||||
config.frame_height = 10.0
|
||||
config.frame_width = 10.0
|
||||
config.frame_height = 7.0
|
||||
config.frame_width = 7.0
|
||||
|
||||
def test_metropolis_hastings_sampler(iterations=100):
|
||||
samples, _, candidates = metropolis_hastings_sampler(iterations=iterations)
|
||||
|
Reference in New Issue
Block a user