Files
2022-08-28 19:11:56 -07:00

194 lines
7.1 KiB
Python

import random
from pathlib import Path
from PIL import Image
from manim import *
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
from manim_ml.neural_network.layers.image import ImageLayer
from manim_ml.neural_network.layers.vector import VectorLayer
from manim_ml.neural_network.neural_network import NeuralNetwork
ROOT_DIR = Path(__file__).parents[2]
config.pixel_height = 1080
config.pixel_width = 1080
config.frame_height = 8.3
config.frame_width = 8.3
class GAN(Mobject):
"""Generative Adversarial Network"""
def __init__(self):
super().__init__()
self.make_entities()
self.place_entities()
self.titles = self.make_titles()
def make_entities(self, image_height=1.2):
"""Makes all of the network entities"""
# Make the fake image layer
default_image = Image.open(ROOT_DIR / 'assets/gan/fake_image.png')
numpy_image = np.asarray(default_image)
self.fake_image_layer = ImageLayer(numpy_image, height=image_height, show_image_on_create=False)
# Make the Generator Network
self.generator = NeuralNetwork([
EmbeddingLayer(covariance=np.array([[3.0, 0], [0, 3.0]])).scale(1.3),
FeedForwardLayer(3),
FeedForwardLayer(5),
self.fake_image_layer
], layer_spacing=0.1)
self.add(self.generator)
# Make the Discriminator
self.discriminator = NeuralNetwork([
FeedForwardLayer(5),
FeedForwardLayer(1),
VectorLayer(1, value_func=lambda: random.uniform(0, 1)),
], layer_spacing=0.1)
self.add(self.discriminator)
# Make Ground Truth Dataset
default_image = Image.open(ROOT_DIR / 'assets/gan/real_image.jpg')
numpy_image = np.asarray(default_image)
self.ground_truth_layer = ImageLayer(numpy_image, height=image_height)
self.add(self.ground_truth_layer)
self.scale(1)
def place_entities(self):
"""Positions entities in correct places"""
# Place relative to generator
# Place the ground_truth image layer
self.ground_truth_layer.next_to(self.fake_image_layer, DOWN, 0.8)
# Group the images
image_group = Group(self.ground_truth_layer, self.fake_image_layer)
# Move the discriminator to the right of thee generator
self.discriminator.next_to(self.generator, RIGHT, 0.2)
self.discriminator.match_y(image_group)
# Move the discriminator to the height of the center of the image_group
# self.discriminator.match_y(image_group)
# self.ground_truth_layer.next_to(self.fake_image_layer, DOWN, 0.5)
def make_titles(self):
"""Makes titles for the different entities"""
titles = VGroup()
self.ground_truth_layer_title = Text("Real Image").scale(0.3)
self.ground_truth_layer_title.next_to(self.ground_truth_layer, UP, 0.1)
self.add(self.ground_truth_layer_title)
titles.add(self.ground_truth_layer_title)
self.fake_image_layer_title = Text("Fake Image").scale(0.3)
self.fake_image_layer_title.next_to(self.fake_image_layer, UP, 0.1)
self.add(self.fake_image_layer_title)
titles.add(self.fake_image_layer_title)
# Overhead title
overhead_title = Text("Generative Adversarial Network").scale(0.75)
overhead_title.shift(np.array([0, 3.5, 0]))
titles.add(overhead_title)
# Probability title
self.probability_title = Text("Probability").scale(0.5)
self.probability_title.move_to(self.discriminator.input_layers[-2])
self.probability_title.shift(UP)
self.probability_title.shift(RIGHT*1.05)
titles.add(self.probability_title)
return titles
def make_highlight_generator_rectangle(self):
"""Returns animation that highlights the generators contents"""
group = VGroup()
generator_surrounding_group = Group(
self.generator,
self.fake_image_layer_title
)
generator_surrounding_rectangle = SurroundingRectangle(
generator_surrounding_group,
buff=0.1,
stroke_width=4.0,
color="#0FFF50"
)
group.add(generator_surrounding_rectangle)
title = Text("Generator").scale(0.5)
title.next_to(generator_surrounding_rectangle, UP, 0.2)
group.add(title)
return group
def make_highlight_discriminator_rectangle(self):
"""Makes a rectangle for highlighting the discriminator"""
discriminator_group = Group(
self.discriminator,
self.fake_image_layer,
self.ground_truth_layer,
self.fake_image_layer_title,
self.probability_title
)
group = VGroup()
discriminator_surrounding_rectangle = SurroundingRectangle(
discriminator_group,
buff=0.05,
stroke_width=4.0,
color="#0FFF50"
)
group.add(discriminator_surrounding_rectangle)
title = Text("Discriminator").scale(0.5)
title.next_to(discriminator_surrounding_rectangle, UP, 0.2)
group.add(title)
return group
def make_generator_forward_pass(self):
"""Makes forward pass of the generator"""
forward_pass = self.generator.make_forward_pass_animation(dist_theme="ellipse")
return forward_pass
def make_discriminator_forward_pass(self):
"""Makes forward pass of the discriminator"""
disc_forward = self.discriminator.make_forward_pass_animation()
return disc_forward
@override_animation(Create)
def _create_override(self):
"""Overrides create"""
animation_group = AnimationGroup(
Create(self.generator),
Create(self.discriminator),
Create(self.ground_truth_layer),
Create(self.titles)
)
return animation_group
class GANScene(Scene):
"""GAN Scene"""
def construct(self):
gan = GAN().scale(1.70)
gan.move_to(ORIGIN)
gan.shift(DOWN*0.35)
gan.shift(LEFT*0.1)
self.play(Create(gan), run_time=3)
# Highlight generator
highlight_generator_rectangle = gan.make_highlight_generator_rectangle()
self.play(Create(highlight_generator_rectangle), run_time=1)
# Generator forward pass
gen_forward_pass = gan.make_generator_forward_pass()
self.play(gen_forward_pass, run_time=5)
# Fade out generator highlight
self.play(Uncreate(highlight_generator_rectangle), run_time=1)
# Highlight discriminator
highlight_discriminator_rectangle = gan.make_highlight_discriminator_rectangle()
self.play(Create(highlight_discriminator_rectangle), run_time=1)
# Discriminator forward pass
discriminator_forward_pass = gan.make_discriminator_forward_pass()
self.play(discriminator_forward_pass, run_time=5)
# Unhighlight discriminator
self.play(Uncreate(highlight_discriminator_rectangle), run_time=1)