Source code for pysindy.utils.axes

"""
A module that defines one external class, AxesArray, to act like a numpy array
but keep track of axis definitions.  It aims to allow meaningful replacement
of magic numbers for axis conventions in code.  E.g::

   import numpy as np

   arr = AxesArray(np.ones((2,3,4)), {"ax_time": 0, "ax_spatial": [1, 2]})
   print(arr.axes)
   print(arr.ax_time)
   print(arr.n_time)
   print(arr.ax_spatial)
   print(arr.n_spatial)

Would show::

   {"ax_time": 0, "ax_spatial": [1, 2]}
   0
   2
   [1, 2]
   [3, 4]

It is up to the user to handle the ``list[int] | int`` return values, but this
module has several functions to deal with the axes dictionary, internally
referred to as type ``CompatDict[T]``:

Appending an item to a ``CompatDict[T]``
   :py:func:`compat_dict_append`

Generating a ``CompatDict[int]`` of axes from list of axes names:
   :py:func:`fwd_from_names`

Create new ``CompatDict[int]`` from this ``AxesArray`` with new axis/axes added:
   :py:meth:`AxesArray.insert_axis`

Create new ``CompatDict[int]`` from this ``AxesArray`` with axis/axes removed:
   :py:meth:`AxesArray.remove_axis`


.. todo::

   Add developer documentation here.

The recommended way to refactor existing code to use AxesArrays is to add them
at the lowest level possible.  Enter debug mode and see how long the expected
axes persist throughout array operations.  When AxesArray loses track of the
correct axes, re-assign them with an AxesArray constructor (which only uses a
view of the data).

Starting at the macro level runs the risk of triggering a great deal of errors
from unimplemented functions.
"""
from __future__ import annotations

import copy
import warnings
from enum import Enum
from typing import Collection
from typing import Dict
from typing import get_args
from typing import List
from typing import Literal
from typing import NewType
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TypeVar
from typing import Union

import numpy as np
from numpy.typing import NDArray
from sklearn.base import TransformerMixin

HANDLED_FUNCTIONS = {}

AxesWarning = type("AxesWarning", (SyntaxWarning,), {})
BasicIndexer = Union[slice, int, type(Ellipsis), None, str]
Indexer = Union[BasicIndexer, NDArray, List]
StandardIndexer = Union[slice, int, None, NDArray[np.dtype(int)]]
OldIndex = NewType("OldIndex", int)  # Before moving advanced axes adajent
KeyIndex = NewType("KeyIndex", int)
NewIndex = NewType("NewIndex", int)
T = TypeVar("T", bound=int)  # TODO: Bind to a non-sequence after type-negation PEP
ItemOrList = Union[T, List[T]]
CompatDict = Dict[str, ItemOrList[T]]


class _Sentinels(Enum):
    ADV_NAME = object()
    ADV_REMOVE = object()


class _AxisMapping:
    """Convenience wrapper for a two-way map between axis names and indexes."""

    fwd_map: Dict[str, List[int]]
    reverse_map: Dict[int, str]

    def __init__(
        self,
        axes: dict[str, Union[int, Sequence[int]]],
        in_ndim: int,
    ):
        self.fwd_map = {}
        self.reverse_map = {}

        def coerce_sequence(obj):
            if isinstance(obj, Sequence):
                return sorted(obj)
            return [obj]

        for ax_name, ax_ids in axes.items():
            ax_ids = coerce_sequence(ax_ids)
            self.fwd_map[ax_name] = ax_ids
            for ax_id in ax_ids:
                old_name = self.reverse_map.get(ax_id)
                if old_name is not None:
                    raise ValueError(f"Assigned multiple definitions to axis {ax_id}")
                if ax_id >= in_ndim:
                    raise ValueError(
                        f"Assigned definition to axis {ax_id}, but array only has"
                        f" {in_ndim} axes"
                    )
                self.reverse_map[ax_id] = ax_name
        if len(self.reverse_map) != in_ndim:
            warnings.warn(
                f"{len(self.reverse_map)} axes labeled for array with {in_ndim} axes",
                AxesWarning,
            )

    @staticmethod
    def _compat_axes(in_dict: Dict[str, List[int]]) -> Dict[str, Union[list[int], int]]:
        """Like fwd_map, but unpack single-element axis lists"""
        axes = {}
        for k, v in in_dict.items():
            if len(v) == 1:
                axes[k] = v[0]
            else:
                axes[k] = v
        return axes

    @property
    def compat_axes(self):
        return self._compat_axes(self.fwd_map)

    def remove_axis(self, axis: Union[Collection[int], int, None] = None):
        """Create an axes dict from self with specified axis or axes
        removed and all greater axes decremented.  This can be passed to
        the constructor to create a new _AxisMapping

        Arguments:
            axis: the axis index or axes indexes to remove.  By numpy
            ufunc convention, axis=None (default) removes _all_ axes.
        """
        if axis is None:
            return {}
        new_axes = copy.deepcopy(self.fwd_map)
        in_ndim = len(self.reverse_map)
        if not isinstance(axis, Collection):
            axis = [axis]
        axis = [ax_id if ax_id >= 0 else (self.ndim + ax_id) for ax_id in axis]
        for cum_shift, orig_ax_remove in enumerate(sorted(axis)):
            remove_ax_name = self.reverse_map[orig_ax_remove]
            curr_ax_remove = orig_ax_remove - cum_shift
            if len(new_axes[remove_ax_name]) == 1:
                new_axes.pop(remove_ax_name)
            else:
                new_axes[remove_ax_name].remove(curr_ax_remove)
            for old_ax_dec in range(curr_ax_remove + 1, in_ndim - cum_shift):
                orig_ax_dec = old_ax_dec + cum_shift
                ax_dec_name = self.reverse_map[orig_ax_dec]
                new_axes[ax_dec_name].remove(old_ax_dec)
                new_axes[ax_dec_name].append(old_ax_dec - 1)
        return self._compat_axes(new_axes)

    def insert_axis(self, axis: Union[Collection[int], int], new_name: str):
        """Create an axes dict from self with specified axis or axes
        added and all greater axes incremented.

        Arguments:
            axis: the axis index or axes indexes to add.

        Todo:
            May be more efficient to determine final axis-to-axis
            mapping, then apply, rather than apply changes after each
            axis insert.
        """
        new_axes = copy.deepcopy(self.fwd_map)
        in_ndim = len(self.reverse_map)
        if not isinstance(axis, Collection):
            axis = [axis]
        for cum_shift, ax in enumerate(sorted(axis)):
            if new_name in new_axes.keys():
                new_axes[new_name].append(ax)
            else:
                new_axes[new_name] = [ax]
            for ax_id in range(ax, in_ndim + cum_shift):
                ax_name = self.reverse_map[ax_id - cum_shift]
                new_axes[ax_name].remove(ax_id)
                new_axes[ax_name].append(ax_id + 1)
        return self._compat_axes(new_axes)

    @property
    def ndim(self):
        return len(self.reverse_map)


[docs]class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray): """A numpy-like array that keeps track of the meaning of its axes. Limitations: * Not all numpy functions, such as ``np.flatten()``, have an implementation for ``AxesArray``. In such cases a regular numpy array is returned. * For functions that are implemented for `AxesArray`, such as ``np.reshape()``, use the numpy function rather than the bound method (e.g. ``arr.reshape``) * Such functions may raise ``ValueError`` where numpy would not, when it is impossible to determine the output axis labels. Current array function implementations: * ``np.concatenate`` * ``np.reshape`` * ``np.transpose`` * ``np.linalg.solve`` * ``np.einsum`` * ``np.tensordot`` Indexing: AxesArray supports all of the basic and advanced indexing of numpy arrays, with the addition that new axes can be inserted with a string name for the axis. E.g. ``arr = arr[..., "lineno"]`` will add a length-one axis at the end, along with the properties ``arr.ax_lineno`` and ``arr.n_lineno``. If ``None`` or ``np.newaxis`` are passed, the axis is named "unk". Parameters: input_array: the data to create the array. axes: A dictionary of axis labels to shape indices. Axes labels must be of the format "ax_name". indices can be either an int or a list of ints. Attributes: axes: dictionary of axis name to dimension index/indices ax_<ax_name>: lookup ax_name in axes n_<ax_name>: lookup shape of subarray defined by ax_name Raises: AxesWarning if axes does not match shape of input_array. ValueError if assigning the same axis index to multiple meanings or assigning an axis beyond ndim. """ _ax_map: _AxisMapping def __new__(cls, input_array: NDArray, axes: CompatDict[int]): obj = np.asarray(input_array).view(cls) in_ndim = len(input_array.shape) obj._ax_map = _AxisMapping(axes, in_ndim) return obj @property def axes(self): return self._ax_map.compat_axes @property def _reverse_map(self): return self._ax_map.reverse_map @property def shape(self): """Shape of array. Unlike numpy ndarray, this is not assignable.""" return super().shape
[docs] def insert_axis( self, axis: Union[Collection[int], int], new_name: str ) -> CompatDict[int]: """Create the constructor axes dict from this array, with new axis/axes""" return self._ax_map.insert_axis(axis, new_name)
[docs] def remove_axis(self, axis: Union[Collection[int], int]) -> CompatDict[int]: """Create the constructor axes dict from this array, without axis/axes""" return self._ax_map.remove_axis(axis)
def __getattr__(self, name): # TODO: replace with structural pattern matching on Oct 2025 (3.9 EOL) parts = name.split("_", 1) if parts[0] == "ax": try: return self.axes[name] except KeyError: raise AttributeError(f"AxesArray has no axis '{name}'") if parts[0] == "n": try: ax_ids = self._ax_map.fwd_map["ax_" + parts[1]] except KeyError: raise AttributeError(f"AxesArray has no axis '{name}'") shape = tuple(self.shape[ax_id] for ax_id in ax_ids) if len(shape) == 1: return shape[0] return shape raise AttributeError(f"'{type(self)}' object has no attribute '{name}'") def __getitem__(self, key: Union[Indexer, Sequence[Indexer]], /): if isinstance(key, tuple): base_indexer = tuple(None if isinstance(k, str) else k for k in key) else: base_indexer = key output = super().__getitem__(base_indexer) if not isinstance(output, AxesArray): return output # return an element from the array in_dim = self.shape key, adv_inds = _standardize_indexer(self, key) bcast_nd, bcast_start_ax = _determine_adv_broadcasting(key, adv_inds) if adv_inds: key = _replace_adv_indexers(key, adv_inds, bcast_start_ax, bcast_nd) remove_axes, new_axes, adv_names = _apply_indexing(key, self._reverse_map) new_axes = _rename_broadcast_axes(new_axes, adv_names) new_map = _AxisMapping( self._ax_map.remove_axis(remove_axes), len(in_dim) - len(remove_axes) ) for insert_counter, (new_ax_ind, new_ax_name) in enumerate(new_axes): new_map = _AxisMapping( new_map.insert_axis(new_ax_ind, new_ax_name), in_ndim=len(in_dim) - len(remove_axes) + (insert_counter + 1), ) output._ax_map = new_map return output def __array_finalize__(self, obj) -> None: if obj is None: # explicit construction via super().__new__() return # view from numpy array, called in constructor but also tests if all( ( not isinstance(obj, AxesArray), self.shape == (), not hasattr(self, "_ax_map"), ) ): self._ax_map = _AxisMapping({}, in_ndim=0) # required by ravel() and view() used in numpy testing. Also for zeros_like... elif all( ( isinstance(obj, AxesArray), hasattr(obj, "_ax_map"), not hasattr(self, "_ax_map"), self.shape == obj.shape, ) ): self._ax_map = _AxisMapping(obj.axes, obj.ndim) # Using a poorly-initialized AxesArray # Occurs in MaskedArray.ravel, used in some plotting. MaskedArray views # of AxesArray lose the axes attributes, and then the _ax_map attributes. # See numpy.ma.core:asanyarray elif all( ( isinstance(obj, AxesArray), not hasattr(obj, "_ax_map"), ) ): self._ax_map = _AxisMapping({"ax_unk": 0}, in_ndim=1) # maybe add errors for incompatible views? def __array_ufunc__( self, ufunc, method, *inputs, out=None, **kwargs ): # this method is called whenever you use a ufunc args = [] for input_ in inputs: if isinstance(input_, AxesArray): args.append(input_.view(np.ndarray)) else: args.append(input_) outputs = out if outputs: out_args = [] for output in outputs: if isinstance(output, AxesArray): out_args.append(output.view(np.ndarray)) else: out_args.append(output) kwargs["out"] = tuple(out_args) else: outputs = (None,) * ufunc.nout results = super().__array_ufunc__(ufunc, method, *args, **kwargs) if results is NotImplemented: return NotImplemented if method == "at": return if ufunc.nout == 1: results = (results,) if method == "reduce" and ( "keepdims" not in kwargs.keys() or kwargs["keepdims"] is False ): axes = None if kwargs["axis"] is not None: axes = self._ax_map.remove_axis(axis=kwargs["axis"]) else: axes = self.axes final_results = [] for result, output in zip(results, outputs): if output is not None: final_results.append(output) elif axes is None: final_results.append(result) else: final_results.append(AxesArray(np.asarray(result), axes)) results = tuple(final_results) return results[0] if len(results) == 1 else results def __array_function__(self, func, types, args, kwargs): if func not in HANDLED_FUNCTIONS: return super(AxesArray, self).__array_function__(func, types, args, kwargs) if not all(issubclass(t, AxesArray) for t in types): return NotImplemented return HANDLED_FUNCTIONS[func](*args, **kwargs)
def _implements(numpy_function): """Register an __array_function__ implementation for AxesArray objects.""" def decorator(func): HANDLED_FUNCTIONS[numpy_function] = func return func return decorator
[docs]@_implements(np.ravel) def ravel(a, order="C"): out = np.ravel(np.asarray(a), order=order) is_1d_already = len(a.shape) == 1 if is_1d_already: return AxesArray(out, a.axes) else: return AxesArray(out, {"ax_unk": 0})
[docs]@_implements(np.ix_) def ix_(*args: AxesArray): calc = np.ix_(*(np.asarray(arg) for arg in args)) ax_names = [list(arr.axes)[0] for arr in args] axes = fwd_from_names(ax_names) return tuple(AxesArray(arr, axes) for arr in calc)
[docs]@_implements(np.concatenate) def concatenate(arrays, axis=0, out=None, dtype=None, casting="same_kind"): parents = [np.asarray(obj) for obj in arrays] ax_list = [obj.axes for obj in arrays if isinstance(obj, AxesArray)] for ax1, ax2 in zip(ax_list[:-1], ax_list[1:]): if ax1 != ax2: raise ValueError("Concatenating >1 AxesArray with incompatible axes") result = np.concatenate(parents, axis, out=out, dtype=dtype, casting=casting) if isinstance(out, AxesArray): out._ax_map = _AxisMapping(ax_list[0], in_ndim=result.ndim) return AxesArray(result, axes=ax_list[0])
[docs]@_implements(np.reshape) def reshape(a: AxesArray, newshape: int | tuple[int], order="C"): """Gives a new shape to an array without changing its data. Args: a: Array to be reshaped newshape: int or tuple of ints The new shape should be compatible with the original shape. In addition, the axis labels must make sense when the data is translated to a new shape. Currently, the only use case supported is to flatten an outer product of two or more axes with the same label and size. order: Must be "C" """ if order != "C": raise ValueError("AxesArray only supports reshaping in 'C' order currently.") out = np.reshape(np.asarray(a), newshape, order) # handle any regular errors new_axes = {} if isinstance(newshape, int): newshape = [newshape] newshape = list(newshape) explicit_new_size = np.multiply.reduce(np.array(newshape)) if explicit_new_size < 0: replace_ind = newshape.index(-1) newshape[replace_ind] = a.size // (-1 * explicit_new_size) curr_base = 0 for curr_new in range(len(newshape)): if curr_base >= a.ndim: raise ValueError( "Cannot reshape an AxesArray this way. Adding a length-1 axis at" f" dimension {curr_new} not understood." ) base_name = a._ax_map.reverse_map[curr_base] if a.shape[curr_base] == newshape[curr_new]: compat_dict_append(new_axes, base_name, curr_new) curr_base += 1 elif newshape[curr_new] == 1: raise ValueError( f"Cannot reshape an AxesArray this way. Inserting a new axis at" f" dimension {curr_new} of new shape is not supported" ) else: # outer product remaining = newshape[curr_new] while remaining > 1: if a._ax_map.reverse_map[curr_base] != base_name: raise ValueError( "Cannot reshape an AxesArray this way. It would combine" f" {base_name} with {a._ax_map.reverse_map[curr_base]}" ) remaining, error = divmod(remaining, a.shape[curr_base]) if error: raise ValueError( f"Cannot reshape an AxesArray this way. Array dimension" f" {curr_base} has size {a.shape[curr_base]}, must divide into" f" newshape dimension {curr_new} with size" f" {newshape[curr_new]}." ) curr_base += 1 compat_dict_append(new_axes, base_name, curr_new) return AxesArray(out, axes=new_axes)
[docs]@_implements(np.transpose) def transpose(a: AxesArray, axes: Optional[Union[Tuple[int], List[int]]] = None): """Returns an array with axes transposed. Args: a: input array axes: As the numpy function """ out = np.transpose(np.asarray(a), axes) if axes is None: axes = range(a.ndim)[::-1] new_axes = {} old_reverse = a._ax_map.reverse_map for new_ind, old_ind in enumerate(axes): compat_dict_append(new_axes, old_reverse[old_ind], new_ind) return AxesArray(out, new_axes)
[docs]@_implements(np.einsum) def einsum( subscripts: str, *operands: AxesArray, out: Optional[NDArray] = None, **kwargs ) -> AxesArray: calc = np.einsum( subscripts, *(np.asarray(arr) for arr in operands), out=out, **kwargs ) try: # explicit mode lscripts, rscript = subscripts.split("->") except ValueError: # implicit mode lscripts = subscripts rscript = "".join( sorted(c for c in set(subscripts) if subscripts.count(c) == 1 and c != ",") ) # 0-dimensional case, may just be better to check type of "calc": if rscript == "": return calc # assemble output reverse map allscript_names = _label_einsum_scripts(lscripts, operands) out_names = [] for char in rscript.replace("...", "."): if char == ".": for script_names in allscript_names: out_names += script_names.get("...", []) else: ax_names = [] for script_names in allscript_names: ax_names += script_names.get(char, []) ax_name = "ax_" + _join_unique_names(ax_names) out_names.append(ax_name) out_axes = fwd_from_names(out_names) if isinstance(out, AxesArray): out._ax_map = _AxisMapping(out_axes, calc.ndim) return AxesArray(calc, axes=out_axes)
def _join_unique_names(l_of_s: List[str]) -> str: ordered_uniques = dict.fromkeys(l_of_s).keys() return "_".join( ax_name[3:] if ax_name[:3] == "ax_" else ax_name for ax_name in ordered_uniques ) def _label_einsum_scripts( lscripts: List[str], operands: tuple[AxesArray] ) -> List[dict[str, str]]: """Create a list of what axis name each script refers to in its operand.""" allscript_names: List[Dict[str, List[str]]] = [] for lscr, op in zip(lscripts.split(","), operands): script_names: Dict[str, List[str]] = {} allscript_names.append(script_names) # handle script ellipses try: ell_ind = lscr.index("...") ell_width = op.ndim - (len(lscr) - 3) ell_expand = range(ell_ind, ell_ind + ell_width) ell_names = [op._ax_map.reverse_map[ax_ind] for ax_ind in ell_expand] script_names["..."] = ell_names except ValueError: ell_ind = len(lscr) ell_width = 0 # handle script non-ellipsis chars shift = 0 for ax_ind, char in enumerate(lscr): if char == ".": shift += 1 continue if ax_ind < ell_ind: scr_name = op._ax_map.reverse_map[ax_ind] else: scr_name = op._ax_map.reverse_map[ax_ind - 3 + ell_width] compat_dict_append(script_names, char, [scr_name]) return allscript_names
[docs]@_implements(np.linalg.solve) def linalg_solve(a: AxesArray, b: AxesArray) -> AxesArray: result = np.linalg.solve(np.asarray(a), np.asarray(b)) a_rev = a._ax_map.reverse_map a_names = [a_rev[k] for k in sorted(a_rev)] contracted_axis_name = a_names[-1] b_rev = b._ax_map.reverse_map b_names = [b_rev[k] for k in sorted(b_rev)] match_axes_list = a_names[:-1] start = max(b.ndim - a.ndim, 0) end = start + len(match_axes_list) align = slice(start, end) if match_axes_list != b_names[align]: raise ValueError("Mismatch in operand axis names when aligning A and b") all_names = ( b_names[: align.stop - 1] + [contracted_axis_name] + b_names[align.stop :] ) axes = fwd_from_names(all_names) return AxesArray(result, axes)
[docs]@_implements(np.tensordot) def tensordot( a: AxesArray, b: AxesArray, axes: Union[int, Sequence[Sequence[int]]] = 2 ) -> AxesArray: sub = _tensordot_to_einsum(a.ndim, b.ndim, axes) return einsum(sub, a, b)
def _tensordot_to_einsum( a_ndim: int, b_ndim: int, axes: Union[int, Sequence[Sequence[int]]] ) -> str: lc_ord = range(97, 123) sub_a = "".join([chr(code) for code in lc_ord[:a_ndim]]) if isinstance(axes, int): axes = [range(-axes, 0), range(0, axes)] sub_b_li = [chr(code) for code in lc_ord[a_ndim : a_ndim + b_ndim]] if np.array(axes).max() > 26: raise ValueError("Too many axes") for a_ind, b_ind in zip(*axes): sub_b_li[b_ind] = sub_a[a_ind] sub_b = "".join(sub_b_li) sub = f"{sub_a},{sub_b}" return sub def _standardize_indexer( arr: np.ndarray, key: Indexer | Sequence[Indexer] ) -> tuple[Sequence[StandardIndexer], tuple[KeyIndex, ...]]: """Convert any legal numpy indexer to a "standard" form. Standard form involves creating an equivalent indexer that is a tuple with one element per index of the original axis. All advanced indexer elements are converted to numpy arrays, and boolean arrays are converted to integer arrays with obj.nonzero(). Returns: A tuple of the normalized indexer as well as the indexes of advanced indexers """ if isinstance(key, tuple): key = list(key) else: key = [key] if not any(ax_key is Ellipsis for ax_key in key): key = [*key, Ellipsis] new_key: List[Indexer] = [] for ax_key in key: if not isinstance(ax_key, get_args(BasicIndexer)): ax_key = np.array(ax_key) if ax_key.dtype == np.dtype(np.bool_): new_key += ax_key.nonzero() continue new_key.append(ax_key) new_key = _expand_indexer_ellipsis(new_key, arr.ndim) # Can't identify position of advanced indexers before expanding ellipses adv_inds: List[KeyIndex] = [] for key_ind, ax_key in enumerate(new_key): if isinstance(ax_key, np.ndarray): adv_inds.append(KeyIndex(key_ind)) return new_key, tuple(adv_inds) def _expand_indexer_ellipsis(key: List[Indexer], ndim: int) -> List[Indexer]: """Replace ellipsis in indexers with the appropriate amount of slice(None)""" # [...].index errors if list contains numpy array ellind = [ind for ind, val in enumerate(key) if val is ...][0] n_new_dims = sum(ax_key is None or isinstance(ax_key, str) for ax_key in key) n_ellipsis_dims = ndim - (len(key) - n_new_dims - 1) new_key = key[:ellind] + key[ellind + 1 :] new_key = new_key[:ellind] + (n_ellipsis_dims * [slice(None)]) + new_key[ellind:] return new_key def _determine_adv_broadcasting( key: Sequence[StandardIndexer], adv_inds: Sequence[OldIndex] ) -> tuple[int, Optional[KeyIndex]]: """Calculate the shape and location for the result of advanced indexing.""" adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:])) adv_indexers = [np.array(key[i]) for i in adv_inds] bcast_nd = np.broadcast(*adv_indexers).nd bcast_start_axis = 0 if not adjacent else min(adv_inds) if adv_inds else None return bcast_nd, KeyIndex(bcast_start_axis) def _rename_broadcast_axes( new_axes: List[tuple[int, None | str | Literal[_Sentinels.ADV_NAME]]], adv_names: List[str], ) -> List[tuple[int, str]]: """Normalize sentinel and NoneType names""" def _calc_bcast_name(*names: str) -> str: if not names: return "" if all(a == b for a, b in zip(names[1:], names[:-1])): return names[0] names = [name[3:] for name in dict.fromkeys(names)] # ordered deduplication return "ax_" + "_".join(names) bcast_name = _calc_bcast_name(*adv_names) renamed_axes = [] for ax_ind, ax_name in new_axes: if ax_name is None: renamed_axes.append((ax_ind, "ax_unk")) elif ax_name is _Sentinels.ADV_NAME: renamed_axes.append((ax_ind, bcast_name)) else: renamed_axes.append((ax_ind, "ax_" + ax_name)) return renamed_axes def _replace_adv_indexers( key: Sequence[StandardIndexer], adv_inds: List[int], bcast_start_ax: int, bcast_nd: int, ) -> tuple[ Union[None, str, int, Literal[_Sentinels.ADV_NAME], Literal[_Sentinels.ADV_REMOVE]], ..., ]: for adv_ind in adv_inds: key[adv_ind] = _Sentinels.ADV_REMOVE key = key[:bcast_start_ax] + bcast_nd * [_Sentinels.ADV_NAME] + key[bcast_start_ax:] return key def _apply_indexing( key: tuple[StandardIndexer], reverse_map: Dict[int, str] ) -> tuple[ List[int], List[tuple[int, None | str | Literal[_Sentinels.ADV_NAME]]], List[str] ]: """Determine where axes should be removed and added Only considers the basic indexers in key. Numpy arrays are treated as slices, in that they don't affect the final dimensions of the output """ remove_axes = [] new_axes = [] adv_names = [] deleted_to_left = 0 added_to_left = 0 for key_ind, indexer in enumerate(key): if isinstance(indexer, int) or indexer is _Sentinels.ADV_REMOVE: orig_arr_axis = key_ind - added_to_left if indexer is _Sentinels.ADV_REMOVE: adv_names.append(reverse_map[orig_arr_axis]) remove_axes.append(orig_arr_axis) deleted_to_left += 1 elif ( indexer is None or indexer is _Sentinels.ADV_NAME or isinstance(indexer, str) ): new_arr_axis = key_ind - deleted_to_left new_axes.append((new_arr_axis, indexer)) added_to_left += 1 return remove_axes, new_axes, adv_names
[docs]def comprehend_axes(x): axes = {} axes["ax_coord"] = len(x.shape) - 1 axes["ax_time"] = len(x.shape) - 2 if x.ndim > 2: axes["ax_spatial"] = list(range(len(x.shape) - 2)) return axes
[docs]class SampleConcatter(TransformerMixin): def __init__(self): pass
[docs] def fit(self, x_list, y_list): return self
def __sklearn_is_fitted__(self): return True
[docs] def transform(self, x_list): return concat_sample_axis(x_list)
[docs]def concat_sample_axis(x_list: List[AxesArray]): """Concatenate all trajectories and axes used to create samples.""" new_arrs = [] for x in x_list: sample_ax_names = ("ax_spatial", "ax_time", "ax_sample") sample_ax_inds = [] for name in sample_ax_names: ax_inds = getattr(x, name, []) if isinstance(ax_inds, int): ax_inds = [ax_inds] sample_ax_inds += ax_inds new_axes = {"ax_sample": 0, "ax_coord": 1} n_samples = np.prod([x.shape[ax] for ax in sample_ax_inds]) arr = AxesArray(x.reshape((n_samples, x.shape[x.ax_coord])), new_axes) new_arrs.append(arr) return np.concatenate(new_arrs, axis=new_arrs[0].ax_sample)
[docs]def wrap_axes(axes: dict, obj): """Add axes to object (usually, a sparse matrix)""" for key in ["ax_spatial", "ax_time", "ax_sample", "ax_coord"]: try: obj.__setattr__(key, axes[key]) except KeyError: pass return obj
[docs]def compat_dict_append( compat_dict: CompatDict[T], key: str, item_or_list: ItemOrList[T], ) -> None: """Add an element or list of elements to a dictionary, preserving old values""" try: prev_val = compat_dict[key] except KeyError: compat_dict[key] = item_or_list return if not isinstance(item_or_list, list): item_or_list = [item_or_list] if not isinstance(prev_val, list): prev_val = [prev_val] compat_dict[key] = prev_val + item_or_list
[docs]def fwd_from_names(names: List[str]) -> CompatDict[int]: """Create mapping of name: axis or name: [ax_1, ax_2, ...]""" fwd_map: Dict[str, Sequence[int]] = {} for ax_ind, name in enumerate(names): compat_dict_append(fwd_map, name, [ax_ind]) return fwd_map