Added changes to the MCMC sampling code. Added an MCMC example.

This commit is contained in:
Alec Helbling
2023-02-03 23:13:20 -05:00
parent 7538e2b620
commit 2b21261db7
4 changed files with 175 additions and 14 deletions

View File

@ -0,0 +1,75 @@
from manim import *
import scipy.stats
from manim_ml.diffusion.mcmc import MCMCAxes
import matplotlib.pyplot as plt
import numpy as np
plt.style.use('dark_background')
# Make the specific scene
config.pixel_height = 720
config.pixel_width = 720
config.frame_height = 7.0
config.frame_width = 7.0
class MCMCWarmupScene(Scene):
def construct(self):
# Define the Gaussian Mixture likelihood
def gaussian_mm_logpdf(x):
"""Gaussian Mixture Model Log PDF"""
# Choose two arbitrary Gaussians
# Big Gaussian
big_gaussian_pdf = scipy.stats.multivariate_normal(
mean=[-0.5, -0.5],
cov=[1.0, 1.0]
).pdf(x)
# Little Gaussian
little_gaussian_pdf = scipy.stats.multivariate_normal(
mean=[2.3, 1.9],
cov=[0.3, 0.3]
).pdf(x)
# Sum their likelihoods and take the log
logpdf = np.log(big_gaussian_pdf + little_gaussian_pdf)
return logpdf
# Generate a bunch of true samples
true_samples = []
# Generate samples for little gaussian
little_gaussian_samples = np.random.multivariate_normal(
mean=[2.3, 1.9],
cov=[[0.3, 0.0], [0.0, 0.3]],
size=(10000)
)
big_gaussian_samples = np.random.multivariate_normal(
mean=[-0.5, -0.5],
cov=[[1.0, 0.0], [0.0, 1.0]],
size=(10000)
)
true_samples = np.concatenate((little_gaussian_samples, big_gaussian_samples))
# Make the MCMC axes
axes = MCMCAxes(
x_range=[-5, 5],
y_range=[-5, 5],
x_length=7.0,
y_length=7.0
)
axes.move_to(ORIGIN)
self.play(
Create(axes)
)
# Make the chain sampling animation
chain_sampling_animation = axes.visualize_metropolis_hastings_chain_sampling(
log_prob_fn=gaussian_mm_logpdf,
true_samples=true_samples,
sampling_kwargs={
"iterations": 2000,
"warm_up": 50,
"initial_location": np.array([-3.5, 3.5]),
"sampling_seed": 4
},
)
self.play(chain_sampling_animation)
self.wait(3)

View File

@ -2,16 +2,20 @@
Tool for animating Markov Chain Monte Carlo simulations in 2D.
"""
from manim import *
import matplotlib
import matplotlib.pyplot as plt
from manim_ml.utils.mobjects.plotting import convert_matplotlib_figure_to_image_mobject
import numpy as np
import scipy
import scipy.stats
from tqdm import tqdm
import seaborn as sns
from manim_ml.utils.mobjects.probability import GaussianDistribution
######################## MCMC Algorithms #########################
def gaussian_proposal(x, sigma=1.0):
def gaussian_proposal(x, sigma=0.3):
"""
Gaussian proposal distribution.
@ -94,6 +98,7 @@ def metropolis_hastings_sampler(
iterations=25,
warm_up=0,
ndim=2,
sampling_seed=1
):
"""Samples using a Metropolis-Hastings sampler.
@ -119,7 +124,7 @@ def metropolis_hastings_sampler(
candidate_samples: np.ndarray
numpy array of the candidate samples for each time step
"""
assert warm_up == 0, "Warmup not implemented yet"
np.random.seed(sampling_seed)
# initialize chain, acceptance rate and lnprob
chain = np.zeros((iterations, ndim))
proposals = np.zeros((iterations, ndim))
@ -156,6 +161,43 @@ def metropolis_hastings_sampler(
#################### MCMC Visualization Tools ######################
def make_dist_image_mobject_from_samples(samples, ylim, xlim):
# Make the plot
matplotlib.use('Agg')
plt.figure(figsize=(10,10), dpi=100)
print(np.shape(samples[:, 0]))
displot = sns.displot(
x=samples[:, 0],
y=samples[:, 1],
cmap="Reds",
kind="kde",
norm=matplotlib.colors.LogNorm()
)
plt.ylim(ylim[0], ylim[1])
plt.xlim(xlim[0], xlim[1])
plt.axis('off')
fig = displot.fig
image_mobject = convert_matplotlib_figure_to_image_mobject(fig)
return image_mobject
class Uncreate(Create):
def __init__(
self,
mobject,
reverse_rate_function: bool = True,
introducer: bool = True,
remover: bool = True,
**kwargs,
) -> None:
super().__init__(
mobject,
reverse_rate_function=reverse_rate_function,
introducer=introducer,
remover=remover,
**kwargs,
)
class MCMCAxes(Group):
"""Container object for visualizing MCMC on a 2D axis"""
@ -166,7 +208,7 @@ class MCMCAxes(Group):
accept_line_color=GREEN,
reject_line_color=RED,
line_color=BLUE,
line_stroke_width=3,
line_stroke_width=2,
x_range=[-3, 3],
y_range=[-3, 3],
x_length=5,
@ -180,6 +222,10 @@ class MCMCAxes(Group):
self.line_color = line_color
self.line_stroke_width = line_stroke_width
# Make the axes
self.x_length = x_length
self.y_length = y_length
self.x_range = x_range
self.y_range = y_range
self.axes = Axes(
x_range=x_range,
y_range=y_range,
@ -290,6 +336,7 @@ class MCMCAxes(Group):
log_prob_fn=MultidimensionalGaussianPosterior(),
prop_fn=gaussian_proposal,
show_dots=False,
true_samples=None,
sampling_kwargs={},
):
"""
@ -318,12 +365,14 @@ class MCMCAxes(Group):
"""
# Compute the chain samples using a Metropolis Hastings Sampler
mcmc_samples, warm_up_samples, candidate_samples = metropolis_hastings_sampler(
log_prob_fn=log_prob_fn, prop_fn=prop_fn, **sampling_kwargs
log_prob_fn=log_prob_fn,
prop_fn=prop_fn,
**sampling_kwargs
)
# print(f"MCMC samples: {mcmc_samples}")
# print(f"Candidate samples: {candidate_samples}")
# Make the animation for visualizing the chain
animations = []
transition_animations = []
# Place the initial point
current_point = mcmc_samples[0]
current_point = Dot(
@ -332,10 +381,11 @@ class MCMCAxes(Group):
radius=self.dot_radius,
)
create_initial_point = Create(current_point)
animations.append(create_initial_point)
transition_animations.append(create_initial_point)
# Show the initial point's proposal distribution
# NOTE: visualize the warm up and the iterations
lines = []
warmup_points = []
num_iterations = len(mcmc_samples) + len(warm_up_samples)
for iteration in tqdm(range(1, num_iterations)):
next_sample = mcmc_samples[iteration]
@ -362,14 +412,50 @@ class MCMCAxes(Group):
transition_animation, line = self.make_transition_animation(
current_point, next_point, candidate_point
)
# Save assets
lines.append(line)
animations.append(transition_animation)
if iteration < len(warm_up_samples):
warmup_points.append(candidate_point)
# Add the transition animation
transition_animations.append(transition_animation)
# Setup for next iteration
current_point = next_point
# Make the final animation group
animation_group = AnimationGroup(
*animations,
# Overall MCMC animation
# 1. Fade in the distribution
image_mobject = make_dist_image_mobject_from_samples(
true_samples,
xlim=(self.x_range[0], self.x_range[1]),
ylim=(self.y_range[0], self.y_range[1])
)
image_mobject.scale_to_fit_height(
self.y_length
)
image_mobject.move_to(self.axes)
fade_in_distribution = FadeIn(
image_mobject,
run_time=0.5
)
# 2. Start sampling the chain
chain_sampling_animation = AnimationGroup(
*transition_animations,
lag_ratio=1.0,
run_time=5.0
)
# 3. Convert the chain to points, excluding the warmup
lines = VGroup(*lines)
warm_up_points = VGroup(*warmup_points)
fade_out_lines_and_warmup = AnimationGroup(
Uncreate(lines),
Uncreate(warm_up_points),
lag_ratio=0.0
)
# Make the final animation
animation_group = Succession(
fade_in_distribution,
chain_sampling_animation,
fade_out_lines_and_warmup,
lag_ratio=1.0
)
return animation_group, VGroup(*lines)
return animation_group

View File

@ -14,7 +14,7 @@ def convert_matplotlib_figure_to_image_mobject(fig, dpi=200):
matplotlib figure
"""
fig.tight_layout(pad=0)
plt.axis('off')
# plt.axis('off')
fig.canvas.draw()
# Save data into a buffer
image_buffer = io.BytesIO()

View File

@ -15,8 +15,8 @@ plt.style.use('dark_background')
# Make the specific scene
config.pixel_height = 1200
config.pixel_width = 1200
config.frame_height = 10.0
config.frame_width = 10.0
config.frame_height = 7.0
config.frame_width = 7.0
def test_metropolis_hastings_sampler(iterations=100):
samples, _, candidates = metropolis_hastings_sampler(iterations=iterations)