mirror of
https://github.com/3b1b/manim.git
synced 2025-08-01 17:29:06 +08:00
Posterior update animations in eop/bayes
This commit is contained in:
@ -3,7 +3,7 @@ from helpers import *
|
||||
from scene import Scene
|
||||
|
||||
from animation.animation import Animation
|
||||
from animation.transform import Transform
|
||||
from animation.transform import Transform, MoveToTarget
|
||||
|
||||
from mobject import Mobject
|
||||
from mobject.vectorized_mobject import VGroup, VMobject, VectorizedPoint
|
||||
@ -22,30 +22,53 @@ class SampleSpaceScene(Scene):
|
||||
def add_sample_space(self, **config):
|
||||
self.add(self.get_sample_space(**config))
|
||||
|
||||
def change_horizontal_division(self, p_list, **kwargs):
|
||||
assert(hasattr(self.sample_space, "horizontal_parts"))
|
||||
added_anims = kwargs.pop("added_anims", [])
|
||||
new_division_kwargs = kwargs.pop("new_division_kwargs", {})
|
||||
added_label_kwargs = kwargs.pop("label_kwargs", {})
|
||||
def get_division_change_animations(
|
||||
self, sample_space, parts, p_list,
|
||||
dimension = 1,
|
||||
new_label_kwargs = None,
|
||||
**kwargs
|
||||
):
|
||||
if new_label_kwargs is None:
|
||||
new_label_kwargs = {}
|
||||
anims = []
|
||||
p_list = sample_space.complete_p_list(p_list)
|
||||
full_space = sample_space.full_space
|
||||
|
||||
curr_parts = self.sample_space.horizontal_parts
|
||||
new_division_kwargs["colors"] = [
|
||||
part.get_color() for part in curr_parts
|
||||
]
|
||||
new_parts = self.sample_space.get_horizontal_division(
|
||||
p_list, **new_division_kwargs
|
||||
)
|
||||
anims = [Transform(curr_parts, new_parts)]
|
||||
if hasattr(curr_parts, "labels"):
|
||||
label_kwargs = curr_parts.label_kwargs
|
||||
label_kwargs.update(added_label_kwargs)
|
||||
new_labels = self.sample_space.get_subdivision_labels(
|
||||
new_parts, **label_kwargs
|
||||
vect = DOWN if dimension == 1 else RIGHT
|
||||
parts.generate_target()
|
||||
for part, p in zip(parts.target, p_list):
|
||||
part.replace(full_space, stretch = True)
|
||||
part.stretch(p, dimension)
|
||||
parts.target.arrange_submobjects(vect, buff = 0)
|
||||
parts.target.move_to(full_space)
|
||||
anims.append(MoveToTarget(parts))
|
||||
if hasattr(parts, "labels"):
|
||||
label_kwargs = parts.label_kwargs
|
||||
label_kwargs.update(new_label_kwargs)
|
||||
new_braces, new_labels = sample_space.get_subdivision_braces_and_labels(
|
||||
parts.target, **label_kwargs
|
||||
)
|
||||
anims.append(Transform(curr_parts.labels, new_labels))
|
||||
anims += added_anims
|
||||
anims += [
|
||||
Transform(parts.braces, new_braces),
|
||||
Transform(parts.labels, new_labels),
|
||||
]
|
||||
return anims
|
||||
|
||||
self.play(*anims, **kwargs)
|
||||
def get_horizontal_division_change_animations(self, p_list, **kwargs):
|
||||
assert(hasattr(self.sample_space, "horizontal_parts"))
|
||||
return self.get_division_change_animations(
|
||||
self.sample_space, self.sample_space.horizontal_parts, p_list,
|
||||
dimension = 1,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def get_vertical_division_change_animations(self, p_list, **kwargs):
|
||||
assert(hasattr(self.sample_space, "vertical_parts"))
|
||||
return self.get_division_change_animations(
|
||||
self.sample_space, self.sample_space.vertical_parts, p_list,
|
||||
dimension = 0,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
class SampleSpace(VGroup):
|
||||
@ -55,7 +78,8 @@ class SampleSpace(VGroup):
|
||||
"width" : 3,
|
||||
"fill_color" : DARK_GREY,
|
||||
"fill_opacity" : 0.8,
|
||||
"stroke_width" : 0,
|
||||
"stroke_width" : 0.5,
|
||||
"stroke_color" : LIGHT_GREY,
|
||||
},
|
||||
"default_label_scale_val" : 0.7,
|
||||
}
|
||||
@ -76,12 +100,16 @@ class SampleSpace(VGroup):
|
||||
def add_label(self, label):
|
||||
self.label = label
|
||||
|
||||
def complete_p_list(self, p_list):
|
||||
new_p_list = list(tuplify(p_list))
|
||||
remainder = 1.0 - sum(new_p_list)
|
||||
if abs(remainder) > EPSILON:
|
||||
new_p_list.append(remainder)
|
||||
return new_p_list
|
||||
|
||||
def get_division_along_dimension(self, p_list, dim, colors, vect):
|
||||
p_list = list(tuplify(p_list))
|
||||
if abs(1.0 - sum(p_list)) > EPSILON:
|
||||
p_list.append(1.0 - sum(p_list))
|
||||
p_list = self.complete_p_list(p_list)
|
||||
colors = color_gradient(colors, len(p_list))
|
||||
perp_dim = 1-dim
|
||||
|
||||
last_point = self.full_space.get_edge_center(-vect)
|
||||
parts = VGroup()
|
||||
@ -89,7 +117,7 @@ class SampleSpace(VGroup):
|
||||
part = SampleSpace()
|
||||
part.set_fill(color, 1)
|
||||
part.replace(self.full_space, stretch = True)
|
||||
part.stretch(factor, perp_dim)
|
||||
part.stretch(factor, dim)
|
||||
part.move_to(last_point, -vect)
|
||||
last_point = part.get_edge_center(vect)
|
||||
parts.add(part)
|
||||
@ -100,14 +128,14 @@ class SampleSpace(VGroup):
|
||||
colors = [GREEN_E, BLUE],
|
||||
vect = DOWN
|
||||
):
|
||||
return self.get_division_along_dimension(p_list, 0, colors, vect)
|
||||
return self.get_division_along_dimension(p_list, 1, colors, vect)
|
||||
|
||||
def get_vertical_division(
|
||||
self, p_list,
|
||||
colors = [MAROON_B, YELLOW],
|
||||
vect = RIGHT
|
||||
):
|
||||
return self.get_division_along_dimension(p_list, 1, colors, vect)
|
||||
return self.get_division_along_dimension(p_list, 0, colors, vect)
|
||||
|
||||
def divide_horizontally(self, *args, **kwargs):
|
||||
self.horizontal_parts = self.get_horizontal_division(*args, **kwargs)
|
||||
@ -117,38 +145,53 @@ class SampleSpace(VGroup):
|
||||
self.vertical_parts = self.get_vertical_division(*args, **kwargs)
|
||||
self.add(self.vertical_parts)
|
||||
|
||||
def get_subdivision_labels(self, parts, labels, direction, buff = SMALL_BUFF):
|
||||
def get_subdivision_braces_and_labels(self, parts, labels, direction, buff = SMALL_BUFF):
|
||||
label_brace_groups = VGroup()
|
||||
label_mobs = VGroup()
|
||||
braces = VGroup()
|
||||
for label, part in zip(labels, parts):
|
||||
brace = Brace(part, direction, min_num_quads = 1, buff = buff)
|
||||
label_mob = TexMobject(label)
|
||||
label_mob.scale(self.default_label_scale_val)
|
||||
if isinstance(label, Mobject):
|
||||
label_mob = label
|
||||
else:
|
||||
label_mob = TexMobject(label)
|
||||
label_mob.scale(self.default_label_scale_val)
|
||||
label_mob.next_to(brace, direction, buff)
|
||||
full_label = VGroup(brace, label_mob)
|
||||
part.add_label(full_label)
|
||||
label_brace_groups.add(full_label)
|
||||
parts.labels = label_brace_groups
|
||||
|
||||
braces.add(brace)
|
||||
label_mobs.add(label_mob)
|
||||
parts.braces = braces
|
||||
parts.labels = label_mobs
|
||||
parts.label_kwargs = {
|
||||
"labels" : labels,
|
||||
"direction" : direction,
|
||||
"buff" : buff,
|
||||
}
|
||||
return label_brace_groups
|
||||
return VGroup(parts.braces, parts.labels)
|
||||
|
||||
def get_side_labels(self, labels, direction = LEFT, **kwargs):
|
||||
def get_side_braces_and_labels(self, labels, direction = LEFT, **kwargs):
|
||||
assert(hasattr(self, "horizontal_parts"))
|
||||
parts = self.horizontal_parts
|
||||
return self.get_subdivision_labels(parts, labels, direction, **kwargs)
|
||||
return self.get_subdivision_braces_and_labels(parts, labels, direction, **kwargs)
|
||||
|
||||
def get_top_labels(self, labels, **kwargs):
|
||||
def get_top_braces_and_labels(self, labels, **kwargs):
|
||||
assert(hasattr(self, "vertical_parts"))
|
||||
parts = self.vertical_parts
|
||||
return self.get_subdivision_labels(parts, labels, UP, **kwargs)
|
||||
return self.get_subdivision_braces_and_labels(parts, labels, UP, **kwargs)
|
||||
|
||||
def get_bototm_labels(self, labels, **kwargs):
|
||||
def get_bottom_braces_and_labels(self, labels, **kwargs):
|
||||
assert(hasattr(self, "vertical_parts"))
|
||||
parts = self.vertical_parts
|
||||
return self.get_subdivision_labels(parts, labels, DOWN, **kwargs)
|
||||
return self.get_subdivision_braces_and_labels(parts, labels, DOWN, **kwargs)
|
||||
|
||||
def add_braces_and_labels(self):
|
||||
for attr in "horizontal_parts", "vertical_parts":
|
||||
if not hasattr(self, attr):
|
||||
continue
|
||||
parts = getattr(self, attr)
|
||||
for subattr in "braces", "labels":
|
||||
if hasattr(parts, subattr):
|
||||
self.add(getattr(parts, subattr))
|
||||
|
||||
|
||||
### Cards ###
|
||||
|
Reference in New Issue
Block a user