from dataclasses import dataclass,field from enum import Enum from typing import List,Dict,Tuple,Set from shapely import Polygon,Point from geometry_setup import GeometrySetup import shapely import numpy as np class CrosswalkStatus(Enum): INACTIVE = "inactive" ACTIVE = "active" COOLDOWN = "cooldown" @dataclass(frozen=True) class CrosswalkConfig: trigger_dist:float activation_time:float cooldown_time:float min_agents_activation:int activation_delay:float crosswalk_area:Polygon trigger_area:Polygon @dataclass class CrosswalkState: status:CrosswalkStatus is_active:bool time_active:float time_remaining:float agents_in_trigger:int agents_waiting:int activation_count:int class CrosswalkController: def __init__(self,config:CrosswalkConfig,current_setup:bool): self.config = config self.crosswalk_area = config.crosswalk_area self.trigger_area = config.trigger_area self.status = CrosswalkStatus.INACTIVE self.state_start_time = 0.0 self.current_time = 0.0 self.activation_count = 0 self.agents_in_trigger: Set[int] = set() self.agents_waiting: Set[int] = set() self.total_active_time = 0.0 self.agents_served = 0 def update( self, agent_pos:Dict[int,Tuple[float,float]], time_step:float, current_time:float=None )->CrosswalkState: if current_time is not None: self.current_time = current_time else: self.current_time +=time_step self.update_agents_in_trigger(agent_pos) self.update_state(time_step) self.update_statistics(time_step) return self.get_state() def update_agents_in_trigger( self, agent_pos:Dict[int,Tuple[float,float]] )->None: self.agents_in_trigger.clear() if not agent_pos: return agent_ids = list(agent_pos.keys()) positions = np.array(list(agent_pos.values()),dtype=np.float32) in_trigger = self.points_in_area(positions,self.trigger_area) for i, agent_id in enumerate(agent_ids): if in_trigger[i]: self.agents_in_trigger.add(agent_id) @staticmethod def points_in_area(points:np.ndarray,polygon:Polygon)->np.ndarray: if len(points) ==0: return np.array([],dtype=bool) buffered_area = polygon.buffer(1e-9) inside = np.zeros(len(points),dtype=bool) for i,(x,y) in enumerate(points): point = Point(x,y) inside[i] = buffered_area.contains(point) return inside def update_state(self,time_step:float)->None: elapsed = self.current_time - self.state_start_time if self.status == CrosswalkStatus.ACTIVE: if elapsed >= self.config.activation_time: self.deactivate() elif self.status == CrosswalkStatus.COOLDOWN: if elapsed >= self.config.cooldown_time: self.status = CrosswalkStatus.INACTIVE self.state_start_time = self.current_time elif self.status ==CrosswalkStatus.INACTIVE: if (len(self.agents_in_trigger)>=self.config.min_agents_activation and \ elapsed >= self.config.activation_delay): self.activate() def activate(self)->None: self.status = CrosswalkStatus.ACTIVE self.state_start_time = self.current_time self.activation_count +=1 self.agents_served += len(self.agents_waiting) self.agents_waiting.clear() def deactivate(self)->None: self.status = CrosswalkStatus.COOLDOWN self.state_start_time = self.current_time def update_statistics(self,time_step:float)->None: if self.status == CrosswalkStatus.ACTIVE: self.total_active_time += time_step def get_state(self)->CrosswalkState: elapsed = self.current_time-self.state_start_time if self.status == CrosswalkStatus.ACTIVE: time_remaining = max(0.0,self.config.activation_time-elapsed) elif self.status == CrosswalkStatus.COOLDOWN: time_remaining = max(0.0,self.config.cooldown_time-elapsed) else: time_remaining = 0.0 return CrosswalkState( status=self.status, is_active=(self.status == CrosswalkStatus.ACTIVE), time_active=elapsed if self.status == CrosswalkStatus.ACTIVE else 0.0, time_remaining=time_remaining, agents_in_trigger=len(self.agents_in_trigger), agents_waiting=len(self.agents_waiting), activation_count = self.activation_count ) def can_agent_cross(self,agent_id:int)->bool: if self.status == CrosswalkStatus.ACTIVE: return True else: if agent_id in self.agents_in_trigger: self.agents_waiting.add(agent_id) return False @property def is_active(self)->bool: return self.status == CrosswalkStatus.ACTIVE def get_statistics(self)->Dict[str,float]: return { "total_activations":self.activation_count, "total_active_time":self.total_active_time, "agents_served":self.agents_served, "current_agents_in_trigger":len(self.agents_in_trigger), "current_agents_waiting":len(self.agents_waiting) }