Source code for symaware.base.models.environment

import time
from abc import abstractmethod
from typing import TYPE_CHECKING

import numpy as np

from symaware.base.data import Identifier
from symaware.base.utils import AsyncLoopLockable, Publisher, get_logger

from .entity import Entity

if TYPE_CHECKING:
    # Forwards declaration
    # String type hinting to support python 3.8
    import sys
    from typing import Any, Callable

    from symaware.base.utils import AsyncLoopLock

    from ..agent import Agent

    if sys.version_info >= (3, 10):
        from typing import TypeAlias
    else:
        from typing_extensions import TypeAlias

    EnvironmentCallback: TypeAlias = Callable[["Environment"], Any]


[docs] class Environment(Publisher, AsyncLoopLockable): """ Just a support class to have multiple agents working in the same environment Args ---- async_loop_lock: async loop lock to use with the environment """ def __init__(self, async_loop_lock: "AsyncLoopLock | None" = None): Publisher.__init__(self) AsyncLoopLockable.__init__(self, async_loop_lock) self._entities: "set[Entity]" = set() # stores the states of the agent in the system self._agent_entities: "dict[Identifier, Entity]" = {} self._running = False self._start_time = -1.0
[docs] @abstractmethod def _add_entity(self, entity: Entity): """ Add an entity to the environment, initialising it. The actual implementation should be done in the derived class, based on the simulated environment API. The entity's :func:`initialise_entity` function should be called within this function with the appropriate arguments. Args ---- entity: entity to initialise """ pass
[docs] @abstractmethod def get_entity_state(self, entity: Entity) -> np.ndarray: """ Get the state of an entity in the environment. The actual implementation should be done in the derived class, based on the simulated environment API. Args ---- entity: entity to get the state of Returns ------- State of the entity within the environment """ pass
[docs] def get_agent_state(self, agent: "Identifier | Agent") -> np.ndarray: if isinstance(agent, Identifier): return self.get_entity_state(self._agent_entities[agent]) return self.get_entity_state(self._agent_entities[agent.id])
@property def entities(self) -> "set[Entity]": """Set of entities in the environment""" return self._entities @property def agent_states(self) -> "dict[Identifier, np.ndarray]": """Dictionary mapping agent identifiers to their states in the environment""" return {agent_id: self.get_agent_state(agent_id) for agent_id in self._agent_entities} @property def elapsed_time(self) -> float: """ Time elapsed since the start of the simulation. If the simulation has not started yet, it will return -1. Can be used to synchronise the agents in the environment. It is 0 at the first simulation step and incremented by the delta time of each subsequent step. """ return -1 if self._start_time == -1 else time.time() - self._start_time @property def start_time(self) -> float: """ Time at which the simulation started. If the simulation has not started yet, it will return -1. """ return self._start_time
[docs] def add_agents(self, *agents: "Agent"): """ Abstract high level interface this class exposes to add entities to the environment. The :func:`_add_single_obstacle` will add the entity to the underlying physics engine. Args ---- agent: agent to add to the environment """ for agent in agents: self.add_entities(agent.entity) self._agent_entities[agent.id] = agent.entity
[docs] def add_entities(self, *entities: Entity): """ Abstract high level interface this class exposes to add entities to the environment. The :func:`_add_single_obstacle` will add the entity to the underlying physics engine. Once the entity is added, it is also stored in the internal set of entities, to avoid adding and initialising it multiple times. Args ---- entities: single instance of :class:`symaware.base.Entity` or an iterable of them """ for entity in entities: if not isinstance(entity, Entity): raise TypeError(f"Expected entity, got {type(entity)}") if entity in self._entities: get_logger(__name__, "Environment").warning("Entity %s already present in the environment", entity) continue self._add_entity(entity) self._entities.add(entity)
[docs] @abstractmethod def initialise(self): """ Initialise the simulation, allocating the required resources. Should be called when the simulation has been set up and is ready to be run. Some environment implementations may call it automatically when the environment is created. It is their responsibility to ensure that the method is idempotent. """ pass
[docs] @abstractmethod def step(self): """ It can be called repeatedly step the environment forward in time, updating the state of all the entities. """ if self._start_time == -1: self._start_time = time.time()
[docs] @abstractmethod def stop(self): """ Terminate the simulation, releasing the resources. Should be called when the simulation has been running manually and needs to be stopped. Some environment implementations may call it automatically when the environment is destroyed. It is their responsibility to ensure that the method is idempotent. Warning ------- Depending on the simulator implementation, calling this method may invalidate all the entities previously added to the environment. In that case, entities and the environment should be recreated from scratch. """ pass
[docs] async def async_run(self): """ Start the environment loop asynchronously. It will run the environment until :meth:`async_stop` is called. The frequency of the loop is determined by the :class:`.AsyncLoopLock` used to initialise the environment. """ self._running = True await self.next_loop() while self._running: self.step() await self.next_loop()
[docs] async def async_stop(self): """ Gracefully stop the environment loop asynchronously. Once the last cycle is completed, the control is returned to the caller. """ self._running = False await AsyncLoopLockable.async_stop(self) self.stop()
[docs] def add_on_initialising(self, callback: "EnvironmentCallback"): """ Add a callback to the event ``initialising`` Args ---- callback: Callback to add """ self._add("initialising", callback)
[docs] def remove_on_initialising(self, callback: "EnvironmentCallback"): """ Remove a callback from the event ``initialising`` Args ---- callback: Callback to remove """ self._remove("initialising", callback)
[docs] def add_on_initialised(self, callback: "EnvironmentCallback"): """ Add a callback to the event ``initialised`` Args ---- callback: Callback to add """ self._add("initialised", callback)
[docs] def remove_on_initialised(self, callback: "EnvironmentCallback"): """ Remove a callback from the event ``initialised`` Args ---- callback: Callback to remove """ self._remove("initialised", callback)
[docs] def add_on_stepping(self, callback: "EnvironmentCallback"): """ Add a callback to the event ``stepping`` Args ---- callback: Callback to add """ self._add("stepping", callback)
[docs] def remove_on_stepping(self, callback: "EnvironmentCallback"): """ Remove a callback from the event ``stepping`` Args ---- callback: Callback to remove """ self._remove("stepping", callback)
[docs] def add_on_stepped(self, callback: "EnvironmentCallback"): """ Add a callback to the event ``stepped`` Args ---- callback: Callback to add """ self._add("stepped", callback)
[docs] def remove_on_stepped(self, callback: "EnvironmentCallback"): """ Remove a callback from the event ``stepped`` Args ---- callback: Callback to remove """ self._remove("stepped", callback)
[docs] def add_on_stopping(self, callback: "EnvironmentCallback"): """ Add a callback to the event ``stopping`` Args ---- callback: Callback to add """ self._add("stopping", callback)
[docs] def remove_on_stopping(self, callback: "EnvironmentCallback"): """ Remove a callback from the event ``stopping`` Args ---- callback: Callback to remove """ self._remove("stopping", callback)
[docs] def add_on_stopped(self, callback: "EnvironmentCallback"): """ Add a callback to the event ``stopped`` Args ---- callback: Callback to add """ self._add("stopped", callback)
[docs] def remove_on_stopped(self, callback: "EnvironmentCallback"): """ Remove a callback from the event ``stopped`` Args ---- callback: Callback to remove """ self._remove("stopped", callback)