Source code for symaware.simulators.carla.mdp

from dataclasses import dataclass, field
from functools import cached_property
from itertools import product
from typing import TYPE_CHECKING, Generic, TypeVar

import numpy as np

if TYPE_CHECKING:
    from typing import Iterable

_S = TypeVar("_S")
_SO = TypeVar("_SO")
_A = TypeVar("_A")
_AO = TypeVar("_AO")


[docs] @dataclass class MDP(Generic[_S, _A]): """A Markov Decision Process (MDP) representation. Represents an MDP as a tuple (states, initial_state, actions, transitions). Supports accessing and manipulating transition probabilities and provides MDP composition via the @ operator. Args ---- states: Collection of states. Can contain any hashable type initial_state: The initial state actions: Collection of actions. Can contain any hashable type transitions: 3D array of state transition probabilities accessible via state_from -> action -> state_to labels: Labels for the states. Auto-generated from states if not provided """ states: "Iterable[_S]" initial_state: "_S" actions: "Iterable[_A]" transitions: "np.ndarray[np.float64]" labels: "tuple[str, ...]" = field(default_factory=tuple) _state_to_idx: "dict[_S, int]" = field(init=False, repr=False) _action_to_idx: "dict[_A, int]" = field(init=False, repr=False) @property def num_states(self) -> int: return len(self.states) @property def num_actions(self) -> int: return len(self.actions) def __post_init__(self): """ Initialize the MDP after dataclass creation. Creates internal dictionaries for mapping states and actions to indices, and generates default labels if none are provided. Raises ------ AssertionError: If transition probabilities do not sum to 1 """ self._state_to_idx = {state: idx for idx, state in enumerate(self.states)} self._action_to_idx = {action: idx for idx, action in enumerate(self.actions)} self.labels = self.labels or tuple(str(i) for i in self.states) # assert np.allclose(np.sum(self.transitions, axis=2), 1), "Transition probabilities must sum to 1"
[docs] def get_trans_prob_idx(self, state_idx: int, action_idx: int) -> "np.ndarray[np.float64]": """ Get the transition probability distribution for a given state and action using indices. Args ---- state_idx: Index of the source state action_idx: Index of the action Returns ------- Array of transition probabilities to all possible next states """ return self.transitions[state_idx, action_idx]
[docs] def get_trans_prob(self, state: "_S", action: "_A") -> "np.ndarray[np.float64]": """ Get the transition probability distribution for a given state and action. Args ---- state: Source state action: Action to take Returns ------- Array of transition probabilities to all possible next states, or 0.0 if the state or action is not known """ if state not in self._state_to_idx or action not in self._action_to_idx: return np.array([0.0]) state_idx = self._state_to_idx[state] action_idx = self._action_to_idx[action] return self.get_trans_prob_idx(state_idx, action_idx)
def __getitem__(self, key: "tuple[_S, _A, _S]") -> float: """ Get the transition probability from one state to another given an action. Allows accessing transition probabilities using syntax: mdp[state_from, action, state_to] Args ---- key: Tuple of (source_state, action, destination_state) Returns ------- Transition probability from source_state to destination_state given action, or 0.0 if any of the states or action is not valid """ s_from, action, s_to = key if s_from not in self._state_to_idx or action not in self._action_to_idx or s_to not in self._state_to_idx: return 0.0 state_idx = self._state_to_idx[s_from] action_idx = self._action_to_idx[action] next_state_idx = self._state_to_idx[s_to] return self.transitions[state_idx, action_idx, next_state_idx] def __setitem__(self, key: "tuple[_S, _A, _S]", prob: float): """ Set the transition probability from one state to another given an action. Allows setting transition probabilities using syntax: mdp[state_from, action, state_to] = prob Args ---- key: Tuple of (source_state, action, destination_state) prob: Transition probability to set Raises ------ KeyError: If any of the states or action is not valid in the MDP """ s_from, action, s_to = key if s_from not in self._state_to_idx or action not in self._action_to_idx or s_to not in self._state_to_idx: raise KeyError("Invalid state or action") state_idx = self._state_to_idx[s_from] action_idx = self._action_to_idx[action] next_state_idx = self._state_to_idx[s_to] self.transitions[state_idx, action_idx, next_state_idx] = prob
[docs] def state_to_idx(self, state: "_S") -> int: """Convert a state to its index. Args ---- state: State to convert Returns ------- Index of the state """ return self._state_to_idx[state]
[docs] def action_to_idx(self, action: "_A") -> int: """Convert an action to its index. Args ---- action: Action to convert Returns ------- Index of the action """ return self._action_to_idx[action]
[docs] def get_label(self, state: "_S") -> str: """Get the label for a given state. Args ---- state: State to get label for Returns ------- Label of the state """ return self.labels[self._state_to_idx[state]]
@property def state_idxs(self): """ Get an array of all state indices. Returns ------- Array of state indices from 0 to len(states)-1 """ return np.arange(len(self.states)) @property def action_idxs(self): """ Get an array of all action indices. Returns ------- Array of action indices from 0 to len(actions)-1 """ return np.arange(len(self.actions)) def __matmul__(self, other: "MDP[_SO, _AO]") -> "MDP[tuple[_S, _SO], tuple[_A, _AO]]": """ Compute the product (composition) of two MDPs using the @ operator. Creates a new MDP where states are tuples of states from both MDPs, actions are tuples of actions from both MDPs, and transition probabilities are the product of the individual transition probabilities. Args ---- other: Another MDP to compose with this one Returns ------- Product MDP with combined state and action spaces Raises ------ AssertionError: If the resulting transition probabilities don't sum to 1 """ prod_states = tuple(product(self.states, other.states)) prod_actions = tuple(product(self.actions, other.actions)) prod_labels = tuple(f"{s_label}_{o_label}" for s_label, o_label in product(self.labels, other.labels)) transitions = np.zeros((len(prod_states), len(prod_actions), len(prod_states))) # Just to ensure we are initialising a valid transition matrix transitions[:, :, 0] = 1.0 prod_mdp = self.__class__( states=prod_states, initial_state=(self.initial_state, other.initial_state), actions=prod_actions, transitions=transitions, labels=prod_labels, ) for prod_s_from, prod_action, prod_s_to in product(prod_states, prod_actions, prod_states): prod_mdp[prod_s_from, prod_action, prod_s_to] = ( self[prod_s_from[0], prod_action[0], prod_s_to[0]] * other[prod_s_from[1], prod_action[1], prod_s_to[1]] ) # assert np.allclose(np.sum(prod_mdp.transitions, axis=2), 1), "Transition probabilities must sum to 1" return prod_mdp
[docs] @dataclass class LightMDP: """A lightweight Markov Decision Process (MDP) representation. Uses integer indices for states and actions instead of hashable objects. Provides the same interface as MDP but with reduced memory overhead. Args ---- num_states: Number of states initial_state: Index of the initial state num_actions: Number of actions transitions: 3D array of state transition probabilities accessible via state_from -> action -> state_to labels: Labels for the states. Auto-generated if not provided """ num_states: int initial_state: int num_actions: int transitions: "np.ndarray[np.float64]" labels: "tuple[str, ...]" = field(default_factory=tuple)
[docs] @classmethod def null_mdp(cls) -> "LightMDP": return cls( num_states=0, initial_state=0, num_actions=0, transitions=np.empty((0, 0, 0)), labels=[], )
def __post_init__(self): """ Initialize the MDP after dataclass creation. Generates default labels if none are provided. Raises ------ AssertionError: If transition probabilities do not sum to 1 """ self.labels = self.labels or tuple(str(i) for i in range(self.num_states)) assert len(self.labels) == self.num_states, "Number of labels must match number of states" assert self.transitions.ndim == 3, "Transition matrix must be 3-dimensional" assert self.transitions.dtype == np.float64, "Transition matrix must be of type float64" assert self.transitions.shape == ( self.num_states, self.num_actions, self.num_states, ), "Transition matrix shape must be (states, actions, states)" # assert np.allclose(np.sum(self.transitions, axis=2), 1), "Transition probabilities must sum to 1" @cached_property def states(self) -> "np.ndarray[int]": """ Get an array of all states. Returns ------- Array of state indices from 0 to num_states-1 """ return np.arange(self.num_states) @cached_property def actions(self) -> "np.ndarray[int]": """ Get an array of all actions. Returns ------- Array of action indices from 0 to num_actions-1 """ return np.arange(self.num_actions)
[docs] def state_to_idx(self, state: "int") -> int: """Convert a state index (identity function for LightMDP). Args ---- state: State index Returns ------- The same state index """ return state
[docs] def action_to_idx(self, action: "int") -> int: """Convert an action index (identity function for LightMDP). Args ---- action: Action index Returns ------- The same action index """ return action
[docs] def get_trans_prob(self, state: "int", action: "int") -> "np.ndarray[np.float64]": """ Get the transition probability distribution for a given state and action. Args ---- state: Source state action: Action to take Returns ------- Array of transition probabilities to all possible next states, or 0.0 if the state or action is not known """ return self.transitions[state, action]
def __getitem__(self, key: "tuple[int, int, int]") -> float: """ Get the transition probability from one state to another given an action. Allows accessing transition probabilities using syntax: mdp[state_from, action, state_to] Args ---- key: Tuple of (source_state, action, destination_state) Returns ------- Transition probability from source_state to destination_state given action, or 0.0 if any of the states or action is not valid """ s_from, action, s_to = key return self.transitions[s_from, action, s_to] def __setitem__(self, key: "tuple[int, int, int]", prob: float): """ Set the transition probability from one state to another given an action. Allows setting transition probabilities using syntax: mdp[state_from, action, state_to] = prob Args ---- key: Tuple of (source_state, action, destination_state) prob: Transition probability to set Raises ------ KeyError: If any of the states or action is not valid in the MDP """ s_from, action, s_to = key self.transitions[s_from, action, s_to] = prob
[docs] def get_label(self, state: "int") -> str: """Get the label for a given state. Args ---- state: State index Returns ------- Label of the state """ return self.labels[state]
def __matmul__(self, other: "LightMDP") -> "LightMDP": """ Compute the product (composition) of two MDPs using the @ operator. Creates a new MDP where states are tuples of states from both MDPs, actions are tuples of actions from both MDPs, and transition probabilities are the product of the individual transition probabilities. Args ---- other: Another MDP to compose with this one Returns ------- Product MDP with combined state and action spaces Raises ------ AssertionError: If the resulting transition probabilities don't sum to 1 """ if self.num_states == 0: return other if other.num_states == 0: return self S1, S2 = np.meshgrid(np.arange(self.num_states), np.arange(other.num_states), indexing="ij") prod_states = np.column_stack([S1.ravel(), S2.ravel()]) A1, A2 = np.meshgrid(np.arange(self.num_actions), np.arange(other.num_actions), indexing="ij") prod_actions = np.column_stack([A1.ravel(), A2.ravel()]) prod_labels = tuple(f"{s_label}_{o_label}" for s_label, o_label in product(self.labels, other.labels)) transitions = np.zeros((len(prod_states), len(prod_actions), len(prod_states))) # Just to ensure we are initialising a valid transition matrix transitions[:, :, 0] = 1.0 prod_mdp = self.__class__( num_states=len(prod_states), initial_state=(self.initial_state, other.initial_state), num_actions=len(prod_actions), transitions=transitions, labels=prod_labels, ) for i, (prod_s_from, prod_action, prod_s_to) in enumerate(product(prod_states, prod_actions, prod_states)): s_from = i // (len(prod_actions) * len(prod_states)) a = (i // len(prod_states)) % len(prod_actions) s_to = i % len(prod_states) prod_mdp[s_from, a, s_to] = ( self[prod_s_from[0], prod_action[0], prod_s_to[0]] * other[prod_s_from[1], prod_action[1], prod_s_to[1]] ) assert np.allclose(np.sum(prod_mdp.transitions, axis=2), 1), "Transition probabilities must sum to 1" return prod_mdp
[docs] def main(): a = MDP( states=["e", "p"], initial_state="e", actions=["stay", "go"], transitions=np.array([[[0.8, 0.2], [0.2, 0.8]], [[0.2, 0.8], [0.2, 0.8]]]), ) b = MDP( states=["A", "B", "C"], initial_state="A", actions=[0], transitions=np.array([[[0.8, 0.2, 0.0]], [[0.2, 0.1, 0.7]], [[0.8, 0.1, 0.1]]]), ) c = a @ b print(c) a = LightMDP( num_states=2, initial_state=0, num_actions=2, transitions=np.array([[[0.8, 0.2], [0.2, 0.8]], [[0.2, 0.8], [0.2, 0.8]]]), ) b = LightMDP( num_states=3, initial_state=0, num_actions=1, transitions=np.array([[[0.8, 0.2, 0.0]], [[0.2, 0.1, 0.7]], [[0.8, 0.1, 0.1]]]), ) c = a @ b print(a @ LightMDP.null_mdp()) print(LightMDP.null_mdp() @ a)
if __name__ == "__main__": main()