mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-23 13:36:26 +08:00
Bug fixes and linting for the activation functions addition.
This commit is contained in:
@ -4,12 +4,22 @@ import random
|
||||
|
||||
import manim_ml.neural_network.activation_functions.relu as relu
|
||||
|
||||
|
||||
class ActivationFunction(ABC, VGroup):
|
||||
"""Abstract parent class for defining activation functions"""
|
||||
|
||||
def __init__(self, function_name=None, x_range=[-1, 1], y_range=[-1, 1],
|
||||
x_length=0.5, y_length=0.3, show_function_name=True, active_color=ORANGE,
|
||||
plot_color=BLUE, rectangle_color=WHITE):
|
||||
def __init__(
|
||||
self,
|
||||
function_name=None,
|
||||
x_range=[-1, 1],
|
||||
y_range=[-1, 1],
|
||||
x_length=0.5,
|
||||
y_length=0.3,
|
||||
show_function_name=True,
|
||||
active_color=ORANGE,
|
||||
plot_color=BLUE,
|
||||
rectangle_color=WHITE,
|
||||
):
|
||||
super(VGroup, self).__init__()
|
||||
self.function_name = function_name
|
||||
self.x_range = x_range
|
||||
@ -25,7 +35,7 @@ class ActivationFunction(ABC, VGroup):
|
||||
|
||||
def construct_activation_function(self):
|
||||
"""Makes the activation function"""
|
||||
# Make an axis
|
||||
# Make an axis
|
||||
self.axes = Axes(
|
||||
x_range=self.x_range,
|
||||
y_range=self.y_range,
|
||||
@ -35,17 +45,17 @@ class ActivationFunction(ABC, VGroup):
|
||||
axis_config={
|
||||
"include_numbers": False,
|
||||
"stroke_width": 0.5,
|
||||
"include_ticks": False
|
||||
}
|
||||
"include_ticks": False,
|
||||
},
|
||||
)
|
||||
self.add(self.axes)
|
||||
# Surround the axis with a rounded rectangle.
|
||||
# Surround the axis with a rounded rectangle.
|
||||
self.surrounding_rectangle = SurroundingRectangle(
|
||||
self.axes,
|
||||
corner_radius=0.05,
|
||||
buff=0.05,
|
||||
stroke_width=2.0,
|
||||
stroke_color=self.rectangle_color
|
||||
stroke_color=self.rectangle_color,
|
||||
)
|
||||
self.add(self.surrounding_rectangle)
|
||||
# Plot function on axis by applying it and showing in given range
|
||||
@ -53,17 +63,15 @@ class ActivationFunction(ABC, VGroup):
|
||||
lambda x: self.apply_function(x),
|
||||
x_range=self.x_range,
|
||||
stroke_color=self.plot_color,
|
||||
stroke_width=2.0
|
||||
stroke_width=2.0,
|
||||
)
|
||||
self.add(self.graph)
|
||||
# Add the function name
|
||||
if self.show_function_name:
|
||||
function_name_text = Text(
|
||||
self.function_name,
|
||||
font_size=12,
|
||||
font="sans-serif"
|
||||
self.function_name, font_size=12, font="sans-serif"
|
||||
)
|
||||
function_name_text.next_to(self.axes, UP*0.5)
|
||||
function_name_text.next_to(self.axes, UP * 0.5)
|
||||
self.add(function_name_text)
|
||||
|
||||
@abstractmethod
|
||||
@ -78,29 +86,21 @@ class ActivationFunction(ABC, VGroup):
|
||||
# TODO: Evaluate the function at the x_val and show a highlighted dot
|
||||
animation_group = Succession(
|
||||
AnimationGroup(
|
||||
ApplyMethod(self.graph.set_color, self.active_color),
|
||||
ApplyMethod(
|
||||
self.graph.set_color,
|
||||
self.active_color
|
||||
self.surrounding_rectangle.set_stroke_color, self.active_color
|
||||
),
|
||||
ApplyMethod(
|
||||
self.surrounding_rectangle.set_stroke_color,
|
||||
self.active_color
|
||||
),
|
||||
lag_ratio=0.0
|
||||
lag_ratio=0.0,
|
||||
),
|
||||
Wait(1),
|
||||
AnimationGroup(
|
||||
ApplyMethod(self.graph.set_color, self.plot_color),
|
||||
ApplyMethod(
|
||||
self.graph.set_color,
|
||||
self.plot_color
|
||||
self.surrounding_rectangle.set_stroke_color, self.rectangle_color
|
||||
),
|
||||
ApplyMethod(
|
||||
self.surrounding_rectangle.set_stroke_color,
|
||||
self.rectangle_color
|
||||
),
|
||||
lag_ratio=0.0
|
||||
lag_ratio=0.0,
|
||||
),
|
||||
lag_ratio=1.0
|
||||
lag_ratio=1.0,
|
||||
)
|
||||
|
||||
return animation_group
|
||||
return animation_group
|
||||
|
Reference in New Issue
Block a user