mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-18 03:05:23 +08:00
102 lines
3.5 KiB
Python
102 lines
3.5 KiB
Python
import numpy as np
|
|
|
|
from manim import *
|
|
from manim_ml.neural_network.layers.convolutional3d import Convolutional3DLayer
|
|
from manim_ml.neural_network.layers.image import ImageLayer
|
|
from manim_ml.neural_network.layers.parent_layers import (
|
|
ThreeDLayer,
|
|
VGroupNeuralNetworkLayer,
|
|
)
|
|
from manim_ml.gridded_rectangle import GriddedRectangle
|
|
|
|
|
|
class ImageToConvolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
|
|
"""Handles rendering a convolutional layer for a nn"""
|
|
|
|
input_class = ImageLayer
|
|
output_class = Convolutional3DLayer
|
|
|
|
def __init__(
|
|
self, input_layer: ImageLayer, output_layer: Convolutional3DLayer, **kwargs
|
|
):
|
|
super().__init__(input_layer, output_layer, **kwargs)
|
|
self.input_layer = input_layer
|
|
self.output_layer = output_layer
|
|
|
|
def make_forward_pass_animation(self, run_time=5, layer_args={}, **kwargs):
|
|
"""Maps image to convolutional layer"""
|
|
# Transform the image from the input layer to the
|
|
num_image_channels = self.input_layer.num_channels
|
|
if num_image_channels == 3:
|
|
return self.rbg_image_animation()
|
|
elif num_image_channels == 1:
|
|
return self.grayscale_image_animation()
|
|
else:
|
|
raise Exception(
|
|
f"Unrecognized number of image channels: {num_image_channels}"
|
|
)
|
|
|
|
def rbg_image_animation(self):
|
|
"""Handles animation for 3 channel image"""
|
|
image_mobject = self.input_layer.image_mobject
|
|
# TODO get each color channel and turn it into an image
|
|
# TODO create image mobjects for each channel and transform
|
|
# it to the feature maps of the output_layer
|
|
raise NotImplementedError()
|
|
pass
|
|
|
|
def grayscale_image_animation(self):
|
|
"""Handles animation for 1 channel image"""
|
|
animations = []
|
|
image_mobject = self.input_layer.image_mobject
|
|
target_feature_map = self.output_layer.feature_maps[0]
|
|
# Map image mobject to feature map
|
|
# Make rotation of image
|
|
rotation = ApplyMethod(
|
|
image_mobject.rotate,
|
|
ThreeDLayer.rotation_angle,
|
|
ThreeDLayer.rotation_axis,
|
|
image_mobject.get_center(),
|
|
run_time=0.5,
|
|
)
|
|
"""
|
|
x_rotation = ApplyMethod(
|
|
image_mobject.rotate,
|
|
ThreeDLayer.three_d_x_rotation,
|
|
[1, 0, 0],
|
|
image_mobject.get_center(),
|
|
run_time=0.5
|
|
)
|
|
y_rotation = ApplyMethod(
|
|
image_mobject.rotate,
|
|
ThreeDLayer.three_d_y_rotation,
|
|
[0, 1, 0],
|
|
image_mobject.get_center(),
|
|
run_time=0.5
|
|
)
|
|
"""
|
|
# Set opacity
|
|
set_opacity = ApplyMethod(image_mobject.set_opacity, 0.2, run_time=0.5)
|
|
# Scale the max of width or height to the
|
|
# width of the feature_map
|
|
max_width_height = max(image_mobject.width, image_mobject.height)
|
|
scale_factor = target_feature_map.rectangle_width / max_width_height
|
|
scale_image = ApplyMethod(image_mobject.scale, scale_factor, run_time=0.5)
|
|
# Move the image
|
|
move_image = ApplyMethod(image_mobject.move_to, target_feature_map)
|
|
# Compose the animations
|
|
animation = Succession(
|
|
rotation,
|
|
scale_image,
|
|
set_opacity,
|
|
move_image,
|
|
)
|
|
return animation
|
|
|
|
def scale(self, scale_factor, **kwargs):
|
|
super().scale(scale_factor, **kwargs)
|
|
|
|
@override_animation(Create)
|
|
def _create_override(self, **kwargs):
|
|
return AnimationGroup()
|