mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-08-06 17:29:45 +08:00
Added changes to the MCMC sampling code. Added an MCMC example.
This commit is contained in:
75
examples/mcmc/warmup_mcmc.py
Normal file
75
examples/mcmc/warmup_mcmc.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
from manim import *
|
||||||
|
|
||||||
|
import scipy.stats
|
||||||
|
from manim_ml.diffusion.mcmc import MCMCAxes
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
plt.style.use('dark_background')
|
||||||
|
|
||||||
|
# Make the specific scene
|
||||||
|
config.pixel_height = 720
|
||||||
|
config.pixel_width = 720
|
||||||
|
config.frame_height = 7.0
|
||||||
|
config.frame_width = 7.0
|
||||||
|
|
||||||
|
class MCMCWarmupScene(Scene):
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
# Define the Gaussian Mixture likelihood
|
||||||
|
def gaussian_mm_logpdf(x):
|
||||||
|
"""Gaussian Mixture Model Log PDF"""
|
||||||
|
# Choose two arbitrary Gaussians
|
||||||
|
# Big Gaussian
|
||||||
|
big_gaussian_pdf = scipy.stats.multivariate_normal(
|
||||||
|
mean=[-0.5, -0.5],
|
||||||
|
cov=[1.0, 1.0]
|
||||||
|
).pdf(x)
|
||||||
|
# Little Gaussian
|
||||||
|
little_gaussian_pdf = scipy.stats.multivariate_normal(
|
||||||
|
mean=[2.3, 1.9],
|
||||||
|
cov=[0.3, 0.3]
|
||||||
|
).pdf(x)
|
||||||
|
# Sum their likelihoods and take the log
|
||||||
|
logpdf = np.log(big_gaussian_pdf + little_gaussian_pdf)
|
||||||
|
|
||||||
|
return logpdf
|
||||||
|
|
||||||
|
# Generate a bunch of true samples
|
||||||
|
true_samples = []
|
||||||
|
# Generate samples for little gaussian
|
||||||
|
little_gaussian_samples = np.random.multivariate_normal(
|
||||||
|
mean=[2.3, 1.9],
|
||||||
|
cov=[[0.3, 0.0], [0.0, 0.3]],
|
||||||
|
size=(10000)
|
||||||
|
)
|
||||||
|
big_gaussian_samples = np.random.multivariate_normal(
|
||||||
|
mean=[-0.5, -0.5],
|
||||||
|
cov=[[1.0, 0.0], [0.0, 1.0]],
|
||||||
|
size=(10000)
|
||||||
|
)
|
||||||
|
true_samples = np.concatenate((little_gaussian_samples, big_gaussian_samples))
|
||||||
|
# Make the MCMC axes
|
||||||
|
axes = MCMCAxes(
|
||||||
|
x_range=[-5, 5],
|
||||||
|
y_range=[-5, 5],
|
||||||
|
x_length=7.0,
|
||||||
|
y_length=7.0
|
||||||
|
)
|
||||||
|
axes.move_to(ORIGIN)
|
||||||
|
self.play(
|
||||||
|
Create(axes)
|
||||||
|
)
|
||||||
|
# Make the chain sampling animation
|
||||||
|
chain_sampling_animation = axes.visualize_metropolis_hastings_chain_sampling(
|
||||||
|
log_prob_fn=gaussian_mm_logpdf,
|
||||||
|
true_samples=true_samples,
|
||||||
|
sampling_kwargs={
|
||||||
|
"iterations": 2000,
|
||||||
|
"warm_up": 50,
|
||||||
|
"initial_location": np.array([-3.5, 3.5]),
|
||||||
|
"sampling_seed": 4
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.play(chain_sampling_animation)
|
||||||
|
self.wait(3)
|
@ -2,16 +2,20 @@
|
|||||||
Tool for animating Markov Chain Monte Carlo simulations in 2D.
|
Tool for animating Markov Chain Monte Carlo simulations in 2D.
|
||||||
"""
|
"""
|
||||||
from manim import *
|
from manim import *
|
||||||
|
import matplotlib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from manim_ml.utils.mobjects.plotting import convert_matplotlib_figure_to_image_mobject
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy
|
import scipy
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import seaborn as sns
|
||||||
|
|
||||||
from manim_ml.utils.mobjects.probability import GaussianDistribution
|
from manim_ml.utils.mobjects.probability import GaussianDistribution
|
||||||
|
|
||||||
######################## MCMC Algorithms #########################
|
######################## MCMC Algorithms #########################
|
||||||
|
|
||||||
def gaussian_proposal(x, sigma=1.0):
|
def gaussian_proposal(x, sigma=0.3):
|
||||||
"""
|
"""
|
||||||
Gaussian proposal distribution.
|
Gaussian proposal distribution.
|
||||||
|
|
||||||
@ -94,6 +98,7 @@ def metropolis_hastings_sampler(
|
|||||||
iterations=25,
|
iterations=25,
|
||||||
warm_up=0,
|
warm_up=0,
|
||||||
ndim=2,
|
ndim=2,
|
||||||
|
sampling_seed=1
|
||||||
):
|
):
|
||||||
"""Samples using a Metropolis-Hastings sampler.
|
"""Samples using a Metropolis-Hastings sampler.
|
||||||
|
|
||||||
@ -119,7 +124,7 @@ def metropolis_hastings_sampler(
|
|||||||
candidate_samples: np.ndarray
|
candidate_samples: np.ndarray
|
||||||
numpy array of the candidate samples for each time step
|
numpy array of the candidate samples for each time step
|
||||||
"""
|
"""
|
||||||
assert warm_up == 0, "Warmup not implemented yet"
|
np.random.seed(sampling_seed)
|
||||||
# initialize chain, acceptance rate and lnprob
|
# initialize chain, acceptance rate and lnprob
|
||||||
chain = np.zeros((iterations, ndim))
|
chain = np.zeros((iterations, ndim))
|
||||||
proposals = np.zeros((iterations, ndim))
|
proposals = np.zeros((iterations, ndim))
|
||||||
@ -156,6 +161,43 @@ def metropolis_hastings_sampler(
|
|||||||
|
|
||||||
#################### MCMC Visualization Tools ######################
|
#################### MCMC Visualization Tools ######################
|
||||||
|
|
||||||
|
def make_dist_image_mobject_from_samples(samples, ylim, xlim):
|
||||||
|
# Make the plot
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
plt.figure(figsize=(10,10), dpi=100)
|
||||||
|
print(np.shape(samples[:, 0]))
|
||||||
|
displot = sns.displot(
|
||||||
|
x=samples[:, 0],
|
||||||
|
y=samples[:, 1],
|
||||||
|
cmap="Reds",
|
||||||
|
kind="kde",
|
||||||
|
norm=matplotlib.colors.LogNorm()
|
||||||
|
)
|
||||||
|
plt.ylim(ylim[0], ylim[1])
|
||||||
|
plt.xlim(xlim[0], xlim[1])
|
||||||
|
plt.axis('off')
|
||||||
|
fig = displot.fig
|
||||||
|
image_mobject = convert_matplotlib_figure_to_image_mobject(fig)
|
||||||
|
|
||||||
|
return image_mobject
|
||||||
|
|
||||||
|
class Uncreate(Create):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mobject,
|
||||||
|
reverse_rate_function: bool = True,
|
||||||
|
introducer: bool = True,
|
||||||
|
remover: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
mobject,
|
||||||
|
reverse_rate_function=reverse_rate_function,
|
||||||
|
introducer=introducer,
|
||||||
|
remover=remover,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
class MCMCAxes(Group):
|
class MCMCAxes(Group):
|
||||||
"""Container object for visualizing MCMC on a 2D axis"""
|
"""Container object for visualizing MCMC on a 2D axis"""
|
||||||
|
|
||||||
@ -166,7 +208,7 @@ class MCMCAxes(Group):
|
|||||||
accept_line_color=GREEN,
|
accept_line_color=GREEN,
|
||||||
reject_line_color=RED,
|
reject_line_color=RED,
|
||||||
line_color=BLUE,
|
line_color=BLUE,
|
||||||
line_stroke_width=3,
|
line_stroke_width=2,
|
||||||
x_range=[-3, 3],
|
x_range=[-3, 3],
|
||||||
y_range=[-3, 3],
|
y_range=[-3, 3],
|
||||||
x_length=5,
|
x_length=5,
|
||||||
@ -180,6 +222,10 @@ class MCMCAxes(Group):
|
|||||||
self.line_color = line_color
|
self.line_color = line_color
|
||||||
self.line_stroke_width = line_stroke_width
|
self.line_stroke_width = line_stroke_width
|
||||||
# Make the axes
|
# Make the axes
|
||||||
|
self.x_length = x_length
|
||||||
|
self.y_length = y_length
|
||||||
|
self.x_range = x_range
|
||||||
|
self.y_range = y_range
|
||||||
self.axes = Axes(
|
self.axes = Axes(
|
||||||
x_range=x_range,
|
x_range=x_range,
|
||||||
y_range=y_range,
|
y_range=y_range,
|
||||||
@ -290,6 +336,7 @@ class MCMCAxes(Group):
|
|||||||
log_prob_fn=MultidimensionalGaussianPosterior(),
|
log_prob_fn=MultidimensionalGaussianPosterior(),
|
||||||
prop_fn=gaussian_proposal,
|
prop_fn=gaussian_proposal,
|
||||||
show_dots=False,
|
show_dots=False,
|
||||||
|
true_samples=None,
|
||||||
sampling_kwargs={},
|
sampling_kwargs={},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -318,12 +365,14 @@ class MCMCAxes(Group):
|
|||||||
"""
|
"""
|
||||||
# Compute the chain samples using a Metropolis Hastings Sampler
|
# Compute the chain samples using a Metropolis Hastings Sampler
|
||||||
mcmc_samples, warm_up_samples, candidate_samples = metropolis_hastings_sampler(
|
mcmc_samples, warm_up_samples, candidate_samples = metropolis_hastings_sampler(
|
||||||
log_prob_fn=log_prob_fn, prop_fn=prop_fn, **sampling_kwargs
|
log_prob_fn=log_prob_fn,
|
||||||
|
prop_fn=prop_fn,
|
||||||
|
**sampling_kwargs
|
||||||
)
|
)
|
||||||
# print(f"MCMC samples: {mcmc_samples}")
|
# print(f"MCMC samples: {mcmc_samples}")
|
||||||
# print(f"Candidate samples: {candidate_samples}")
|
# print(f"Candidate samples: {candidate_samples}")
|
||||||
# Make the animation for visualizing the chain
|
# Make the animation for visualizing the chain
|
||||||
animations = []
|
transition_animations = []
|
||||||
# Place the initial point
|
# Place the initial point
|
||||||
current_point = mcmc_samples[0]
|
current_point = mcmc_samples[0]
|
||||||
current_point = Dot(
|
current_point = Dot(
|
||||||
@ -332,10 +381,11 @@ class MCMCAxes(Group):
|
|||||||
radius=self.dot_radius,
|
radius=self.dot_radius,
|
||||||
)
|
)
|
||||||
create_initial_point = Create(current_point)
|
create_initial_point = Create(current_point)
|
||||||
animations.append(create_initial_point)
|
transition_animations.append(create_initial_point)
|
||||||
# Show the initial point's proposal distribution
|
# Show the initial point's proposal distribution
|
||||||
# NOTE: visualize the warm up and the iterations
|
# NOTE: visualize the warm up and the iterations
|
||||||
lines = []
|
lines = []
|
||||||
|
warmup_points = []
|
||||||
num_iterations = len(mcmc_samples) + len(warm_up_samples)
|
num_iterations = len(mcmc_samples) + len(warm_up_samples)
|
||||||
for iteration in tqdm(range(1, num_iterations)):
|
for iteration in tqdm(range(1, num_iterations)):
|
||||||
next_sample = mcmc_samples[iteration]
|
next_sample = mcmc_samples[iteration]
|
||||||
@ -362,14 +412,50 @@ class MCMCAxes(Group):
|
|||||||
transition_animation, line = self.make_transition_animation(
|
transition_animation, line = self.make_transition_animation(
|
||||||
current_point, next_point, candidate_point
|
current_point, next_point, candidate_point
|
||||||
)
|
)
|
||||||
|
# Save assets
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
animations.append(transition_animation)
|
if iteration < len(warm_up_samples):
|
||||||
|
warmup_points.append(candidate_point)
|
||||||
|
|
||||||
|
# Add the transition animation
|
||||||
|
transition_animations.append(transition_animation)
|
||||||
# Setup for next iteration
|
# Setup for next iteration
|
||||||
current_point = next_point
|
current_point = next_point
|
||||||
# Make the final animation group
|
# Overall MCMC animation
|
||||||
animation_group = AnimationGroup(
|
# 1. Fade in the distribution
|
||||||
*animations,
|
image_mobject = make_dist_image_mobject_from_samples(
|
||||||
|
true_samples,
|
||||||
|
xlim=(self.x_range[0], self.x_range[1]),
|
||||||
|
ylim=(self.y_range[0], self.y_range[1])
|
||||||
|
)
|
||||||
|
image_mobject.scale_to_fit_height(
|
||||||
|
self.y_length
|
||||||
|
)
|
||||||
|
image_mobject.move_to(self.axes)
|
||||||
|
fade_in_distribution = FadeIn(
|
||||||
|
image_mobject,
|
||||||
|
run_time=0.5
|
||||||
|
)
|
||||||
|
# 2. Start sampling the chain
|
||||||
|
chain_sampling_animation = AnimationGroup(
|
||||||
|
*transition_animations,
|
||||||
|
lag_ratio=1.0,
|
||||||
|
run_time=5.0
|
||||||
|
)
|
||||||
|
# 3. Convert the chain to points, excluding the warmup
|
||||||
|
lines = VGroup(*lines)
|
||||||
|
warm_up_points = VGroup(*warmup_points)
|
||||||
|
fade_out_lines_and_warmup = AnimationGroup(
|
||||||
|
Uncreate(lines),
|
||||||
|
Uncreate(warm_up_points),
|
||||||
|
lag_ratio=0.0
|
||||||
|
)
|
||||||
|
# Make the final animation
|
||||||
|
animation_group = Succession(
|
||||||
|
fade_in_distribution,
|
||||||
|
chain_sampling_animation,
|
||||||
|
fade_out_lines_and_warmup,
|
||||||
lag_ratio=1.0
|
lag_ratio=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
return animation_group, VGroup(*lines)
|
return animation_group
|
||||||
|
@ -14,7 +14,7 @@ def convert_matplotlib_figure_to_image_mobject(fig, dpi=200):
|
|||||||
matplotlib figure
|
matplotlib figure
|
||||||
"""
|
"""
|
||||||
fig.tight_layout(pad=0)
|
fig.tight_layout(pad=0)
|
||||||
plt.axis('off')
|
# plt.axis('off')
|
||||||
fig.canvas.draw()
|
fig.canvas.draw()
|
||||||
# Save data into a buffer
|
# Save data into a buffer
|
||||||
image_buffer = io.BytesIO()
|
image_buffer = io.BytesIO()
|
||||||
|
@ -15,8 +15,8 @@ plt.style.use('dark_background')
|
|||||||
# Make the specific scene
|
# Make the specific scene
|
||||||
config.pixel_height = 1200
|
config.pixel_height = 1200
|
||||||
config.pixel_width = 1200
|
config.pixel_width = 1200
|
||||||
config.frame_height = 10.0
|
config.frame_height = 7.0
|
||||||
config.frame_width = 10.0
|
config.frame_width = 7.0
|
||||||
|
|
||||||
def test_metropolis_hastings_sampler(iterations=100):
|
def test_metropolis_hastings_sampler(iterations=100):
|
||||||
samples, _, candidates = metropolis_hastings_sampler(iterations=iterations)
|
samples, _, candidates = metropolis_hastings_sampler(iterations=iterations)
|
||||||
|
Reference in New Issue
Block a user