0pv9sbx7qza1mvhfi.jpeg

Optimización del uso de aceleradores de entrenamiento de IA limitados: Parte 2

Foto por Adrien Aletti en desempaquetar

Esta publicación fue creada en colaboración con Max Rabin.

Esta es la segunda parte de una serie de publicaciones sobre el tema de maximizar la utilidad de los escasos recursos de IA. En el Primer comentario Observamos las crecientes limitaciones en la capacidad de ampliar los recursos de IA a voluntad y, como consecuencia, la creciente tendencia de los equipos de desarrollo de IA a garantizar la capacidad informática de la IA por medios como la creación de una granja de servidores de IA interna y/o la reserva de recursos. Instancias dedicadas en la nube. La escasez de recursos informáticos de IA motiva el diseño de soluciones de programación especializadas para minimizar el tiempo de inactividad y priorizar las cargas de trabajo críticas. Por favor vea nuestro Publicación anterior en el que propusimos una lista detallada de requisitos para tales soluciones. El enfoque que adoptamos fue aprovechar el sistema basado en prioridades existente. planificador eso viene con Kubernetes y alinear nuestro flujo de trabajo de desarrollo de capacitación con su uso. En esta publicación exploramos la opción de mantener nuestro marco existente para entrenar modelos de IA y mejorarlo con nuestra propia implementación personalizada de un programador basado en prioridades. Es importante destacar que la necesidad de este tipo de solución a menudo está motivada no solo por la escasez de recursos de IA, sino también por el deseo de aumentar el control sobre la orquestación y priorización de las cargas de trabajo de capacitación para reducir los costos de desarrollo. Por ejemplo, incluso en un escenario de capacidad abundante, puede optar por limitar su uso a un número fijo de instancias de capacitación para limitar su gasto en capacitación.

Para los fines de esta publicación, asumiremos que nuestro marco de capacitación preferido es el servicio administrado de AWS para la capacitación de modelos de IA. Amazon SageMaker. La solución que propondremos utilizará servicios adicionales de AWS, como AmazonDynamoDB y AWS Lambda. La elección de demostrar nuestra solución utilizando los servicios de AWS no debe considerarse como un respaldo. Hay muchas ofertas de servicios basados ​​en la nube disponibles y la mejor para usted dependerá de los detalles particulares de su proyecto. Se pueden diseñar soluciones similares a la que describiremos en otros entornos basados ​​en la nube y/o utilizar servicios alternativos basados ​​en la nube.

Tradicionalmente, iniciaríamos un trabajo de capacitación de SageMaker utilizando el SDK de Python de Amazon SageMaker. En el bloque de código siguiente utilizamos SageMaker SDK (versión 2.208) para ejecutar una carga de trabajo de entrenamiento de PyTorch en una única instancia de tipo. p5.48xgrande.

desde sagemaker.pytorch importar PyTorch

# definir trabajo
estimador = PyTorch(
role='',
punto_entrada='tren.py',
instancia_type='ml.p5.48xlarge',
recuento_instancia=1,
framework_version='2.0.1',
py_version='py310',
etiquetas =[{'Key': 'priority', 'Value': '100'}
)

# start job
estimator.fit()

When the estimator.fit() function is called, the SageMaker library uploads our code to Amazon S3 and then transforms the request to a boto3 SageMaker client create_training_job request (see here).

This method for starting up training jobs is dependent on the availability of the requested resources for its success. In our scenario of scarce AI resources, it is likely to fail more often than not. Although this can be partially mitigated by retaining provisioned compute instances for successive workloads, the API does not provide the appropriate tooling for maximizing their utility. Let’s suppose that we wish to utilize precisely two p5.48xlarge instances. To simplify our discussion, let’s assume that each training workload runs on a single instance. Typically, during an AI model development cycle there will be periods when there are more than two training workloads that are waiting to be processed. The existing API would try to start up a third p5.48xlarge instance and would most likely fail due to its limited availability. Even when there is instance availability, we may wish to limit our training to just our two designated instances to increase our control over the costs of training.

We require a new API for submitting jobs for training, one that does not immediately start up a new p5.48xlarge instance, but rather enters the jobs to a priority queue. And we need an associated job scheduler that manages the use of our two resources while prioritizing critical workloads.

Importantly, please note that as of the time of this writing, Amazon SageMaker does not support the option of training on reserved Amazon EC2 instances. And although Amazon SageMaker Savings Plans has similar properties to instance reservations, it does not guarantee instance capacity. In a previous post we addressed this limitation and proposed using SageMaker managed warm pools as an alternative method for retaining access to provisioned instances. For the remainder of the post, we will assume that we are able to attain two instances of our choice whether it be through this or some other method.

In this section we will describe the components of our proposed solution. We will use the AWS Serverless Application Model (SAM) specification. More specifically, we will create an AWS SAM template YAML file and gradually add the AWS resources that we need. Please see the documentation for details on how to define and deploy serverless solutions using AWS SAM.

AWS Architecture Diagram (by Author)

We start by using Amazon API Gateway to define a private REST API for submitting training job requests. We name the API training-job-queue. Later, we will add a POST method called add-job and modify our training-job creation code to use this method instead of the SageMaker client create_training_job API. The code block below contains the definition of the private API resource in SAM. In practice you will likely want to specify access limitations to the API and/or a method of authorization.

AWSTemplateFormatVersion: '2010-09-09'
Transform: AWS::Serverless-2016-10-31

Resources:
InternalAPI:
Type: AWS::Serverless::Api
# Auth: # Add access control to API
EndpointConfiguration:
Type: PRIVATE
# VPCEndpointIds: # Specify VPC Endpoint(s)
Name: training-job-queue
StageName: prod

Define an AWS DynamoDB Table for Storing Training Job Requests

We will use an Amazon DynamoDB table named sagemaker-queue to store the submitted training workloads. Each entry will have the following fields:

  1. jobName: Stores the unique name of the training job.
  2. entryTime: Stores the date and time that the job was added.
  3. jobState: Stores the current state of the training job. The valid values are ‘pending’, ‘running’, and ‘preempted’.
  4. priority: Stores an integer value representing the relative priority of the job.
  5. jobDetails: Stores the details of the job request.

We define our DynamoDB table in our SAM template YAML file using the AWS::Serverless::SimpleTable resource.

  DynamoSMQueue:
Type: AWS::Serverless::SimpleTable
Properties:
PrimaryKey:
Name: jobName
Type: String
TableName: sagemaker-queue

We define a function that creates a table entry from a given training job request. We assume that request contains the same contents as the input to the create_training_job API in JSON format. We further assume that the priority of the workload is entered as a key-value tag in the training job definition.

import json, boto3, datetime

dynamodb = boto3.resource('dynamodb')
table = dynamodb.Table('sagemaker-queue')

def add_job_entry(job_json):
job_details = json.loads(job_json)

# extract job_name
job_name = job_details['TrainingJobName']
print(f'agregar entrada {nombre_trabajo}')

# obtener la hora actual
tiempo_entrada = datetime.now().strftime("%Y-%m-%dT%H:%M:%S")

# la prioridad predeterminada es 0
prioridad = 0

# prioridad de actualización basada en etiquetas
etiquetas = detalles_trabajo['Tags']
para etiqueta en etiquetas:
si etiqueta['Key'] == 'prioridad':
prioridad = int(etiqueta['Value'])
romper

# crear entrada
entrada = {
'nombre_trabajo': nombre_trabajo,
'hora_entrada': hora_entrada,
'jobState': 'pendiente',
'prioridad': prioridad,
'Detalles del trabajo': job_json
}
table.put_item(Item=entry) #TODO manejar errores
print(f'Trabajo agregado {nombre_trabajo} a la cola')

La API REST agregar trabajo El método que pronto definiremos será programado para llamar al agregar_entrada_trabajo función.

Definimos una segunda función que extrae los trabajos pendientes de la base de datos y los devuelve en orden de prioridad. En el caso de que varios trabajos tengan la misma prioridad, se ordenan según el tiempo que llevan esperando en la cola.

from boto3.dynamodb.conditions import Attr

# Get a list of all pending jobs sorted by priority
def get_pending_jobs():
response = table.scan(
ProjectionExpression='jobName, priority, entryTime',
FilterExpression=Attr('jobState').ne('running')
)
jobs = response.get('Items', [])

# sort jobs, first by priority (descending) and then by entryTime
sorted_jobs = sorted(jobs,
key=lambda x: (-x['priority'], x['entryTime']))

return sorted_jobs

Las siguientes funciones de utilidad serán útiles en las siguientes secciones.

# Get a jobName -> priority mapping of all running jobs
def get_running_jobs_dict():
# Get all running jobs
response = table.scan(
ProjectionExpression="jobName, priority",
FilterExpression=Attr('jobState').eq('running')
)
jobs = response.get('Items', [])

running_jobs = {job['jobName']: job['priority'] for job in jobs}

return running_jobs

# Print the queue state
def print_queue_state():
response = table.scan(
ProjectionExpression='jobName, jobState, priority'
)
jobs = response.get('Items', [])

print_table = []
for job in jobs:
print_table.append([job['jobName'], job['jobState'], job['priority']])

# sort by priority
sorted_table = sorted(print_table,
key=lambda x: -x[2])
# Print the table
from tabulate import tabulate
print(tabulate(sorted_table, headers=['Job Name', 'State', 'Priority']))

# get job details
def get_job_details(job_name):
response = table.get_item(
Key={'jobName': job_name},
ProjectionExpression='jobDetails'
)
return json.loads(response.get('Item').get('jobDetails'))

# get job state or None if the job does not exist
def get_job_state(job_name):
response = table.get_item(
Key={'jobName': job_name},
ProjectionExpression='jobState'
)
job = response.get('Item')
return job.get('jobState') if job else None

# update the job state
def update_job_state(job_name, new_state):
table.update_item(
Key={'jobName': job_name},
UpdateExpression="SET jobState = :new_state",
ExpressionAttributeValues={":new_state": new_state}
)
print(f'Update job {job_name} to {new_state}')

# remove a job entry
def remove_job(job_name):
table.delete_item(
Key={'jobName': job_name}
)
print(f'Removed job {job_name} from queue')

Tanto nuestra elección de DynamoDB como su uso (por ejemplo, nuestro uso del Escanear API en lugar de la Consulta API) suponen que el número total de trabajos en nuestra cola será de docenas, como máximo. Para una solución a mayor escala, puede que le convenga una base de datos más pesada (por ejemplo, una que realice la operación de clasificación por usted) o un uso más sofisticado de DynamoDB (por ejemplo, consulte aquí).

Definir el administrador de cola de trabajos de capacitación

El componente principal de nuestra solución es el programador de trabajos de capacitación. Aquí implementamos un administrador bastante simple que realiza los siguientes pasos:

  1. Extraiga la lista de trabajos en cola, ordenados por prioridad. Si no existe ninguno, regrese.
  2. Descubra la capacidad de instancia no utilizada. Por cada instancia gratuita, inicie un trabajo pendiente en SageMaker. Si no quedan trabajos después de eso, regresa.
  3. Calcule el número de trabajos de SageMaker en el Parada estado. Si es mayor que el número de trabajos pendientes, regrese.
  4. Evaluar la necesidad de adelantarse en la ejecución de trabajos de SageMaker comparando sus prioridades a los de nuestros trabajos pendientes.
# set the limit on total number of instances/jobs
MAX_CAPACITY = 2

sagemaker = boto3.client('sagemaker')

# apply a queue stamp to identify that the job came from the queue
def apply_qstamp(job_name):
return f'{job_name}-qstamp-{datetime.now().strftime("%d%H%M")}'

# strip the queue stamp
def strip_qstamp(job_name):
return job_name.split('-qstamp-')[0]

# start a SageMaker job and update job entry in queue
def start_job(job_name):
print(f'start job {job_name}')
job_details = get_job_details(job_name)
job_details['TrainingJobName'] = apply_qstamp(job_name)
if(job_details):
# start job with detail from queue
# (you may optinally overwrite fields such as the iam role)
response = sagemaker.create_training_job(**job_details)
if response['ResponseMetadata']['HTTPStatusCode'] == 200:
print(f'started job {job_name}')
update_job_state(job_name, 'running')

# preempt a SageMaker job and update job entry in queue
def preempt_job(job_name):
print(f'preempt job {job_name}')
response = sagemaker.stop_training_job(TrainingJobName=job_name)
if response['ResponseMetadata']['HTTPStatusCode'] == 200:
print(f'preempted job {job_name}')
update_job_state(strip_qstamp(job_name), 'preempted')

# get SageMaker jobs
def get_sagemaker_jobs(status):
running = sagemaker.list_training_jobs(StatusEquals=status)
return running.get('TrainingJobSummaries', [])

# queue manager
def manage_queue():
# extract pending jobs to run
pending = get_pending_jobs()

if not pending:
return

if len(pending) > MAX_CAPACITY:
pending = pending[:MAX_CAPACITY]

# get running sagemaker jobs
running = get_sagemaker_jobs('InProgress')
total_running = len(running)

# get stopping sagemaker jobs
stopping = get_sagemaker_jobs('Stopping')
total_stopping = len(stopping)

# calculate the number of free instances
free_slots = MAX_CAPACITY - total_running - total_stopping

jobs_to_start = min(len(pending), free_slots)

# for each free instance, start a job
for i in range(jobs_to_start):
start_job(pending[i].get('jobName'))

still_pending = pending[jobs_to_start:]

if not still_pending:
return

# assume that 'total_stopping' number of jobs will start soon
test_for_preemption = len(still_pending) - total_stopping
if test_for_preemption <= 0:
return

# check if preemption is required
test_priority = still_pending[total_stopping:]

running_jobs = get_running_jobs_dict()
priority_dict = {}
for job in running:
job_name = job['TrainingJobName']
priority_dict[job_name] = running_jobs[strip_qstamp(job_name)]

# sort running jobs from lowest to highest priority
sorted_running = sorted(priority_dict.items(), key=lambda item: item[1])

index = 0
while index < test_for_preemption and \
test_priority[index].get('priority') > sorted_running[index][1]:
preempt_job(sorted_running[index][0])
index = index + 1

Notas importantes:

  1. Nuestra implementación es altamente optimista en el sentido de que asumimos que todos los trabajos que se insertan son válidos y que podremos iniciarlos en SageMaker sin problemas. En la práctica, se debe agregar un manejo de errores adecuado (por ejemplo, eliminar trabajos defectuosos de la cola con el registro adecuado).
  2. En un entorno de producción, necesitaríamos tener en cuenta la probabilidad de que se produzca un condición de carrera cuando nuestra administrador_de_colas es desencadenado por múltiples eventos simultáneos. Hay varias maneras de abordar este problema (por ejemplo, ver aquí), incluida la aplicación de la atomicidad (por ejemplo, estableciendo nuestra Concurrencia de funciones lambda a uno), utilizando algún tipo de mecanismo de bloqueo (por ejemplo, como se hace aquí), o haciendo nuestra función idempotente. Aquí hemos adoptado el enfoque de lo que llamamos «idempotencia optimista», donde confiamos en el uso apropiado de la API y en la idempotencia de nuestras llamadas subyacentes a las API de SageMaker.
  3. Destacamos que nuestra implementación es ingenua. En la práctica, recomendamos un algoritmo más sofisticado que 1) tenga en cuenta el uso de diferentes tipos de instancias y trabajos que requieren más de una instancia, 2) tenga en cuenta todos los casos extremos y 3) se adapte a las necesidades específicas de su proyecto.

Definir la función AWS Lambda

El siguiente componente de la solución es la función Lambda. El siguiente bloque de código incluye el Sam definición de nuestra función sin servidor. Programamos la función para que se ejecute en dos tipos diferentes de eventos: cualquier llamada a agregar trabajo en nuestra puerta de enlace API privada y un cambiar el estado de un trabajo de capacitación de SageMaker.

  ManagedTrainingJobQueue:
Type: AWS::Serverless::Function
Properties:
CodeUri: job-queue/ # the directory containing our index.py file
Handler: index.lambda_handler
Runtime: python3.12
Architectures:
- arm64 # use graviton
Policies: # allow access to SageMaker and DynamoDB
- !Sub "arn:${AWS::Partition}:iam::aws:policy/AmazonSageMakerFullAccess"
- DynamoDBCrudPolicy:
TableName: !Ref DynamoSMQueue
Events:
CreateTraining:
Type: Api
Properties:
Path: /add-job
Method: post
RestApiId: !Ref InternalAPI
SageMakerEvent:
Type: EventBridgeRule
Properties:
Pattern:
source:
- aws.sagemaker
detail-type:
- SageMaker Training Job State Change
detail:
TrainingJobStatus:
- "Completed"
- "Failed"
- "Stopped"

El lambda_handler La función se implementa de la siguiente manera:

def lambda_handler(event, context):
# identify source of event and take appropriate action
if 'requestContext' in event and 'apiId' in event['requestContext']:
print('Lambda triggerred by API Gateway')
job_details = json.loads(event.get('body'))
add_job_entry(job_details)
elif 'source' in event and event['source'] == 'aws.sagemaker':
print('Lambda triggerred by SageMaker job state change')
job_name = event['detail']['TrainingJobName']
job_status = event['detail']['TrainingJobStatus']
print(f'{job_name} status changed to {job_status}')

# strip qstamp from job_name
job_name = strip_qstamp(job_name)

if job_status in ['Completed' , 'Failed']:
remove_job(job_name)
elif job_status == 'Stopped':
# check if it was manually stopped or preempted by queue manager
if get_job_state(job_name) == 'preempted':
print(f'job {job_name} preemption completed')
else:
print(f'job {job_name} {job_status}, remove from queue')
remove_job(job_name)

# in all cases invoke queue manager
manage_queue()

Interceptar la solicitud de creación de trabajo de capacitación

La modificación final necesaria para completar nuestra solución es interceptar la llamada a SageMaker. crear_trabajo_de_formación API y redirigirlo a nuestro agregar trabajo método. Hacemos esto anulando el _intercept_create_request función de la Clase de sesión de SageMaker:

from sagemaker.pytorch import PyTorch
from sagemaker.session import Session
import requests, logging
logger = logging.getLogger('sagemaker')

def submit_to_training_queue(job):
logger.info(f'Adding training-job {job['TrainingJobName']} to queue')
logger.debug('train request: {json.dumps(job, indent=4)}')

vpce='<vpc endpoint>' # insert id of vpc endpoint
region='us-east-1' # specify region
url=f'https://{vpce}.execute-api.{region}.vpce.amazonaws.com/prod/add-job'
headers = {'x-apigw-api-id': '<api-id>'} # insert api gateway id

# submit job
response = requests.post(url, headers=headers, json=job)

class QueueTrainingJobSession(Session):
def _intercept_create_request(self, request, create, func_name = None):
"""This function intercepts the create job request

Args:
request (dict): the create job request
create (functor): a functor calls the sagemaker client create method
func_name (str): the name of the function needed intercepting
"""
if func_name == 'train':
submit_to_training_queue(request)
else:
super()._intercept_create_request(request,create,func_name)

# define job
estimator = PyTorch(
role='<sagemaker role>',
entry_point='train.py',
instance_type='ml.p5.48xlarge',
instance_count=1,
framework_version='2.0.1',
py_version='py310',
tags=[{'Key': 'priority', 'Value': '100'},
keep_alive_period_in_seconds=60, # keep warm for 1 minute
# use our custom Session class
sagemaker_session=QueueTrainingJobSession()
)

estimator.fit(wait=False)

Para probar nuestra solución presentamos la siguiente secuencia de trabajos. Después de cada llamada imprimimos el estado de la cola (usando el print_queue_state función) y duerma durante veinte segundos.

  1. Inicie el trabajo 1 con prioridad 1.
  2. Inicie el trabajo 2 con prioridad 2.
  3. Inicie el trabajo 3 con prioridad 1.
  4. Inicie el trabajo 4 con prioridad 3.

Los primeros dos trabajos se envían inmediatamente a SageMaker y se actualizan al correr estado. Dado que el tercer trabajo tiene baja prioridad y tenemos exactamente dos instancias de capacitación, permanece en el pendiente estado y espera su turno. Después de enviar los primeros tres trabajos, el estado de la cola aparece como:

Job Name    State      Priority
---------- ------- ----------
job2 running 2
job1 running 1
job3 pending 1

El cuarto trabajo que enviamos tiene una prioridad más alta que todos los trabajos en la cola. En consecuencia, el trabajo en ejecución con la prioridad más baja, trabajo1, se adelanta. El trabajo de SageMaker correspondiente se detiene y una vez que se libera la instancia, el estado de la cola pasa a ser:

Job Name    State        Priority
---------- --------- ----------
job4 running 3
job2 running 2
job1 preempted 1
job3 pending 1

El trabajo de SageMaker en ejecución trabajo2 es el primero en terminar, trabajo2 se elimina de la cola y se reanuda nuestro trabajo adelantado:

Job Name    State      Priority
---------- ------- ----------
job4 running 3
job1 running 1
job3 pending 1

Una vez trabajo4 se completa, también se elimina de la cola, dejando espacio para trabajo3. Los trabajos restantes también se ejecutan hasta su finalización, lo que finalmente deja nuestra cola vacía.

La creciente dificultad para adquirir capacidad informática de IA ha obligado a los equipos de desarrollo de IA a reevaluar los procesos que utilizan para entrenar modelos de IA. El enfoque que hemos demostrado en esta publicación es aumentar las API tradicionales para modelos de entrenamiento con una cola de prioridad personalizada y un programador de trabajos asociado. Es importante destacar que la propuesta que hemos presentado debe considerarse como un plan general, no como una solución digna de producción. Se requerirían modificaciones y mejoras apropiadas para abordar las necesidades específicas de su proyecto.