En esta publicación, hablo sobre la motivación, las complejidades y los detalles de implementación de la construcción de Torchvista, un paquete de código abierto para visualizar interactivamente el pase hacia adelante de cualquier modelo de Pytorch desde los cuadernos basados en la web.
Para tener una idea del funcionamiento de Torchvista mientras lee esta publicación, puede consultar:
- Página de Github Si desea instalarlo a través de
pipy úselo en cuadernos basados en la web (Jupyter, Colab, Kaggle, VScode, etc.) - Un interactivo página de demostración con varios modelos conocidos visualizados
- A Google Colab tutorial
- Una demostración de video:
Motivación
Los modelos de Pytorch pueden ser muy grandes y complejos, y dar sentido al uno solo del código puede ser un ejercicio agotador e incluso intratable. Tener una visualización similar al gráfico es justo lo que necesitamos para facilitar esto.
Si bien existen herramientas como Netron, Pytorchviz y Torchview que lo hacen más fácil, mi motivación para construir Torchvista fue que descubrí que carecían de algunos o todos estos requisitos:
- Soporte de interacción: El gráfico visualizado debe ser interactivo y no una imagen estática. Debe ser una estructura que pueda hacer zoom, arrastrar, expandir/colapsar, etc. Los modelos pueden ser muy grandes, y si todo lo que ve es una imagen estática gigantesca del gráfico, ¿cómo puede realmente explorarlo?
- Exploración modular: Los modelos de Pytorch grandes son modulares en pensamiento e implementación. Por ejemplo, piense en un módulo que tiene un
Sequentialmódulo que contiene algunosAttentionbloques, que a su vez cada uno tiene bloques completamente conectados que contienenLinearCapas con funciones de activación, etc. La herramienta debe permitirle aprovechar esta estructura modular, y no solo presentar un gráfico de enlace tensor de bajo nivel.
- Soporte de cuaderno: Tendemos a prototipos y construyendo nuestros modelos en cuadernos. Si se proporcionó una herramienta como una aplicación independiente que requería que construyera su modelo y lo cargara para visualizarlo, es un bucle de retroalimentación demasiado largo. Por lo tanto, la herramienta tiene que trabajar idealmente desde los cuadernos.
- Soporte de depuración de errores: Mientras construyendo modelos desde cero, a menudo nos encontramos con muchos errores hasta que el modelo pueda ejecutar un pase completo de adelante a extremo. Por lo tanto, la herramienta de visualización debe ser tolerante al error y mostrarle un gráfico de visualización parcial incluso si hay errores, para que pueda depurar el error.
torch.cat Falló debido a formas tensoras no coincidentes- Rastreo de pases hacia adelante: Pytorch expone de forma nativa un gráfico de pase hacia atrás a través de su sistema Autograd, que el paquete pytorchviz Expone como un gráfico, pero esto es diferente del pase hacia adelante. Cuando construimos, estudiamos e imaginamos modelos, pensamos más sobre el pase hacia adelante, y esto puede ser muy útil para visualizar.
Edificio Torchvista
API básica
El objetivo era tener una API simple que funcione con casi cualquier modelo de Pytorch.
import torch
from transformers import XLNetModel
from torchvista import trace_model
model = XLNetModel.from_pretrained("xlnet-base-cased")
example_input = torch.randint(0, 32000, (1, 10))
# Trace it!
trace_model(model, example_input)
Con una línea de llamadas de código trace_model(<model_instance>, <input>) Debería producir una visualización interactiva del pase hacia adelante.
Pasos involucrados
Detrás de escena, Torchvista, cuando se llama, funciona en dos fases:
- Rastreo: Aquí es donde Torchvista extrae una estructura de datos de gráficos del pase hacia adelante del modelo. Pytorch no expone inherentemente esta estructura de gráficos (a pesar de que expone un gráfico para el pase hacia atrás), por lo que Torchvista tiene que construir esta estructura de datos por sí misma.
- Visualización: Una vez que se extrae el gráfico, Torchvista tiene que producir la visualización real como un gráfico interactivo. El trazador de Torchvista hace esto cargando un archivo HTML de plantilla (con JS incrustado dentro de él) e inyectando objetos de estructura de datos gráficos serializados como cadenas en la plantilla a ser cargadas posteriormente por el motor del navegador.
Rastreo
El rastreo se realiza esencialmente envolviendo (temporalmente) todas las operaciones tensoras importantes y conocidas, y los módulos de pytorch estándar. El objetivo de envolver es modificar las funciones para que cuando se les llame, también hagan la contabilidad necesaria para el rastreo.
Estructura del gráfico
El gráfico que extraemos del modelo es un gráfico dirigido donde:
- Los nodos son las diversas operaciones tensoras y los diversos módulos de Pytorch incorporados que se llaman durante el pase hacia adelante
- Además, los tensores de entrada y salida, y los tensores valorados constantes también son nodos en el gráfico.
- Existe una ventaja de un nodo a otro por cada tensor enviado desde el primero hasta el segundo.
- La etiqueta de borde es la dimensión del tensor asociado.
Pero, la estructura de nuestro gráfico puede ser más complicada porque la mayoría de los módulos de Pytorch llaman operaciones tensoras y, a veces, otros módulos ‘ forward método. Esto significa que tenemos que mantener una estructura gráfica que contenga información para explorarla visualmente en cualquier nivel de profundidad.
Por lo tanto, la estructura que los extractos de Torchvista incluyen dos estructuras de datos principales:
- Lista de adyacencia de las operaciones/módulos de nivel más bajo que se llaman.
input_0 -> [ linear ]
linear -> [ __add__ ]
__getitem__ -> [ __add__ ]
__add__ -> [ multi_head_attention_forward ]
multi_head_attention_forward -> [ dropout ]
dropout -> [ __add__ ]
- Mapa de jerarquía que mapea cada nodo a su contenedor de módulo principal (si está presente)
linear -> Linear
multi_head_attention_forward -> MultiheadAttention
MultiheadAttention -> TransformerEncoderLayer
TransformerEncoderLayer -> TransformerEncoder
Con ambos, podemos construir cualquier visión deseada del pase hacia adelante en la capa de visualización.
Operaciones y módulos de envoltura
Toda la idea detrás de la envoltura es hacer una contabilidad antes y después de la operación real, de modo que cuando se llama la operación, se llama a nuestra función envuelta y se lleva a cabo la contabilidad. Los objetivos de la contabilidad son:
- Registro de conexiones entre nodos basados en referencias de tensor.
- Registre las dimensiones del tensor para mostrar como etiquetas de borde.
- Jerarquía de módulos de registro para módulos en el caso en que los módulos se aniden entre sí
Aquí hay un fragmento de código simplificado de cómo funciona la envoltura:
original_operations = {}
def wrap_operation(module, operation):
original_operations[get_hashable_key(module, operation)] = operation
def wrapped_operation(*args, **kwargs):
# Do the necessary pre-call bookkeeping
do_pre_call_bookkeeping()
# Call the original operation
result = operation(*args, **kwargs)
do_post_call_bookkeeping()
return result
setattr(module, func_name, wrapped_operation)
for module, operation in LONG_LIST_OF_PYTORCH_OPS:
wrap_operation(module, operation)
Y cuando Trace_model está a punto de completar, debemos restablecer todo a su estado original:
for module, operation in LONG_LIST_OF_PYTORCH_OPS:
setattr(module, func_name, original_operations[get_hashable_key(module,
operation)])
Esto se hace de la misma manera para el forward() Métodos de módulos de pytorch incorporados como Linear, Conv2d etc.
Conexiones entre nodos
Como se indicó anteriormente, existe un borde entre dos nodos si se envió un tensor de uno a otro. Esto forma la base de crear conexiones entre nodos mientras construye el gráfico.
Aquí hay un fragmento de código simplificado de cómo funciona esto:
adj_list = {}
def do_post_call_bookkeeping(module, operation, tensor_output):
# Set a "marker" on the output tensor so that whoever consumes it
# knows which operation produced it
tensor_output._source_node = get_hashable_key(module, operation)
def do_pre_call_bookkeeping(module, operation, tensor_input):
source_node = tensor_input._source_node
# Add a link from the producer of the tensor to this node (the consumer)
adj_list[source_node].append(get_hashable_key(module, operation))
Mapa de jerarquía de módulos
Cuando envolvemos módulos, las cosas deben hacerse de manera un poco diferente para construir el mapa de la jerarquía del módulo. La idea es mantener una pila de módulos que se llaman actualmente para que la parte superior de la pila siempre represente en el padre inmediato en el mapa de la jerarquía.
Aquí hay un fragmento de código simplificado de cómo funciona esto:
hierarchy_map = {}
module_call_stack = []
def do_pre_call_bookkeeping_for_module(package, module, tensor_output):
# Add it to the stack
module_call_stack.append(get_hashable_key(package, module))
def do_post_call_bookkeeping_for_module(module, operation, tensor_input):
module_call_stack.pop()
# Top of the stack now is the parent node
hierarchy_map[get_hashable_key(package, module)] = module_call_stack[-1]
Visualización
Esta parte se maneja por completo en Javscript porque la visualización ocurre en los cuadernos basados en la web. Las bibliotecas clave que se usan aquí son:
- GraphViz: para generar el diseño para el gráfico (VIZ-JS es el puerto js)
- D3: para dibujar el gráfico interactivo en un lienzo
- ipython: para representar el contenido de HTML dentro de un cuaderno
Diseño de gráfico
Obtener el diseño para el gráfico correcto es un problema extremadamente complejo. El objetivo principal es que el gráfico tenga un “flujo” de los bordes de arriba a abajo, y lo más importante, para que no haya una superposición entre los diversos nodos, bordes y etiquetas de borde.
Esto se hace aún más complejo cuando trabajamos con un gráfico “jerárquico” donde hay cuadros de “contenedor” para módulos dentro de los cuales se muestran los nodos y subcomponentes subyacentes.
Afortunadamente, GraphViz (VIZ-JS) viene al rescate por nosotros. GraphViz usa un idioma llamado “Lenguaje de puntos“A través del cual especificamos cómo requerimos que se construya el diseño del gráfico.
Aquí hay una muestra de la sintaxis del punto para el gráfico anterior:
# Edges and nodes
"input_0" [width=1.2, height=0.5];
"output_0" [width=1.2, height=0.5];
"input_0" -> "linear_1"[label="(1, 16)", fontsize="10", edge_data_id="5623840688" ];
"linear_1" -> "layer_norm_1"[label="(1, 32)", fontsize="10", edge_data_id="5801314448" ];
"linear_1" -> "layer_norm_2"[label="(1, 32)", fontsize="10", edge_data_id="5801314448" ];
...
# Module hierarchy specified using clusters
subgraph cluster_FeatureEncoder_1 {
label="FeatureEncoder_1";
style=rounded;
subgraph cluster_MiddleBlock_1 {
label="MiddleBlock_1";
style=rounded;
subgraph cluster_InnerBlock_1 {
label="InnerBlock_1";
style=rounded;
subgraph cluster_LayerNorm_1 {
label="LayerNorm_1";
style=rounded;
"layer_norm_1";
}
subgraph cluster_TinyBranch_1 {
label="TinyBranch_1";
style=rounded;
subgraph cluster_MicroBranch_1 {
label="MicroBranch_1";
style=rounded;
subgraph cluster_Linear_2 {
label="Linear_2";
style=rounded;
"linear_2";
}
...
Una vez que esta representación de puntos se genera a partir de nuestra lista de adyacencia y mapa de jerarquía, GraphViz produce un diseño con posiciones y tamaños de todos los nodos y rutas para los bordes.
Representación
Una vez que se genera el diseño, D3 se usa para representar el gráfico visualmente. Todo se dibuja en un lienzo (que es fácil de hacer arrastrable y amplio), y establecemos varios controladores de eventos para detectar los clics de los usuarios.
Cuando el usuario realiza estos dos tipos de clics de expansión/colapso en los módulos (usando los botones ‘+’ ‘-‘), los registros de Torchvista en qué nodo se realizó la acción, y simplemente vuelve a renderizar el gráfico porque el diseño debe reconstruirse, y luego arrastra y se acerca automáticamente a un nivel apropiado basado en la posición registrada previa al click.
Renderizar un gráfico con D3 es un tema muy detallado y de lo contrario no es exclusivo de Torchvista, y por lo tanto, dejo los detalles de esta publicación.
[Bonus] Manejo de errores en modelos Pytorch
Cuando los usuarios rastrean sus modelos Pytorch (especialmente mientras desarrollan los modelos), a veces los modelos arrojan errores. Hubiera sido fácil para Torchvista simplemente rendirse cuando esto sucede y dejar que el usuario corrija el error primero antes de poder usar Torchvista. Pero Torchvista, en cambio, le da una mano al depurar estos errores haciendo el rastreo de mejor esfuerzo del modelo. La idea es simple: solo rastree el máximo que puede hasta que ocurra el error, y luego rinde el gráfico con tanto (con indicadores visuales que muestran dónde ocurrió el error), y luego solo aumente la excepción para que el usuario también pueda ver la StackTrace como lo haría normalmente.
Aquí hay un fragmento de código simplificado de cómo funciona esto:
def trace_model(...):
exception = None
try:
# All the tracing code
except Exception as e:
exception = e
finally:
# do all the necessary cleanups (unwrapping all the operations and modules)
if exception is not None:
raise exception
Concluir
Esta publicación arroja algo de luz sobre el viaje de construir un paquete de visualización de Pytorch. Primero hablamos sobre la motivación muy específica para construir dicha herramienta comparando con otras herramientas similares. Luego, discutimos el diseño e implementación de Torchvista en dos partes. La primera parte fue sobre el proceso de rastrear el pase hacia adelante de un modelo de Pytorch utilizando envoltura (temporal) de operaciones y módulos para extraer información detallada sobre el pase hacia adelante del modelo, incluidas no solo las conexiones entre varias operaciones, sino también la jerarquía del módulo. Luego, en la segunda parte, repasamos la capa de visualización y las complejidades de la generación de diseño, que se resolvieron utilizando la elección correcta de las bibliotecas.
Torchvista es de código abierto, y todas las contribuciones, incluidos los comentarios, los problemas y las solicitudes de extracción, son bienvenidas. Espero que Torchvista ayude a las personas de todos los niveles de especialización en la construcción y visualización de sus modelos (independientemente del tamaño del modelo), mostrando su trabajo y como una herramienta para educar a otros sobre modelos de aprendizaje automático.
Direcciones futuras
Las potenciales mejoras futuras para Torchvista incluyen:
- Agregar soporte para “rodar”, donde si la misma subestructura de un modelo se repite varias veces, se muestra solo una vez con un recuento de cuántas veces se repite
- Exploración sistemática de modelos de vanguardia para garantizar que todas sus operaciones tensoras estén adecuadamente cubiertas
- Soporte para exportar imágenes estáticas de modelos como archivos PNG o PDF
- Mejoras de eficiencia y velocidad
Referencias
- Bibliotecas de código abierto utilizadas:
- Lenguaje de puntos De GraphViz
- Otras herramientas de visualización similar:
- Torchvista:
Todas las imágenes a menos que se indique lo contrario el autor sea de lo contrario.