mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-17 18:55:54 +08:00
81 lines
2.4 KiB
Python
81 lines
2.4 KiB
Python
from manim import *
|
|
import numpy as np
|
|
from manim_ml.utils.mobjects.image import GrayscaleImageMobject
|
|
from manim_ml.neural_network.layers.parent_layers import NeuralNetworkLayer
|
|
|
|
from PIL import Image
|
|
|
|
|
|
class ImageLayer(NeuralNetworkLayer):
|
|
"""Single Image Layer for Neural Network"""
|
|
|
|
def __init__(self, numpy_image, height=1.5, show_image_on_create=True, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.image_height = height
|
|
self.numpy_image = numpy_image
|
|
self.show_image_on_create = show_image_on_create
|
|
|
|
def construct_layer(self, input_layer, output_layer):
|
|
"""Construct layer method
|
|
|
|
Parameters
|
|
----------
|
|
input_layer :
|
|
Input layer
|
|
output_layer :
|
|
Output layer
|
|
"""
|
|
if len(np.shape(self.numpy_image)) == 2:
|
|
# Assumed Grayscale
|
|
self.num_channels = 1
|
|
self.image_mobject = GrayscaleImageMobject(
|
|
self.numpy_image, height=self.image_height
|
|
)
|
|
elif len(np.shape(self.numpy_image)) == 3:
|
|
# Assumed RGB
|
|
self.num_channels = 3
|
|
self.image_mobject = ImageMobject(self.numpy_image).scale_to_fit_height(
|
|
self.image_height
|
|
)
|
|
self.add(self.image_mobject)
|
|
|
|
@classmethod
|
|
def from_path(cls, image_path, grayscale=True, **kwargs):
|
|
"""Creates a query using the paths"""
|
|
# Load images from path
|
|
image = Image.open(image_path)
|
|
numpy_image = np.asarray(image)
|
|
# Make the layer
|
|
image_layer = cls(numpy_image, **kwargs)
|
|
|
|
return image_layer
|
|
|
|
@override_animation(Create)
|
|
def _create_override(self, **kwargs):
|
|
debug_mode = False
|
|
if debug_mode:
|
|
return FadeIn(SurroundingRectangle(self.image_mobject))
|
|
if self.show_image_on_create:
|
|
return FadeIn(self.image_mobject)
|
|
else:
|
|
return AnimationGroup()
|
|
|
|
def make_forward_pass_animation(self, layer_args={}, **kwargs):
|
|
return AnimationGroup()
|
|
|
|
def get_right(self):
|
|
"""Override get right"""
|
|
return self.image_mobject.get_right()
|
|
|
|
def scale(self, scale_factor, **kwargs):
|
|
"""Scales the image mobject"""
|
|
self.image_mobject.scale(scale_factor)
|
|
|
|
@property
|
|
def width(self):
|
|
return self.image_mobject.width
|
|
|
|
@property
|
|
def height(self):
|
|
return self.image_mobject.height
|