Maixtchup: crea tu propia mezcla de expertos con Mergekit

El ascenso de los MoE

Imagen del autor — Generada con DALL-E

Desde el lanzamiento de Mistral AI del Mixtral-8x7B, ha habido un renovado interés en el mezcla de modelos expertos (MoE). Esta arquitectura explota subredes expertas entre las cuales solo algunas de ellas son seleccionadas y activadas por una red de enrutador durante la inferencia.

Los MoE son tan simples y flexibles que es fácil crear un MoE personalizado. En Hugging Face Hub, ahora podemos encontrar varios LLM de tendencia que son MoE personalizados, como mlabonne/phixtral-4x2_8.

Sin embargo, la mayoría de ellos no son MoE tradicionales creados desde cero, simplemente utilizan una combinación de LLM ya ajustados como expertos. Su creación fue fácil con kit de fusión (Licencia LGPL-3.0). Por ejemplo, los LLM de Phixtral se han creado con mergekit combinando varios Modelos Phi-2.

En este artículo veremos cómo se creó Phixtral. Aplicaremos el mismo proceso para crear nuestra propia mezcla de expertos, Maixtchup, utilizando varios modelos de Mistral 7B.

Para comprender rápidamente la arquitectura de alto nivel de un modelo, me gusta imprimirlo. Por ejemplo, para mlabonne/phixtral-4x2_8 (licencia MIT):

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"mlabonne/phixtral-4x2_8",
torch_dtype="auto",
load_in_4bit=True,
trust_remote_code=True
)
print(model)

Imprime:

PhiForCausalLM(
(transformer): PhiModel(
(embd): Embedding(
(wte): Embedding(51200, 2560)
(drop): Dropout(p=0.0, inplace=False)
)
(h): ModuleList(
(0-31): 32 x ParallelBlock(
(ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
(resid_dropout): Dropout(p=0.1, inplace=False)
(mixer): MHA(
(rotary_emb): RotaryEmbedding()
(Wqkv): Linear4bit(in_features=2560, out_features=7680, bias=True)
(out_proj): Linear4bit(in_features=2560, out_features=2560, bias=True)
(inner_attn): SelfAttention(
(drop): Dropout(p=0.0, inplace=False)
)
(inner_cross_attn): CrossAttention(
(drop): Dropout(p=0.0, inplace=False)
)
)
(moe): MoE(
(mlp): ModuleList(…