mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-08-23 10:01:02 +08:00
357 lines
14 KiB
Python
357 lines
14 KiB
Python
"""
|
|
Here I thought it would be interesting to visualize the cross attention
|
|
maps produced by the UNet of a text-to-image diffusion model.
|
|
|
|
The key thing I want to show is how images and text are broken up into tokens (patches for images),
|
|
and those tokens are used to compute a cross attention score, which can be interpreted
|
|
as a 2D heatmap over the image patches for each text token.
|
|
|
|
|
|
Necessary operations:
|
|
1. [X] Split an image into patches and "expand" the patches outward to highlight that they are
|
|
separate.
|
|
2. Split text into tokens using the tokenizer and display them.
|
|
3. Compute the cross attention scores (using DAAM) for each word/token and display them as a heatmap.
|
|
4. Overlay the heatmap over the image patches. Show the overlaying as a transition animation.
|
|
"""
|
|
import torch
|
|
import cv2
|
|
from manim import *
|
|
import numpy as np
|
|
from typing import List
|
|
from daam import trace, set_seed
|
|
from diffusers import StableDiffusionPipeline
|
|
import torchvision
|
|
import matplotlib.pyplot as plt
|
|
|
|
class ImagePatches(Group):
|
|
"""Container object for a grid of ImageMobjects."""
|
|
|
|
def __init__(self, image: ImageMobject, grid_width=4):
|
|
"""
|
|
Parameters
|
|
----------
|
|
image : ImageMobject
|
|
image to split into patches
|
|
grid_width : int, optional
|
|
number of patches per row, by default 4
|
|
"""
|
|
self.image = image
|
|
self.grid_width = grid_width
|
|
super(Group, self).__init__()
|
|
self.patch_dict = self._split_image_into_patches(image, grid_width)
|
|
|
|
def _split_image_into_patches(self, image, grid_width):
|
|
"""Splits the images into a set of patches
|
|
|
|
Parameters
|
|
----------
|
|
image : ImageMobject
|
|
image to split into patches
|
|
grid_width : int
|
|
number of patches per row
|
|
"""
|
|
patch_dict = {}
|
|
# Get a pixel array of the image mobject
|
|
pixel_array = image.pixel_array
|
|
# Get the height and width of the image
|
|
height, width = pixel_array.shape[:2]
|
|
# Split the image into an equal width grid of patches
|
|
assert height == width, "Image must be square"
|
|
assert height % grid_width == 0, "Image width must be divisible by grid_width"
|
|
patch_width, patch_height = width // grid_width, height // grid_width
|
|
|
|
for patch_i in range(self.grid_width):
|
|
for patch_j in range(self.grid_width):
|
|
# Get patch pixels
|
|
i_start, i_end = patch_i * patch_width, (patch_i + 1) * patch_width
|
|
j_start, j_end = patch_j * patch_height, (patch_j + 1) * patch_height
|
|
patch_pixels = pixel_array[i_start:i_end, j_start:j_end, :]
|
|
# Make the patch
|
|
image_patch = ImageMobject(patch_pixels)
|
|
# Move the patch to the correct location on the old image
|
|
image_width = image.get_width()
|
|
image_center = image.get_center()
|
|
image_patch.scale(image_width / grid_width / image_patch.get_width())
|
|
patch_manim_space_width = image_patch.get_width()
|
|
|
|
patch_center = image_center
|
|
patch_center += (patch_j - self.grid_width / 2 + 0.5) * patch_manim_space_width * RIGHT
|
|
# patch_center = image_center - (patch_i - self.grid_width / 2 + 0.5) * patch_manim_space_width * RIGHT
|
|
patch_center -= (patch_i - self.grid_width / 2 + 0.5) * patch_manim_space_width * UP
|
|
# + (patch_j - self.grid_width / 2) * patch_height / 2 * UP
|
|
|
|
image_patch.move_to(patch_center)
|
|
|
|
self.add(image_patch)
|
|
patch_dict[(patch_i, patch_j)] = image_patch
|
|
|
|
return patch_dict
|
|
|
|
class ExpandPatches(Animation):
|
|
|
|
def __init__(self, image_patches: ImagePatches, expansion_scale=2.0):
|
|
"""
|
|
Parameters
|
|
----------
|
|
image_patches : ImagePatches
|
|
set of image patches
|
|
expansion_scale : float, optional
|
|
scale factor for expansion, by default 2.0
|
|
"""
|
|
self.image_patches = image_patches
|
|
self.expansion_scale = expansion_scale
|
|
super().__init__(image_patches)
|
|
|
|
def interpolate_submobject(self, submobject, starting_submobject, alpha):
|
|
"""
|
|
Parameters
|
|
----------
|
|
submobject : ImageMobject
|
|
current patch
|
|
starting_submobject : ImageMobject
|
|
starting patch
|
|
alpha : float
|
|
interpolation alpha
|
|
"""
|
|
patch = submobject
|
|
starting_patch_center = starting_submobject.get_center()
|
|
image_center = self.image_patches.image.get_center()
|
|
# Start image vector
|
|
starting_patch_vector = starting_patch_center - image_center
|
|
final_image_vector = image_center + starting_patch_vector * self.expansion_scale
|
|
# Interpolate vectors
|
|
current_location = alpha * final_image_vector + (1 - alpha) * starting_patch_center
|
|
# # Need to compute the direction of expansion as the unit vector from the original image center
|
|
# # and patch center.
|
|
patch.move_to(current_location)
|
|
|
|
class TokenizedText(Group):
|
|
"""Tokenizes the given text and displays the tokens as a list of Text Mobjects."""
|
|
|
|
def __init__(self, text, tokenizer=None):
|
|
self.text = text
|
|
if not tokenizer is None:
|
|
self.tokenizer = tokenizer
|
|
else:
|
|
# TODO load default stable diffusion tokenizer here
|
|
raise NotImplementedError("Tokenizer must be provided")
|
|
|
|
self.token_strings = self._tokenize_text(text)
|
|
self.token_mobjects = self.make_text_mobjects(self.token_strings)
|
|
|
|
super(Group, self).__init__()
|
|
self.add(*self.token_mobjects)
|
|
self.arrange(RIGHT, buff=-0.05, aligned_edge=UP)
|
|
|
|
token_dict = {}
|
|
for token_index, token_string in enumerate(self.token_strings):
|
|
token_dict[token_string] = self.token_mobjects[token_index]
|
|
|
|
self.token_dict = token_dict
|
|
|
|
def _tokenize_text(self, text):
|
|
"""Tokenize the text using the tokenizer.
|
|
|
|
Parameters
|
|
----------
|
|
text : str
|
|
text to tokenize
|
|
"""
|
|
tokens = self.tokenizer.tokenize(text)
|
|
tokens = [token.replace("</w>", "") for token in tokens]
|
|
return tokens
|
|
|
|
def make_text_mobjects(self, tokens_list: List[str]):
|
|
"""Converts the tokens into a list of TextMobjects."""
|
|
# Make the tokens
|
|
# Insert phantom
|
|
tokens_list = ["l" + token + "g" for token in tokens_list]
|
|
token_objects = [
|
|
Text(token, t2c={'[-1:]': '#000000', '[0:1]': "#000000"}, font_size=64)
|
|
for token in tokens_list
|
|
]
|
|
|
|
return token_objects
|
|
|
|
def compute_stable_diffusion_cross_attention_heatmaps(
|
|
pipe,
|
|
prompt: str,
|
|
seed: int = 2,
|
|
map_resolution=(32, 32)
|
|
):
|
|
"""Computes the cross attention heatmaps for the given prompt.
|
|
|
|
Parameters
|
|
----------
|
|
prompt : str
|
|
the prompt
|
|
seed : int, optional
|
|
random seed, by default 0
|
|
map_resolution : tuple, optional
|
|
resolution to downscale maps to, by default (16, 16)
|
|
|
|
Returns
|
|
-------
|
|
_type_
|
|
_description_
|
|
"""
|
|
# Get tokens
|
|
tokens = pipe.tokenizer.tokenize(prompt)
|
|
tokens = [token.replace("</w>", "") for token in tokens]
|
|
# Set torch seeds
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
gen = set_seed(seed) # for reproducibility
|
|
|
|
heatmap_dict = {}
|
|
image = None
|
|
|
|
with torch.cuda.amp.autocast(dtype=torch.float16), torch.no_grad():
|
|
with trace(pipe) as tc:
|
|
out = pipe(prompt, num_inference_steps=30, generator=gen)
|
|
image = out[0][0]
|
|
global_heat_map = tc.compute_global_heat_map()
|
|
for token in tokens:
|
|
word_heat_map = global_heat_map.compute_word_heat_map(token)
|
|
heatmap = word_heat_map.heatmap
|
|
|
|
# Downscale the heatmap
|
|
heatmap = heatmap.unsqueeze(0).unsqueeze(0)
|
|
# Save the heatmap
|
|
heatmap = torchvision.transforms.Resize(
|
|
map_resolution,
|
|
interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
|
)(heatmap)
|
|
heatmap = heatmap.squeeze(0).squeeze(0)
|
|
# heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
|
# plt.imshow(heatmap)
|
|
# plt.savefig(f"{token}.png")
|
|
# Convert heatmap to rgb color
|
|
# normalize
|
|
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
|
|
cmap = plt.get_cmap('inferno')
|
|
heatmap = cmap(heatmap) * 255
|
|
print(heatmap)
|
|
|
|
# print(heatmap)
|
|
# Make an image mobject for each heatmap
|
|
print(heatmap.shape)
|
|
heatmap = ImageMobject(heatmap)
|
|
heatmap.set_resampling_algorithm(RESAMPLING_ALGORITHMS["nearest"])
|
|
heatmap_dict[token] = heatmap
|
|
|
|
return image, heatmap_dict
|
|
|
|
# Make the scene
|
|
config.pixel_height = 1200
|
|
config.pixel_width = 1900
|
|
config.frame_height = 30.0
|
|
config.frame_width = 30.0
|
|
|
|
class StableDiffusionCrossAttentionScene(Scene):
|
|
|
|
def construct(self):
|
|
# Load the pipeline
|
|
model_id = 'stabilityai/stable-diffusion-2-base'
|
|
pipe = StableDiffusionPipeline.from_pretrained(model_id)
|
|
prompt = "Astronaut riding a horse on the moon"
|
|
# Compute the images and heatmaps
|
|
image, heatmap_dict = compute_stable_diffusion_cross_attention_heatmaps(pipe, prompt)
|
|
# 1. Show an image appearing
|
|
image_mobject = ImageMobject(image)
|
|
image_mobject.shift(DOWN)
|
|
image_mobject.scale(0.7)
|
|
image_mobject.shift(LEFT * 7)
|
|
self.add(image_mobject)
|
|
# 1. Show a text prompt and the corresponding generated image
|
|
paragraph_prompt = '"Astronaut riding a\nhorse on the moon"'
|
|
text_prompt = Paragraph(paragraph_prompt, alignment="center", font_size=64)
|
|
text_prompt.next_to(image_mobject, UP, buff=1.5)
|
|
prompt_title = Text("Prompt", font_size=72)
|
|
prompt_title.next_to(text_prompt, UP, buff=0.5)
|
|
self.play(Create(prompt_title))
|
|
self.play(Create(text_prompt))
|
|
# Make an arrow from the text to the image
|
|
arrow = Arrow(text_prompt.get_bottom(), image_mobject.get_top(), buff=0.5)
|
|
self.play(GrowArrow(arrow))
|
|
self.wait(1)
|
|
self.remove(arrow)
|
|
# 2. Show the image being split into patches
|
|
# Make the patches
|
|
patches = ImagePatches(image_mobject, grid_width=32)
|
|
# Expand and contract
|
|
self.remove(image_mobject)
|
|
self.play(ExpandPatches(patches, expansion_scale=1.2))
|
|
# Play arrows for visual tokens
|
|
visual_token_label = Text("Visual Tokens", font_size=72)
|
|
visual_token_label.next_to(image_mobject, DOWN, buff=1.5)
|
|
# Draw arrows
|
|
arrow_animations = []
|
|
arrows = []
|
|
for i in range(patches.grid_width):
|
|
patch = patches.patch_dict[(patches.grid_width - 1, i)]
|
|
arrow = Line(visual_token_label.get_top(), patch.get_bottom(), buff=0.3)
|
|
arrow_animations.append(
|
|
Create(
|
|
arrow
|
|
)
|
|
)
|
|
arrows.append(arrow)
|
|
self.play(AnimationGroup(*arrow_animations, FadeIn(visual_token_label), lag_ratio=0))
|
|
self.wait(1)
|
|
self.play(FadeOut(*arrows, visual_token_label))
|
|
self.play(ExpandPatches(patches, expansion_scale=1/1.2))
|
|
# 3. Show the text prompt and image being tokenized.
|
|
tokenized_text = TokenizedText(prompt, pipe.tokenizer)
|
|
tokenized_text.shift(RIGHT * 7)
|
|
tokenized_text.shift(DOWN * 7.5)
|
|
self.play(FadeIn(tokenized_text))
|
|
# Plot token label
|
|
token_label = Text("Textual Tokens", font_size=72)
|
|
token_label.next_to(tokenized_text, DOWN, buff=0.5)
|
|
arrow_animations = []
|
|
self.play(Create(token_label))
|
|
arrows = []
|
|
for token_name, token_mobject in tokenized_text.token_dict.items():
|
|
arrow = Line(token_label.get_top(), token_mobject.get_bottom(), buff=0.3)
|
|
arrow_animations.append(
|
|
Create(
|
|
arrow
|
|
)
|
|
)
|
|
arrows.append(arrow)
|
|
self.play(AnimationGroup(*arrow_animations, lag_ratio=0))
|
|
self.wait(1)
|
|
self.play(FadeOut(*arrows, token_label))
|
|
# 4. Show the heatmap corresponding to the cross attention map for each image.
|
|
surrounding_rectangle = SurroundingRectangle(tokenized_text.token_dict["astronaut"], buff=0.1)
|
|
self.add(surrounding_rectangle)
|
|
# Show the heatmap for "astronaut"
|
|
astronaut_heatmap = heatmap_dict["astronaut"]
|
|
astronaut_heatmap.height = image_mobject.get_height()
|
|
astronaut_heatmap.shift(RIGHT * 7)
|
|
astronaut_heatmap.shift(DOWN)
|
|
self.play(FadeIn(astronaut_heatmap))
|
|
self.wait(3)
|
|
self.remove(surrounding_rectangle)
|
|
surrounding_rectangle = SurroundingRectangle(tokenized_text.token_dict["horse"], buff=0.1)
|
|
self.add(surrounding_rectangle)
|
|
# Show the heatmap for "horse"
|
|
horse_heatmap = heatmap_dict["horse"]
|
|
horse_heatmap.height = image_mobject.get_height()
|
|
horse_heatmap.move_to(astronaut_heatmap)
|
|
self.play(FadeOut(astronaut_heatmap))
|
|
self.play(FadeIn(horse_heatmap))
|
|
self.wait(3)
|
|
self.remove(surrounding_rectangle)
|
|
surrounding_rectangle = SurroundingRectangle(tokenized_text.token_dict["riding"], buff=0.1)
|
|
self.add(surrounding_rectangle)
|
|
# Show the heatmap for "riding"
|
|
riding_heatmap = heatmap_dict["riding"]
|
|
riding_heatmap.height = image_mobject.get_height()
|
|
riding_heatmap.move_to(horse_heatmap)
|
|
self.play(FadeOut(horse_heatmap))
|
|
self.play(FadeIn(riding_heatmap))
|
|
self.wait(3)
|