Added changes to the MCMC sampling code. Added an MCMC example.

This commit is contained in:
Alec Helbling
2023-02-03 23:13:20 -05:00
parent 7538e2b620
commit 2b21261db7
4 changed files with 175 additions and 14 deletions

View File

@ -0,0 +1,75 @@
from manim import *
import scipy.stats
from manim_ml.diffusion.mcmc import MCMCAxes
import matplotlib.pyplot as plt
import numpy as np
plt.style.use('dark_background')
# Make the specific scene
config.pixel_height = 720
config.pixel_width = 720
config.frame_height = 7.0
config.frame_width = 7.0
class MCMCWarmupScene(Scene):
def construct(self):
# Define the Gaussian Mixture likelihood
def gaussian_mm_logpdf(x):
"""Gaussian Mixture Model Log PDF"""
# Choose two arbitrary Gaussians
# Big Gaussian
big_gaussian_pdf = scipy.stats.multivariate_normal(
mean=[-0.5, -0.5],
cov=[1.0, 1.0]
).pdf(x)
# Little Gaussian
little_gaussian_pdf = scipy.stats.multivariate_normal(
mean=[2.3, 1.9],
cov=[0.3, 0.3]
).pdf(x)
# Sum their likelihoods and take the log
logpdf = np.log(big_gaussian_pdf + little_gaussian_pdf)
return logpdf
# Generate a bunch of true samples
true_samples = []
# Generate samples for little gaussian
little_gaussian_samples = np.random.multivariate_normal(
mean=[2.3, 1.9],
cov=[[0.3, 0.0], [0.0, 0.3]],
size=(10000)
)
big_gaussian_samples = np.random.multivariate_normal(
mean=[-0.5, -0.5],
cov=[[1.0, 0.0], [0.0, 1.0]],
size=(10000)
)
true_samples = np.concatenate((little_gaussian_samples, big_gaussian_samples))
# Make the MCMC axes
axes = MCMCAxes(
x_range=[-5, 5],
y_range=[-5, 5],
x_length=7.0,
y_length=7.0
)
axes.move_to(ORIGIN)
self.play(
Create(axes)
)
# Make the chain sampling animation
chain_sampling_animation = axes.visualize_metropolis_hastings_chain_sampling(
log_prob_fn=gaussian_mm_logpdf,
true_samples=true_samples,
sampling_kwargs={
"iterations": 2000,
"warm_up": 50,
"initial_location": np.array([-3.5, 3.5]),
"sampling_seed": 4
},
)
self.play(chain_sampling_animation)
self.wait(3)

View File

@ -2,16 +2,20 @@
Tool for animating Markov Chain Monte Carlo simulations in 2D. Tool for animating Markov Chain Monte Carlo simulations in 2D.
""" """
from manim import * from manim import *
import matplotlib
import matplotlib.pyplot as plt
from manim_ml.utils.mobjects.plotting import convert_matplotlib_figure_to_image_mobject
import numpy as np import numpy as np
import scipy import scipy
import scipy.stats import scipy.stats
from tqdm import tqdm from tqdm import tqdm
import seaborn as sns
from manim_ml.utils.mobjects.probability import GaussianDistribution from manim_ml.utils.mobjects.probability import GaussianDistribution
######################## MCMC Algorithms ######################### ######################## MCMC Algorithms #########################
def gaussian_proposal(x, sigma=1.0): def gaussian_proposal(x, sigma=0.3):
""" """
Gaussian proposal distribution. Gaussian proposal distribution.
@ -94,6 +98,7 @@ def metropolis_hastings_sampler(
iterations=25, iterations=25,
warm_up=0, warm_up=0,
ndim=2, ndim=2,
sampling_seed=1
): ):
"""Samples using a Metropolis-Hastings sampler. """Samples using a Metropolis-Hastings sampler.
@ -119,7 +124,7 @@ def metropolis_hastings_sampler(
candidate_samples: np.ndarray candidate_samples: np.ndarray
numpy array of the candidate samples for each time step numpy array of the candidate samples for each time step
""" """
assert warm_up == 0, "Warmup not implemented yet" np.random.seed(sampling_seed)
# initialize chain, acceptance rate and lnprob # initialize chain, acceptance rate and lnprob
chain = np.zeros((iterations, ndim)) chain = np.zeros((iterations, ndim))
proposals = np.zeros((iterations, ndim)) proposals = np.zeros((iterations, ndim))
@ -156,6 +161,43 @@ def metropolis_hastings_sampler(
#################### MCMC Visualization Tools ###################### #################### MCMC Visualization Tools ######################
def make_dist_image_mobject_from_samples(samples, ylim, xlim):
# Make the plot
matplotlib.use('Agg')
plt.figure(figsize=(10,10), dpi=100)
print(np.shape(samples[:, 0]))
displot = sns.displot(
x=samples[:, 0],
y=samples[:, 1],
cmap="Reds",
kind="kde",
norm=matplotlib.colors.LogNorm()
)
plt.ylim(ylim[0], ylim[1])
plt.xlim(xlim[0], xlim[1])
plt.axis('off')
fig = displot.fig
image_mobject = convert_matplotlib_figure_to_image_mobject(fig)
return image_mobject
class Uncreate(Create):
def __init__(
self,
mobject,
reverse_rate_function: bool = True,
introducer: bool = True,
remover: bool = True,
**kwargs,
) -> None:
super().__init__(
mobject,
reverse_rate_function=reverse_rate_function,
introducer=introducer,
remover=remover,
**kwargs,
)
class MCMCAxes(Group): class MCMCAxes(Group):
"""Container object for visualizing MCMC on a 2D axis""" """Container object for visualizing MCMC on a 2D axis"""
@ -166,7 +208,7 @@ class MCMCAxes(Group):
accept_line_color=GREEN, accept_line_color=GREEN,
reject_line_color=RED, reject_line_color=RED,
line_color=BLUE, line_color=BLUE,
line_stroke_width=3, line_stroke_width=2,
x_range=[-3, 3], x_range=[-3, 3],
y_range=[-3, 3], y_range=[-3, 3],
x_length=5, x_length=5,
@ -180,6 +222,10 @@ class MCMCAxes(Group):
self.line_color = line_color self.line_color = line_color
self.line_stroke_width = line_stroke_width self.line_stroke_width = line_stroke_width
# Make the axes # Make the axes
self.x_length = x_length
self.y_length = y_length
self.x_range = x_range
self.y_range = y_range
self.axes = Axes( self.axes = Axes(
x_range=x_range, x_range=x_range,
y_range=y_range, y_range=y_range,
@ -290,6 +336,7 @@ class MCMCAxes(Group):
log_prob_fn=MultidimensionalGaussianPosterior(), log_prob_fn=MultidimensionalGaussianPosterior(),
prop_fn=gaussian_proposal, prop_fn=gaussian_proposal,
show_dots=False, show_dots=False,
true_samples=None,
sampling_kwargs={}, sampling_kwargs={},
): ):
""" """
@ -318,12 +365,14 @@ class MCMCAxes(Group):
""" """
# Compute the chain samples using a Metropolis Hastings Sampler # Compute the chain samples using a Metropolis Hastings Sampler
mcmc_samples, warm_up_samples, candidate_samples = metropolis_hastings_sampler( mcmc_samples, warm_up_samples, candidate_samples = metropolis_hastings_sampler(
log_prob_fn=log_prob_fn, prop_fn=prop_fn, **sampling_kwargs log_prob_fn=log_prob_fn,
prop_fn=prop_fn,
**sampling_kwargs
) )
# print(f"MCMC samples: {mcmc_samples}") # print(f"MCMC samples: {mcmc_samples}")
# print(f"Candidate samples: {candidate_samples}") # print(f"Candidate samples: {candidate_samples}")
# Make the animation for visualizing the chain # Make the animation for visualizing the chain
animations = [] transition_animations = []
# Place the initial point # Place the initial point
current_point = mcmc_samples[0] current_point = mcmc_samples[0]
current_point = Dot( current_point = Dot(
@ -332,10 +381,11 @@ class MCMCAxes(Group):
radius=self.dot_radius, radius=self.dot_radius,
) )
create_initial_point = Create(current_point) create_initial_point = Create(current_point)
animations.append(create_initial_point) transition_animations.append(create_initial_point)
# Show the initial point's proposal distribution # Show the initial point's proposal distribution
# NOTE: visualize the warm up and the iterations # NOTE: visualize the warm up and the iterations
lines = [] lines = []
warmup_points = []
num_iterations = len(mcmc_samples) + len(warm_up_samples) num_iterations = len(mcmc_samples) + len(warm_up_samples)
for iteration in tqdm(range(1, num_iterations)): for iteration in tqdm(range(1, num_iterations)):
next_sample = mcmc_samples[iteration] next_sample = mcmc_samples[iteration]
@ -362,14 +412,50 @@ class MCMCAxes(Group):
transition_animation, line = self.make_transition_animation( transition_animation, line = self.make_transition_animation(
current_point, next_point, candidate_point current_point, next_point, candidate_point
) )
# Save assets
lines.append(line) lines.append(line)
animations.append(transition_animation) if iteration < len(warm_up_samples):
warmup_points.append(candidate_point)
# Add the transition animation
transition_animations.append(transition_animation)
# Setup for next iteration # Setup for next iteration
current_point = next_point current_point = next_point
# Make the final animation group # Overall MCMC animation
animation_group = AnimationGroup( # 1. Fade in the distribution
*animations, image_mobject = make_dist_image_mobject_from_samples(
true_samples,
xlim=(self.x_range[0], self.x_range[1]),
ylim=(self.y_range[0], self.y_range[1])
)
image_mobject.scale_to_fit_height(
self.y_length
)
image_mobject.move_to(self.axes)
fade_in_distribution = FadeIn(
image_mobject,
run_time=0.5
)
# 2. Start sampling the chain
chain_sampling_animation = AnimationGroup(
*transition_animations,
lag_ratio=1.0,
run_time=5.0
)
# 3. Convert the chain to points, excluding the warmup
lines = VGroup(*lines)
warm_up_points = VGroup(*warmup_points)
fade_out_lines_and_warmup = AnimationGroup(
Uncreate(lines),
Uncreate(warm_up_points),
lag_ratio=0.0
)
# Make the final animation
animation_group = Succession(
fade_in_distribution,
chain_sampling_animation,
fade_out_lines_and_warmup,
lag_ratio=1.0 lag_ratio=1.0
) )
return animation_group, VGroup(*lines) return animation_group

View File

@ -14,7 +14,7 @@ def convert_matplotlib_figure_to_image_mobject(fig, dpi=200):
matplotlib figure matplotlib figure
""" """
fig.tight_layout(pad=0) fig.tight_layout(pad=0)
plt.axis('off') # plt.axis('off')
fig.canvas.draw() fig.canvas.draw()
# Save data into a buffer # Save data into a buffer
image_buffer = io.BytesIO() image_buffer = io.BytesIO()

View File

@ -15,8 +15,8 @@ plt.style.use('dark_background')
# Make the specific scene # Make the specific scene
config.pixel_height = 1200 config.pixel_height = 1200
config.pixel_width = 1200 config.pixel_width = 1200
config.frame_height = 10.0 config.frame_height = 7.0
config.frame_width = 10.0 config.frame_width = 7.0
def test_metropolis_hastings_sampler(iterations=100): def test_metropolis_hastings_sampler(iterations=100):
samples, _, candidates = metropolis_hastings_sampler(iterations=iterations) samples, _, candidates = metropolis_hastings_sampler(iterations=iterations)