Source code for sb3_soft.sql.sql

"""
Soft Q-Learning (SQL) implementation for discrete action spaces, based on
Haarnoha et al. (2017) "Reinforcement Learning with Deep Energy-Based Policies"
https://proceedings.mlr.press/v70/haarnoja17a.html
"""

from typing import Any, ClassVar, Optional, TypeVar, Union

import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.dqn.policies import QNetwork
from torch.nn import functional as F

from .policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SQLPolicy

SelfSQL = TypeVar("SelfSQL", bound="SQL")


[docs] class SQL(OffPolicyAlgorithm): """Discrete-action Soft Q-Learning from T. Haarnoja, H. Tang, P. Abbeel, and S. Levine, Reinforcement Learning with Deep Energy-Based Policies,” Proceedings of the 34th International Conference on Machine Learning, PMLR, Jul. 2017, pp. 1352–1361. https://proceedings.mlr.press/v70/haarnoja17a.html Extends SB3's ``OffPolicyAlgorithm`` with an entropy-regularized Bellman backup and Boltzmann (softmax) action sampling. Parameters ---------- policy : str | type[SQLPolicy] Policy model to use (e.g., ``"MlpPolicy"``, ``"CnnPolicy"``). env : GymEnv | str Environment to learn from. learning_rate : float | Schedule, default=1e-4 Learning rate (constant or schedule). buffer_size : int, default=1_000_000 Replay buffer capacity. learning_starts : int, default=100 Number of steps of random exploration before learning starts. batch_size : int, default=32 Minibatch size for each gradient update. tau : float, default=1.0 Polyak update coefficient for target network updates. gamma : float, default=0.99 Discount factor. train_freq : int | tuple[int, str], default=4 Update frequency in steps or episodes. gradient_steps : int, default=1 Number of gradient updates after each rollout. action_noise : ActionNoise | None, default=None Action noise for exploration (only applicable to continuous action spaces). replay_buffer_class : type[ReplayBuffer] | None, default=None Optional replay buffer implementation override. replay_buffer_kwargs : dict[str, Any] | None, default=None Additional keyword arguments for replay buffer creation. optimize_memory_usage : bool, default=False Whether to use the memory-efficient replay buffer variant. n_steps : int, default=1 Number of steps for n-step returns. target_update_interval : int, default=10_000 Environment steps between target network updates. max_grad_norm : float, default=10 Maximum gradient norm for clipping. ent_coef : str | float, default="auto" Temperature :math:`\\alpha` used in the soft Bellman target :math:`V(s') = \\alpha \\log \\sum_a \\exp(Q(s', a) / \\alpha)`. Set to ``"auto"`` (or ``"auto_0.1"``) to learn it automatically. target_entropy : str | float, default="auto" Target policy entropy used when ``ent_coef`` is learned automatically. If ``"auto"``, uses :math:`0.98 \\log(|\\mathcal{A}|)`, but this may need adjustment for low-entropy environments. action_temperature : float | None, default=None Temperature :math:`\tau` for Boltzmann action sampling :math:`\\pi(a \\mid s) \\propto \\exp(Q(s, a) / \\tau)`. If ``None``, uses ``ent_coef``. stats_window_size : int, default=100 Window size for rollout statistics logging. tensorboard_log : str | None, default=None TensorBoard log directory. policy_kwargs : dict[str, Any] | None, default=None Additional keyword arguments passed to the policy. verbose : int, default=0 Verbosity level (0: no output, 1: info, 2: debug). seed : int | None, default=None Random seed. device : torch.device | str, default="auto" Device to run the model on. _init_setup_model : bool, default=True Whether to build networks and optimizers during construction. """ policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": MlpPolicy, "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, } q_net: QNetwork q_net_target: QNetwork def __init__( self, policy: Union[str, type[SQLPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 1e-4, buffer_size: int = 1_000_000, learning_starts: int = 100, batch_size: int = 32, tau: float = 1.0, gamma: float = 0.99, train_freq: Union[int, tuple[int, str]] = 4, gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, replay_buffer_class: Optional[type[ReplayBuffer]] = None, replay_buffer_kwargs: Optional[dict[str, Any]] = None, optimize_memory_usage: bool = False, n_steps: int = 1, target_update_interval: int = 10_000, max_grad_norm: float = 10, ent_coef: Union[str, float] = "auto", target_entropy: Union[str, float] = "auto", action_temperature: float | None = None, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", _init_setup_model: bool = True, ) -> None: self.target_entropy = target_entropy self.ent_coef = ent_coef self.log_ent_coef: Optional[th.Tensor] = None self.ent_coef_optimizer: Optional[th.optim.Adam] = None self._temperature_follows_ent_coef = action_temperature is None init_ent_coef: float if isinstance(ent_coef, str): if not ent_coef.startswith("auto"): raise ValueError( "ent_coef must be a float or start with 'auto' (e.g. 'auto_0.1'). " f"Got: {ent_coef}" ) init_ent_coef = 1.0 if "_" in ent_coef: init_ent_coef = float(ent_coef.split("_")[1]) if init_ent_coef <= 0: raise ValueError( f"The initial value of ent_coef must be > 0, got {init_ent_coef}" ) else: init_ent_coef = float(ent_coef) if init_ent_coef <= 0: raise ValueError(f"ent_coef must be > 0, got {ent_coef}") if self._temperature_follows_ent_coef: self.action_temperature = init_ent_coef else: assert action_temperature is not None self.action_temperature = float(action_temperature) if self.action_temperature <= 0: raise ValueError( f"action_temperature must be > 0 when provided, got {self.action_temperature}" ) policy_kwargs = {} if policy_kwargs is None else dict(policy_kwargs) policy_kwargs.setdefault("temperature", self.action_temperature) self.target_update_interval = target_update_interval self.max_grad_norm = max_grad_norm self._n_calls = 0 super().__init__( policy=policy, env=env, learning_rate=learning_rate, buffer_size=buffer_size, learning_starts=learning_starts, batch_size=batch_size, tau=tau, gamma=gamma, train_freq=train_freq, gradient_steps=gradient_steps, action_noise=action_noise, replay_buffer_class=replay_buffer_class, replay_buffer_kwargs=replay_buffer_kwargs, optimize_memory_usage=optimize_memory_usage, n_steps=n_steps, policy_kwargs=policy_kwargs, use_sde=False, sde_sample_freq=-1, use_sde_at_warmup=False, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, verbose=verbose, device=device, seed=seed, sde_support=False, supported_action_spaces=(spaces.Discrete,), support_multi_env=True, ) if _init_setup_model: self._setup_model() def _setup_model(self) -> None: super()._setup_model() self._create_aliases() self.batch_norm_stats = get_parameters_by_name(self.q_net, ["running_"]) self.batch_norm_stats_target = get_parameters_by_name( self.q_net_target, ["running_"] ) if self.target_entropy == "auto": assert isinstance(self.action_space, spaces.Discrete) self.target_entropy = float(0.98 * np.log(self.action_space.n)) else: self.target_entropy = float(self.target_entropy) if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"): init_value = 1.0 if "_" in self.ent_coef: init_value = float(self.ent_coef.split("_")[1]) assert init_value > 0.0, ( "The initial value of ent_coef must be greater than 0" ) self.log_ent_coef = th.log( th.ones(1, device=self.device) * init_value ).requires_grad_(True) self.ent_coef_optimizer = th.optim.Adam( [self.log_ent_coef], lr=self.lr_schedule(1) ) else: self.ent_coef_tensor = th.tensor(float(self.ent_coef), device=self.device) def _set_action_temperature(self, value: float) -> None: value = max(float(value), 1e-8) self.action_temperature = value if isinstance(self.policy, SQLPolicy): self.policy.temperature = value setattr(self.policy.q_net, "temperature", value) setattr(self.policy.q_net_target, "temperature", value) def _create_aliases(self) -> None: assert isinstance(self.policy, SQLPolicy) self.q_net = self.policy.q_net self.q_net_target = self.policy.q_net_target def _on_step(self) -> None: """Update the target network if needed.""" self._n_calls += 1 if self._n_calls % max(self.target_update_interval // self.n_envs, 1) == 0: polyak_update( self.q_net.parameters(), self.q_net_target.parameters(), self.tau ) polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)
[docs] def train(self, gradient_steps: int, batch_size: int = 100) -> None: self.policy.set_training_mode(True) optimizers: list[th.optim.Optimizer] = [self.policy.optimizer] if self.ent_coef_optimizer is not None: optimizers.append(self.ent_coef_optimizer) self._update_learning_rate(optimizers) losses: list[float] = [] ent_coef_losses: list[float] = [] ent_coefs: list[float] = [] last_batch_entropy: float | None = None for _ in range(gradient_steps): replay_data = self.replay_buffer.sample( # type: ignore[union-attr] batch_size, env=self._vec_normalize_env ) q_obs = self.q_net(replay_data.observations) log_probs = th.log_softmax(q_obs / self.action_temperature, dim=1) probs = log_probs.exp() entropy = -(probs * log_probs).sum(dim=1, keepdim=True) ent_coef_loss: Optional[th.Tensor] = None if self.ent_coef_optimizer is not None and self.log_ent_coef is not None: ent_coef_tensor = th.exp(self.log_ent_coef.detach()) assert isinstance(self.target_entropy, float) ent_coef_loss = ( self.log_ent_coef * (entropy.detach() - self.target_entropy) ).mean() ent_coef_losses.append(ent_coef_loss.item()) else: ent_coef_tensor = self.ent_coef_tensor ent_coefs.append(ent_coef_tensor.item()) if ent_coef_loss is not None and self.ent_coef_optimizer is not None: self.ent_coef_optimizer.zero_grad() ent_coef_loss.backward() self.ent_coef_optimizer.step() assert self.log_ent_coef is not None ent_coef_tensor = th.exp(self.log_ent_coef.detach()) if self._temperature_follows_ent_coef: self._set_action_temperature(ent_coef_tensor.item()) discounts = ( replay_data.discounts if replay_data.discounts is not None else self.gamma ) with th.no_grad(): next_q_values = self.q_net_target(replay_data.next_observations) next_v_values = ent_coef_tensor * th.logsumexp( next_q_values / ent_coef_tensor, dim=1 ) next_v_values = next_v_values.reshape(-1, 1) target_q_values = ( replay_data.rewards + (1 - replay_data.dones) * discounts * next_v_values ) current_q_values = self.q_net(replay_data.observations) current_q_values = th.gather( current_q_values, dim=1, index=replay_data.actions.long() ) loss = F.smooth_l1_loss(current_q_values, target_q_values) losses.append(float(loss.item())) self.policy.optimizer.zero_grad() loss.backward() th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() last_batch_entropy = entropy.mean().item() self._n_updates += gradient_steps self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/loss", np.mean(losses)) self.logger.record("train/ent_coef", np.mean(ent_coefs)) if len(ent_coef_losses) > 0: self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) if last_batch_entropy is not None: self.logger.record("train/entropy", last_batch_entropy) self.logger.record("train/action_temperature", self.action_temperature)
[docs] def predict( self, observation: Union[np.ndarray, dict[str, np.ndarray]], state: Optional[tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]: return self.policy.predict(observation, state, episode_start, deterministic)
[docs] def learn( self: SelfSQL, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, tb_log_name: str = "SQL", reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfSQL: return super().learn( total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, tb_log_name=tb_log_name, reset_num_timesteps=reset_num_timesteps, progress_bar=progress_bar, )
def _excluded_save_params(self) -> list[str]: return [*super()._excluded_save_params(), "q_net", "q_net_target"] def _get_torch_save_params(self) -> tuple[list[str], list[str]]: state_dicts = ["policy", "policy.optimizer"] if self.ent_coef_optimizer is not None: state_dicts.append("ent_coef_optimizer") saved_pytorch_variables = ["log_ent_coef"] else: saved_pytorch_variables = ["ent_coef_tensor"] return state_dicts, saved_pytorch_variables