El VAE vainilla muestra grupos distintos, mientras que el CVAE tiene una distribución más homogénea. Vanilla VAE codifica la clase y la variación de clase en el espacio latente ya que no se proporciona ninguna señal condicional. Sin embargo, el CVAE no necesita aprender a distinguir clases y el espacio latente puede centrarse en la variación dentro de las clases. Por lo tanto, un CVAE puede potencialmente aprender más información, ya que no depende de tener que aprender el condicionamiento básico de clase.
Se crearon dos arquitecturas modelo para probar la generación de imágenes. La primera arquitectura fue un CVAE convolucional con un enfoque condicional de concatenación. Todas las redes se crearon para imágenes Fashion-MNIST de tamaño 28×28 (784 píxeles en total).
class ConcatConditionalVAE(nn.Module):
def __init__(self, latent_dim=128, num_classes=10):
super().__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.ReLU(),
nn.Flatten()
)
self.flatten_size = 128 * 4 * 4
# Conditional embedding
self.label_embedding = nn.Embedding(num_classes, 32)
# Latent space (with concatenated condition)
self.fc_mu = nn.Linear(self.flatten_size + 32, latent_dim)
self.fc_var = nn.Linear(self.flatten_size + 32, latent_dim)
# Decoder
self.decoder_input = nn.Linear(latent_dim + 32, 4 * 4 * 128)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 2, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def encode(self, x, c):
x = self.encoder(x)
c = self.label_embedding(c)
# Concatenate condition with encoded input
x = torch.cat([x, c], dim=1)
mu = self.fc_mu(x)
log_var = self.fc_var(x)
return mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, c):
c = self.label_embedding(c)
# Concatenate condition with latent vector
z = torch.cat([z, c], dim=1)
z = self.decoder_input(z)
z = z.view(-1, 128, 4, 4)
return self.decoder(z)
def forward(self, x, c):
mu, log_var = self.encode(x, c)
z = self.reparameterize(mu, log_var)
return self.decode(z, c), mu, log_var
El codificador CVAE consta de 3 capas convolucionales, cada una seguida de una no linealidad ReLU. Luego se aplana la salida del codificador. Luego, el número de clase pasa a través de una capa de incrustación y se agrega a la salida del codificador. Luego se utiliza el truco de reparametrización con 2 capas lineales para obtener μ y σ en el espacio latente. Una vez muestreado, la salida del espacio latente reparametrizado se pasa al decodificador ahora concatenado con la salida de la capa de incrustación del número de clase. El decodificador consta de 3 capas convolucionales transpuestas. Los dos primeros contienen una no linealidad ReLU y la última capa contiene una no linealidad sigmoidea. La salida del decodificador es una imagen generada de 28×28.
La otra arquitectura del modelo sigue el mismo enfoque pero agregando la entrada condicional en lugar de concatenar. Una pregunta importante fue si agregar o concatenar conducirá a mejores resultados de reconstrucción o generación.
class AdditiveConditionalVAE(nn.Module):
def __init__(self, latent_dim=128, num_classes=10):
super().__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.ReLU(),
nn.Flatten()
)
self.flatten_size = 128 * 4 * 4
# Conditional embedding
self.label_embedding = nn.Embedding(num_classes, self.flatten_size)
# Latent space (without concatenation)
self.fc_mu = nn.Linear(self.flatten_size, latent_dim)
self.fc_var = nn.Linear(self.flatten_size, latent_dim)
# Decoder condition embedding
self.decoder_label_embedding = nn.Embedding(num_classes, latent_dim)
# Decoder
self.decoder_input = nn.Linear(latent_dim, 4 * 4 * 128)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 2, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def encode(self, x, c):
x = self.encoder(x)
c = self.label_embedding(c)
# Add condition to encoded input
x = x + c
mu = self.fc_mu(x)
log_var = self.fc_var(x)
return mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, c):
# Add condition to latent vector
c = self.decoder_label_embedding(c)
z = z + c
z = self.decoder_input(z)
z = z.view(-1, 128, 4, 4)
return self.decoder(z)
def forward(self, x, c):
mu, log_var = self.encode(x, c)
z = self.reparameterize(mu, log_var)
return self.decode(z, c), mu, log_var
Se utiliza la misma función de pérdida para todos los CVAE de la ecuación que se muestra arriba.
def loss_function(recon_x, x, mu, logvar):
"""Computes the loss = -ELBO = Negative Log-Likelihood + KL Divergence.
Args:
recon_x: Decoder output.
x: Ground truth.
mu: Mean of Z
logvar: Log-Variance of Z
"""
BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
Para evaluar las imágenes generadas por modelos, se utilizan habitualmente tres métricas cuantitativas. El error cuadrático medio (MSE) se calculó sumando los cuadrados de la diferencia entre la imagen generada y una imagen real del terreno en píxeles. La medida del índice de similitud estructural (SSIM) es una métrica que evalúa la calidad de la imagen comparando dos imágenes en función de la información estructural, la luminancia y el contraste. [3]. SSIM se puede utilizar para comparar imágenes de cualquier tamaño, mientras que MSE es relativo al tamaño de píxel. La puntuación SSIM varía de -1 a 1, donde 1 indica imágenes idénticas. La distancia de inicio de Frechet (FID) es una métrica para cuantificar el realismo y la diversidad de las imágenes generadas. Como FID es una medida de distancia, puntuaciones más bajas son indicativas de una mejor reconstrucción de un conjunto de imágenes.
Antes de ampliar a texto completo a imagen, reconstrucción y generación de imágenes CVAE en Fashion-MNIST. Fashion-MNIST es un conjunto de datos similar a MNIST que consta de un conjunto de entrenamiento de 60.000 ejemplos y un conjunto de prueba de 10.000 ejemplos. Cada ejemplo es una imagen en escala de grises de 28×28, asociada con una etiqueta de 10 clases. [4].
Se crearon funciones de preprocesamiento para extraer la palabra clave relevante que contiene el nombre de la clase de la coincidencia de expresiones regulares de texto breve de entrada. Se utilizaron descriptores adicionales (sinónimos) para la mayoría de las clases para dar cuenta de artículos de moda similares incluidos en cada clase (por ejemplo, abrigo y chaqueta).
classes = {
'Shirt':0,
'Top':0,
'Trouser':1,
'Pants':1,
'Pullover':2,
'Sweater':2,
'Hoodie':2,
'Dress':3,
'Coat':4,
'Jacket':4,
'Sandal':5,
'Shirt':6,
'Sneaker':7,
'Shoe':7,
'Bag':8,
'Ankle boot':9,
'Boot':9
}def word_to_text(input_str, classes, model, device):
label = class_embedding(input_str, classes)
if label == -1: return Exception("No valid label")
samples = sample_images(model, num_samples=4, label=label, device=device)
plot_samples(samples, input_str, torch.tensor([label]))
return
def class_embedding(input_str, classes):
for key in list(classes.keys()):
template = f'(?i)\\b{key}\\b'
output = re.search(template, input_str)
if output: return classes[key]
return -1
Luego, el nombre de la clase se convirtió a su número de clase y se usó como entrada condicional para el CVAE. Para generar una imagen, la etiqueta de clase extraída de la breve descripción del texto se pasa al decodificador con muestras aleatorias de una distribución gaussiana para ingresar la variable desde el espacio latente.
Antes de probar la generación, se prueba la reconstrucción de imágenes para garantizar la funcionalidad del CVAE. Debido a la creación de una red convolucional con imágenes de 28×28, la red se puede entrenar en menos de una hora con menos de 100 épocas.
Las reconstrucciones contienen la forma general de las imágenes reales del terreno, pero en la imagen faltan características nítidas y de alta frecuencia. Cualquier texto o patrón de diseño complejo aparece borroso en la salida del modelo. Al ingresar cualquier texto breve que contenga una clase de Fashion-MNIST, se generan resultados que se asemejan a imágenes reconstruidas.
Las imágenes generadas tienen un MSE de 11 y un SSIM de 0,76. Éstas constituyen buenas generaciones, lo que significa que en imágenes simples y pequeñas, los CVAE pueden generar imágenes de calidad. Las GAN y los DDPM producirán imágenes de mayor calidad con características complejas, pero las CVAE pueden manejar casos simples.
Al ampliar la generación de imágenes a texto de cualquier longitud, se necesitarían métodos más sólidos además de la coincidencia de expresiones regulares. Para hacer esto, se utiliza el CLIP de Open AI para convertir texto en un vector de incrustación de alta dimensión. El modelo de incrustación se utiliza en su configuración ViT-B/32, que genera incrustaciones de longitud 512. Una limitación del modelo CLIP es que tiene una longitud máxima de token de 77, y los estudios muestran una longitud efectiva aún menor de 20. [5]. Por lo tanto, en los casos en que el texto de entrada contiene varias oraciones, el texto se divide por oración y se pasa a través del codificador CLIP. Las incrustaciones resultantes se promedian para crear la incrustación de salida final.
Un modelo de texto largo requiere datos de entrenamiento mucho más complicados que Fashion-MNIST, por lo que se utilizó el conjunto de datos COCO. El conjunto de datos COCO tiene anotaciones (que no son completamente sólidas pero que se analizarán más adelante) que se pueden pasar a CLIP para obtener incrustaciones. Sin embargo, las imágenes COCO tienen un tamaño de 640×480, lo que significa que incluso con transformaciones de recorte, se necesita una red más grande. Las arquitecturas de adición y concatenación de entradas condicionales se prueban para la generación de texto largo a imagen, pero el enfoque de concatenación se muestra aquí:
class cVAE(nn.Module):
def __init__(self, latent_dim=128):
super().__init__()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.clip_model, _ = clip.load("ViT-B/32", device=device)
self.clip_model.eval()
for param in self.clip_model.parameters():
param.requires_grad = False
self.latent_dim = latent_dim
# Modified encoder for 128x128 input
self.encoder = nn.Sequential(
nn.Conv2d(3, 32, 4, stride=2, padding=1), # 64x64
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2, padding=1), # 32x32
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, 4, stride=2, padding=1), # 16x16
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 256, 4, stride=2, padding=1), # 8x8
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 512, 4, stride=2, padding=1), # 4x4
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Flatten()
)
self.flatten_size = 512 * 4 * 4 # Flattened size from encoder
# Process CLIP embeddings for encoder
self.condition_processor_encoder = nn.Sequential(
nn.Linear(512, 1024)
)
self.fc_mu = nn.Linear(self.flatten_size + 1024, latent_dim)
self.fc_var = nn.Linear(self.flatten_size + 1024, latent_dim)
self.decoder_input = nn.Linear(latent_dim + 512, 512 * 4 * 4)
# Modified decoder for 128x128 output
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), # 8x8
nn.BatchNorm2d(256),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # 16x16
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # 32x32
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), # 64x64
nn.BatchNorm2d(32),
nn.ReLU(),
nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1), # 128x128
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 3, 3, stride=1, padding=1), # 128x128
nn.Sigmoid()
)
def encode_condition(self, text):
with torch.no_grad():
embeddings = []
for sentence in text:
embeddings.append(self.clip_model.encode_text(clip.tokenize(sentence).to('cuda')).type(torch.float32))
return torch.mean(torch.stack(embeddings), dim=0)
def encode(self, x, c):
x = self.encoder(x)
c = self.condition_processor_encoder(c)
x = torch.cat([x, c], dim=1)
return self.fc_mu(x), self.fc_var(x)
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, c):
z = torch.cat([z, c], dim=1)
z = self.decoder_input(z)
z = z.view(-1, 512, 4, 4)
return self.decoder(z)
def forward(self, x, c):
mu, log_var = self.encode(x, c)
z = self.reparameterize(mu, log_var)
return self.decode(z, c), mu, log_var
Otro punto importante de investigación fue la generación y reconstrucción de imágenes de diferentes tamaños. Específicamente, modificar las imágenes COCO para que tengan un tamaño de 64×64, 128×128 y 256×256. Después de entrenar la red, primero se deben probar los resultados de la reconstrucción.
Todos los tamaños de imagen conducen a un fondo reconstruido con algunos contornos característicos y colores correctos. Sin embargo, a medida que aumenta el tamaño de la imagen, se pueden recuperar más funciones. Esto tiene sentido, ya que aunque llevará mucho más tiempo entrenar un modelo con un tamaño de imagen más grande, el modelo puede capturar y aprender más información.
Con la generación de imágenes, es extremadamente difícil generar imágenes de alta calidad. La mayoría de las imágenes tienen fondos hasta cierto punto y características borrosas en la imagen. Esto sería de esperar para la generación de imágenes a partir de un CVAE. Esto ocurre tanto en la concatenación como en la suma de la entrada condicional, pero el enfoque concatenado funciona mejor. Es probable que esto se deba a que las entradas condicionales concatenadas no interferirán con funciones importantes y garantizarán que la información se conserve de forma distintiva. Las condiciones pueden ignorarse si son irrelevantes. Sin embargo, las entradas condicionales aditivas pueden interferir con las funciones existentes y estropear completamente la red al actualizar los pesos durante la retropropagación.
Todas las imágenes generadas por COCO tienen un SSIM mucho más bajo, de aproximadamente 0,4, en comparación con el SSIM de Fashion-MNIST. MSE es proporcional al tamaño de la imagen, por lo que es difícil cuantificar las diferencias. El FID para las generaciones de imágenes COCO está en los 200 como prueba adicional de que las imágenes generadas por COCO CVAE no son sólidas.
La mayor limitación al intentar utilizar CVAE para la generación de imágenes es, bueno, el CVAE. La cantidad de información que puede contenerse y reconstruirse/generarse depende en gran medida del tamaño del espacio latente. Un espacio latente que sea demasiado pequeño no capturará ninguna información significativa y es proporcional al tamaño de la imagen de salida. Una imagen de 28×28 necesita un espacio latente mucho más pequeño que una imagen de 64×64 (ya que cuadra proporcionalmente el tamaño de la imagen). Sin embargo, un espacio latente más grande que la imagen real agrega información innecesaria y en ese punto simplemente crea un mapeo 1 a 1. Para el conjunto de datos COCO, se necesita un espacio latente de al menos 512 para capturar algunas características. Y si bien los CVAE son modelos generativos, un codificador y decodificador convolucional es una red bastante rudimentaria. El estilo de entrenamiento de una GAN o el complejo proceso de eliminación de ruido de un DDPM permiten una generación de imágenes mucho más complicada.
Otra limitación importante en la generación de imágenes es el conjunto de datos con el que se entrena. Aunque el conjunto de datos COCO tiene anotaciones, estas no están muy detalladas. Para entrenar modelos generativos complejos, se debe utilizar un conjunto de datos diferente para el entrenamiento. COCO no proporciona ubicaciones ni información excesiva para detalles de antecedentes. Un vector de características complejo del codificador CLIP no se puede utilizar de manera efectiva para un CVAE en COCO.
Aunque los CVAE y la generación de imágenes en COCO tienen sus limitaciones, crea un modelo de generación de imágenes viable. Se pueden proporcionar más códigos y detalles, ¡simplemente comuníquese!
[1] Kingma, Diederik P, et. Alabama. “Bayes variacionales de codificación automática”. arXiv:1312.6114 (2013).
[2] Sohn, Kihyuk, et. Alabama. “Aprendizaje de la representación de resultados estructurados mediante modelos generativos condicionales profundos”. Procedimientos de NeurIPS (2015).
[3] Nilsson, J., et. Alabama. “Comprensión de ssim”. arXiv:2102.12037 (2020).
[4] Xiao, Han, et. Alabama. “Fashion-mnist: un nuevo conjunto de datos de imágenes para comparar algoritmos de aprendizaje automático”. arXiv:2403.15378 (2024) (licencia MIT).
[5] Zhang, B., et. Alabama. “Clip largo: desbloquear la capacidad de texto largo del clip”. arXiv:2403.15378 (2024).
¡Una referencia a los socios del proyecto de mi grupo Jake Hession (Consultor de Deloitte), Ashley Hong (Google SWE) y Julian Kuppel (Quant)!