From 9b8bc021d63bc80687a3b724af5ec265881f05a3 Mon Sep 17 00:00:00 2001 From: Weihang Zheng Date: Tue, 28 Nov 2023 14:44:16 -0500 Subject: [PATCH] add cycle detection and cycle prevention --- src/nepiada/hist_baseline.py | 222 +++++++++++++++++++++++++++++++++++ src/nepiada/utils/agent.py | 3 + src/nepiada/utils/config.py | 4 + 3 files changed, 229 insertions(+) create mode 100644 src/nepiada/hist_baseline.py diff --git a/src/nepiada/hist_baseline.py b/src/nepiada/hist_baseline.py new file mode 100644 index 0000000..e98c7fe --- /dev/null +++ b/src/nepiada/hist_baseline.py @@ -0,0 +1,222 @@ +# Local imports +import numpy as np +import env.nepiada as nepiada +from utils.config import BaselineConfig +import pygame + + +def calculate_cost(agent_name, target_neighbours, beliefs, grid_size): + """ + This function calculates the cost value given the agent's beliefs. + Before calling this method, the agent's belief of itself should be updated + with it's intended new position. + + It is calculated using a similar function to the cost function in + Diane and Prof. Pavel's Paper, however modified to the discrete space. + """ + arrangement_cost = 0 + target_neighbor_cost = 0 + + target_x = grid_size / 2 + target_y = grid_size / 2 + + length = len(beliefs) + # Calculate the global arrangement cost + for curr_agent_name, agent_belief in beliefs.items(): + if agent_belief is None: + length -= 1 + else: + arrangement_cost += np.sqrt( + (agent_belief[0] - target_x) ** 2 + (agent_belief[1] - target_y) ** 2 + ) + + arrangement_cost /= length + + # Calculate the target neighbour cost + for curr_agent_name, target_relative_position in target_neighbours.items(): + assert beliefs[agent_name] is not None + + curr_agent_position = beliefs[curr_agent_name] + agent_position = beliefs[agent_name] + + if curr_agent_position is None: + # What do we do here? + continue + target_neighbor_cost += np.sqrt( + (curr_agent_position[0] - agent_position[0] - target_relative_position[0]) + ** 2 + + (curr_agent_position[1] - agent_position[1] - target_relative_position[1]) + ** 2 + ) + + # Return cost, should this be weighted? + return arrangement_cost + target_neighbor_cost + + +def create_beliefs_with_obs(agent_name, observations, all_agents): + """ + Create a new beliefs dict with the agent's observations filled in as the + groundtruth. If the agent has no observations of another agent, + the belief of that agent is set to None. + """ + beliefs = {} + + for agent_name in all_agents: + if observations[agent_name]: + beliefs[agent_name] = observations[agent_name] + else: + beliefs[agent_name] = None + + return beliefs + + +def strip_extreme_values_and_update_beliefs( + D_value, incoming_messages, curr_beliefs, new_beliefs, agent_name, all_agents +): + """ + This function strips the extreme values from the incoming messages according + to the D value. It strips the D greater values compared to it's current beliefs, + as well as the D lesser values compared to it's current beliefs and updates the beliefs + with the average of the remaining communication messages. If no communication messages + are left for the remaining agents, the agent's new belief of the target's agents position + remains unchanged. + """ + for current_agent in all_agents: + if current_agent == agent_name or new_beliefs[current_agent] is not None: + continue + if current_agent not in incoming_messages or incoming_messages[current_agent] is None: + # No incoming messages about this agent, keep previous state + new_beliefs[current_agent] = curr_beliefs[current_agent] + continue + + in_messages = [] + + # Get incoming messages that contain this agent's position + for _, comm_message in incoming_messages[current_agent].items(): + if comm_message is not None: + in_messages.append(comm_message) + + # Strip the extreme values + if curr_beliefs[current_agent] is None: + # Average all the incoming messages for the case where we don't have an estimate for the current agent + x_pos_mean = sum([message[0] for message in in_messages]) / len(in_messages) + y_pos_mean = sum([message[1] for message in in_messages]) / len(in_messages) + new_beliefs[current_agent] = (x_pos_mean, y_pos_mean) + continue + + x_pos_deviation = [] + y_pos_deviation = [] + for message in in_messages: + x_pos_deviation.append(message[0] - curr_beliefs[current_agent][0]) + y_pos_deviation.append(message[1] - curr_beliefs[current_agent][1]) + + if len(x_pos_deviation) <= D_value * 2 or len(y_pos_deviation) <= D_value * 2: + # Not enough messages to strip + new_beliefs[current_agent] = curr_beliefs[current_agent] + continue + + # Sort the deviations + x_pos_deviation.sort() + y_pos_deviation.sort() + + # Remove D lowest and D highest values + x_pos_deviation = x_pos_deviation[D_value:-D_value] + y_pos_deviation = y_pos_deviation[D_value:-D_value] + + # Average the remaining values + x_pos_delta = sum(x_pos_deviation) / len(x_pos_deviation) + y_pos_delta = sum(y_pos_deviation) / len(y_pos_deviation) + + # Update the beliefs + new_beliefs[current_agent] = ( + curr_beliefs[current_agent][0] + x_pos_delta, + curr_beliefs[current_agent][1] + y_pos_delta, + ) + +def is_observation_not_empty(observation): + return any(value is not None for value in observation.values()) + +def has_identical_observation(past_observations, current_observation): + return current_observation in past_observations + +def step(agent_name, agent_instance, observations, infos, env, config): + + print(observations) + # Hashing observations and storing them + # Hashing observations and storing them + # Function to check for identical observations + + # Check if the observation is not empty + if is_observation_not_empty(observations): + # Append the current observation to the past observations list + agent_instance.past_observations.append(observations) + + # Check for identical observations + if has_identical_observation(agent_instance.past_observations[:-1], observations): + return 0 + + # Create new beliefs dict with observation information + new_beliefs = create_beliefs_with_obs( + agent_name, + observations, + env.agents + ) + + # If there are incoming messages, process them and update beliefs + if "incoming_messages" in infos: + strip_extreme_values_and_update_beliefs( + config.D, + infos["incoming_messages"], + agent_instance.beliefs, + new_beliefs, + agent_name, + env.agents, + ) + + # Calculate the cost for every possible action + action_costs = {} + for action in range(env.action_space(agent_name).n): + # Calculate the cost for the action + new_beliefs[agent_name] = ( + agent_instance.p_pos[0] + config.possible_moves[action][0], + agent_instance.p_pos[1] + config.possible_moves[action][1], + ) + action_costs[action] = calculate_cost( + agent_name, agent_instance.target_neighbour, new_beliefs, config.size + ) + + # Choose the action with the lowest cost + min_action = min(action_costs, key=action_costs.get) + new_beliefs[agent_name] = ( + agent_instance.p_pos[0] + config.possible_moves[min_action][0], + agent_instance.p_pos[1] + config.possible_moves[min_action][1], + ) + agent_instance.beliefs = new_beliefs + + return min_action + + +def main(): + env_config = BaselineConfig() + + env = nepiada.parallel_env(config=env_config) + observations, infos = env.reset() + + while env.agents: + actions = {} + for c_agent in env.agents: + curr_agent = infos[c_agent]["agent_instance"] + agent_action = step( + c_agent, curr_agent, observations[c_agent], infos[c_agent], env, env_config + ) + print(agent_action) + actions[c_agent] = agent_action + + observations, rewards, terminations, truncations, infos = env.step(actions) + + env.close() + pygame.quit() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/nepiada/utils/agent.py b/src/nepiada/utils/agent.py index ab91298..c2be510 100644 --- a/src/nepiada/utils/agent.py +++ b/src/nepiada/utils/agent.py @@ -50,6 +50,9 @@ def __init__(self, type): # This dictionary stores the ideal distance from a drone's neighbour, based on relative_x and relative_y distance self.target_neighbour = {} + # stores past observations in an agent for cycle detection + self.past_observations = [] + print("Agent with uid " + str(self.uid) + " has been initialized") def set_target_neighbour(self, neighbour_name, distance): diff --git a/src/nepiada/utils/config.py b/src/nepiada/utils/config.py index a7d66fe..a886e28 100644 --- a/src/nepiada/utils/config.py +++ b/src/nepiada/utils/config.py @@ -50,6 +50,10 @@ class Config: 2: (0, -1), 3: (-1, 0), 4: (1, 0), + 5: (1, 1), + 6: (1, -1), + 7: (-1, 1), + 8: (-1, -1), } empty_cell: int = -1 global_reward_weight: int = 1