Bug fixes and linting for the activation functions addition.

This commit is contained in:
Alec Helbling
2023-01-25 08:40:32 -05:00
parent ce184af78e
commit f56620f047
42 changed files with 1275 additions and 387 deletions

View File

@ -4,12 +4,22 @@ import random
import manim_ml.neural_network.activation_functions.relu as relu
class ActivationFunction(ABC, VGroup):
"""Abstract parent class for defining activation functions"""
def __init__(self, function_name=None, x_range=[-1, 1], y_range=[-1, 1],
x_length=0.5, y_length=0.3, show_function_name=True, active_color=ORANGE,
plot_color=BLUE, rectangle_color=WHITE):
def __init__(
self,
function_name=None,
x_range=[-1, 1],
y_range=[-1, 1],
x_length=0.5,
y_length=0.3,
show_function_name=True,
active_color=ORANGE,
plot_color=BLUE,
rectangle_color=WHITE,
):
super(VGroup, self).__init__()
self.function_name = function_name
self.x_range = x_range
@ -25,7 +35,7 @@ class ActivationFunction(ABC, VGroup):
def construct_activation_function(self):
"""Makes the activation function"""
# Make an axis
# Make an axis
self.axes = Axes(
x_range=self.x_range,
y_range=self.y_range,
@ -35,17 +45,17 @@ class ActivationFunction(ABC, VGroup):
axis_config={
"include_numbers": False,
"stroke_width": 0.5,
"include_ticks": False
}
"include_ticks": False,
},
)
self.add(self.axes)
# Surround the axis with a rounded rectangle.
# Surround the axis with a rounded rectangle.
self.surrounding_rectangle = SurroundingRectangle(
self.axes,
corner_radius=0.05,
buff=0.05,
stroke_width=2.0,
stroke_color=self.rectangle_color
stroke_color=self.rectangle_color,
)
self.add(self.surrounding_rectangle)
# Plot function on axis by applying it and showing in given range
@ -53,17 +63,15 @@ class ActivationFunction(ABC, VGroup):
lambda x: self.apply_function(x),
x_range=self.x_range,
stroke_color=self.plot_color,
stroke_width=2.0
stroke_width=2.0,
)
self.add(self.graph)
# Add the function name
if self.show_function_name:
function_name_text = Text(
self.function_name,
font_size=12,
font="sans-serif"
self.function_name, font_size=12, font="sans-serif"
)
function_name_text.next_to(self.axes, UP*0.5)
function_name_text.next_to(self.axes, UP * 0.5)
self.add(function_name_text)
@abstractmethod
@ -78,29 +86,21 @@ class ActivationFunction(ABC, VGroup):
# TODO: Evaluate the function at the x_val and show a highlighted dot
animation_group = Succession(
AnimationGroup(
ApplyMethod(self.graph.set_color, self.active_color),
ApplyMethod(
self.graph.set_color,
self.active_color
self.surrounding_rectangle.set_stroke_color, self.active_color
),
ApplyMethod(
self.surrounding_rectangle.set_stroke_color,
self.active_color
),
lag_ratio=0.0
lag_ratio=0.0,
),
Wait(1),
AnimationGroup(
ApplyMethod(self.graph.set_color, self.plot_color),
ApplyMethod(
self.graph.set_color,
self.plot_color
self.surrounding_rectangle.set_stroke_color, self.rectangle_color
),
ApplyMethod(
self.surrounding_rectangle.set_stroke_color,
self.rectangle_color
),
lag_ratio=0.0
lag_ratio=0.0,
),
lag_ratio=1.0
lag_ratio=1.0,
)
return animation_group
return animation_group