mirror of
https://github.com/3b1b/manim.git
synced 2025-08-03 04:04:36 +08:00
NN Part 1 published
This commit is contained in:
386
nn/part1.py
386
nn/part1.py
@ -632,6 +632,8 @@ class LayOutPlan(TeacherStudentsScene, NetworkScene):
|
||||
self.remove(self.network_mob)
|
||||
|
||||
def construct(self):
|
||||
self.force_skipping()
|
||||
|
||||
self.show_words()
|
||||
self.show_network()
|
||||
self.show_math()
|
||||
@ -757,6 +759,8 @@ class LayOutPlan(TeacherStudentsScene, NetworkScene):
|
||||
def show_videos(self):
|
||||
network_mob = self.network_mob
|
||||
learning = self.learning_word
|
||||
structure = TextMobject("Structure")
|
||||
structure.highlight(YELLOW)
|
||||
videos = VGroup(*[
|
||||
VideoIcon().set_fill(RED)
|
||||
for x in range(2)
|
||||
@ -770,13 +774,17 @@ class LayOutPlan(TeacherStudentsScene, NetworkScene):
|
||||
network_mob.target.move_to(videos[0])
|
||||
learning.generate_target()
|
||||
learning.target.next_to(videos[1], UP)
|
||||
structure.next_to(videos[0], UP)
|
||||
structure.shift(0.5*SMALL_BUFF*UP)
|
||||
|
||||
self.revert_to_original_skipping_status()
|
||||
self.play(
|
||||
MoveToTarget(network_mob),
|
||||
MoveToTarget(learning)
|
||||
)
|
||||
self.play(
|
||||
DrawBorderThenFill(videos[0]),
|
||||
FadeIn(structure),
|
||||
self.get_student_changes(*["pondering"]*3)
|
||||
)
|
||||
self.dither()
|
||||
@ -1192,13 +1200,29 @@ class IntroduceEachLayer(PreviewMNistNetwork):
|
||||
network_mob = self.network_mob
|
||||
neurons = self.neurons
|
||||
layer = network_mob.layers[0]
|
||||
layer.save_state()
|
||||
layer.rotate(np.pi/2)
|
||||
layer.center()
|
||||
layer.brace_label.rotate_in_place(-np.pi/2)
|
||||
n = network_mob.max_shown_neurons/2
|
||||
|
||||
rows = VGroup(*[
|
||||
VGroup(*neurons[28*i:28*(i+1)])
|
||||
for i in range(28)
|
||||
])
|
||||
|
||||
self.play(
|
||||
FadeOut(self.braces),
|
||||
FadeOut(self.brace_labels),
|
||||
FadeOut(VGroup(*self.num_pixels_equation[:-1]))
|
||||
)
|
||||
|
||||
self.play(rows.space_out_submobjects, 1.2)
|
||||
self.play(
|
||||
rows.arrange_submobjects, RIGHT, buff = SMALL_BUFF,
|
||||
path_arc = np.pi/2,
|
||||
run_time = 2
|
||||
)
|
||||
self.play(
|
||||
ReplacementTransform(
|
||||
VGroup(*neurons[:n]),
|
||||
@ -1212,15 +1236,15 @@ class IntroduceEachLayer(PreviewMNistNetwork):
|
||||
VGroup(*neurons[-n:]),
|
||||
VGroup(*layer.neurons[-n:]),
|
||||
),
|
||||
FadeIn(self.corner_image)
|
||||
)
|
||||
self.play(
|
||||
ReplacementTransform(
|
||||
self.num_pixels_equation[-1],
|
||||
layer.brace_label
|
||||
),
|
||||
FadeIn(layer.brace)
|
||||
FadeIn(layer.brace),
|
||||
)
|
||||
self.play(layer.restore, FadeIn(self.corner_image))
|
||||
self.dither()
|
||||
for edge_group, layer in zip(network_mob.edge_groups, network_mob.layers[1:]):
|
||||
self.play(
|
||||
@ -1320,6 +1344,69 @@ class IntroduceEachLayer(PreviewMNistNetwork):
|
||||
self.remove_random_edges(0.7)
|
||||
self.feed_forward(self.image_vect)
|
||||
|
||||
class DiscussChoiceForHiddenLayers(TeacherStudentsScene):
|
||||
def construct(self):
|
||||
network_mob = MNistNetworkMobject(
|
||||
layer_to_layer_buff = 2.5,
|
||||
neuron_stroke_color = WHITE,
|
||||
)
|
||||
network_mob.scale_to_fit_height(4)
|
||||
network_mob.to_edge(UP, buff = LARGE_BUFF)
|
||||
layers = VGroup(*network_mob.layers[1:3])
|
||||
rects = VGroup(*map(SurroundingRectangle, layers))
|
||||
self.add(network_mob)
|
||||
|
||||
two_words = TextMobject("2 hidden layers")
|
||||
two_words.highlight(YELLOW)
|
||||
sixteen_words = TextMobject("16 neurons each")
|
||||
sixteen_words.highlight(MAROON_B)
|
||||
for words in two_words, sixteen_words:
|
||||
words.next_to(rects, UP)
|
||||
|
||||
neurons_anim = LaggedStart(
|
||||
Indicate,
|
||||
VGroup(*it.chain(*[layer.neurons for layer in layers])),
|
||||
rate_func = there_and_back,
|
||||
scale_factor = 2,
|
||||
color = MAROON_B,
|
||||
)
|
||||
|
||||
self.play(
|
||||
ShowCreation(rects),
|
||||
Write(two_words, run_time = 1),
|
||||
self.teacher.change, "raise_right_hand",
|
||||
)
|
||||
self.dither()
|
||||
self.play(
|
||||
FadeOut(rects),
|
||||
ReplacementTransform(two_words, sixteen_words),
|
||||
neurons_anim
|
||||
)
|
||||
self.dither()
|
||||
self.play(self.teacher.change, "shruggie")
|
||||
self.change_student_modes("erm", "confused", "sassy")
|
||||
self.dither()
|
||||
self.student_says(
|
||||
"Why 2 \\\\ layers?",
|
||||
student_index = 1,
|
||||
bubble_kwargs = {"direction" : RIGHT},
|
||||
run_time = 1,
|
||||
target_mode = "raise_left_hand",
|
||||
)
|
||||
self.play(self.teacher.change, "happy")
|
||||
self.dither()
|
||||
self.student_says(
|
||||
"Why 16?",
|
||||
student_index = 0,
|
||||
run_time = 1,
|
||||
)
|
||||
self.play(neurons_anim, run_time = 3)
|
||||
self.play(
|
||||
self.teacher.change, "shruggie",
|
||||
RemovePiCreatureBubble(self.students[0]),
|
||||
)
|
||||
self.dither()
|
||||
|
||||
class MoreHonestMNistNetworkPreview(IntroduceEachLayer):
|
||||
CONFIG = {
|
||||
"network_mob_config" : {
|
||||
@ -1597,7 +1684,7 @@ class BreakUpMacroPatterns(IntroduceEachLayer):
|
||||
|
||||
def show_upper_loop_activation(self):
|
||||
neuron = self.network_mob.layers[-2].neurons[0]
|
||||
words = TextMobject("Upper loop neuron...mabye...")
|
||||
words = TextMobject("Upper loop neuron...maybe...")
|
||||
words.scale(0.8)
|
||||
words.next_to(neuron, UP)
|
||||
words.shift(RIGHT)
|
||||
@ -1619,11 +1706,11 @@ class BreakUpMacroPatterns(IntroduceEachLayer):
|
||||
]
|
||||
|
||||
self.play(FadeIn(nine))
|
||||
self.add_foreground_mobject(self.patterns)
|
||||
self.play(
|
||||
ShowCreation(rect),
|
||||
Write(words)
|
||||
)
|
||||
self.add_foreground_mobject(self.patterns)
|
||||
self.feed_forward(np.random.random(784))
|
||||
self.dither(2)
|
||||
|
||||
@ -2204,7 +2291,7 @@ class IntroduceWeights(IntroduceEachLayer):
|
||||
pixels.next_to(neuron, RIGHT, LARGE_BUFF)
|
||||
rect = SurroundingRectangle(pixels, color = BLUE)
|
||||
|
||||
pixels_to_detect = self.get_pixels_to_detect()
|
||||
pixels_to_detect = self.get_pixels_to_detect(pixels)
|
||||
|
||||
self.play(
|
||||
FadeIn(rect),
|
||||
@ -2249,7 +2336,8 @@ class IntroduceWeights(IntroduceEachLayer):
|
||||
p_labels[-1].shift(SMALL_BUFF*RIGHT)
|
||||
|
||||
def get_alpha_func(i, start = 0):
|
||||
m = int(5*np.sin(2*np.pi*i/128.))
|
||||
# m = int(5*np.sin(2*np.pi*i/128.))
|
||||
m = random.randint(1, 10)
|
||||
return lambda a : start + (1-2*start)*np.sin(np.pi*a*m)**2
|
||||
|
||||
decimals = VGroup()
|
||||
@ -2660,12 +2748,16 @@ class IntroduceSigmoid(GraphScene):
|
||||
name = TextMobject("Sigmoid")
|
||||
name.next_to(ORIGIN, RIGHT, LARGE_BUFF)
|
||||
name.to_edge(UP)
|
||||
char = self.x_axis_label.replace("$", "")
|
||||
equation = TexMobject(
|
||||
"\\sigma(x) = \\frac{1}{1+e^{-x}}"
|
||||
"\\sigma(%s) = \\frac{1}{1+e^{-%s}}"%(char, char)
|
||||
)
|
||||
equation.next_to(name, DOWN)
|
||||
self.add(equation, name)
|
||||
|
||||
self.equation = equation
|
||||
self.sigmoid_name = name
|
||||
|
||||
def add_graph(self):
|
||||
graph = self.get_graph(
|
||||
lambda x : 1./(1+np.exp(-x)),
|
||||
@ -2675,6 +2767,8 @@ class IntroduceSigmoid(GraphScene):
|
||||
self.play(ShowCreation(graph))
|
||||
self.dither()
|
||||
|
||||
self.sigmoid_graph = graph
|
||||
|
||||
###
|
||||
|
||||
def show_part(self, x_min, x_max, color):
|
||||
@ -3494,9 +3588,9 @@ class IntroduceWeightMatrix(NetworkScene):
|
||||
"w_{%s, 0}"%i,
|
||||
"w_{%s, 1}"%i,
|
||||
"\\cdots",
|
||||
"w_{%s, k}"%i,
|
||||
"w_{%s, n}"%i,
|
||||
]))
|
||||
for i in "1", "n"
|
||||
for i in "1", "k"
|
||||
]
|
||||
dots_row = VGroup(*map(TexMobject, [
|
||||
"\\vdots", "\\vdots", "\\ddots", "\\vdots"
|
||||
@ -3591,12 +3685,65 @@ class IntroduceWeightMatrix(NetworkScene):
|
||||
FadeIn, VGroup(*result_terms[1:])
|
||||
))
|
||||
self.dither(2)
|
||||
self.show_meaning_of_lower_rows(
|
||||
arrow, brace, top_row_rect, result_terms
|
||||
)
|
||||
self.play(*map(FadeOut, [
|
||||
result_terms, result_brackets, equals,
|
||||
arrow, brace,
|
||||
top_row_rect, column_rect
|
||||
result_terms, result_brackets, equals, column_rect
|
||||
]))
|
||||
|
||||
def show_meaning_of_lower_rows(self, arrow, brace, row_rect, result_terms):
|
||||
n1, n2, nk = neurons = VGroup(*[
|
||||
self.network_mob.layers[1].neurons[i]
|
||||
for i in 0, 1, -1
|
||||
])
|
||||
for n in neurons:
|
||||
n.save_state()
|
||||
n.edges_in.save_state()
|
||||
|
||||
rect2 = SurroundingRectangle(result_terms[1])
|
||||
rectk = SurroundingRectangle(result_terms[-1])
|
||||
VGroup(rect2, rectk).highlight(WHITE)
|
||||
row2 = self.lower_matrix_rows[0]
|
||||
rowk = self.lower_matrix_rows[-1]
|
||||
|
||||
def show_edges(neuron):
|
||||
self.play(LaggedStart(
|
||||
ShowCreationThenDestruction,
|
||||
neuron.edges_in.copy().set_stroke(GREEN, 5),
|
||||
lag_ratio = 0.7,
|
||||
run_time = 1,
|
||||
))
|
||||
|
||||
self.play(
|
||||
row_rect.move_to, row2,
|
||||
n1.fade,
|
||||
n1.set_fill, None, 0,
|
||||
n1.edges_in.set_stroke, None, 1,
|
||||
n2.set_stroke, WHITE, 3,
|
||||
n2.edges_in.set_stroke, None, 3,
|
||||
ReplacementTransform(arrow, rect2),
|
||||
FadeOut(brace),
|
||||
)
|
||||
show_edges(n2)
|
||||
self.play(
|
||||
row_rect.move_to, rowk,
|
||||
n2.restore,
|
||||
n2.edges_in.restore,
|
||||
nk.set_stroke, WHITE, 3,
|
||||
nk.edges_in.set_stroke, None, 3,
|
||||
ReplacementTransform(rect2, rectk),
|
||||
)
|
||||
show_edges(nk)
|
||||
self.play(
|
||||
n1.restore,
|
||||
n1.edges_in.restore,
|
||||
nk.restore,
|
||||
nk.edges_in.restore,
|
||||
FadeOut(rectk),
|
||||
FadeOut(row_rect),
|
||||
)
|
||||
|
||||
def add_bias_vector(self):
|
||||
bias = self.bias
|
||||
bias_name = self.bias_name
|
||||
@ -4054,7 +4201,7 @@ class NextVideo(MoreHonestMNistNetworkPreview, PiCreatureScene):
|
||||
content = self.content
|
||||
|
||||
video = VideoIcon()
|
||||
video.scale_to_fit_height(2)
|
||||
video.scale_to_fit_height(3)
|
||||
video.set_fill(RED, 0.8)
|
||||
video.next_to(morty, UP+LEFT)
|
||||
|
||||
@ -4098,12 +4245,12 @@ class NextVideo(MoreHonestMNistNetworkPreview, PiCreatureScene):
|
||||
)
|
||||
bang = subscribe_word[1]
|
||||
subscribe_word.to_corner(DOWN+RIGHT)
|
||||
subscribe_word.shift(2*UP)
|
||||
subscribe_word.shift(3*UP)
|
||||
q_mark = TextMobject("?")
|
||||
q_mark.move_to(bang, LEFT)
|
||||
arrow = Arrow(ORIGIN, DOWN, color = RED, buff = 0)
|
||||
arrow.next_to(subscribe_word, DOWN)
|
||||
arrow.shift(RIGHT)
|
||||
arrow.shift(MED_LARGE_BUFF * RIGHT)
|
||||
|
||||
self.play(
|
||||
Write(subscribe_word),
|
||||
@ -4120,7 +4267,7 @@ class NextVideo(MoreHonestMNistNetworkPreview, PiCreatureScene):
|
||||
morty = self.pi_creature
|
||||
|
||||
network_mob, rect, video, words = self.video
|
||||
network_mob.generate_target()
|
||||
network_mob.generate_target(use_deepcopy = True)
|
||||
network_mob.target.scale_to_fit_height(5)
|
||||
network_mob.target.to_corner(UP+LEFT)
|
||||
neurons = VGroup(*network_mob.target.layers[-1].neurons[:2])
|
||||
@ -4216,6 +4363,209 @@ class NNPatreonThanks(PatreonThanks):
|
||||
]
|
||||
}
|
||||
|
||||
class PiCreatureGesture(PiCreatureScene):
|
||||
def construct(self):
|
||||
self.play(self.pi_creature.change, "raise_right_hand")
|
||||
self.dither(5)
|
||||
self.play(self.pi_creature.change, "happy")
|
||||
self.dither(4)
|
||||
|
||||
class IntroduceReLU(IntroduceSigmoid):
|
||||
CONFIG = {
|
||||
"x_axis_label" : "$a$"
|
||||
}
|
||||
def construct(self):
|
||||
self.setup_axes()
|
||||
self.add_title()
|
||||
self.add_graph()
|
||||
self.old_school()
|
||||
self.show_ReLU()
|
||||
self.label_input_regions()
|
||||
|
||||
def old_school(self):
|
||||
sigmoid_graph = self.sigmoid_graph
|
||||
sigmoid_title = VGroup(
|
||||
self.sigmoid_name,
|
||||
self.equation
|
||||
)
|
||||
cross = Cross(sigmoid_title)
|
||||
old_school = TextMobject("Old school")
|
||||
old_school.to_corner(UP+RIGHT)
|
||||
old_school.highlight(RED)
|
||||
arrow = Arrow(
|
||||
old_school.get_bottom(),
|
||||
self.equation.get_right(),
|
||||
color = RED
|
||||
)
|
||||
|
||||
self.play(ShowCreation(cross))
|
||||
self.play(
|
||||
Write(old_school, run_time = 1),
|
||||
GrowArrow(arrow)
|
||||
)
|
||||
self.dither(2)
|
||||
self.play(
|
||||
ApplyMethod(
|
||||
VGroup(cross, sigmoid_title).shift,
|
||||
SPACE_WIDTH*RIGHT,
|
||||
rate_func = running_start
|
||||
),
|
||||
FadeOut(old_school),
|
||||
FadeOut(arrow),
|
||||
)
|
||||
self.play(ShowCreation(
|
||||
self.sigmoid_graph,
|
||||
rate_func = lambda t : smooth(1-t),
|
||||
remover = True
|
||||
))
|
||||
|
||||
def show_ReLU(self):
|
||||
graph = VGroup(
|
||||
Line(
|
||||
self.coords_to_point(-7, 0),
|
||||
self.coords_to_point(0, 0),
|
||||
),
|
||||
Line(
|
||||
self.coords_to_point(0, 0),
|
||||
self.coords_to_point(4, 4),
|
||||
),
|
||||
)
|
||||
graph.highlight(YELLOW)
|
||||
char = self.x_axis_label.replace("$", "")
|
||||
equation = TextMobject("ReLU($%s$) = max$(0, %s)$"%(char, char))
|
||||
equation.shift(SPACE_WIDTH*LEFT/2)
|
||||
equation.to_edge(UP)
|
||||
equation.add_background_rectangle()
|
||||
name = TextMobject("Rectified linear unit")
|
||||
name.move_to(equation)
|
||||
name.add_background_rectangle()
|
||||
|
||||
self.play(Write(equation))
|
||||
self.play(ShowCreation(graph), Animation(equation))
|
||||
self.dither(2)
|
||||
self.play(
|
||||
Write(name),
|
||||
equation.shift, DOWN
|
||||
)
|
||||
self.dither(2)
|
||||
|
||||
self.ReLU_graph = graph
|
||||
|
||||
def label_input_regions(self):
|
||||
l1, l2 = self.ReLU_graph
|
||||
neg_words = TextMobject("Inactive")
|
||||
neg_words.highlight(RED)
|
||||
neg_words.next_to(self.coords_to_point(-2, 0), UP)
|
||||
|
||||
pos_words = TextMobject("Same as $f(a) = a$")
|
||||
pos_words.highlight(GREEN)
|
||||
pos_words.next_to(
|
||||
self.coords_to_point(1, 1),
|
||||
DOWN+RIGHT
|
||||
)
|
||||
|
||||
self.revert_to_original_skipping_status()
|
||||
self.play(ShowCreation(l1.copy().highlight(RED)))
|
||||
self.play(Write(neg_words))
|
||||
self.dither()
|
||||
self.play(ShowCreation(l2.copy().highlight(GREEN)))
|
||||
self.play(Write(pos_words))
|
||||
self.dither(2)
|
||||
|
||||
class CompareSigmoidReLUOnDeepNetworks(PiCreatureScene):
|
||||
def construct(self):
|
||||
morty, lisha = self.morty, self.lisha
|
||||
sigmoid_graph = FunctionGraph(
|
||||
sigmoid,
|
||||
x_min = -5,
|
||||
x_max = 5,
|
||||
)
|
||||
sigmoid_graph.stretch_to_fit_width(3)
|
||||
sigmoid_graph.highlight(YELLOW)
|
||||
sigmoid_graph.next_to(lisha, UP+LEFT)
|
||||
sigmoid_graph.shift_onto_screen()
|
||||
sigmoid_name = TextMobject("Sigmoid")
|
||||
sigmoid_name.next_to(sigmoid_graph, UP)
|
||||
sigmoid_graph.add(sigmoid_name)
|
||||
|
||||
slow_learner = TextMobject("Slow learner")
|
||||
slow_learner.highlight(YELLOW)
|
||||
slow_learner.to_corner(UP+LEFT)
|
||||
slow_arrow = Arrow(
|
||||
slow_learner.get_bottom(),
|
||||
sigmoid_graph.get_top(),
|
||||
)
|
||||
|
||||
relu_graph = VGroup(
|
||||
Line(2*LEFT, ORIGIN),
|
||||
Line(ORIGIN, np.sqrt(2)*(RIGHT+UP)),
|
||||
)
|
||||
relu_graph.highlight(BLUE)
|
||||
relu_graph.next_to(lisha, UP+RIGHT)
|
||||
relu_name = TextMobject("ReLU")
|
||||
relu_name.move_to(relu_graph, UP)
|
||||
relu_graph.add(relu_name)
|
||||
|
||||
network_mob = NetworkMobject(Network(
|
||||
sizes = [6, 4, 5, 4, 3, 5, 2]
|
||||
))
|
||||
network_mob.scale(0.8)
|
||||
network_mob.to_edge(UP, buff = MED_SMALL_BUFF)
|
||||
network_mob.shift(RIGHT)
|
||||
edge_update = ContinualEdgeUpdate(
|
||||
network_mob, stroke_width_exp = 1,
|
||||
)
|
||||
|
||||
self.play(
|
||||
FadeIn(sigmoid_name),
|
||||
ShowCreation(sigmoid_graph),
|
||||
lisha.change, "raise_left_hand",
|
||||
morty.change, "pondering"
|
||||
)
|
||||
self.play(
|
||||
Write(slow_learner, run_time = 1),
|
||||
GrowArrow(slow_arrow)
|
||||
)
|
||||
self.dither()
|
||||
self.play(
|
||||
FadeIn(relu_name),
|
||||
ShowCreation(relu_graph),
|
||||
lisha.change, "raise_right_hand",
|
||||
morty.change, "thinking"
|
||||
)
|
||||
self.play(FadeIn(network_mob))
|
||||
self.add(edge_update)
|
||||
self.dither(10)
|
||||
|
||||
|
||||
|
||||
###
|
||||
def create_pi_creatures(self):
|
||||
morty = Mortimer()
|
||||
morty.shift(SPACE_WIDTH*RIGHT/2).to_edge(DOWN)
|
||||
lisha = PiCreature(color = BLUE_C)
|
||||
lisha.shift(SPACE_WIDTH*LEFT/2).to_edge(DOWN)
|
||||
self.morty, self.lisha = morty, lisha
|
||||
return morty, lisha
|
||||
|
||||
class ShowAmplify(PiCreatureScene):
|
||||
def construct(self):
|
||||
morty = self.pi_creature
|
||||
rect = ScreenRectangle(height = 5)
|
||||
rect.to_corner(UP+LEFT)
|
||||
rect.shift(DOWN)
|
||||
email = TextMobject("3blue1brown@amplifypartners.com")
|
||||
email.next_to(rect, UP)
|
||||
|
||||
self.play(
|
||||
ShowCreation(rect),
|
||||
morty.change, "raise_right_hand"
|
||||
)
|
||||
self.dither(2)
|
||||
self.play(Write(email))
|
||||
self.play(morty.change, "happy", rect)
|
||||
self.dither(10)
|
||||
|
||||
class Thumbnail(NetworkScene):
|
||||
CONFIG = {
|
||||
"network_mob_config" : {
|
||||
@ -4225,11 +4575,13 @@ class Thumbnail(NetworkScene):
|
||||
def construct(self):
|
||||
network_mob = self.network_mob
|
||||
network_mob.scale_to_fit_height(2*SPACE_HEIGHT - 1)
|
||||
for layer in network_mob.layers:
|
||||
layer.neurons.set_stroke(width = 5)
|
||||
|
||||
edge_update = ContinualEdgeUpdate(
|
||||
network_mob,
|
||||
max_stroke_width = 10,
|
||||
stroke_width_exp = 5,
|
||||
stroke_width_exp = 4,
|
||||
)
|
||||
edge_update.internal_time = 3
|
||||
edge_update.update(0)
|
||||
|
Reference in New Issue
Block a user