import multiprocessing as mp
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Generic, TypeVar
from matplotlib import pyplot as plt
from .component import Component
T = TypeVar("T")
if TYPE_CHECKING:
from multiprocessing.connection import Connection
[docs]
class ProcessPlotter(Generic[T], ABC):
"""Abstract base class for plot generators running in separate processes.
Provides a framework for creating interactive plots that update with data
received through inter-process communication pipes.
Args
----
interval:
Update interval in seconds for the plot timer
"""
def __init__(self, interval: float = 1.0):
self._interval = interval * 1000
self._pipe: "Connection"
self._fig: "plt.Figure"
self._ax: "plt.Axes"
[docs]
def stop(self):
"""Stop the plotter and close all matplotlib figures."""
plt.close("all")
[docs]
def _timer_call_back(self):
while self._pipe.poll():
command = self._pipe.recv()
if command is None:
self.stop()
return False
else:
self.update_plot(command)
self._fig.canvas.draw()
return True
[docs]
def run(self, pipe: "Connection"):
"""Run the plotter in the process.
Args
----
pipe:
Pipe connection for receiving plot update data
"""
self._pipe = pipe
self.initialise_plot()
timer = self._fig.canvas.new_timer(interval=int(self._interval))
timer.add_callback(self._timer_call_back)
timer.start()
plt.show()
[docs]
def initialise_plot(self):
"""Override this method to set up the initial state of the plot."""
self._fig, self._ax = plt.subplots()
[docs]
@abstractmethod
def update_plot(self, data: T):
"""Override this method to update the plot with new data."""
pass
[docs]
class ProcessPlotterManager(Component, Generic[T]):
"""Manager for ProcessPlotter instances running in separate processes.
Manages the lifecycle and communication with plot processes, allowing
asynchronous updates to plots from the main process.
Args
----
agent_id : int
Identifier for the component
async_loop_lock : Any
Optional lock for async loop synchronization
"""
def __init__(self, agent_id: int, async_loop_lock=None):
super().__init__(agent_id=agent_id, async_loop_lock=async_loop_lock)
self._plot_pipe: "Connection"
self._plotter_process: mp.Process
self._plotter_object: ProcessPlotter[T]
[docs]
def initialise_plotter(self, plotter: "ProcessPlotter[T]"):
"""Initialize and start a plotter process.
Args
----
plotter:
The plotter instance to run in a separate process
"""
plotter_pipe, self._plot_pipe = mp.Pipe(False)
self._plotter_object = plotter
self._plotter_process = mp.Process(target=self._plotter_object.run, args=(plotter_pipe,), daemon=True)
self._plotter_process.start()
[docs]
def _update(self, data: "T | None"):
"""Send data to the plotter process.
Args
----
data:
Data to send to the plotter for updating the plot
"""
if self._plot_pipe.closed:
return
self._plot_pipe.send(data)
[docs]
async def async_stop(self):
"""Asynchronously stop the plotter process.
Returns
-------
Result of sending stop signal through the pipe
"""
return self._plot_pipe.send(None)