# Copyright (C) 2020-2026 Fraunhofer ITWM and Sebastian Blauth
#
# This file is part of cashocs.
#
# cashocs is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# cashocs is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with cashocs. If not, see <https://www.gnu.org/licenses/>.
"""Basic mesh generation."""
from __future__ import annotations
import collections
from collections.abc import Callable
import functools
from typing import Any, Literal, TYPE_CHECKING
import fenics
from mpi4py import MPI
import numpy as np
from cashocs import _exceptions
from cashocs import log
from cashocs import mpi
from cashocs.geometry import measure
if TYPE_CHECKING:
from cashocs import _typing
def _get_mesh_stats(
mode: Literal["import", "generate"],
) -> Callable[..., Callable[..., _typing.MeshTuple]]:
"""A decorator for mesh importing / generating function which logs stats.
Args:
mode: A string indicating whether the mesh is being generated or imported.
Returns:
The decorated function.
"""
def decorator_stats(
func: Callable[..., _typing.MeshTuple],
) -> Callable[..., _typing.MeshTuple]:
"""A decorator for a mesh generating function.
Args:
func: The function to be decorated.
Returns:
The decorated function
"""
@functools.wraps(func)
def wrapper_stats(*args: Any, **kwargs: Any) -> _typing.MeshTuple:
"""Wrapper function for mesh generating functions.
Args:
*args: The arguments for the function.
**kwargs: The keyword arguments for the function.
Returns:
The wrapped function.
"""
comm = None
if "comm" in kwargs.keys(): # pylint: disable=consider-iterating-dictionary
comm = kwargs["comm"]
else:
for arg in args:
if isinstance(arg, MPI.Comm):
comm = arg
if comm is None:
comm = mpi.COMM_WORLD
word = "importing" if mode.casefold() == "import" else "generating"
worded = "imported" if mode.casefold() == "import" else "generated"
mpi_size = comm.size
log.begin(f"{word.capitalize()} mesh.", level=log.INFO)
value = func(*args, **kwargs)
dim = value[0].geometry().dim()
log.info(
f"Successfully {worded} {dim}-dimensional mesh on {mpi_size} CPU(s)."
)
log.info(
f"Mesh contains {value[0].num_entities_global(0):,} vertices and "
f"{value[0].num_entities_global(dim):,} cells of type "
f"{value[0].ufl_cell().cellname()}."
)
log.end()
return value
return wrapper_stats
return decorator_stats
[docs]
@_get_mesh_stats("generate")
def interval_mesh(
n: int = 10,
start: float = 0.0,
end: float = 1.0,
partitions: list[float] | None = None,
comm: MPI.Comm | None = None,
) -> _typing.MeshTuple:
r"""Creates an 1D interval mesh starting at x=0 to x=length.
This function creates a uniform mesh of a 1D interval, starting at the ``start`` and
ending at ``end``. The resulting mesh uses ``n`` sub-intervals to
discretize the geometry. The boundary markers are as follows:
- 1 corresponds to :math:`x=start`
- 2 corresponds to :math:`x=end`
Args:
n: Number of elements for discretizing the interval, default is 10
start: The start of the interval, default is 0.0
end: The end of the interval, default is 1.0
partitions: Points in the interval at which a partition in subdomains should be
made. The resulting volume measure is sorted ascendingly according to the
sub-intervals defined in partitions (starting at 1). Defaults to ``None``.
comm: MPI communicator that is to be used for creating the mesh.
Returns:
A tuple (mesh, subdomains, boundaries, dx, ds, dS), where mesh is the imported
FEM mesh, subdomains is a mesh function for the subdomains, boundaries is a mesh
function for the boundaries, dx is a volume measure, ds is a surface measure,
and dS is a measure for the interior facets.
"""
if end <= start:
raise _exceptions.InputError(
"cashocs.geometry.interval_mesh", "end", "end needs to be larger than start"
)
if partitions is not None:
if not all(x < y for x, y in zip(partitions[:-1], partitions[1:], strict=True)):
raise _exceptions.InputError(
"cashocs.geometry.interval_mesh",
"partitions",
"partitions must be strictly increasing",
)
n = int(n)
dim = 1
if comm is None:
comm = mpi.COMM_WORLD
mesh = fenics.IntervalMesh(comm, n, start, end)
physical_groups = {"dx": {}, "ds": {"start": 1, "end": 2}}
subdomains = fenics.MeshFunction("size_t", mesh, dim=dim)
boundaries = fenics.MeshFunction("size_t", mesh, dim=dim - 1)
x_min = fenics.CompiledSubDomain(
"on_boundary && near(x[0], start, tol)", tol=fenics.DOLFIN_EPS, start=start
)
x_max = fenics.CompiledSubDomain(
"on_boundary && near(x[0], end, tol)", tol=fenics.DOLFIN_EPS, end=end
)
x_min.mark(boundaries, 1)
x_max.mark(boundaries, 2)
if partitions is not None:
padded_partitions = collections.deque(partitions)
padded_partitions.appendleft(start)
padded_partitions.append(end)
for i in range(len(padded_partitions) - 1):
start_point = padded_partitions[i]
end_point = padded_partitions[i + 1]
part = fenics.CompiledSubDomain(
"x[0] >= start_point - eps && x[0] <= end_point + eps",
start_point=start_point,
end_point=end_point,
eps=fenics.DOLFIN_EPS,
)
part.mark(subdomains, i + 1)
physical_groups["dx"].update({str(i + 1): i + 1})
else:
subdomains.set_all(1)
physical_groups["dx"].update({"all": 1})
mesh.physical_groups = physical_groups
dx = measure.NamedMeasure(
"dx", mesh, subdomain_data=subdomains, physical_groups=physical_groups
)
ds = measure.NamedMeasure(
"ds", mesh, subdomain_data=boundaries, physical_groups=physical_groups
)
dS = measure.NamedMeasure( # pylint: disable=invalid-name
"dS", mesh, subdomain_data=boundaries, physical_groups=physical_groups
)
return mesh, subdomains, boundaries, dx, ds, dS
[docs]
def regular_mesh(
n: int = 10,
length_x: float = 1.0,
length_y: float = 1.0,
length_z: float | None = None,
diagonal: Literal["left", "right", "left/right", "right/left", "crossed"] = "right",
comm: MPI.Comm | None = None,
) -> _typing.MeshTuple:
r"""Creates a mesh corresponding to a rectangle or cube.
This function creates a uniform mesh of either a rectangle or a cube, starting at
the origin and having length specified in ``length_x``, ``length_y``, and
``length_z``. The resulting mesh uses ``n`` elements along the shortest direction
and accordingly many along the longer ones. The resulting domain is
.. math::
\begin{alignedat}{2}
&[0, length_x] \times [0, length_y] \quad &&\text{ in } 2D, \\
&[0, length_x] \times [0, length_y] \times [0, length_z] \quad &&\text{ in } 3D.
\end{alignedat}
The boundary markers are ordered as follows:
- 1 corresponds to :math:`x=0`.
- 2 corresponds to :math:`x=length_x`.
- 3 corresponds to :math:`y=0`.
- 4 corresponds to :math:`y=length_y`.
- 5 corresponds to :math:`z=0` (only in 3D).
- 6 corresponds to :math:`z=length_z` (only in 3D).
Args:
n: Number of elements in the shortest coordinate direction.
length_x: Length in x-direction.
length_y: Length in y-direction.
length_z: Length in z-direction, if this is ``None``, then the geometry will be
two-dimensional (default is ``None``).
diagonal: This defines the type of diagonal used to create the box mesh in 2D.
This can be one of ``"right"``, ``"left"``, ``"left/right"``,
``"right/left"`` or ``"crossed"``.
comm: MPI communicator that is to be used for creating the mesh.
Returns:
A tuple (mesh, subdomains, boundaries, dx, ds, dS), where mesh is the imported
FEM mesh, subdomains is a mesh function for the subdomains, boundaries is a mesh
function for the boundaries, dx is a volume measure, ds is a surface measure,
and dS is a measure for the interior facets.
"""
start_x = 0.0
start_y = 0.0
start_z = 0.0 if length_z is not None else None
end_x = length_x
end_y = length_y
end_z = length_z
return regular_box_mesh(
n, start_x, start_y, start_z, end_x, end_y, end_z, diagonal, comm
)
[docs]
@_get_mesh_stats("generate")
def regular_box_mesh(
n: int = 10,
start_x: float = 0.0,
start_y: float = 0.0,
start_z: float | None = None,
end_x: float = 1.0,
end_y: float = 1.0,
end_z: float | None = None,
diagonal: Literal["right", "left", "left/right", "right/left", "crossed"] = "right",
comm: MPI.Comm | None = None,
) -> _typing.MeshTuple:
r"""Creates a mesh corresponding to a rectangle or cube.
This function creates a uniform mesh of either a rectangle
or a cube, with specified start (``S_``) and end points (``E_``).
The resulting mesh uses ``n`` elements along the shortest direction
and accordingly many along the longer ones. The resulting domain is
.. math::
\begin{alignedat}{2}
&[start_x, end_x] \times [start_y, end_y] \quad &&\text{ in } 2D, \\
&[start_x, end_x] \times [start_y, end_y] \times [start_z, end_z] \quad
&&\text{ in } 3D.
\end{alignedat}
The boundary markers are ordered as follows:
- 1 corresponds to :math:`x=start_x`.
- 2 corresponds to :math:`x=end_x`.
- 3 corresponds to :math:`y=start_y`.
- 4 corresponds to :math:`y=end_y`.
- 5 corresponds to :math:`z=start_z` (only in 3D).
- 6 corresponds to :math:`z=end_z` (only in 3D).
Args:
n: Number of elements in the shortest coordinate direction.
start_x: Start of the x-interval.
start_y: Start of the y-interval.
start_z: Start of the z-interval, mesh is 2D if this is ``None`` (default is
``None``).
end_x: End of the x-interval.
end_y: End of the y-interval.
end_z: End of the z-interval, mesh is 2D if this is ``None`` (default is
``None``).
diagonal: This defines the type of diagonal used to create the box mesh in 2D.
This can be one of ``"right"``, ``"left"``, ``"left/right"``,
``"right/left"`` or ``"crossed"``.
comm: MPI communicator that is to be used for creating the mesh.
Returns:
A tuple (mesh, subdomains, boundaries, dx, ds, dS), where mesh is the imported
FEM mesh, subdomains is a mesh function for the subdomains, boundaries is a mesh
function for the boundaries, dx is a volume measure, ds is a surface measure,
and dS is a measure for the interior facets.
"""
n = int(n)
if comm is None:
comm = mpi.COMM_WORLD
lx = end_x - start_x
ly = end_y - start_y
sizes = [lx, ly]
dim = 2
if start_z is None and end_z is None:
pass
elif start_z is not None and end_z is not None:
lz = end_z - start_z
sizes.append(lz)
dim = 3
else:
raise _exceptions.InputError(
"cashocs.geometry.regular_box_mesh",
"start_z",
"Incorrect input for the z-coordinate. "
"Both start_z and end_z need to be specified.",
)
_check_sizes(sizes)
size_min = np.min(sizes)
num_points = [int(np.round(length / size_min * n)) for length in sizes]
physical_groups = {
"dx": {"all": 1},
"ds": {"left": 1, "right": 2, "bottom": 3, "top": 4},
}
if start_z is None:
mesh = fenics.RectangleMesh(
comm,
fenics.Point(start_x, start_y),
fenics.Point(end_x, end_y),
num_points[0],
num_points[1],
diagonal=diagonal,
)
else:
mesh = fenics.BoxMesh(
comm,
fenics.Point(start_x, start_y, start_z),
fenics.Point(end_x, end_y, end_z),
num_points[0],
num_points[1],
num_points[2],
)
physical_groups["ds"].update({"front": 5, "back": 6})
mesh.physical_groups = physical_groups
subdomains = fenics.MeshFunction("size_t", mesh, dim=dim)
subdomains.set_all(1)
boundaries = fenics.MeshFunction("size_t", mesh, dim=dim - 1)
x_min = fenics.CompiledSubDomain(
"on_boundary && near(x[0], sx, tol)", tol=fenics.DOLFIN_EPS, sx=start_x
)
x_max = fenics.CompiledSubDomain(
"on_boundary && near(x[0], ex, tol)", tol=fenics.DOLFIN_EPS, ex=end_x
)
x_min.mark(boundaries, 1)
x_max.mark(boundaries, 2)
y_min = fenics.CompiledSubDomain(
"on_boundary && near(x[1], sy, tol)", tol=fenics.DOLFIN_EPS, sy=start_y
)
y_max = fenics.CompiledSubDomain(
"on_boundary && near(x[1], ey, tol)", tol=fenics.DOLFIN_EPS, ey=end_y
)
y_min.mark(boundaries, 3)
y_max.mark(boundaries, 4)
if start_z is not None:
z_min = fenics.CompiledSubDomain(
"on_boundary && near(x[2], sz, tol)", tol=fenics.DOLFIN_EPS, sz=start_z
)
z_max = fenics.CompiledSubDomain(
"on_boundary && near(x[2], ez, tol)", tol=fenics.DOLFIN_EPS, ez=end_z
)
z_min.mark(boundaries, 5)
z_max.mark(boundaries, 6)
dx = measure.NamedMeasure(
"dx", mesh, subdomain_data=subdomains, physical_groups=physical_groups
)
ds = measure.NamedMeasure(
"ds", mesh, subdomain_data=boundaries, physical_groups=physical_groups
)
dS = measure.NamedMeasure( # pylint: disable=invalid-name
"dS", mesh, subdomain_data=boundaries, physical_groups=physical_groups
)
return mesh, subdomains, boundaries, dx, ds, dS
def _check_sizes(sizes: list[float]) -> None:
for size in sizes:
if size <= 0:
raise _exceptions.InputError(
"cashocs.geometry.regular_box_mesh",
"start_",
"The start values have to be smaller than the end values.",
)