Files

106 lines
3.4 KiB
Python

from manim import *
from abc import ABC, abstractmethod
import random
import manim_ml.neural_network.activation_functions.relu as relu
class ActivationFunction(ABC, VGroup):
"""Abstract parent class for defining activation functions"""
def __init__(
self,
function_name=None,
x_range=[-1, 1],
y_range=[-1, 1],
x_length=0.5,
y_length=0.3,
show_function_name=True,
active_color=ORANGE,
plot_color=BLUE,
rectangle_color=WHITE,
):
super(VGroup, self).__init__()
self.function_name = function_name
self.x_range = x_range
self.y_range = y_range
self.x_length = x_length
self.y_length = y_length
self.show_function_name = show_function_name
self.active_color = active_color
self.plot_color = plot_color
self.rectangle_color = rectangle_color
self.construct_activation_function()
def construct_activation_function(self):
"""Makes the activation function"""
# Make an axis
self.axes = Axes(
x_range=self.x_range,
y_range=self.y_range,
x_length=self.x_length,
y_length=self.y_length,
tips=False,
axis_config={
"include_numbers": False,
"stroke_width": 0.5,
"include_ticks": False,
},
)
self.add(self.axes)
# Surround the axis with a rounded rectangle.
self.surrounding_rectangle = SurroundingRectangle(
self.axes,
corner_radius=0.05,
buff=0.05,
stroke_width=2.0,
stroke_color=self.rectangle_color,
)
self.add(self.surrounding_rectangle)
# Plot function on axis by applying it and showing in given range
self.graph = self.axes.plot(
lambda x: self.apply_function(x),
x_range=self.x_range,
stroke_color=self.plot_color,
stroke_width=2.0,
)
self.add(self.graph)
# Add the function name
if self.show_function_name:
function_name_text = Text(
self.function_name, font_size=12, font="sans-serif"
)
function_name_text.next_to(self.axes, UP * 0.5)
self.add(function_name_text)
@abstractmethod
def apply_function(self, x_val):
"""Evaluates function at given x_val"""
if x_val == None:
x_val = random.uniform(self.x_range[0], self.x_range[1])
def make_evaluate_animation(self, x_val=None):
"""Evaluates the function at a random point in the x_range"""
# Highlight the graph
# TODO: Evaluate the function at the x_val and show a highlighted dot
animation_group = Succession(
AnimationGroup(
ApplyMethod(self.graph.set_color, self.active_color),
ApplyMethod(
self.surrounding_rectangle.set_stroke_color, self.active_color
),
lag_ratio=0.0,
),
Wait(1),
AnimationGroup(
ApplyMethod(self.graph.set_color, self.plot_color),
ApplyMethod(
self.surrounding_rectangle.set_stroke_color, self.rectangle_color
),
lag_ratio=0.0,
),
lag_ratio=1.0,
)
return animation_group