Posterior update animations in eop/bayes

This commit is contained in:
Grant Sanderson
2017-06-07 14:34:39 -07:00
parent f6340f42d7
commit a21c96991f
4 changed files with 684 additions and 317 deletions

View File

@ -3,7 +3,7 @@ from helpers import *
from scene import Scene
from animation.animation import Animation
from animation.transform import Transform
from animation.transform import Transform, MoveToTarget
from mobject import Mobject
from mobject.vectorized_mobject import VGroup, VMobject, VectorizedPoint
@ -22,30 +22,53 @@ class SampleSpaceScene(Scene):
def add_sample_space(self, **config):
self.add(self.get_sample_space(**config))
def change_horizontal_division(self, p_list, **kwargs):
assert(hasattr(self.sample_space, "horizontal_parts"))
added_anims = kwargs.pop("added_anims", [])
new_division_kwargs = kwargs.pop("new_division_kwargs", {})
added_label_kwargs = kwargs.pop("label_kwargs", {})
def get_division_change_animations(
self, sample_space, parts, p_list,
dimension = 1,
new_label_kwargs = None,
**kwargs
):
if new_label_kwargs is None:
new_label_kwargs = {}
anims = []
p_list = sample_space.complete_p_list(p_list)
full_space = sample_space.full_space
curr_parts = self.sample_space.horizontal_parts
new_division_kwargs["colors"] = [
part.get_color() for part in curr_parts
]
new_parts = self.sample_space.get_horizontal_division(
p_list, **new_division_kwargs
)
anims = [Transform(curr_parts, new_parts)]
if hasattr(curr_parts, "labels"):
label_kwargs = curr_parts.label_kwargs
label_kwargs.update(added_label_kwargs)
new_labels = self.sample_space.get_subdivision_labels(
new_parts, **label_kwargs
vect = DOWN if dimension == 1 else RIGHT
parts.generate_target()
for part, p in zip(parts.target, p_list):
part.replace(full_space, stretch = True)
part.stretch(p, dimension)
parts.target.arrange_submobjects(vect, buff = 0)
parts.target.move_to(full_space)
anims.append(MoveToTarget(parts))
if hasattr(parts, "labels"):
label_kwargs = parts.label_kwargs
label_kwargs.update(new_label_kwargs)
new_braces, new_labels = sample_space.get_subdivision_braces_and_labels(
parts.target, **label_kwargs
)
anims.append(Transform(curr_parts.labels, new_labels))
anims += added_anims
anims += [
Transform(parts.braces, new_braces),
Transform(parts.labels, new_labels),
]
return anims
self.play(*anims, **kwargs)
def get_horizontal_division_change_animations(self, p_list, **kwargs):
assert(hasattr(self.sample_space, "horizontal_parts"))
return self.get_division_change_animations(
self.sample_space, self.sample_space.horizontal_parts, p_list,
dimension = 1,
**kwargs
)
def get_vertical_division_change_animations(self, p_list, **kwargs):
assert(hasattr(self.sample_space, "vertical_parts"))
return self.get_division_change_animations(
self.sample_space, self.sample_space.vertical_parts, p_list,
dimension = 0,
**kwargs
)
class SampleSpace(VGroup):
@ -55,7 +78,8 @@ class SampleSpace(VGroup):
"width" : 3,
"fill_color" : DARK_GREY,
"fill_opacity" : 0.8,
"stroke_width" : 0,
"stroke_width" : 0.5,
"stroke_color" : LIGHT_GREY,
},
"default_label_scale_val" : 0.7,
}
@ -76,12 +100,16 @@ class SampleSpace(VGroup):
def add_label(self, label):
self.label = label
def complete_p_list(self, p_list):
new_p_list = list(tuplify(p_list))
remainder = 1.0 - sum(new_p_list)
if abs(remainder) > EPSILON:
new_p_list.append(remainder)
return new_p_list
def get_division_along_dimension(self, p_list, dim, colors, vect):
p_list = list(tuplify(p_list))
if abs(1.0 - sum(p_list)) > EPSILON:
p_list.append(1.0 - sum(p_list))
p_list = self.complete_p_list(p_list)
colors = color_gradient(colors, len(p_list))
perp_dim = 1-dim
last_point = self.full_space.get_edge_center(-vect)
parts = VGroup()
@ -89,7 +117,7 @@ class SampleSpace(VGroup):
part = SampleSpace()
part.set_fill(color, 1)
part.replace(self.full_space, stretch = True)
part.stretch(factor, perp_dim)
part.stretch(factor, dim)
part.move_to(last_point, -vect)
last_point = part.get_edge_center(vect)
parts.add(part)
@ -100,14 +128,14 @@ class SampleSpace(VGroup):
colors = [GREEN_E, BLUE],
vect = DOWN
):
return self.get_division_along_dimension(p_list, 0, colors, vect)
return self.get_division_along_dimension(p_list, 1, colors, vect)
def get_vertical_division(
self, p_list,
colors = [MAROON_B, YELLOW],
vect = RIGHT
):
return self.get_division_along_dimension(p_list, 1, colors, vect)
return self.get_division_along_dimension(p_list, 0, colors, vect)
def divide_horizontally(self, *args, **kwargs):
self.horizontal_parts = self.get_horizontal_division(*args, **kwargs)
@ -117,38 +145,53 @@ class SampleSpace(VGroup):
self.vertical_parts = self.get_vertical_division(*args, **kwargs)
self.add(self.vertical_parts)
def get_subdivision_labels(self, parts, labels, direction, buff = SMALL_BUFF):
def get_subdivision_braces_and_labels(self, parts, labels, direction, buff = SMALL_BUFF):
label_brace_groups = VGroup()
label_mobs = VGroup()
braces = VGroup()
for label, part in zip(labels, parts):
brace = Brace(part, direction, min_num_quads = 1, buff = buff)
label_mob = TexMobject(label)
label_mob.scale(self.default_label_scale_val)
if isinstance(label, Mobject):
label_mob = label
else:
label_mob = TexMobject(label)
label_mob.scale(self.default_label_scale_val)
label_mob.next_to(brace, direction, buff)
full_label = VGroup(brace, label_mob)
part.add_label(full_label)
label_brace_groups.add(full_label)
parts.labels = label_brace_groups
braces.add(brace)
label_mobs.add(label_mob)
parts.braces = braces
parts.labels = label_mobs
parts.label_kwargs = {
"labels" : labels,
"direction" : direction,
"buff" : buff,
}
return label_brace_groups
return VGroup(parts.braces, parts.labels)
def get_side_labels(self, labels, direction = LEFT, **kwargs):
def get_side_braces_and_labels(self, labels, direction = LEFT, **kwargs):
assert(hasattr(self, "horizontal_parts"))
parts = self.horizontal_parts
return self.get_subdivision_labels(parts, labels, direction, **kwargs)
return self.get_subdivision_braces_and_labels(parts, labels, direction, **kwargs)
def get_top_labels(self, labels, **kwargs):
def get_top_braces_and_labels(self, labels, **kwargs):
assert(hasattr(self, "vertical_parts"))
parts = self.vertical_parts
return self.get_subdivision_labels(parts, labels, UP, **kwargs)
return self.get_subdivision_braces_and_labels(parts, labels, UP, **kwargs)
def get_bototm_labels(self, labels, **kwargs):
def get_bottom_braces_and_labels(self, labels, **kwargs):
assert(hasattr(self, "vertical_parts"))
parts = self.vertical_parts
return self.get_subdivision_labels(parts, labels, DOWN, **kwargs)
return self.get_subdivision_braces_and_labels(parts, labels, DOWN, **kwargs)
def add_braces_and_labels(self):
for attr in "horizontal_parts", "vertical_parts":
if not hasattr(self, attr):
continue
parts = getattr(self, attr)
for subattr in "braces", "labels":
if hasattr(parts, subattr):
self.add(getattr(parts, subattr))
### Cards ###