from typing import TYPE_CHECKING
import numpy as np
from symaware.base import Environment as BaseEnvironment
from symaware.base import get_logger, log
from symaware.extra.ros.utils import RosClient
from .entities import Entity
if TYPE_CHECKING:
# String type hinting to support python 3.9
from symaware.base.utils import AsyncLoopLock
[docs]
class Environment(BaseEnvironment):
"""
Environment based on the PyBullet physics engine.
Args
----
ros_client:
Client used to connect to the ROS bridge to send and receive topic messages
async_loop_lock:
Async loop lock to use for the environment
"""
__LOGGER = get_logger(__name__, "ROS.Environment")
def __init__(self, ros_client: "RosClient | None", async_loop_lock: "AsyncLoopLock | None" = None):
super().__init__(async_loop_lock)
self._is_ros_initialized = False
self._entity_states = {}
self._ros = ros_client
@property
def ros(self) -> RosClient:
"""ROS client instance."""
return self._ros
[docs]
@log(__LOGGER)
def get_entity_state(self, entity: Entity) -> np.ndarray:
if not isinstance(entity, Entity):
raise TypeError(f"Expected Entity, got {type(entity)}")
# TODO: get info from ROS
return np.array([])
[docs]
@log(__LOGGER)
def _add_entity(self, entity: Entity):
if not isinstance(entity, Entity):
raise TypeError(f"Expected Entity, got {type(entity)}")
if not self._is_ros_initialized:
self.initialise()
entity.initialise(self._ros)
[docs]
def initialise(self):
if self._is_ros_initialized:
return
self._notify("initialising", self)
self._is_ros_initialized = True
self._notify("initialised", self)
[docs]
def step(self):
self._notify("stepping", self)
super().step()
for entity in self._agent_entities.values():
entity.step()
self._notify("stepped", self)
[docs]
def stop(self):
self._notify("stopping", self)
self._is_ros_initialized = False
self._notify("stopped", self)