0rq6zxeod8xcj2giy.jpeg

Un tutorial de ajuste fino de MLLM utilizando el modelo Mini-InternVL de bolsillo más nuevo

Foto por Maarten van den Heuvel en desempaquetar

El mundo de los modelos de lenguajes grandes (LLM) está en constante evolución y surgen rápidamente nuevos avances. Un área interesante es el desarrollo de LLM multimodales (MLLM), capaces de comprender e interactuar tanto con textos como con imágenes. Esto abre un mundo de posibilidades para tareas como comprensión de documentos, respuesta visual a preguntas y más.

Recientemente escribí una publicación general sobre uno de esos modelos que puedes consultar aquí:

Pero en este, exploraremos una combinación poderosa: el modelo InternVL y la técnica de ajuste fino QLoRA. Nos centraremos en cómo podemos personalizar fácilmente dichos modelos para cualquier caso de uso específico. Usaremos estas herramientas para crear un canal de comprensión de recibos que extraiga información clave como el nombre de la empresa, la dirección y el monto total de la compra con alta precisión.

Este proyecto tiene como objetivo desarrollar un sistema que pueda extraer con precisión información específica de recibos escaneados, utilizando las capacidades de InternVL. La tarea presenta un desafío único, que requiere no solo un procesamiento sólido del lenguaje natural (PLN), sino también la capacidad de interpretar el diseño visual de la imagen de entrada. Esto nos permitirá crear un canal único, de extremo a extremo, sin OCR, que demuestre una fuerte generalización en documentos complejos.

Para entrenar y evaluar nuestro modelo, usaremos el SROIE conjunto de datos. SROIE proporciona 1000 imágenes de recibos escaneadas, cada una anotada con entidades clave como:

  • Empresa: El nombre de la tienda o negocio.
  • Fecha: La fecha de compra.
  • Dirección: La dirección de la tienda.
  • Total: El importe total pagado.
Fuente: https://arxiv.org/pdf/2103.10213.pdf.

Evaluaremos el rendimiento de nuestro modelo utilizando una puntuación de similitud difusa, una métrica que mide la similitud entre entidades predichas y reales. Esta métrica va de 0 (resultados irrelevantes) a 100 (predicciones perfectas).

InternVL es una familia de LLM multimodales de OpenGVLab, diseñada para sobresalir en tareas que involucran imágenes y texto. Su arquitectura combina un modelo de visión (como InternViT) con un modelo de lenguaje (como InternLM2 o Phi-3). Nos centraremos en la variante Mini-InternVL-Chat-2B-V1–5, una versión más pequeña que es muy adecuada para ejecutarse en GPU de consumo.

Puntos fuertes de InternVL:

  • Eficiencia: Su tamaño compacto permite un entrenamiento e inferencia eficientes.
  • Precisión: A pesar de ser más pequeño, logra un desempeño competitivo en varios benchmarks.
  • Capacidades multimodales: combina a la perfección la comprensión de imágenes y texto.

Demostración: puede explorar una demostración en vivo de InternVL aquí.

Para mejorar aún más el rendimiento de nuestro modelo, usaremos QLoRA, que es una técnica de ajuste que reduce significativamente el consumo de memoria y al mismo tiempo preserva el rendimiento. Así es como funciona:

  1. Cuantización: el LLM previamente entrenado se cuantifica con una precisión de 4 bits, lo que reduce su huella de memoria.
  2. Adaptadores de bajo rango (LoRA): en lugar de modificar todos los parámetros del modelo previamente entrenado, LoRA agrega adaptadores pequeños y entrenables a la red. Estos adaptadores capturan información específica de la tarea sin requerir cambios en el modelo principal.
  3. Entrenamiento eficiente: la combinación de cuantificación y LoRA permite un ajuste eficiente incluso en GPU con memoria limitada.

Profundicemos en el código. Primero, evaluaremos el rendimiento básico de Mini-InternVL-Chat-2B-V1–5 sin ningún ajuste:

quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)

model = InternVLChatModel.from_pretrained(
args.path,
device_map={"": 0},
quantization_config=quant_config if args.quant else None,
torch_dtype=torch.bfloat16,
)

tokenizer = InternLM2Tokenizer.from_pretrained(args.path)
# set the max number of tiles in `max_num`

model.eval()

pixel_values = (
load_image(image_base_path / "X51005255805.jpg", max_num=6)
.to(torch.bfloat16)
.cuda()
)

generation_config = dict(
num_beams=1,
max_new_tokens=512,
do_sample=False,
)

# single-round single-image conversation
question = (
"Extract the company, date, address and total in json format."
"Respond with a valid JSON only."
)
# print(model)
response = model.chat(tokenizer, pixel_values, question, generation_config)

print(response)

El resultado:

```json
{
"company": "SAM SAM TRADING CO",
"date": "Fri, 29-12-2017",
"address": "67, JLN MENHAW 25/63 TNN SRI HUDA, 40400 SHAH ALAM",
"total": "RM 14.10"
}
```

Este código:

  1. Carga el modelo desde el centro Hugging Face.
  2. Carga una imagen de recibo de muestra y la convierte en un tensor.
  3. Formula una pregunta pidiendo al modelo que extraiga información relevante de la imagen.
  4. Ejecuta el modelo y genera la información extraída en formato JSON.

Esta evaluación de tiro cero muestra resultados impresionantes, logrando una puntuación promedio de similitud difusa de 74,24%. Esto demuestra la capacidad de InternVL para comprender recibos y extraer información sin ajustes.

Para aumentar aún más la precisión, ajustaremos el modelo utilizando QLoRA. Así es como lo implementamos:

_data = load_data(args.data_path, fold="train")

# Quantization Config
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)

model = InternVLChatModel.from_pretrained(
path,
device_map={"": 0},
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)

tokenizer = InternLM2Tokenizer.from_pretrained(path)

# set the max number of tiles in `max_num`
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
print("img_context_token_id", img_context_token_id)
model.img_context_token_id = img_context_token_id

model.config.llm_config.use_cache = False

model = wrap_lora(model, r=128, lora_alpha=256)

training_data = SFTDataset(
data=_data, template=model.config.template, tokenizer=tokenizer
)

collator = CustomDataCollator(pad_token=tokenizer.pad_token_id, ignore_index=-100)

img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
print("img_context_token_id", img_context_token_id)
model.img_context_token_id = img_context_token_id
print("model.img_context_token_id", model.img_context_token_id)

train_params = TrainingArguments(
output_dir=str(BASE_PATH / "results_modified"),
num_train_epochs=EPOCHS,
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
optim="paged_adamw_32bit",
save_steps=len(training_data) // 10,
logging_steps=len(training_data) // 50,
learning_rate=5e-4,
lr_scheduler_type="cosine",
warmup_steps=100,
weight_decay=0.001,
max_steps=-1,
group_by_length=False,
max_grad_norm=1.0,
)
# Trainer
fine_tuning = SFTTrainer(
model=model,
train_dataset=training_data,
dataset_text_field="###",
tokenizer=tokenizer,
args=train_params,
data_collator=collator,
max_seq_length=tokenizer.model_max_length,
)

print(fine_tuning.model.print_trainable_parameters())
# Training
fine_tuning.train()
# Save Model
fine_tuning.model.save_pretrained(refined_model)

Este código:

  1. Carga el modelo con la cuantificación habilitada.
  2. Envuelve el modelo con LoRA, agregando adaptadores entrenables.
  3. Crea un conjunto de datos a partir del conjunto de datos SROIE.
  4. Define argumentos de entrenamiento como la tasa de aprendizaje, el tamaño del lote y las épocas.
  5. Inicializa un entrenador para manejar el proceso de capacitación.
  6. Entrena el modelo en el conjunto de datos SROIE.
  7. Guarda el modelo ajustado.

Aquí hay una comparación de muestra entre el modelo base y el modelo ajustado QLoRA:

Ground Truth: 

{
"company": "YONG TAT HARDWARE TRADING",
"date": "13/03/2018",
"address": "NO 4,JALAN PERJIRANAN 10, TAMAN AIR BIRU, 81700 PASIR GUDANG, JOHOR.",
"total": "72.00"
}

Prediction Base: KO

```json
{
"company": "YONG TAT HARDWARE TRADING",
"date": "13/03/2016",
"address": "JM092487-D",
"total": "67.92"
}
```

Prediction QLoRA: OK

{
"company": "YONG TAT HARDWARE TRADING",
"date": "13/03/2018",
"address": "NO 4, JALAN PERUBANAN 10, TAMAN AIR BIRU, 81700 PASIR GUDANG, JOHOR",
"total": "72.00"
}

Después de realizar ajustes con QLoRA, nuestro modelo logra un notable 95,4% puntuación de similitud difusa, una mejora significativa con respecto al rendimiento inicial (74,24%). Esto demuestra el poder de QLoRA para aumentar la precisión del modelo sin requerir recursos informáticos masivos (entrenamiento de 15 minutos en 600 muestras en una GPU RTX 3080).

Hemos creado con éxito un sistema sólido de comprensión de recibos utilizando InternVL y QLoRA. Este enfoque muestra el potencial de los LLM multimodales para tareas del mundo real como el análisis de documentos y la extracción de información. En este caso de uso de ejemplo, obtuvimos 30 puntos en la calidad de la predicción utilizando unos cientos de ejemplos y unos pocos minutos de tiempo de cálculo en una GPU de consumo.

Puede encontrar la implementación del código completo para este proyecto. aquí.

El desarrollo de LLM multimodales apenas está comenzando y el futuro presenta posibilidades interesantes. El área del procesamiento automatizado de documentos tiene un inmenso potencial en la era de los MLLM. Estos modelos pueden revolucionar la forma en que extraemos información de contratos, facturas y otros documentos, y requieren datos de capacitación mínimos. Al integrar texto y visión, pueden analizar el diseño de documentos complejos con una precisión sin precedentes, allanando el camino para una gestión de la información más eficiente e inteligente.

El futuro de la IA es multimodal, e InternVL y QLoRA son herramientas poderosas que nos ayudan a desbloquear su potencial con un presupuesto informático pequeño.

Enlaces:

Código: https://github.com/CVxTz/doc-llm

Fuente del conjunto de datos: https://rrc.cvc.uab.es/?ch=13&com=introducción
Licencia de conjunto de datos: licenciada bajo un Licencia Creative Commons Atribución 4.0 Internacional.