"""
author Shuhao Qi
"""
import hashlib
from itertools import product
from typing import TYPE_CHECKING, Generic, TypeVar
import gurobipy as grb
import numpy as np
from scipy.spatial.distance import pdist, squareform
from .abstraction import GridAbstraction, VelocityGridAbstraction
from .ltl import DFA
from .mdp import MDP, LightMDP
_S = TypeVar("_S")
_A = TypeVar("_A")
NULL_MDP = MDP([0], initial_state=0, actions=[0], transitions=np.array([[[1]]]))
NULL_LIGHT_MDP = LightMDP.null_mdp()
NULL_DFA = DFA.null_dfa()
if TYPE_CHECKING:
from typing import Callable
CostFunction = Callable[[int, str, GridAbstraction | None], float]
[docs]
class Product(Generic[_S, _A]):
r"""
The product MDP of $\mathcal{M}$, $A_{cs}$ and $A_s$ is defined as a tuple $P = (Z, z_0, A, \hat{T}, G, D)$, with the state set
.. math::
Z \coloneq S \times Q_{cs} \times Q_s
A \coloneq \text{action set}
z_0 \coloneq (\bar{s}_0, \delta_{cs}(q^0_{cs}, \mathcal{L}(s_0)_cs, \delta_s(q^0_s, L(\bar{s}_0)))
\hat{T} \coloneq Z \times A \times Z \to [0,1]
Args
----
MDP:
The Markov Decision Process representing system dynamics
DFA_cs:
The co-safety (liveness) DFA
DFA_safe:
The safety DFA
"""
def __init__(self, MDP: "MDP[_S, _A]", DFA_cs: DFA, DFA_safe: DFA):
self.MDP = MDP
self.DFA_cs: "DFA" = DFA_cs
self.DFA_safe: "DFA" = DFA_safe
# TODO: this could probably become a 3D numpy array for efficiency
self.prod_state_set = tuple(product(self.MDP.states, self.DFA_cs.states, self.DFA_safe.states))
self.prod_action_set = self.MDP.actions
self.prod_transitions = self.gen_product_transition()
self.accepting_states, self.trap_states = self.gen_final_states()
[docs]
def update(self, MDP: "MDP[_S, _A]"):
assert np.array_equal(MDP.transitions, self.MDP.transitions), "MDP transitions do not match!"
self.MDP = MDP
self.prod_transitions = self.gen_product_transition()
[docs]
def gen_product_transition(self):
P_matrix = np.zeros((len(self.prod_state_set), len(self.prod_action_set), len(self.prod_state_set)))
# Pre-compute all DFA transitions for all unique labels
dfa_cs_cache = {}
dfa_safe_cache = {}
for label in set(self.MDP.labels):
alphabet_cs = self.DFA_cs.get_alphabet(label)
alphabet_s = self.DFA_safe.get_alphabet(label)
dfa_cs_cache[label] = alphabet_cs
dfa_safe_cache[label] = alphabet_s
s_coeff = len(self.DFA_cs.states) * len(self.DFA_safe.states)
s_cs_coeff = len(self.DFA_safe.states)
# Iterate through all non-zero transitions in the MDP
combinations = np.nonzero(self.MDP.transitions)
for s_idx, a_idx, next_s_idx in zip(*combinations):
# Find the next state label and corresponding DFA transitions using the cached alphabets
next_x_label = self.MDP.labels[next_s_idx]
alphabet_cs = dfa_cs_cache[next_x_label]
alphabet_s = dfa_safe_cache[next_x_label]
for s_cs, s_s in product(self.DFA_cs.states, self.DFA_safe.states):
# For each state, compute the next DFA states based on the current state and the label
next_s_cs = self.DFA_cs.transitions.get((s_cs, alphabet_cs), s_cs)
next_s_s = self.DFA_safe.transitions.get((s_s, alphabet_s), s_s)
# Compute the current and next product state indices
current_state_idx = s_idx * s_coeff + (s_cs - 1) * s_cs_coeff + (s_s - 1)
next_state_idx = next_s_idx * s_coeff + (next_s_cs - 1) * s_cs_coeff + (next_s_s - 1)
# Set the transition probability in the product MDP
P_matrix[current_state_idx, a_idx, next_state_idx] = self.MDP.transitions[s_idx, a_idx, next_s_idx]
return P_matrix
[docs]
def _state_to_index(self, state: "tuple[_S, int, int]") -> int:
x, s_cs, s_s = state
return (
self.MDP.state_to_idx(x) * len(self.DFA_cs.states) * len(self.DFA_safe.states)
+ (int(s_cs) - 1) * len(self.DFA_safe.states)
+ (int(s_s) - 1)
)
[docs]
def gen_final_states(self):
accepting_states = set()
trap_states = set()
for n, (x, s_cs, s_s) in enumerate(self.prod_state_set):
if self.DFA_safe.is_sink_state(s_s):
trap_states.add(n)
elif self.DFA_cs.is_sink_state(s_cs):
accepting_states.add(n)
return accepting_states, trap_states
[docs]
def gen_cost_map(
self, cost_func: "CostFunction | None" = None, abstraction: "GridAbstraction | None" = None
) -> np.ndarray:
cost_map = np.zeros(len(self.prod_state_set))
if cost_func is None:
return cost_map
for n, (x, _, _) in enumerate(self.prod_state_set):
# x, s_cs, s_s = self.prod_state_set[n]
label = self.MDP.get_label(x)
cost_map[n] = cost_func(x, label, abstraction)
return cost_map
[docs]
def get_next_prod_state(
self, mdp_state: "_S", last_product_state: "tuple[_S, int, int] | int"
) -> "tuple[int, tuple[_S, int, int]]":
if not isinstance(last_product_state, tuple):
last_product_state = self.prod_state_set[last_product_state]
_, last_cs_state, last_safe_state = last_product_state
label = self.MDP.get_label(mdp_state)
alphabet_cs = self.DFA_cs.get_alphabet(label)
alphabet_s = self.DFA_safe.get_alphabet(label)
next_cs_state = self.DFA_cs.transitions.get((last_cs_state, alphabet_cs), last_cs_state)
next_safe_state = self.DFA_safe.transitions.get((last_safe_state, alphabet_s), last_safe_state)
cur_product_state = (mdp_state, next_cs_state, next_safe_state)
product_state_index = self._state_to_index(cur_product_state)
return product_state_index, cur_product_state
TIME_SHRINK_RATE = 1.0
[docs]
class RiskLTL:
"""Risk-aware controller using LTL specifications and a Markov Decision Process.
Solves the problem of finding a policy that satisfies LTL specifications
while minimizing cost using a product automaton and Linear Programming.
"""
def __init__(
self,
abs_model: GridAbstraction,
MDP_sys: LightMDP,
MDP_env: LightMDP = NULL_LIGHT_MDP,
DFA_safe: DFA = NULL_DFA,
DFA_sc: DFA = NULL_DFA,
cost_func: "CostFunction | None" = None,
):
self._abs_model = abs_model
self._prod = Product(MDP_sys @ MDP_env, DFA_sc, DFA_safe)
self._cost_map = self._prod.gen_cost_map(cost_func=cost_func, abstraction=abs_model)
[docs]
@staticmethod
def policy_from_file(file_path: str) -> np.ndarray:
return np.load(file_path)
@property
def cost_map(self):
return self._cost_map
@property
def product(self):
return self._prod
[docs]
def change_cost_func(self, cost_func: "CostFunction"):
self._cost_map = self._prod.gen_cost_map(cost_func=cost_func, abstraction=self._abs_model)
[docs]
def update(self, sys_state: "tuple[int, int]", env_state: "str", risk_th: float):
ego_prod_state_index, self._last_prod_state = self.prod_auto.get_next_prod_state(
(sys_state, env_state), self._last_prod_state
)
occ_measure, policy_map = self.LP_prob.solve(
self.prod_auto.prod_transitions,
self._cost_map,
ego_prod_state_index,
self.prod_auto.accepting_states,
risk_th,
None,
)
risk = self.cal_risk(policy_map, self._cost_map)
optimal_policy, Z = self.LP_prob.extract(occ_measure)
decision_index = optimal_policy[ego_prod_state_index]
return int(decision_index), optimal_policy, risk
[docs]
def cal_risk(self, policy_map, cost_map):
risk = 0
for n in range(len(policy_map)):
risk += policy_map[n] * cost_map[n] / 1.5
return risk
[docs]
class VelocityLPRiskLTL(RiskLTL):
"""Velocity-based LTL risk controller with Linear Programming solver.
Extends RiskLTL to handle velocity-based abstractions and provides
efficient optimization using Gurobi Linear Programming.
"""
def __init__(
self,
abs_model: VelocityGridAbstraction,
DFA_safe: DFA,
DFA_sc: DFA,
cost_func: "CostFunction | None" = None,
):
self._abs_model: VelocityGridAbstraction
super().__init__(abs_model, abs_model.MDP, NULL_LIGHT_MDP, DFA_safe, DFA_sc, cost_func)
self.state_num = len(self._prod.prod_state_set)
self.action_num = len(self._prod.prod_action_set)
self.transition_time = self._compute_transition_time()
@property
def abstraction(self) -> VelocityGridAbstraction:
return self._abs_model
@property
def hash(self):
m = hashlib.md5()
for label in self._abs_model.MDP.labels:
m.update(label.encode("utf-8"))
return m.hexdigest()
[docs]
def update(
self,
pos_speed: "tuple[float, float, float]",
label_function: "Callable[[tuple[float, float, float]], str]" = None,
cost_func: "CostFunction | None" = None,
):
self._abs_model.update(pos_speed=pos_speed, label_function=label_function)
self._prod.update(MDP=self._abs_model.MDP)
self._cost_map = self._prod.gen_cost_map(cost_func=cost_func, abstraction=self._abs_model)
[docs]
def _compute_transition_time(self):
states_to_pos = np.empty((self.state_num, 3))
for s in range(self.state_num):
state, _, _ = self._prod.prod_state_set[s]
(x, y), v = self._abs_model.state_to_pos_speed(state)
states_to_pos[s] = np.array([x, y, v])
# Adjust positions to the center of the grid cells and speed bins
states_to_pos[:, 0] += self._abs_model.grid_map.cell_width / 2
states_to_pos[:, 1] += self._abs_model.grid_map.cell_height / 2
states_to_pos[:, 2] += self._abs_model.speed_resolution * 0.1
# Compute the pairwise euclidian distance matrix
pairwise_distances = squareform(pdist(states_to_pos[:, :2], metric="euclidean"))
# Compute the average velocity matrix
velocity_matrix = (states_to_pos[:, 2][:, np.newaxis] + states_to_pos[:, 2][np.newaxis, :]) / 2
# Compute the transition time matrix as distance / velocity. Handle division by zero.
transition_time = np.where(velocity_matrix != 0, pairwise_distances / velocity_matrix, float("inf"))
transition_time *= TIME_SHRINK_RATE
assert transition_time.shape == (self.state_num, self.state_num)
# Make sure the output shape is (state_num, action_num, state_num)
return np.broadcast_to(transition_time[:, np.newaxis, :], (self.state_num, self.action_num, self.state_num))
[docs]
def solve(
self, initial_state: "int | tuple[float, float, float] | None" = None, initial_guess: "np.ndarray | None" = None
):
P = self._prod.prod_transitions
# self.action_num = P.shape[1] # number of actions
# self.state_num = P.shape[0] # number of states not in T
if initial_state is None:
S0 = self._prod.get_next_prod_state(self._abs_model.MDP.initial_state, (0, 1, 1))[0]
elif isinstance(initial_state, int):
S0 = initial_state
elif len(initial_state) == 3:
mdp_state = self._abs_model.pos_speed_to_state(initial_state[:2], initial_state[2])
S0 = self._prod.get_next_prod_state(mdp_state, (0, 1, 1))[0]
elif len(initial_state) == 5:
x, y, v, s_cs, s_s = initial_state
mdp_state = self._abs_model.pos_speed_to_state((x, y, v))
S0 = self._prod.prod_state_set.index((mdp_state, int(s_cs), int(s_s)))
else:
raise ValueError("Invalid initial_state format")
S0 = int(S0)
gamma = 0.9 # discount factor
th_hard = 5
th_soft = 0.8
env = grb.Env()
model = grb.Model("risk_lp", env=env)
y = model.addVars(
self.state_num, self.action_num, vtype=grb.GRB.CONTINUOUS, lb=0.0, ub=1.0, name="x"
) # occupation measure
z = model.addVar(vtype=grb.GRB.CONTINUOUS, name="z")
# for i in range(self.state_num):
# for j in range(self.action_num):
# state, co, s = self.product.prod_state_set[i]
# y[i, j].setAttr(
# grb.GRB.Attr.VarName,
# f"y_({self._abs_model.state_to_pos_speed(state)}, {co}, {s})_{self.abstraction.action_names[j]}",
# )
# Set initial values for warm start if provided
if initial_guess is not None:
for s in range(self.state_num):
for a in range(self.action_num):
y[s, a].start = initial_guess[s, a]
# Precompute discounted transition probabilities
discounted_P = gamma**self.transition_time * P
rhs = np.zeros(self.state_num)
rhs[S0] = 1
for sn in range(self.state_num):
# Outgoing flow from state sn
outgoing = grb.quicksum(y[sn, a] for a in range(self.action_num))
# Incoming flow to state sn (vectorized over all source states and actions with non-zero transitions to sn)
incoming = grb.quicksum(
discounted_P[s, a, sn] * y[s, a] for s, a in zip(*np.nonzero(discounted_P[:, :, sn]))
)
model.addConstr(outgoing - incoming == rhs[sn])
# Cost constraint - We only need to consider states with non-zero cost, since they will have an impact on the objective and constraints
cost_map_idxs = np.nonzero(self._cost_map)
cost_expr = grb.quicksum(self._cost_map[s] * y[s, a] for s in cost_map_idxs[0] for a in range(self.action_num))
model.addConstr(cost_expr <= th_soft + z)
model.addConstr(z + th_soft <= th_hard)
model.addConstr(z >= 0)
# Objective function - Optimized
from_states = tuple(set(range(self.state_num)) - self._prod.accepting_states)
filtered_P = P[
np.ix_(
from_states, # No need to consider sink states as coming states, as we have already reached the goal or failed
tuple(range(self.action_num)), # Consider all actions
tuple(self._prod.accepting_states), # We only care about transitions to accepting states
)
]
obj_idxs = np.nonzero(filtered_P)
obj = grb.quicksum(2.0 * filtered_P[s, a, sn] * y[from_states[s], a] for s, a, sn in zip(*obj_idxs))
obj.addTerms(-1.0, z)
model.setObjective(obj, grb.GRB.MAXIMIZE)
# vars = [obj.getVar(i) for i in range(model.NumVars)]
# model.setParam(grb.GRB.Param.MIPGap, 1e-4)
# model.setParam(grb.GRB.Param.DualReductions, 0)
# model.setParam(grb.GRB.Param.InfUnbdInfo, 1)
# model.write("new_risk_lp.lp")
model.optimize()
return model.getAttr(grb.GRB.Attr.X, y) if model.status == grb.GRB.OPTIMAL else None
# Check if the model is feasible
if model.status == grb.GRB.OPTIMAL:
# We could not reach the target states.
if not any(True for var in vars if var.X > 0):
print("No reachable accepting states found in the optimal solution.")
return None
for var in sorted(vars, key=lambda v: v.X, reverse=True):
if var.X > 0:
print(f"{var.VarName}: {var.X}")
sol = model.getAttr(grb.GRB.Attr.X, y)
relax = z.getAttr(grb.GRB.Attr.X)
print("relax:", relax)
return sol
elif model.status == grb.GRB.INFEASIBLE:
print("ERROR: LP problem is infeasible!")
print("This usually means the constraints are too restrictive.")
print("Suggestions:")
print("1. Increase th_hard or th_soft thresholds")
print("2. Check if the cost function values are too high")
print("3. Verify that accepting states are reachable")
# Try to compute IIS (Irreducible Inconsistent Subsystem)
model.computeIIS()
print("IIS constraints:")
for c in model.getConstrs():
if c.IISConstr:
print(f" {c.constrName}")
return None
elif model.status == grb.GRB.UNBOUNDED:
print("ERROR: LP problem is unbounded!")
# Print the unbounded ray
unbounded_ray = model.getAttr(grb.GRB.Attr.UnbdRay, y)
print("Unbounded ray:")
for key in unbounded_ray:
if unbounded_ray[key] != 0:
print(f" {key}: {unbounded_ray[key]}")
return None
elif model.status == grb.GRB.INF_OR_UNBD:
print("ERROR: LP problem is infeasible or unbounded!")
print("This usually means the constraints are too restrictive or there's an issue with the formulation.")
print("Suggestions:")
print("1. Increase th_hard from 5 to 10 or higher")
print("2. Increase th_soft from 0.8 to 2.0 or higher")
print("3. Reduce cost function values")
print(f"Current th_hard: {th_hard}, th_soft: {th_soft}")
return None
else:
print(f"ERROR: LP solver returned status {model.status}")
return None
[docs]
def plot_policy(
self, policy: np.ndarray, subsample: int = 1, ax: "plt.Axes | None" = None
) -> "tuple[Figure, Axes3D]":
"""
Create a 3D visualization of the policy showing state transitions.
This function plots the policy as a 3D vector field where:
- X, Y axes represent spatial position
- Z axis represents speed
- Arrows show the transition from current state to next state based on the policy
Args
----
policy:
Policy array where each index corresponds to a state (composed of x, y, speed)
and each element is the action to take in that state
subsample:
Plot every nth state to reduce clutter (default: 1, plot all states)
Returns
-------
tuple[Figure, Axes3D]
Matplotlib figure and 3D axes objects
"""
import matplotlib.pyplot as plt
if ax is not None:
fig = ax.get_figure()
else:
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection="3d")
# Lists to store arrow data
x_positions = []
y_positions = []
z_positions = []
u_directions = [] # x direction
v_directions = [] # y direction
w_directions = [] # z direction (speed)
# Iterate through all states in the policy
for state_idx in range(0, len(policy), subsample):
action: int = policy[state_idx]
if action < 0:
continue # No action defined for this state
# if state_idx in self._prod.trap_states:
# continue
# Get current state's position and speed
state_idx, _, __ = self.product.prod_state_set[state_idx]
(x, y), speed = self._abs_model.state_to_pos_speed(state_idx)
# speed = self._abs_model._state_to_speed_idx(state_idx)
(tx, ty), tspeed = self._abs_model.state_action_to_next_state(state_idx, action)
if (tx, ty) != (x, y) or tspeed != speed:
dx = tx - x
dy = ty - y
dz = tspeed - speed
# Add new arrow
x_positions.append(x)
y_positions.append(y)
z_positions.append(speed)
u_directions.append(dx)
v_directions.append(dy)
w_directions.append(dz)
# Plot arrows using quiver
if len(x_positions) > 0:
# # Color by azimuthal angle
# c: "np.ndarray" = np.array(z_positions)
# # Flatten and normalize
# c = (c.ravel() - c.min()) / c.ptp()
# # Colormap
# c = plt.cm.hsv(c)
ax.quiver(
x_positions,
y_positions,
z_positions,
u_directions,
v_directions,
w_directions,
arrow_length_ratio=0.3,
# colors=c,
# normalize=True,
alpha=0.7,
# linewidths=1.5
)
# Set labels and title
ax.set_xlabel("X Position", fontsize=12)
ax.set_ylabel("Y Position", fontsize=12)
ax.set_zlabel("Speed", fontsize=12)
ax.set_title("Policy Visualization: State Transitions in 3D", fontsize=14, pad=20)
# Set grid
ax.grid(True, alpha=0.3)
# Adjust viewing angle for better visualization
ax.view_init(elev=20, azim=45)
plt.tight_layout()
return fig, ax
[docs]
def plot_policy_plane(
self, policy: np.ndarray, initial_state: "tuple[float, float, float]", flexible_speed=False, ax=None
) -> "tuple[Figure, plt.Axes]":
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(10, 8)) if ax is None else (ax.get_figure(), ax)
x_positions = []
y_positions = []
u_directions = [] # x direction
v_directions = [] # y direction
current_state_idx = self._abs_model.pos_speed_to_state(initial_state[:2], initial_state[2])
if current_state_idx is None:
return fig, ax # Invalid initial state, return empty plot
s_cs = 1
s_ss = 1
prod_state_idx = self._prod._state_to_index((current_state_idx, s_cs, s_ss))
while True:
action: int = policy[prod_state_idx]
if action < 0 and flexible_speed:
# Try all speed options to find a valid action
for i in range(self._abs_model.speed_range - 1, -1, -1):
temp_state_idx = self._abs_model.pos_speed_to_state(
initial_state[:2], i * self._abs_model.speed_resolution
)
temp_prod_state_idx = self._prod._state_to_index((temp_state_idx, s_cs, s_ss))
action = policy[temp_prod_state_idx]
if action >= 0:
break
if action < 0:
break # No action defined for this state
# Get current state's position and speed
(x, y), speed = self._abs_model.state_to_pos_speed(current_state_idx)
(tx, ty), tspeed = self._abs_model.state_action_to_next_state(current_state_idx, action)
if (tx, ty) != (x, y):
dx = tx - x
dy = ty - y
# Add new arrow
x_positions.append(x)
y_positions.append(y)
u_directions.append(dx)
v_directions.append(dy)
else:
break # No movement, likely a sink state
current_state_idx = self._abs_model.pos_speed_to_state((tx, ty), tspeed)
new_prod_state_idx, (_, s_cs, s_ss) = self._prod.get_next_prod_state(current_state_idx, (0, s_cs, s_ss))
if new_prod_state_idx == prod_state_idx:
break
prod_state_idx = new_prod_state_idx
if len(x_positions) > 0:
ax.quiver(
x_positions,
y_positions,
u_directions,
v_directions,
angles="xy",
scale_units="xy",
scale=1,
alpha=0.7,
)
ax.set_xlabel("X Position", fontsize=12)
ax.set_ylabel("Y Position", fontsize=12)
ax.set_title(
f"Policy Visualization from {initial_state[0]:.2f}, {initial_state[1]:.2f}, {initial_state[2]:.2f}",
fontsize=14,
pad=20,
)
ax.set_aspect("equal", adjustable="box")
ax.grid(True, alpha=0.3)
plt.tight_layout()
return fig, ax