SDSAC API

Algorithm

Stable Discrete Soft Actor-Critic (SDSAC).

Actor-critic algorithm for discrete action spaces based on Zhou et al. (2024), “Revisiting Discrete Soft Actor-Critic”. https://openreview.net/forum?id=EUF2R6VBeU

Key differences from continuous SAC and naive discrete SAC:

  • The actor outputs a categorical distribution over discrete actions.

  • Twin critics output Q-values for all actions given a state (no action input), enabling exact expectation computation.

  • Double-average Q-learning: the target uses mean (not min) of the twin target critics.

  • The actor objective also uses the mean of twin online critics (instead of min) for policy improvement.

  • Q-clip: the critic loss is

    \(\max\left((Q - y)^2,\, \left(Q' + \mathrm{clip}(Q - Q', -c, c) - y\right)^2\right)\).

  • Entropy-penalty: the actor loss includes

    \(\beta \cdot \tfrac{1}{2} \cdot \left(H_{\pi_{\mathrm{old}}} - H_{\pi}\right)^2\), where \(H_{\pi_{\mathrm{old}}}\) is stored in the replay buffer at collection time.

class sb3_soft.sdsac.sdsac.SDSAC(policy, env, learning_rate=0.0003, buffer_size=1000000, learning_starts=100, batch_size=256, tau=0.005, gamma=0.99, train_freq=1, gradient_steps=1, replay_buffer_class=None, replay_buffer_kwargs=None, optimize_memory_usage=False, n_steps=1, ent_coef='auto', target_update_interval=1, target_entropy='auto', beta=0.5, clip_range=0.5, stats_window_size=100, tensorboard_log=None, policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)[source]

Bases: OffPolicyAlgorithm

Stable Discrete Soft Actor-Critic (SD-SAC) from H. Zhou et al., “Revisiting Discrete Soft Actor-Critic,” Transactions on Machine Learning Research, Aug. 2024. https://openreview.net/forum?id=EUF2R6VBeU

An off-policy actor-critic algorithm for discrete action spaces that maintains separate actor and twin-critic networks and performs entropy-regularized updates using full-distribution expectations.

Compared with a naive discrete adaptation of SAC, SD-SAC adds three stabilisation mechanisms (Algorithm 1 in the paper):

  1. Double-average Q-learning – the Bellman target uses

    \(\mathrm{mean}(Q'_1, Q'_2)\) of the twin target critics instead of min. The actor objective likewise uses the mean of online critics \(\mathrm{mean}(Q_1, Q_2)\) in place of a clipped-min estimate.

  2. Q-clip – the critic loss is

    \(\max\left((Q - y)^2,\, \left(Q' + \mathrm{clip}(Q - Q', -c, c) - y\right)^2\right)\).

  3. Entropy penalty – the actor loss adds

    \(\beta \cdot 0.5 \cdot \left(H_{\pi_{\mathrm{old}}} - H_{\pi}\right)^2\), where \(H_{\pi_{\mathrm{old}}}\) is the policy entropy stored in the replay buffer at collection time.

Reference: Zhou et al. (2024) “Revisiting Discrete Soft Actor-Critic”.

Parameters:
  • policy (str | type[SDSACPolicy]) – Policy model to use ("MlpPolicy", "CnnPolicy", …).

  • env (GymEnv | str) – Environment to learn from.

  • learning_rate (float | Schedule, default=3e-4) – Learning rate for all networks (actor, critic, and optionally alpha).

  • buffer_size (int, default=1_000_000) – Replay buffer capacity.

  • learning_starts (int, default=100) – Number of environment steps to collect before training starts.

  • batch_size (int, default=256) – Mini-batch size for each gradient update.

  • tau (float, default=0.005) – Polyak averaging coefficient for target network updates.

  • gamma (float, default=0.99) – Discount factor.

  • train_freq (int | tuple[int, str], default=1) – How often to update the model (in steps or episodes).

  • gradient_steps (int, default=1) – Gradient updates per rollout step. -1 means as many as environment steps collected.

  • replay_buffer_class (type[ReplayBuffer] | None, default=None) – Custom replay buffer class. Defaults to SDSACReplayBuffer.

  • replay_buffer_kwargs (dict | None, default=None) – Keyword arguments for the replay buffer.

  • optimize_memory_usage (bool, default=False) – Memory-efficient replay buffer variant.

  • n_steps (int, default=1) – Steps for n-step returns.

  • ent_coef (str | float, default="auto") – Entropy coefficient (temperature) \(\alpha\). "auto" enables automatic tuning ("auto_0.1" sets the initial value).

  • target_update_interval (int, default=1) – Gradient steps between target network updates.

  • target_entropy (str | float, default="auto") – Target entropy for automatic \(\alpha\) tuning. "auto" uses \(0.98 \log |\mathcal{A}|\), but this may need adjustment for low-entropy environments.

  • beta (float, default=0.1) – Entropy-penalty coefficient \(\beta\) in the actor loss.

  • clip_range (float, default=0.5) – Clipping range \(c\) for the Q-clip critic loss.

  • stats_window_size (int, default=100) – Window size for rollout statistics.

  • tensorboard_log (str | None, default=None) – TensorBoard log directory.

  • policy_kwargs (dict | None, default=None) – Extra keyword arguments for policy construction.

  • verbose (int, default=0) – Verbosity level (0: silent, 1: info, 2: debug).

  • seed (int | None, default=None) – Random seed.

  • device (str | th.device, default="auto") – Computation device.

  • _init_setup_model (bool, default=True) – Whether to build networks on construction.

policy_aliases = {'CnnPolicy': <class 'sb3_soft.sdsac.policies.CnnPolicy'>, 'MlpPolicy': <class 'sb3_soft.sdsac.policies.SDSACPolicy'>, 'MultiInputPolicy': <class 'sb3_soft.sdsac.policies.MultiInputPolicy'>}
policy
actor
critic
critic_target
train(gradient_steps, batch_size=64)[source]

Sample the replay buffer and do the updates (gradient descent and update target networks)

predict(observation, state=None, episode_start=None, deterministic=False)[source]

Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images).

Parameters:
  • observation – the input observation

  • state – The last hidden states (can be None, used in recurrent policies)

  • episode_start – The last masks (can be None, used in recurrent policies) this correspond to beginning of episodes, where the hidden states of the RNN must be reset.

  • deterministic – Whether or not to return deterministic actions.

Returns:

the model’s action and the next hidden state (used in recurrent policies)

learn(total_timesteps, callback=None, log_interval=4, tb_log_name='SDSAC', reset_num_timesteps=True, progress_bar=False)[source]

Return a trained model.

Parameters:
  • total_timesteps – The total number of samples (env steps) to train on Note: it is a lower bound, see issue #1150

  • callback – callback(s) called at every step with state of the algorithm.

  • log_interval – for on-policy algos (e.g., PPO, A2C, …) this is the number of training iterations (i.e., log_interval * n_steps * n_envs timesteps) before logging; for off-policy algos (e.g., TD3, SAC, …) this is the number of episodes before logging.

  • tb_log_name – the name of the run for TensorBoard logging

  • reset_num_timesteps – whether or not to reset the current timestep number (used in logging)

  • progress_bar – Display a progress bar using tqdm and rich.

Returns:

the trained model

Policies

Policies for Stable Discrete Soft Actor-Critic (SDSAC).

Implements a discrete-action actor-critic architecture following Zhou et al. (2024), “Revisiting Discrete Soft Actor-Critic”.

The actor outputs a categorical distribution over discrete actions. Twin critics each output Q-values for all actions given a state.

class sb3_soft.sdsac.policies.DiscreteActor(observation_space, action_space, net_arch, features_extractor, features_dim, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, normalize_images=True)[source]

Bases: BasePolicy

Actor (policy) network for discrete-action SAC.

Outputs a categorical distribution over the discrete action space via a softmax over learned logits.

Parameters:
  • observation_space (spaces.Space) – Observation space.

  • action_space (spaces.Discrete) – Discrete action space.

  • net_arch (list[int]) – Network architecture (list of hidden layer sizes).

  • features_extractor (nn.Module) – Network used to extract features from observations.

  • features_dim (int) – Dimensionality of extracted features.

  • activation_fn (type[nn.Module], default=nn.ReLU) – Activation function.

  • normalize_images (bool, default=True) – Whether to normalize images by dividing by 255.

action_space
get_action_dist_params(obs)[source]

Compute action logits from observations.

Parameters:

obs (PyTorchObs) – Batched observations.

Returns:

Raw logits of shape (batch, n_actions).

Return type:

th.Tensor

get_action_probs(obs, epsilon=1e-08)[source]

Get action probabilities and log-probabilities.

Parameters:
  • obs (PyTorchObs) – Batched observations.

  • epsilon (float, default=1e-8) – Unused placeholder for API compatibility.

Returns:

Tuple (probs, log_probs), each of shape (batch, n_actions).

Return type:

tuple[th.Tensor, th.Tensor]

forward(obs, deterministic=False)[source]

Select actions from observations.

Parameters:
  • obs (PyTorchObs) – Batched observations.

  • deterministic (bool, default=False) – If True, return greedy actions (argmax).

Returns:

Selected action indices of shape (batch,).

Return type:

th.Tensor

class sb3_soft.sdsac.policies.DiscreteCritic(observation_space, action_space, net_arch, features_extractor, features_dim, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, normalize_images=True, n_critics=2, share_features_extractor=True)[source]

Bases: BaseModel

Twin Q-network critic for discrete-action SAC.

Each Q-network takes a state as input and outputs Q-values for every discrete action. Multiple networks (default: 2) are used to reduce overestimation via clipped double Q-learning.

Parameters:
  • observation_space (spaces.Space) – Observation space.

  • action_space (spaces.Discrete) – Discrete action space.

  • net_arch (list[int]) – Network architecture for each Q-network.

  • features_extractor (BaseFeaturesExtractor) – Network used to extract features from observations.

  • features_dim (int) – Dimensionality of extracted features.

  • activation_fn (type[nn.Module], default=nn.ReLU) – Activation function.

  • normalize_images (bool, default=True) – Whether to normalize images by dividing by 255.

  • n_critics (int, default=2) – Number of Q-networks to create.

  • share_features_extractor (bool, default=True) – Whether the features extractor is shared with the actor. If True, gradients through it are blocked in the critic forward pass.

features_extractor
q_networks
forward(obs)[source]

Compute Q-values for all actions from all critic networks.

Parameters:

obs (th.Tensor) – Batched observations.

Returns:

Q-value tensors, one per critic, each of shape (batch, n_actions).

Return type:

tuple[th.Tensor, …]

q1_forward(obs)[source]

Compute Q-values using only the first critic network.

Parameters:

obs (th.Tensor) – Batched observations.

Returns:

Q-values of shape (batch, n_actions).

Return type:

th.Tensor

class sb3_soft.sdsac.policies.SDSACPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, features_extractor_class=<class 'stable_baselines3.common.torch_layers.FlattenExtractor'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, n_critics=2, share_features_extractor=False)[source]

Bases: BasePolicy

Policy class (actor + twin critics) for discrete-action SAC.

Parameters:
  • observation_space (spaces.Space) – Observation space.

  • action_space (spaces.Discrete) – Discrete action space.

  • lr_schedule (Schedule) – Learning rate schedule.

  • net_arch (Optional[Union[list[int], dict[str, list[int]]]], default=None) – Network architecture specification. Can be a list of integers (shared) or a dictionary with "pi" and "qf" keys.

  • activation_fn (type[nn.Module], default=nn.ReLU) – Activation function.

  • features_extractor_class (type[BaseFeaturesExtractor], default=FlattenExtractor) – Features extractor class.

  • features_extractor_kwargs (Optional[dict[str, Any]], default=None) – Keyword arguments for the features extractor.

  • normalize_images (bool, default=True) – Whether to normalize images by dividing by 255.

  • optimizer_class (type[th.optim.Optimizer], default=th.optim.Adam) – Optimizer class.

  • optimizer_kwargs (Optional[dict[str, Any]], default=None) – Additional optimizer keyword arguments.

  • n_critics (int, default=2) – Number of critic networks.

  • share_features_extractor (bool, default=False) – Whether to share the features extractor between actor and critic.

actor
critic
critic_target
make_actor(features_extractor=None)[source]
make_critic(features_extractor=None)[source]
forward(obs, deterministic=False)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

set_training_mode(mode)[source]

Put the policy in either training or evaluation mode.

This affects certain modules, such as batch normalisation and dropout.

Parameters:

mode – if true, set to training mode, else set to evaluation mode

sb3_soft.sdsac.policies.MlpPolicy

alias of SDSACPolicy

class sb3_soft.sdsac.policies.CnnPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, features_extractor_class=<class 'stable_baselines3.common.torch_layers.NatureCNN'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, n_critics=2, share_features_extractor=False)[source]

Bases: SDSACPolicy

SDSAC policy with NatureCNN features extractor.

class sb3_soft.sdsac.policies.MultiInputPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, features_extractor_class=<class 'stable_baselines3.common.torch_layers.CombinedExtractor'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, n_critics=2, share_features_extractor=False)[source]

Bases: SDSACPolicy

SDSAC policy with CombinedExtractor for dict observations.

Replay Buffers

Custom replay buffer for SD-SAC.

Extends the standard SB3 ReplayBuffer with per-transition storage of the policy entropy at collection time (H_πold). This is required by the entropy-penalty term in the SD-SAC actor loss.

class sb3_soft.sdsac.buffers.SDSACReplayBufferSamples(observations, actions, next_observations, dones, rewards, old_entropies, discounts=None)[source]

Bases: NamedTuple

Replay-buffer samples with an extra old_entropies field.

observations

Batch of observations.

Type:

th.Tensor | dict[str, th.Tensor]

actions

Batch of actions.

Type:

th.Tensor

next_observations

Batch of next observations.

Type:

th.Tensor | dict[str, th.Tensor]

dones

Batch of done flags.

Type:

th.Tensor

rewards

Batch of rewards.

Type:

th.Tensor

old_entropies

Policy entropy at the time the transition was collected, shape (batch, 1).

Type:

th.Tensor

discounts

Per-sample discount factors (used by n-step buffers).

Type:

th.Tensor | None

observations

Alias for field number 0

actions

Alias for field number 1

next_observations

Alias for field number 2

dones

Alias for field number 3

rewards

Alias for field number 4

old_entropies

Alias for field number 5

discounts

Alias for field number 6

class sb3_soft.sdsac.buffers.SDSACReplayBuffer(buffer_size, observation_space, action_space, device='auto', n_envs=1, optimize_memory_usage=False, handle_timeout_termination=True)[source]

Bases: ReplayBuffer

Replay buffer that additionally stores per-transition policy entropy.

The entropy is set via set_entropy() before each add() call. During sampling, the stored entropy is returned alongside the standard replay-buffer fields.

Parameters:
  • buffer_size (int) – Maximum number of transitions to store.

  • observation_space (spaces.Space) – Observation space.

  • action_space (spaces.Space) – Action space.

  • device (Union[th.device, str], default="auto") – Device for returned tensors.

  • n_envs (int, default=1) – Number of parallel environments.

  • optimize_memory_usage (bool, default=False) – Memory-efficient variant (see SB3 docs).

  • handle_timeout_termination (bool, default=True) – Whether to handle TimeLimit.truncated in infos.

old_entropies
set_entropy(entropy)[source]

Stage entropy values to be written on the next add() call.

Parameters:

entropy (np.ndarray) – Entropy for each environment, shape (n_envs,) or (n_envs, 1).

add(obs, next_obs, action, reward, done, infos)[source]

Store a transition, including any staged entropy.

class sb3_soft.sdsac.buffers.SDSACDictReplayBuffer(buffer_size, observation_space, action_space, device='auto', n_envs=1, optimize_memory_usage=False, handle_timeout_termination=True)[source]

Bases: DictReplayBuffer

Dict replay buffer that additionally stores per-transition entropy.

old_entropies
set_entropy(entropy)[source]

Stage entropy values to be written on the next add() call.

add(obs, next_obs, action, reward, done, infos)[source]

Store a transition, including any staged entropy.