# Copyright (C) 2023-2026 Jørgen Schartum Dokken
#
# This file is part of io4dolfinx
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import typing
from pathlib import Path
from typing import Any
from mpi4py import MPI
import basix
import dolfinx
import numpy as np
import numpy.typing as npt
import ufl
from packaging.version import Version
from .backends import FileMode, ReadMode, get_backend
from .comm_helpers import (
send_and_recv_cell_perm,
send_dofmap_and_recv_values,
send_dofs_and_recv_values,
)
from .readers import create_geometry_function_space
from .structures import ArrayData, FunctionData, MeshTagsData
from .utils import (
check_file_exists,
compute_dofmap_pos,
compute_local_range,
index_owner,
unroll_dofmap,
unroll_insert_position,
)
from .writers import prepare_meshdata_for_storage
from .writers import write_function as _internal_function_writer
from .writers import write_mesh as _internal_mesh_writer
__all__ = [
"read_mesh",
"write_function",
"read_function",
"write_mesh",
"read_meshtags",
"write_meshtags",
"read_attributes",
"write_attributes",
]
[docs]
def write_attributes(
filename: Path | str,
comm: MPI.Intracomm,
name: str,
attributes: dict[str, np.ndarray],
backend_args: dict[str, typing.Any] | None = None,
backend: str = "adios2",
):
"""Write attributes to file.
Args:
filename: Path to file to write to
comm: MPI communicator used in storage
name: Name of the attributes
attributes: Dictionary of attributes to write to file
backend_args: Arguments for backend, for instance file type.
backend: What backend to use for writing.
"""
backend_cls = get_backend(backend)
backend_args = backend_cls.get_default_backend_args(backend_args)
backend_cls.write_attributes(filename, comm, name, attributes, backend_args)
[docs]
def read_attributes(
filename: Path | str,
comm: MPI.Intracomm,
name: str,
backend_args: dict[str, typing.Any] | None = None,
backend: str = "adios2",
) -> dict[str, typing.Any]:
"""Read attributes from file.
Args:
filename: Path to file to read from
comm: MPI communicator used in storage
name: Name of the attributes
backend_args: Arguments for backend, for instance file type.
backend: What backend to use for writing.
Returns:
The attributes
"""
backend_cls = get_backend(backend)
backend_args = backend_cls.get_default_backend_args(backend_args)
return backend_cls.read_attributes(filename, comm, name, backend_args)
[docs]
def read_timestamps(
filename: Path | str,
comm: MPI.Intracomm,
function_name: str,
backend_args: dict[str, typing.Any] | None = None,
backend: str = "adios2",
) -> npt.NDArray[np.float64 | str]: # type: ignore[type-var]
"""
Read time-stamps from a checkpoint file.
Args:
comm: MPI communicator
filename: Path to file
function_name: Name of the function to read time-stamps for
backend_args: Arguments for backend, for instance file type.
backend: What backend to use for writing.
Returns:
The time-stamps
"""
check_file_exists(filename)
backend_cls = get_backend(backend)
backend_args = backend_cls.get_default_backend_args(backend_args)
return backend_cls.read_timestamps(filename, comm, function_name, backend_args)
[docs]
def read_function(
filename: Path | str,
u: dolfinx.fem.Function,
time: float = 0.0,
name: str | None = None,
backend_args: dict[str, Any] | None = None,
backend: str = "adios2",
):
"""
Read checkpoint from file and fill it into `u`.
Args:
filename: Path to checkpoint
u: Function to fill
time: Time-stamp associated with checkpoint
name: If not provided, `u.name` is used to search through the input file for the function
"""
check_file_exists(filename)
mesh = u.function_space.mesh
comm = mesh.comm
if name is None:
name = u.name
# ----------------------Step 1---------------------------------
# Compute index of input cells and get cell permutation
num_owned_cells = mesh.topology.index_map(mesh.topology.dim).size_local
input_cells = mesh.topology.original_cell_index[:num_owned_cells]
mesh.topology.create_entity_permutations()
cell_perm = mesh.topology.get_cell_permutation_info()[:num_owned_cells]
# Compute mesh->input communicator
# 1.1 Compute mesh->input communicator
backend_cls = get_backend(backend)
owners: npt.NDArray[np.int32]
if backend_cls.read_mode == ReadMode.serial:
owners = np.zeros(input_cells, dtype=np.int32)
elif backend_cls.read_mode == ReadMode.parallel:
num_cells_global = mesh.topology.index_map(mesh.topology.dim).size_global
owners = index_owner(mesh.comm, input_cells, num_cells_global)
else:
raise NotImplementedError(f"{backend_cls.read_mode} not implemented")
# -------------------Step 2------------------------------------
# Send and receive global cell index and cell perm
inc_cells, inc_perms = send_and_recv_cell_perm(input_cells, cell_perm, owners, mesh.comm)
# -------------------Step 3-----------------------------------
# Read dofmap from file and compute dof owners
check_file_exists(filename)
backend_cls = get_backend(backend)
backend_args = backend_cls.get_default_backend_args(backend_args)
input_dofmap = backend_cls.read_dofmap(filename, comm, name, backend_args)
# Compute owner of dofs in dofmap
dof_owner: npt.NDArray[np.int32]
if backend_cls.read_mode == ReadMode.serial:
dof_owner = np.zeros(len(input_dofmap.array), dtype=np.int32)
elif backend_cls.read_mode == ReadMode.parallel:
num_dofs_global = (
u.function_space.dofmap.index_map.size_global * u.function_space.dofmap.index_map_bs
)
dof_owner = index_owner(comm, input_dofmap.array.astype(np.int64), num_dofs_global)
else:
raise NotImplementedError(f"{backend_cls.read_mode} not implemented")
# --------------------Step 4-----------------------------------
# Read array from file and communicate them to input dofmap process
input_array, starting_pos = backend_cls.read_dofs(filename, comm, name, time, backend_args)
recv_array = send_dofs_and_recv_values(
input_dofmap.array.astype(np.int64), dof_owner, comm, input_array, starting_pos
)
# -------------------Step 5--------------------------------------
# Invert permutation of input data based on input perm
# Then apply current permutation to the local data
element = u.function_space.element
if element.needs_dof_transformations:
bs = u.function_space.dofmap.bs
# Read input cell permutations on dofmap process
local_input_range = compute_local_range(comm, num_cells_global)
input_local_cell_index = inc_cells - local_input_range[0]
input_perms = backend_cls.read_cell_perms(comm, filename, backend_args)
# Start by sorting data array by cell permutation
num_dofs_per_cell = input_dofmap.offsets[1:] - input_dofmap.offsets[:-1]
assert np.allclose(num_dofs_per_cell, num_dofs_per_cell[0])
# Sort dofmap by input local cell index
input_perms_sorted = input_perms[input_local_cell_index]
unrolled_dofmap_position = unroll_insert_position(
input_local_cell_index, num_dofs_per_cell[0]
)
dofmap_sorted_by_input = recv_array[unrolled_dofmap_position]
# First invert input data to reference element then transform to current mesh
element.Tt_apply(dofmap_sorted_by_input, input_perms_sorted, bs)
element.Tt_inv_apply(dofmap_sorted_by_input, inc_perms, bs)
# Compute invert permutation
inverted_perm = np.empty_like(unrolled_dofmap_position)
inverted_perm[unrolled_dofmap_position] = np.arange(
len(unrolled_dofmap_position), dtype=inverted_perm.dtype
)
recv_array = dofmap_sorted_by_input[inverted_perm]
# ------------------Step 6----------------------------------------
# For each dof owned by a process, find the local position in the dofmap.
V = u.function_space
local_cells, dof_pos = compute_dofmap_pos(V)
input_cells = V.mesh.topology.original_cell_index[local_cells]
num_cells_global = V.mesh.topology.index_map(V.mesh.topology.dim).size_global
if backend_cls.read_mode == ReadMode.serial:
owners = np.zeros(len(input_cells), dtype=np.int32)
elif backend_cls.read_mode == ReadMode.parallel:
owners = index_owner(V.mesh.comm, input_cells, num_cells_global)
else:
raise NotImplementedError(f"{backend_cls.read_mode} not implemented")
unique_owners, owner_count = np.unique(owners, return_counts=True)
# FIXME: In C++ use NBX to find neighbourhood
sub_comm = V.mesh.comm.Create_dist_graph(
[V.mesh.comm.rank], [len(unique_owners)], unique_owners, reorder=False
)
source, dest, _ = sub_comm.Get_dist_neighbors()
sub_comm.Free()
owned_values = send_dofmap_and_recv_values(
comm,
np.asarray(source, dtype=np.int32),
np.asarray(dest, dtype=np.int32),
owners,
owner_count.astype(np.int32),
input_cells,
dof_pos,
num_cells_global,
recv_array,
input_dofmap.offsets,
)
u.x.array[: len(owned_values)] = owned_values
u.x.scatter_forward()
[docs]
def read_mesh(
filename: Path | str,
comm: MPI.Intracomm,
ghost_mode: dolfinx.mesh.GhostMode = dolfinx.mesh.GhostMode.shared_facet,
time: float | str | None = 0.0,
read_from_partition: bool = False,
backend_args: dict[str, Any] | None = None,
backend: str = "adios2",
max_facet_to_cell_links: int = 2,
) -> dolfinx.mesh.Mesh:
"""
Read an ADIOS2 mesh into DOLFINx.
Args:
filename: Path to input file
comm: The MPI communciator to distribute the mesh over
ghost_mode: Ghost mode to use for mesh. If `read_from_partition`
is set to `True` this option is ignored.
time: Time stamp associated with mesh
read_from_partition: Read mesh with partition from file
backend_args: List of arguments to reader backend
max_facet_to_cell_links: Maximum number of cells a facet
can be connected to.
Returns:
The distributed mesh
"""
# Read in data in a distributed fashin
check_file_exists(filename)
backend_cls = get_backend(backend)
backend_args = backend_cls.get_default_backend_args(backend_args)
# Let each backend handle what should be default behavior when reading mesh
# with or without time stamp.
dist_in_data = backend_cls.read_mesh_data(
filename,
comm,
time=time,
read_from_partition=read_from_partition,
backend_args=backend_args,
)
# Create DOLFINx mesh
element = basix.ufl.element(
basix.ElementFamily.P,
dist_in_data.cell_type,
dist_in_data.degree,
basix.LagrangeVariant(int(dist_in_data.lvar)),
shape=(dist_in_data.x.shape[1],),
dtype=dist_in_data.x.dtype,
)
domain = ufl.Mesh(element)
if (partition_graph := dist_in_data.partition_graph) is not None:
def partitioner(comm: MPI.Intracomm, n, m, topo):
assert len(topo[0]) % (len(partition_graph.offsets) - 1) == 0
if Version(dolfinx.__version__) > Version("0.9.0"):
return partition_graph._cpp_object
else:
return partition_graph
else:
try:
partitioner = dolfinx.cpp.mesh.create_cell_partitioner(
ghost_mode, max_facet_to_cell_links=max_facet_to_cell_links
)
except TypeError:
partitioner = dolfinx.cpp.mesh.create_cell_partitioner(ghost_mode)
# Should change to the commented code below when we require python
# minimum version to be >=3.12 see https://github.com/python/cpython/pull/116198
# import inspect
# sig = inspect.signature(dolfinx.mesh.create_cell_partitioner)
# part_kwargs = {}
# if "max_facet_to_cell_links" in list(sig.parameters.keys()):
# part_kwargs["max_facet_to_cell_links"] = max_facet_to_cell_links
# partitioner = dolfinx.cpp.mesh.create_cell_partitioner(ghost_mode, **part_kwargs)
return dolfinx.mesh.create_mesh(
comm,
cells=dist_in_data.cells,
x=dist_in_data.x,
e=domain,
partitioner=partitioner,
)
[docs]
def write_mesh(
filename: Path,
mesh: dolfinx.mesh.Mesh,
mode: FileMode = FileMode.write,
time: float = 0.0,
store_partition_info: bool = False,
backend_args: dict[str, Any] | None = None,
backend: str = "adios2",
):
"""
Write a mesh to file.
Args:
filename: Path to save mesh (without file-extension)
mesh: The mesh to write to file
store_partition_info: Store mesh partitioning (including ghosting) to file
"""
mesh_data = prepare_meshdata_for_storage(mesh=mesh, store_partition_info=store_partition_info)
_internal_mesh_writer(
filename,
mesh.comm,
mesh_data=mesh_data,
time=time,
backend_args=backend_args,
backend=backend,
mode=mode,
)
[docs]
def write_function(
filename: Path | str,
u: dolfinx.fem.Function,
time: float = 0.0,
mode: FileMode = FileMode.append,
name: str | None = None,
backend_args: dict[str, Any] | None = None,
backend: str = "adios2",
):
"""
Write function checkpoint to file.
Args:
u: Function to write to file
time: Time-stamp for simulation
filename: Path to write to
mode: Write or append.
name: Name of function to write. If None, the name of the function is used.
backend_args: Arguments to the IO backend.
backend: The backend to use
"""
dofmap = u.function_space.dofmap
values = u.x.array
mesh = u.function_space.mesh
comm = mesh.comm
mesh.topology.create_entity_permutations()
cell_perm = mesh.topology.get_cell_permutation_info()
num_cells_local = mesh.topology.index_map(mesh.topology.dim).size_local
local_cell_range = mesh.topology.index_map(mesh.topology.dim).local_range
num_cells_global = mesh.topology.index_map(mesh.topology.dim).size_global
# Convert local dofmap into global_dofmap
dmap = dofmap.list
num_dofs_per_cell = dmap.shape[1]
dofmap_bs = dofmap.bs
num_dofs_local_dmap = num_cells_local * num_dofs_per_cell * dofmap_bs
index_map_bs = dofmap.index_map_bs
# Unroll dofmap for block size
unrolled_dofmap = unroll_dofmap(dofmap.list[:num_cells_local, :], dofmap_bs)
dmap_loc = (unrolled_dofmap // index_map_bs).reshape(-1)
dmap_rem = (unrolled_dofmap % index_map_bs).reshape(-1)
# Convert imap index to global index
imap_global = dofmap.index_map.local_to_global(dmap_loc)
dofmap_global = imap_global * index_map_bs + dmap_rem
dofmap_imap = dolfinx.common.IndexMap(mesh.comm, num_dofs_local_dmap)
# Compute dofmap offsets
local_dofmap_offsets = np.arange(num_cells_local + 1, dtype=np.int64)
local_dofmap_offsets[:] *= num_dofs_per_cell * dofmap_bs
local_dofmap_offsets += dofmap_imap.local_range[0]
num_dofs_global = dofmap.index_map.size_global * dofmap.index_map_bs
local_dof_range = np.asarray(dofmap.index_map.local_range) * dofmap.index_map_bs
num_dofs_local = local_dof_range[1] - local_dof_range[0]
# Create internal data structure for function data to write to file
function_data = FunctionData(
cell_permutations=cell_perm[:num_cells_local].copy(),
local_cell_range=local_cell_range,
num_cells_global=num_cells_global,
dofmap_array=dofmap_global,
dofmap_offsets=local_dofmap_offsets,
dofmap_range=dofmap_imap.local_range,
global_dofs_in_dofmap=dofmap_imap.size_global,
values=values[:num_dofs_local].copy(),
dof_range=local_dof_range,
num_dofs_global=num_dofs_global,
name=name or u.name,
)
# Write to file
fname = Path(filename)
_internal_function_writer(
fname, comm, function_data, time, backend_args=backend_args, backend=backend, mode=mode
)
[docs]
def read_function_names(
filename: Path | str,
comm: MPI.Intracomm,
backend_args: dict[str, Any] | None = None,
backend: str = "h5py",
) -> list[str]:
"""Read all function names from a file.
Args:
filename: Path to file
comm: MPI communicator to launch IO on.
backend_args: Arguments to backend
Returns:
A list of function names.
"""
backend_cls = get_backend(backend)
return backend_cls.read_function_names(filename, comm, backend_args=backend_args)
[docs]
def write_point_data(
filename: Path | str,
u: dolfinx.fem.Function,
time: str | float | None,
mode: FileMode,
backend_args: dict[str, Any] | None,
backend: str = "vtkhdf",
):
"""Write function to file by interpolating into geometry nodes.
Args:
filename: Path to file
u: The function to store
time: Time stamp
mode: Append or write
backend_args: The backend arguments
backend: Which backend to use.
"""
V = create_geometry_function_space(u.function_space.mesh, int(np.prod(u.ufl_shape)))
v_out = dolfinx.fem.Function(V, name=u.name, dtype=u.x.array.dtype)
v_out.interpolate(u)
comm = v_out.function_space.mesh.comm
data_shape = (V.dofmap.index_map.size_global, V.dofmap.index_map_bs)
local_range = V.dofmap.index_map.local_range
num_dofs_local = V.dofmap.index_map.size_local
data = v_out.x.array.reshape(-1, V.dofmap.index_map_bs)[:num_dofs_local]
ad = ArrayData(
name=v_out.name, values=data, global_shape=data_shape, local_range=local_range, type="Point"
)
backend_cls = get_backend(backend)
return backend_cls.write_data(
filename, comm=comm, mode=mode, time=time, array_data=ad, backend_args=backend_args
)
[docs]
def write_cell_data(
filename: Path | str,
u: dolfinx.fem.Function,
time: str | float | None,
mode: FileMode,
backend_args: dict[str, Any] | None,
backend: str = "vtkhdf",
):
"""Write function to file by interpolating into cell midpoints.
Args:
filename: Path to file
point_data: Data to write to file
time: Time stamp
mode: Append or write
backend_args: The backend arguments
"""
V = dolfinx.fem.functionspace(u.function_space.mesh, ("DG", 0, u.ufl_shape))
v_out = dolfinx.fem.Function(V, name=u.name, dtype=u.x.array.dtype)
v_out.interpolate(u)
comm = v_out.function_space.mesh.comm
data_shape = (V.dofmap.index_map.size_global, V.dofmap.index_map_bs)
local_range = V.dofmap.index_map.local_range
num_dofs_local = V.dofmap.index_map.size_local
data = v_out.x.array.reshape(-1, V.dofmap.index_map_bs)[:num_dofs_local]
backend_cls = get_backend(backend)
ad = ArrayData(
name=v_out.name, values=data, global_shape=data_shape, local_range=local_range, type="Cell"
)
backend_cls = get_backend(backend)
return backend_cls.write_data(
filename, comm=comm, mode=mode, time=time, array_data=ad, backend_args=backend_args
)