Las herramientas de generación de imágenes están más de moda que nunca y nunca han sido más potentes. Modelos como PixArt Sigma y Flux.1 están a la vanguardia, gracias a sus modelos de peso abiertos y licencias permisivas. Esta configuración permite realizar modificaciones creativas, incluido el entrenamiento de LoRA sin compartir datos fuera de su computadora.
Sin embargo, trabajar con estos modelos puede ser un desafío si estás usando GPU más antiguas o con menos VRAM. Normalmente, existe un equilibrio entre calidad, velocidad y uso de VRAM. En esta publicación del blog, nos centraremos en optimizar la velocidad y el menor uso de VRAM, manteniendo al mismo tiempo la mayor calidad posible. Este enfoque funciona excepcionalmente bien para PixArt debido a su menor tamaño, pero los resultados pueden variar con Flux.1. Compartiré algunas soluciones alternativas para Flux.1 al final de esta publicación.
Tanto PixArt Sigma como Flux.1 se basan en transformadores, lo que significa que se benefician de las mismas técnicas de cuantificación que utilizan los modelos de lenguaje grandes (LLM). La cuantificación implica comprimir los componentes del modelo para utilizar menos memoria. Le permite mantener todos los componentes del modelo en la VRAM de la GPU simultáneamente, lo que genera velocidades de generación más rápidas en comparación con los métodos que mueven pesos entre la GPU y la CPU, lo que puede ralentizar las cosas.
¡Vamos a sumergirnos en la configuración!
Configuración de su entorno local
Primero, asegúrese de tener instalados los controladores de Nvidia y Anaconda.
A continuación, cree un entorno de Python e instale todos los requisitos principales:
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
Luego los difusores y las bibliotecas Quanto:
pip install pillow==10.3.0 loguru~=0.7.2 optimum-quanto==0.2.4 diffusers==0.30.0 transformers==4.44.2 accelerate==0.33.0 sentencepiece==0.2.0
Código de cuantificación
A continuación se muestra un script simple para comenzar a utilizar PixArt-Sigma:
from optimum.quanto import qint8, qint4, quantize, freeze
from diffusers import PixArtSigmaPipeline
import torchpipeline = PixArtSigmaPipeline.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16
)
quantize(pipeline.transformer, weights=qint8)
freeze(pipeline.transformer)
quantize(pipeline.text_encoder, weights=qint4, exclude="proj_out")
freeze(pipeline.text_encoder)
pipe = pipeline.to("cuda")
for i in range(2):
generator = torch.Generator(device="cpu").manual_seed(i)
prompt = "Cyberpunk cityscape, small black crow, neon lights, dark alleys, skyscrapers, futuristic, vibrant colors, high contrast, highly detailed"
image = pipe(prompt, height=512, width=768, guidance_scale=3.5, generator=generator).images[0]
image.save(f"Sigma_{i}.png")
Comprender el guión: estos son los pasos principales de la implementación
- Importar las bibliotecas necesarias:Importamos bibliotecas para cuantificación, carga de modelos y manejo de GPU.
- Cargar el modelo:Primero cargamos el modelo PixArt Sigma en precisión media (float16) en la CPU.
- Cuantizar el modelo:Aplicamos cuantificación a los componentes del transformador y del codificador de texto del modelo. Aquí aplicamos diferentes niveles de cuantificación: la parte del codificador de texto se cuantifica en qint4 dado que es bastante grande. La parte de visión, si se cuantifica en qint8, haría que todo el pipeline utilice 7,5 GB de memoria VRAMsi no se cuantifica en absoluto, se utilizaría alrededor de 8,5 GB de memoria VRAM.
- Pasar a la GPU:Trasladamos el pipeline a la GPU
.to("cuda")Para un procesamiento más rápido. - Generar imágenes:Utilizamos el
pipepara generar imágenes basadas en un mensaje dado y guardar el resultado.
Ejecutando el script
Guarde el script y ejecútelo en su entorno. Debería ver una imagen generada en función del mensaje “Ciudad ciberpunk, pequeño cuervo negro, luces de neón, callejones oscuros, rascacielos, futurista, colores vibrantes, alto contraste, muy detallado” guardada como sigma_1.pngLa generación toma 6 segundos en una GPU RTX 3080.
Puedes lograr resultados similares con Flux.1 Schnell, a pesar de sus componentes adicionales, pero necesitaría una cuantificación más agresiva, lo que reduciría negativamente la calidad (a menos que tengas acceso a más VRAM, digamos 16 o 25 Gigas)
import torchfrom optimum.quanto import qint2, qint4, quantize, freeze
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
quantize(pipe.text_encoder, weights=qint4, exclude="proj_out")
freeze(pipe.text_encoder)
quantize(pipe.text_encoder_2, weights=qint2, exclude="proj_out")
freeze(pipe.text_encoder_2)
quantize(pipe.transformer, weights=qint4, exclude="proj_out")
freeze(pipe.transformer)
pipe = pipe.to("cuda")
for i in range(10):
generator = torch.Generator(device="cpu").manual_seed(i)
prompt = "Cyberpunk cityscape, small black crow, neon lights, dark alleys, skyscrapers, futuristic, vibrant colors, high contrast, highly detailed"
image = pipe(prompt, height=512, width=768, guidance_scale=3.5, generator=generator, num_inference_steps=4).images[0]
image.save(f"Schnell_{i}.png")
Podemos ver que la cuantificación del codificador de texto a qint2 y del transformador de visión a qint8 podría ser demasiado agresiva, lo que tuvo un impacto significativo en la calidad de Flux.1 Schnell.
Aquí hay algunas alternativas para ejecutar Flux.1 Schnell:
Si PixArt-Sigma no es suficiente para tus necesidades y no tienes suficiente VRAM para ejecutar Flux.1 con suficiente calidad, tienes dos opciones principales:
- Interfaz de usuario cómoda o Fragua:Estas son herramientas GUI que utilizan los entusiastas, que en su mayoría sacrifican la velocidad por la calidad.
- Replicar API:Cuesta 0,003 por generación de imagen para Schnell.
Despliegue
Me divertí un poco al implementar PixArt Sigma en una máquina más vieja que tengo. Aquí hay un breve resumen de cómo lo hice:
Primero la lista de componentes:
- HTMX y Tailwind:Son como la cara del proyecto. HTMX ayuda a que el sitio web sea interactivo sin mucho código adicional y Tailwind le da un aspecto atractivo.
- API rápida:Toma solicitudes del sitio web y decide qué hacer con ellas.
- Trabajador del apioPiense en esto como el trabajador esforzado. Recibe las órdenes de FastAPI y crea las imágenes.
- Caché/Publicación y suscripción de Redis:Es como un centro de comunicación. Ayuda a que las distintas partes del proyecto se comuniquen entre sí y recuerden cosas importantes.
- GCS (Almacenamiento en la nube de Google):Aquí es donde guardamos las imágenes terminadas.
Ahora bien, ¿cómo funcionan todos juntos? A continuación, se ofrece un resumen sencillo:
- Cuando visita el sitio web y realiza una solicitud, HTMX y Tailwind se aseguran de que se vea bien.
- FastAPI recibe la solicitud y le dice a Celery Worker qué tipo de imagen crear a través de Redis.
- El trabajador del apio se pone a trabajar, creando la imagen.
- Una vez que la imagen está lista, se almacena en GCS, por lo que es fácil acceder a ella.
URL del servicio: https://image-generación-app-340387183829.europe-west1.run.app
Conclusión
Al cuantificar los componentes del modelo, podemos reducir significativamente el uso de VRAM, manteniendo al mismo tiempo una buena calidad de imagen y mejorando la velocidad de generación. Este método es particularmente eficaz para modelos como PixArt Sigma. En el caso de Flux.1, si bien los resultados pueden ser mixtos, los principios de cuantificación siguen siendo aplicables.
Referencias: