SQL API
Algorithm
Soft Q-Learning (SQL) implementation for discrete action spaces, based on Haarnoha et al. (2017) “Reinforcement Learning with Deep Energy-Based Policies” https://proceedings.mlr.press/v70/haarnoja17a.html
- class sb3_soft.sql.sql.SQL(policy, env, learning_rate=0.0001, buffer_size=1000000, learning_starts=100, batch_size=32, tau=1.0, gamma=0.99, train_freq=4, gradient_steps=1, action_noise=None, replay_buffer_class=None, replay_buffer_kwargs=None, optimize_memory_usage=False, n_steps=1, target_update_interval=10000, max_grad_norm=10, ent_coef='auto', target_entropy='auto', action_temperature=None, stats_window_size=100, tensorboard_log=None, policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)[source]
Bases:
OffPolicyAlgorithmDiscrete-action Soft Q-Learning from T. Haarnoja, H. Tang, P. Abbeel, and S. Levine, Reinforcement Learning with Deep Energy-Based Policies,” Proceedings of the 34th International Conference on Machine Learning, PMLR, Jul. 2017, pp. 1352–1361. https://proceedings.mlr.press/v70/haarnoja17a.html
Extends SB3’s
OffPolicyAlgorithmwith an entropy-regularized Bellman backup and Boltzmann (softmax) action sampling.- Parameters:
policy (str | type[SQLPolicy]) – Policy model to use (e.g.,
"MlpPolicy","CnnPolicy").env (GymEnv | str) – Environment to learn from.
learning_rate (float | Schedule, default=1e-4) – Learning rate (constant or schedule).
buffer_size (int, default=1_000_000) – Replay buffer capacity.
learning_starts (int, default=100) – Number of steps of random exploration before learning starts.
batch_size (int, default=32) – Minibatch size for each gradient update.
tau (float, default=1.0) – Polyak update coefficient for target network updates.
gamma (float, default=0.99) – Discount factor.
train_freq (int | tuple[int, str], default=4) – Update frequency in steps or episodes.
gradient_steps (int, default=1) – Number of gradient updates after each rollout.
action_noise (ActionNoise | None, default=None) – Action noise for exploration (only applicable to continuous action spaces).
replay_buffer_class (type[ReplayBuffer] | None, default=None) – Optional replay buffer implementation override.
replay_buffer_kwargs (dict[str, Any] | None, default=None) – Additional keyword arguments for replay buffer creation.
optimize_memory_usage (bool, default=False) – Whether to use the memory-efficient replay buffer variant.
n_steps (int, default=1) – Number of steps for n-step returns.
target_update_interval (int, default=10_000) – Environment steps between target network updates.
max_grad_norm (float, default=10) – Maximum gradient norm for clipping.
ent_coef (str | float, default="auto") – Temperature \(\alpha\) used in the soft Bellman target \(V(s') = \alpha \log \sum_a \exp(Q(s', a) / \alpha)\). Set to
"auto"(or"auto_0.1") to learn it automatically.target_entropy (str | float, default="auto") – Target policy entropy used when
ent_coefis learned automatically. If"auto", uses \(0.98 \log(|\mathcal{A}|)\), but this may need adjustment for low-entropy environments.action_temperature (float | None, default=None) – Temperature :math:` au` for Boltzmann action sampling \(\pi(a \mid s) \propto \exp(Q(s, a) / \tau)\). If
None, usesent_coef.stats_window_size (int, default=100) – Window size for rollout statistics logging.
tensorboard_log (str | None, default=None) – TensorBoard log directory.
policy_kwargs (dict[str, Any] | None, default=None) – Additional keyword arguments passed to the policy.
verbose (int, default=0) – Verbosity level (0: no output, 1: info, 2: debug).
seed (int | None, default=None) – Random seed.
device (torch.device | str, default="auto") – Device to run the model on.
_init_setup_model (bool, default=True) – Whether to build networks and optimizers during construction.
- policy_aliases = {'CnnPolicy': <class 'sb3_soft.sql.policies.SQLCnnPolicy'>, 'MlpPolicy': <class 'sb3_soft.sql.policies.SQLPolicy'>, 'MultiInputPolicy': <class 'sb3_soft.sql.policies.SQLMultiInputPolicy'>}
- q_net
- q_net_target
- train(gradient_steps, batch_size=100)[source]
Sample the replay buffer and do the updates (gradient descent and update target networks)
- predict(observation, state=None, episode_start=None, deterministic=False)[source]
Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images).
- Parameters:
observation – the input observation
state – The last hidden states (can be None, used in recurrent policies)
episode_start – The last masks (can be None, used in recurrent policies) this correspond to beginning of episodes, where the hidden states of the RNN must be reset.
deterministic – Whether or not to return deterministic actions.
- Returns:
the model’s action and the next hidden state (used in recurrent policies)
- learn(total_timesteps, callback=None, log_interval=4, tb_log_name='SQL', reset_num_timesteps=True, progress_bar=False)[source]
Return a trained model.
- Parameters:
total_timesteps – The total number of samples (env steps) to train on Note: it is a lower bound, see issue #1150
callback – callback(s) called at every step with state of the algorithm.
log_interval – for on-policy algos (e.g., PPO, A2C, …) this is the number of training iterations (i.e., log_interval * n_steps * n_envs timesteps) before logging; for off-policy algos (e.g., TD3, SAC, …) this is the number of episodes before logging.
tb_log_name – the name of the run for TensorBoard logging
reset_num_timesteps – whether or not to reset the current timestep number (used in logging)
progress_bar – Display a progress bar using tqdm and rich.
- Returns:
the trained model
Policies
- class sb3_soft.sql.policies.SoftQNetwork(observation_space, action_space, features_extractor, features_dim, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, normalize_images=True, temperature=1.0)[source]
Bases:
QNetworkQ-network with Boltzmann sampling for stochastic action selection.
- class sb3_soft.sql.policies.SQLPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, features_extractor_class=<class 'stable_baselines3.common.torch_layers.FlattenExtractor'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, temperature=1.0)[source]
Bases:
DQNPolicyDQN policy variant using a soft Q-network for stochastic sampling.
- class sb3_soft.sql.policies.SQLCnnPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, features_extractor_class=<class 'stable_baselines3.common.torch_layers.NatureCNN'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, temperature=1.0)[source]
Bases:
SQLPolicy
- class sb3_soft.sql.policies.SQLMultiInputPolicy(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, features_extractor_class=<class 'stable_baselines3.common.torch_layers.CombinedExtractor'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, temperature=1.0)[source]
Bases:
SQLPolicy
- sb3_soft.sql.policies.CnnPolicy
alias of
SQLCnnPolicy
- sb3_soft.sql.policies.MultiInputPolicy
alias of
SQLMultiInputPolicy