Control Continuo con DRL: Implementación de TD3 para Robótica y Sistemas Autónomos en PyTorch
Contexto del Problema: La Necesidad del Control Fino en IA
En el ámbito de la Inteligencia Artificial, el Aprendizaje por Refuerzo (RL) ha demostrado un éxito notable en tareas con espacios de acción discretos, como los juegos de Atari o Go. Sin embargo, muchos problemas del mundo real, especialmente en robótica, vehículos autónomos y automatización industrial, requieren que un agente realice acciones que varían continuamente, como ajustar el par de un motor, el ángulo de dirección de un vehículo o la fuerza de agarre de un manipulador. [1, 7, 17] Aquí es donde el control continuo se vuelve fundamental. A diferencia de las acciones discretas (por ejemplo, 'arriba', 'abajo'), las acciones continuas (por ejemplo, un valor entre -1 y 1) exigen ajustes precisos y matizados. [1, 2, 7]
Los algoritmos tradicionales de Q-Learning, que funcionan bien en espacios de acción discretos al encontrar el valor Q máximo entre un conjunto finito de acciones, no son directamente aplicables a espacios de acción continuos debido a la infinidad de posibles acciones. [1, 19] Evaluar cada valor o incluso una muestra densa se vuelve computacionalmente inviable. [17] Esto llevó al desarrollo de métodos basados en gradientes de política y arquitecturas Actor-Crítico, que pueden aprender una función que mapea directamente estados a acciones continuas. [1, 2, 9]
Fundamento Teórico: De DDPG a TD3
Para abordar el control continuo, surgió el algoritmo Deep Deterministic Policy Gradient (DDPG). DDPG es un algoritmo Actor-Crítico off-policy que combina las ideas de Q-Learning con gradientes de política deterministas. [2, 7, 12] Utiliza una red Actor para producir acciones deterministas dado un estado, y una red Crítico para estimar el valor Q de un par estado-acción. [2, 7]
Aunque DDPG fue un avance significativo, sufría de problemas de estabilidad y sobreestimación del valor Q, lo que podía llevar a políticas subóptimas. [4, 5, 13] La sobreestimación ocurre porque el Crítico tiende a ser demasiado optimista sobre los valores de las acciones, lo que a su vez guía al Actor hacia acciones que parecen mejores de lo que realmente son. [4, 5, 14]
Para mitigar estas limitaciones, se introdujo el algoritmo Twin Delayed Deep Deterministic Policy Gradient (TD3). [5, 13] TD3 es una mejora directa sobre DDPG, incorporando tres innovaciones clave: [5, 6, 13]
- Doble Crítico (Twin Critics): En lugar de un solo Crítico, TD3 utiliza dos redes Crítico. Para calcular el valor objetivo, toma el mínimo de los valores Q estimados por ambos Críticos. Esto reduce el sesgo de sobreestimación que afectaba a DDPG. [5, 6, 13, 14]
- Actualizaciones de Política Retrasadas (Delayed Policy Updates): El Actor (política) y las redes objetivo se actualizan con menos frecuencia que las redes Crítico. Esto asegura que la política se base en estimaciones de valor Q más estables y fiables, mejorando la estabilidad del entrenamiento. [5, 6, 13]
- Suavizado de la Política Objetivo (Target Policy Smoothing): Se añade ruido recortado a las acciones objetivo antes de calcular los valores Q objetivo. Esto hace que el proceso de aprendizaje sea más robusto al ruido y evita que la política explote errores en la función Q, lo que resulta en políticas más suaves y fiables. [4, 5, 6, 13]
Estas mejoras hacen de TD3 un algoritmo más estable y robusto para tareas de control continuo, superando consistentemente a DDPG en entornos complejos. [5, 20]
Implementación Práctica: TD3 en PyTorch
Implementar TD3 en PyTorch implica definir las redes Actor y Crítico, un búfer de repetición y el bucle de entrenamiento. A continuación, se presenta un esqueleto de código para entender los componentes clave.
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
# 1. Definición de las Redes Actor y Crítico
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
super(Actor, self).__init__()
self.l1 = nn.Linear(state_dim, 256)
self.l2 = nn.Linear(256, 256)
self.l3 = nn.Linear(256, action_dim)
self.max_action = max_action
def forward(self, state):
x = torch.relu(self.l1(state))
x = torch.relu(self.l2(x))
return self.max_action * torch.tanh(self.l3(x))
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super(Critic, self).__init__()
# Q1 network
self.l1 = nn.Linear(state_dim + action_dim, 256)
self.l2 = nn.Linear(256, 256)
self.l3 = nn.Linear(256, 1)
# Q2 network
self.l4 = nn.Linear(state_dim + action_dim, 256)
self.l5 = nn.Linear(256, 256)
self.l6 = nn.Linear(256, 1)
def forward(self, state, action):
sa = torch.cat([state, action], 1)
q1 = torch.relu(self.l1(sa))
q1 = torch.relu(self.l2(q1))
q1 = self.l3(q1)
q2 = torch.relu(self.l4(sa))
q2 = torch.relu(self.l5(q2))
q2 = self.l6(q2)
return q1, q2
def Q1(self, state, action):
sa = torch.cat([state, action], 1)
q1 = torch.relu(self.l1(sa))
q1 = torch.relu(self.l2(q1))
q1 = self.l3(q1)
return q1
# 2. Búfer de Repetición (Replay Buffer)
class ReplayBuffer:
def __init__(self, capacity):
self.capacity = capacity
self.buffer = []
self.position = 0
def push(self, state, action, reward, next_state, done):
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state, done)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
state, action, reward, next_state, done = map(np.stack, zip(*batch))
return state, action, reward, next_state, done
def __len__(self):
return len(self.buffer)
# 3. Clase del Agente TD3
class TD3:
def __init__(self, state_dim, action_dim, max_action, lr_actor=3e-4, lr_critic=3e-4,
gamma=0.99, tau=0.005, policy_noise=0.2, noise_clip=0.5, policy_freq=2):
self.actor = Actor(state_dim, action_dim, max_action)
self.actor_target = Actor(state_dim, action_dim, max_action)
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr_actor)
self.critic = Critic(state_dim, action_dim)
self.critic_target = Critic(state_dim, action_dim)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr_critic)
self.max_action = max_action
self.gamma = gamma
self.tau = tau
self.policy_noise = policy_noise
self.noise_clip = noise_clip
self.policy_freq = policy_freq
self.total_it = 0
def select_action(self, state):
state = torch.FloatTensor(state.reshape(1, -1))
return self.actor(state).cpu().data.numpy().flatten()
def train(self, replay_buffer, batch_size=100):
self.total_it += 1
# Muestrear del búfer de repetición
state, action, reward, next_state, done = replay_buffer.sample(batch_size)
state = torch.FloatTensor(state)
action = torch.FloatTensor(action)
reward = torch.FloatTensor(reward).reshape(-1, 1)
next_state = torch.FloatTensor(next_state)
done = torch.FloatTensor(1 - done).reshape(-1, 1)
# Calcular el valor Q objetivo
with torch.no_grad():
# Añadir ruido a la política objetivo para suavizado
noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)
# Doble Crítico: tomar el mínimo de los dos Q-values
target_Q1, target_Q2 = self.critic_target(next_state, next_action)
target_Q = torch.min(target_Q1, target_Q2)
target_Q = reward + (done * self.gamma * target_Q)
# Actualizar los Críticos
current_Q1, current_Q2 = self.critic(state, action)
critic_loss = nn.MSELoss()(current_Q1, target_Q) + nn.MSELoss()(current_Q2, target_Q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# Actualizar el Actor (retrasado)
if self.total_it % self.policy_freq == 0:
actor_loss = -self.critic.Q1(state, self.actor(state)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Actualizar redes objetivo (soft update)
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
def save(self, filename):
torch.save(self.critic.state_dict(), filename + "_critic")
torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")
torch.save(self.actor.state_dict(), filename + "_actor")
torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")
def load(self, filename):
self.critic.load_state_dict(torch.load(filename + "_critic"))
self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
self.actor.load_state_dict(torch.load(filename + "_actor"))
self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
Este código muestra la estructura básica. La red Actor produce acciones continuas usando una función de activación tanh para escalar la salida al rango de acción deseado. Las dos redes Crítico calculan los valores Q, y el entrenamiento del Actor se basa en el valor Q de una de las redes Crítico. Las actualizaciones retrasadas y el ruido en la política objetivo se implementan directamente en el método train. [5, 6]
Aplicaciones Reales: Más Allá de los Juegos
La capacidad de TD3 para manejar espacios de acción continuos lo hace invaluable en una variedad de aplicaciones del mundo real: [3, 7, 21]
- Robótica: Control de brazos robóticos para tareas de manipulación fina, como ensamblaje o agarre de objetos con precisión. [3, 7, 21] También en la locomoción de robots humanoides o cuadrúpedos, donde se requieren ajustes continuos de las articulaciones. [17]
- Vehículos Autónomos: Optimización de la trayectoria, planificación de movimiento y control de la dirección y el acelerador en tiempo real. [7, 21]
- Sistemas de Control Industrial: Regulación de procesos complejos en manufactura, energía o química, donde las variables de control son continuas.
- Simulaciones y Juegos: Creación de agentes con comportamientos más naturales y fluidos en entornos simulados, como personajes de juegos con movimientos realistas.
- Finanzas Cuantitativas: Estrategias de trading algorítmico que requieren ajustes continuos de posiciones o asignaciones de capital. [21]
Mejores Prácticas: Optimizando el Entrenamiento de TD3
Para obtener el mejor rendimiento de TD3 en entornos de producción, considera las siguientes mejores prácticas:
- Normalización de Estados y Acciones: Escalar los estados y las acciones a un rango consistente (por ejemplo, [-1, 1]) puede mejorar la estabilidad y la velocidad de convergencia del entrenamiento.
- Exploración: Aunque TD3 utiliza ruido en la política objetivo, es crucial añadir ruido de exploración (por ejemplo, ruido gaussiano o de Ornstein-Uhlenbeck) a las acciones durante la recolección de datos para asegurar una exploración adecuada del espacio de acción, especialmente al principio del entrenamiento. [6, 12]
- Tamaño del Búfer de Repetición: Un búfer de repetición grande es esencial para los algoritmos off-policy, ya que permite al agente aprender de una amplia variedad de experiencias pasadas y reduce la correlación entre las muestras. [2, 5]
- Frecuencia de Actualización de la Política: Experimenta con el parámetro
policy_freq. Un valor de 2 (actualizar el Actor cada dos actualizaciones del Crítico) es una recomendación común, pero puede variar según el entorno. [5, 6] - Ajuste de Hiperparámetros: Los hiperparámetros como las tasas de aprendizaje (
lr_actor,lr_critic), el factor de descuento (gamma), el factor de suavizado (tau), y los parámetros del ruido (policy_noise,noise_clip) son cruciales y a menudo requieren un ajuste fino para cada entorno específico. - Entornos de Simulación: Utiliza entornos de simulación como OpenAI Gym con MuJoCo o PyBullet para probar y depurar tus implementaciones antes de desplegarlas en hardware real. [5, 7, 10]
Aprendizaje Futuro: Más Allá de TD3
TD3 es un algoritmo robusto, pero el campo del Aprendizaje por Refuerzo para control continuo sigue evolucionando. Para profundizar, considera explorar:
- Soft Actor-Critic (SAC): Otro algoritmo off-policy de última generación para control continuo que incorpora la maximización de la entropía en su objetivo, fomentando una exploración más eficiente y políticas más robustas. [7, 17]
- Proximal Policy Optimization (PPO): Un algoritmo on-policy que ha demostrado ser muy efectivo y estable en una amplia gama de tareas de control continuo. [2, 7]
- Aprendizaje por Refuerzo Multi-Agente: Cómo aplicar estos conceptos en entornos donde múltiples agentes interactúan.
- Aprendizaje por Refuerzo Jerárquico: Descomponer problemas complejos en subtareas para facilitar el aprendizaje.
- RL en el Mundo Real: Desafíos y soluciones para la implementación de RL en sistemas físicos, incluyendo la eficiencia de la muestra y la transferencia de la simulación al mundo real (sim-to-real). [1, 7]
Dominar TD3 proporciona una base sólida para abordar problemas de control continuo en IA/ML, abriendo la puerta a la creación de sistemas autónomos más inteligentes y capaces.