Up to InterpretFirstWeightMatrixRows in NN part 2

This commit is contained in:
Grant Sanderson
2017-10-12 17:38:25 -07:00
parent fd59591000
commit d138ffd353
6 changed files with 1210 additions and 41 deletions

View File

@ -446,11 +446,14 @@ class Mobject(object):
return self.color return self.color
## ##
def save_state(self): def save_state(self, use_deepcopy = False):
if hasattr(self, "saved_state"): if hasattr(self, "saved_state"):
#Prevent exponential growth of data #Prevent exponential growth of data
self.saved_state = None self.saved_state = None
self.saved_state = self.copy() if use_deepcopy:
self.saved_state = self.deepcopy()
else:
self.saved_state = self.copy()
return self return self
def restore(self): def restore(self):

View File

@ -21,12 +21,12 @@ import cPickle
from nn.mnist_loader import load_data_wrapper from nn.mnist_loader import load_data_wrapper
NN_DIRECTORY = os.path.dirname(os.path.realpath(__file__)) NN_DIRECTORY = os.path.dirname(os.path.realpath(__file__))
# PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases_36") # PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases_80")
# PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases_ReLU") # PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases_ReLU")
PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases") PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases")
IMAGE_MAP_DATA_FILE = os.path.join(NN_DIRECTORY, "image_map") IMAGE_MAP_DATA_FILE = os.path.join(NN_DIRECTORY, "image_map")
# PRETRAINED_DATA_FILE = "/Users/grant/cs/manim/nn/pretrained_weights_and_biases_on_zero" # PRETRAINED_DATA_FILE = "/Users/grant/cs/manim/nn/pretrained_weights_and_biases_on_zero"
# DEFAULT_LAYER_SIZES = [28**2, 36, 10] # DEFAULT_LAYER_SIZES = [28**2, 80, 10]
DEFAULT_LAYER_SIZES = [28**2, 16, 16, 10] DEFAULT_LAYER_SIZES = [28**2, 16, 16, 10]
class Network(object): class Network(object):

View File

@ -196,19 +196,22 @@ class NetworkMobject(VGroup):
for l1, l2 in zip(self.layers[:-1], self.layers[1:]): for l1, l2 in zip(self.layers[:-1], self.layers[1:]):
edge_group = VGroup() edge_group = VGroup()
for n1, n2 in it.product(l1.neurons, l2.neurons): for n1, n2 in it.product(l1.neurons, l2.neurons):
edge = Line( edge = self.get_edge(n1, n2)
n1.get_center(),
n2.get_center(),
buff = self.neuron_radius,
stroke_color = self.edge_color,
stroke_width = self.edge_stroke_width,
)
edge_group.add(edge) edge_group.add(edge)
n1.edges_out.add(edge) n1.edges_out.add(edge)
n2.edges_in.add(edge) n2.edges_in.add(edge)
self.edge_groups.add(edge_group) self.edge_groups.add(edge_group)
self.add_to_back(self.edge_groups) self.add_to_back(self.edge_groups)
def get_edge(self, neuron1, neuron2):
return Line(
neuron1.get_center(),
neuron2.get_center(),
buff = self.neuron_radius,
stroke_color = self.edge_color,
stroke_width = self.edge_stroke_width,
)
def get_active_layer(self, layer_index, activation_vector): def get_active_layer(self, layer_index, activation_vector):
layer = self.layers[layer_index].deepcopy() layer = self.layers[layer_index].deepcopy()
n_neurons = len(layer.neurons) n_neurons = len(layer.neurons)
@ -2980,6 +2983,7 @@ class ContinualEdgeUpdate(ContinualAnimation):
"max_stroke_width" : 3, "max_stroke_width" : 3,
"stroke_width_exp" : 7, "stroke_width_exp" : 7,
"n_cycles" : 5, "n_cycles" : 5,
"colors" : [GREEN, GREEN, GREEN, RED],
} }
def __init__(self, network_mob, **kwargs): def __init__(self, network_mob, **kwargs):
digest_config(self, kwargs) digest_config(self, kwargs)
@ -2988,7 +2992,7 @@ class ContinualEdgeUpdate(ContinualAnimation):
self.move_to_targets = [] self.move_to_targets = []
for edge in edges: for edge in edges:
edge.colors = [ edge.colors = [
random.choice([GREEN, GREEN, GREEN, RED]) random.choice(self.colors)
for x in range(n_cycles) for x in range(n_cycles)
] ]
msw = self.max_stroke_width msw = self.max_stroke_width

File diff suppressed because it is too large Load Diff

View File

@ -242,9 +242,13 @@ class Arrow(Line):
if len(args) == 1: if len(args) == 1:
args = (points[0]+UP+LEFT, points[0]) args = (points[0]+UP+LEFT, points[0])
Line.__init__(self, *args, **kwargs) Line.__init__(self, *args, **kwargs)
self.add_tip() self.init_tip()
if self.use_rectangular_stem and not hasattr(self, "rect"): if self.use_rectangular_stem and not hasattr(self, "rect"):
self.add_rectangular_stem() self.add_rectangular_stem()
self.init_colors()
def init_tip(self):
self.tip = self.add_tip()
def add_tip(self, add_at_end = True): def add_tip(self, add_at_end = True):
tip = VMobject( tip = VMobject(
@ -253,11 +257,11 @@ class Arrow(Line):
fill_color = self.color, fill_color = self.color,
fill_opacity = 1, fill_opacity = 1,
stroke_color = self.color, stroke_color = self.color,
stroke_width = 0,
) )
self.set_tip_points(tip, add_at_end, preserve_normal = False) self.set_tip_points(tip, add_at_end, preserve_normal = False)
self.tip = tip self.add(tip)
self.add(self.tip) return tip
self.init_colors()
def add_rectangular_stem(self): def add_rectangular_stem(self):
self.rect = Rectangle( self.rect = Rectangle(
@ -283,6 +287,10 @@ class Arrow(Line):
self.rectangular_stem_width, self.rectangular_stem_width,
self.max_stem_width_to_tip_width_ratio*tip_base_width, self.max_stem_width_to_tip_width_ratio*tip_base_width,
) )
if hasattr(self, "second_tip"):
start = center_of_mass(
self.second_tip.get_anchors()[1:]
)
self.rect.set_points_as_corners([ self.rect.set_points_as_corners([
tip_base + perp_vect*width/2, tip_base + perp_vect*width/2,
start + perp_vect*width/2, start + perp_vect*width/2,
@ -319,7 +327,6 @@ class Arrow(Line):
if np.linalg.norm(v) == 0: if np.linalg.norm(v) == 0:
v[0] = 1 v[0] = 1
v *= tip_length/np.linalg.norm(v) v *= tip_length/np.linalg.norm(v)
ratio = self.tip_width_to_length_ratio ratio = self.tip_width_to_length_ratio
tip.set_points_as_corners([ tip.set_points_as_corners([
end_point, end_point,
@ -374,9 +381,9 @@ class Vector(Arrow):
Arrow.__init__(self, ORIGIN, direction, **kwargs) Arrow.__init__(self, ORIGIN, direction, **kwargs)
class DoubleArrow(Arrow): class DoubleArrow(Arrow):
def __init__(self, *args, **kwargs): def init_tip(self):
Arrow.__init__(self, *args, **kwargs) self.tip = self.add_tip()
self.add_tip(add_at_end = False) self.second_tip = self.add_tip(add_at_end = False)
class CubicBezier(VMobject): class CubicBezier(VMobject):
def __init__(self, points, **kwargs): def __init__(self, points, **kwargs):

View File

@ -11,9 +11,9 @@ class DecimalNumber(VMobject):
"num_decimal_points" : 2, "num_decimal_points" : 2,
"digit_to_digit_buff" : 0.05 "digit_to_digit_buff" : 0.05
} }
def __init__(self, float_num, **kwargs): def __init__(self, number, **kwargs):
digest_config(self, kwargs) digest_config(self, kwargs, locals())
num_string = '%.*f'%(self.num_decimal_points, float_num) num_string = '%.*f'%(self.num_decimal_points, number)
VMobject.__init__(self, *[ VMobject.__init__(self, *[
TexMobject(char) TexMobject(char)
for char in num_string for char in num_string
@ -22,7 +22,7 @@ class DecimalNumber(VMobject):
buff = self.digit_to_digit_buff, buff = self.digit_to_digit_buff,
aligned_edge = DOWN aligned_edge = DOWN
) )
if float_num < 0: if number < 0:
minus = self.submobjects[0] minus = self.submobjects[0]
minus.next_to( minus.next_to(
self.submobjects[1], LEFT, self.submobjects[1], LEFT,
@ -65,9 +65,9 @@ class ChangingDecimal(Animation):
def update_number(self, alpha): def update_number(self, alpha):
decimal = self.decimal_number decimal = self.decimal_number
new_number = self.number_update_func(alpha)
new_decimal = DecimalNumber( new_decimal = DecimalNumber(
self.number_update_func(alpha), new_number, num_decimal_points = self.num_decimal_points
num_decimal_points = self.num_decimal_points
) )
new_decimal.replace(decimal, dim_to_match = 1) new_decimal.replace(decimal, dim_to_match = 1)
new_decimal.highlight(decimal.get_color()) new_decimal.highlight(decimal.get_color())
@ -78,6 +78,7 @@ class ChangingDecimal(Animation):
] ]
for sm1, sm2 in zip(*families): for sm1, sm2 in zip(*families):
sm1.interpolate(sm1, sm2, 1) sm1.interpolate(sm1, sm2, 1)
self.mobject.number = new_number
def update_position(self): def update_position(self):
if self.position_update_func is not None: if self.position_update_func is not None: