mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-20 03:57:40 +08:00
194 lines
7.1 KiB
Python
194 lines
7.1 KiB
Python
import random
|
|
from pathlib import Path
|
|
|
|
from PIL import Image
|
|
from manim import *
|
|
from manim_ml.neural_network.layers.embedding import EmbeddingLayer
|
|
from manim_ml.neural_network.layers.feed_forward import FeedForwardLayer
|
|
from manim_ml.neural_network.layers.image import ImageLayer
|
|
from manim_ml.neural_network.layers.vector import VectorLayer
|
|
|
|
from manim_ml.neural_network.neural_network import NeuralNetwork
|
|
|
|
ROOT_DIR = Path(__file__).parents[2]
|
|
|
|
config.pixel_height = 1080
|
|
config.pixel_width = 1080
|
|
config.frame_height = 8.3
|
|
config.frame_width = 8.3
|
|
|
|
class GAN(Mobject):
|
|
"""Generative Adversarial Network"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.make_entities()
|
|
self.place_entities()
|
|
self.titles = self.make_titles()
|
|
|
|
def make_entities(self, image_height=1.2):
|
|
"""Makes all of the network entities"""
|
|
# Make the fake image layer
|
|
default_image = Image.open(ROOT_DIR / 'assets/gan/fake_image.png')
|
|
numpy_image = np.asarray(default_image)
|
|
self.fake_image_layer = ImageLayer(numpy_image, height=image_height, show_image_on_create=False)
|
|
# Make the Generator Network
|
|
self.generator = NeuralNetwork([
|
|
EmbeddingLayer(covariance=np.array([[3.0, 0], [0, 3.0]])).scale(1.3),
|
|
FeedForwardLayer(3),
|
|
FeedForwardLayer(5),
|
|
self.fake_image_layer
|
|
], layer_spacing=0.1)
|
|
|
|
self.add(self.generator)
|
|
# Make the Discriminator
|
|
self.discriminator = NeuralNetwork([
|
|
FeedForwardLayer(5),
|
|
FeedForwardLayer(1),
|
|
VectorLayer(1, value_func=lambda: random.uniform(0, 1)),
|
|
], layer_spacing=0.1)
|
|
self.add(self.discriminator)
|
|
# Make Ground Truth Dataset
|
|
default_image = Image.open(ROOT_DIR / 'assets/gan/real_image.jpg')
|
|
numpy_image = np.asarray(default_image)
|
|
self.ground_truth_layer = ImageLayer(numpy_image, height=image_height)
|
|
self.add(self.ground_truth_layer)
|
|
|
|
self.scale(1)
|
|
|
|
def place_entities(self):
|
|
"""Positions entities in correct places"""
|
|
# Place relative to generator
|
|
# Place the ground_truth image layer
|
|
self.ground_truth_layer.next_to(self.fake_image_layer, DOWN, 0.8)
|
|
# Group the images
|
|
image_group = Group(self.ground_truth_layer, self.fake_image_layer)
|
|
# Move the discriminator to the right of thee generator
|
|
self.discriminator.next_to(self.generator, RIGHT, 0.2)
|
|
self.discriminator.match_y(image_group)
|
|
# Move the discriminator to the height of the center of the image_group
|
|
# self.discriminator.match_y(image_group)
|
|
# self.ground_truth_layer.next_to(self.fake_image_layer, DOWN, 0.5)
|
|
|
|
def make_titles(self):
|
|
"""Makes titles for the different entities"""
|
|
titles = VGroup()
|
|
|
|
self.ground_truth_layer_title = Text("Real Image").scale(0.3)
|
|
self.ground_truth_layer_title.next_to(self.ground_truth_layer, UP, 0.1)
|
|
self.add(self.ground_truth_layer_title)
|
|
titles.add(self.ground_truth_layer_title)
|
|
self.fake_image_layer_title = Text("Fake Image").scale(0.3)
|
|
self.fake_image_layer_title.next_to(self.fake_image_layer, UP, 0.1)
|
|
self.add(self.fake_image_layer_title)
|
|
titles.add(self.fake_image_layer_title)
|
|
# Overhead title
|
|
overhead_title = Text("Generative Adversarial Network").scale(0.75)
|
|
overhead_title.shift(np.array([0, 3.5, 0]))
|
|
titles.add(overhead_title)
|
|
# Probability title
|
|
self.probability_title = Text("Probability").scale(0.5)
|
|
self.probability_title.move_to(self.discriminator.input_layers[-2])
|
|
self.probability_title.shift(UP)
|
|
self.probability_title.shift(RIGHT*1.05)
|
|
titles.add(self.probability_title)
|
|
|
|
return titles
|
|
|
|
def make_highlight_generator_rectangle(self):
|
|
"""Returns animation that highlights the generators contents"""
|
|
group = VGroup()
|
|
|
|
generator_surrounding_group = Group(
|
|
self.generator,
|
|
self.fake_image_layer_title
|
|
)
|
|
|
|
generator_surrounding_rectangle = SurroundingRectangle(
|
|
generator_surrounding_group,
|
|
buff=0.1,
|
|
stroke_width=4.0,
|
|
color="#0FFF50"
|
|
)
|
|
group.add(generator_surrounding_rectangle)
|
|
title = Text("Generator").scale(0.5)
|
|
title.next_to(generator_surrounding_rectangle, UP, 0.2)
|
|
group.add(title)
|
|
|
|
return group
|
|
|
|
def make_highlight_discriminator_rectangle(self):
|
|
"""Makes a rectangle for highlighting the discriminator"""
|
|
discriminator_group = Group(
|
|
self.discriminator,
|
|
self.fake_image_layer,
|
|
self.ground_truth_layer,
|
|
self.fake_image_layer_title,
|
|
self.probability_title
|
|
)
|
|
|
|
group = VGroup()
|
|
|
|
discriminator_surrounding_rectangle = SurroundingRectangle(
|
|
discriminator_group,
|
|
buff=0.05,
|
|
stroke_width=4.0,
|
|
color="#0FFF50"
|
|
)
|
|
group.add(discriminator_surrounding_rectangle)
|
|
title = Text("Discriminator").scale(0.5)
|
|
title.next_to(discriminator_surrounding_rectangle, UP, 0.2)
|
|
group.add(title)
|
|
|
|
return group
|
|
|
|
def make_generator_forward_pass(self):
|
|
"""Makes forward pass of the generator"""
|
|
|
|
forward_pass = self.generator.make_forward_pass_animation(dist_theme="ellipse")
|
|
|
|
return forward_pass
|
|
|
|
def make_discriminator_forward_pass(self):
|
|
"""Makes forward pass of the discriminator"""
|
|
|
|
disc_forward = self.discriminator.make_forward_pass_animation()
|
|
|
|
return disc_forward
|
|
|
|
@override_animation(Create)
|
|
def _create_override(self):
|
|
"""Overrides create"""
|
|
animation_group = AnimationGroup(
|
|
Create(self.generator),
|
|
Create(self.discriminator),
|
|
Create(self.ground_truth_layer),
|
|
Create(self.titles)
|
|
)
|
|
return animation_group
|
|
|
|
class GANScene(Scene):
|
|
"""GAN Scene"""
|
|
|
|
def construct(self):
|
|
gan = GAN().scale(1.70)
|
|
gan.move_to(ORIGIN)
|
|
gan.shift(DOWN*0.35)
|
|
gan.shift(LEFT*0.1)
|
|
self.play(Create(gan), run_time=3)
|
|
# Highlight generator
|
|
highlight_generator_rectangle = gan.make_highlight_generator_rectangle()
|
|
self.play(Create(highlight_generator_rectangle), run_time=1)
|
|
# Generator forward pass
|
|
gen_forward_pass = gan.make_generator_forward_pass()
|
|
self.play(gen_forward_pass, run_time=5)
|
|
# Fade out generator highlight
|
|
self.play(Uncreate(highlight_generator_rectangle), run_time=1)
|
|
# Highlight discriminator
|
|
highlight_discriminator_rectangle = gan.make_highlight_discriminator_rectangle()
|
|
self.play(Create(highlight_discriminator_rectangle), run_time=1)
|
|
# Discriminator forward pass
|
|
discriminator_forward_pass = gan.make_discriminator_forward_pass()
|
|
self.play(discriminator_forward_pass, run_time=5)
|
|
# Unhighlight discriminator
|
|
self.play(Uncreate(highlight_discriminator_rectangle), run_time=1) |