Added sigmoid activation, and ability to do activation above linear layers.

This commit is contained in:
Alec Helbling
2023-01-25 17:36:44 -05:00
parent 4948c0ea4e
commit 301b230c73
7 changed files with 57 additions and 10 deletions

View File

@ -63,12 +63,12 @@ class CombinedScene(ThreeDScene):
nn.move_to(ORIGIN)
self.add(nn)
# Make code snippet
code = make_code_snippet()
code.next_to(nn, DOWN)
self.add(code)
nn.move_to(ORIGIN)
# code = make_code_snippet()
# code.next_to(nn, DOWN)
# self.add(code)
# nn.move_to(ORIGIN)
# Move everything up
Group(nn, code).move_to(ORIGIN)
# Group(nn, code).move_to(ORIGIN)
# Play animation
forward_pass = nn.make_forward_pass_animation()
self.wait(1)

View File

@ -1,7 +1,13 @@
from manim_ml.neural_network.activation_functions.relu import ReLUFunction
from manim_ml.neural_network.activation_functions.sigmoid import SigmoidFunction
name_to_activation_function_map = {"ReLU": ReLUFunction}
name_to_activation_function_map = {
"ReLU": ReLUFunction,
"Sigmoid": SigmoidFunction
}
def get_activation_function_by_name(name):
assert name in name_to_activation_function_map.keys(), \
f"Unrecognized activation function {name}"
return name_to_activation_function_map[name]

View File

@ -4,7 +4,6 @@ import random
import manim_ml.neural_network.activation_functions.relu as relu
class ActivationFunction(ABC, VGroup):
"""Abstract parent class for defining activation functions"""

View File

@ -0,0 +1,15 @@
from manim import *
import numpy as np
from manim_ml.neural_network.activation_functions.activation_function import (
ActivationFunction,
)
class SigmoidFunction(ActivationFunction):
"""Sigmoid Activation Function"""
def __init__(self, function_name="Sigmoid", x_range=[-5, 5], y_range=[0, 1]):
super().__init__(function_name, x_range, y_range)
def apply_function(self, x_val):
return 1 / (1 + np.exp(-1 * x_val))

View File

@ -68,6 +68,11 @@ class Convolutional2DLayer(VGroupNeuralNetworkLayer, ThreeDLayer):
about_point=self.get_center(),
axis=ThreeDLayer.rotation_axis,
)
self.construct_activation_function()
def construct_activation_function(self):
"""Construct the activation function"""
# Add the activation function
if not self.activation_function is None:
# Check if it is a string

View File

@ -1,6 +1,8 @@
from manim import *
from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLayer
from manim_ml.neural_network.activation_functions import get_activation_function_by_name
from manim_ml.neural_network.activation_functions.activation_function import ActivationFunction
from manim_ml.neural_network.layers.parent_layers import VGroupNeuralNetworkLayer
class FeedForwardLayer(VGroupNeuralNetworkLayer):
"""Handles rendering a layer for a neural network"""
@ -18,6 +20,7 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer):
node_stroke_width=2.0,
rectangle_stroke_width=2.0,
animation_dot_color=RED,
activation_function=None,
**kwargs
):
super(VGroupNeuralNetworkLayer, self).__init__(**kwargs)
@ -32,6 +35,7 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer):
self.node_spacing = node_spacing
self.rectangle_fill_color = rectangle_fill_color
self.animation_dot_color = animation_dot_color
self.activation_function = activation_function
self.node_group = VGroup()
@ -68,6 +72,24 @@ class FeedForwardLayer(VGroupNeuralNetworkLayer):
# Add the objects to the class
self.add(self.surrounding_rectangle, self.node_group)
self.construct_activation_function()
def construct_activation_function(self):
"""Construct the activation function"""
# Add the activation function
if not self.activation_function is None:
# Check if it is a string
if isinstance(self.activation_function, str):
activation_function = get_activation_function_by_name(
self.activation_function
)()
else:
assert isinstance(self.activation_function, ActivationFunction)
activation_function = self.activation_function
# Plot the function above the rest of the layer
self.activation_function = activation_function
self.add(self.activation_function)
def make_dropout_forward_pass_animation(self, layer_args, **kwargs):
"""Makes a forward pass animation with dropout"""
# Make sure proper dropout information was passed

View File

@ -22,7 +22,7 @@ class CombinedScene(ThreeDScene):
ImageLayer(numpy_image, height=1.5),
Convolutional2DLayer(1, 7, filter_spacing=0.32),
Convolutional2DLayer(3, 5, 3, filter_spacing=0.32, activation_function="ReLU"),
FeedForwardLayer(3),
FeedForwardLayer(3, activation_function="Sigmoid"),
],
layer_spacing=0.25,
)