En este tutorial, implementamos un agente de aprendizaje por refuerzo utilizando RLax, una biblioteca orientada a la investigación desarrollada por Google DeepMind para crear algoritmos de aprendizaje por refuerzo con JAX. Combinamos RLax con JAX, Haiku y Optax para construir un agente Deep Q-Learning (DQN) que aprende a resolver el entorno CartPole. En lugar de utilizar un marco de RL completamente empaquetado, ensamblamos el proceso de capacitación nosotros mismos para que podamos comprender claramente cómo interactúan los componentes centrales del aprendizaje por refuerzo. Definimos la red neuronal, creamos un búfer de reproducción, calculamos errores de diferencia temporal con RLax y entrenamos al agente mediante optimización basada en gradientes. Además, nos centramos en comprender cómo RLax proporciona primitivas RL reutilizables que se pueden integrar en canales de aprendizaje por refuerzo personalizados. Usamos JAX para cálculo numérico eficiente, Haiku para modelado de redes neuronales y Optax para optimización.
num_actions = env.action_space.n def q_network(x): mlp = hk.Sequential([
hk.Linear(128), jax.nn.relu,
hk.Linear(128), jax.nn.relu,
hk.Linear(num_actions),
]) devuelve mlp(x) q_net = hk. without_apply_rng(hk.transform(q_network)) dummy_obs = jnp.zeros((1, obs_dim), dtype=jnp.float32) rng = jax.random.PRNGKey(seed) params = q_net.init(rng, dummy_obs) target_params = optimizador de parámetros = optax.chain( optax.clip_by_global_norm(10.0), optax.adam(3e-4), ) opt_state = optimizador.init(params)
Instalamos las bibliotecas necesarias e importamos todos los módulos necesarios para el proceso de aprendizaje por refuerzo. Inicializamos el entorno, definimos la arquitectura de la red neuronal usando Haiku y configuramos la red Q que predice valores de acción. También inicializamos la red y los parámetros de la red objetivo, así como el optimizador que se utilizará durante el entrenamiento.
Definimos la estructura de transición e implementamos un búfer de reproducción para almacenar experiencias pasadas del entorno. Creamos funciones para agregar transiciones y minilotes de muestra que luego se usarán para capacitar al agente. También implementamos la estrategia de exploración codiciosa de épsilon.
Definimos las funciones básicas de aprendizaje utilizadas durante el entrenamiento. Calculamos errores de diferencia temporal utilizando la primitiva Q-learning de RLax y calculamos la pérdida utilizando la función de pérdida de Huber. Luego implementamos el paso de entrenamiento que calcula gradientes, aplica actualizaciones del optimizador y devuelve métricas de entrenamiento.
para ep en rango (episodios): obs, _ = eval_env.reset(seed=seed + 1000 + ep) hecho = Falso truncado = Falso recompensa_total = 0.0 mientras no (hecho o truncado): q_values = q_net.apply(params, obs[None, :]) acción = int(jnp.argmax(q_values[0])) next_obs, recompensa, hecho, truncado, _ = eval_env.step(action) total_reward += recompensa obs = next_obs return.append(total_reward) return float(np.mean(returns)) num_frames = 40000 tamaño_de_lote = 128 pasos_de calentamiento = 1000 train_every = 4 eval_every = 2000 gamma = 0,99 tau = 0,01 max_grad_updates_per_step = 1 obs, _ = env.reset(seed=seed) episodio_retorno = 0,0 episodio_retornos = []
eval_returns = []
pérdidas = []
td_means = []
q_significa = []
pasos_evaluación = []
hora_inicio = hora.hora()
Definimos la función de evaluación que mide el desempeño del agente. Configuramos los hiperparámetros de entrenamiento, incluida la cantidad de fotogramas, el tamaño del lote, el factor de descuento y la tasa de actualización de la red objetivo. También inicializamos variables que rastrean las estadísticas de entrenamiento, incluidos los retornos de los episodios, las pérdidas y las métricas de evaluación.
done = False truncated = False total_reward = 0.0 render_env = gym.make(“CartPole-v1″, render_mode=”rgb_array”) obs, _ = render_env.reset(seed=999) mientras no (hecho o truncado): frame = render_env.render() frames.append(frame) q_values = q_net.apply(params, obs[None, :]) acción = int(jnp.argmax(q_values[0])) obs, recompensa, hecho, truncado, _ = render_env.step(action) total_reward += recompensa render_env.close() print(f”Retorno del episodio de demostración: {total_reward:.2f}”) intente: importar matplotlib.animation como animación desde IPython.display importar HTML, mostrar fig = plt.figure(figsize=(6, 4)) patch = plt.imshow(frames[0]) plt.axis(“off”) def animate(i): patch.set_data(frames[i]) return (parche,) anim = animación.FuncAnimation(fig, animate, frames=len(frames), intervalo=30, blit=True) display(HTML(anim.to_jshtml())) plt.close(fig) excepto excepción como e: print(“Visualización de animación omitida:”, e)
Ejecutamos el ciclo completo de capacitación de aprendizaje por refuerzo. Actualizamos periódicamente los parámetros de la red, evaluamos el desempeño del agente y registramos métricas para su visualización. Además, trazamos los resultados del entrenamiento y presentamos un episodio de demostración para observar cómo se comporta el agente entrenado.
En conclusión, creamos un agente completo de Deep Q-Learning combinando RLax con el moderno ecosistema de aprendizaje automático basado en JAX. Diseñamos una red neuronal para estimar valores de acción, implementar la repetición de experiencias para estabilizar el aprendizaje y calcular errores de TD utilizando la primitiva Q-learning de RLax. Durante el entrenamiento, actualizamos los parámetros de la red mediante optimización basada en gradientes y evaluamos periódicamente el agente para realizar un seguimiento de las mejoras de rendimiento. Además, vimos cómo RLax permite un enfoque modular para el aprendizaje por refuerzo al proporcionar componentes algorítmicos reutilizables en lugar de algoritmos completos. Esta flexibilidad nos permite experimentar fácilmente con diferentes arquitecturas, reglas de aprendizaje y estrategias de optimización. Al ampliar esta base, podemos crear agentes más avanzados, como Double DQN, modelos de aprendizaje por refuerzo distributivo y métodos de actor-crítico, utilizando las mismas primitivas de RLax.
Consulte el cuaderno completo aquí. Además, no dude en seguirnos en Twitter y no olvide unirse a nuestro SubReddit de más de 120.000 ML y suscribirse a nuestro boletín. ¡Esperar! estas en telegrama? Ahora también puedes unirte a nosotros en Telegram.