153 lines
5.4 KiB
Python
153 lines
5.4 KiB
Python
from dataclasses import dataclass,field
|
|
from enum import Enum
|
|
from typing import List,Dict,Tuple,Set
|
|
from shapely import Polygon,Point
|
|
from geometry_setup import GeometrySetup
|
|
import shapely
|
|
import numpy as np
|
|
|
|
class CrosswalkStatus(Enum):
|
|
INACTIVE = "inactive"
|
|
ACTIVE = "active"
|
|
COOLDOWN = "cooldown"
|
|
|
|
@dataclass(frozen=True)
|
|
class CrosswalkConfig:
|
|
trigger_dist:float
|
|
activation_time:float
|
|
cooldown_time:float
|
|
min_agents_activation:int
|
|
activation_delay:float
|
|
crosswalk_area:Polygon
|
|
trigger_area:Polygon
|
|
|
|
@dataclass
|
|
class CrosswalkState:
|
|
status:CrosswalkStatus
|
|
is_active:bool
|
|
time_active:float
|
|
time_remaining:float
|
|
agents_in_trigger:int
|
|
agents_waiting:int
|
|
activation_count:int
|
|
|
|
class CrosswalkController:
|
|
def __init__(self,config:CrosswalkConfig,current_setup:bool):
|
|
self.config = config
|
|
self.crosswalk_area = config.crosswalk_area
|
|
self.trigger_area = config.trigger_area
|
|
self.status = CrosswalkStatus.INACTIVE
|
|
self.state_start_time = 0.0
|
|
self.current_time = 0.0
|
|
self.activation_count = 0
|
|
self.agents_in_trigger: Set[int] = set()
|
|
self.agents_waiting: Set[int] = set()
|
|
self.total_active_time = 0.0
|
|
self.agents_served = 0
|
|
|
|
def update(
|
|
self,
|
|
agent_pos:Dict[int,Tuple[float,float]],
|
|
time_step:float,
|
|
current_time:float=None
|
|
)->CrosswalkState:
|
|
if current_time is not None:
|
|
self.current_time = current_time
|
|
else:
|
|
self.current_time +=time_step
|
|
self.update_agents_in_trigger(agent_pos)
|
|
self.update_state(time_step)
|
|
self.update_statistics(time_step)
|
|
return self.get_state()
|
|
|
|
def update_agents_in_trigger(
|
|
self,
|
|
agent_pos:Dict[int,Tuple[float,float]]
|
|
)->None:
|
|
self.agents_in_trigger.clear()
|
|
if not agent_pos:
|
|
return
|
|
agent_ids = list(agent_pos.keys())
|
|
positions = np.array(list(agent_pos.values()),dtype=np.float32)
|
|
in_trigger = self.points_in_area(positions,self.trigger_area)
|
|
for i, agent_id in enumerate(agent_ids):
|
|
if in_trigger[i]:
|
|
self.agents_in_trigger.add(agent_id)
|
|
|
|
@staticmethod
|
|
def points_in_area(points:np.ndarray,polygon:Polygon)->np.ndarray:
|
|
if len(points) ==0:
|
|
return np.array([],dtype=bool)
|
|
buffered_area = polygon.buffer(1e-9)
|
|
inside = np.zeros(len(points),dtype=bool)
|
|
for i,(x,y) in enumerate(points):
|
|
point = Point(x,y)
|
|
inside[i] = buffered_area.contains(point)
|
|
return inside
|
|
|
|
def update_state(self,time_step:float)->None:
|
|
elapsed = self.current_time - self.state_start_time
|
|
if self.status == CrosswalkStatus.ACTIVE:
|
|
if elapsed >= self.config.activation_time:
|
|
self.deactivate()
|
|
elif self.status == CrosswalkStatus.COOLDOWN:
|
|
if elapsed >= self.config.cooldown_time:
|
|
self.status = CrosswalkStatus.INACTIVE
|
|
self.state_start_time = self.current_time
|
|
elif self.status ==CrosswalkStatus.INACTIVE:
|
|
if (len(self.agents_in_trigger)>=self.config.min_agents_activation and \
|
|
elapsed >= self.config.activation_delay):
|
|
self.activate()
|
|
|
|
def activate(self)->None:
|
|
self.status = CrosswalkStatus.ACTIVE
|
|
self.state_start_time = self.current_time
|
|
self.activation_count +=1
|
|
self.agents_served += len(self.agents_waiting)
|
|
self.agents_waiting.clear()
|
|
|
|
def deactivate(self)->None:
|
|
self.status = CrosswalkStatus.COOLDOWN
|
|
self.state_start_time = self.current_time
|
|
|
|
def update_statistics(self,time_step:float)->None:
|
|
if self.status == CrosswalkStatus.ACTIVE:
|
|
self.total_active_time += time_step
|
|
|
|
def get_state(self)->CrosswalkState:
|
|
elapsed = self.current_time-self.state_start_time
|
|
if self.status == CrosswalkStatus.ACTIVE:
|
|
time_remaining = max(0.0,self.config.activation_time-elapsed)
|
|
elif self.status == CrosswalkStatus.COOLDOWN:
|
|
time_remaining = max(0.0,self.config.cooldown_time-elapsed)
|
|
else:
|
|
time_remaining = 0.0
|
|
return CrosswalkState(
|
|
status=self.status,
|
|
is_active=(self.status == CrosswalkStatus.ACTIVE),
|
|
time_active=elapsed if self.status == CrosswalkStatus.ACTIVE else 0.0,
|
|
time_remaining=time_remaining,
|
|
agents_in_trigger=len(self.agents_in_trigger),
|
|
agents_waiting=len(self.agents_waiting),
|
|
activation_count = self.activation_count
|
|
)
|
|
def can_agent_cross(self,agent_id:int)->bool:
|
|
if self.status == CrosswalkStatus.ACTIVE:
|
|
return True
|
|
else:
|
|
if agent_id in self.agents_in_trigger:
|
|
self.agents_waiting.add(agent_id)
|
|
return False
|
|
|
|
@property
|
|
def is_active(self)->bool:
|
|
return self.status == CrosswalkStatus.ACTIVE
|
|
|
|
def get_statistics(self)->Dict[str,float]:
|
|
return {
|
|
"total_activations":self.activation_count,
|
|
"total_active_time":self.total_active_time,
|
|
"agents_served":self.agents_served,
|
|
"current_agents_in_trigger":len(self.agents_in_trigger),
|
|
"current_agents_waiting":len(self.agents_waiting)
|
|
} |