import numpy as np from animation.creation import ShowCreation from animation.creation import Write from animation.transform import ApplyFunction from animation.transform import ApplyMethod from animation.transform import ApplyPointwiseFunction from animation.creation import FadeOut from animation.transform import Transform from mobject.mobject import Mobject from mobject.svg.tex_mobject import TexMobject from mobject.svg.tex_mobject import TextMobject from mobject.types.vectorized_mobject import VGroup from mobject.types.vectorized_mobject import VMobject from scene.scene import Scene from mobject.geometry import Arrow from mobject.shape_matchers import BackgroundRectangle from mobject.geometry import Circle from mobject.geometry import Dot from mobject.geometry import Line from mobject.geometry import Vector from topics.number_line import Axes from topics.number_line import NumberPlane from constants import * VECTOR_LABEL_SCALE_FACTOR = 0.8 def matrix_to_tex_string(matrix): matrix = np.array(matrix).astype("string") if matrix.ndim == 1: matrix = matrix.reshape((matrix.size, 1)) n_rows, n_cols = matrix.shape prefix = "\\left[ \\begin{array}{%s}"%("c"*n_cols) suffix = "\\end{array} \\right]" rows = [ " & ".join(row) for row in matrix ] return prefix + " \\\\ ".join(rows) + suffix def matrix_to_mobject(matrix): return TexMobject(matrix_to_tex_string(matrix)) def vector_coordinate_label(vector_mob, integer_labels = True, n_dim = 2, color = WHITE): vect = np.array(vector_mob.get_end()) if integer_labels: vect = np.round(vect).astype(int) vect = vect[:n_dim] vect = vect.reshape((n_dim, 1)) label = Matrix(vect, add_background_rectangles = True) label.scale(VECTOR_LABEL_SCALE_FACTOR) shift_dir = np.array(vector_mob.get_end()) if shift_dir[0] >= 0: #Pointing right shift_dir -= label.get_left() + DEFAULT_MOBJECT_TO_MOBJECT_BUFFER*LEFT else: #Pointing left shift_dir -= label.get_right() + DEFAULT_MOBJECT_TO_MOBJECT_BUFFER*RIGHT label.shift(shift_dir) label.set_color(color) label.rect = BackgroundRectangle(label) label.add_to_back(label.rect) return label class Matrix(VMobject): CONFIG = { "v_buff" : 0.5, "h_buff" : 1, "add_background_rectangles" : False } def __init__(self, matrix, **kwargs): """ Matrix can either either include numbres, tex_strings, or mobjects """ VMobject.__init__(self, **kwargs) matrix = np.array(matrix) if matrix.ndim == 1: matrix = matrix.reshape((matrix.size, 1)) if not isinstance(matrix[0][0], Mobject): matrix = matrix.astype("string") matrix = self.string_matrix_to_mob_matrix(matrix) self.organize_mob_matrix(matrix) self.add(*matrix.flatten()) self.add_brackets() self.center() self.mob_matrix = matrix if self.add_background_rectangles: for mob in matrix.flatten(): mob.add_background_rectangle() def string_matrix_to_mob_matrix(self, matrix): return np.array([ map(TexMobject, row) for row in matrix ]).reshape(matrix.shape) def organize_mob_matrix(self, matrix): for i, row in enumerate(matrix): for j, elem in enumerate(row): mob = matrix[i][j] if i == 0 and j == 0: continue elif i == 0: mob.next_to(matrix[i][j-1], RIGHT, self.h_buff) else: mob.next_to(matrix[i-1][j], DOWN, self.v_buff) return self def add_brackets(self): bracket_pair = TexMobject("\\big[ \\big]") bracket_pair.scale(2) bracket_pair.stretch_to_fit_height(self.get_height() + 0.5) l_bracket, r_bracket = bracket_pair.split() l_bracket.next_to(self, LEFT) r_bracket.next_to(self, RIGHT) self.add(l_bracket, r_bracket) self.brackets = VMobject(l_bracket, r_bracket) return self def set_color_columns(self, *colors): for i, color in enumerate(colors): VMobject(*self.mob_matrix[:,i]).set_color(color) return self def add_background_to_entries(self): for mob in self.get_entries(): mob.add_background_rectangle() return self def get_mob_matrix(self): return self.mob_matrix def get_entries(self): return VMobject(*self.get_mob_matrix().flatten()) def get_brackets(self): return self.brackets class NumericalMatrixMultiplication(Scene): CONFIG = { "left_matrix" : [[1, 2], [3, 4]], "right_matrix" : [[5, 6], [7, 8]], "use_parens" : True, } def construct(self): left_string_matrix, right_string_matrix = [ np.array(matrix).astype("string") for matrix in self.left_matrix, self.right_matrix ] if right_string_matrix.shape[0] != left_string_matrix.shape[1]: raise Exception("Incompatible shapes for matrix multiplication") left = Matrix(left_string_matrix) right = Matrix(right_string_matrix) result = self.get_result_matrix( left_string_matrix, right_string_matrix ) self.organize_matrices(left, right, result) self.animate_product(left, right, result) def get_result_matrix(self, left, right): (m, k), n = left.shape, right.shape[1] mob_matrix = np.array([VMobject()]).repeat(m*n).reshape((m, n)) for a in range(m): for b in range(n): template = "(%s)(%s)" if self.use_parens else "%s%s" parts = [ prefix + template%(left[a][c], right[c][b]) for c in range(k) for prefix in ["" if c == 0 else "+"] ] mob_matrix[a][b] = TexMobject(parts, next_to_buff = 0.1) return Matrix(mob_matrix) def add_lines(self, left, right): line_kwargs = { "color" : BLUE, "stroke_width" : 2, } left_rows = [ VMobject(*row) for row in left.get_mob_matrix() ] h_lines = VMobject() for row in left_rows[:-1]: h_line = Line(row.get_left(), row.get_right(), **line_kwargs) h_line.next_to(row, DOWN, buff = left.v_buff/2.) h_lines.add(h_line) right_cols = [ VMobject(*col) for col in np.transpose(right.get_mob_matrix()) ] v_lines = VMobject() for col in right_cols[:-1]: v_line = Line(col.get_top(), col.get_bottom(), **line_kwargs) v_line.next_to(col, RIGHT, buff = right.h_buff/2.) v_lines.add(v_line) self.play(ShowCreation(h_lines)) self.play(ShowCreation(v_lines)) self.wait() self.show_frame() def organize_matrices(self, left, right, result): equals = TexMobject("=") everything = VMobject(left, right, equals, result) everything.arrange_submobjects() everything.scale_to_fit_width(FRAME_WIDTH-1) self.add(everything) def animate_product(self, left, right, result): l_matrix = left.get_mob_matrix() r_matrix = right.get_mob_matrix() result_matrix = result.get_mob_matrix() circle = Circle( radius = l_matrix[0][0].get_height(), color = GREEN ) circles = VMobject(*[ entry.get_point_mobject() for entry in l_matrix[0][0], r_matrix[0][0] ]) (m, k), n = l_matrix.shape, r_matrix.shape[1] for mob in result_matrix.flatten(): mob.set_color(BLACK) lagging_anims = [] for a in range(m): for b in range(n): for c in range(k): l_matrix[a][c].set_color(YELLOW) r_matrix[c][b].set_color(YELLOW) for c in range(k): start_parts = VMobject( l_matrix[a][c].copy(), r_matrix[c][b].copy() ) result_entry = result_matrix[a][b].split()[c] new_circles = VMobject(*[ circle.copy().shift(part.get_center()) for part in start_parts.split() ]) self.play(Transform(circles, new_circles)) self.play( Transform( start_parts, result_entry.copy().set_color(YELLOW), path_arc = -np.pi/2, submobject_mode = "all_at_once", ), *lagging_anims ) result_entry.set_color(YELLOW) self.remove(start_parts) lagging_anims = [ ApplyMethod(result_entry.set_color, WHITE) ] for c in range(k): l_matrix[a][c].set_color(WHITE) r_matrix[c][b].set_color(WHITE) self.play(FadeOut(circles), *lagging_anims) self.wait()