Source code for sb3_soft.sql.policies

from typing import Any, Optional

import torch as th
from gymnasium import spaces
from stable_baselines3.common.torch_layers import (
    BaseFeaturesExtractor,
    CombinedExtractor,
    FlattenExtractor,
    NatureCNN,
)
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
from stable_baselines3.dqn.policies import DQNPolicy, QNetwork
from torch import nn


[docs] class SoftQNetwork(QNetwork): """Q-network with Boltzmann sampling for stochastic action selection.""" def __init__( self, observation_space: spaces.Space, action_space: spaces.Discrete, features_extractor: BaseFeaturesExtractor, features_dim: int, net_arch: Optional[list[int]] = None, activation_fn: type[nn.Module] = nn.ReLU, normalize_images: bool = True, temperature: float = 1.0, ) -> None: super().__init__( observation_space=observation_space, action_space=action_space, features_extractor=features_extractor, features_dim=features_dim, net_arch=net_arch, activation_fn=activation_fn, normalize_images=normalize_images, ) self.temperature = max(temperature, 1e-8) def _predict( self, observation: PyTorchObs, deterministic: bool = True ) -> th.Tensor: q_values = self(observation) if deterministic: return q_values.argmax(dim=1).reshape(-1) probs = th.softmax(q_values / self.temperature, dim=1) return th.multinomial(probs, num_samples=1).reshape(-1) def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update(dict(temperature=self.temperature)) return data
[docs] class SQLPolicy(DQNPolicy): """DQN policy variant using a soft Q-network for stochastic sampling.""" def __init__( self, observation_space: spaces.Space, action_space: spaces.Discrete, lr_schedule: Schedule, net_arch: Optional[list[int]] = None, activation_fn: type[nn.Module] = nn.ReLU, features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[dict[str, Any]] = None, temperature: float = 1.0, ) -> None: self.temperature = max(temperature, 1e-8) super().__init__( observation_space=observation_space, action_space=action_space, lr_schedule=lr_schedule, net_arch=net_arch, activation_fn=activation_fn, features_extractor_class=features_extractor_class, features_extractor_kwargs=features_extractor_kwargs, normalize_images=normalize_images, optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, )
[docs] def make_q_net(self) -> SoftQNetwork: net_args = self._update_features_extractor( self.net_args, features_extractor=None ) net_args["temperature"] = self.temperature return SoftQNetwork(**net_args).to(self.device)
def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update(dict(temperature=self.temperature)) return data
[docs] class SQLCnnPolicy(SQLPolicy): def __init__( self, observation_space: spaces.Space, action_space: spaces.Discrete, lr_schedule: Schedule, net_arch: Optional[list[int]] = None, activation_fn: type[nn.Module] = nn.ReLU, features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN, features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[dict[str, Any]] = None, temperature: float = 1.0, ) -> None: super().__init__( observation_space=observation_space, action_space=action_space, lr_schedule=lr_schedule, net_arch=net_arch, activation_fn=activation_fn, features_extractor_class=features_extractor_class, features_extractor_kwargs=features_extractor_kwargs, normalize_images=normalize_images, optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, temperature=temperature, )
[docs] class SQLMultiInputPolicy(SQLPolicy): def __init__( self, observation_space: spaces.Space, action_space: spaces.Discrete, lr_schedule: Schedule, net_arch: Optional[list[int]] = None, activation_fn: type[nn.Module] = nn.ReLU, features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor, features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[dict[str, Any]] = None, temperature: float = 1.0, ) -> None: super().__init__( observation_space=observation_space, action_space=action_space, lr_schedule=lr_schedule, net_arch=net_arch, activation_fn=activation_fn, features_extractor_class=features_extractor_class, features_extractor_kwargs=features_extractor_kwargs, normalize_images=normalize_images, optimizer_class=optimizer_class, optimizer_kwargs=optimizer_kwargs, temperature=temperature, )
MlpPolicy = SQLPolicy CnnPolicy = SQLCnnPolicy MultiInputPolicy = SQLMultiInputPolicy