from manim import *
from abc import ABC, abstractmethod


class NeuralNetworkLayer(ABC, Group):
    """Abstract Neural Network Layer class"""

    def __init__(self, text=None, *args, **kwargs):
        super(Group, self).__init__()
        self.title_text = kwargs["title"] if "title" in kwargs else " "
        self.title = Text(self.title_text, font_size=DEFAULT_FONT_SIZE / 3).scale(0.6)
        self.title.next_to(self, UP, 1.2)
        # self.add(self.title)

    @abstractmethod
    def make_forward_pass_animation(self, layer_args={}, **kwargs):
        pass

    @override_animation(Create)
    def _create_override(self):
        return AnimationGroup()

    def __repr__(self):
        return f"{type(self).__name__}"


class VGroupNeuralNetworkLayer(NeuralNetworkLayer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # self.camera = camera

    @abstractmethod
    def make_forward_pass_animation(self, **kwargs):
        pass

    @override_animation(Create)
    def _create_override(self):
        return super()._create_override()


class ThreeDLayer(ABC):
    """Abstract class for 3D layers"""

    # Angle of ThreeD layers is static context
    three_d_x_rotation = 90 * DEGREES  # -90 * DEGREES
    three_d_y_rotation = 0 * DEGREES  # -10 * DEGREES
    rotation_angle = 60 * DEGREES
    rotation_axis = [0.0, 0.9, 0.0]


class ConnectiveLayer(VGroupNeuralNetworkLayer):
    """Forward pass animation for a given pair of layers"""

    @abstractmethod
    def __init__(self, input_layer, output_layer, **kwargs):
        super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
        self.input_layer = input_layer
        self.output_layer = output_layer
        # Handle input and output class
        # assert isinstance(input_layer, self.input_class), f"{input_layer}, {self.input_class}"
        # assert isinstance(output_layer, self.output_class), f"{output_layer}, {self.output_class}"

    @abstractmethod
    def make_forward_pass_animation(self, run_time=2.0, layer_args={}, **kwargs):
        pass

    @override_animation(Create)
    def _create_override(self):
        return super()._create_override()

    def __repr__(self):
        return (
            f"{self.__class__.__name__}("
            + f"input_layer={self.input_layer.__class__.__name__},"
            + f"output_layer={self.output_layer.__class__.__name__},"
            + ")"
        )


class BlankConnective(ConnectiveLayer):
    """Connective layer to be used when the given pair of layers is undefined"""

    def __init__(self, input_layer, output_layer, **kwargs):
        super().__init__(input_layer, output_layer, **kwargs)

    def make_forward_pass_animation(self, run_time=1.5, layer_args={}, **kwargs):
        return AnimationGroup(run_time=run_time)

    @override_animation(Create)
    def _create_override(self):
        return super()._create_override()