SDSAC API
Algorithm
Stable Discrete Soft Actor-Critic (SDSAC).
Actor-critic algorithm for discrete action spaces based on Zhou et al. (2024), “Revisiting Discrete Soft Actor-Critic”. https://openreview.net/forum?id=EUF2R6VBeU
Key differences from continuous SAC and naive discrete SAC:
The actor outputs a categorical distribution over discrete actions.
Twin critics output Q-values for all actions given a state (no action input), enabling exact expectation computation.
Double-average Q-learning: the target uses
mean(notmin) of the twin target critics.The actor objective also uses the
meanof twin online critics (instead ofmin) for policy improvement.- Q-clip: the critic loss is
\(\max\left((Q - y)^2,\, \left(Q' + \mathrm{clip}(Q - Q', -c, c) - y\right)^2\right)\).
- Entropy-penalty: the actor loss includes
\(\beta \cdot \tfrac{1}{2} \cdot \left(H_{\pi_{\mathrm{old}}} - H_{\pi}\right)^2\), where \(H_{\pi_{\mathrm{old}}}\) is stored in the replay buffer at collection time.
- class sb3_soft.sdsac.sdsac.SDSAC(policy, env, learning_rate=0.0003, buffer_size=1000000, learning_starts=100, batch_size=256, tau=0.005, gamma=0.99, train_freq=1, gradient_steps=1, replay_buffer_class=None, replay_buffer_kwargs=None, optimize_memory_usage=False, n_steps=1, ent_coef='auto', target_update_interval=1, target_entropy='auto', beta=0.5, clip_range=0.5, stats_window_size=100, tensorboard_log=None, policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)[source]
Bases:
OffPolicyAlgorithmStable Discrete Soft Actor-Critic (SD-SAC) from H. Zhou et al., “Revisiting Discrete Soft Actor-Critic,” Transactions on Machine Learning Research, Aug. 2024. https://openreview.net/forum?id=EUF2R6VBeU
An off-policy actor-critic algorithm for discrete action spaces that maintains separate actor and twin-critic networks and performs entropy-regularized updates using full-distribution expectations.
Compared with a naive discrete adaptation of SAC, SD-SAC adds three stabilisation mechanisms (Algorithm 1 in the paper):
- Double-average Q-learning – the Bellman target uses
\(\mathrm{mean}(Q'_1, Q'_2)\) of the twin target critics instead of
min. The actor objective likewise uses the mean of online critics \(\mathrm{mean}(Q_1, Q_2)\) in place of a clipped-min estimate.
- Q-clip – the critic loss is
\(\max\left((Q - y)^2,\, \left(Q' + \mathrm{clip}(Q - Q', -c, c) - y\right)^2\right)\).
- Entropy penalty – the actor loss adds
\(\beta \cdot 0.5 \cdot \left(H_{\pi_{\mathrm{old}}} - H_{\pi}\right)^2\), where \(H_{\pi_{\mathrm{old}}}\) is the policy entropy stored in the replay buffer at collection time.
Reference: Zhou et al. (2024) “Revisiting Discrete Soft Actor-Critic”.
- Parameters:
policy (str | type[SDSACPolicy]) – Policy model to use (
"MlpPolicy","CnnPolicy", …).env (GymEnv | str) – Environment to learn from.
learning_rate (float | Schedule, default=3e-4) – Learning rate for all networks (actor, critic, and optionally alpha).
buffer_size (int, default=1_000_000) – Replay buffer capacity.
learning_starts (int, default=100) – Number of environment steps to collect before training starts.
batch_size (int, default=256) – Mini-batch size for each gradient update.
tau (float, default=0.005) – Polyak averaging coefficient for target network updates.
gamma (float, default=0.99) – Discount factor.
train_freq (int | tuple[int, str], default=1) – How often to update the model (in steps or episodes).
gradient_steps (int, default=1) – Gradient updates per rollout step.
-1means as many as environment steps collected.replay_buffer_class (type[ReplayBuffer] | None, default=None) – Custom replay buffer class. Defaults to
SDSACReplayBuffer.replay_buffer_kwargs (dict | None, default=None) – Keyword arguments for the replay buffer.
optimize_memory_usage (bool, default=False) – Memory-efficient replay buffer variant.
n_steps (int, default=1) – Steps for n-step returns.
ent_coef (str | float, default="auto") – Entropy coefficient (temperature) \(\alpha\).
"auto"enables automatic tuning ("auto_0.1"sets the initial value).target_update_interval (int, default=1) – Gradient steps between target network updates.
target_entropy (str | float, default="auto") – Target entropy for automatic \(\alpha\) tuning.
"auto"uses \(0.98 \log |\mathcal{A}|\), but this may need adjustment for low-entropy environments.beta (float, default=0.1) – Entropy-penalty coefficient \(\beta\) in the actor loss.
clip_range (float, default=0.5) – Clipping range \(c\) for the Q-clip critic loss.
stats_window_size (int, default=100) – Window size for rollout statistics.
tensorboard_log (str | None, default=None) – TensorBoard log directory.
policy_kwargs (dict | None, default=None) – Extra keyword arguments for policy construction.
verbose (int, default=0) – Verbosity level (0: silent, 1: info, 2: debug).
seed (int | None, default=None) – Random seed.
device (str | th.device, default="auto") – Computation device.
_init_setup_model (bool, default=True) – Whether to build networks on construction.
- policy_aliases = {'CnnPolicy': <class 'sb3_soft.sdsac.policies.CnnPolicy'>, 'MlpPolicy': <class 'sb3_soft.sdsac.policies.SDSACPolicy'>, 'MultiInputPolicy': <class 'sb3_soft.sdsac.policies.MultiInputPolicy'>}
- policy
- actor
- critic
- critic_target
- train(gradient_steps, batch_size=64)[source]
Sample the replay buffer and do the updates (gradient descent and update target networks)
- predict(observation, state=None, episode_start=None, deterministic=False)[source]
Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images).
- Parameters:
observation – the input observation
state – The last hidden states (can be None, used in recurrent policies)
episode_start – The last masks (can be None, used in recurrent policies) this correspond to beginning of episodes, where the hidden states of the RNN must be reset.
deterministic – Whether or not to return deterministic actions.
- Returns:
the model’s action and the next hidden state (used in recurrent policies)
- learn(total_timesteps, callback=None, log_interval=4, tb_log_name='SDSAC', reset_num_timesteps=True, progress_bar=False)[source]
Return a trained model.
- Parameters:
total_timesteps – The total number of samples (env steps) to train on Note: it is a lower bound, see issue #1150
callback – callback(s) called at every step with state of the algorithm.
log_interval – for on-policy algos (e.g., PPO, A2C, …) this is the number of training iterations (i.e., log_interval * n_steps * n_envs timesteps) before logging; for off-policy algos (e.g., TD3, SAC, …) this is the number of episodes before logging.
tb_log_name – the name of the run for TensorBoard logging
reset_num_timesteps – whether or not to reset the current timestep number (used in logging)
progress_bar – Display a progress bar using tqdm and rich.
- Returns:
the trained model
Policies
Policies for Stable Discrete Soft Actor-Critic (SDSAC).
Implements a discrete-action actor-critic architecture following Zhou et al. (2024), “Revisiting Discrete Soft Actor-Critic”.
The actor outputs a categorical distribution over discrete actions. Twin critics each output Q-values for all actions given a state.
- class sb3_soft.sdsac.policies.DiscreteActor(observation_space, action_space, net_arch, features_extractor, features_dim, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, normalize_images=True)[source]
Bases:
BasePolicyActor (policy) network for discrete-action SAC.
Outputs a categorical distribution over the discrete action space via a softmax over learned logits.
- Parameters:
observation_space (spaces.Space) – Observation space.
action_space (spaces.Discrete) – Discrete action space.
net_arch (list[int]) – Network architecture (list of hidden layer sizes).
features_extractor (nn.Module) – Network used to extract features from observations.
features_dim (int) – Dimensionality of extracted features.
activation_fn (type[nn.Module], default=nn.ReLU) – Activation function.
normalize_images (bool, default=True) – Whether to normalize images by dividing by 255.
- action_space
- get_action_dist_params(obs)[source]
Compute action logits from observations.
- Parameters:
obs (PyTorchObs) – Batched observations.
- Returns:
Raw logits of shape
(batch, n_actions).- Return type:
th.Tensor
- get_action_probs(obs, epsilon=1e-08)[source]
Get action probabilities and log-probabilities.
- Parameters:
obs (PyTorchObs) – Batched observations.
epsilon (float, default=1e-8) – Unused placeholder for API compatibility.
- Returns:
Tuple
(probs, log_probs), each of shape(batch, n_actions).- Return type:
tuple[th.Tensor, th.Tensor]
- class sb3_soft.sdsac.policies.DiscreteCritic(observation_space, action_space, net_arch, features_extractor, features_dim, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, normalize_images=True, n_critics=2, share_features_extractor=True)[source]
Bases:
BaseModelTwin Q-network critic for discrete-action SAC.
Each Q-network takes a state as input and outputs Q-values for every discrete action. Multiple networks (default: 2) are used to reduce overestimation via clipped double Q-learning.
- Parameters:
observation_space (spaces.Space) – Observation space.
action_space (spaces.Discrete) – Discrete action space.
net_arch (list[int]) – Network architecture for each Q-network.
features_extractor (BaseFeaturesExtractor) – Network used to extract features from observations.
features_dim (int) – Dimensionality of extracted features.
activation_fn (type[nn.Module], default=nn.ReLU) – Activation function.
normalize_images (bool, default=True) – Whether to normalize images by dividing by 255.
n_critics (int, default=2) – Number of Q-networks to create.
share_features_extractor (bool, default=True) – Whether the features extractor is shared with the actor. If
True, gradients through it are blocked in the critic forward pass.
- features_extractor
- q_networks
- class sb3_soft.sdsac.policies.SDSACPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, features_extractor_class=<class 'stable_baselines3.common.torch_layers.FlattenExtractor'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, n_critics=2, share_features_extractor=False)[source]
Bases:
BasePolicyPolicy class (actor + twin critics) for discrete-action SAC.
- Parameters:
observation_space (spaces.Space) – Observation space.
action_space (spaces.Discrete) – Discrete action space.
lr_schedule (Schedule) – Learning rate schedule.
net_arch (Optional[Union[list[int], dict[str, list[int]]]], default=None) – Network architecture specification. Can be a list of integers (shared) or a dictionary with
"pi"and"qf"keys.activation_fn (type[nn.Module], default=nn.ReLU) – Activation function.
features_extractor_class (type[BaseFeaturesExtractor], default=FlattenExtractor) – Features extractor class.
features_extractor_kwargs (Optional[dict[str, Any]], default=None) – Keyword arguments for the features extractor.
normalize_images (bool, default=True) – Whether to normalize images by dividing by 255.
optimizer_class (type[th.optim.Optimizer], default=th.optim.Adam) – Optimizer class.
optimizer_kwargs (Optional[dict[str, Any]], default=None) – Additional optimizer keyword arguments.
n_critics (int, default=2) – Number of critic networks.
share_features_extractor (bool, default=False) – Whether to share the features extractor between actor and critic.
- actor
- critic
- critic_target
- forward(obs, deterministic=False)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- sb3_soft.sdsac.policies.MlpPolicy
alias of
SDSACPolicy
- class sb3_soft.sdsac.policies.CnnPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, features_extractor_class=<class 'stable_baselines3.common.torch_layers.NatureCNN'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, n_critics=2, share_features_extractor=False)[source]
Bases:
SDSACPolicySDSAC policy with
NatureCNNfeatures extractor.
- class sb3_soft.sdsac.policies.MultiInputPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, features_extractor_class=<class 'stable_baselines3.common.torch_layers.CombinedExtractor'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, n_critics=2, share_features_extractor=False)[source]
Bases:
SDSACPolicySDSAC policy with
CombinedExtractorfor dict observations.
Replay Buffers
Custom replay buffer for SD-SAC.
Extends the standard SB3 ReplayBuffer with per-transition storage
of the policy entropy at collection time (H_πold). This is required by
the entropy-penalty term in the SD-SAC actor loss.
- class sb3_soft.sdsac.buffers.SDSACReplayBufferSamples(observations, actions, next_observations, dones, rewards, old_entropies, discounts=None)[source]
Bases:
NamedTupleReplay-buffer samples with an extra
old_entropiesfield.- observations
Batch of observations.
- Type:
th.Tensor | dict[str, th.Tensor]
- actions
Batch of actions.
- Type:
th.Tensor
- next_observations
Batch of next observations.
- Type:
th.Tensor | dict[str, th.Tensor]
- dones
Batch of done flags.
- Type:
th.Tensor
- rewards
Batch of rewards.
- Type:
th.Tensor
- old_entropies
Policy entropy at the time the transition was collected, shape
(batch, 1).- Type:
th.Tensor
- discounts
Per-sample discount factors (used by n-step buffers).
- Type:
th.Tensor | None
- observations
Alias for field number 0
- actions
Alias for field number 1
- next_observations
Alias for field number 2
- dones
Alias for field number 3
- rewards
Alias for field number 4
- old_entropies
Alias for field number 5
- discounts
Alias for field number 6
- class sb3_soft.sdsac.buffers.SDSACReplayBuffer(buffer_size, observation_space, action_space, device='auto', n_envs=1, optimize_memory_usage=False, handle_timeout_termination=True)[source]
Bases:
ReplayBufferReplay buffer that additionally stores per-transition policy entropy.
The entropy is set via
set_entropy()before eachadd()call. During sampling, the stored entropy is returned alongside the standard replay-buffer fields.- Parameters:
buffer_size (int) – Maximum number of transitions to store.
observation_space (spaces.Space) – Observation space.
action_space (spaces.Space) – Action space.
device (Union[th.device, str], default="auto") – Device for returned tensors.
n_envs (int, default=1) – Number of parallel environments.
optimize_memory_usage (bool, default=False) – Memory-efficient variant (see SB3 docs).
handle_timeout_termination (bool, default=True) – Whether to handle
TimeLimit.truncatedininfos.
- old_entropies
- set_entropy(entropy)[source]
Stage entropy values to be written on the next
add()call.- Parameters:
entropy (np.ndarray) – Entropy for each environment, shape
(n_envs,)or(n_envs, 1).
- add(obs, next_obs, action, reward, done, infos)[source]
Store a transition, including any staged entropy.
- class sb3_soft.sdsac.buffers.SDSACDictReplayBuffer(buffer_size, observation_space, action_space, device='auto', n_envs=1, optimize_memory_usage=False, handle_timeout_termination=True)[source]
Bases:
DictReplayBufferDict replay buffer that additionally stores per-transition entropy.
- old_entropies
- set_entropy(entropy)[source]
Stage entropy values to be written on the next
add()call.
- add(obs, next_obs, action, reward, done, infos)[source]
Store a transition, including any staged entropy.