Source code for pydims.indexing_functions
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024 PyDims contributors (https://github.com/pydims)
"""
Indexing functions.
Includes functions from the "Indexing Functions" section of the Python Array API
standard.
"""
from .dimensioned_array import DimArr, DimensionedArray, DimensionError
[docs]
def take(x: DimArr, /, indices: DimensionedArray) -> DimArr:
"""
Returns elements of an array along an axis.
The indices must be 1-D and their single dimension defines the axis along which
to take elements.
Parameters
----------
x:
Input array.
indices:
Array of indices to extract from the input array. Must be 1-D.
Returns
-------
:
Array containing the elements of the input array at the specified indices.
"""
try:
axis = x.dims.index(indices.dim)
except ValueError:
raise DimensionError(
f"Indices dimension '{indices.dim}' not in data dimensions '{x.dims}'"
) from None
return x.__class__(
values=x.values.take(indices.values, axis=axis),
dims=x.dims,
unit=x.unit,
)
__all__ = ['take']