Coverage for / dolfinx-env / lib / python3.12 / site-packages / io4dolfinx / comm_helpers.py: 100%
110 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-26 18:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-26 18:16 +0000
1from __future__ import annotations
3from mpi4py import MPI
5import numpy as np
6import numpy.typing as npt
8from .utils import compute_insert_position, compute_local_range, valid_function_types
10__all__ = [
11 "send_dofmap_and_recv_values",
12 "send_and_recv_cell_perm",
13 "send_dofs_and_recv_values",
14 "numpy_to_mpi",
15]
16"""
17Helpers for sending and receiving values for checkpointing
18"""
20numpy_to_mpi = {
21 np.float64: MPI.DOUBLE,
22 np.float32: MPI.FLOAT,
23 np.complex64: MPI.COMPLEX,
24 np.complex128: MPI.DOUBLE_COMPLEX,
25 np.int64: MPI.INT64_T,
26 np.int32: MPI.INT32_T,
27}
30def send_dofmap_and_recv_values(
31 comm: MPI.Intracomm,
32 source_ranks: npt.NDArray[np.int32],
33 dest_ranks: npt.NDArray[np.int32],
34 output_owners: npt.NDArray[np.int32],
35 dest_size: npt.NDArray[np.int32],
36 input_cells: npt.NDArray[np.int64],
37 dofmap_pos: npt.NDArray[np.int32],
38 num_cells_global: np.int64,
39 values: npt.NDArray[valid_function_types],
40 dofmap_offsets: npt.NDArray[np.int32],
41) -> npt.NDArray[valid_function_types]:
42 """
43 Given a set of positions in input dofmap, give the global input index of this dofmap entry
44 in input file.
46 Args:
47 comm: The MPI communicator to create the Neighbourhood-communicator from
48 source_ranks: Ranks that will send dofmap indices to current process
49 dest_ranks: Ranks that will receive dofmap indices from current process
50 output_owners: The owners of each dofmap entry on this process. The unique set of
51 these entries should be the same as the dest_ranks.
52 dest_size: The number of entries sent to each owner
53 input_cells: A cell associated with the degree of freedom sent (global index).
54 dofmap_pos: The local position in the dofmap. I.e.
55 `dof = dofmap.links(input_cells)[dofmap_pos]`
56 num_cells_global: Number of global cells
57 values: Values currently held by this process. These are
58 ordered (num_cells_local, num_dofs_per_cell), flattened row-major.
59 dofmap_offsets: Local dofmap offsets to access the correct `values`.
61 Returns:
62 Values corresponding to the dofs owned by this process.
63 """
64 insert_position = compute_insert_position(output_owners, dest_ranks, dest_size)
66 # Pack the cells and dofmap position for all dofs this process is distributing
67 out_cells = np.zeros(len(output_owners), dtype=np.int64)
68 out_cells[insert_position] = input_cells
69 out_pos = np.zeros(len(output_owners), dtype=np.int32)
70 out_pos[insert_position] = dofmap_pos
72 # Compute map from the data index sent to each process and the local
73 # number on the current process
74 proc_to_dof = np.zeros_like(input_cells, dtype=np.int32)
75 proc_to_dof[insert_position] = np.arange(len(input_cells), dtype=np.int32)
76 del insert_position
78 # Send sizes to create data structures for receiving from NeighAlltoAllv
79 recv_size = np.zeros(len(source_ranks), dtype=np.int32)
80 mesh_to_data_comm = comm.Create_dist_graph_adjacent(
81 source_ranks.tolist(), dest_ranks.tolist(), reorder=False
82 )
83 mesh_to_data_comm.Neighbor_alltoall(dest_size, recv_size)
85 # Prepare data-structures for receiving
86 total_incoming = sum(recv_size)
87 inc_cells = np.zeros(total_incoming, dtype=np.int64)
88 inc_pos = np.zeros(total_incoming, dtype=np.intc)
90 # Compute incoming offset
91 inc_offsets = np.zeros(len(recv_size) + 1, dtype=np.intc)
92 inc_offsets[1:] = np.cumsum(recv_size)
94 # Send data
95 s_msg = [out_cells, dest_size, MPI.INT64_T]
96 r_msg = [inc_cells, recv_size, MPI.INT64_T]
97 mesh_to_data_comm.Neighbor_alltoallv(s_msg, r_msg)
99 s_msg = [out_pos, dest_size, MPI.INT32_T]
100 r_msg = [inc_pos, recv_size, MPI.INT32_T]
101 mesh_to_data_comm.Neighbor_alltoallv(s_msg, r_msg)
102 mesh_to_data_comm.Free()
104 local_input_range = compute_local_range(comm, num_cells_global)
105 values_to_distribute = np.zeros_like(inc_pos, dtype=values.dtype)
107 # Map values based on input cells and dofmap
108 local_cells = inc_cells - local_input_range[0]
109 values_to_distribute = values[dofmap_offsets[local_cells] + inc_pos]
111 # Send input dofs back to owning process
112 data_to_mesh_comm = comm.Create_dist_graph_adjacent(
113 dest_ranks.tolist(), source_ranks.tolist(), reorder=False
114 )
116 incoming_global_dofs = np.zeros(sum(dest_size), dtype=values.dtype)
117 s_msg = [values_to_distribute, recv_size, numpy_to_mpi[values.dtype.type]]
118 r_msg = [incoming_global_dofs, dest_size, numpy_to_mpi[values.dtype.type]]
119 data_to_mesh_comm.Neighbor_alltoallv(s_msg, r_msg)
121 # Sort incoming global dofs as they were inputted
122 assert len(incoming_global_dofs) == len(input_cells)
123 sorted_global_dofs = np.zeros_like(incoming_global_dofs, dtype=values.dtype)
124 sorted_global_dofs[proc_to_dof] = incoming_global_dofs
126 data_to_mesh_comm.Free()
127 return sorted_global_dofs
130def send_and_recv_cell_perm(
131 cells: npt.NDArray[np.int64],
132 perms: npt.NDArray[np.uint32],
133 cell_owners: npt.NDArray[np.int32],
134 comm: MPI.Intracomm,
135) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.uint32]]:
136 """
137 Send global cell index and permutation to corresponding entry in `dest_ranks`.
139 Args:
140 cells: The global input index of the cell
141 perms: The corresponding cell permutation of the cell
142 cell_owners: The rank to send the i-th entry of cells and perms to
143 comm: Rank of comm to generate neighbourhood communicator from
144 """
145 dest_ranks, _dest_size = np.unique(cell_owners, return_counts=True)
146 dest_size = _dest_size.astype(np.int32)
147 del _dest_size
149 mesh_to_data = comm.Create_dist_graph(
150 [comm.rank], [len(dest_ranks)], dest_ranks.tolist(), reorder=False
151 )
152 source, dest, _ = mesh_to_data.Get_dist_neighbors()
153 assert np.allclose(dest, dest_ranks)
154 insert_position = compute_insert_position(cell_owners, dest_ranks.astype(np.int32), dest_size)
156 # Pack cells and permutations for sending
157 out_cells = np.zeros_like(cells, dtype=np.int64)
158 out_perm = np.zeros_like(perms, dtype=np.uint32)
159 out_cells[insert_position] = cells
160 out_perm[insert_position] = perms
161 del insert_position
163 # Send sizes to create data structures for receiving from NeighAlltoAllv
164 recv_size = np.zeros_like(source, dtype=np.int32)
165 mesh_to_data.Neighbor_alltoall(dest_size, recv_size)
167 # Prepare data-structures for receiving
168 total_incoming = sum(recv_size)
169 inc_cells = np.zeros(total_incoming, dtype=np.int64)
170 inc_perm = np.zeros(total_incoming, dtype=np.uint32)
172 # Compute incoming offset
173 inc_offsets = np.zeros(len(recv_size) + 1, dtype=np.intc)
174 inc_offsets[1:] = np.cumsum(recv_size)
176 # Send data
177 s_msg = [out_cells, dest_size, MPI.INT64_T]
178 r_msg = [inc_cells, recv_size, MPI.INT64_T]
179 mesh_to_data.Neighbor_alltoallv(s_msg, r_msg)
181 s_msg = [out_perm, dest_size, MPI.UINT32_T]
182 r_msg = [inc_perm, recv_size, MPI.UINT32_T]
183 mesh_to_data.Neighbor_alltoallv(s_msg, r_msg)
184 mesh_to_data.Free()
185 return inc_cells, inc_perm
188def send_dofs_and_recv_values(
189 input_dofmap: npt.NDArray[np.int64],
190 dofmap_owners: npt.NDArray[np.int32],
191 comm: MPI.Intracomm,
192 input_array: npt.NDArray[valid_function_types],
193 array_start: int,
194):
195 """
196 Send a set of dofs (global index) to the process holding the DOF values to retrieve them.
198 Args:
199 input_dofmap: List of dofs (global index) that this process wants values for
200 dofmap_owners: The process currently holding the values this process want to get.
201 comm: MPI communicator
202 input_array: Values for dofs
203 array_start: The global starting index of `input_array`.
204 """
205 dest_ranks, _dest_size = np.unique(dofmap_owners, return_counts=True)
206 dest_size = _dest_size.astype(np.int32)
207 del _dest_size
209 dofmap_to_values = comm.Create_dist_graph(
210 [comm.rank], [len(dest_ranks)], dest_ranks.tolist(), reorder=False
211 )
213 source, dest, _ = dofmap_to_values.Get_dist_neighbors()
214 assert np.allclose(dest_ranks, dest)
215 # Compute amount of data to send to each process
217 insert_position = compute_insert_position(dofmap_owners, dest_ranks, dest_size)
219 # Pack dofs for sending
220 out_dofs = np.zeros(len(dofmap_owners), dtype=np.int64)
221 out_dofs[insert_position] = input_dofmap
223 # Compute map from the data index sent to each process and the local number on
224 # the current process
225 proc_to_local = np.zeros_like(input_dofmap, dtype=np.int32)
226 proc_to_local[insert_position] = np.arange(len(input_dofmap), dtype=np.int32)
227 del insert_position
229 # Send sizes to create data structures for receiving from NeighAlltoAllv
230 recv_size = np.zeros_like(source, dtype=np.int32)
231 recv_size.resize(max(len(recv_size), 1)) # Minimal resize to work with ompi
232 dest_size.resize(max(len(dest_size), 1)) # Mininal resize to work with ompi
233 dofmap_to_values.Neighbor_alltoall(dest_size, recv_size)
234 dest_size.resize(len(dest))
235 recv_size.resize(len(source))
237 # Send input dofs to processes holding input array
238 inc_dofs = np.zeros(sum(recv_size), dtype=np.int64)
239 s_msg = [out_dofs, dest_size, MPI.INT64_T]
240 r_msg = [inc_dofs, recv_size, MPI.INT64_T]
241 dofmap_to_values.Neighbor_alltoallv(s_msg, r_msg)
242 dofmap_to_values.Free()
244 # Send back appropriate input values
245 if len(input_array) > 0:
246 sending_values = input_array[inc_dofs - array_start]
247 else:
248 sending_values = np.zeros(0, dtype=input_array.dtype)
250 values_to_dofmap = comm.Create_dist_graph_adjacent(dest, source, reorder=False)
251 inc_values = np.zeros_like(out_dofs, dtype=input_array.dtype)
252 s_msg_rev = [sending_values, recv_size, numpy_to_mpi[input_array.dtype.type]]
253 r_msg_rev = [inc_values, dest_size, numpy_to_mpi[input_array.dtype.type]]
254 values_to_dofmap.Neighbor_alltoallv(s_msg_rev, r_msg_rev)
255 values_to_dofmap.Free()
257 # Sort inputs according to local dof number (input process)
258 values = np.empty_like(inc_values, dtype=input_array.dtype)
259 values[proc_to_local] = inc_values
260 return values