diff --git a/pysemtools/datatypes/msh_connectivity.py b/pysemtools/datatypes/msh_connectivity.py index ef5c6f0..b0d2d1f 100644 --- a/pysemtools/datatypes/msh_connectivity.py +++ b/pysemtools/datatypes/msh_connectivity.py @@ -19,6 +19,9 @@ facet_to_vertex_map, ) import sys +from typing import Tuple, cast +import math +from mpi4py import MPI __all__ = ['MeshConnectivity'] @@ -367,8 +370,466 @@ def get_multiplicity(self, msh: Mesh): local_appearances + global_appearances ) + def validate_rank_size(self, total_elements: int, size: int, num_elements: int): + """ + Validates whether the total number of elements and sections can be evenly distributed among ranks. + - Criterion 1: Total number of elements must be evenly divisible across all ranks. + - Criterion 2: Each cross-section must be assigned an integer number of ranks. + + Args: + total_elements (int): The total number of elements in the dataset. + size (int): The number of ranks available for computation. + num_elements (int): The number of elements in each section. + + Returns: + bool: True if the distribution is valid, False otherwise. + """ + # Criterion 1: Total number of elements must be evenly divisible across all ranks. + elements_per_rank = total_elements // size + if total_elements % size != 0: + print(f"Error: Criteria 1 Failed, {total_elements} % {size} != 0 (elements_per_rank = {elements_per_rank})") if self.rt.comm.rank == 0 else None + return False + + # Criterion 2: Each cross-section must be assigned an integer number of ranks. + ranks_per_section = num_elements // elements_per_rank + if num_elements % elements_per_rank != 0: + print(f"Error: Criteria 2 Failed, {num_elements} % {elements_per_rank} != 0 (ranks_per_section = {ranks_per_section})") if self.rt.comm.rank == 0 else None + return False + + print(f"Both Criteria Passed: elements_per_rank = {elements_per_rank}, ranks_per_section = {ranks_per_section}") if self.rt.comm.rank == 0 else None + return True + + def find_valid_sizes(self, total_elements: int, num_elements: int, min_size=150, max_size=2000): + """ + Finds valid rank sizes for distributing elements evenly across computational processes. + + Args: + total_elements (int): Total number of elements. + num_elements (int): Number of elements per section. + min_size (int, optional): Minimum size of ranks to consider. Defaults to 150. + max_size (int, optional): Maximum size of ranks to consider. Defaults to 2000. + + Returns: + list: A list of valid sizes for distributing elements. + """ + valid_sizes = [] + + for size in range(min_size, max_size + 1): + if total_elements % size == 0: + elements_per_rank = total_elements // size + if num_elements % elements_per_rank == 0: + valid_sizes.append(size) + + if not valid_sizes: + print("No valid sizes found. Consider adjusting your min/max limits.") + print("Valid size options:", valid_sizes) + + return valid_sizes + + def get_periodicity_map(self, msh: Mesh = None, offset_vector: Tuple[int, int, int] = (0, 0, 0), num_elements: int = None, pattern_factor: int = 1): + """ + Generate a periodicity mapping for the given mesh. Identifies entities (vertices, edges, facets) that are periodic partners across all MPI ranks + using a ring-exchange scheme. It builds/updates shared maps for rank, element, facet, edge, and vertex. + + Notation used: + One "pattern cross-section" is formed by repeating a specific group of cross-sections. + The number of cross-sections required to complete one full pattern is referred to as the "pattern_factor". + + Parameters: + ---------- + msh : Mesh + The mesh object containing element and vertex data. + offset_vector : (Tuple[int, int, int]) + The offset to apply to the coordinates for periodicity. + num_elements : (int) + The number of elements in each cross-section. + pattern_factor : (int) + The number of cross-sections that form one complete pattern cross-section, must be >= 2. + + This function: + 1. Validates the rank size to ensure even distribution of elements. + 2. Determines previous and next ranks for circular communication. + 3. Based on mesh dim ('msh.gdim'), computes local vertices, edge centers and facet centers. + 4. Applies the specified offset to generate shifted coordinates for matching. + 5. Performs an MPI ring exchange so each rank receives and compares shifted vertices from every other rank. + 6. For matched entities, updates the global shared maps of rank, elements and facets/edges/vertices. + + Returns: + ------- + None + Updates the mesh objects with periodicity mappings. + """ + # Validate rank size + self.total_elements = num_elements * pattern_factor # Total elements in the file (e.g., 2688) + if self.rt.comm.rank == 0: + valid_size = self.validate_rank_size(self.total_elements, self.rt.comm.size, num_elements) + valid_pattern = pattern_factor >= 2 + if not valid_pattern: + print("Invalid pattern factor: must be greater than or equal to 2.") + if not valid_size and valid_pattern: + self.find_valid_sizes(self.total_elements, num_elements) + self.check = valid_size and valid_pattern + else: + self.check = None + + self.check = self.rt.comm.bcast(self.check, root=0) + if not self.check: + return + + # Determine the previous and next ranks for circular topology + prev_rank = (self.rt.comm.rank - 1 + self.rt.comm.size) % self.rt.comm.size + next_rank = (self.rt.comm.rank + 1) % self.rt.comm.size + + # Utility: Checks if two centres are equal within tolerance + def are_coords_close(c1, c2, rtol=1e-5): + """ + Check if two coordinates are equal within a specified absolute tolerance. + """ + return all(math.isclose(a, b, rel_tol=0, abs_tol=rtol) for a, b in zip(c1, c2)) + + # Utility: Update mapping dictionary (append values safely) + def update_map(target_map, key, value): + """ + Update a mapping dictionary safely by appending values. + + Parameters + ---------- + target_map : dict + Dictionary mapping keys like (elem, subentity) → numpy array of integers. + Example: self.global_shared_evp_to_rank_map for vertices. + key : tuple + The key to update (e.g., (element_id, vertex_id)). + value : int + Value to append to the array at the given key. + + Returns + ------- + None + Updates the target_map. + """ + if key in target_map: + target_map[key] = np.append(target_map[key], value) + else: + target_map[key] = np.array([value], dtype=int) + + # Utility: Generic receiver for periodic matching updates + def process_incoming_matches(tag, rank_map, elem_map, submap, sub_key_name): + """ + Generic receiver for periodic matching updates. + + Parameters + ---------- + tag : int + MPI tag used for this mesh type. + rank_map : dict + Map from (elem, subentity) → rank. + Example: self.global_shared_evp_to_rank_map for vertices. + elem_map : dict + Map from (elem, subentity) → matching element. + Example: self.global_shared_evp_to_elem_map for vertices. + submap : dict + Map from (elem, subentity) → matching subentity (vertex/facet/edge). + Example: self.global_shared_evp_to_vertex_map for vertices. + sub_key_name : str + Name of the subentity type being processed ('vertex', 'facet', or 'edge'). + + Return + ------ + None + Updates the rank/elem/subentity maps. + """ + status = MPI.Status() + + # Keep receiving until no more messages with this tag are pending + while self.rt.comm.iprobe(source=MPI.ANY_SOURCE, tag=tag): + incoming = self.rt.comm.recv(source=MPI.ANY_SOURCE, tag=tag, status=status) + + sender = status.Get_source() + local_elem = incoming["remote_elem"] + local_vertex = incoming[f"remote_{sub_key_name}"] + remote_elem = incoming["local_elem"] + remote_vertex = incoming[f"local_{sub_key_name}"] + + key_local = (local_elem, local_vertex) # my own vertex + + # Source rank updates ITS local maps + update_map(rank_map, key_local, sender) + update_map(elem_map, key_local, remote_elem) + update_map(submap, key_local, remote_vertex) + + if msh.gdim >= 1: + # ---------------------- + # Step 1: Compute local vertices + # We extract only those vertices that appear on the "incomplete" side of the periodic boundary. + # ---------------------- + local_vertices = { + (elem, vertex): tuple(msh.vertices[elem, vertex]) + for elem, vertex in zip(self.incomplete_evp_elem, self.incomplete_evp_vertex) + } + + # ---------------------- + # Step 2: Apply periodic offset to vertices + # The offset vector represents the periodic displacement between the two boundary sections. Adding it generates the "periodically shifted" + # version of the vertex positions, used for matching. + # ---------------------- + scaled_vertex = { + (elem, vertex): (x + offset_vector[0], y + offset_vector[1], z + offset_vector[2]) + for (elem, vertex), (x, y, z) in local_vertices.items() + } + + # ---------------------- + # Step 3: Distribute vertex bundles around the communication ring + # Each rank forwards its scaled vertices to the next rank and receives the previous rank's data. Over (size - 1) iterations, + # every rank will see every other rank’s periodic-shifted vertices. + # ---------------------- + current_bundle = scaled_vertex # Start with your own data + + for i in range(self.rt.comm.size - 1): + + received_bundle = self.rt.comm.sendrecv(current_bundle, dest=next_rank, sendtag=99, source=prev_rank, recvtag=99) + + # Extract received vertex/edge/facet data + received_scaled_vertex = received_bundle + + # Rank of the bundle we just received + source_rank = (self.rt.comm.rank - i - 1) % self.rt.comm.size + + # ---------------------- + # Step 4: Match received vertices with local original vertices + # If a received scaled coordinate matches a local coordinate, those two vertices represent the same physical point under + # periodicity. We update our local maps accordingly and prepare a message to notify the source rank to update its own maps. + # ---------------------- + pending_msgs_vert = [] + + for (elem1, vertex1), scaled_coord in received_scaled_vertex.items(): # received data + for (elem2, vertex2), original_coord in local_vertices.items(): # local data + if are_coords_close(scaled_coord, original_coord): + + key1 = (elem1, vertex1) # received-side vertex (shifted) + key2 = (elem2, vertex2) # local-side vertex (original) + + # Local updates: "this vertex corresponds to source_rank's vertex" + update_map(self.global_shared_evp_to_rank_map, key2, source_rank) + update_map(self.global_shared_evp_to_elem_map, key2, elem1) + update_map(self.global_shared_evp_to_vertex_map, key2, vertex1) + + # Queue message to send back so source_rank updates its maps too + pending_msgs_vert.append({ + "remote_elem": elem1, + "remote_vertex": vertex1, + "local_elem": elem2, + "local_vertex": vertex2, + }) + + # Continue ring passing + current_bundle = received_bundle + + # ---------------------- + # Step 5: Send match messages back to the source rank + # These messages ensure symmetry: If I matched your vertex, you also update your side to reflect the match. + # ---------------------- + for msg in pending_msgs_vert: + self.rt.comm.send(msg, dest=source_rank, tag=199) + + self.rt.comm.Barrier() + + # ---------------------- + # Step 6: Process all incoming "update your maps" messages + # Each rank may receive multiple corrections from others. We pull everything tagged with 199 and update the maps to keep + # matching consistent from both sides. + # ---------------------- + process_incoming_matches(tag=199, + rank_map=self.global_shared_evp_to_rank_map, + elem_map=self.global_shared_evp_to_elem_map, + submap=self.global_shared_evp_to_vertex_map, + sub_key_name="vertex" + ) + + self.rt.comm.Barrier() + + if msh.gdim >= 2: + # ---------------------- + # Step 1: Compute local edges + # We extract only those edges that appear on the "incomplete" side of the periodic boundary. + # ---------------------- + local_edge_centers = { + (elem, edge): tuple(msh.edge_centers[elem, edge]) + for elem, edge in zip(self.incomplete_eep_elem, self.incomplete_eep_edge) + } + + # ---------------------- + # Step 2: Apply periodic offset to edges + # The offset vector represents the periodic displacement between the two boundary sections. Adding it generates the "periodically shifted" + # version of the edge positions, used for matching. + # ---------------------- + scaled_edge_centers = { + (elem, edge): (x + offset_vector[0], y + offset_vector[1], z + offset_vector[2]) + for (elem, edge), (x, y, z) in local_edge_centers.items() + } + + # ---------------------- + # Step 3: Distribute edge bundles around the communication ring + # Each rank forwards its scaled edge centers to the next rank and receives the previous rank's data. Over (size - 1) iterations, + # every rank will see every other rank’s periodic-shifted edge centers. + # ---------------------- + current_bundle = scaled_edge_centers # Start with your own data + + for i in range(self.rt.comm.size - 1): + + received_bundle = self.rt.comm.sendrecv(current_bundle, dest=next_rank, sendtag=98, source=prev_rank, recvtag=98) + + # Extract received vertex/edge/facet data + received_scaled_edge_centers = received_bundle + + # Rank of the bundle we just received + source_rank = (self.rt.comm.rank - i - 1) % self.rt.comm.size + + # ---------------------- + # Step 4: Match received edges with local original edges + # If a received scaled edge center matches a local edge center, those two edges represent the same physical edge under + # periodicity. We update our local maps accordingly and prepare a message to notify the source rank to update its own maps. + # ---------------------- + pending_msgs_edge = [] + + for (elem1, edge1), scaled_center in received_scaled_edge_centers.items(): # received data + for (elem2, edge2), original_center in local_edge_centers.items(): # local data + if are_coords_close(scaled_center, original_center): + + key1 = (elem1, edge1) # received-side edge (shifted) + key2 = (elem2, edge2) # local-side edge (original) + + # Local updates: "this edge corresponds to source_rank's edge" + update_map(self.global_shared_eep_to_rank_map, key2, source_rank) + update_map(self.global_shared_eep_to_elem_map, key2, elem1) + update_map(self.global_shared_eep_to_edge_map, key2, edge1) + + # Queue message to send back so source_rank updates its maps too + pending_msgs_edge.append({ + "remote_elem": elem1, + "remote_edge": edge1, + "local_elem": elem2, + "local_edge": edge2, + }) + + # Continue ring passing + current_bundle = received_bundle + + # ---------------------- + # Step 5: Send match messages back to the source rank + # These messages ensure symmetry: If I matched your edge, you also update your side to reflect the match. + # ---------------------- + for msg in pending_msgs_edge: + self.rt.comm.send(msg, dest=source_rank, tag=198) + + self.rt.comm.Barrier() + + # ---------------------- + # Step 6: Process all incoming "update your maps" messages + # Each rank may receive multiple corrections from others. We pull everything tagged with 198 and update the maps to keep + # matching consistent from both sides. + # ---------------------- + process_incoming_matches(tag=198, + rank_map=self.global_shared_eep_to_rank_map, + elem_map=self.global_shared_eep_to_elem_map, + submap=self.global_shared_eep_to_edge_map, + sub_key_name="edge" + ) + + self.rt.comm.Barrier() + + if msh.gdim >= 3: + # ---------------------- + # Step 1: Compute local facets + # We extract only those facets that appear on the "incomplete" side of the periodic boundary. + # ---------------------- + local_facet_centers = { + (elem, facet): tuple(msh.facet_centers[elem, facet]) + for elem, facet in zip(self.unique_efp_elem, self.unique_efp_facet) + } + + # ---------------------- + # Step 2: Apply periodic offset to facets + # The offset vector represents the periodic displacement between the two boundary sections. Adding it generates the "periodically shifted" + # version of the facet positions, used for matching. + # ---------------------- + scaled_facet_centers = { + (elem, facet): (x + offset_vector[0], y + offset_vector[1], z + offset_vector[2]) + for (elem, facet), (x, y, z) in local_facet_centers.items() + } + + # ---------------------- + # Step 3: Distribute facet bundles around the communication ring + # Each rank forwards its scaled facets centers to the next rank and receives the previous rank's data. Over (size - 1) iterations, + # every rank will see every other rank’s periodic-shifted facets centers. + # ---------------------- + current_bundle = scaled_facet_centers # Start with your own data + + for i in range(self.rt.comm.size - 1): + + received_bundle = self.rt.comm.sendrecv(current_bundle, dest=next_rank, sendtag=97, source=prev_rank, recvtag=97) + + # Extract received vertex/edge/facet data + received_scaled_face_centers = received_bundle + + # Rank of the bundle we just received + source_rank = (self.rt.comm.rank - i - 1) % self.rt.comm.size + + # ---------------------- + # Step 4: Match received facets with local original facets + # If a received scaled facet center matches a local facet center, those two facets represent the same physical facet under + # periodicity. We update our local maps accordingly and prepare a message to notify the source rank to update its own maps. + # ---------------------- + pending_msgs = [] + + for (elem1, facet1), scaled_center in received_scaled_face_centers.items(): # received data + for (elem2, facet2), original_center in local_facet_centers.items(): # local data + if are_coords_close(scaled_center, original_center): + + key1 = (elem1, facet1) # received-side facet (shifted) + key2 = (elem2, facet2) # local-side facet (original) + + # Local updates: "this facet corresponds to source_rank's facet" + update_map(self.global_shared_efp_to_rank_map, key2, source_rank) + update_map(self.global_shared_efp_to_elem_map, key2, elem1) + update_map(self.global_shared_efp_to_facet_map, key2, facet1) + + # Queue message to send back so source_rank updates its maps too + pending_msgs.append({ + "remote_elem": elem1, + "remote_facet": facet1, + "local_elem": elem2, + "local_facet": facet2, + }) + + # Continue ring passing + current_bundle = received_bundle + + # ---------------------- + # Step 5: Send match messages back to the source rank + # These messages ensure symmetry: If I matched your facet, you also update your side to reflect the match. + # ---------------------- + for msg in pending_msgs: + self.rt.comm.send(msg, dest=source_rank, tag=197) + + self.rt.comm.Barrier() + + # ---------------------- + # Step 6: Process all incoming "update your maps" messages + # Each rank may receive multiple corrections from others. We pull everything tagged with 197 and update the maps to keep + # matching consistent from both sides. + # ---------------------- + process_incoming_matches(tag=197, + rank_map=self.global_shared_efp_to_rank_map, + elem_map=self.global_shared_efp_to_elem_map, + submap=self.global_shared_efp_to_facet_map, + sub_key_name="facet" + ) + + self.rt.comm.Barrier() + def dssum( - self, field: np.ndarray = None, msh: Mesh = None, average: str = "multiplicity" + self, field: np.ndarray = None, msh: Mesh = None, average: str = "multiplicity", periodicity: bool = False, offset_vector: Tuple[int, int, int] = (0, 0, 0), + num_elements: int = None, pattern_factor: int = 1 ): """ Computes the dssum of the field @@ -381,6 +842,14 @@ def dssum( The mesh object average : str The averaging weights to use. Can be "multiplicity" + periodicity : bool, optional + If True, applies periodic connectivity mapping before summation. + offset_vector : Tuple[int, int, int], optional + The offset to apply when matching periodic entities. + num_elements : int, optional + Number of elements per cross-section, used for periodic mapping + pattern_factor : int, optional + Number of cross-sections that form one complete periodic pattern. Returns ------- @@ -395,6 +864,13 @@ def dssum( if self.rt.comm.Get_size() > 1: iferror = False try: + if periodicity: + periodicity_map = self.get_periodicity_map( + msh=msh, + offset_vector=offset_vector, + num_elements=num_elements, + pattern_factor=pattern_factor + ) dssum_field = self.dssum_global( local_dssum_field=dssum_field, field=field, msh=msh )