Inspirado en el reciente video de YouTube de Andrej Kapathy sobre Reproduzcamos GPT-2 (124M)Me gustaría reconstruirlo con la mayoría de las optimizaciones de entrenamiento en Jax. Jax está diseñado para una velocidad de cálculo altamente eficiente, y es bastante interesante comparar Pytorch con su optimización de entrenamiento reciente, y Jax con sus bibliotecas relacionadas como Flax (API de capas para entrenamiento de redes neuronales para Jax) y Optax (una biblioteca de optimización y procesamiento de gradientes para JAX). Aprenderemos rápidamente qué es Jax y reconstruiremos el GPT con Jax. Al final, compararemos el token/seg con entrenamiento multiGPU entre Pytorch y Jax.
¿Qué es Jax?
Basado en su leer el documentoJAX es una biblioteca de Python para computación de matrices orientada a aceleradores y transformación de programas, diseñada para computación numérica de alto rendimiento y aprendizaje automático a gran escala. Me gustaría presentar JAX con su nombre. Mientras que algunos lo llaman Just Another XLA (Álgibra lineal acelerada), prefiero llamarla J(it) A(autograd) X(LA) para demostrar su capacidad de alta eficiencia.
J — Compilación Just-in-time (JIT). Cuando ejecuta su función de Python, Jax la convierte en un conjunto primitivo de operaciones llamado Jaxpr. Luego, la expresión Jaxpr se convertirá en una entrada para XLA, que compila los scripts de nivel inferior para producir un ejecutable optimizado para el dispositivo de destino (CPU, GPU o TPU).
A — Autograd. Calcular gradientes es una parte fundamental de los métodos de aprendizaje automático modernos, y puedes simplemente llamarlos jax.grad() para obtener gradientes que permitan optimizar los modelos.
X — XLA. Este es un compilador de aprendizaje automático de código abierto para aceleradores de CPU, GPU y ML. En general, XLA realiza varias optimizaciones integradas y pasos de análisis en el HLO estable Luego envía el cálculo de HLO a un backend para realizar más optimizaciones a nivel de HLO. El backend luego realiza la generación de código específico del objetivo.
Esas son solo algunas características clave de JAX, pero también tiene muchas API fáciles de usar similares a numpy. jax.numpy y vectorización automática con jax.vmap y paralelice sus códigos en múltiples dispositivos a través de jax.pmap Cubriremos más conceptos y aplicaciones de Jax en los próximos blogs, pero ahora ¡reproduzcamos el NanoGPT con Jax!
De la atención al transformador
GPT es un modelo de transformador de solo decodificador y el bloque de construcción clave es el módulo de Atención. Primero podemos definir una clase de datos de configuración de modelo para guardar los hiperparámetros del modelo, de modo que el módulo del modelo pueda consumirlos de manera eficiente para inicializar la arquitectura del modelo. De manera similar al modelo GPT de 124M, aquí inicializamos un decodificador de transformador de 12 capas con 12 cabezas y un tamaño de vocabulario de 50257 tokens, cada uno de los cuales tiene una dimensión de incrustación de 768. El tamaño de bloque para el cálculo de la atención es 1024.
from dataclasses import dataclass@dataclass
class ModelConfig:
vocab_size: int = 50257
n_head: int = 12
n_embd: int = 768
block_size: int = 1024
n_layer: int = 12
dropout_rate: float = 0.1
A continuación, se llega al componente clave del modelo de transformador: la atención. La idea es procesar las entradas en tres matrices de ponderación: clave, consulta y valor. Aquí nos basamos en la flax una biblioteca de API de entrenamiento y capa Jax para inicializar la matriz de 3 pesos, simplemente llamando a la flax.linen.Dense Como se mencionó, Jax tiene muchas API similares a numpy, por lo que reformulamos las salidas después de la matriz de peso con jax.numpy.reshape de [batch_size, sequence_length, embedding_dim] a [batch_size, sequence_length, num_head, embedding_dim / num_head]Dado que necesitamos hacer una multiplicación de matrices en las matrices de clave y valor, jax también tiene jax.numpy.matmul API y jax.numpy.transpose (transponer la matriz clave para la multiplicación).
Tenga en cuenta que debemos colocar una máscara en la matriz de atención para evitar fugas de información (evitar que los tokens anteriores tengan acceso a los tokens posteriores). jax.numpy.tril ayuda a construir una matriz de triángulos inferiores y jax.numpy.where ¿Puedes completar el número infinito para que obtengamos 0 después de softmax? jax.nn.softmax Los códigos completos de atención multicabezal se pueden encontrar a continuación.
from flax import linen as nn
import jax.numpy as jnpclass CausalSelfAttention(nn.Module):
config: ModelConfig
@nn.compact
def __call__(self, x, deterministic=True):
assert len(x.shape) == 3
b, l, d = x.shape
q = nn.Dense(self.config.n_embd)(x)
k = nn.Dense(self.config.n_embd)(x)
v = nn.Dense(self.config.n_embd)(x)
# q*k / sqrt(dim) -> softmax -> @v
q = jnp.reshape(q, (b, l, d//self.config.n_head , self.config.n_head))
k = jnp.reshape(k, (b, l, d//self.config.n_head , self.config.n_head))
v = jnp.reshape(v, (b, l, d//self.config.n_head , self.config.n_head))
norm = jnp.sqrt(list(jnp.shape(k))[-1])
attn = jnp.matmul(q,jnp.transpose(k, (0,1,3,2))) / norm
mask = jnp.tril(attn)
attn = jnp.where(mask[:,:,:l,:l], attn, float("-inf"))
probs = jax.nn.softmax(attn, axis=-1)
y = jnp.matmul(probs, v)
y = jnp.reshape(y, (b,l,d))
y = nn.Dense(self.config.n_embd)(y)
return y
Es posible que notes que no hay __init__ o forward métodos como se puede ver en pytorch. Esto es lo especial de jax, donde se pueden definir explícitamente las capas con setup métodos, o definirlos implícitamente dentro del paso hacia adelante agregando nn.compact encima de __call__ método. [ref]
A continuación, construyamos la capa MLP y la capa de bloque, que incluye la capa densa, la función de activación Gelu, LayerNorm y Dropout. Nuevamente, flax.linen tiene las API de capa para ayudarnos a construir el módulo. Tenga en cuenta que pasaremos un deterministic Variable booleana para controlar diferentes comportamientos durante el entrenamiento o la evaluación para algunas capas como Dropout.
class MLP(nn.Module):config: ModelConfig
@nn.compact
def __call__(self, x, deterministic=True):
x = nn.Dense(self.config.n_embd*4)(x)
x = nn.gelu(x, approximate=True)
x = nn.Dropout(rate=self.config.dropout_rate)(x, deterministic=deterministic)
x = nn.Dense(self.config.n_embd)(x)
x = nn.Dropout(rate=self.config.dropout_rate)(x, deterministic=deterministic)
return x
class Block(nn.Module):
config: ModelConfig
@nn.compact
def __call__(self, x):
x = nn.LayerNorm()(x)
x = x + CausalSelfAttention(self.config)(x)
x = nn.LayerNorm()(x)
x = x + MLP(self.config)(x)
return x
Ahora usemos los bloques anteriores para construir el NanoGPT:
Dadas las entradas de una secuencia de identificadores de token, utilizamos el flax.linen.Embed capa para obtener incrustaciones de posición e incrustaciones de token. Luego las pasamos al módulo Bloque N veces, donde N es el número de capas definidas en la configuración del modelo. Al final, asignamos las salidas del último Bloque a las probabilidades de cada token en el vocabulario para predecir el siguiente token. Además del avance __call__ método, también vamos a crear un init métodos para obtener las entradas ficticias para obtener los parámetros del modelo.
class GPT(nn.Module):config: ModelConfig
@nn.compact
def __call__(self, x, deterministic=False):
B, T = x.shape
assert T <= self.config.block_size
pos = jnp.arange(0, T)[None]
pos_emb = nn.Embed(self.config.block_size, self.config.n_embd)(pos)
wte = nn.Embed(self.config.vocab_size, self.config.n_embd)
tok_emb = wte(x)
x = tok_emb + pos_emb
for _ in range(self.config.n_layer):
x = Block(self.config)(x)
x = nn.LayerNorm()(x)
logits = nn.Dense(config.n_embd, config.vocab_size)
# logits = wte.attend(x) # parameter sharing
return logits
def init(self, rng):
tokens = jnp.zeros((1, self.config.block_size), dtype=jnp.uint16)
params = jax.jit(super().init, static_argnums=(2,))(rng, tokens, True)
return params
Ahora, vamos a variar la cantidad de parámetros: primero inicializamos la clase de datos de configuración del modelo y la clave aleatoria, luego creamos entradas ficticias y las introducimos en el modelo GPT. Luego, utilizamos jax.util.treemap API para crear una función de parámetro de conteo. Obtuvimos 124439808 (124M) parámetros, la misma cantidad que GPT2 de Huggingface, ¡BOOM!
Cargador de datos y bucle de entrenamiento
Ahora vamos a sobreajustar un conjunto de datos pequeño. Para hacerlo comparable en el video de Andrej sobre Pytorch NanoGPT, usemos el juguete conjunto de datos que compartió en su video. Usamos el tokenizador GPT2 de tiktoken biblioteca para tokenizar todos los textos del archivo de entrada y convertir los tokens en jax.numpy.array para el entrenamiento modelo de Jax.
class DataLoader:
def __init__(self, B, T):
self.current_position = 0
self.B = B
self.T = Twith open("input.txt","r") as f:
text = f.read()
enc = tiktoken.get_encoding("gpt2")
self.tokens = jnp.array(enc.encode(text))
print(f"loaded {len(self.tokens)} tokens in the datasets" )
print(f" 1 epoch = {len(self.tokens)//(B*T)} batches")
def next_batch(self):
B,T = self.B, self.T
buf = self.tokens[self.current_position:self.current_position+B*T+1]
x,y = jnp.reshape(buf[:-1],(B,T)), jnp.reshape(buf[1:],(B,T))
self.current_position += B*T
if self.current_position + B*T+1 > len(self.tokens):
self.current_position = 0
return x,y
A continuación, olvidemos primero el entrenamiento distribuido y la optimización y simplemente creemos un bucle de entrenamiento simple para una comprobación de la coherencia. Lo primero que hay que hacer después de inicializar el modelo es crear un TrenEstadoun estado del modelo en el que podemos actualizar los parámetros y gradientes. TrainState toma tres entradas importantes: apply_fn (función de avance del modelo), params (parámetros del modelo del método init) y tx (una transformación de gradiente de Optax).
Luego usamos la función train_step para actualizar el estado del modelo (gradientes y parámetros) para continuar con el entrenamiento del modelo. Optax Proporcionar la entropía cruzada softmax como función de pérdida para la siguiente tarea de predicción de tokens, y jax.value_and_grad Calcula los gradientes y el valor de pérdida para la función de pérdida. Finalmente, actualizamos el estado del modelo con los nuevos parámetros utilizando el apply_gradients Asignación de funciones. [ref] ¡No olvides modificar la función train_step para reducir la sobrecarga de cálculo!
def init_train_state(key, config) -> TrainState:
model = GPT(config)
params = model.init(key)
optimizer = optax.adamw(3e-4, b1=0.9, b2=0.98, eps=1e-9, weight_decay=1e-1)
train_state = TrainState.create(
apply_fn=model.apply,
params=params,
tx=optimizer)
return train_state@jax.jit
def train_step(state: TrainState, x: jnp.ndarray, y: jnp.ndarray) -> Tuple[jnp.ndarray, TrainState]:
def loss_fn(params: FrozenDict) -> jnp.ndarray:
logits = state.apply_fn(params, x, False)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
return loss
loss, grads = jax.value_and_grad(loss_fn, has_aux=False)(state.params)
new_state = state.apply_gradients(grads=grads)
return loss, new_state
Ahora todo está listo para el bucle de entrenamiento del pobre. Vamos a comprobar el valor de la pérdida. La predicción del modelo debería ser mejor que la estimación aleatoria, por lo que la pérdida debería ser inferior a -ln(1/50257)≈10,825. Lo que esperamos del sobreajuste de un único lote es que: al principio la pérdida sea cercana a 10,825, luego baje hasta cerca de 0. Tomemos un lote de (x, y) y ejecutemos el bucle de entrenamiento 50 veces. También agrego un logaritmo similar para calcular la velocidad de entrenamiento.
Como podemos ver, el valor de pérdida es exactamente el que esperábamos y el rendimiento de entrenamiento es de alrededor de 400–500 k token/seg. Lo cual ya es 40 veces más rápido que la versión inicial de Pytorch sin ninguna optimización en el video de Andrej. Tenga en cuenta que ejecutamos los scripts de Jax en 1 GPU A100, lo que debería eliminar la diferencia de hardware para la comparación de velocidad. No hay .to(device) cosas para mover tu modelo o datos desde la CPU host a la GPU del dispositivo, ¡lo cual es uno de los beneficios de Jax!
Así que eso es todo y lo logramos. Haremos que el entrenamiento sea 10 veces más rápido en la Parte 2 con más optimizaciones…
Parte 2:¡El viaje de la optimización del entrenamiento a 1350k tokens/seg en una sola GPU!
“A menos que se indique lo contrario, todas las imágenes son del autor”