mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-18 03:05:23 +08:00
Made mcmc example. Added ability to view matplotlib plots.
This commit is contained in:
@ -6,12 +6,11 @@
|
||||
TODO reimplement the decision 2D decision tree surface drawing.
|
||||
"""
|
||||
from manim import *
|
||||
from manim_ml.decision_tree.classification_areas import (
|
||||
from manim_ml.decision_tree.decision_tree_surface import (
|
||||
compute_decision_areas,
|
||||
merge_overlapping_polygons,
|
||||
)
|
||||
import manim_ml.decision_tree.helpers as helpers
|
||||
from manim_ml.one_to_one_sync import OneToOneSync
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@ -330,6 +329,7 @@ class DecisionTreeDiagram(Group):
|
||||
# then show the split node
|
||||
# If it is a leaf then just show the leaf node
|
||||
pass
|
||||
pass
|
||||
|
||||
@override_animation(Create)
|
||||
def create_decision_tree(self, traversal_order="bfs"):
|
||||
@ -345,7 +345,7 @@ class DecisionTreeDiagram(Group):
|
||||
expand_tree_animation = self.make_expand_tree_animation(node_expand_order)
|
||||
return expand_tree_animation
|
||||
|
||||
class DecisionTreeContainer(OneToOneSync):
|
||||
class DecisionTreeContainer():
|
||||
"""Connects the DecisionTreeDiagram to the DecisionTreeEmbedding"""
|
||||
|
||||
def __init__(self, sklearn_tree, points, classes):
|
||||
|
@ -3,7 +3,6 @@ import numpy as np
|
||||
from collections import deque
|
||||
from sklearn.tree import _tree as ctree
|
||||
|
||||
|
||||
class AABB:
|
||||
"""Axis-aligned bounding box"""
|
||||
|
||||
@ -20,7 +19,6 @@ class AABB:
|
||||
|
||||
return left, right
|
||||
|
||||
|
||||
def tree_bounds(tree, n_features=None):
|
||||
"""Compute final decision rule for each node in tree"""
|
||||
if n_features is None:
|
||||
@ -36,8 +34,13 @@ def tree_bounds(tree, n_features=None):
|
||||
queue.extend([l, r])
|
||||
return aabbs
|
||||
|
||||
|
||||
def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None):
|
||||
def compute_decision_areas(
|
||||
tree_classifier,
|
||||
maxrange,
|
||||
x=0,
|
||||
y=1,
|
||||
n_features=None
|
||||
):
|
||||
"""Extract decision areas.
|
||||
|
||||
tree_classifier: Instance of a sklearn.tree.DecisionTreeClassifier
|
||||
@ -73,7 +76,6 @@ def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None)
|
||||
rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2])
|
||||
return rectangles
|
||||
|
||||
|
||||
def plot_areas(rectangles):
|
||||
for rect in rectangles:
|
||||
color = ["b", "r"][int(rect[4])]
|
||||
@ -87,7 +89,6 @@ def plot_areas(rectangles):
|
||||
)
|
||||
plt.gca().add_artist(rp)
|
||||
|
||||
|
||||
def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
|
||||
# get all polygons of each color
|
||||
polygon_dict = {
|
||||
@ -161,7 +162,6 @@ def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]):
|
||||
return_polygons.append(polygon)
|
||||
return return_polygons
|
||||
|
||||
|
||||
class IrisDatasetPlot(VGroup):
|
||||
def __init__(self, iris):
|
||||
points = iris.data[:, 0:2]
|
||||
@ -359,3 +359,4 @@ class DecisionTreeSurface(VGroup):
|
||||
# 1. Make a line split animation
|
||||
# 2. Create the relevant classification areas
|
||||
# and transform the old ones to them
|
||||
pass
|
||||
|
@ -66,8 +66,8 @@ def compute_bfs_traversal(tree):
|
||||
while len(queue) > 0:
|
||||
current_index = queue.pop(0)
|
||||
traversal_order.append(current_index)
|
||||
left_child_index = self.tree.children_left[node_index]
|
||||
right_child_index = self.tree.children_right[node_index]
|
||||
left_child_index = tree.children_left[node_index]
|
||||
right_child_index = tree.children_right[node_index]
|
||||
is_leaf_node = left_child_index == right_child_index
|
||||
if not is_leaf_node:
|
||||
queue.append(left_child_index)
|
||||
|
@ -9,8 +9,9 @@ from tqdm import tqdm
|
||||
|
||||
from manim_ml.utils.mobjects.probability import GaussianDistribution
|
||||
|
||||
######################## MCMC Algorithms #########################
|
||||
|
||||
def gaussian_proposal(x, sigma=0.2):
|
||||
def gaussian_proposal(x, sigma=1.0):
|
||||
"""
|
||||
Gaussian proposal distribution.
|
||||
|
||||
@ -86,7 +87,6 @@ class MultidimensionalGaussianPosterior:
|
||||
else:
|
||||
return -1e6
|
||||
|
||||
|
||||
def metropolis_hastings_sampler(
|
||||
log_prob_fn=MultidimensionalGaussianPosterior(),
|
||||
prop_fn=gaussian_proposal,
|
||||
@ -154,6 +154,7 @@ def metropolis_hastings_sampler(
|
||||
|
||||
return chain, np.array([]), proposals
|
||||
|
||||
#################### MCMC Visualization Tools ######################
|
||||
|
||||
class MCMCAxes(Group):
|
||||
"""Container object for visualizing MCMC on a 2D axis"""
|
||||
@ -161,11 +162,15 @@ class MCMCAxes(Group):
|
||||
def __init__(
|
||||
self,
|
||||
dot_color=BLUE,
|
||||
dot_radius=0.05,
|
||||
dot_radius=0.02,
|
||||
accept_line_color=GREEN,
|
||||
reject_line_color=RED,
|
||||
line_color=WHITE,
|
||||
line_stroke_width=1,
|
||||
line_color=BLUE,
|
||||
line_stroke_width=3,
|
||||
x_range=[-3, 3],
|
||||
y_range=[-3, 3],
|
||||
x_length=5,
|
||||
y_length=5
|
||||
):
|
||||
super().__init__()
|
||||
self.dot_color = dot_color
|
||||
@ -176,10 +181,10 @@ class MCMCAxes(Group):
|
||||
self.line_stroke_width = line_stroke_width
|
||||
# Make the axes
|
||||
self.axes = Axes(
|
||||
x_range=[-3, 3],
|
||||
y_range=[-3, 3],
|
||||
x_length=12,
|
||||
y_length=12,
|
||||
x_range=x_range,
|
||||
y_range=y_range,
|
||||
x_length=x_length,
|
||||
y_length=y_length,
|
||||
x_axis_config={"stroke_opacity": 0.0},
|
||||
y_axis_config={"stroke_opacity": 0.0},
|
||||
tips=False,
|
||||
@ -214,7 +219,12 @@ class MCMCAxes(Group):
|
||||
return create_guassian
|
||||
|
||||
def make_transition_animation(
|
||||
self, start_point, end_point, candidate_point, run_time=0.1
|
||||
self,
|
||||
start_point,
|
||||
end_point,
|
||||
candidate_point,
|
||||
show_dots=True,
|
||||
run_time=0.1
|
||||
) -> AnimationGroup:
|
||||
"""Makes an transition animation for a single point on a Markov Chain
|
||||
|
||||
@ -224,6 +234,8 @@ class MCMCAxes(Group):
|
||||
Start point of the transition
|
||||
end_point : Dot
|
||||
End point of the transition
|
||||
show_dots: boolean, optional
|
||||
Whether or not to show the dots
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -237,20 +249,32 @@ class MCMCAxes(Group):
|
||||
# point_is_rejected = not candidate_location == end_location
|
||||
point_is_rejected = False
|
||||
if point_is_rejected:
|
||||
return AnimationGroup()
|
||||
return AnimationGroup(), Dot().set_opacity(0.0)
|
||||
else:
|
||||
create_end_point = Create(end_point)
|
||||
create_line = Create(
|
||||
Line(
|
||||
line = Line(
|
||||
start_point,
|
||||
end_point,
|
||||
color=self.line_color,
|
||||
stroke_width=self.line_stroke_width,
|
||||
buff=-0.1
|
||||
)
|
||||
)
|
||||
|
||||
create_line = Create(line)
|
||||
|
||||
if show_dots:
|
||||
return AnimationGroup(
|
||||
create_end_point, create_line, lag_ratio=1.0, run_time=run_time
|
||||
)
|
||||
create_end_point,
|
||||
create_line,
|
||||
lag_ratio=1.0,
|
||||
run_time=run_time
|
||||
), line
|
||||
else:
|
||||
return AnimationGroup(
|
||||
create_line,
|
||||
lag_ratio=1.0,
|
||||
run_time=run_time
|
||||
), line
|
||||
|
||||
def show_ground_truth_gaussian(self, distribution):
|
||||
""" """
|
||||
@ -265,6 +289,7 @@ class MCMCAxes(Group):
|
||||
self,
|
||||
log_prob_fn=MultidimensionalGaussianPosterior(),
|
||||
prop_fn=gaussian_proposal,
|
||||
show_dots=False,
|
||||
sampling_kwargs={},
|
||||
):
|
||||
"""
|
||||
@ -281,6 +306,8 @@ class MCMCAxes(Group):
|
||||
Function to compute proposal location, by default gaussian_proposal
|
||||
initial_location : list, optional
|
||||
initial location for the markov chain, by default None
|
||||
show_dots : bool, optional
|
||||
whether or not to show the dots on the screen, by default False
|
||||
iterations : int, optional
|
||||
number of iterations of the markov chain, by default 100
|
||||
|
||||
@ -293,8 +320,8 @@ class MCMCAxes(Group):
|
||||
mcmc_samples, warm_up_samples, candidate_samples = metropolis_hastings_sampler(
|
||||
log_prob_fn=log_prob_fn, prop_fn=prop_fn, **sampling_kwargs
|
||||
)
|
||||
print(f"MCMC samples: {mcmc_samples}")
|
||||
print(f"Candidate samples: {candidate_samples}")
|
||||
# print(f"MCMC samples: {mcmc_samples}")
|
||||
# print(f"Candidate samples: {candidate_samples}")
|
||||
# Make the animation for visualizing the chain
|
||||
animations = []
|
||||
# Place the initial point
|
||||
@ -308,30 +335,41 @@ class MCMCAxes(Group):
|
||||
animations.append(create_initial_point)
|
||||
# Show the initial point's proposal distribution
|
||||
# NOTE: visualize the warm up and the iterations
|
||||
lines = []
|
||||
num_iterations = len(mcmc_samples) + len(warm_up_samples)
|
||||
for iteration in tqdm(range(1, num_iterations)):
|
||||
next_sample = mcmc_samples[iteration]
|
||||
print(f"Next sample: {next_sample}")
|
||||
# print(f"Next sample: {next_sample}")
|
||||
candidate_sample = candidate_samples[iteration - 1]
|
||||
# Make the next point
|
||||
next_point = Dot(
|
||||
self.axes.coords_to_point(next_sample[0], next_sample[1]),
|
||||
self.axes.coords_to_point(
|
||||
next_sample[0],
|
||||
next_sample[1]
|
||||
),
|
||||
color=self.dot_color,
|
||||
radius=self.dot_radius,
|
||||
)
|
||||
candidate_point = Dot(
|
||||
self.axes.coords_to_point(candidate_sample[0], candidate_sample[1]),
|
||||
self.axes.coords_to_point(
|
||||
candidate_sample[0],
|
||||
candidate_sample[1]
|
||||
),
|
||||
color=self.dot_color,
|
||||
radius=self.dot_radius,
|
||||
)
|
||||
# Make a transition animation
|
||||
transition_animation = self.make_transition_animation(
|
||||
transition_animation, line = self.make_transition_animation(
|
||||
current_point, next_point, candidate_point
|
||||
)
|
||||
lines.append(line)
|
||||
animations.append(transition_animation)
|
||||
# Setup for next iteration
|
||||
current_point = next_point
|
||||
# Make the final animation group
|
||||
animation_group = AnimationGroup(*animations, lag_ratio=1.0)
|
||||
animation_group = AnimationGroup(
|
||||
*animations,
|
||||
lag_ratio=1.0
|
||||
)
|
||||
|
||||
return animation_group
|
||||
return animation_group, VGroup(*lines)
|
||||
|
@ -174,6 +174,7 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
)
|
||||
|
||||
self.construct_activation_function()
|
||||
super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def construct_activation_function(self):
|
||||
"""Construct the activation function"""
|
||||
|
@ -50,6 +50,7 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer):
|
||||
self.latent_distribution = GaussianDistribution(
|
||||
self.axes, mean=self.mean, cov=self.covariance
|
||||
) # Use defaults
|
||||
super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def add_gaussian_distribution(self, gaussian_distribution):
|
||||
"""Adds given GaussianDistribution to the list"""
|
||||
|
@ -76,6 +76,7 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer):
|
||||
self.add(self.surrounding_rectangle, self.node_group)
|
||||
|
||||
self.construct_activation_function()
|
||||
super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def construct_activation_function(self):
|
||||
"""Construct the activation function"""
|
||||
|
@ -39,6 +39,7 @@ class FeedForwardToFeedForward(ConnectiveLayer):
|
||||
):
|
||||
self.edges = self.construct_edges()
|
||||
self.add(self.edges)
|
||||
super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
def construct_edges(self):
|
||||
# Go through each node in the two layers and make a connecting line
|
||||
|
@ -1,21 +1,27 @@
|
||||
from manim import *
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from manim_ml.utils.mobjects.image import GrayscaleImageMobject
|
||||
from manim_ml.neural_network.layers.parent_layers import NeuralNetworkLayer
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageLayer(NeuralNetworkLayer):
|
||||
"""Single Image Layer for Neural Network"""
|
||||
|
||||
def __init__(self, numpy_image, height=1.5, show_image_on_create=True, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
numpy_image,
|
||||
height=1.5,
|
||||
show_image_on_create=True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.image_height = height
|
||||
self.numpy_image = numpy_image
|
||||
self.show_image_on_create = show_image_on_create
|
||||
|
||||
def construct_layer(self, input_layer, output_layer):
|
||||
def construct_layer(self, input_layer, output_layer, **kwargs):
|
||||
"""Construct layer method
|
||||
|
||||
Parameters
|
||||
@ -29,7 +35,8 @@ class ImageLayer(NeuralNetworkLayer):
|
||||
# Assumed Grayscale
|
||||
self.num_channels = 1
|
||||
self.image_mobject = GrayscaleImageMobject(
|
||||
self.numpy_image, height=self.image_height
|
||||
self.numpy_image,
|
||||
height=self.image_height
|
||||
)
|
||||
elif len(np.shape(self.numpy_image)) == 3:
|
||||
# Assumed RGB
|
||||
@ -38,6 +45,7 @@ class ImageLayer(NeuralNetworkLayer):
|
||||
self.image_height
|
||||
)
|
||||
self.add(self.image_mobject)
|
||||
super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, image_path, grayscale=True, **kwargs):
|
||||
|
@ -67,6 +67,8 @@ class MaxPooling2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
||||
input_layer.feature_map_size[0] / self.kernel_size,
|
||||
input_layer.feature_map_size[1] / self.kernel_size,
|
||||
)
|
||||
super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
|
||||
def _make_output_feature_maps(self, num_input_feature_maps, input_feature_map_size):
|
||||
"""Makes a set of output feature maps"""
|
||||
|
@ -51,4 +51,4 @@ class MaxPooling2DToConvolutional2D(Convolutional2DToConvolutional2D):
|
||||
output_layer : NeuralNetworkLayer
|
||||
output layer
|
||||
"""
|
||||
pass
|
||||
super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
@ -1,7 +1,6 @@
|
||||
from manim import *
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class NeuralNetworkLayer(ABC, Group):
|
||||
"""Abstract Neural Network Layer class"""
|
||||
|
||||
@ -28,7 +27,8 @@ class NeuralNetworkLayer(ABC, Group):
|
||||
output_layer : NeuralNetworkLayer
|
||||
following layer
|
||||
"""
|
||||
pass
|
||||
if "debug_mode" in kwargs and kwargs["debug_mode"]:
|
||||
self.add(SurroundingRectangle(self))
|
||||
|
||||
@abstractmethod
|
||||
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
||||
@ -41,7 +41,6 @@ class NeuralNetworkLayer(ABC, Group):
|
||||
def __repr__(self):
|
||||
return f"{type(self).__name__}"
|
||||
|
||||
|
||||
class VGroupNeuralNetworkLayer(NeuralNetworkLayer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -55,7 +54,6 @@ class VGroupNeuralNetworkLayer(NeuralNetworkLayer):
|
||||
def _create_override(self):
|
||||
return super()._create_override()
|
||||
|
||||
|
||||
class ThreeDLayer(ABC):
|
||||
"""Abstract class for 3D layers"""
|
||||
|
||||
|
@ -35,6 +35,7 @@ class TripletLayer(NeuralNetworkLayer):
|
||||
# Make the assets
|
||||
self.assets = self.make_assets()
|
||||
self.add(self.assets)
|
||||
super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_paths(
|
||||
|
@ -18,6 +18,7 @@ class VectorLayer(VGroupNeuralNetworkLayer):
|
||||
output_layer: "NeuralNetworkLayer",
|
||||
**kwargs,
|
||||
):
|
||||
super().construct_layer(input_layer, output_layer, **kwargs)
|
||||
# Make the vector
|
||||
self.vector_label = self.make_vector()
|
||||
self.add(self.vector_label)
|
||||
|
@ -38,6 +38,7 @@ class NeuralNetwork(Group):
|
||||
title=" ",
|
||||
layout="linear",
|
||||
layout_direction="left_to_right",
|
||||
debug_mode=False
|
||||
):
|
||||
super(Group, self).__init__()
|
||||
self.input_layers_dict = self.make_input_layers_dict(input_layers)
|
||||
@ -51,6 +52,7 @@ class NeuralNetwork(Group):
|
||||
self.created = False
|
||||
self.layout = layout
|
||||
self.layout_direction = layout_direction
|
||||
self.debug_mode = debug_mode
|
||||
# TODO take layer_node_count [0, (1, 2), 0]
|
||||
# and make it have explicit distinct subspaces
|
||||
# Construct all of the layers
|
||||
@ -124,9 +126,17 @@ class NeuralNetwork(Group):
|
||||
if layer_index > 0:
|
||||
prev_layer = self.input_layers[layer_index - 1]
|
||||
# Run the construct layer method for each
|
||||
current_layer.construct_layer(prev_layer, next_layer)
|
||||
current_layer.construct_layer(
|
||||
prev_layer,
|
||||
next_layer,
|
||||
debug_mode=self.debug_mode
|
||||
)
|
||||
|
||||
def _place_layers(self, layout="linear", layout_direction="top_to_bottom"):
|
||||
def _place_layers(
|
||||
self,
|
||||
layout="linear",
|
||||
layout_direction="top_to_bottom"
|
||||
):
|
||||
"""Creates the neural network"""
|
||||
# TODO implement more sophisticated custom layouts
|
||||
# Default: Linear layout
|
||||
@ -224,10 +234,16 @@ class NeuralNetwork(Group):
|
||||
return animation_group
|
||||
|
||||
def make_forward_pass_animation(
|
||||
self, run_time=None, passing_flash=True, layer_args={}, **kwargs
|
||||
self,
|
||||
run_time=None,
|
||||
passing_flash=True,
|
||||
layer_args={},
|
||||
per_layer_animations=False,
|
||||
**kwargs
|
||||
):
|
||||
"""Generates an animation for feed forward propagation"""
|
||||
all_animations = []
|
||||
per_layer_animations = {}
|
||||
per_layer_runtime = (
|
||||
run_time / len(self.all_layers) if not run_time is None else None
|
||||
)
|
||||
@ -275,12 +291,18 @@ class NeuralNetwork(Group):
|
||||
break
|
||||
|
||||
layer_forward_pass = AnimationGroup(
|
||||
layer_forward_pass, connection_input_pass, lag_ratio=0.0
|
||||
layer_forward_pass,
|
||||
connection_input_pass,
|
||||
lag_ratio=0.0
|
||||
)
|
||||
all_animations.append(layer_forward_pass)
|
||||
# Add the animation to per layer animation
|
||||
per_layer_animations[layer] = layer_forward_pass
|
||||
# Make the animation group
|
||||
animation_group = Succession(*all_animations, lag_ratio=1.0)
|
||||
|
||||
if per_layer_animations:
|
||||
return per_layer_animations
|
||||
else:
|
||||
return animation_group
|
||||
|
||||
@override_animation(Create)
|
||||
|
@ -2,7 +2,6 @@ from manim import *
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class GrayscaleImageMobject(Group):
|
||||
"""Mobject for creating images in Manim from numpy arrays"""
|
||||
|
||||
@ -15,9 +14,14 @@ class GrayscaleImageMobject(Group):
|
||||
# Convert grayscale to rgb version of grayscale
|
||||
input_image = np.repeat(input_image, 3, axis=0)
|
||||
input_image = np.rollaxis(input_image, 0, start=3)
|
||||
self.image_mobject = ImageMobject(input_image, image_mode="RBG")
|
||||
self.image_mobject = ImageMobject(
|
||||
input_image,
|
||||
image_mode="RBG",
|
||||
)
|
||||
self.add(self.image_mobject)
|
||||
self.image_mobject.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
||||
self.image_mobject.set_resampling_algorithm(
|
||||
RESAMPLING_ALGORITHMS["nearest"]
|
||||
)
|
||||
self.image_mobject.scale_to_fit_height(height)
|
||||
|
||||
@classmethod
|
||||
|
28
manim_ml/utils/mobjects/plotting.py
Normal file
28
manim_ml/utils/mobjects/plotting.py
Normal file
@ -0,0 +1,28 @@
|
||||
from manim import *
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
def convert_matplotlib_figure_to_image_mobject(fig, dpi=200):
|
||||
"""Takes a matplotlib figure and makes an image mobject from it
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fig : matplotlib figure
|
||||
matplotlib figure
|
||||
"""
|
||||
fig.tight_layout(pad=0)
|
||||
plt.axis('off')
|
||||
fig.canvas.draw()
|
||||
# Save data into a buffer
|
||||
image_buffer = io.BytesIO()
|
||||
plt.savefig(image_buffer, format='png', dpi=dpi)
|
||||
# Reopen in PIL and convert to numpy
|
||||
image = Image.open(image_buffer)
|
||||
image = np.array(image)
|
||||
# Convert it to an image mobject
|
||||
image_mobject = ImageMobject(image, image_mode="RGB")
|
||||
|
||||
return image_mobject
|
4
setup.py
4
setup.py
@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="manim_ml",
|
||||
version="0.0.16",
|
||||
description=(" Machine Learning Animations in python using Manim."),
|
||||
version="0.0.17",
|
||||
description=("Machine Learning Animations in python using Manim."),
|
||||
packages=find_packages(),
|
||||
)
|
||||
|
BIN
tests/control_data/plotting/matplotlib_to_image_mobject.npz
Normal file
BIN
tests/control_data/plotting/matplotlib_to_image_mobject.npz
Normal file
Binary file not shown.
@ -1,6 +0,0 @@
|
||||
from manim_ml.flow.flow import *
|
||||
|
||||
|
||||
class TestScene(Scene):
|
||||
def construct(self):
|
||||
self.add(Rectangle())
|
@ -4,30 +4,99 @@ from manim_ml.diffusion.mcmc import (
|
||||
MultidimensionalGaussianPosterior,
|
||||
metropolis_hastings_sampler,
|
||||
)
|
||||
from manim_ml.utils.mobjects.plotting import convert_matplotlib_figure_to_image_mobject
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import matplotlib
|
||||
plt.style.use('dark_background')
|
||||
|
||||
# Make the specific scene
|
||||
config.pixel_height = 1200
|
||||
config.pixel_width = 1200
|
||||
config.frame_height = 12.0
|
||||
config.frame_width = 12.0
|
||||
|
||||
config.frame_height = 10.0
|
||||
config.frame_width = 10.0
|
||||
|
||||
def test_metropolis_hastings_sampler(iterations=100):
|
||||
samples, _, candidates = metropolis_hastings_sampler(iterations=iterations)
|
||||
assert samples.shape == (iterations, 2)
|
||||
|
||||
def plot_hexbin_gaussian_on_image_mobject(
|
||||
sample_func,
|
||||
xlim=(-4, 4),
|
||||
ylim=(-4, 4)
|
||||
):
|
||||
# Fixing random state for reproducibility
|
||||
np.random.seed(19680801)
|
||||
n = 100_000
|
||||
samples = []
|
||||
for i in range(n):
|
||||
samples.append(sample_func())
|
||||
samples = np.array(samples)
|
||||
|
||||
x = samples[:, 0]
|
||||
y = samples[:, 1]
|
||||
|
||||
fig, ax0 = plt.subplots(1, figsize=(5, 5))
|
||||
|
||||
hb = ax0.hexbin(x, y, gridsize=50, cmap='gist_heat')
|
||||
|
||||
ax0.set(xlim=xlim, ylim=ylim)
|
||||
|
||||
return convert_matplotlib_figure_to_image_mobject(fig)
|
||||
|
||||
class MCMCTest(Scene):
|
||||
def construct(self):
|
||||
axes = MCMCAxes()
|
||||
self.play(Create(axes))
|
||||
gaussian_posterior = MultidimensionalGaussianPosterior(
|
||||
mu=np.array([0.0, 0.0]), var=np.array([4.0, 2.0])
|
||||
|
||||
def construct(
|
||||
self,
|
||||
mu=np.array([0.0, 0.0]),
|
||||
var=np.array([[1.0, 1.0]])
|
||||
):
|
||||
|
||||
def gaussian_sample_func():
|
||||
vals = np.random.multivariate_normal(
|
||||
mu,
|
||||
np.eye(2) * var,
|
||||
1
|
||||
)[0]
|
||||
|
||||
return vals
|
||||
|
||||
image_mobject = plot_hexbin_gaussian_on_image_mobject(
|
||||
gaussian_sample_func
|
||||
)
|
||||
show_gaussian_animation = axes.show_ground_truth_gaussian(gaussian_posterior)
|
||||
self.play(show_gaussian_animation)
|
||||
chain_sampling_animation = axes.visualize_metropolis_hastings_chain_sampling(
|
||||
log_prob_fn=gaussian_posterior, sampling_kwargs={"iterations": 1000}
|
||||
self.add(image_mobject)
|
||||
self.play(FadeOut(image_mobject))
|
||||
|
||||
axes = MCMCAxes(
|
||||
x_range=[-4, 4],
|
||||
y_range=[-4, 4],
|
||||
)
|
||||
self.play(
|
||||
Create(axes)
|
||||
)
|
||||
|
||||
self.play(chain_sampling_animation)
|
||||
gaussian_posterior = MultidimensionalGaussianPosterior(
|
||||
mu=np.array([0.0, 0.0]),
|
||||
var=np.array([1.0, 1.0])
|
||||
)
|
||||
|
||||
chain_sampling_animation, lines = axes.visualize_metropolis_hastings_chain_sampling(
|
||||
log_prob_fn=gaussian_posterior,
|
||||
sampling_kwargs={"iterations": 500},
|
||||
)
|
||||
|
||||
self.play(
|
||||
chain_sampling_animation,
|
||||
run_time=3.5
|
||||
)
|
||||
self.play(
|
||||
FadeOut(lines)
|
||||
)
|
||||
self.wait(1)
|
||||
self.play(
|
||||
FadeIn(image_mobject)
|
||||
)
|
||||
|
||||
|
||||
|
71
tests/test_plotting.py
Normal file
71
tests/test_plotting.py
Normal file
@ -0,0 +1,71 @@
|
||||
|
||||
from manim import *
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import matplotlib
|
||||
plt.style.use('dark_background')
|
||||
|
||||
from manim_ml.utils.mobjects.plotting import convert_matplotlib_figure_to_image_mobject
|
||||
from manim_ml.utils.testing.frames_comparison import frames_comparison
|
||||
|
||||
__module_test__ = "plotting"
|
||||
|
||||
@frames_comparison
|
||||
def test_matplotlib_to_image_mobject(scene):
|
||||
# libraries & dataset
|
||||
df = sns.load_dataset('iris')
|
||||
# Custom the color, add shade and bandwidth
|
||||
matplotlib.use('Agg')
|
||||
plt.figure(figsize=(10,10), dpi=100)
|
||||
displot = sns.displot(
|
||||
x=df.sepal_width,
|
||||
y=df.sepal_length,
|
||||
cmap="Reds",
|
||||
kind="kde",
|
||||
)
|
||||
plt.axis('off')
|
||||
fig = displot.fig
|
||||
image_mobject = convert_matplotlib_figure_to_image_mobject(fig)
|
||||
# Display the image mobject
|
||||
scene.add(image_mobject)
|
||||
|
||||
class TestMatplotlibToImageMobject(Scene):
|
||||
|
||||
def construct(self):
|
||||
# Make a matplotlib plot
|
||||
# libraries & dataset
|
||||
df = sns.load_dataset('iris')
|
||||
# Custom the color, add shade and bandwidth
|
||||
matplotlib.use('Agg')
|
||||
plt.figure(figsize=(10,10), dpi=100)
|
||||
displot = sns.displot(
|
||||
x=df.sepal_width,
|
||||
y=df.sepal_length,
|
||||
cmap="Reds",
|
||||
kind="kde",
|
||||
)
|
||||
plt.axis('off')
|
||||
fig = displot.fig
|
||||
image_mobject = convert_matplotlib_figure_to_image_mobject(fig)
|
||||
# Display the image mobject
|
||||
self.add(image_mobject)
|
||||
|
||||
|
||||
class HexabinScene(Scene):
|
||||
|
||||
def construct(self):
|
||||
# Fixing random state for reproducibility
|
||||
np.random.seed(19680801)
|
||||
n = 100_000
|
||||
x = np.random.standard_normal(n)
|
||||
y = x + 1.0 * np.random.standard_normal(n)
|
||||
xlim = -4, 4
|
||||
ylim = -4, 4
|
||||
|
||||
fig, ax0 = plt.subplots(1, figsize=(5, 5))
|
||||
|
||||
hb = ax0.hexbin(x, y, gridsize=50, cmap='inferno')
|
||||
ax0.set(xlim=xlim, ylim=ylim)
|
||||
|
||||
self.add(convert_matplotlib_figure_to_image_mobject(fig))
|
Reference in New Issue
Block a user