from typing import Any, Optional
import torch as th
from gymnasium import spaces
from stable_baselines3.common.torch_layers import (
BaseFeaturesExtractor,
CombinedExtractor,
FlattenExtractor,
NatureCNN,
)
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
from stable_baselines3.dqn.policies import DQNPolicy, QNetwork
from torch import nn
[docs]
class SoftQNetwork(QNetwork):
"""Q-network with Boltzmann sampling for stochastic action selection."""
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Discrete,
features_extractor: BaseFeaturesExtractor,
features_dim: int,
net_arch: Optional[list[int]] = None,
activation_fn: type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
temperature: float = 1.0,
) -> None:
super().__init__(
observation_space=observation_space,
action_space=action_space,
features_extractor=features_extractor,
features_dim=features_dim,
net_arch=net_arch,
activation_fn=activation_fn,
normalize_images=normalize_images,
)
self.temperature = max(temperature, 1e-8)
def _predict(
self, observation: PyTorchObs, deterministic: bool = True
) -> th.Tensor:
q_values = self(observation)
if deterministic:
return q_values.argmax(dim=1).reshape(-1)
probs = th.softmax(q_values / self.temperature, dim=1)
return th.multinomial(probs, num_samples=1).reshape(-1)
def _get_constructor_parameters(self) -> dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(dict(temperature=self.temperature))
return data
[docs]
class SQLPolicy(DQNPolicy):
"""DQN policy variant using a soft Q-network for stochastic sampling."""
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Discrete,
lr_schedule: Schedule,
net_arch: Optional[list[int]] = None,
activation_fn: type[nn.Module] = nn.ReLU,
features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
temperature: float = 1.0,
) -> None:
self.temperature = max(temperature, 1e-8)
super().__init__(
observation_space=observation_space,
action_space=action_space,
lr_schedule=lr_schedule,
net_arch=net_arch,
activation_fn=activation_fn,
features_extractor_class=features_extractor_class,
features_extractor_kwargs=features_extractor_kwargs,
normalize_images=normalize_images,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
)
[docs]
def make_q_net(self) -> SoftQNetwork:
net_args = self._update_features_extractor(
self.net_args, features_extractor=None
)
net_args["temperature"] = self.temperature
return SoftQNetwork(**net_args).to(self.device)
def _get_constructor_parameters(self) -> dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(dict(temperature=self.temperature))
return data
[docs]
class SQLCnnPolicy(SQLPolicy):
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Discrete,
lr_schedule: Schedule,
net_arch: Optional[list[int]] = None,
activation_fn: type[nn.Module] = nn.ReLU,
features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[dict[str, Any]] = None,
temperature: float = 1.0,
) -> None:
super().__init__(
observation_space=observation_space,
action_space=action_space,
lr_schedule=lr_schedule,
net_arch=net_arch,
activation_fn=activation_fn,
features_extractor_class=features_extractor_class,
features_extractor_kwargs=features_extractor_kwargs,
normalize_images=normalize_images,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
temperature=temperature,
)
MlpPolicy = SQLPolicy
CnnPolicy = SQLCnnPolicy
MultiInputPolicy = SQLMultiInputPolicy