Source code for sb3_soft.sdsac.buffers

"""Custom replay buffer for SD-SAC.

Extends the standard SB3 :class:`ReplayBuffer` with per-transition storage
of the policy entropy at collection time (``H_πold``).  This is required by
the entropy-penalty term in the SD-SAC actor loss.
"""

from typing import Any, NamedTuple, Optional, Union

import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
from stable_baselines3.common.vec_env import VecNormalize


[docs] class SDSACReplayBufferSamples(NamedTuple): """Replay-buffer samples with an extra ``old_entropies`` field. Attributes ---------- observations : th.Tensor | dict[str, th.Tensor] Batch of observations. actions : th.Tensor Batch of actions. next_observations : th.Tensor | dict[str, th.Tensor] Batch of next observations. dones : th.Tensor Batch of done flags. rewards : th.Tensor Batch of rewards. old_entropies : th.Tensor Policy entropy at the time the transition was collected, shape ``(batch, 1)``. discounts : th.Tensor | None Per-sample discount factors (used by n-step buffers). """ observations: Union[th.Tensor, dict[str, th.Tensor]] actions: th.Tensor next_observations: Union[th.Tensor, dict[str, th.Tensor]] dones: th.Tensor rewards: th.Tensor old_entropies: th.Tensor discounts: Optional[th.Tensor] = None
[docs] class SDSACReplayBuffer(ReplayBuffer): """Replay buffer that additionally stores per-transition policy entropy. The entropy is set via :meth:`set_entropy` *before* each :meth:`add` call. During sampling, the stored entropy is returned alongside the standard replay-buffer fields. Parameters ---------- buffer_size : int Maximum number of transitions to store. observation_space : spaces.Space Observation space. action_space : spaces.Space Action space. device : Union[th.device, str], default="auto" Device for returned tensors. n_envs : int, default=1 Number of parallel environments. optimize_memory_usage : bool, default=False Memory-efficient variant (see SB3 docs). handle_timeout_termination : bool, default=True Whether to handle ``TimeLimit.truncated`` in ``infos``. """ old_entropies: np.ndarray def __init__( self, buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, device: Union[th.device, str] = "auto", n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, ) -> None: super().__init__( buffer_size, observation_space, action_space, device=device, n_envs=n_envs, optimize_memory_usage=optimize_memory_usage, handle_timeout_termination=handle_timeout_termination, ) # Allocate storage for per-transition entropy. # Shape mirrors self.rewards: (effective_buffer_size, n_envs). self.old_entropies = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self._pending_entropy: Optional[np.ndarray] = None
[docs] def set_entropy(self, entropy: np.ndarray) -> None: """Stage entropy values to be written on the next :meth:`add` call. Parameters ---------- entropy : np.ndarray Entropy for each environment, shape ``(n_envs,)`` or ``(n_envs, 1)``. """ self._pending_entropy = np.asarray(entropy, dtype=np.float32).reshape( self.n_envs )
[docs] def add( self, obs: np.ndarray, next_obs: np.ndarray, action: np.ndarray, reward: np.ndarray, done: np.ndarray, infos: list[dict[str, Any]], ) -> None: """Store a transition, including any staged entropy.""" if self._pending_entropy is not None: self.old_entropies[self.pos] = self._pending_entropy self._pending_entropy = None else: # Fallback: store zero (e.g. during warmup if entropy isn't set). self.old_entropies[self.pos] = 0.0 super().add(obs, next_obs, action, reward, done, infos)
def _get_samples( # type: ignore[override] self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None, ) -> SDSACReplayBufferSamples: """Return a batch of transitions including old entropies.""" # Randomly select one env per sample (matches parent behaviour). env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),)) if self.optimize_memory_usage: next_obs = self._normalize_obs( self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :], env, ) else: next_obs = self._normalize_obs( self.next_observations[batch_inds, env_indices, :], env ) assert not isinstance(next_obs, dict) observations_np = self._normalize_obs( self.observations[batch_inds, env_indices, :], env ) assert not isinstance(observations_np, dict) observations = self.to_torch(observations_np) actions = self.to_torch(self.actions[batch_inds, env_indices, :]) next_observations = self.to_torch(next_obs) dones = self.to_torch( self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices]) ).reshape(-1, 1) rewards = self.to_torch( self._normalize_reward( self.rewards[batch_inds, env_indices].reshape(-1, 1), env ) ) old_entropies = self.to_torch( self.old_entropies[batch_inds, env_indices].reshape(-1, 1) ) return SDSACReplayBufferSamples( observations=observations, actions=actions, next_observations=next_observations, dones=dones, rewards=rewards, old_entropies=old_entropies, )
[docs] class SDSACDictReplayBuffer(DictReplayBuffer): """Dict replay buffer that additionally stores per-transition entropy.""" old_entropies: np.ndarray def __init__( self, buffer_size: int, observation_space: spaces.Dict, action_space: spaces.Space, device: Union[th.device, str] = "auto", n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, ) -> None: super().__init__( buffer_size, observation_space, action_space, device=device, n_envs=n_envs, optimize_memory_usage=optimize_memory_usage, handle_timeout_termination=handle_timeout_termination, ) self.old_entropies = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self._pending_entropy: Optional[np.ndarray] = None
[docs] def set_entropy(self, entropy: np.ndarray) -> None: """Stage entropy values to be written on the next :meth:`add` call.""" self._pending_entropy = np.asarray(entropy, dtype=np.float32).reshape( self.n_envs )
[docs] def add( # type: ignore[override] self, obs: dict[str, np.ndarray], next_obs: dict[str, np.ndarray], action: np.ndarray, reward: np.ndarray, done: np.ndarray, infos: list[dict[str, Any]], ) -> None: """Store a transition, including any staged entropy.""" if self._pending_entropy is not None: self.old_entropies[self.pos] = self._pending_entropy self._pending_entropy = None else: self.old_entropies[self.pos] = 0.0 super().add(obs, next_obs, action, reward, done, infos)
def _get_samples( # type: ignore[override] self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None, ) -> SDSACReplayBufferSamples: """Return a batch of dict transitions including old entropies.""" env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),)) obs_ = self._normalize_obs( { key: obs[batch_inds, env_indices, :] for key, obs in self.observations.items() }, env, ) next_obs_ = self._normalize_obs( { key: obs[batch_inds, env_indices, :] for key, obs in self.next_observations.items() }, env, ) assert isinstance(obs_, dict) assert isinstance(next_obs_, dict) observations = {key: self.to_torch(obs) for key, obs in obs_.items()} next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()} return SDSACReplayBufferSamples( observations=observations, actions=self.to_torch(self.actions[batch_inds, env_indices]), next_observations=next_observations, dones=self.to_torch( self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices]) ).reshape(-1, 1), rewards=self.to_torch( self._normalize_reward( self.rewards[batch_inds, env_indices].reshape(-1, 1), env ) ), old_entropies=self.to_torch( self.old_entropies[batch_inds, env_indices].reshape(-1, 1) ), )