Add option for StringMobject to only render one svg

This commit is contained in:
Grant Sanderson
2022-12-20 12:23:19 -08:00
parent 9c106eb873
commit 6176bcd45a

View File

@ -9,6 +9,7 @@ from scipy.spatial.distance import cdist
from manimlib.constants import WHITE from manimlib.constants import WHITE
from manimlib.logger import log from manimlib.logger import log
from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_to_hex from manimlib.utils.color import color_to_hex
from manimlib.utils.color import hex_to_int from manimlib.utils.color import hex_to_int
@ -55,12 +56,17 @@ class StringMobject(SVGMobject, ABC):
base_color: ManimColor = WHITE, base_color: ManimColor = WHITE,
isolate: Selector = (), isolate: Selector = (),
protect: Selector = (), protect: Selector = (),
# When set to true, only the labelled svg is
# rendered, and its contents are used directly
# for the body of this String Mobject
use_labelled_svg: bool = False,
**kwargs **kwargs
): ):
self.string = string self.string = string
self.base_color = base_color or WHITE self.base_color = base_color or WHITE
self.isolate = isolate self.isolate = isolate
self.protect = protect self.protect = protect
self.use_labelled_svg = use_labelled_svg
self.parse() self.parse()
super().__init__( super().__init__(
@ -72,47 +78,34 @@ class StringMobject(SVGMobject, ABC):
) )
self.labels = [submob.label for submob in self.submobjects] self.labels = [submob.label for submob in self.submobjects]
def get_file_path(self) -> str: def get_file_path(self, is_labelled: bool = False) -> str:
original_content = self.get_content(is_labelled=False) is_labelled = is_labelled or self.use_labelled_svg
return self.get_file_path_by_content(original_content) return self.get_file_path_by_content(self.get_content(is_labelled))
@abstractmethod @abstractmethod
def get_file_path_by_content(self, content: str) -> str: def get_file_path_by_content(self, content: str) -> str:
return "" return ""
def generate_mobject(self) -> None: def assign_labels_by_color(self, mobjects: list[VMobject]) -> None:
super().generate_mobject() """
Assuming each mobject in the list `mobjects` has a fill color
meant to represent a numerical label, this assigns those
those numerical labels to each mobject as an attribute
"""
labels_count = len(self.labelled_spans) labels_count = len(self.labelled_spans)
if labels_count == 1: if labels_count == 1:
for submob in self.submobjects: for mob in mobjects:
submob.label = 0 mob.label = 0
return return
labelled_content = self.get_content(is_labelled=True)
file_path = self.get_file_path_by_content(labelled_content)
labelled_svg = SVGMobject(file_path)
if len(self.submobjects) != len(labelled_svg.submobjects):
log.warning(
"Cannot align submobjects of the labelled svg " + \
"to the original svg. Skip the labelling process."
)
for submob in self.submobjects:
submob.label = 0
return
self.rearrange_submobjects_by_positions(labelled_svg)
unrecognizable_colors = [] unrecognizable_colors = []
for submob, labelled_svg_submob in zip( for mob in mobjects:
self.submobjects, labelled_svg.submobjects label = hex_to_int(color_to_hex(mob.get_fill_color()))
):
label = hex_to_int(color_to_hex(
labelled_svg_submob.get_fill_color()
))
if label >= labels_count: if label >= labels_count:
unrecognizable_colors.append(label) unrecognizable_colors.append(label)
label = 0 label = 0
submob.label = label mob.label = label
if unrecognizable_colors: if unrecognizable_colors:
log.warning( log.warning(
"Unrecognizable color labels detected (%s). " + \ "Unrecognizable color labels detected (%s). " + \
@ -123,26 +116,59 @@ class StringMobject(SVGMobject, ABC):
) )
) )
def mobjects_from_file(self, file_path: str) -> list[VMobject]:
submobs = super().mobjects_from_file(file_path)
if self.use_labelled_svg:
# This means submobjects are colored according to spans
self.assign_labels_by_color(submobs)
return submobs
# Otherwise, submobs are not colored, so generate a new list
# of submobject which are and use those for labels
unlabelled_submobs = submobs
labelled_content = self.get_content(is_labelled=True)
labelled_file = self.get_file_path_by_content(labelled_content)
labelled_submobs = super().mobjects_from_file(labelled_file)
self.labelled_submobs = labelled_submobs
self.unlabelled_submobs = unlabelled_submobs
self.assign_labels_by_color(labelled_submobs)
self.rearrange_submobjects_by_positions(labelled_submobs, unlabelled_submobs)
for usm, lsm in zip(unlabelled_submobs, labelled_submobs):
usm.label = lsm.label
if len(unlabelled_submobs) != len(labelled_submobs):
log.warning(
"Cannot align submobjects of the labelled svg " + \
"to the original svg. Skip the labelling process."
)
for usm in unlabelled_submobs:
usm.label = 0
return unlabelled_submobs
return unlabelled_submobs
def rearrange_submobjects_by_positions( def rearrange_submobjects_by_positions(
self, labelled_svg: SVGMobject self, labelled_submobs: list[VMobject], unlabelled_submobs: list[VMobject],
) -> None: ) -> None:
# Rearrange submobjects of `labelled_svg` so that """
# each submobject is labelled by the nearest one of `labelled_svg`. Rearrange `labeleled_submobjects` so that each submobject
# The correctness cannot be ensured, since the svg may is labelled by the nearest one of `unlabelled_submobs`.
# change significantly after inserting color commands. The correctness cannot be ensured, since the svg may
if not labelled_svg.submobjects: change significantly after inserting color commands.
"""
if len(labelled_submobs) == 0:
return return
labelled_svg.replace(self) labelled_svg = VGroup(*labelled_submobs)
labelled_svg.replace(VGroup(*unlabelled_submobs))
distance_matrix = cdist( distance_matrix = cdist(
[submob.get_center() for submob in self.submobjects], [submob.get_center() for submob in unlabelled_submobs],
[submob.get_center() for submob in labelled_svg.submobjects] [submob.get_center() for submob in labelled_submobs]
) )
_, indices = linear_sum_assignment(distance_matrix) _, indices = linear_sum_assignment(distance_matrix)
labelled_svg.set_submobjects([ labelled_submobs[:] = [labelled_submobs[index] for index in indices]
labelled_svg.submobjects[index]
for index in indices
])
# Toolkits # Toolkits