Source code for gym_tl_tools.automaton

import re
from typing import Literal

import numpy as np
import spot
from pydantic import BaseModel
from spot import twa as Twa

from gym_tl_tools.parser import Parser


[docs] class Transition(BaseModel): """ Represents a transition in a finite state automaton from a LTL formula. Attributes ---------- next_state : int The next automaton state after the transition. condition : str The condition that triggers the transition represented as a Boolean expression. e.g. "psi_1 & psi_2" or "psi_1 | !psi_2". is_trapped_next : bool Indicates if the next state is a trap state (i.e., no further transitions are possible). Defaults to False. """ condition: str next_state: int is_trapped_next: bool = False
[docs] class Predicate(BaseModel): """ Represents a predicate in a TL formula. Attributes ---------- name : str The name of the atomic predicate. e.g. "psi_goal_robot". formula : str The formula of the atomic predicate, which can be a Boolean expression. e.g. "d_robot_goal < 5". """ name: str formula: str
[docs] class RobNextStatePair(BaseModel): """ Represents a pair of robustness value and transition. Attributes ---------- robustness : float The robustness value of the transition. next_state: int The next state of the automaton after the transition. """ robustness: float next_state: int
[docs] class RobustnessCounter(BaseModel): robustness: float ind: int
[docs] class AutomatonStateCounter(BaseModel): state: int ind: int
[docs] class Edge: """ Represents an edge in a finite state automaton from a LTL formula. Attributes ---------- state : int The automaton state. transitions : list[Transition] The transitions from this state to other states. is_inf_zero_acc : bool Indicates if the edge has an Inf(0) acceptance condition. is_terminal_state : bool Indicates if the edge is a terminal state (i.e., no further transitions are possible). is_trap_state : bool Indicates if the edge is a trap state (i.e., all transitions lead to trap states). """ def __init__(self, raw_state: str, atomic_pred_names: list[str]): """ Parameters ---------- raw_state : str The raw state string from the automaton in HOA format. atomic_pred_names : list[str] The names of the atomic predicates used in the transitions. """ self.state: int = int(*re.findall("State: (\d).*\[", raw_state)) self.transitions: list[Transition] = [ Transition( condition=replace_atomic_pred_ids_to_names( add_or_parentheses(re.findall("\[(.*)\]", transition)[0]), atomic_pred_names, ), next_state=int(*re.findall("\[.*\] (\d)", transition)), ) for transition in re.findall( "(\[.*\] \d+)\n", raw_state.replace("[", "\n[") ) ] # Check if the edge has an Inf(0) acceptance condition self.is_inf_zero_acc = True if "{0}" in raw_state else False # Check if the edge is a terminal state. If so, elimite 't' transition # looping back to itself is_terminal_state: bool = True for i, transition in enumerate(self.transitions): if transition.next_state != self.state: is_terminal_state = False else: if transition.condition == "t": self.transitions.pop(i) else: pass self.is_terminal_state = is_terminal_state # A terminal state is a trap state if it doesn't have Inf(0) acceptance self.is_trap_state: bool = ( True if not self.is_inf_zero_acc and self.is_terminal_state else False )
[docs] class Automaton: def __init__( self, tl_spec: str, atomic_predicates: list[Predicate], parser: Parser = Parser(), ) -> None: """ Initialize the Automaton with a temporal logic specification and atomic predicates. Parameters ---------- tl_spec : str The temporal logic specification in LTL format. e.g. "F(psi_1 & psi_2) | G(psi_3)". atomic_predicates : list[Predicate] A list of atomic predicates used in the temporal logic specification. Each predicate should be an instance of the Predicate class. e.g. [Predicate("psi_1", "d_robot_goal < 5"), Predicate("psi_2", "d_robot_goal < 10")]. parser : Parser = gym_tl_tools.parser.Parser() An instance of the Parser class used to parse the temporal logic specification. Defaults to a new instance of Parser. """ self.tl_spec: str = tl_spec self.atomic_predicates: list[Predicate] = atomic_predicates self.parser: Parser = parser aut: Twa = spot.translate(tl_spec, "Buchi", "state-based", "complete") aut_hoa: str = aut.to_str("hoa") self.num_states: int = int(aut.num_states()) # type: ignore self.start: int = int(*re.findall("Start: (\d+)\n", aut_hoa)) num_used_aps: int = int(*re.findall('AP: (\d+) "', aut_hoa)) used_aps: list[str] = ( re.findall("AP: \d+ (.*)", aut_hoa)[0].replace('"', "").split() ) if len(used_aps) != num_used_aps: raise ValueError( f"Number of used atomic predicates ({len(used_aps)}) does not match the number of atomic predicates in the automaton ({num_used_aps})." ) else: self.used_aps: list[str] = used_aps self.acc_name: str = re.findall("acc-name: (.*)\n", aut_hoa)[0] self.acceptance: str = re.findall("Acceptance: (.*)\n", aut_hoa)[0] self.properties = re.findall("properties: (.*)\n", aut_hoa) aut_hoa_states = ( aut_hoa.replace("\n", "") .replace("State", "\nState") .replace("--END--", "\n") ) raw_states: list[str] = [ raw_state + "\n" for raw_state in re.findall("(State:.*)\n", aut_hoa_states) ] untrapped_edges: list[Edge] = [ Edge(raw_state, self.used_aps) for raw_state in raw_states ] # Check if a transition of an edge leads to a trap state and if all transitions # of an edge leads to a trap state (if so, this edge is also a trap state) trap_states: list[int] = [ edge.state for edge in untrapped_edges if edge.is_trap_state ] goal_states: list[int] = [] # for _ in range(len(untrapped_edges)): for i, edge in enumerate(untrapped_edges): if edge.is_trap_state: pass elif edge.is_inf_zero_acc and edge.is_terminal_state: # If the edge has Inf(0) acceptance and is a terminal state, # it is a goal state goal_states.append(edge.state) else: is_trapped_in_all_trans = True for j, transition in enumerate(edge.transitions): if transition.next_state in trap_states: if edge.is_inf_zero_acc: edge.transitions.pop(j) goal_states.append(edge.state) else: edge.transitions[j] = Transition( condition=transition.condition, next_state=transition.next_state, is_trapped_next=True, ) else: is_trapped_in_all_trans = False if is_trapped_in_all_trans and edge.transitions: untrapped_edges[i].is_trap_state = True trap_states.append(edge.state) else: pass self.goal_states: tuple[int, ...] = tuple(goal_states) self.trap_states: tuple[int, ...] = tuple(trap_states) self.edges: tuple[Edge, ...] = tuple(untrapped_edges)
[docs] def reset(self, seed: int | None = None) -> None: """ Reset the automaton to its initial state. """ self.current_state: int = self.start self.status: Literal["intermediate", "goal", "trap"] = "intermediate" # Initialize Numpy RNG if a seed is provided self._np_rng: np.random.Generator = np.random.default_rng(seed)
[docs] def step( self, var_value_dict: dict[str, float], terminal_state_reward: float = 5, state_trans_reward_scale: float = 100, dense_reward: bool = False, dense_reward_scale: float = 0.01, ) -> tuple[float, int]: """ Step the automaton to the next state based on the current predicate values. This function updates the current state of the automaton based on the values of variables defining atomic predicates provided in `var_value_dict`. Parameters ---------- var_value_dict : dict[str, float] A dictionary mapping the variable names used in the atomic predicate definitions to their current values. e.g. {"d_robot_goal": 3.0, "d_robot_obstacle": 1.0, "d_robot_goal": 0.5}. terminal_state_reward : float = 5 Returns ------- reward : float The reward for the current step. new_state : int The new automaton state. """ if not hasattr(self, "current_state"): raise ValueError( " Error: The automaton has not been reset. Please call the reset() method before stepping." ) ap_rob_dict: dict[str, float] = { atom_pred.name: self.parser.tl2rob(atom_pred.formula, var_value_dict) for atom_pred in self.atomic_predicates } reward, new_state = self.tl_reward( ap_rob_dict, self.current_state, terminal_state_reward=terminal_state_reward, state_trans_reward_scale=state_trans_reward_scale, dense_reward=dense_reward, dense_reward_scale=dense_reward_scale, ) # Update the current state of the automaton self.current_state = new_state # Update the status of the automaton if new_state in self.goal_states: self.status = "goal" elif new_state in self.trap_states: self.status = "trap" else: self.status = "intermediate" return reward, new_state
[docs] def tl_reward( self, ap_rob_dict: dict[str, float], curr_aut_state: int, terminal_state_reward: float = 5, state_trans_reward_scale: float = 100, dense_reward: bool = False, dense_reward_scale: float = 0.01, ) -> tuple[float, int]: """ Calculate the reward of the step from a given automaton. Parameters ---------- ap_rob_dict: dict[str, float] A dictionary mapping the names of atomic predicates to their robustness values. curr_aut_state: int The current state of the automaton. dense_reward: bool If True, the reward is calculated based on the maximum robustness of non-trap transitions. terminal_state_reward: float The reward given when the automaton reaches a terminal state. Returns ------- reward : float The reward of the step based on the MDP and automaton states. next_aut_state : int The resultant automaton state after the step. """ if curr_aut_state in self.goal_states: return (terminal_state_reward, curr_aut_state) elif curr_aut_state in self.trap_states: return (-terminal_state_reward, curr_aut_state) else: curr_edge = self.edges[curr_aut_state] transitions = curr_edge.transitions ( pos_trap_rob_trans_pairs, pos_non_trap_rob_trans_pairs, zero_trap_rob_trans_pairs, zero_non_trap_rob_trans_pairs, non_trap_rob_trans_pairs, ) = self.transition_robustness(transitions, ap_rob_dict) trans_rob: float next_aut_state: int if len(pos_trap_rob_trans_pairs) == 1: # If there is only one positive transition robustness leading to a trap state trans_rob = pos_trap_rob_trans_pairs[0].robustness next_aut_state = pos_trap_rob_trans_pairs[0].next_state elif len(pos_non_trap_rob_trans_pairs) == 1: # If there is only one positive transition robustness leading to a non-trap state trans_rob = pos_non_trap_rob_trans_pairs[0].robustness next_aut_state = pos_non_trap_rob_trans_pairs[0].next_state elif zero_trap_rob_trans_pairs: # If there are no positive transition robustnesses, but there are zero transition robustnesses # leading to a trap state, select the first one trans_rob = zero_trap_rob_trans_pairs[0].robustness next_aut_state = zero_trap_rob_trans_pairs[0].next_state elif zero_non_trap_rob_trans_pairs: # If there are no positive transition robustnesses, but there are zero transition robustnesses # leading to a non-trap state, select the first one selected_index: int = int( self._np_rng.integers(0, len(zero_non_trap_rob_trans_pairs)) ) trans_rob = zero_non_trap_rob_trans_pairs[selected_index].robustness next_aut_state = zero_non_trap_rob_trans_pairs[ selected_index ].next_state else: raise ValueError( "Error: No singular positive or zero transition robustnesses found for the current automaton state.", "Current automaton state:", curr_aut_state, "Positive trap transitions:", pos_trap_rob_trans_pairs, "Positive non-trap transitions:", pos_non_trap_rob_trans_pairs, "Zero trap transitions:", zero_trap_rob_trans_pairs, "Zero non-trap transitions:", zero_non_trap_rob_trans_pairs, ) # Calculate the reward reward: float # Weight for reward calculation # alpha: float = 0.7 # beta: float = 0.5 gamma: float = dense_reward_scale delta: float = state_trans_reward_scale if next_aut_state == curr_aut_state: # non_trap_robs.remove(trans_rob) # reward = -gamma * ( # beta * 1 / max(non_trap_robs) - (1 - beta) * 1 / max(trap_robs) # ) # reward = gamma * ( # beta * 1 / max(non_trap_robs) - (1 - beta) * 1 / max(trap_robs) # ) if dense_reward: non_trap_robs: list[float] = [ pair.robustness for pair in non_trap_rob_trans_pairs ] # ] + [pair.robustness for pair in zero_non_trap_rob_trans_pairs] non_trap_robs.remove(trans_rob) reward = gamma * max(non_trap_robs) else: reward = 0 else: if next_aut_state not in self.trap_states: reward = ( delta * trans_rob ) # alpha * trans_rob - (1 - alpha) * max(trap_robs) elif next_aut_state in self.trap_states: reward = ( -delta * trans_rob ) # -(1 - alpha) * max(non_trap_robs) - alpha * trans_rob else: raise ValueError( "Error: the transition robustness doesn't exit in the robustness set." ) return reward, next_aut_state
[docs] def transition_robustness( self, transitions: list[Transition], ap_rob_dict: dict[str, float], ) -> tuple[ list[RobNextStatePair], list[RobNextStatePair], list[RobNextStatePair], list[RobNextStatePair], list[RobNextStatePair], ]: """ Calculate the robustness of transitions in the automaton. This function computes the robustness values for each transition based on the conditions of the transitions and the robustness values of the atomic predicates. Parameters ---------- transitions: list[Transition] A list of transitions in the automaton. ap_rob_dict: dict[str, float] A dictionary mapping the names of atomic predicates to their robustness values. Returns ------- pos_trap_rob_trans_pairs : list[RobNextStatePair] A list of pairs of positive robustness values and transitions that lead to trap states. pos_non_trap_rob_trans_pairs : list[RobNextStatePair] A list of pairs of positive robustness values and transitions that do not lead to trap states. zero_trap_rob_trans_pairs : list[RobNextStatePair] A list of pairs of zero robustness values and transitions that lead to trap states. zero_non_trap_rob_trans_pairs : list[RobNextStatePair] A list of pairs of zero robustness values and transitions that do not lead to trap states. """ robs: list[float] = [] pos_trap_rob_trans_pairs: list[RobNextStatePair] = [] pos_non_trap_rob_trans_pairs: list[RobNextStatePair] = [] zero_trap_rob_trans_pairs: list[RobNextStatePair] = [] zero_non_trap_rob_trans_pairs: list[RobNextStatePair] = [] non_trap_rob_trans_pairs: list[RobNextStatePair] = [] for trans in transitions: rob: float = self.parser.tl2rob(trans.condition, ap_rob_dict) robs.append(rob) match (rob, trans.is_trapped_next): case (0, True): zero_trap_rob_trans_pairs.append( RobNextStatePair(robustness=rob, next_state=trans.next_state) ) case (0, False): zero_non_trap_rob_trans_pairs.append( RobNextStatePair(robustness=rob, next_state=trans.next_state) ) non_trap_rob_trans_pairs.append( RobNextStatePair(robustness=rob, next_state=trans.next_state) ) case (rob, True) if rob > 0: pos_trap_rob_trans_pairs.append( RobNextStatePair(robustness=rob, next_state=trans.next_state) ) case (rob, False) if rob > 0: pos_non_trap_rob_trans_pairs.append( RobNextStatePair(robustness=rob, next_state=trans.next_state) ) non_trap_rob_trans_pairs.append( RobNextStatePair(robustness=rob, next_state=trans.next_state) ) case (rob, False) if rob < 0: non_trap_rob_trans_pairs.append( RobNextStatePair(robustness=rob, next_state=trans.next_state) ) case _: pass return ( pos_trap_rob_trans_pairs, pos_non_trap_rob_trans_pairs, zero_trap_rob_trans_pairs, zero_non_trap_rob_trans_pairs, non_trap_rob_trans_pairs, )
@property def is_goaled(self) -> bool: """ Check if the current state of the automaton is a goal state. Returns ------- bool True if the current state is a goal state, False otherwise. """ return self.status == "goal" @property def is_trapped(self) -> bool: """ Check if the current state of the automaton is a trap state. Returns ------- bool True if the current state is a trap state, False otherwise. """ return self.status == "trap" @property def is_terminated(self) -> bool: """ Check if the automaton is in a terminal state (goal or trap). Returns ------- bool True if the automaton is in a terminal state, False otherwise. """ return self.is_goaled or self.is_trapped
[docs] def add_or_parentheses(condition: str) -> str: or_locs: list[int] = [m.start() for m in re.finditer("\|", condition)] condition_list: list[str] = list(condition) new_condition: str = condition if or_locs: for loc in or_locs: if condition_list[loc - 1] == " " and condition_list[loc + 1] == " ": condition_list[loc - 1] = ")" condition_list[loc + 1] = "(" else: pass new_condition = "(" + "".join(condition_list) + ")" else: pass return new_condition
[docs] def replace_atomic_pred_ids_to_names(condition: str, atom_prop_names: list[str]): spec: str = condition # Replace atomic props in numeric IDs to their actual names (ex. '0'->'psi_0') for i, atom_prop_name in enumerate(reversed(atom_prop_names)): target = str(len(atom_prop_names) - 1 - i) spec = spec.replace(target, atom_prop_name) return spec