""" Here is a animated explanatory figure for the "Oracle Guided Image Synthesis with Relative Queries" paper. """ from pathlib import Path from manim import * from manim_ml.neural_network.layers import triplet from manim_ml.neural_network.layers.image import ImageLayer from manim_ml.neural_network.layers.paired_query import PairedQueryLayer from manim_ml.neural_network.layers.triplet import TripletLayer from manim_ml.neural_network.neural_network import NeuralNetwork from manim_ml.neural_network.layers import FeedForwardLayer, EmbeddingLayer from manim_ml.neural_network.layers.util import get_connective_layer import os from manim_ml.probability import GaussianDistribution # Make the specific scene config.pixel_height = 1200 config.pixel_width = 1900 config.frame_height = 6.0 config.frame_width = 6.0 ROOT_DIR = Path(__file__).parents[3] class Localizer: """ Holds the localizer object, which contains the queries, images, etc. needed to represent a localization run. """ def __init__(self, axes): # Set dummy values for these self.index = -1 self.axes = axes self.num_queries = 3 self.assets_path = ROOT_DIR / "assets/oracle_guidance" self.ground_truth_image_path = self.assets_path / "ground_truth.jpg" self.ground_truth_location = np.array([2, 3]) # Prior distribution print("initial gaussian") self.prior_distribution = GaussianDistribution( self.axes, mean=np.array([0.0, 0.0]), cov=np.array([[3, 0], [0, 3]]), dist_theme="ellipse", color=GREEN, ) # Define the query images and embedded locations # Contains image paths [(positive_path, negative_path), ...] self.query_image_paths = [ ( os.path.join(self.assets_path, "positive_1.jpg"), os.path.join(self.assets_path, "negative_1.jpg"), ), ( os.path.join(self.assets_path, "positive_2.jpg"), os.path.join(self.assets_path, "negative_2.jpg"), ), ( os.path.join(self.assets_path, "positive_3.jpg"), os.path.join(self.assets_path, "negative_3.jpg"), ), ] # Contains 2D locations for each image [([2, 3], [2, 4]), ...] self.query_locations = [ (np.array([-1, -1]), np.array([1, 1])), (np.array([1, -1]), np.array([-1, 1])), (np.array([0.3, -0.6]), np.array([-0.5, 0.7])), ] # Make the covariances for each query self.query_covariances = [ (np.array([[0.3, 0], [0.0, 0.2]]), np.array([[0.2, 0], [0.0, 0.2]])), (np.array([[0.2, 0], [0.0, 0.2]]), np.array([[0.2, 0], [0.0, 0.2]])), (np.array([[0.2, 0], [0.0, 0.2]]), np.array([[0.2, 0], [0.0, 0.2]])), ] # Posterior distributions over time GaussianDistribution objects self.posterior_distributions = [ GaussianDistribution( self.axes, dist_theme="ellipse", color=GREEN, mean=np.array([-0.3, -0.3]), cov=np.array([[5, -4], [-4, 6]]), ).scale(0.6), GaussianDistribution( self.axes, dist_theme="ellipse", color=GREEN, mean=np.array([0.25, -0.25]), cov=np.array([[3, -2], [-2, 4]]), ).scale(0.35), GaussianDistribution( self.axes, dist_theme="ellipse", color=GREEN, mean=np.array([0.4, -0.35]), cov=np.array([[1, 0], [0, 1]]), ).scale(0.3), ] # Some assumptions assert len(self.query_locations) == len(self.query_image_paths) assert len(self.query_locations) == len(self.posterior_distributions) def __iter__(self): return self def __next__(self): """Steps through each localization time instance""" if self.index < len(self.query_image_paths): self.index += 1 else: raise StopIteration # Return query_paths, query_locations, posterior out_tuple = ( self.query_image_paths[self.index], self.query_locations[self.index], self.posterior_distributions[self.index], self.query_covariances[self.index], ) return out_tuple class OracleGuidanceVisualization(Scene): def __init__(self): super().__init__() self.neural_network, self.embedding_layer = self.make_vae() self.localizer = iter(Localizer(self.embedding_layer.axes)) self.subtitle = None self.title = None # Set image paths # VAE embedding animation image paths self.assets_path = ROOT_DIR / "assets/oracle_guidance" self.input_embed_image_path = os.path.join(self.assets_path, "input_image.jpg") self.output_embed_image_path = os.path.join( self.assets_path, "output_image.jpg" ) def make_vae(self): """Makes a simple VAE architecture""" embedding_layer = EmbeddingLayer(dist_theme="ellipse") self.encoder = NeuralNetwork( [ FeedForwardLayer(5), FeedForwardLayer(3), embedding_layer, ] ) self.decoder = NeuralNetwork( [ FeedForwardLayer(3), FeedForwardLayer(5), ] ) neural_network = NeuralNetwork([self.encoder, self.decoder]) neural_network.shift(DOWN * 0.4) return neural_network, embedding_layer @override_animation(Create) def _create_animation(self): animation_group = AnimationGroup(Create(self.neural_network)) return animation_group def insert_at_start(self, layer, create=True): """Inserts a layer at the beggining of the network""" # Note: Does not move the rest of the network current_first_layer = self.encoder.all_layers[0] # Get connective layer connective_layer = get_connective_layer(layer, current_first_layer) # Insert both layers self.encoder.all_layers.insert(0, layer) self.encoder.all_layers.insert(1, connective_layer) # Move layers to the correct location # TODO: Fix this cause its hacky layer.shift(DOWN * 0.4) layer.shift(LEFT * 2.35) # Make insert animation if not create: animation_group = AnimationGroup(Create(connective_layer)) else: animation_group = AnimationGroup(Create(layer), Create(connective_layer)) self.play(animation_group) def remove_start_layer(self): """Removes the first layer of the network""" first_layer = self.encoder.all_layers.remove_at_index(0) first_connective = self.encoder.all_layers.remove_at_index(0) # Make remove animations animation_group = AnimationGroup( FadeOut(first_layer), FadeOut(first_connective) ) self.play(animation_group) def insert_at_end(self, layer): """Inserts the given layer at the end of the network""" current_last_layer = self.decoder.all_layers[-1] # Get connective layer connective_layer = get_connective_layer(current_last_layer, layer) # Insert both layers self.decoder.all_layers.add(connective_layer) self.decoder.all_layers.add(layer) # Move layers to the correct location # TODO: Fix this cause its hacky layer.shift(DOWN * 0.4) layer.shift(RIGHT * 2.35) # Make insert animation animation_group = AnimationGroup(Create(layer), Create(connective_layer)) self.play(animation_group) def remove_end_layer(self): """Removes the lsat layer of the network""" first_layer = self.decoder.all_layers.remove_at_index(-1) first_connective = self.decoder.all_layers.remove_at_index(-1) # Make remove animations animation_group = AnimationGroup( FadeOut(first_layer), FadeOut(first_connective) ) self.play(animation_group) def change_title(self, text, title_location=np.array([0, 1.25, 0]), font_size=24): """Changes title to the given text""" if self.title is None: self.title = Text(text, font_size=font_size) self.title.move_to(title_location) self.add(self.title) self.play(Write(self.title), run_time=1) self.wait(1) return self.play(Unwrite(self.title)) new_title = Text(text, font_size=font_size) new_title.move_to(self.title) self.title = new_title self.wait(0.1) self.play(Write(self.title)) def change_subtitle(self, text, title_location=np.array([0, 0.9, 0]), font_size=20): """Changes subtitle to the next algorithm step""" if self.subtitle is None: self.subtitle = Text(text, font_size=font_size) self.subtitle.move_to(title_location) self.play(Write(self.subtitle)) return self.play(Unwrite(self.subtitle)) new_title = Text(text, font_size=font_size) new_title.move_to(title_location) self.subtitle = new_title self.wait(0.1) self.play(Write(self.subtitle)) def make_embed_input_image_animation(self, input_image_path, output_image_path): """Makes embed input image animation""" # insert the input image at the begginging input_image_layer = ImageLayer.from_path(input_image_path) input_image_layer.scale(0.6) current_first_layer = self.encoder.all_layers[0] # Get connective layer connective_layer = get_connective_layer(input_image_layer, current_first_layer) # Insert both layers self.encoder.all_layers.insert(0, input_image_layer) self.encoder.all_layers.insert(1, connective_layer) # Move layers to the correct location # TODO: Fix this cause its hacky input_image_layer.shift(DOWN * 0.4) input_image_layer.shift(LEFT * 2.35) # Play full forward pass forward_pass = self.neural_network.make_forward_pass_animation( layer_args={ self.encoder: { self.embedding_layer: { "dist_args": { "cov": np.array([[1.5, 0], [0, 1.5]]), "mean": np.array([0.5, 0.5]), "dist_theme": "ellipse", "color": ORANGE, } } } } ) self.play(forward_pass) # insert the output image at the end output_image_layer = ImageLayer.from_path(output_image_path) output_image_layer.scale(0.6) self.insert_at_end(output_image_layer) # Remove the input and output layers self.remove_start_layer() self.remove_end_layer() # Remove the latent distribution self.play(FadeOut(self.embedding_layer.latent_distribution)) def make_localization_time_step(self, old_posterior): """ Performs one query update for the localization procedure Procedure: a. Embed query input images b. Oracle is asked a query c. Query is embedded d. Show posterior update e. Show current recomendation """ # Helper functions def embed_query_to_latent_space(query_locations, query_covariance): """Makes animation for a paired query""" # Assumes first layer of neural network is a PairedQueryLayer # Make the embedding animation # Wait self.play(Wait(1)) # Embed the query to latent space self.change_subtitle("3. Embed the Query in Latent Space") # Make forward pass animation self.embedding_layer.paired_query_mode = True # Make embedding embed query animation embed_query_animation = self.encoder.make_forward_pass_animation( run_time=5, layer_args={ self.embedding_layer: { "positive_dist_args": { "cov": query_covariance[0], "mean": query_locations[0], "dist_theme": "ellipse", "color": BLUE, }, "negative_dist_args": { "cov": query_covariance[1], "mean": query_locations[1], "dist_theme": "ellipse", "color": RED, }, } }, ) self.play(embed_query_animation) # Access localizer information query_paths, query_locations, posterior_distribution, query_covariances = next( self.localizer ) positive_path, negative_path = query_paths # Make subtitle for present user with query self.change_subtitle("2. Present User with Query") # Insert the layer into the encoder query_layer = PairedQueryLayer.from_paths( positive_path, negative_path, grayscale=False ) query_layer.scale(0.5) self.insert_at_start(query_layer) # Embed query to latent space query_to_latent_space_animation = embed_query_to_latent_space( query_locations, query_covariances ) # Wait self.play(Wait(1)) # Update the posterior self.change_subtitle("4. Update the Posterior") # Remove the old posterior self.play(ReplacementTransform(old_posterior, posterior_distribution)) """ self.play( self.embedding_layer.remove_gaussian_distribution(self.localizer.posterior_distribution) ) """ # self.embedding_layer.add_gaussian_distribution(posterior_distribution) # self.localizer.posterior_distribution = posterior_distribution # Remove query layer self.remove_start_layer() # Remove query ellipses fade_outs = [] for dist in self.embedding_layer.gaussian_distributions: self.embedding_layer.gaussian_distributions.remove(dist) fade_outs.append(FadeOut(dist)) if not len(fade_outs) == 0: fade_outs = AnimationGroup(*fade_outs) self.play(fade_outs) return posterior_distribution def make_generate_estimate_animation(self, estimate_image_path): """Makes the generate estimate animation""" # Change embedding layer mode self.embedding_layer.paired_query_mode = False # Sample from posterior distribution # self.embedding_layer.latent_distribution = self.localizer.posterior_distribution emb_to_ff_ind = self.neural_network.all_layers.index_of(self.encoder) embedding_to_ff = self.neural_network.all_layers[emb_to_ff_ind + 1] self.play(embedding_to_ff.make_forward_pass_animation()) # Pass through decoder self.play(self.decoder.make_forward_pass_animation(), run_time=1) # Create Image layer after the decoder output_image_layer = ImageLayer.from_path(estimate_image_path) output_image_layer.scale(0.5) self.insert_at_end(output_image_layer) # Wait self.wait(1) # Remove the image at the end print(self.neural_network) self.remove_end_layer() def make_triplet_forward_animation(self): """Make triplet forward animation""" # Make triplet layer anchor_path = os.path.join(self.assets_path, "anchor.jpg") positive_path = os.path.join(self.assets_path, "positive.jpg") negative_path = os.path.join(self.assets_path, "negative.jpg") triplet_layer = TripletLayer.from_paths( anchor_path, positive_path, negative_path, grayscale=False, font_size=100, buff=1.05, ) triplet_layer.scale(0.10) self.insert_at_start(triplet_layer) # Make latent triplet animation self.play( self.encoder.make_forward_pass_animation( layer_args={ self.embedding_layer: { "triplet_args": { "anchor_dist": { "cov": np.array([[0.3, 0], [0, 0.3]]), "mean": np.array([0.7, 1.4]), "dist_theme": "ellipse", "color": BLUE, }, "positive_dist": { "cov": np.array([[0.2, 0], [0, 0.2]]), "mean": np.array([0.8, -0.4]), "dist_theme": "ellipse", "color": GREEN, }, "negative_dist": { "cov": np.array([[0.4, 0], [0, 0.25]]), "mean": np.array([-1, -1.2]), "dist_theme": "ellipse", "color": RED, }, } } }, run_time=3, ) ) def construct(self): """ Makes the whole visualization. 1. Create the Architecture a. Create the traditional VAE architecture with images 2. The Localization Procedure 3. The Training Procedure """ # 1. Create the Architecture self.neural_network.scale(1.2) create_vae = Create(self.neural_network) self.play(create_vae, run_time=3) # Make changing title self.change_title("Oracle Guided Image Synthesis\n with Relative Queries") # 2. The Localization Procedure self.change_title("The Localization Procedure") # Make algorithm subtitle self.change_subtitle("Algorithm Steps") # Wait self.play(Wait(1)) # Make prior distribution subtitle self.change_subtitle("1. Calculate Prior Distribution") # Draw the prior distribution self.play(Create(self.localizer.prior_distribution)) old_posterior = self.localizer.prior_distribution # For N queries update the posterior for query_index in range(self.localizer.num_queries): # Make localization iteration old_posterior = self.make_localization_time_step(old_posterior) self.play(Wait(1)) if not query_index == self.localizer.num_queries - 1: # Repeat self.change_subtitle("5. Repeat") # Wait a second self.play(Wait(1)) # Generate final estimate self.change_subtitle("5. Generate Estimate Image") # Generate an estimate image estimate_image_path = os.path.join(self.assets_path, "estimate_image.jpg") self.make_generate_estimate_animation(estimate_image_path) self.wait(1) # Remove old posterior self.play(FadeOut(old_posterior)) # 3. The Training Procedure self.change_title("The Training Procedure") # Make training animation # Do an Image forward pass self.change_subtitle("1. Unsupervised Image Reconstruction") self.make_embed_input_image_animation( self.input_embed_image_path, self.output_embed_image_path ) self.wait(1) # Do triplet forward pass self.change_subtitle("2. Triplet Loss in Latent Space") self.make_triplet_forward_animation() self.wait(1)