mirror of
https://github.com/3b1b/manim.git
synced 2025-07-30 21:44:19 +08:00
Up to InterpretFirstWeightMatrixRows in NN part 2
This commit is contained in:
@ -446,11 +446,14 @@ class Mobject(object):
|
|||||||
return self.color
|
return self.color
|
||||||
##
|
##
|
||||||
|
|
||||||
def save_state(self):
|
def save_state(self, use_deepcopy = False):
|
||||||
if hasattr(self, "saved_state"):
|
if hasattr(self, "saved_state"):
|
||||||
#Prevent exponential growth of data
|
#Prevent exponential growth of data
|
||||||
self.saved_state = None
|
self.saved_state = None
|
||||||
self.saved_state = self.copy()
|
if use_deepcopy:
|
||||||
|
self.saved_state = self.deepcopy()
|
||||||
|
else:
|
||||||
|
self.saved_state = self.copy()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def restore(self):
|
def restore(self):
|
||||||
|
@ -21,12 +21,12 @@ import cPickle
|
|||||||
from nn.mnist_loader import load_data_wrapper
|
from nn.mnist_loader import load_data_wrapper
|
||||||
|
|
||||||
NN_DIRECTORY = os.path.dirname(os.path.realpath(__file__))
|
NN_DIRECTORY = os.path.dirname(os.path.realpath(__file__))
|
||||||
# PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases_36")
|
# PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases_80")
|
||||||
# PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases_ReLU")
|
# PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases_ReLU")
|
||||||
PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases")
|
PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases")
|
||||||
IMAGE_MAP_DATA_FILE = os.path.join(NN_DIRECTORY, "image_map")
|
IMAGE_MAP_DATA_FILE = os.path.join(NN_DIRECTORY, "image_map")
|
||||||
# PRETRAINED_DATA_FILE = "/Users/grant/cs/manim/nn/pretrained_weights_and_biases_on_zero"
|
# PRETRAINED_DATA_FILE = "/Users/grant/cs/manim/nn/pretrained_weights_and_biases_on_zero"
|
||||||
# DEFAULT_LAYER_SIZES = [28**2, 36, 10]
|
# DEFAULT_LAYER_SIZES = [28**2, 80, 10]
|
||||||
DEFAULT_LAYER_SIZES = [28**2, 16, 16, 10]
|
DEFAULT_LAYER_SIZES = [28**2, 16, 16, 10]
|
||||||
|
|
||||||
class Network(object):
|
class Network(object):
|
||||||
|
20
nn/part1.py
20
nn/part1.py
@ -196,19 +196,22 @@ class NetworkMobject(VGroup):
|
|||||||
for l1, l2 in zip(self.layers[:-1], self.layers[1:]):
|
for l1, l2 in zip(self.layers[:-1], self.layers[1:]):
|
||||||
edge_group = VGroup()
|
edge_group = VGroup()
|
||||||
for n1, n2 in it.product(l1.neurons, l2.neurons):
|
for n1, n2 in it.product(l1.neurons, l2.neurons):
|
||||||
edge = Line(
|
edge = self.get_edge(n1, n2)
|
||||||
n1.get_center(),
|
|
||||||
n2.get_center(),
|
|
||||||
buff = self.neuron_radius,
|
|
||||||
stroke_color = self.edge_color,
|
|
||||||
stroke_width = self.edge_stroke_width,
|
|
||||||
)
|
|
||||||
edge_group.add(edge)
|
edge_group.add(edge)
|
||||||
n1.edges_out.add(edge)
|
n1.edges_out.add(edge)
|
||||||
n2.edges_in.add(edge)
|
n2.edges_in.add(edge)
|
||||||
self.edge_groups.add(edge_group)
|
self.edge_groups.add(edge_group)
|
||||||
self.add_to_back(self.edge_groups)
|
self.add_to_back(self.edge_groups)
|
||||||
|
|
||||||
|
def get_edge(self, neuron1, neuron2):
|
||||||
|
return Line(
|
||||||
|
neuron1.get_center(),
|
||||||
|
neuron2.get_center(),
|
||||||
|
buff = self.neuron_radius,
|
||||||
|
stroke_color = self.edge_color,
|
||||||
|
stroke_width = self.edge_stroke_width,
|
||||||
|
)
|
||||||
|
|
||||||
def get_active_layer(self, layer_index, activation_vector):
|
def get_active_layer(self, layer_index, activation_vector):
|
||||||
layer = self.layers[layer_index].deepcopy()
|
layer = self.layers[layer_index].deepcopy()
|
||||||
n_neurons = len(layer.neurons)
|
n_neurons = len(layer.neurons)
|
||||||
@ -2980,6 +2983,7 @@ class ContinualEdgeUpdate(ContinualAnimation):
|
|||||||
"max_stroke_width" : 3,
|
"max_stroke_width" : 3,
|
||||||
"stroke_width_exp" : 7,
|
"stroke_width_exp" : 7,
|
||||||
"n_cycles" : 5,
|
"n_cycles" : 5,
|
||||||
|
"colors" : [GREEN, GREEN, GREEN, RED],
|
||||||
}
|
}
|
||||||
def __init__(self, network_mob, **kwargs):
|
def __init__(self, network_mob, **kwargs):
|
||||||
digest_config(self, kwargs)
|
digest_config(self, kwargs)
|
||||||
@ -2988,7 +2992,7 @@ class ContinualEdgeUpdate(ContinualAnimation):
|
|||||||
self.move_to_targets = []
|
self.move_to_targets = []
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
edge.colors = [
|
edge.colors = [
|
||||||
random.choice([GREEN, GREEN, GREEN, RED])
|
random.choice(self.colors)
|
||||||
for x in range(n_cycles)
|
for x in range(n_cycles)
|
||||||
]
|
]
|
||||||
msw = self.max_stroke_width
|
msw = self.max_stroke_width
|
||||||
|
1184
nn/part2.py
1184
nn/part2.py
File diff suppressed because it is too large
Load Diff
@ -242,9 +242,13 @@ class Arrow(Line):
|
|||||||
if len(args) == 1:
|
if len(args) == 1:
|
||||||
args = (points[0]+UP+LEFT, points[0])
|
args = (points[0]+UP+LEFT, points[0])
|
||||||
Line.__init__(self, *args, **kwargs)
|
Line.__init__(self, *args, **kwargs)
|
||||||
self.add_tip()
|
self.init_tip()
|
||||||
if self.use_rectangular_stem and not hasattr(self, "rect"):
|
if self.use_rectangular_stem and not hasattr(self, "rect"):
|
||||||
self.add_rectangular_stem()
|
self.add_rectangular_stem()
|
||||||
|
self.init_colors()
|
||||||
|
|
||||||
|
def init_tip(self):
|
||||||
|
self.tip = self.add_tip()
|
||||||
|
|
||||||
def add_tip(self, add_at_end = True):
|
def add_tip(self, add_at_end = True):
|
||||||
tip = VMobject(
|
tip = VMobject(
|
||||||
@ -253,11 +257,11 @@ class Arrow(Line):
|
|||||||
fill_color = self.color,
|
fill_color = self.color,
|
||||||
fill_opacity = 1,
|
fill_opacity = 1,
|
||||||
stroke_color = self.color,
|
stroke_color = self.color,
|
||||||
|
stroke_width = 0,
|
||||||
)
|
)
|
||||||
self.set_tip_points(tip, add_at_end, preserve_normal = False)
|
self.set_tip_points(tip, add_at_end, preserve_normal = False)
|
||||||
self.tip = tip
|
self.add(tip)
|
||||||
self.add(self.tip)
|
return tip
|
||||||
self.init_colors()
|
|
||||||
|
|
||||||
def add_rectangular_stem(self):
|
def add_rectangular_stem(self):
|
||||||
self.rect = Rectangle(
|
self.rect = Rectangle(
|
||||||
@ -283,6 +287,10 @@ class Arrow(Line):
|
|||||||
self.rectangular_stem_width,
|
self.rectangular_stem_width,
|
||||||
self.max_stem_width_to_tip_width_ratio*tip_base_width,
|
self.max_stem_width_to_tip_width_ratio*tip_base_width,
|
||||||
)
|
)
|
||||||
|
if hasattr(self, "second_tip"):
|
||||||
|
start = center_of_mass(
|
||||||
|
self.second_tip.get_anchors()[1:]
|
||||||
|
)
|
||||||
self.rect.set_points_as_corners([
|
self.rect.set_points_as_corners([
|
||||||
tip_base + perp_vect*width/2,
|
tip_base + perp_vect*width/2,
|
||||||
start + perp_vect*width/2,
|
start + perp_vect*width/2,
|
||||||
@ -319,7 +327,6 @@ class Arrow(Line):
|
|||||||
if np.linalg.norm(v) == 0:
|
if np.linalg.norm(v) == 0:
|
||||||
v[0] = 1
|
v[0] = 1
|
||||||
v *= tip_length/np.linalg.norm(v)
|
v *= tip_length/np.linalg.norm(v)
|
||||||
|
|
||||||
ratio = self.tip_width_to_length_ratio
|
ratio = self.tip_width_to_length_ratio
|
||||||
tip.set_points_as_corners([
|
tip.set_points_as_corners([
|
||||||
end_point,
|
end_point,
|
||||||
@ -374,9 +381,9 @@ class Vector(Arrow):
|
|||||||
Arrow.__init__(self, ORIGIN, direction, **kwargs)
|
Arrow.__init__(self, ORIGIN, direction, **kwargs)
|
||||||
|
|
||||||
class DoubleArrow(Arrow):
|
class DoubleArrow(Arrow):
|
||||||
def __init__(self, *args, **kwargs):
|
def init_tip(self):
|
||||||
Arrow.__init__(self, *args, **kwargs)
|
self.tip = self.add_tip()
|
||||||
self.add_tip(add_at_end = False)
|
self.second_tip = self.add_tip(add_at_end = False)
|
||||||
|
|
||||||
class CubicBezier(VMobject):
|
class CubicBezier(VMobject):
|
||||||
def __init__(self, points, **kwargs):
|
def __init__(self, points, **kwargs):
|
||||||
|
@ -11,9 +11,9 @@ class DecimalNumber(VMobject):
|
|||||||
"num_decimal_points" : 2,
|
"num_decimal_points" : 2,
|
||||||
"digit_to_digit_buff" : 0.05
|
"digit_to_digit_buff" : 0.05
|
||||||
}
|
}
|
||||||
def __init__(self, float_num, **kwargs):
|
def __init__(self, number, **kwargs):
|
||||||
digest_config(self, kwargs)
|
digest_config(self, kwargs, locals())
|
||||||
num_string = '%.*f'%(self.num_decimal_points, float_num)
|
num_string = '%.*f'%(self.num_decimal_points, number)
|
||||||
VMobject.__init__(self, *[
|
VMobject.__init__(self, *[
|
||||||
TexMobject(char)
|
TexMobject(char)
|
||||||
for char in num_string
|
for char in num_string
|
||||||
@ -22,7 +22,7 @@ class DecimalNumber(VMobject):
|
|||||||
buff = self.digit_to_digit_buff,
|
buff = self.digit_to_digit_buff,
|
||||||
aligned_edge = DOWN
|
aligned_edge = DOWN
|
||||||
)
|
)
|
||||||
if float_num < 0:
|
if number < 0:
|
||||||
minus = self.submobjects[0]
|
minus = self.submobjects[0]
|
||||||
minus.next_to(
|
minus.next_to(
|
||||||
self.submobjects[1], LEFT,
|
self.submobjects[1], LEFT,
|
||||||
@ -65,9 +65,9 @@ class ChangingDecimal(Animation):
|
|||||||
|
|
||||||
def update_number(self, alpha):
|
def update_number(self, alpha):
|
||||||
decimal = self.decimal_number
|
decimal = self.decimal_number
|
||||||
|
new_number = self.number_update_func(alpha)
|
||||||
new_decimal = DecimalNumber(
|
new_decimal = DecimalNumber(
|
||||||
self.number_update_func(alpha),
|
new_number, num_decimal_points = self.num_decimal_points
|
||||||
num_decimal_points = self.num_decimal_points
|
|
||||||
)
|
)
|
||||||
new_decimal.replace(decimal, dim_to_match = 1)
|
new_decimal.replace(decimal, dim_to_match = 1)
|
||||||
new_decimal.highlight(decimal.get_color())
|
new_decimal.highlight(decimal.get_color())
|
||||||
@ -78,6 +78,7 @@ class ChangingDecimal(Animation):
|
|||||||
]
|
]
|
||||||
for sm1, sm2 in zip(*families):
|
for sm1, sm2 in zip(*families):
|
||||||
sm1.interpolate(sm1, sm2, 1)
|
sm1.interpolate(sm1, sm2, 1)
|
||||||
|
self.mobject.number = new_number
|
||||||
|
|
||||||
def update_position(self):
|
def update_position(self):
|
||||||
if self.position_update_func is not None:
|
if self.position_update_func is not None:
|
||||||
|
Reference in New Issue
Block a user