import numpy as np

from manim import *
from manim_ml.neural_network.layers.convolutional3d import Convolutional3DLayer
from manim_ml.neural_network.layers.image import ImageLayer
from manim_ml.neural_network.layers.parent_layers import (
    ThreeDLayer,
    VGroupNeuralNetworkLayer,
)
from manim_ml.gridded_rectangle import GriddedRectangle


class ImageToConvolutional3DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
    """Handles rendering a convolutional layer for a nn"""

    input_class = ImageLayer
    output_class = Convolutional3DLayer

    def __init__(
        self, input_layer: ImageLayer, output_layer: Convolutional3DLayer, **kwargs
    ):
        super().__init__(input_layer, output_layer, **kwargs)
        self.input_layer = input_layer
        self.output_layer = output_layer

    def make_forward_pass_animation(self, run_time=5, layer_args={}, **kwargs):
        """Maps image to convolutional layer"""
        # Transform the image from the input layer to the
        num_image_channels = self.input_layer.num_channels
        if num_image_channels == 1 or num_image_channels == 3:  # TODO fix this later
            return self.grayscale_image_animation()
        elif num_image_channels == 3:
            return self.rbg_image_animation()
        else:
            raise Exception(
                f"Unrecognized number of image channels: {num_image_channels}"
            )

    def rbg_image_animation(self):
        """Handles animation for 3 channel image"""
        image_mobject = self.input_layer.image_mobject
        # TODO get each color channel and turn it into an image
        # TODO create image mobjects for each channel and transform
        # it to the feature maps of the output_layer
        raise NotImplementedError()

    def grayscale_image_animation(self):
        """Handles animation for 1 channel image"""
        animations = []
        image_mobject = self.input_layer.image_mobject
        target_feature_map = self.output_layer.feature_maps[0]
        # Map image mobject to feature map
        # Make rotation of image
        rotation = ApplyMethod(
            image_mobject.rotate,
            ThreeDLayer.rotation_angle,
            ThreeDLayer.rotation_axis,
            image_mobject.get_center(),
            run_time=0.5,
        )
        """
        x_rotation = ApplyMethod(
            image_mobject.rotate,
            ThreeDLayer.three_d_x_rotation,
            [1, 0, 0], 
            image_mobject.get_center(),
            run_time=0.5
        )
        y_rotation = ApplyMethod(
            image_mobject.rotate,
            ThreeDLayer.three_d_y_rotation,
            [0, 1, 0], 
            image_mobject.get_center(),
            run_time=0.5
        )
        """
        # Set opacity
        set_opacity = ApplyMethod(image_mobject.set_opacity, 0.2, run_time=0.5)
        # Scale the max of width or height to the
        # width of the feature_map
        max_width_height = max(image_mobject.width, image_mobject.height)
        scale_factor = target_feature_map.width / max_width_height
        scale_image = ApplyMethod(image_mobject.scale, scale_factor, run_time=0.5)
        # Move the image
        move_image = ApplyMethod(image_mobject.move_to, target_feature_map)
        # Compose the animations
        animation = Succession(
            rotation,
            scale_image,
            set_opacity,
            move_image,
        )
        return animation

    def scale(self, scale_factor, **kwargs):
        super().scale(scale_factor, **kwargs)

    @override_animation(Create)
    def _create_override(self, **kwargs):
        return AnimationGroup()