mirror of
https://github.com/helblazer811/ManimML.git
synced 2025-05-24 05:57:17 +08:00
99 lines
3.7 KiB
Python
99 lines
3.7 KiB
Python
import pickle
|
|
import sys
|
|
import os
|
|
sys.path.append(os.environ["PROJECT_ROOT"])
|
|
from autoencoder_models.variational_autoencoder import VAE, load_dataset, load_vae_from_path
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
import scipy
|
|
import scipy.stats
|
|
import cv2
|
|
|
|
def binned_images(model_path, num_x_bins=6, plot=False):
|
|
latent_dim = 2
|
|
model = load_vae_from_path(model_path, latent_dim)
|
|
image_dataset = load_dataset(digit=2)
|
|
# Compute embedding
|
|
num_images = 500
|
|
embedding = []
|
|
images = []
|
|
for i in range(num_images):
|
|
image, _ = image_dataset[i]
|
|
mean, _, recon, _ = model.forward(image)
|
|
mean = mean.detach().numpy()
|
|
recon = recon.detach().numpy()
|
|
recon = recon.reshape(32, 32)
|
|
images.append(recon.squeeze())
|
|
if latent_dim > 2:
|
|
mean = mean[:2]
|
|
embedding.append(mean)
|
|
images = np.stack(images)
|
|
tsne_points = np.array(embedding)
|
|
tsne_points = (tsne_points - tsne_points.mean(axis=0))/(tsne_points.std(axis=0))
|
|
# make vis
|
|
num_points = np.shape(tsne_points)[0]
|
|
x_min = np.amin(tsne_points.T[0])
|
|
y_min = np.amin(tsne_points.T[1])
|
|
y_max = np.amax(tsne_points.T[1])
|
|
x_max = np.amax(tsne_points.T[0])
|
|
# make the bins from the ranges
|
|
# to keep it square the same width is used for x and y dim
|
|
x_bins, step = np.linspace(x_min, x_max, num_x_bins, retstep=True)
|
|
x_bins = x_bins.astype(float)
|
|
num_y_bins = np.absolute(np.ceil((y_max - y_min)/step)).astype(int)
|
|
y_bins = np.linspace(y_min, y_max, num_y_bins)
|
|
# sort the tsne_points into a 2d histogram
|
|
tsne_points = tsne_points.squeeze()
|
|
hist_obj = scipy.stats.binned_statistic_dd(tsne_points, np.arange(num_points), statistic='count', bins=[x_bins, y_bins], expand_binnumbers=True)
|
|
# sample one point from each bucket
|
|
binnumbers = hist_obj.binnumber
|
|
num_x_bins = np.amax(binnumbers[0]) + 1
|
|
num_y_bins = np.amax(binnumbers[1]) + 1
|
|
binnumbers = binnumbers.T
|
|
# some places have no value in a region
|
|
used_mask = np.zeros((num_y_bins, num_x_bins))
|
|
image_bins = np.zeros((num_y_bins, num_x_bins, 3, np.shape(images)[2], np.shape(images)[2]))
|
|
for i, bin_num in enumerate(list(binnumbers)):
|
|
used_mask[bin_num[1], bin_num[0]] = 1
|
|
image_bins[bin_num[1], bin_num[0]] = images[i]
|
|
# plot a grid of the images
|
|
fig, axs = plt.subplots(nrows=np.shape(y_bins)[0], ncols=np.shape(x_bins)[0], constrained_layout=False, dpi=50)
|
|
images = []
|
|
bin_indices = []
|
|
for y in range(num_y_bins):
|
|
for x in range(num_x_bins):
|
|
if used_mask[y, x] > 0.0:
|
|
image = np.uint8(image_bins[y][x].squeeze()*255)
|
|
image = np.rollaxis(image, 0, 3)
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
axs[num_y_bins - 1 - y][x].imshow(image)
|
|
images.append(image)
|
|
bin_indices.append((y, x))
|
|
axs[y, x].axis('off')
|
|
if plot:
|
|
plt.axis('off')
|
|
plt.show()
|
|
else:
|
|
return images, bin_indices
|
|
|
|
def generate_disentanglement(model_path="saved_models/model_dim2.pth"):
|
|
"""Generates disentanglement visualization and serializes it"""
|
|
# Disentanglement object
|
|
disentanglement_object = {}
|
|
# Make Disentanglement
|
|
images, bin_indices = binned_images(model_path)
|
|
disentanglement_object["images"] = images
|
|
disentanglement_object["bin_indices"] = bin_indices
|
|
# Serialize Images
|
|
with open("disentanglement.pkl", "wb") as f:
|
|
pickle.dump(disentanglement_object, f)
|
|
|
|
if __name__ == "__main__":
|
|
plot = False
|
|
if plot:
|
|
model_path = "saved_models/model_dim2.pth"
|
|
#uniform_image_sample(model_path)
|
|
binned_images(model_path)
|
|
else:
|
|
generate_disentanglement() |