Added Mobject.set_color_by_rgba_func

This commit is contained in:
Grant Sanderson
2021-02-25 08:46:56 -08:00
parent 8fcb069808
commit d06b3769b8
2 changed files with 52 additions and 11 deletions

View File

@ -504,7 +504,7 @@ class Mobject(object):
self.refresh_has_updater_status() self.refresh_has_updater_status()
if call_updater: if call_updater:
self.update() self.update(dt=0)
return self return self
def remove_updater(self, update_function): def remove_updater(self, update_function):
@ -841,7 +841,30 @@ class Mobject(object):
# Color functions # Color functions
def set_rgba_array(self, color=None, opacity=None, name="rgbas", recurse=True): def set_rgba_array(self, rgba_array, name="rgbas", recurse=False):
for mob in self.get_family(recurse):
mob.data[name] = np.array(rgba_array)
return self
def set_color_by_rgba_func(self, func, recurse=True):
"""
Func should take in a point in R3 and output an rgba value
"""
for mob in self.get_family(recurse):
rgba_array = [func(point) for point in mob.get_points()]
mob.set_rgba_array(rgba_array)
return self
def set_color_by_rgb_func(self, func, opacity=1, recurse=True):
"""
Func should take in a point in R3 and output an rgb value
"""
for mob in self.get_family(recurse):
rgba_array = [[*func(point), opacity] for point in mob.get_points()]
mob.set_rgba_array(rgba_array)
return self
def set_rgba_array_by_color(self, color=None, opacity=None, name="rgbas", recurse=True):
if color is not None: if color is not None:
rgbs = np.array([color_to_rgb(c) for c in listify(color)]) rgbs = np.array([color_to_rgb(c) for c in listify(color)])
if opacity is not None: if opacity is not None:
@ -870,8 +893,8 @@ class Mobject(object):
return self return self
def set_color(self, color, opacity=None, recurse=True): def set_color(self, color, opacity=None, recurse=True):
self.set_rgba_array(color, opacity, recurse=False) self.set_rgba_array_by_color(color, opacity, recurse=False)
# Recurse to submobjects differently from how set_rgba_array # Recurse to submobjects differently from how set_rgba_array_by_color
# in case they implement set_color differently # in case they implement set_color differently
if recurse: if recurse:
for submob in self.submobjects: for submob in self.submobjects:
@ -879,7 +902,7 @@ class Mobject(object):
return self return self
def set_opacity(self, opacity, recurse=True): def set_opacity(self, opacity, recurse=True):
self.set_rgba_array(color=None, opacity=opacity, recurse=False) self.set_rgba_array_by_color(color=None, opacity=opacity, recurse=False)
if recurse: if recurse:
for submob in self.submobjects: for submob in self.submobjects:
submob.set_opacity(opacity, recurse=True) submob.set_opacity(opacity, recurse=True)

View File

@ -106,24 +106,42 @@ class VMobject(Mobject):
self.set_flat_stroke(self.flat_stroke) self.set_flat_stroke(self.flat_stroke)
return self return self
def set_rgba_array(self, rgba_array, name=None, recurse=False):
if name is None:
names = ["fill_rgba", "stroke_rgba"]
else:
names = [name]
for name in names:
super().set_rgba_array(rgba_array, name, recurse)
return self
def set_fill(self, color=None, opacity=None, recurse=True): def set_fill(self, color=None, opacity=None, recurse=True):
self.set_rgba_array(color, opacity, 'fill_rgba', recurse) self.set_rgba_array_by_color(color, opacity, 'fill_rgba', recurse)
return self return self
def set_stroke(self, color=None, width=None, opacity=None, background=None, recurse=True): def set_stroke(self, color=None, width=None, opacity=None, background=None, recurse=True):
self.set_rgba_array(color, opacity, 'stroke_rgba', recurse) self.set_rgba_array_by_color(color, opacity, 'stroke_rgba', recurse)
if width is not None: if width is not None:
for mob in self.get_family(recurse): for mob in self.get_family(recurse):
mob.data['stroke_width'] = np.array([ if isinstance(width, np.ndarray):
[width] for width in listify(width) arr = width.reshape((len(width), 1))
]) else:
arr = np.array([[w] for w in listify(width)])
mob.data['stroke_width'] = arr
if background is not None: if background is not None:
for mob in self.get_family(recurse): for mob in self.get_family(recurse):
mob.draw_stroke_behind_fill = background mob.draw_stroke_behind_fill = background
return self return self
def align_stroke_width_data_to_points(self, recurse=True):
for mob in self.get_family(recurse):
mob.data["stroke_width"] = resize_with_interpolation(
mob.data["stroke_width"], len(mob.get_points())
)
def set_style(self, def set_style(self,
fill_color=None, fill_color=None,
fill_opacity=None, fill_opacity=None,
@ -259,7 +277,7 @@ class VMobject(Mobject):
return self.get_fill_color() return self.get_fill_color()
def has_stroke(self): def has_stroke(self):
return any(self.get_stroke_widths()) and any(self.get_stroke_opacities()) return self.get_stroke_widths().any() and self.get_stroke_opacities().any()
def has_fill(self): def has_fill(self):
return any(self.get_fill_opacities()) return any(self.get_fill_opacities())