from dataclasses import dataclass, field
from functools import cached_property
from itertools import product
from typing import TYPE_CHECKING, Generic, TypeVar
import numpy as np
if TYPE_CHECKING:
from typing import Iterable
_S = TypeVar("_S")
_SO = TypeVar("_SO")
_A = TypeVar("_A")
_AO = TypeVar("_AO")
[docs]
@dataclass
class MDP(Generic[_S, _A]):
"""A Markov Decision Process (MDP) representation.
Represents an MDP as a tuple (states, initial_state, actions, transitions).
Supports accessing and manipulating transition probabilities and provides
MDP composition via the @ operator.
Args
----
states:
Collection of states. Can contain any hashable type
initial_state:
The initial state
actions:
Collection of actions. Can contain any hashable type
transitions:
3D array of state transition probabilities accessible via state_from -> action -> state_to
labels:
Labels for the states. Auto-generated from states if not provided
"""
states: "Iterable[_S]"
initial_state: "_S"
actions: "Iterable[_A]"
transitions: "np.ndarray[np.float64]"
labels: "tuple[str, ...]" = field(default_factory=tuple)
_state_to_idx: "dict[_S, int]" = field(init=False, repr=False)
_action_to_idx: "dict[_A, int]" = field(init=False, repr=False)
@property
def num_states(self) -> int:
return len(self.states)
@property
def num_actions(self) -> int:
return len(self.actions)
def __post_init__(self):
"""
Initialize the MDP after dataclass creation.
Creates internal dictionaries for mapping states and actions to indices,
and generates default labels if none are provided.
Raises
------
AssertionError:
If transition probabilities do not sum to 1
"""
self._state_to_idx = {state: idx for idx, state in enumerate(self.states)}
self._action_to_idx = {action: idx for idx, action in enumerate(self.actions)}
self.labels = self.labels or tuple(str(i) for i in self.states)
# assert np.allclose(np.sum(self.transitions, axis=2), 1), "Transition probabilities must sum to 1"
[docs]
def get_trans_prob_idx(self, state_idx: int, action_idx: int) -> "np.ndarray[np.float64]":
"""
Get the transition probability distribution for a given state and action using indices.
Args
----
state_idx:
Index of the source state
action_idx:
Index of the action
Returns
-------
Array of transition probabilities to all possible next states
"""
return self.transitions[state_idx, action_idx]
[docs]
def get_trans_prob(self, state: "_S", action: "_A") -> "np.ndarray[np.float64]":
"""
Get the transition probability distribution for a given state and action.
Args
----
state:
Source state
action:
Action to take
Returns
-------
Array of transition probabilities to all possible next states,
or 0.0 if the state or action is not known
"""
if state not in self._state_to_idx or action not in self._action_to_idx:
return np.array([0.0])
state_idx = self._state_to_idx[state]
action_idx = self._action_to_idx[action]
return self.get_trans_prob_idx(state_idx, action_idx)
def __getitem__(self, key: "tuple[_S, _A, _S]") -> float:
"""
Get the transition probability from one state to another given an action.
Allows accessing transition probabilities using syntax: mdp[state_from, action, state_to]
Args
----
key:
Tuple of (source_state, action, destination_state)
Returns
-------
Transition probability from source_state to destination_state given action,
or 0.0 if any of the states or action is not valid
"""
s_from, action, s_to = key
if s_from not in self._state_to_idx or action not in self._action_to_idx or s_to not in self._state_to_idx:
return 0.0
state_idx = self._state_to_idx[s_from]
action_idx = self._action_to_idx[action]
next_state_idx = self._state_to_idx[s_to]
return self.transitions[state_idx, action_idx, next_state_idx]
def __setitem__(self, key: "tuple[_S, _A, _S]", prob: float):
"""
Set the transition probability from one state to another given an action.
Allows setting transition probabilities using syntax: mdp[state_from, action, state_to] = prob
Args
----
key:
Tuple of (source_state, action, destination_state)
prob:
Transition probability to set
Raises
------
KeyError: If any of the states or action is not valid in the MDP
"""
s_from, action, s_to = key
if s_from not in self._state_to_idx or action not in self._action_to_idx or s_to not in self._state_to_idx:
raise KeyError("Invalid state or action")
state_idx = self._state_to_idx[s_from]
action_idx = self._action_to_idx[action]
next_state_idx = self._state_to_idx[s_to]
self.transitions[state_idx, action_idx, next_state_idx] = prob
[docs]
def state_to_idx(self, state: "_S") -> int:
"""Convert a state to its index.
Args
----
state:
State to convert
Returns
-------
Index of the state
"""
return self._state_to_idx[state]
[docs]
def action_to_idx(self, action: "_A") -> int:
"""Convert an action to its index.
Args
----
action:
Action to convert
Returns
-------
Index of the action
"""
return self._action_to_idx[action]
[docs]
def get_label(self, state: "_S") -> str:
"""Get the label for a given state.
Args
----
state:
State to get label for
Returns
-------
Label of the state
"""
return self.labels[self._state_to_idx[state]]
@property
def state_idxs(self):
"""
Get an array of all state indices.
Returns
-------
Array of state indices from 0 to len(states)-1
"""
return np.arange(len(self.states))
@property
def action_idxs(self):
"""
Get an array of all action indices.
Returns
-------
Array of action indices from 0 to len(actions)-1
"""
return np.arange(len(self.actions))
def __matmul__(self, other: "MDP[_SO, _AO]") -> "MDP[tuple[_S, _SO], tuple[_A, _AO]]":
"""
Compute the product (composition) of two MDPs using the @ operator.
Creates a new MDP where states are tuples of states from both MDPs,
actions are tuples of actions from both MDPs, and transition probabilities
are the product of the individual transition probabilities.
Args
----
other:
Another MDP to compose with this one
Returns
-------
Product MDP with combined state and action spaces
Raises
------
AssertionError: If the resulting transition probabilities don't sum to 1
"""
prod_states = tuple(product(self.states, other.states))
prod_actions = tuple(product(self.actions, other.actions))
prod_labels = tuple(f"{s_label}_{o_label}" for s_label, o_label in product(self.labels, other.labels))
transitions = np.zeros((len(prod_states), len(prod_actions), len(prod_states)))
# Just to ensure we are initialising a valid transition matrix
transitions[:, :, 0] = 1.0
prod_mdp = self.__class__(
states=prod_states,
initial_state=(self.initial_state, other.initial_state),
actions=prod_actions,
transitions=transitions,
labels=prod_labels,
)
for prod_s_from, prod_action, prod_s_to in product(prod_states, prod_actions, prod_states):
prod_mdp[prod_s_from, prod_action, prod_s_to] = (
self[prod_s_from[0], prod_action[0], prod_s_to[0]] * other[prod_s_from[1], prod_action[1], prod_s_to[1]]
)
# assert np.allclose(np.sum(prod_mdp.transitions, axis=2), 1), "Transition probabilities must sum to 1"
return prod_mdp
[docs]
@dataclass
class LightMDP:
"""A lightweight Markov Decision Process (MDP) representation.
Uses integer indices for states and actions instead of hashable objects.
Provides the same interface as MDP but with reduced memory overhead.
Args
----
num_states:
Number of states
initial_state:
Index of the initial state
num_actions:
Number of actions
transitions:
3D array of state transition probabilities accessible via state_from -> action -> state_to
labels:
Labels for the states. Auto-generated if not provided
"""
num_states: int
initial_state: int
num_actions: int
transitions: "np.ndarray[np.float64]"
labels: "tuple[str, ...]" = field(default_factory=tuple)
[docs]
@classmethod
def null_mdp(cls) -> "LightMDP":
return cls(
num_states=0,
initial_state=0,
num_actions=0,
transitions=np.empty((0, 0, 0)),
labels=[],
)
def __post_init__(self):
"""
Initialize the MDP after dataclass creation.
Generates default labels if none are provided.
Raises
------
AssertionError: If transition probabilities do not sum to 1
"""
self.labels = self.labels or tuple(str(i) for i in range(self.num_states))
assert len(self.labels) == self.num_states, "Number of labels must match number of states"
assert self.transitions.ndim == 3, "Transition matrix must be 3-dimensional"
assert self.transitions.dtype == np.float64, "Transition matrix must be of type float64"
assert self.transitions.shape == (
self.num_states,
self.num_actions,
self.num_states,
), "Transition matrix shape must be (states, actions, states)"
# assert np.allclose(np.sum(self.transitions, axis=2), 1), "Transition probabilities must sum to 1"
@cached_property
def states(self) -> "np.ndarray[int]":
"""
Get an array of all states.
Returns
-------
Array of state indices from 0 to num_states-1
"""
return np.arange(self.num_states)
@cached_property
def actions(self) -> "np.ndarray[int]":
"""
Get an array of all actions.
Returns
-------
Array of action indices from 0 to num_actions-1
"""
return np.arange(self.num_actions)
[docs]
def state_to_idx(self, state: "int") -> int:
"""Convert a state index (identity function for LightMDP).
Args
----
state:
State index
Returns
-------
The same state index
"""
return state
[docs]
def action_to_idx(self, action: "int") -> int:
"""Convert an action index (identity function for LightMDP).
Args
----
action:
Action index
Returns
-------
The same action index
"""
return action
[docs]
def get_trans_prob(self, state: "int", action: "int") -> "np.ndarray[np.float64]":
"""
Get the transition probability distribution for a given state and action.
Args
----
state:
Source state
action:
Action to take
Returns
-------
Array of transition probabilities to all possible next states,
or 0.0 if the state or action is not known
"""
return self.transitions[state, action]
def __getitem__(self, key: "tuple[int, int, int]") -> float:
"""
Get the transition probability from one state to another given an action.
Allows accessing transition probabilities using syntax: mdp[state_from, action, state_to]
Args
----
key:
Tuple of (source_state, action, destination_state)
Returns
-------
Transition probability from source_state to destination_state given action,
or 0.0 if any of the states or action is not valid
"""
s_from, action, s_to = key
return self.transitions[s_from, action, s_to]
def __setitem__(self, key: "tuple[int, int, int]", prob: float):
"""
Set the transition probability from one state to another given an action.
Allows setting transition probabilities using syntax: mdp[state_from, action, state_to] = prob
Args
----
key:
Tuple of (source_state, action, destination_state)
prob:
Transition probability to set
Raises
------
KeyError: If any of the states or action is not valid in the MDP
"""
s_from, action, s_to = key
self.transitions[s_from, action, s_to] = prob
[docs]
def get_label(self, state: "int") -> str:
"""Get the label for a given state.
Args
----
state:
State index
Returns
-------
Label of the state
"""
return self.labels[state]
def __matmul__(self, other: "LightMDP") -> "LightMDP":
"""
Compute the product (composition) of two MDPs using the @ operator.
Creates a new MDP where states are tuples of states from both MDPs,
actions are tuples of actions from both MDPs, and transition probabilities
are the product of the individual transition probabilities.
Args
----
other:
Another MDP to compose with this one
Returns
-------
Product MDP with combined state and action spaces
Raises
------
AssertionError: If the resulting transition probabilities don't sum to 1
"""
if self.num_states == 0:
return other
if other.num_states == 0:
return self
S1, S2 = np.meshgrid(np.arange(self.num_states), np.arange(other.num_states), indexing="ij")
prod_states = np.column_stack([S1.ravel(), S2.ravel()])
A1, A2 = np.meshgrid(np.arange(self.num_actions), np.arange(other.num_actions), indexing="ij")
prod_actions = np.column_stack([A1.ravel(), A2.ravel()])
prod_labels = tuple(f"{s_label}_{o_label}" for s_label, o_label in product(self.labels, other.labels))
transitions = np.zeros((len(prod_states), len(prod_actions), len(prod_states)))
# Just to ensure we are initialising a valid transition matrix
transitions[:, :, 0] = 1.0
prod_mdp = self.__class__(
num_states=len(prod_states),
initial_state=(self.initial_state, other.initial_state),
num_actions=len(prod_actions),
transitions=transitions,
labels=prod_labels,
)
for i, (prod_s_from, prod_action, prod_s_to) in enumerate(product(prod_states, prod_actions, prod_states)):
s_from = i // (len(prod_actions) * len(prod_states))
a = (i // len(prod_states)) % len(prod_actions)
s_to = i % len(prod_states)
prod_mdp[s_from, a, s_to] = (
self[prod_s_from[0], prod_action[0], prod_s_to[0]] * other[prod_s_from[1], prod_action[1], prod_s_to[1]]
)
assert np.allclose(np.sum(prod_mdp.transitions, axis=2), 1), "Transition probabilities must sum to 1"
return prod_mdp
[docs]
def main():
a = MDP(
states=["e", "p"],
initial_state="e",
actions=["stay", "go"],
transitions=np.array([[[0.8, 0.2], [0.2, 0.8]], [[0.2, 0.8], [0.2, 0.8]]]),
)
b = MDP(
states=["A", "B", "C"],
initial_state="A",
actions=[0],
transitions=np.array([[[0.8, 0.2, 0.0]], [[0.2, 0.1, 0.7]], [[0.8, 0.1, 0.1]]]),
)
c = a @ b
print(c)
a = LightMDP(
num_states=2,
initial_state=0,
num_actions=2,
transitions=np.array([[[0.8, 0.2], [0.2, 0.8]], [[0.2, 0.8], [0.2, 0.8]]]),
)
b = LightMDP(
num_states=3,
initial_state=0,
num_actions=1,
transitions=np.array([[[0.8, 0.2, 0.0]], [[0.2, 0.1, 0.7]], [[0.8, 0.1, 0.1]]]),
)
c = a @ b
print(a @ LightMDP.null_mdp())
print(LightMDP.null_mdp() @ a)
if __name__ == "__main__":
main()