diff --git a/manim_ml/decision_tree/decision_tree.py b/manim_ml/decision_tree/decision_tree.py index 8c3e7ba..a4904ed 100644 --- a/manim_ml/decision_tree/decision_tree.py +++ b/manim_ml/decision_tree/decision_tree.py @@ -6,12 +6,11 @@ TODO reimplement the decision 2D decision tree surface drawing. """ from manim import * -from manim_ml.decision_tree.classification_areas import ( +from manim_ml.decision_tree.decision_tree_surface import ( compute_decision_areas, merge_overlapping_polygons, ) import manim_ml.decision_tree.helpers as helpers -from manim_ml.one_to_one_sync import OneToOneSync import numpy as np from PIL import Image @@ -329,6 +328,7 @@ class DecisionTreeDiagram(Group): # If it is not a leaf then remove the placeholder leaf node # then show the split node # If it is a leaf then just show the leaf node + pass pass @override_animation(Create) @@ -345,7 +345,7 @@ class DecisionTreeDiagram(Group): expand_tree_animation = self.make_expand_tree_animation(node_expand_order) return expand_tree_animation -class DecisionTreeContainer(OneToOneSync): +class DecisionTreeContainer(): """Connects the DecisionTreeDiagram to the DecisionTreeEmbedding""" def __init__(self, sklearn_tree, points, classes): diff --git a/manim_ml/decision_tree/decision_tree_surface.py b/manim_ml/decision_tree/decision_tree_surface.py index 8ac3c1e..0015086 100644 --- a/manim_ml/decision_tree/decision_tree_surface.py +++ b/manim_ml/decision_tree/decision_tree_surface.py @@ -3,7 +3,6 @@ import numpy as np from collections import deque from sklearn.tree import _tree as ctree - class AABB: """Axis-aligned bounding box""" @@ -20,7 +19,6 @@ class AABB: return left, right - def tree_bounds(tree, n_features=None): """Compute final decision rule for each node in tree""" if n_features is None: @@ -36,8 +34,13 @@ def tree_bounds(tree, n_features=None): queue.extend([l, r]) return aabbs - -def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None): +def compute_decision_areas( + tree_classifier, + maxrange, + x=0, + y=1, + n_features=None +): """Extract decision areas. tree_classifier: Instance of a sklearn.tree.DecisionTreeClassifier @@ -73,7 +76,6 @@ def compute_decision_areas(tree_classifier, maxrange, x=0, y=1, n_features=None) rectangles[:, [1, 3]] = np.minimum(rectangles[:, [1, 3]], maxrange[1::2]) return rectangles - def plot_areas(rectangles): for rect in rectangles: color = ["b", "r"][int(rect[4])] @@ -87,7 +89,6 @@ def plot_areas(rectangles): ) plt.gca().add_artist(rp) - def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]): # get all polygons of each color polygon_dict = { @@ -161,7 +162,6 @@ def merge_overlapping_polygons(all_polygons, colors=[BLUE, GREEN, ORANGE]): return_polygons.append(polygon) return return_polygons - class IrisDatasetPlot(VGroup): def __init__(self, iris): points = iris.data[:, 0:2] @@ -359,3 +359,4 @@ class DecisionTreeSurface(VGroup): # 1. Make a line split animation # 2. Create the relevant classification areas # and transform the old ones to them + pass diff --git a/manim_ml/decision_tree/helpers.py b/manim_ml/decision_tree/helpers.py index 50a7ed8..2ee94ec 100644 --- a/manim_ml/decision_tree/helpers.py +++ b/manim_ml/decision_tree/helpers.py @@ -66,8 +66,8 @@ def compute_bfs_traversal(tree): while len(queue) > 0: current_index = queue.pop(0) traversal_order.append(current_index) - left_child_index = self.tree.children_left[node_index] - right_child_index = self.tree.children_right[node_index] + left_child_index = tree.children_left[node_index] + right_child_index = tree.children_right[node_index] is_leaf_node = left_child_index == right_child_index if not is_leaf_node: queue.append(left_child_index) diff --git a/manim_ml/diffusion/mcmc.py b/manim_ml/diffusion/mcmc.py index f2429a7..929874e 100644 --- a/manim_ml/diffusion/mcmc.py +++ b/manim_ml/diffusion/mcmc.py @@ -9,8 +9,9 @@ from tqdm import tqdm from manim_ml.utils.mobjects.probability import GaussianDistribution +######################## MCMC Algorithms ######################### -def gaussian_proposal(x, sigma=0.2): +def gaussian_proposal(x, sigma=1.0): """ Gaussian proposal distribution. @@ -86,7 +87,6 @@ class MultidimensionalGaussianPosterior: else: return -1e6 - def metropolis_hastings_sampler( log_prob_fn=MultidimensionalGaussianPosterior(), prop_fn=gaussian_proposal, @@ -154,6 +154,7 @@ def metropolis_hastings_sampler( return chain, np.array([]), proposals +#################### MCMC Visualization Tools ###################### class MCMCAxes(Group): """Container object for visualizing MCMC on a 2D axis""" @@ -161,11 +162,15 @@ class MCMCAxes(Group): def __init__( self, dot_color=BLUE, - dot_radius=0.05, + dot_radius=0.02, accept_line_color=GREEN, reject_line_color=RED, - line_color=WHITE, - line_stroke_width=1, + line_color=BLUE, + line_stroke_width=3, + x_range=[-3, 3], + y_range=[-3, 3], + x_length=5, + y_length=5 ): super().__init__() self.dot_color = dot_color @@ -176,10 +181,10 @@ class MCMCAxes(Group): self.line_stroke_width = line_stroke_width # Make the axes self.axes = Axes( - x_range=[-3, 3], - y_range=[-3, 3], - x_length=12, - y_length=12, + x_range=x_range, + y_range=y_range, + x_length=x_length, + y_length=y_length, x_axis_config={"stroke_opacity": 0.0}, y_axis_config={"stroke_opacity": 0.0}, tips=False, @@ -214,7 +219,12 @@ class MCMCAxes(Group): return create_guassian def make_transition_animation( - self, start_point, end_point, candidate_point, run_time=0.1 + self, + start_point, + end_point, + candidate_point, + show_dots=True, + run_time=0.1 ) -> AnimationGroup: """Makes an transition animation for a single point on a Markov Chain @@ -224,6 +234,8 @@ class MCMCAxes(Group): Start point of the transition end_point : Dot End point of the transition + show_dots: boolean, optional + Whether or not to show the dots Returns ------- @@ -237,21 +249,33 @@ class MCMCAxes(Group): # point_is_rejected = not candidate_location == end_location point_is_rejected = False if point_is_rejected: - return AnimationGroup() + return AnimationGroup(), Dot().set_opacity(0.0) else: create_end_point = Create(end_point) - create_line = Create( - Line( - start_point, - end_point, - color=self.line_color, - stroke_width=self.line_stroke_width, - ) - ) - return AnimationGroup( - create_end_point, create_line, lag_ratio=1.0, run_time=run_time + line = Line( + start_point, + end_point, + color=self.line_color, + stroke_width=self.line_stroke_width, + buff=-0.1 ) + create_line = Create(line) + + if show_dots: + return AnimationGroup( + create_end_point, + create_line, + lag_ratio=1.0, + run_time=run_time + ), line + else: + return AnimationGroup( + create_line, + lag_ratio=1.0, + run_time=run_time + ), line + def show_ground_truth_gaussian(self, distribution): """ """ mean = distribution.mu @@ -265,6 +289,7 @@ class MCMCAxes(Group): self, log_prob_fn=MultidimensionalGaussianPosterior(), prop_fn=gaussian_proposal, + show_dots=False, sampling_kwargs={}, ): """ @@ -281,6 +306,8 @@ class MCMCAxes(Group): Function to compute proposal location, by default gaussian_proposal initial_location : list, optional initial location for the markov chain, by default None + show_dots : bool, optional + whether or not to show the dots on the screen, by default False iterations : int, optional number of iterations of the markov chain, by default 100 @@ -293,8 +320,8 @@ class MCMCAxes(Group): mcmc_samples, warm_up_samples, candidate_samples = metropolis_hastings_sampler( log_prob_fn=log_prob_fn, prop_fn=prop_fn, **sampling_kwargs ) - print(f"MCMC samples: {mcmc_samples}") - print(f"Candidate samples: {candidate_samples}") + # print(f"MCMC samples: {mcmc_samples}") + # print(f"Candidate samples: {candidate_samples}") # Make the animation for visualizing the chain animations = [] # Place the initial point @@ -308,30 +335,41 @@ class MCMCAxes(Group): animations.append(create_initial_point) # Show the initial point's proposal distribution # NOTE: visualize the warm up and the iterations + lines = [] num_iterations = len(mcmc_samples) + len(warm_up_samples) for iteration in tqdm(range(1, num_iterations)): next_sample = mcmc_samples[iteration] - print(f"Next sample: {next_sample}") + # print(f"Next sample: {next_sample}") candidate_sample = candidate_samples[iteration - 1] # Make the next point next_point = Dot( - self.axes.coords_to_point(next_sample[0], next_sample[1]), + self.axes.coords_to_point( + next_sample[0], + next_sample[1] + ), color=self.dot_color, radius=self.dot_radius, ) candidate_point = Dot( - self.axes.coords_to_point(candidate_sample[0], candidate_sample[1]), + self.axes.coords_to_point( + candidate_sample[0], + candidate_sample[1] + ), color=self.dot_color, radius=self.dot_radius, ) # Make a transition animation - transition_animation = self.make_transition_animation( + transition_animation, line = self.make_transition_animation( current_point, next_point, candidate_point ) + lines.append(line) animations.append(transition_animation) # Setup for next iteration current_point = next_point # Make the final animation group - animation_group = AnimationGroup(*animations, lag_ratio=1.0) + animation_group = AnimationGroup( + *animations, + lag_ratio=1.0 + ) - return animation_group + return animation_group, VGroup(*lines) diff --git a/manim_ml/neural_network/layers/convolutional_2d.py b/manim_ml/neural_network/layers/convolutional_2d.py index 5b51e81..b95c740 100644 --- a/manim_ml/neural_network/layers/convolutional_2d.py +++ b/manim_ml/neural_network/layers/convolutional_2d.py @@ -174,6 +174,7 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer): ) self.construct_activation_function() + super().construct_layer(input_layer, output_layer, **kwargs) def construct_activation_function(self): """Construct the activation function""" diff --git a/manim_ml/neural_network/layers/embedding.py b/manim_ml/neural_network/layers/embedding.py index 14f4be4..6e07fe5 100644 --- a/manim_ml/neural_network/layers/embedding.py +++ b/manim_ml/neural_network/layers/embedding.py @@ -50,6 +50,7 @@ class EmbeddingLayer(VGroupNeuralNetworkLayer): self.latent_distribution = GaussianDistribution( self.axes, mean=self.mean, cov=self.covariance ) # Use defaults + super().construct_layer(input_layer, output_layer, **kwargs) def add_gaussian_distribution(self, gaussian_distribution): """Adds given GaussianDistribution to the list""" diff --git a/manim_ml/neural_network/layers/feed_forward.py b/manim_ml/neural_network/layers/feed_forward.py index 262046a..af3f01e 100644 --- a/manim_ml/neural_network/layers/feed_forward.py +++ b/manim_ml/neural_network/layers/feed_forward.py @@ -76,6 +76,7 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer): self.add(self.surrounding_rectangle, self.node_group) self.construct_activation_function() + super().construct_layer(input_layer, output_layer, **kwargs) def construct_activation_function(self): """Construct the activation function""" diff --git a/manim_ml/neural_network/layers/feed_forward_to_feed_forward.py b/manim_ml/neural_network/layers/feed_forward_to_feed_forward.py index d55ae78..f94d585 100644 --- a/manim_ml/neural_network/layers/feed_forward_to_feed_forward.py +++ b/manim_ml/neural_network/layers/feed_forward_to_feed_forward.py @@ -39,6 +39,7 @@ class FeedForwardToFeedForward(ConnectiveLayer): ): self.edges = self.construct_edges() self.add(self.edges) + super().construct_layer(input_layer, output_layer, **kwargs) def construct_edges(self): # Go through each node in the two layers and make a connecting line diff --git a/manim_ml/neural_network/layers/image.py b/manim_ml/neural_network/layers/image.py index 2d70e7e..e0e828b 100644 --- a/manim_ml/neural_network/layers/image.py +++ b/manim_ml/neural_network/layers/image.py @@ -1,21 +1,27 @@ from manim import * import numpy as np +from PIL import Image + from manim_ml.utils.mobjects.image import GrayscaleImageMobject from manim_ml.neural_network.layers.parent_layers import NeuralNetworkLayer -from PIL import Image - class ImageLayer(NeuralNetworkLayer): """Single Image Layer for Neural Network""" - def __init__(self, numpy_image, height=1.5, show_image_on_create=True, **kwargs): + def __init__( + self, + numpy_image, + height=1.5, + show_image_on_create=True, + **kwargs + ): super().__init__(**kwargs) self.image_height = height self.numpy_image = numpy_image self.show_image_on_create = show_image_on_create - def construct_layer(self, input_layer, output_layer): + def construct_layer(self, input_layer, output_layer, **kwargs): """Construct layer method Parameters @@ -29,7 +35,8 @@ class ImageLayer(NeuralNetworkLayer): # Assumed Grayscale self.num_channels = 1 self.image_mobject = GrayscaleImageMobject( - self.numpy_image, height=self.image_height + self.numpy_image, + height=self.image_height ) elif len(np.shape(self.numpy_image)) == 3: # Assumed RGB @@ -38,6 +45,7 @@ class ImageLayer(NeuralNetworkLayer): self.image_height ) self.add(self.image_mobject) + super().construct_layer(input_layer, output_layer, **kwargs) @classmethod def from_path(cls, image_path, grayscale=True, **kwargs): diff --git a/manim_ml/neural_network/layers/max_pooling_2d.py b/manim_ml/neural_network/layers/max_pooling_2d.py index 50e8409..5e5f3ea 100644 --- a/manim_ml/neural_network/layers/max_pooling_2d.py +++ b/manim_ml/neural_network/layers/max_pooling_2d.py @@ -67,6 +67,8 @@ class MaxPooling2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer): input_layer.feature_map_size[0] / self.kernel_size, input_layer.feature_map_size[1] / self.kernel_size, ) + super().construct_layer(input_layer, output_layer, **kwargs) + def _make_output_feature_maps(self, num_input_feature_maps, input_feature_map_size): """Makes a set of output feature maps""" diff --git a/manim_ml/neural_network/layers/max_pooling_2d_to_convolutional_2d.py b/manim_ml/neural_network/layers/max_pooling_2d_to_convolutional_2d.py index dd6ce7b..43ba7dc 100644 --- a/manim_ml/neural_network/layers/max_pooling_2d_to_convolutional_2d.py +++ b/manim_ml/neural_network/layers/max_pooling_2d_to_convolutional_2d.py @@ -51,4 +51,4 @@ class MaxPooling2DToConvolutional2D(Convolutional2DToConvolutional2D): output_layer : NeuralNetworkLayer output layer """ - pass + super().construct_layer(input_layer, output_layer, **kwargs) diff --git a/manim_ml/neural_network/layers/parent_layers.py b/manim_ml/neural_network/layers/parent_layers.py index 711baff..0346cc5 100644 --- a/manim_ml/neural_network/layers/parent_layers.py +++ b/manim_ml/neural_network/layers/parent_layers.py @@ -1,7 +1,6 @@ from manim import * from abc import ABC, abstractmethod - class NeuralNetworkLayer(ABC, Group): """Abstract Neural Network Layer class""" @@ -28,7 +27,8 @@ class NeuralNetworkLayer(ABC, Group): output_layer : NeuralNetworkLayer following layer """ - pass + if "debug_mode" in kwargs and kwargs["debug_mode"]: + self.add(SurroundingRectangle(self)) @abstractmethod def make_forward_pass_animation(self, layer_args={}, **kwargs): @@ -41,7 +41,6 @@ class NeuralNetworkLayer(ABC, Group): def __repr__(self): return f"{type(self).__name__}" - class VGroupNeuralNetworkLayer(NeuralNetworkLayer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -55,7 +54,6 @@ class VGroupNeuralNetworkLayer(NeuralNetworkLayer): def _create_override(self): return super()._create_override() - class ThreeDLayer(ABC): """Abstract class for 3D layers""" diff --git a/manim_ml/neural_network/layers/triplet.py b/manim_ml/neural_network/layers/triplet.py index 12062e6..789da38 100644 --- a/manim_ml/neural_network/layers/triplet.py +++ b/manim_ml/neural_network/layers/triplet.py @@ -35,6 +35,7 @@ class TripletLayer(NeuralNetworkLayer): # Make the assets self.assets = self.make_assets() self.add(self.assets) + super().construct_layer(input_layer, output_layer, **kwargs) @classmethod def from_paths( diff --git a/manim_ml/neural_network/layers/vector.py b/manim_ml/neural_network/layers/vector.py index 0d536c6..ca41b84 100644 --- a/manim_ml/neural_network/layers/vector.py +++ b/manim_ml/neural_network/layers/vector.py @@ -18,6 +18,7 @@ class VectorLayer(VGroupNeuralNetworkLayer): output_layer: "NeuralNetworkLayer", **kwargs, ): + super().construct_layer(input_layer, output_layer, **kwargs) # Make the vector self.vector_label = self.make_vector() self.add(self.vector_label) diff --git a/manim_ml/neural_network/neural_network.py b/manim_ml/neural_network/neural_network.py index 895f60a..2b6b833 100644 --- a/manim_ml/neural_network/neural_network.py +++ b/manim_ml/neural_network/neural_network.py @@ -38,6 +38,7 @@ class NeuralNetwork(Group): title=" ", layout="linear", layout_direction="left_to_right", + debug_mode=False ): super(Group, self).__init__() self.input_layers_dict = self.make_input_layers_dict(input_layers) @@ -51,6 +52,7 @@ class NeuralNetwork(Group): self.created = False self.layout = layout self.layout_direction = layout_direction + self.debug_mode = debug_mode # TODO take layer_node_count [0, (1, 2), 0] # and make it have explicit distinct subspaces # Construct all of the layers @@ -124,9 +126,17 @@ class NeuralNetwork(Group): if layer_index > 0: prev_layer = self.input_layers[layer_index - 1] # Run the construct layer method for each - current_layer.construct_layer(prev_layer, next_layer) + current_layer.construct_layer( + prev_layer, + next_layer, + debug_mode=self.debug_mode + ) - def _place_layers(self, layout="linear", layout_direction="top_to_bottom"): + def _place_layers( + self, + layout="linear", + layout_direction="top_to_bottom" + ): """Creates the neural network""" # TODO implement more sophisticated custom layouts # Default: Linear layout @@ -224,10 +234,16 @@ class NeuralNetwork(Group): return animation_group def make_forward_pass_animation( - self, run_time=None, passing_flash=True, layer_args={}, **kwargs + self, + run_time=None, + passing_flash=True, + layer_args={}, + per_layer_animations=False, + **kwargs ): """Generates an animation for feed forward propagation""" all_animations = [] + per_layer_animations = {} per_layer_runtime = ( run_time / len(self.all_layers) if not run_time is None else None ) @@ -275,13 +291,19 @@ class NeuralNetwork(Group): break layer_forward_pass = AnimationGroup( - layer_forward_pass, connection_input_pass, lag_ratio=0.0 + layer_forward_pass, + connection_input_pass, + lag_ratio=0.0 ) all_animations.append(layer_forward_pass) + # Add the animation to per layer animation + per_layer_animations[layer] = layer_forward_pass # Make the animation group animation_group = Succession(*all_animations, lag_ratio=1.0) - - return animation_group + if per_layer_animations: + return per_layer_animations + else: + return animation_group @override_animation(Create) def _create_override(self, **kwargs): diff --git a/manim_ml/utils/mobjects/image.py b/manim_ml/utils/mobjects/image.py index e9c942f..b6ac47a 100644 --- a/manim_ml/utils/mobjects/image.py +++ b/manim_ml/utils/mobjects/image.py @@ -2,7 +2,6 @@ from manim import * import numpy as np from PIL import Image - class GrayscaleImageMobject(Group): """Mobject for creating images in Manim from numpy arrays""" @@ -15,9 +14,14 @@ class GrayscaleImageMobject(Group): # Convert grayscale to rgb version of grayscale input_image = np.repeat(input_image, 3, axis=0) input_image = np.rollaxis(input_image, 0, start=3) - self.image_mobject = ImageMobject(input_image, image_mode="RBG") + self.image_mobject = ImageMobject( + input_image, + image_mode="RBG", + ) self.add(self.image_mobject) - self.image_mobject.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"]) + self.image_mobject.set_resampling_algorithm( + RESAMPLING_ALGORITHMS["nearest"] + ) self.image_mobject.scale_to_fit_height(height) @classmethod diff --git a/manim_ml/utils/mobjects/plotting.py b/manim_ml/utils/mobjects/plotting.py new file mode 100644 index 0000000..296091b --- /dev/null +++ b/manim_ml/utils/mobjects/plotting.py @@ -0,0 +1,28 @@ +from manim import * +import numpy as np +import matplotlib +import matplotlib.pyplot as plt +from PIL import Image +import io + +def convert_matplotlib_figure_to_image_mobject(fig, dpi=200): + """Takes a matplotlib figure and makes an image mobject from it + + Parameters + ---------- + fig : matplotlib figure + matplotlib figure + """ + fig.tight_layout(pad=0) + plt.axis('off') + fig.canvas.draw() + # Save data into a buffer + image_buffer = io.BytesIO() + plt.savefig(image_buffer, format='png', dpi=dpi) + # Reopen in PIL and convert to numpy + image = Image.open(image_buffer) + image = np.array(image) + # Convert it to an image mobject + image_mobject = ImageMobject(image, image_mode="RGB") + + return image_mobject \ No newline at end of file diff --git a/setup.py b/setup.py index e42ac64..97a8afa 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages setup( name="manim_ml", - version="0.0.16", - description=(" Machine Learning Animations in python using Manim."), + version="0.0.17", + description=("Machine Learning Animations in python using Manim."), packages=find_packages(), ) diff --git a/tests/control_data/plotting/matplotlib_to_image_mobject.npz b/tests/control_data/plotting/matplotlib_to_image_mobject.npz new file mode 100644 index 0000000..80f6c79 Binary files /dev/null and b/tests/control_data/plotting/matplotlib_to_image_mobject.npz differ diff --git a/tests/test_flow.py b/tests/test_flow.py deleted file mode 100644 index ef76f0c..0000000 --- a/tests/test_flow.py +++ /dev/null @@ -1,6 +0,0 @@ -from manim_ml.flow.flow import * - - -class TestScene(Scene): - def construct(self): - self.add(Rectangle()) diff --git a/tests/test_mcmc.py b/tests/test_mcmc.py index 745c42d..fad0ac3 100644 --- a/tests/test_mcmc.py +++ b/tests/test_mcmc.py @@ -4,30 +4,99 @@ from manim_ml.diffusion.mcmc import ( MultidimensionalGaussianPosterior, metropolis_hastings_sampler, ) +from manim_ml.utils.mobjects.plotting import convert_matplotlib_figure_to_image_mobject + +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +import matplotlib +plt.style.use('dark_background') # Make the specific scene config.pixel_height = 1200 config.pixel_width = 1200 -config.frame_height = 12.0 -config.frame_width = 12.0 - +config.frame_height = 10.0 +config.frame_width = 10.0 def test_metropolis_hastings_sampler(iterations=100): samples, _, candidates = metropolis_hastings_sampler(iterations=iterations) assert samples.shape == (iterations, 2) +def plot_hexbin_gaussian_on_image_mobject( + sample_func, + xlim=(-4, 4), + ylim=(-4, 4) +): + # Fixing random state for reproducibility + np.random.seed(19680801) + n = 100_000 + samples = [] + for i in range(n): + samples.append(sample_func()) + samples = np.array(samples) + + x = samples[:, 0] + y = samples[:, 1] + + fig, ax0 = plt.subplots(1, figsize=(5, 5)) + + hb = ax0.hexbin(x, y, gridsize=50, cmap='gist_heat') + + ax0.set(xlim=xlim, ylim=ylim) + + return convert_matplotlib_figure_to_image_mobject(fig) class MCMCTest(Scene): - def construct(self): - axes = MCMCAxes() - self.play(Create(axes)) - gaussian_posterior = MultidimensionalGaussianPosterior( - mu=np.array([0.0, 0.0]), var=np.array([4.0, 2.0]) + + def construct( + self, + mu=np.array([0.0, 0.0]), + var=np.array([[1.0, 1.0]]) + ): + + def gaussian_sample_func(): + vals = np.random.multivariate_normal( + mu, + np.eye(2) * var, + 1 + )[0] + + return vals + + image_mobject = plot_hexbin_gaussian_on_image_mobject( + gaussian_sample_func ) - show_gaussian_animation = axes.show_ground_truth_gaussian(gaussian_posterior) - self.play(show_gaussian_animation) - chain_sampling_animation = axes.visualize_metropolis_hastings_chain_sampling( - log_prob_fn=gaussian_posterior, sampling_kwargs={"iterations": 1000} + self.add(image_mobject) + self.play(FadeOut(image_mobject)) + + axes = MCMCAxes( + x_range=[-4, 4], + y_range=[-4, 4], + ) + self.play( + Create(axes) ) - self.play(chain_sampling_animation) + gaussian_posterior = MultidimensionalGaussianPosterior( + mu=np.array([0.0, 0.0]), + var=np.array([1.0, 1.0]) + ) + + chain_sampling_animation, lines = axes.visualize_metropolis_hastings_chain_sampling( + log_prob_fn=gaussian_posterior, + sampling_kwargs={"iterations": 500}, + ) + + self.play( + chain_sampling_animation, + run_time=3.5 + ) + self.play( + FadeOut(lines) + ) + self.wait(1) + self.play( + FadeIn(image_mobject) + ) + + diff --git a/tests/test_plotting.py b/tests/test_plotting.py new file mode 100644 index 0000000..f328564 --- /dev/null +++ b/tests/test_plotting.py @@ -0,0 +1,71 @@ + +from manim import * + +import matplotlib.pyplot as plt +import seaborn as sns +import matplotlib +plt.style.use('dark_background') + +from manim_ml.utils.mobjects.plotting import convert_matplotlib_figure_to_image_mobject +from manim_ml.utils.testing.frames_comparison import frames_comparison + +__module_test__ = "plotting" + +@frames_comparison +def test_matplotlib_to_image_mobject(scene): + # libraries & dataset + df = sns.load_dataset('iris') + # Custom the color, add shade and bandwidth + matplotlib.use('Agg') + plt.figure(figsize=(10,10), dpi=100) + displot = sns.displot( + x=df.sepal_width, + y=df.sepal_length, + cmap="Reds", + kind="kde", + ) + plt.axis('off') + fig = displot.fig + image_mobject = convert_matplotlib_figure_to_image_mobject(fig) + # Display the image mobject + scene.add(image_mobject) + +class TestMatplotlibToImageMobject(Scene): + + def construct(self): + # Make a matplotlib plot + # libraries & dataset + df = sns.load_dataset('iris') + # Custom the color, add shade and bandwidth + matplotlib.use('Agg') + plt.figure(figsize=(10,10), dpi=100) + displot = sns.displot( + x=df.sepal_width, + y=df.sepal_length, + cmap="Reds", + kind="kde", + ) + plt.axis('off') + fig = displot.fig + image_mobject = convert_matplotlib_figure_to_image_mobject(fig) + # Display the image mobject + self.add(image_mobject) + + +class HexabinScene(Scene): + + def construct(self): + # Fixing random state for reproducibility + np.random.seed(19680801) + n = 100_000 + x = np.random.standard_normal(n) + y = x + 1.0 * np.random.standard_normal(n) + xlim = -4, 4 + ylim = -4, 4 + + fig, ax0 = plt.subplots(1, figsize=(5, 5)) + + hb = ax0.hexbin(x, y, gridsize=50, cmap='inferno') + ax0.set(xlim=xlim, ylim=ylim) + + self.add(convert_matplotlib_figure_to_image_mobject(fig))