SQL API

Algorithm

Soft Q-Learning (SQL) implementation for discrete action spaces, based on Haarnoha et al. (2017) “Reinforcement Learning with Deep Energy-Based Policies” https://proceedings.mlr.press/v70/haarnoja17a.html

class sb3_soft.sql.sql.SQL(policy, env, learning_rate=0.0001, buffer_size=1000000, learning_starts=100, batch_size=32, tau=1.0, gamma=0.99, train_freq=4, gradient_steps=1, action_noise=None, replay_buffer_class=None, replay_buffer_kwargs=None, optimize_memory_usage=False, n_steps=1, target_update_interval=10000, max_grad_norm=10, ent_coef='auto', target_entropy='auto', action_temperature=None, stats_window_size=100, tensorboard_log=None, policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)[source]

Bases: OffPolicyAlgorithm

Discrete-action Soft Q-Learning from T. Haarnoja, H. Tang, P. Abbeel, and S. Levine, Reinforcement Learning with Deep Energy-Based Policies,” Proceedings of the 34th International Conference on Machine Learning, PMLR, Jul. 2017, pp. 1352–1361. https://proceedings.mlr.press/v70/haarnoja17a.html

Extends SB3’s OffPolicyAlgorithm with an entropy-regularized Bellman backup and Boltzmann (softmax) action sampling.

Parameters:
  • policy (str | type[SQLPolicy]) – Policy model to use (e.g., "MlpPolicy", "CnnPolicy").

  • env (GymEnv | str) – Environment to learn from.

  • learning_rate (float | Schedule, default=1e-4) – Learning rate (constant or schedule).

  • buffer_size (int, default=1_000_000) – Replay buffer capacity.

  • learning_starts (int, default=100) – Number of steps of random exploration before learning starts.

  • batch_size (int, default=32) – Minibatch size for each gradient update.

  • tau (float, default=1.0) – Polyak update coefficient for target network updates.

  • gamma (float, default=0.99) – Discount factor.

  • train_freq (int | tuple[int, str], default=4) – Update frequency in steps or episodes.

  • gradient_steps (int, default=1) – Number of gradient updates after each rollout.

  • action_noise (ActionNoise | None, default=None) – Action noise for exploration (only applicable to continuous action spaces).

  • replay_buffer_class (type[ReplayBuffer] | None, default=None) – Optional replay buffer implementation override.

  • replay_buffer_kwargs (dict[str, Any] | None, default=None) – Additional keyword arguments for replay buffer creation.

  • optimize_memory_usage (bool, default=False) – Whether to use the memory-efficient replay buffer variant.

  • n_steps (int, default=1) – Number of steps for n-step returns.

  • target_update_interval (int, default=10_000) – Environment steps between target network updates.

  • max_grad_norm (float, default=10) – Maximum gradient norm for clipping.

  • ent_coef (str | float, default="auto") – Temperature \(\alpha\) used in the soft Bellman target \(V(s') = \alpha \log \sum_a \exp(Q(s', a) / \alpha)\). Set to "auto" (or "auto_0.1") to learn it automatically.

  • target_entropy (str | float, default="auto") – Target policy entropy used when ent_coef is learned automatically. If "auto", uses \(0.98 \log(|\mathcal{A}|)\), but this may need adjustment for low-entropy environments.

  • action_temperature (float | None, default=None) – Temperature :math:` au` for Boltzmann action sampling \(\pi(a \mid s) \propto \exp(Q(s, a) / \tau)\). If None, uses ent_coef.

  • stats_window_size (int, default=100) – Window size for rollout statistics logging.

  • tensorboard_log (str | None, default=None) – TensorBoard log directory.

  • policy_kwargs (dict[str, Any] | None, default=None) – Additional keyword arguments passed to the policy.

  • verbose (int, default=0) – Verbosity level (0: no output, 1: info, 2: debug).

  • seed (int | None, default=None) – Random seed.

  • device (torch.device | str, default="auto") – Device to run the model on.

  • _init_setup_model (bool, default=True) – Whether to build networks and optimizers during construction.

policy_aliases = {'CnnPolicy': <class 'sb3_soft.sql.policies.SQLCnnPolicy'>, 'MlpPolicy': <class 'sb3_soft.sql.policies.SQLPolicy'>, 'MultiInputPolicy': <class 'sb3_soft.sql.policies.SQLMultiInputPolicy'>}
q_net
q_net_target
train(gradient_steps, batch_size=100)[source]

Sample the replay buffer and do the updates (gradient descent and update target networks)

predict(observation, state=None, episode_start=None, deterministic=False)[source]

Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images).

Parameters:
  • observation – the input observation

  • state – The last hidden states (can be None, used in recurrent policies)

  • episode_start – The last masks (can be None, used in recurrent policies) this correspond to beginning of episodes, where the hidden states of the RNN must be reset.

  • deterministic – Whether or not to return deterministic actions.

Returns:

the model’s action and the next hidden state (used in recurrent policies)

learn(total_timesteps, callback=None, log_interval=4, tb_log_name='SQL', reset_num_timesteps=True, progress_bar=False)[source]

Return a trained model.

Parameters:
  • total_timesteps – The total number of samples (env steps) to train on Note: it is a lower bound, see issue #1150

  • callback – callback(s) called at every step with state of the algorithm.

  • log_interval – for on-policy algos (e.g., PPO, A2C, …) this is the number of training iterations (i.e., log_interval * n_steps * n_envs timesteps) before logging; for off-policy algos (e.g., TD3, SAC, …) this is the number of episodes before logging.

  • tb_log_name – the name of the run for TensorBoard logging

  • reset_num_timesteps – whether or not to reset the current timestep number (used in logging)

  • progress_bar – Display a progress bar using tqdm and rich.

Returns:

the trained model

Policies

class sb3_soft.sql.policies.SoftQNetwork(observation_space, action_space, features_extractor, features_dim, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, normalize_images=True, temperature=1.0)[source]

Bases: QNetwork

Q-network with Boltzmann sampling for stochastic action selection.

class sb3_soft.sql.policies.SQLPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, features_extractor_class=<class 'stable_baselines3.common.torch_layers.FlattenExtractor'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, temperature=1.0)[source]

Bases: DQNPolicy

DQN policy variant using a soft Q-network for stochastic sampling.

make_q_net()[source]
class sb3_soft.sql.policies.SQLCnnPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, features_extractor_class=<class 'stable_baselines3.common.torch_layers.NatureCNN'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, temperature=1.0)[source]

Bases: SQLPolicy

class sb3_soft.sql.policies.SQLMultiInputPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, features_extractor_class=<class 'stable_baselines3.common.torch_layers.CombinedExtractor'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, temperature=1.0)[source]

Bases: SQLPolicy

sb3_soft.sql.policies.MlpPolicy

alias of SQLPolicy

sb3_soft.sql.policies.CnnPolicy

alias of SQLCnnPolicy

sb3_soft.sql.policies.MultiInputPolicy

alias of SQLMultiInputPolicy