Source code for gym_tl_tools.wrapper

import importlib
from abc import ABC, abstractmethod
from typing import Any, Callable, Generic, SupportsFloat

import numpy as np
from gymnasium import Env, ObservationWrapper
from gymnasium.core import ActType, ObsType
from gymnasium.spaces import Dict, Discrete
from gymnasium.utils import RecordConstructorArgs
from numpy.typing import NDArray
from pydantic import (
    BaseModel,
    ConfigDict,
    SerializationInfo,
    computed_field,
    model_serializer,
)
from typing_extensions import TypedDict

from gym_tl_tools.automaton import Automaton, Predicate
from gym_tl_tools.parser import Parser


[docs] class RewardConfigDict(TypedDict, total=False): """ Configuration dictionary for the reward structure in the TLObservationReward wrapper. This dictionary is used to define how rewards are computed based on the automaton's state and the environment's info. Attributes ---------- terminal_state_reward : float Reward given when the automaton reaches a terminal state in addition to the automaton state transition reward. state_trans_reward_scale : float Scale factor for the reward based on the automaton's state transition robustness. This is applied to the robustness computed from the automaton's transition. If the transition leads to a trap state, the reward is set to be negative, scaled by this factor. dense_reward : bool Whether to use dense rewards (True) or sparse rewards (False). Dense rewards provide a continuous reward signal based on the robustness of the transition to the next non-trap state. dense_reward_scale : float Scale factor for the dense reward. This is applied to the computed dense reward based on the robustness to the next non-trap automaton's state. If dense rewards are enabled, this factor scales the reward returned by the automaton. """ terminal_state_reward: float state_trans_reward_scale: float dense_reward: bool dense_reward_scale: float
[docs] class RewardConfig(BaseModel): terminal_state_reward: float = 5 state_trans_reward_scale: float = 100 dense_reward: bool = False dense_reward_scale: float = 0.01
[docs] class BaseVarValueInfoGenerator(Generic[ObsType, ActType], ABC): """ Base class for generating variable values from the environment's observation and info. This class should be subclassed to implement custom logic for extracting variable values that are used to evaluate atomic predicates in temporal logic formulas. The `get_var_values` method should return a dictionary where keys are variable names used in predicate formulas (e.g., "d_robot_goal", "d_robot_obstacle") and values are their corresponding numerical values extracted from the environment's observation and info. Example ------- ```python class MyVarValueGenerator(BaseVarValueInfoGenerator): def get_var_values(self, env, obs, info): return { "d_robot_goal": info.get("d_robot_goal", float('inf')), "d_robot_obstacle": obs.get("obstacle_distance", 0.0), "robot_speed": np.linalg.norm(obs["velocity"]) if "velocity" in obs else 0.0, } ``` """
[docs] @abstractmethod def get_var_values( self, env: Env[ObsType, ActType], obs: ObsType, info: dict[str, Any] ) -> dict[str, Any]: """ Extract variable values from the environment's observation and info. This method should be implemented by subclasses to provide the variable values needed for evaluating atomic predicates in temporal logic formulas. Parameters ---------- env : Env[ObsType, ActType] The environment from which to extract variable values. obs : ObsType The current observation from the environment. info : dict[str, Any] The info dictionary containing additional information from the environment. Returns ------- dict[str, Any] A dictionary where keys are variable names used in predicate formulas and values are their corresponding numerical values. """ raise NotImplementedError("This method should be implemented by subclasses.")
[docs] class TLObservationRewardConfig(BaseModel, Generic[ObsType, ActType]): """ Configuration for the TLObservationReward wrapper. This class defines the parameters used to configure the TLObservationReward wrapper, including the temporal logic specification, atomic predicates, variable value generator, reward configuration, and other settings. Attributes ---------- tl_spec : str The temporal logic specification (e.g., LTL formula) to be used for the automaton. atomic_predicates : list[Predicate] List of atomic predicates used in the TL formula. var_value_info_generator : BaseVarValueInfoGenerator[ObsType, ActType] An instance of a subclass of BaseVarValueInfoGenerator that extracts variable values from the environment's observation and info. reward_config : RewardConfigDict Configuration for the reward structure. early_termination : bool Whether to terminate episodes when automaton reaches terminal states. parser : Parser Parser for TL expressions. dict_aut_state_key : str Key for the automaton state in the observation dictionary. """ tl_spec: str = "" atomic_predicates: list[Predicate] var_value_info_generator_cls: str var_value_info_generator_args: dict[str, Any] = {} reward_config: RewardConfig = RewardConfig() early_termination: bool = True parser: Parser = Parser() dict_aut_state_key: str = "aut_state" model_config = ConfigDict(arbitrary_types_allowed=True) @computed_field @property def var_value_info_generator(self) -> BaseVarValueInfoGenerator[ObsType, ActType]: """ Instantiate the variable value info generator based on the provided class and arguments. Returns ------- BaseVarValueInfoGenerator[ObsType, ActType] An instance of the variable value info generator. """ if isinstance(self.var_value_info_generator_cls, str): # If it's a string, assume it's a class name and import it dynamically module_name, class_name = self.var_value_info_generator_cls.rsplit(".", 1) module = importlib.import_module(module_name) generator_cls = getattr(module, class_name) else: generator_cls = self.var_value_info_generator_cls return generator_cls(**self.var_value_info_generator_args)
[docs] @model_serializer def serialize_model(self, info: SerializationInfo | None = None) -> dict[str, Any]: """Custom model serializer to handle atomic_predicates serialization with context.""" # Handle atomic_predicates based on context atomic_predicates_data = self.atomic_predicates # if info and info.context: # excluded = info.context.get("excluded", []) # if "atomic_predicates" not in excluded: # atomic_predicates_data = [ # pred.model_dump() for pred in self.atomic_predicates # ] # else: # atomic_predicates_data = [ # pred.model_dump() for pred in self.atomic_predicates # ] data = { "tl_spec": self.tl_spec, "atomic_predicates": atomic_predicates_data, "var_value_info_generator": self.var_value_info_generator, "reward_config": self.reward_config, "early_termination": self.early_termination, "parser": self.parser, "dict_aut_state_key": self.dict_aut_state_key, } return data
[docs] class TLObservationReward( ObservationWrapper[ dict[str, ObsType | np.int64 | NDArray[np.number]], ActType, ObsType ], RecordConstructorArgs, Generic[ObsType, ActType], ): """ A wrapper for Gymnasium environments that augments observations with the state of a temporal logic automaton, and computes rewards based on satisfaction of temporal logic (TL) specifications. This wrapper is designed for environments where the agent's objective is specified using temporal logic (e.g., LTL). It integrates an automaton (from a TL formula) into the observation and reward structure, enabling RL agents to learn tasks with complex temporal requirements. Usage ----- 1. **Define Atomic Predicates**: Create a list of `Predicate` objects, each representing an atomic proposition in your TL formula. Each predicate has a `name` (used in the TL formula) and a `formula` (Boolean expression string). Example: ```python from gym_tl_tools import Predicate atomic_predicates = [ Predicate(name="goal_reached", formula="d_robot_goal < 1.0"), Predicate(name="obstacle_hit", formula="d_robot_obstacle < 1.0"), ] ``` 2. **Create a Variable Value Generator**: Implement a subclass of `BaseVarValueInfoGenerator` to extract variable values from the environment's observation and info for evaluating the atomic predicates. Example: ```python from gym_tl_tools import BaseVarValueInfoGenerator class MyVarValueGenerator(BaseVarValueInfoGenerator): def get_var_values(self, env, obs, info): # Extract variables needed for predicate evaluation return { "d_robot_goal": info.get("d_robot_goal", float('inf')), "d_robot_obstacle": info.get("d_robot_obstacle", float('inf')), } var_generator = MyVarValueGenerator() ``` 3. **Specify the Temporal Logic Formula**: Write your TL specification as a string, using the names of your atomic predicates. Example: ```python tl_spec = "F(goal_reached) & G(!obstacle_hit)" ``` 4. **Wrap Your Environment**: Pass your environment, TL specification, atomic predicates, and variable value generator to `TLObservationReward`. Example: ```python from gym_tl_tools import TLObservationReward wrapped_env = TLObservationReward( env, tl_spec=tl_spec, atomic_predicates=atomic_predicates, var_value_info_generator=var_generator, ) ``` 5. **Observation Structure**: - The wrapper augments each observation with the current automaton state. - If the original observation space is a [Dict](https://gymnasium.farama.org/main/api/spaces/composite/#gymnasium.spaces.Dict), the automaton state is added as a new key (default: `"aut_state"`). - ~~If the original observation space is a [Tuple](https://gymnasium.farama.org/main/api/spaces/composite/#gymnasium.spaces.Tuple), the automaton state is appended.~~ - Otherwise, the observation is wrapped in a [Dict](https://gymnasium.farama.org/main/api/spaces/composite/#gymnasium.spaces.Dict) with keys `"obs"` and `"aut_state"`. 6. **Reward Calculation**: - At each step, the wrapper computes the reward based on the automaton's transition, which reflects progress toward (or violation of) the TL specification. - The automaton state is updated according to the values of the atomic predicates, which are evaluated using the variable values provided by the `var_value_info_generator`. 7. **Reset and Step**: - On `reset()`, the automaton is reset to its initial state, and the initial observation is augmented. - On `step(action)`, the automaton transitions based on the variable values extracted by the `var_value_info_generator`, and the reward is computed accordingly. Notes ----- - The wrapper adds additional information to the `info` dict, including: - `is_success`: Whether the automaton has reached a goal state. - `is_failure`: Whether the automaton has reached a trap state. - `is_aut_terminated`: Whether the automaton has been terminated. Parameters ---------- env : gymnasium.Env The environment to wrap. tl_spec : str The temporal logic specification (e.g., LTL formula) to be used for the automaton. atomic_predicates : list[gym_tl_tools.Predicate] List of atomic predicates used in the TL formula. Each predicate has a `name` (used in the TL formula) and a `formula` (Boolean expression string for evaluation). var_value_info_generator : BaseVarValueInfoGenerator[ObsType, ActType] An instance of a subclass of `BaseVarValueInfoGenerator` that extracts variable values from the environment's observation and info for evaluating atomic predicates. reward_config : RewardConfigDict, optional Configuration for the reward structure (default values provided). early_termination : bool, optional Whether to terminate episodes when automaton reaches terminal states (default: True). parser : Parser, optional Parser for TL expressions (default: new instance of `Parser`). dict_aut_state_key : str, optional Key for the automaton state in the observation dictionary (default: "aut_state"). Example ------- ```python from gym_tl_tools import Predicate, BaseVarValueInfoGenerator, TLObservationReward # Define atomic predicates atomic_predicates = [ Predicate(name="goal_reached", formula="d_robot_goal < 1.0"), Predicate(name="obstacle_hit", formula="d_robot_obstacle < 1.0"), ] # Create variable value generator class MyVarValueGenerator(BaseVarValueInfoGenerator): def get_var_values(self, env, obs, info): return { "d_robot_goal": info.get("d_robot_goal", float('inf')), "d_robot_obstacle": info.get("d_robot_obstacle", float('inf')), } var_generator = MyVarValueGenerator() tl_spec = "F(goal_reached) & G(!obstacle_hit)" wrapped_env = TLObservationReward( env, tl_spec=tl_spec, atomic_predicates=atomic_predicates, var_value_info_generator=var_generator, ) obs, info = wrapped_env.reset() done = False while not done: action = wrapped_env.action_space.sample() obs, reward, terminated, truncated, info = wrapped_env.step(action) done = terminated or truncated ``` """ def __init__( self, env: Env[ObsType, ActType], tl_spec: str, atomic_predicates: list[Predicate] | list[dict[str, str]], var_value_info_generator: BaseVarValueInfoGenerator[ObsType, ActType], reward_config: RewardConfigDict = { "terminal_state_reward": 5, "state_trans_reward_scale": 100, "dense_reward": False, "dense_reward_scale": 0.01, }, early_termination: bool = True, parser: Parser = Parser(), dict_aut_state_key: str = "aut_state", ): """ Initialize the TLObservationReward wrapper. Parameters ---------- env : Env[ObsType, ActType] The environment to wrap. tl_spec : str The temporal logic specification to be used with the automaton. atomic_predicates : list[gym_tl_tools.Predicate] A list of atomic predicates that define the conditions for the automaton. Each predicate should have a `name` (used in the TL formula) and a `formula` (Boolean expression string). e.g. [Predicate(name="goal_reached", formula="d_robot_goal < 1.0")] var_value_info_generator : BaseVarValueInfoGenerator[ObsType, ActType] An instance of a subclass of BaseVarValueInfoGenerator that extracts variable values from the environment's observation and info. This is used to evaluate the atomic predicates in the automaton. The get_var_values method should return a dictionary where keys are variable names used in predicate formulas and values are their corresponding numerical values. parser : Parser = gym_tl_tools.Parser() An instance of the Parser class for parsing temporal logic expressions. Defaults to a new instance of Parser. var_value_info_generator : Callable[[Env[ObsType, ActType], ObsType, dict[str, Any]], dict[str, Any]] = lambda env, obs, info: {} A function that generates a dictionary of variable values from the environment's observation and info. This function should return a dictionary where keys are variable names and values are their corresponding values. This is used to evaluate the atomic predicates in the automaton. reward_config : RewardConfigDict = { "terminal_state_reward": 5, "state_trans_reward_scale": 100, "dense_reward": False, "dense_reward_scale": 0.01, } A dictionary containing the configuration for the reward structure. - `terminal_state_reward` : float Reward given when the automaton reaches a terminal state in addition to the automaton state transition reward. - `state_trans_reward_scale` : float Scale factor for the reward based on the automaton's state transition robustness. This is applied to the robustness computed from the automaton's transition. If the transition leads to a trap state, the reward is set to be negative, scaled by this factor. - `dense_reward` : bool Whether to use dense rewards (True) or sparse rewards (False). Dense rewards provide a continuous reward signal based on the robustness of the transition to the next non-trap state. - `dense_reward_scale` : float Scale factor for the dense reward. This is applied to the computed dense reward based on the robustness to the next non-trap automaton's state. If dense rewards are enabled, this factor scales the reward returned by the automaton. early_termination : bool = True Whether to terminate the episode when the automaton reaches a terminal state. dict_aut_state_key : str = "aut_state" The key under which the automaton state will be stored in the observation space. Defaults to "aut_state". """ match tl_spec: case "0" | "1" | "": raise ValueError( f"The temporal logic specification cannot be '0', '1', or an empty string (given tl_spec={tl_spec}). " "Please provide a valid temporal logic formula." ) RecordConstructorArgs.__init__( self, tl_spec=tl_spec, atomic_predicates=atomic_predicates, parser=parser, reward_config=reward_config, dict_aut_state_key=dict_aut_state_key, early_termination=early_termination, var_value_info_generator=var_value_info_generator, ) ObservationWrapper.__init__(self, env) predicates: list[Predicate] = [] for pred in atomic_predicates: if isinstance(pred, dict): predicates.append(Predicate(**pred)) elif isinstance(pred, Predicate): predicates.append(pred) else: raise TypeError( f"Expected atomic_predicates to be a list of Predicate or dict, got {type(pred)}." ) self.var_value_info_generator: BaseVarValueInfoGenerator[ObsType, ActType] = ( var_value_info_generator ) self.parser = Parser() self.automaton = Automaton(tl_spec, predicates, parser=parser) self.reward_config = RewardConfig(**reward_config) self.early_termination: bool = early_termination aut_state_space = Discrete(self.automaton.num_states) self._append_data_func: Callable[ [ObsType, int], dict[str, ObsType | np.int64 | NDArray] ] # Find the observation space match env.observation_space: case Dict(): assert dict_aut_state_key not in env.observation_space.spaces, ( f"Key '{dict_aut_state_key}' already exists in the observation space. " "Please choose a different key." ) observation_space = Dict( { **env.observation_space.spaces, dict_aut_state_key: aut_state_space, } ) self._append_data_func = lambda obs, aut_state: { **obs, dict_aut_state_key: aut_state, } # case Tuple(): # observation_space = Tuple( # env.observation_space.spaces + (aut_state_space,) # ) # self._append_data_func = lambda obs, aut_state: obs + (aut_state,) case _: observation_space = Dict( {"obs": env.observation_space, dict_aut_state_key: aut_state_space} ) self._append_data_func = lambda obs, aut_state: { "obs": obs, "aut_state": np.int64(aut_state), } self.observation_space = observation_space self._obs_postprocess_func = lambda obs: obs
[docs] def observation( self, observation: ObsType ) -> dict[str, ObsType | np.int64 | NDArray]: """ Process the observation to include the automaton state. Parameters ---------- observation : ObsType The original observation from the environment. Returns ------- new_obs: dict[str,ObsType|int] The processed observation with the automaton state appended. """ aut_state = self.automaton.current_state new_obs: dict[str, ObsType | np.int64 | NDArray] = self._append_data_func( observation, aut_state ) return new_obs
[docs] def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = { "forward_aut_on_reset": True, "obs": {}, "info": {}, }, ) -> tuple[dict[str, ObsType | np.int64 | NDArray], dict[str, Any]]: """ Reset the environment and return the initial observation. Parameters ---------- seed : int | None, optional Random seed for reproducibility. options : dict[str, Any] | None, optional Additional options for resetting the environment. Returns ------- new_obs: dict[str,ObsType|int] The initial observation with the automaton state. info: dict[str, Any] Additional information from the reset. Should contain the variable keys and values that define the atomic predicates. """ obs, info = self.env.reset(seed=seed, options=options) # Update the info dict with the variable values info = self._var_value_info_update(obs, info) self.automaton.reset(seed=seed) if options: if options.get("forward_aut_on_reset", True): obs = options["obs"] if options.get("obs", {}) else obs info = options["info"] if options.get("info", {}) else info self.forward_aut(obs, info) info = self._success_info_update(info) new_obs = self.observation(obs) return new_obs, info
[docs] def forward_aut(self, obs: ObsType, info: dict[str, Any]) -> None: """ Forward the automaton based on the current observation and info. This method updates the automaton state based on the current observation and info. It is useful for manually stepping the automaton without taking an action in the environment. Parameters ---------- obs : ObsType The current observation from the environment. info : dict[str, Any] The info dictionary containing variable values for the atomic predicates. """ curr_state: int = self.automaton.current_state new_state: int | None = None # Update the automaton state based on the info info = self._var_value_info_update(obs, info) while new_state != curr_state: curr_state = self.automaton.current_state _, new_state = self.automaton.step(info, **self.reward_config.model_dump()) # Update the automaton state if new_state is not None: self.automaton.current_state = new_state
@property def is_aut_terminated(self) -> bool: """ Check if the automaton is in a terminal state. Returns ------- bool True if the automaton is in a terminal state, False otherwise. """ return self.automaton.is_terminated def _var_value_info_update( self, obs: ObsType, info: dict[str, Any] ) -> dict[str, Any]: """ Update the info dictionary with variable values based on the observation. This method is used to update the info dictionary with the values of the atomic predicates based on the current observation. It is called during the step and reset methods. Parameters ---------- obs : ObsType The current observation from the environment. info : dict[str, Any] The info dictionary to be updated with variable values. Returns ------- info : dict[str, Any] The updated info dictionary containing the variable keys and values that define the atomic predicates. """ info_updates = self.var_value_info_generator.get_var_values(self.env, obs, info) info.update(info_updates) return info def _success_info_update(self, info: dict[str, Any]) -> dict[str, Any]: """ Update the info dictionary with success information. This method updates the info dictionary to include whether the automaton has reached a goal state or has been terminated due to reaching a trap state. Parameters ---------- info : dict[str, Any] The info dictionary to be updated. Returns ------- info : dict[str, Any] The updated info dictionary with success information. Following keys are added or updated: - is_success: Whether the automaton has reached a goal state. - is_failure: Whether the automaton has reached a trap state. - is_aut_terminated: Whether the automaton has been terminated. """ info.update( { "is_success": self.automaton.current_state in self.automaton.goal_states, "is_failure": self.automaton.current_state in self.automaton.trap_states, "is_aut_terminated": self.automaton.is_terminated, } ) return info
[docs] def step( self, action: ActType ) -> tuple[ dict[str, ObsType | np.int64 | NDArray], SupportsFloat, bool, bool, dict[str, Any], ]: """ Take a step in the environment with the given action. Parameters ---------- action : ActType The action to take in the environment. Returns ------- new_obs: dict[str,ObsType|int] The new observation after taking the action. reward: SupportsFloat The reward received from the environment. terminated: bool Whether the episode has terminated. truncated: bool Whether the episode has been truncated. info: dict[str, Any] Additional information from the step. Should contain the variable keys and values that define the atomic predicates. """ obs, orig_reward, terminated, truncated, info = self.env.step(action) info = self._var_value_info_update(obs, info) info.update({"original_reward": orig_reward}) reward, next_aut_state = self.automaton.step( info, **self.reward_config.model_dump() ) if ( self.early_termination and next_aut_state in self.automaton.goal_states + self.automaton.trap_states ): terminated = True info = self._success_info_update(info) new_obs = self.observation(obs) return new_obs, reward, terminated, truncated, info