Files
2022-04-21 23:18:58 -04:00

39 lines
1.3 KiB
Python

from manim import *
from manim_ml.image import GrayscaleImageMobject
from manim_ml.neural_network.layers.parent_layers import NeuralNetworkLayer
class ImageLayer(NeuralNetworkLayer):
"""Single Image Layer for Neural Network"""
def __init__(self, numpy_image, height=1.5, **kwargs):
super().__init__(**kwargs)
self.numpy_image = numpy_image
if len(np.shape(self.numpy_image)) == 2:
# Assumed Grayscale
self.image_mobject = GrayscaleImageMobject(self.numpy_image, height=height)
elif len(np.shape(self.numpy_image)) == 3:
# Assumed RGB
self.image_mobject = ImageMobject(self.numpy_image).scale_to_fit_height(height)
self.add(self.image_mobject)
@override_animation(Create)
def _create_override(self, **kwargs):
debug_mode = False
if debug_mode:
return FadeIn(SurroundingRectangle(self.image_mobject))
return FadeIn(self.image_mobject)
def make_forward_pass_animation(self):
return FadeIn(self.image_mobject)
# def move_to(self, location):
# """Override of move to"""
# self.image_mobject.move_to(location)
def get_right(self):
"""Override get right"""
return self.image_mobject.get_right()
@property
def width(self):
return self.image_mobject.width