diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index 82fb1605..6499d0af 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -214,7 +214,7 @@ class AddTextWordByWord(ShowIncreasingSubsets): def __init__(self, string_mobject, **kwargs): assert isinstance(string_mobject, LabelledString) - grouped_mobject = string_mobject.get_submob_groups() + grouped_mobject = string_mobject.submob_groups digest_config(self, kwargs) if self.run_time is None: self.run_time = self.time_per_word * len(grouped_mobject) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index f824663d..dab88005 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -160,77 +160,67 @@ class TransformMatchingStrings(AnimationGroup): } def __init__(self, - source_mobject: LabelledString, - target_mobject: LabelledString, + source: LabelledString, + target: LabelledString, **kwargs ): digest_config(self, kwargs) - assert isinstance(source_mobject, LabelledString) - assert isinstance(target_mobject, LabelledString) + assert isinstance(source, LabelledString) + assert isinstance(target, LabelledString) anims = [] - rest_source_indices = list(range(len(source_mobject.submobjects))) - rest_target_indices = list(range(len(target_mobject.submobjects))) + source_indices = list(range(len(source.labelled_submobjects))) + target_indices = list(range(len(target.labelled_submobjects))) + + def get_indices_lists(mobject, parts): + return [ + [ + mobject.labelled_submobjects.index(submob) + for submob in part + ] + for part in parts + ] def add_anims_from(anim_class, func, source_args, target_args=None): if target_args is None: target_args = source_args.copy() for source_arg, target_arg in zip(source_args, target_args): - source_parts = func(source_mobject, source_arg) - target_parts = func(target_mobject, target_arg) - source_indices_lists = source_mobject.indices_lists_of_parts( - source_parts - ) - target_indices_lists = target_mobject.indices_lists_of_parts( - target_parts - ) - filtered_source_indices_lists = list(filter( + source_parts = func(source, source_arg) + target_parts = func(target, target_arg) + source_indices_lists = list(filter( lambda indices_list: all([ - index in rest_source_indices + index in source_indices for index in indices_list - ]), source_indices_lists + ]), get_indices_lists(source, source_parts) )) - filtered_target_indices_lists = list(filter( + target_indices_lists = list(filter( lambda indices_list: all([ - index in rest_target_indices + index in target_indices for index in indices_list - ]), target_indices_lists + ]), get_indices_lists(target, target_parts) )) - if not all([ - filtered_source_indices_lists, - filtered_target_indices_lists - ]): + if not source_indices_lists or not target_indices_lists: continue anims.append(anim_class(source_parts, target_parts, **kwargs)) - for index in it.chain(*filtered_source_indices_lists): - rest_source_indices.remove(index) - for index in it.chain(*filtered_target_indices_lists): - rest_target_indices.remove(index) + for index in it.chain(*source_indices_lists): + source_indices.remove(index) + for index in it.chain(*target_indices_lists): + target_indices.remove(index) - def get_common_substrs(func): + def get_common_substrs(substrs_from_source, substrs_from_target): return sorted([ - substr for substr in func(source_mobject) - if substr and substr in func(target_mobject) + substr for substr in substrs_from_source + if substr and substr in substrs_from_target ], key=len, reverse=True) def get_parts_from_keys(mobject, keys): - if not isinstance(keys, tuple): - keys = (keys,) - indices = [] + if isinstance(keys, str): + keys = [keys] + result = VGroup() for key in keys: - if isinstance(key, int): - indices.append(key) - elif isinstance(key, range): - indices.extend(key) - elif isinstance(key, str): - all_parts = mobject.get_parts_by_string(key) - indices.extend(it.chain(*[ - mobject.indices_of_part(part) for part in all_parts - ])) - else: + if not isinstance(key, str): raise TypeError(key) - return VGroup(VGroup(*[ - mobject[index] for index in remove_list_redundancies(indices) - ])) + result.add(*mobject.get_parts_by_string(key)) + return result add_anims_from( ReplacementTransform, get_parts_from_keys, @@ -239,38 +229,32 @@ class TransformMatchingStrings(AnimationGroup): add_anims_from( FadeTransformPieces, LabelledString.get_parts_by_string, - get_common_substrs(LabelledString.get_specified_substrs) + get_common_substrs( + source.specified_substrs, + target.specified_substrs + ) ) add_anims_from( FadeTransformPieces, LabelledString.get_parts_by_group_substr, - get_common_substrs(LabelledString.get_group_substrs) + get_common_substrs( + source.group_substrs, + target.group_substrs + ) ) - fade_source = VGroup(*[ - source_mobject[index] - for index in rest_source_indices - ]) - fade_target = VGroup(*[ - target_mobject[index] - for index in rest_target_indices - ]) + rest_source = VGroup(*[source[index] for index in source_indices]) + rest_target = VGroup(*[target[index] for index in target_indices]) if self.transform_mismatches: - anims.append(ReplacementTransform( - fade_source, - fade_target, - **kwargs - )) + anims.append( + ReplacementTransform(rest_source, rest_target, **kwargs) + ) else: - anims.append(FadeOutToPoint( - fade_source, - target_mobject.get_center(), - **kwargs - )) - anims.append(FadeInFromPoint( - fade_target, - source_mobject.get_center(), - **kwargs - )) + anims.append( + FadeOutToPoint(rest_source, target.get_center(), **kwargs) + ) + anims.append( + FadeInFromPoint(rest_target, source.get_center(), **kwargs) + ) super().__init__(*anims) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index a2a9f889..32d468a9 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -57,6 +57,7 @@ class LabelledString(_StringSVG): self.pre_parse() self.parse() super().__init__(**kwargs) + self.post_parse() def get_file_path(self) -> str: return self.get_file_path_(use_plain_file=False) @@ -108,6 +109,20 @@ class LabelledString(_StringSVG): self.label_span_list = self.get_label_span_list() self.check_overlapping() + def post_parse(self) -> None: + self.labelled_submobject_items = [ + (submob.label, submob) + for submob in self.submobjects + ] + self.labelled_submobjects = self.get_labelled_submobjects() + self.specified_substrs = self.get_specified_substrs() + self.group_items = self.get_group_items() + self.group_substrs = self.get_group_substrs() + self.submob_groups = self.get_submob_groups() + + def copy(self): + return self.deepcopy() + # Toolkits def get_substr(self, span: Span) -> str: @@ -118,10 +133,14 @@ class LabelledString(_StringSVG): ) -> Iterable[re.Match]: return re.compile(pattern, flags).finditer(self.string, **kwargs) - def search(self, pattern: str, flags: int = 0, **kwargs) -> re.Match: + def search( + self, pattern: str, flags: int = 0, **kwargs + ) -> re.Match | None: return re.compile(pattern, flags).search(self.string, **kwargs) - def match(self, pattern: str, flags: int = 0, **kwargs) -> re.Match: + def match( + self, pattern: str, flags: int = 0, **kwargs + ) -> re.Match | None: return re.compile(pattern, flags).match(self.string, **kwargs) def find_spans(self, pattern: str, **kwargs) -> list[Span]: @@ -275,9 +294,7 @@ class LabelledString(_StringSVG): def color_to_label(color: ManimColor) -> int: rgb_tuple = color_to_int_rgb(color) rgb = LabelledString.rgb_to_int(rgb_tuple) - if rgb == 16777215: # white - return -1 - return rgb + return rgb - 1 @abstractmethod def get_begin_color_command_str(int_rgb: int) -> str: @@ -321,12 +338,11 @@ class LabelledString(_StringSVG): return [] def get_specified_spans(self) -> list[Span]: - spans = [ - self.full_span, - *self.internal_specified_spans, - *self.external_specified_spans, - *self.find_substrs(self.isolate) - ] + spans = list(it.chain( + self.internal_specified_spans, + self.external_specified_spans, + self.find_substrs(self.isolate) + )) shrinked_spans = list(filter( lambda span: span[0] < span[1], [self.shrink_span(span) for span in spans] @@ -381,6 +397,9 @@ class LabelledString(_StringSVG): # Post-parsing + def get_labelled_submobjects(self) -> list[VMobject]: + return [submob for _, submob in self.labelled_submobject_items] + def get_cleaned_substr(self, span: Span) -> str: span_repl_dict = dict.fromkeys(self.command_spans, "") return self.get_replaced_substr(span, span_repl_dict) @@ -391,17 +410,14 @@ class LabelledString(_StringSVG): for span in self.specified_spans ]) - def get_group_span_items(self) -> tuple[list[int], list[Span]]: - submob_labels = [submob.label for submob in self.submobjects] - if not submob_labels: - return [], [] - return tuple(zip(*self.compress_neighbours(submob_labels))) - - def get_group_substrs(self) -> list[str]: - group_labels, _ = self.get_group_span_items() - if not group_labels: + def get_group_items(self) -> list[tuple[str, VGroup]]: + if not self.labelled_submobject_items: return [] + labels, labelled_submobjects = zip(*self.labelled_submobject_items) + group_labels, labelled_submob_spans = zip( + *self.compress_neighbours(labels) + ) ordered_spans = [ self.label_span_list[label] if label != -1 else self.full_span for label in group_labels @@ -425,16 +441,27 @@ class LabelledString(_StringSVG): interval_spans, (ordered_spans[0][0], ordered_spans[-1][1]) ) ] - return [ + group_substrs = [ self.get_cleaned_substr(span) if span[0] < span[1] else "" for span in shrinked_spans ] + submob_groups = VGroup(*[ + VGroup(*labelled_submobjects[slice(*submob_span)]) + for submob_span in labelled_submob_spans + ]) + return list(zip(group_substrs, submob_groups)) - def get_submob_groups(self) -> VGroup: - _, submob_spans = self.get_group_span_items() + def get_group_substrs(self) -> list[str]: + return [group_substr for group_substr, _ in self.group_items] + + def get_submob_groups(self) -> list[VGroup]: + return [submob_group for _, submob_group in self.group_items] + + def get_parts_by_group_substr(self, substr: str) -> VGroup: return VGroup(*[ - VGroup(*self.submobjects[slice(*submob_span)]) - for submob_span in submob_spans + group + for group_substr, group in self.group_items + if group_substr == substr ]) # Selector @@ -477,7 +504,7 @@ class LabelledString(_StringSVG): span_begin = next_begin return result - def get_parts_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup: + def get_part_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup: labels = [ label for label, span in enumerate(self.label_span_list) if any([ @@ -487,10 +514,10 @@ class LabelledString(_StringSVG): ) ]) ] - return VGroup(*filter( - lambda submob: submob.label in labels, - self.submobjects - )) + return VGroup(*[ + submob for label, submob in self.labelled_submobject_items + if label in labels + ]) def get_parts_by_string( self, substr: str, case_sensitive: bool = True, **kwargs @@ -499,19 +526,10 @@ class LabelledString(_StringSVG): if not case_sensitive: flags |= re.I return VGroup(*[ - self.get_parts_by_custom_span(span, **kwargs) + self.get_part_by_custom_span(span, **kwargs) for span in self.find_substr(substr, flags=flags) ]) - def get_parts_by_group_substr(self, substr: str) -> VGroup: - return VGroup(*[ - group - for group, group_substr in zip( - self.get_submob_groups(), self.get_group_substrs() - ) - if group_substr == substr - ]) - def get_part_by_string( self, substr: str, index: int = 0, **kwargs ) -> VMobject: @@ -528,13 +546,5 @@ class LabelledString(_StringSVG): self.set_color_by_string(substr, color, **kwargs) return self - def indices_of_part(self, part: Iterable[VMobject]) -> list[int]: - return [self.submobjects.index(submob) for submob in part] - - def indices_lists_of_parts( - self, parts: Iterable[Iterable[VMobject]] - ) -> list[list[int]]: - return [self.indices_of_part(part) for part in parts] - def get_string(self) -> str: return self.string diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 341db072..5668b183 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -209,7 +209,7 @@ class MTex(LabelledString): right_brace_indices, cmd_end, n_braces ) + 1 if substitute_cmd: - repl_str = "\\" + cmd_name + n_braces * "{white}" + repl_str = "\\" + cmd_name + n_braces * "{black}" else: repl_str = "" result.append(((span_begin, span_end), repl_str)) @@ -270,7 +270,7 @@ class MTex(LabelledString): ] return [ (span, ( - self.get_begin_color_command_str(label), + self.get_begin_color_command_str(label + 1), self.get_end_color_command_str() )) for label, span in enumerate(extended_label_span_list) diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 8dbd05cc..76ae8e38 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -485,12 +485,12 @@ class MarkupText(LabelledString): if not use_plain_file: attr_dict_items = [ (span, { - key: WHITE if key in COLOR_RELATED_KEYS else val + key: BLACK if key in COLOR_RELATED_KEYS else val for key, val in attr_dict.items() }) for span, attr_dict in self.predefined_attr_dicts ] + [ - (span, {"foreground": self.rgb_int_to_hex(label)}) + (span, {"foreground": self.rgb_int_to_hex(label + 1)}) for label, span in enumerate(self.label_span_list) ] else: