Implementación de Deep Q-Learning (DQN) desde cero utilizando RLax JAX Haiku y Optax para capacitar a un agente de aprendizaje por refuerzo CartPole

En este tutorial, implementamos un agente de aprendizaje por refuerzo utilizando RLax, una biblioteca orientada a la investigación desarrollada por Google DeepMind para crear algoritmos de aprendizaje por refuerzo con JAX. Combinamos RLax con JAX, Haiku y Optax para construir un agente Deep Q-Learning (DQN) que aprende a resolver el entorno CartPole. En lugar de utilizar un marco de RL completamente empaquetado, ensamblamos el proceso de capacitación nosotros mismos para que podamos comprender claramente cómo interactúan los componentes centrales del aprendizaje por refuerzo. Definimos la red neuronal, creamos un búfer de reproducción, calculamos errores de diferencia temporal con RLax y entrenamos al agente mediante optimización basada en gradientes. Además, nos centramos en comprender cómo RLax proporciona primitivas RL reutilizables que se pueden integrar en canales de aprendizaje por refuerzo personalizados. Usamos JAX para cálculo numérico eficiente, Haiku para modelado de redes neuronales y Optax para optimización.

!pip -q instalar “jax[cpu]” dm-haiku optax rlax gimnasio matplotlib numpy import os.environ[“XLA_PYTHON_CLIENT_PREALLOCATE”] = “falso” importar tiempo de importación aleatorio desde clases de datos importar clase de datos desde colecciones importar deque importar gimnasio como gimnasio importar haiku como hk importar jax importar jax.numpy como jnp importar matplotlib.pyplot como plt importar numpy como np importar optax importar rlax semilla = 42 random.seed(seed) np.random.seed(seed) env = gimnasio.make(“CartPole-v1”) eval_env = gimnasio.make(“CartPole-v1”) obs_dim = env.observation_space.shape[0]
num_actions = env.action_space.n def q_network(x): mlp = hk.Sequential([
hk.Linear(128), jax.nn.relu,
hk.Linear(128), jax.nn.relu,
hk.Linear(num_actions),
]) devuelve mlp(x) q_net = hk. without_apply_rng(hk.transform(q_network)) dummy_obs = jnp.zeros((1, obs_dim), dtype=jnp.float32) rng = jax.random.PRNGKey(seed) params = q_net.init(rng, dummy_obs) target_params = optimizador de parámetros = optax.chain( optax.clip_by_global_norm(10.0), optax.adam(3e-4), ) opt_state = optimizador.init(params)

Instalamos las bibliotecas necesarias e importamos todos los módulos necesarios para el proceso de aprendizaje por refuerzo. Inicializamos el entorno, definimos la arquitectura de la red neuronal usando Haiku y configuramos la red Q que predice valores de acción. También inicializamos la red y los parámetros de la red objetivo, así como el optimizador que se utilizará durante el entrenamiento.

@dataclass clase Transición: obs: np.ndarray acción: int recompensa: descuento flotante: flotante next_obs: np.ndarray hecho: clase flotante ReplayBuffer: def __init__(self, capacidad): self.buffer = deque(maxlen=capacidad) def add(self, *args): self.buffer.append(Transition(*args)) def muestra(self, tamaño_lote): lote = muestra aleatoria (self.buffer, tamaño de lote) obs = np.stack ([t.obs for t in batch]).astype(np.float32) acción = np.array([t.action for t in batch]dtype=np.int32) recompensa = np.array([t.reward for t in batch]dtype=np.float32) descuento = np.array([t.discount for t in batch]dtype=np.float32) next_obs = np.stack([t.next_obs for t in batch]).astype(np.float32) hecho = np.array([t.done for t in batch]dtype=np.float32) return { “obs”: obs, “action”: acción, “recompensa”: recompensa, “descuento”: descuento, “next_obs”: next_obs, “done”: hecho, } def __len__(self): return len(self.buffer) replay = ReplayBuffer(capacidad=50000) def epsilon_by_frame(frame_idx, eps_start=1.0, eps_end=0.05, decay_frames=20000): mix = min(frame_idx / decay_frames, 1.0) return eps_start + mix * (eps_end – eps_start) def select_action(params, obs, epsilon): if random.random() < epsilon: return env.action_space.sample() q_values = q_net.apply(params, obs[None, :]) devuelve int(jnp.argmax(q_values[0]))

Definimos la estructura de transición e implementamos un búfer de reproducción para almacenar experiencias pasadas del entorno. Creamos funciones para agregar transiciones y minilotes de muestra que luego se usarán para capacitar al agente. También implementamos la estrategia de exploración codiciosa de épsilon.

@jax.jit def soft_update(target_params, online_params, tau): return jax.tree_util.tree_map(lambda t, s: (1.0 – tau) * t + tau * s, target_params, online_params) def lote_td_errors(params, target_params, lote): q_tm1 = q_net.apply(params, lote[“obs”]) q_t = q_net.apply(target_params, lote[“next_obs”]) td_errors = jax.vmap( lambda q1, a, r, d, q2: rlax.q_learning(q1, a, r, d, q2) )(q_tm1, lote[“action”]lote[“reward”]lote[“discount”]q_t) return td_errors @jax.jit def train_step(params, target_params, opt_state, lote): def loss_fn(p): td_errors = lote_td_errors(p, target_params, lote) pérdida = jnp.mean(rlax.huber_loss(td_errors, delta=1.0)) metrics = { “loss”: pérdida, “td_abs_mean”: jnp.mean(jnp.abs(td_errors)), “q_mean”: jnp.mean(q_net.apply(p, lote[“obs”])), } pérdida de retorno, métricas (pérdida, métricas), grads = jax.value_and_grad(loss_fn, has_aux=True)(params) actualizaciones, opt_state = optimizador.update(grads, opt_state, params) params = optax.apply_updates(params, actualizaciones) return params, opt_state, métricas

Definimos las funciones básicas de aprendizaje utilizadas durante el entrenamiento. Calculamos errores de diferencia temporal utilizando la primitiva Q-learning de RLax y calculamos la pérdida utilizando la función de pérdida de Huber. Luego implementamos el paso de entrenamiento que calcula gradientes, aplica actualizaciones del optimizador y devuelve métricas de entrenamiento.

def evalua_agent(parámetros, episodios=5): devuelve = []
para ep en rango (episodios): obs, _ = eval_env.reset(seed=seed + 1000 + ep) hecho = Falso truncado = Falso recompensa_total = 0.0 mientras no (hecho o truncado): q_values ​​= q_net.apply(params, obs[None, :]) acción = int(jnp.argmax(q_values[0])) next_obs, recompensa, hecho, truncado, _ = eval_env.step(action) total_reward += recompensa obs = next_obs return.append(total_reward) return float(np.mean(returns)) num_frames = 40000 tamaño_de_lote = 128 pasos_de calentamiento = 1000 train_every = 4 eval_every = 2000 gamma = 0,99 tau = 0,01 max_grad_updates_per_step = 1 obs, _ = env.reset(seed=seed) episodio_retorno = 0,0 episodio_retornos = []
eval_returns = []
pérdidas = []
td_means = []
q_significa = []
pasos_evaluación = []

hora_inicio = hora.hora()

Definimos la función de evaluación que mide el desempeño del agente. Configuramos los hiperparámetros de entrenamiento, incluida la cantidad de fotogramas, el tamaño del lote, el factor de descuento y la tasa de actualización de la red objetivo. También inicializamos variables que rastrean las estadísticas de entrenamiento, incluidos los retornos de los episodios, las pérdidas y las métricas de evaluación.

para frame_idx en rango(1, num_frames + 1): epsilon = epsilon_by_frame(frame_idx) action = select_action(params, obs.astype(np.float32), epsilon) next_obs, recompensa, hecho, truncado, _ = env.step(action) terminal = hecho o truncado descuento = 0.0 si terminal else gamma replay.add( obs.astype(np.float32), acción, float(recompensa), float(descuento), next_obs.astype(np.float32), float(terminal), ) obs = next_obs episodio_return += recompensa si terminal: episodio_returns.append(episode_return) obs, _ = env.reset() episodio_return = 0.0 si len(repetición) >= pasos_calentamiento y frame_idx % train_every == 0: para _ en rango(max_grad_updates_per_step): lote_np = replay.sample(batch_size) lote = {k: jnp.asarray(v) para k, v en lote_np.items()} params, opt_state, metrics = train_step(params, target_params, opt_state, lote) target_params = soft_update(target_params, params, tau) pérdidas.append(float(métricas[“loss”])) td_means.append(float(métricas[“td_abs_mean”])) q_means.append(float(métricas[“q_mean”])) si frame_idx % eval_every == 0: avg_eval_return = evalua_agent(params, episodios=5) eval_returns.append(avg_eval_return) eval_steps.append(frame_idx) reciente_train = np.mean(episode_returns[-10:]) si episodio_returns else 0.0 pérdida_reciente = np.mean(pérdidas[-100:]) si las pérdidas son 0,0 print( f”step={frame_idx:6d} | epsilon={epsilon:.3f} | ” f”recent_train_return={recent_train:7.2f} | ” f”eval_return={avg_eval_return:7.2f} | ” f”recent_loss={recent_loss:.5f} | buffer={len(replay)}” ) transcurrido = tiempo.tiempo() – tiempo_inicio final_eval = evaluar_agente(params, episodios=10) print(“\nEntrenamiento completo”) print(f”Tiempo transcurrido: {transcurrido:.1f} segundos”) print(f”Regreso de la evaluación final de 10 episodios: {final_eval:.2f}”) plt.figure(figsize=(14, 4)) plt.subplot(1, 3, 1) plt.plot(episode_returns) plt.title(“Vuelve el episodio de entrenamiento”) plt.xlabel(“Episodio”) plt.ylabel(“Regreso”) plt.subplot(1, 3, 2) plt.plot(eval_steps, eval_returns) plt.title(“Devoluciones de evaluación”) plt.xlabel(“Pasos del entorno”) plt.ylabel(“Retorno promedio”) plt.subplot(1, 3, 3) plt.plot(losses, label=”Pérdida”) plt.plot(td_means, label=”|Error TD| Media”) plt.title(“Métricas de optimización”) plt.xlabel(“Actualizaciones de gradiente”) plt.legend() plt.tight_layout() plt.show() obs, _ = eval_env.reset(seed=999) frames = []
done = False truncated = False total_reward = 0.0 render_env = gym.make(“CartPole-v1″, render_mode=”rgb_array”) obs, _ = render_env.reset(seed=999) mientras no (hecho o truncado): frame = render_env.render() frames.append(frame) q_values ​​= q_net.apply(params, obs[None, :]) acción = int(jnp.argmax(q_values[0])) obs, recompensa, hecho, truncado, _ = render_env.step(action) total_reward += recompensa render_env.close() print(f”Retorno del episodio de demostración: {total_reward:.2f}”) intente: importar matplotlib.animation como animación desde IPython.display importar HTML, mostrar fig = plt.figure(figsize=(6, 4)) patch = plt.imshow(frames[0]) plt.axis(“off”) def animate(i): patch.set_data(frames[i]) return (parche,) anim = animación.FuncAnimation(fig, animate, frames=len(frames), intervalo=30, blit=True) display(HTML(anim.to_jshtml())) plt.close(fig) excepto excepción como e: print(“Visualización de animación omitida:”, e)

Ejecutamos el ciclo completo de capacitación de aprendizaje por refuerzo. Actualizamos periódicamente los parámetros de la red, evaluamos el desempeño del agente y registramos métricas para su visualización. Además, trazamos los resultados del entrenamiento y presentamos un episodio de demostración para observar cómo se comporta el agente entrenado.

En conclusión, creamos un agente completo de Deep Q-Learning combinando RLax con el moderno ecosistema de aprendizaje automático basado en JAX. Diseñamos una red neuronal para estimar valores de acción, implementar la repetición de experiencias para estabilizar el aprendizaje y calcular errores de TD utilizando la primitiva Q-learning de RLax. Durante el entrenamiento, actualizamos los parámetros de la red mediante optimización basada en gradientes y evaluamos periódicamente el agente para realizar un seguimiento de las mejoras de rendimiento. Además, vimos cómo RLax permite un enfoque modular para el aprendizaje por refuerzo al proporcionar componentes algorítmicos reutilizables en lugar de algoritmos completos. Esta flexibilidad nos permite experimentar fácilmente con diferentes arquitecturas, reglas de aprendizaje y estrategias de optimización. Al ampliar esta base, podemos crear agentes más avanzados, como Double DQN, modelos de aprendizaje por refuerzo distributivo y métodos de actor-crítico, utilizando las mismas primitivas de RLax.

Consulte el cuaderno completo aquí. Además, no dude en seguirnos en Twitter y no olvide unirse a nuestro SubReddit de más de 120.000 ML y suscribirse a nuestro boletín. ¡Esperar! estas en telegrama? Ahora también puedes unirte a nosotros en Telegram.