Coverage for /dolfinx-env/lib/python3.12/site-packages/io4dolfinx/original_checkpoint.py: 99%

189 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-08 09:09 +0000

1# Copyright (C) 2024 Jørgen Schartum Dokken 

2# 

3# This file is part of io4dolfinx 

4# 

5# SPDX-License-Identifier: MIT 

6 

7from __future__ import annotations 

8 

9import logging 

10import typing 

11from pathlib import Path 

12 

13from mpi4py import MPI 

14 

15import dolfinx 

16import numpy as np 

17 

18from . import compat 

19from .backends import FileMode, get_backend 

20from .comm_helpers import numpy_to_mpi 

21from .structures import FunctionData, MeshData 

22from .utils import ( 

23 compute_insert_position, 

24 compute_local_range, 

25 index_owner, 

26 unroll_dofmap, 

27 unroll_insert_position, 

28) 

29 

30__all__ = ["write_function_on_input_mesh", "write_mesh_input_order"] 

31logger = logging.getLogger(__name__) 

32 

33 

34def create_original_mesh_data(mesh: dolfinx.mesh.Mesh) -> MeshData: 

35 """ 

36 Store data locally on output process 

37 """ 

38 

39 # 1. Send cell indices owned by current process to the process which owned its input 

40 

41 # Get the input cell index for cells owned by this process 

42 num_owned_cells = mesh.topology.index_map(mesh.topology.dim).size_local 

43 original_cell_index = mesh.topology.original_cell_index[:num_owned_cells] 

44 

45 # Compute owner of cells on this process based on the original cell index 

46 num_cells_global = mesh.topology.index_map(mesh.topology.dim).size_global 

47 output_cell_owner = index_owner(mesh.comm, original_cell_index, num_cells_global) 

48 local_cell_range = compute_local_range(mesh.comm, num_cells_global) 

49 

50 # Compute outgoing edges from current process to outputting process 

51 # Computes the number of cells sent to each process at the same time 

52 cell_destinations, _send_cells_per_proc = np.unique(output_cell_owner, return_counts=True) 

53 send_cells_per_proc = _send_cells_per_proc.astype(np.int32) 

54 del _send_cells_per_proc 

55 cell_to_output_comm = mesh.comm.Create_dist_graph( 

56 [mesh.comm.rank], 

57 [len(cell_destinations)], 

58 cell_destinations.tolist(), 

59 reorder=False, 

60 ) 

61 cell_sources, cell_dests, _ = cell_to_output_comm.Get_dist_neighbors() 

62 assert np.allclose(cell_dests, cell_destinations) 

63 

64 # Compute number of recieving cells 

65 recv_cells_per_proc = np.zeros_like(cell_sources, dtype=np.int32) 

66 if len(send_cells_per_proc) == 0: 

67 send_cells_per_proc = np.zeros(1, dtype=np.int32) 

68 if len(recv_cells_per_proc) == 0: 

69 recv_cells_per_proc = np.zeros(1, dtype=np.int32) 

70 send_cells_per_proc = send_cells_per_proc.astype(np.int32) 

71 cell_to_output_comm.Neighbor_alltoall(send_cells_per_proc, recv_cells_per_proc) 

72 assert recv_cells_per_proc.sum() == local_cell_range[1] - local_cell_range[0] 

73 # Pack and send cell indices (used for mapping topology dofmap later) 

74 cell_insert_position = compute_insert_position( 

75 output_cell_owner, cell_destinations, send_cells_per_proc 

76 ) 

77 send_cells = np.empty_like(cell_insert_position, dtype=np.int64) 

78 send_cells[cell_insert_position] = original_cell_index 

79 recv_cells = np.empty(recv_cells_per_proc.sum(), dtype=np.int64) 

80 send_cells_msg = [send_cells, send_cells_per_proc, MPI.INT64_T] 

81 recv_cells_msg = [recv_cells, recv_cells_per_proc, MPI.INT64_T] 

82 cell_to_output_comm.Neighbor_alltoallv(send_cells_msg, recv_cells_msg) 

83 del send_cells_msg, recv_cells_msg, send_cells 

84 

85 # Map received cells to the local index 

86 local_cell_index = recv_cells - local_cell_range[0] 

87 

88 # 2. Create dofmap based on original geometry indices and re-order in the same order as original 

89 # cell indices on output process 

90 

91 # Get original node index for all nodes (including ghosts) and convert dofmap to these indices 

92 original_node_index = mesh.geometry.input_global_indices 

93 _, num_nodes_per_cell = compat.dofmap(mesh).shape 

94 local_geometry_dofmap = compat.dofmap(mesh)[:num_owned_cells, :] 

95 global_geometry_dofmap = original_node_index[local_geometry_dofmap.reshape(-1)] 

96 

97 # Unroll insert position for geometry dofmap 

98 dofmap_insert_position = unroll_insert_position(cell_insert_position, num_nodes_per_cell) 

99 

100 # Create and commmnicate connecitivity in original geometry indices 

101 send_geometry_dofmap = np.empty_like(dofmap_insert_position, dtype=np.int64) 

102 send_geometry_dofmap[dofmap_insert_position] = global_geometry_dofmap 

103 del global_geometry_dofmap 

104 send_sizes_dofmap = send_cells_per_proc * num_nodes_per_cell 

105 recv_sizes_dofmap = recv_cells_per_proc * num_nodes_per_cell 

106 recv_geometry_dofmap = np.empty(recv_sizes_dofmap.sum(), dtype=np.int64) 

107 send_geometry_dofmap_msg = [send_geometry_dofmap, send_sizes_dofmap, MPI.INT64_T] 

108 recv_geometry_dofmap_msg = [recv_geometry_dofmap, recv_sizes_dofmap, MPI.INT64_T] 

109 cell_to_output_comm.Neighbor_alltoallv(send_geometry_dofmap_msg, recv_geometry_dofmap_msg) 

110 del send_geometry_dofmap_msg, recv_geometry_dofmap_msg 

111 

112 # Reshape dofmap and sort by original cell index 

113 recv_dofmap = recv_geometry_dofmap.reshape(-1, num_nodes_per_cell) 

114 sorted_recv_dofmap = np.empty_like(recv_dofmap) 

115 sorted_recv_dofmap[local_cell_index] = recv_dofmap 

116 

117 # 3. Move geometry coordinates to input process 

118 # Compute outgoing edges from current process and create neighbourhood communicator 

119 # Also create number of outgoing cells at the same time 

120 num_owned_nodes = mesh.geometry.index_map().size_local 

121 num_nodes_global = mesh.geometry.index_map().size_global 

122 output_node_owner = index_owner( 

123 mesh.comm, original_node_index[:num_owned_nodes], num_nodes_global 

124 ) 

125 

126 node_destinations, _send_nodes_per_proc = np.unique(output_node_owner, return_counts=True) 

127 send_nodes_per_proc = _send_nodes_per_proc.astype(np.int32) 

128 del _send_nodes_per_proc 

129 

130 geometry_to_owner_comm = mesh.comm.Create_dist_graph( 

131 [mesh.comm.rank], 

132 [len(node_destinations)], 

133 node_destinations.tolist(), 

134 reorder=False, 

135 ) 

136 

137 node_sources, node_dests, _ = geometry_to_owner_comm.Get_dist_neighbors() 

138 assert np.allclose(node_dests, node_destinations) 

139 

140 # Compute send node insert positions 

141 send_nodes_position = compute_insert_position( 

142 output_node_owner, node_destinations, send_nodes_per_proc 

143 ) 

144 unrolled_nodes_positiion = unroll_insert_position(send_nodes_position, 3) 

145 

146 send_coordinates = np.empty_like(unrolled_nodes_positiion, dtype=mesh.geometry.x.dtype) 

147 send_coordinates[unrolled_nodes_positiion] = mesh.geometry.x[:num_owned_nodes, :].reshape(-1) 

148 

149 # Send and recieve geometry sizes 

150 send_coordinate_sizes = (send_nodes_per_proc * 3).astype(np.int32) 

151 recv_coordinate_sizes = np.zeros_like(node_sources, dtype=np.int32) 

152 geometry_to_owner_comm.Neighbor_alltoall(send_coordinate_sizes, recv_coordinate_sizes) 

153 

154 # Send node coordinates 

155 recv_coordinates = np.empty(recv_coordinate_sizes.sum(), dtype=mesh.geometry.x.dtype) 

156 mpi_type = numpy_to_mpi[recv_coordinates.dtype.type] 

157 send_coord_msg = [send_coordinates, send_coordinate_sizes, mpi_type] 

158 recv_coord_msg = [recv_coordinates, recv_coordinate_sizes, mpi_type] 

159 geometry_to_owner_comm.Neighbor_alltoallv(send_coord_msg, recv_coord_msg) 

160 del send_coord_msg, recv_coord_msg 

161 

162 # Send node ordering for reordering the coordinates on output process 

163 send_nodes = np.empty(num_owned_nodes, dtype=np.int64) 

164 send_nodes[send_nodes_position] = original_node_index[:num_owned_nodes] 

165 

166 recv_indices = np.empty(recv_coordinate_sizes.sum() // 3, dtype=np.int64) 

167 send_nodes_msg = [send_nodes, send_nodes_per_proc, MPI.INT64_T] 

168 recv_nodes_msg = [recv_indices, recv_coordinate_sizes // 3, MPI.INT64_T] 

169 geometry_to_owner_comm.Neighbor_alltoallv(send_nodes_msg, recv_nodes_msg) 

170 

171 # Compute local ording of received nodes 

172 local_node_range = compute_local_range(mesh.comm, num_nodes_global) 

173 recv_indices -= local_node_range[0] 

174 

175 # Sort geometry based on input index and strip to gdim 

176 gdim = mesh.geometry.dim 

177 recv_nodes = recv_coordinates.reshape(-1, 3) 

178 _geometry = np.empty(recv_nodes.shape, dtype=mesh.geometry.x.dtype) 

179 _geometry[recv_indices, :] = recv_nodes 

180 geometry = _geometry[:, :gdim].copy() 

181 del _geometry, recv_nodes 

182 

183 assert local_node_range[1] - local_node_range[0] == geometry.shape[0] 

184 cmap = compat.cmap(mesh) 

185 

186 cell_to_output_comm.Free() 

187 geometry_to_owner_comm.Free() 

188 

189 # NOTE: Could in theory store partitioning information, but would not work nicely 

190 # as one would need to read this data rather than the xdmffile. 

191 # NOTE: Local geometry type hint skip is only required on DOLFINX<0.10 where 

192 # proper `dolfinx.mesh.Geometry` wrapper doesn't exist 

193 return MeshData( 

194 local_geometry=geometry, # type: ignore[arg-type] 

195 local_geometry_pos=local_node_range, 

196 num_nodes_global=num_nodes_global, 

197 local_topology=sorted_recv_dofmap, 

198 local_topology_pos=local_cell_range, 

199 num_cells_global=num_cells_global, 

200 cell_type=mesh.topology.cell_name(), 

201 degree=cmap.degree, 

202 lagrange_variant=cmap.variant, 

203 store_partition=False, 

204 partition_processes=None, 

205 ownership_array=None, 

206 ownership_offset=None, 

207 partition_range=None, 

208 partition_global=None, 

209 ) 

210 

211 

212def create_function_data_on_original_mesh( 

213 u: dolfinx.fem.Function, name: typing.Optional[str] = None 

214) -> FunctionData: 

215 """ 

216 Create data object to save with ADIOS2 

217 """ 

218 mesh = u.function_space.mesh 

219 

220 # Compute what cells owned by current process should be sent to what output process 

221 # FIXME: Cache this 

222 num_owned_cells = mesh.topology.index_map(mesh.topology.dim).size_local 

223 original_cell_index = mesh.topology.original_cell_index[:num_owned_cells] 

224 

225 # Compute owner of cells on this process based on the original cell index 

226 num_cells_global = mesh.topology.index_map(mesh.topology.dim).size_global 

227 output_cell_owner = index_owner(mesh.comm, original_cell_index, num_cells_global) 

228 local_cell_range = compute_local_range(mesh.comm, num_cells_global) 

229 

230 # Compute outgoing edges from current process to outputting process 

231 # Computes the number of cells sent to each process at the same time 

232 cell_destinations, _send_cells_per_proc = np.unique(output_cell_owner, return_counts=True) 

233 send_cells_per_proc = _send_cells_per_proc.astype(np.int32) 

234 del _send_cells_per_proc 

235 cell_to_output_comm = mesh.comm.Create_dist_graph( 

236 [mesh.comm.rank], 

237 [len(cell_destinations)], 

238 cell_destinations.tolist(), 

239 reorder=False, 

240 ) 

241 cell_sources, cell_dests, _ = cell_to_output_comm.Get_dist_neighbors() 

242 assert np.allclose(cell_dests, cell_destinations) 

243 

244 # Compute number of recieving cells 

245 recv_cells_per_proc = np.zeros_like(cell_sources, dtype=np.int32) 

246 send_cells_per_proc = send_cells_per_proc.astype(np.int32) 

247 cell_to_output_comm.Neighbor_alltoall(send_cells_per_proc, recv_cells_per_proc) 

248 assert recv_cells_per_proc.sum() == local_cell_range[1] - local_cell_range[0] 

249 

250 # Pack and send cell indices (used for mapping topology dofmap later) 

251 cell_insert_position = compute_insert_position( 

252 output_cell_owner, cell_destinations, send_cells_per_proc 

253 ) 

254 send_cells = np.empty_like(cell_insert_position, dtype=np.int64) 

255 send_cells[cell_insert_position] = original_cell_index 

256 recv_cells = np.empty(recv_cells_per_proc.sum(), dtype=np.int64) 

257 send_cells_msg = [send_cells, send_cells_per_proc, MPI.INT64_T] 

258 recv_cells_msg = [recv_cells, recv_cells_per_proc, MPI.INT64_T] 

259 cell_to_output_comm.Neighbor_alltoallv(send_cells_msg, recv_cells_msg) 

260 del send_cells_msg, recv_cells_msg 

261 

262 # Map received cells to the local index 

263 local_cell_index = recv_cells - local_cell_range[0] 

264 

265 # Pack and send cell permutation info 

266 mesh.topology.create_entity_permutations() 

267 cell_permutation_info = mesh.topology.get_cell_permutation_info()[:num_owned_cells] 

268 send_perm = np.empty_like(send_cells, dtype=np.uint32) 

269 send_perm[cell_insert_position] = cell_permutation_info 

270 recv_perm = np.empty_like(recv_cells, dtype=np.uint32) 

271 send_perm_msg = [send_perm, send_cells_per_proc, MPI.UINT32_T] 

272 recv_perm_msg = [recv_perm, recv_cells_per_proc, MPI.UINT32_T] 

273 cell_to_output_comm.Neighbor_alltoallv(send_perm_msg, recv_perm_msg) 

274 cell_permutation_info = np.empty_like(recv_perm) 

275 cell_permutation_info[local_cell_index] = recv_perm 

276 

277 # 2. Extract function data (array is the same, keeping global indices from DOLFINx) 

278 # Dofmap is moved by the original cell index similar to the mesh geometry dofmap 

279 dofmap = u.function_space.dofmap 

280 dmap = dofmap.list 

281 num_dofs_per_cell = dmap.shape[1] 

282 dofmap_bs = dofmap.bs 

283 index_map_bs = dofmap.index_map_bs 

284 

285 # Unroll dofmap for block size 

286 unrolled_dofmap = unroll_dofmap(dofmap.list[:num_owned_cells, :], dofmap_bs) 

287 dmap_loc = (unrolled_dofmap // index_map_bs).reshape(-1) 

288 dmap_rem = (unrolled_dofmap % index_map_bs).reshape(-1) 

289 

290 # Convert imap index to global index 

291 imap_global = dofmap.index_map.local_to_global(dmap_loc) 

292 dofmap_global = (imap_global * index_map_bs + dmap_rem).reshape(unrolled_dofmap.shape) 

293 num_dofs_per_cell = dofmap_global.shape[1] 

294 dofmap_insert_position = unroll_insert_position(cell_insert_position, num_dofs_per_cell) 

295 

296 # Create and send array for global dofmap 

297 send_function_dofmap = np.empty(len(dofmap_insert_position), dtype=np.int64) 

298 send_function_dofmap[dofmap_insert_position] = dofmap_global.reshape(-1) 

299 send_sizes_dofmap = send_cells_per_proc * num_dofs_per_cell 

300 recv_size_dofmap = recv_cells_per_proc * num_dofs_per_cell 

301 recv_function_dofmap = np.empty(recv_size_dofmap.sum(), dtype=np.int64) 

302 cell_to_output_comm.Neighbor_alltoallv( 

303 [send_function_dofmap, send_sizes_dofmap, MPI.INT64_T], 

304 [recv_function_dofmap, recv_size_dofmap, MPI.INT64_T], 

305 ) 

306 

307 shaped_dofmap = recv_function_dofmap.reshape( 

308 local_cell_range[1] - local_cell_range[0], num_dofs_per_cell 

309 ).copy() 

310 _final_dofmap = np.empty_like(shaped_dofmap) 

311 _final_dofmap[local_cell_index] = shaped_dofmap 

312 final_dofmap = _final_dofmap.reshape(-1) 

313 

314 # Get offsets of dofmap 

315 num_cells_local = local_cell_range[1] - local_cell_range[0] 

316 num_dofs_local_dmap = num_cells_local * num_dofs_per_cell 

317 dofmap_imap = dolfinx.common.IndexMap(mesh.comm, num_dofs_local_dmap) 

318 local_dofmap_offsets = np.arange(num_cells_local + 1, dtype=np.int64) 

319 local_dofmap_offsets[:] *= num_dofs_per_cell 

320 local_dofmap_offsets[:] += dofmap_imap.local_range[0] 

321 

322 num_dofs_local = dofmap.index_map.size_local * dofmap.index_map_bs 

323 num_dofs_global = dofmap.index_map.size_global * dofmap.index_map_bs 

324 local_range = np.asarray(dofmap.index_map.local_range, dtype=np.int64) * dofmap.index_map_bs 

325 func_name = name if name is not None else u.name 

326 cell_to_output_comm.Free() 

327 return FunctionData( 

328 cell_permutations=cell_permutation_info, 

329 local_cell_range=local_cell_range, 

330 num_cells_global=num_cells_global, 

331 dofmap_array=final_dofmap, 

332 dofmap_offsets=local_dofmap_offsets, 

333 values=u.x.array[:num_dofs_local].copy(), 

334 dof_range=local_range, 

335 num_dofs_global=num_dofs_global, 

336 dofmap_range=dofmap_imap.local_range, 

337 global_dofs_in_dofmap=dofmap_imap.size_global, 

338 name=func_name, 

339 ) 

340 

341 

342def write_function_on_input_mesh( 

343 filename: Path | str, 

344 u: dolfinx.fem.Function, 

345 time: float = 0.0, 

346 name: typing.Optional[str] = None, 

347 mode: FileMode = FileMode.append, 

348 backend_args: dict[str, typing.Any] | None = None, 

349 backend: str = "adios2", 

350): 

351 """ 

352 Write function checkpoint (to be read with the input mesh). 

353 

354 Note: 

355 Requires backend to implement {py:class}`io4dolfinx.backends.write_function`. 

356 

357 Args: 

358 filename: The filename to write to 

359 u: The function to checkpoint 

360 time: Time-stamp associated with function at current write step 

361 mode: The mode to use (write or append) 

362 name: Name of function. If None, the name of the function is used. 

363 backend_args: Arguments to backend 

364 backend: Choice of backend module 

365 """ 

366 logger.debug( 

367 f"Writing function on input mesh to {filename} at time {time} with name {name or u.name}" 

368 ) 

369 logger.debug(f"Using backend {backend} with arguments {backend_args} and mode {mode}") 

370 mesh = u.function_space.mesh 

371 function_data = create_function_data_on_original_mesh(u, name) 

372 fname = Path(filename) 

373 

374 backend_cls = get_backend(backend) 

375 backend_args = backend_cls.get_default_backend_args(backend_args) 

376 backend_cls.write_function( 

377 fname, 

378 mesh.comm, 

379 function_data, 

380 time=time, 

381 mode=mode, 

382 backend_args=backend_args, 

383 ) 

384 

385 

386def write_mesh_input_order( 

387 filename: Path | str, 

388 mesh: dolfinx.mesh.Mesh, 

389 time: float = 0.0, 

390 mode: FileMode = FileMode.write, 

391 backend: str = "adios2", 

392 backend_args: dict[str, typing.Any] | None = None, 

393): 

394 """ 

395 Write mesh to checkpoint file in original input ordering. 

396 

397 Note: 

398 Requires backend to implement {py:class}`io4dolfinx.backends.write_mesh`. 

399 

400 Args: 

401 filename: The filename to write to 

402 mesh: Mesh to checkpoint 

403 time: Time-stamp associated with function at current write step 

404 mode: The mode to use (write or append) 

405 name: Name of function. If None, the name of the function is used. 

406 backend_args: Arguments to backend 

407 backend: Choice of backend module 

408 """ 

409 logger.debug(f"Writing mesh in input order to {filename} at time {time}") 

410 logger.debug(f"Using backend {backend} with arguments {backend_args} and mode {mode}") 

411 mesh_data = create_original_mesh_data(mesh) 

412 fname = Path(filename) 

413 

414 backend_cls = get_backend(backend) 

415 backend_args = backend_cls.get_default_backend_args(backend_args) 

416 backend_cls.write_mesh( 

417 fname, 

418 mesh.comm, 

419 mesh_data, 

420 backend_args=backend_args, 

421 mode=mode, 

422 time=time, 

423 )