mirror of
https://github.com/3b1b/manim.git
synced 2025-07-30 21:44:19 +08:00
Up to InterpretFirstWeightMatrixRows in NN part 2
This commit is contained in:
@ -446,11 +446,14 @@ class Mobject(object):
|
||||
return self.color
|
||||
##
|
||||
|
||||
def save_state(self):
|
||||
def save_state(self, use_deepcopy = False):
|
||||
if hasattr(self, "saved_state"):
|
||||
#Prevent exponential growth of data
|
||||
self.saved_state = None
|
||||
self.saved_state = self.copy()
|
||||
if use_deepcopy:
|
||||
self.saved_state = self.deepcopy()
|
||||
else:
|
||||
self.saved_state = self.copy()
|
||||
return self
|
||||
|
||||
def restore(self):
|
||||
|
@ -21,12 +21,12 @@ import cPickle
|
||||
from nn.mnist_loader import load_data_wrapper
|
||||
|
||||
NN_DIRECTORY = os.path.dirname(os.path.realpath(__file__))
|
||||
# PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases_36")
|
||||
# PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases_80")
|
||||
# PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases_ReLU")
|
||||
PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases")
|
||||
IMAGE_MAP_DATA_FILE = os.path.join(NN_DIRECTORY, "image_map")
|
||||
# PRETRAINED_DATA_FILE = "/Users/grant/cs/manim/nn/pretrained_weights_and_biases_on_zero"
|
||||
# DEFAULT_LAYER_SIZES = [28**2, 36, 10]
|
||||
# DEFAULT_LAYER_SIZES = [28**2, 80, 10]
|
||||
DEFAULT_LAYER_SIZES = [28**2, 16, 16, 10]
|
||||
|
||||
class Network(object):
|
||||
|
20
nn/part1.py
20
nn/part1.py
@ -196,19 +196,22 @@ class NetworkMobject(VGroup):
|
||||
for l1, l2 in zip(self.layers[:-1], self.layers[1:]):
|
||||
edge_group = VGroup()
|
||||
for n1, n2 in it.product(l1.neurons, l2.neurons):
|
||||
edge = Line(
|
||||
n1.get_center(),
|
||||
n2.get_center(),
|
||||
buff = self.neuron_radius,
|
||||
stroke_color = self.edge_color,
|
||||
stroke_width = self.edge_stroke_width,
|
||||
)
|
||||
edge = self.get_edge(n1, n2)
|
||||
edge_group.add(edge)
|
||||
n1.edges_out.add(edge)
|
||||
n2.edges_in.add(edge)
|
||||
self.edge_groups.add(edge_group)
|
||||
self.add_to_back(self.edge_groups)
|
||||
|
||||
def get_edge(self, neuron1, neuron2):
|
||||
return Line(
|
||||
neuron1.get_center(),
|
||||
neuron2.get_center(),
|
||||
buff = self.neuron_radius,
|
||||
stroke_color = self.edge_color,
|
||||
stroke_width = self.edge_stroke_width,
|
||||
)
|
||||
|
||||
def get_active_layer(self, layer_index, activation_vector):
|
||||
layer = self.layers[layer_index].deepcopy()
|
||||
n_neurons = len(layer.neurons)
|
||||
@ -2980,6 +2983,7 @@ class ContinualEdgeUpdate(ContinualAnimation):
|
||||
"max_stroke_width" : 3,
|
||||
"stroke_width_exp" : 7,
|
||||
"n_cycles" : 5,
|
||||
"colors" : [GREEN, GREEN, GREEN, RED],
|
||||
}
|
||||
def __init__(self, network_mob, **kwargs):
|
||||
digest_config(self, kwargs)
|
||||
@ -2988,7 +2992,7 @@ class ContinualEdgeUpdate(ContinualAnimation):
|
||||
self.move_to_targets = []
|
||||
for edge in edges:
|
||||
edge.colors = [
|
||||
random.choice([GREEN, GREEN, GREEN, RED])
|
||||
random.choice(self.colors)
|
||||
for x in range(n_cycles)
|
||||
]
|
||||
msw = self.max_stroke_width
|
||||
|
1184
nn/part2.py
1184
nn/part2.py
File diff suppressed because it is too large
Load Diff
@ -242,9 +242,13 @@ class Arrow(Line):
|
||||
if len(args) == 1:
|
||||
args = (points[0]+UP+LEFT, points[0])
|
||||
Line.__init__(self, *args, **kwargs)
|
||||
self.add_tip()
|
||||
self.init_tip()
|
||||
if self.use_rectangular_stem and not hasattr(self, "rect"):
|
||||
self.add_rectangular_stem()
|
||||
self.init_colors()
|
||||
|
||||
def init_tip(self):
|
||||
self.tip = self.add_tip()
|
||||
|
||||
def add_tip(self, add_at_end = True):
|
||||
tip = VMobject(
|
||||
@ -253,11 +257,11 @@ class Arrow(Line):
|
||||
fill_color = self.color,
|
||||
fill_opacity = 1,
|
||||
stroke_color = self.color,
|
||||
stroke_width = 0,
|
||||
)
|
||||
self.set_tip_points(tip, add_at_end, preserve_normal = False)
|
||||
self.tip = tip
|
||||
self.add(self.tip)
|
||||
self.init_colors()
|
||||
self.add(tip)
|
||||
return tip
|
||||
|
||||
def add_rectangular_stem(self):
|
||||
self.rect = Rectangle(
|
||||
@ -283,6 +287,10 @@ class Arrow(Line):
|
||||
self.rectangular_stem_width,
|
||||
self.max_stem_width_to_tip_width_ratio*tip_base_width,
|
||||
)
|
||||
if hasattr(self, "second_tip"):
|
||||
start = center_of_mass(
|
||||
self.second_tip.get_anchors()[1:]
|
||||
)
|
||||
self.rect.set_points_as_corners([
|
||||
tip_base + perp_vect*width/2,
|
||||
start + perp_vect*width/2,
|
||||
@ -319,7 +327,6 @@ class Arrow(Line):
|
||||
if np.linalg.norm(v) == 0:
|
||||
v[0] = 1
|
||||
v *= tip_length/np.linalg.norm(v)
|
||||
|
||||
ratio = self.tip_width_to_length_ratio
|
||||
tip.set_points_as_corners([
|
||||
end_point,
|
||||
@ -374,9 +381,9 @@ class Vector(Arrow):
|
||||
Arrow.__init__(self, ORIGIN, direction, **kwargs)
|
||||
|
||||
class DoubleArrow(Arrow):
|
||||
def __init__(self, *args, **kwargs):
|
||||
Arrow.__init__(self, *args, **kwargs)
|
||||
self.add_tip(add_at_end = False)
|
||||
def init_tip(self):
|
||||
self.tip = self.add_tip()
|
||||
self.second_tip = self.add_tip(add_at_end = False)
|
||||
|
||||
class CubicBezier(VMobject):
|
||||
def __init__(self, points, **kwargs):
|
||||
|
@ -11,9 +11,9 @@ class DecimalNumber(VMobject):
|
||||
"num_decimal_points" : 2,
|
||||
"digit_to_digit_buff" : 0.05
|
||||
}
|
||||
def __init__(self, float_num, **kwargs):
|
||||
digest_config(self, kwargs)
|
||||
num_string = '%.*f'%(self.num_decimal_points, float_num)
|
||||
def __init__(self, number, **kwargs):
|
||||
digest_config(self, kwargs, locals())
|
||||
num_string = '%.*f'%(self.num_decimal_points, number)
|
||||
VMobject.__init__(self, *[
|
||||
TexMobject(char)
|
||||
for char in num_string
|
||||
@ -22,7 +22,7 @@ class DecimalNumber(VMobject):
|
||||
buff = self.digit_to_digit_buff,
|
||||
aligned_edge = DOWN
|
||||
)
|
||||
if float_num < 0:
|
||||
if number < 0:
|
||||
minus = self.submobjects[0]
|
||||
minus.next_to(
|
||||
self.submobjects[1], LEFT,
|
||||
@ -65,9 +65,9 @@ class ChangingDecimal(Animation):
|
||||
|
||||
def update_number(self, alpha):
|
||||
decimal = self.decimal_number
|
||||
new_number = self.number_update_func(alpha)
|
||||
new_decimal = DecimalNumber(
|
||||
self.number_update_func(alpha),
|
||||
num_decimal_points = self.num_decimal_points
|
||||
new_number, num_decimal_points = self.num_decimal_points
|
||||
)
|
||||
new_decimal.replace(decimal, dim_to_match = 1)
|
||||
new_decimal.highlight(decimal.get_color())
|
||||
@ -78,6 +78,7 @@ class ChangingDecimal(Animation):
|
||||
]
|
||||
for sm1, sm2 in zip(*families):
|
||||
sm1.interpolate(sm1, sm2, 1)
|
||||
self.mobject.number = new_number
|
||||
|
||||
def update_position(self):
|
||||
if self.position_update_func is not None:
|
||||
|
Reference in New Issue
Block a user