Source code for steps.API_2.saving

####################################################################################
#
#    STEPS - STochastic Engine for Pathway Simulation
#    Copyright (C) 2007-2023 Okinawa Institute of Science and Technology, Japan.
#    Copyright (C) 2003-2006 University of Antwerp, Belgium.
#    
#    See the file AUTHORS for details.
#    This file is part of STEPS.
#    
#    STEPS is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License version 3,
#    as published by the Free Software Foundation.
#    
#    STEPS is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#    GNU General Public License for more details.
#    
#    You should have received a copy of the GNU General Public License
#    along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#################################################################################   
###

import collections
import copy
import datetime
import enum
import functools
import itertools
import math
import numbers
import operator
import os
import pickle
import re
import sqlite3
import struct
import sys
import warnings

from xml.etree import ElementTree

import numpy
import steps

from . import _saving_optim as nsaving_optim
from . import geom as ngeom
from . import model as nmodel
from . import sim as nsim
from . import utils as nutils

__all__ = [
    'ResultSelector',
    'CustomResults',
    'DatabaseHandler',
    'DatabaseGroup',
    'SQLiteDBHandler',
    'SQLiteGroup',
    'HDF5Handler',
    'HDF5Group',
    'XDMFHandler',
]

###################################################################################################
# Exceptions


class ReadOnlyWriteError(Exception):
    """
    :meta private:
    """
    pass


class UnavailableDataError(Exception):
    """
    :meta private:
    """
    pass

###################################################################################################
# Result selectors


class _MetaData(nutils.MutableDictInterface):
    """
    Small utility class for handling metadata setting and getting through __getitem__ and
    __setitem__. Behaves like a dict but performs additional checks when setting values.
    """

    def __init__(self, parent):
        super().__init__()
        self._parent = parent
        self._dict = {}

    def _checkKey(self, key):
        """Check that the given key is valid."""
        if not isinstance(key, str):
            raise TypeError(
                f'MetaData can only be accessed by first specifying a string key, '
                f'got a {type(key)} instead.'
            )

    def _clear(self):
        self._dict = {}

    def __getitem__(self, key):
        """Return the metadata corresponding to key."""
        self._checkKey(key)
        if key not in self._dict:
            raise KeyError(f'No metadata corresponding to key {key}.')
        return nutils.nparray(self._dict[key])

    def __setitem__(self, key, val, _internal=False):
        """Set the metadata corresponding to key."""
        self._checkKey(key)
        if self._parent._savingStarted() and not _internal:
            raise Exception(f'Cannot save metadata once sim.newRun() has been called.')

        # Convert to list in case val is a generator
        lst = list(val)

        if len(lst) != self._parent._getEvalLen():
            raise Exception(
                f'Expected a list of length {self._parent._getEvalLen()}, got a list '
                f'of length {len(lst)}.'
            )
        if not all(isinstance(v, (numbers.Number, str)) or v is None for v in lst):
            raise TypeError(f'Metadata can only be composed of numbers and / or strings.')

        if key in self._dict and self._dict[key] != lst and not _internal:
            warnings.warn(
                f'The metadata associated with key {key} was already set, replacing with new values.'
            )

        self._dict[key] = lst

    def keys(self):
        return self._dict.keys()


class _LabelSelector:
    """Utility class for pointing to a specific label of a ResultSelector"""
    def __init__(self, sel, ind):
        self.sel = sel
        self.ind = ind


[docs]class ResultSelector: """Class to describe which data should be saved during simulation :param sim: The simulation for which we want to select data :type sim: :py:class:`steps.API_2.sim.Simulation` This class works in a way that is very similar to :py:class:`steps.API_2.sim.SimPath`, paths to the data that should be saved are built in the same way, using dot syntax. For :py:class:`steps.API_2.sim.SimPath`, the root of the path is the simulation itself and when a path is completed with a property (e.g. `Count`), it returns the actual value in the simulation. Since :py:class:`ResultSelector` aims at describing the data to be saved, we have to use a different root for our paths:: >>> sim.comp1.S1.Count 13 >>> rs = ResultSelector(sim) >>> rs.comp1.S1.Count comp1.S1.Count While the path whose root is the actual simulation returns a number, the path whose root is the result selector object does not. Any methods defined in :py:class:`steps.API_2.sim.SimPath` can be used to build result selector paths. In addition, result selectors can be combined using standard arithmetic operators (see :py:func:`ResultSelector.__add__`, etc.) and can be concatenated with ``<<`` (see :py:func:`ResultSelector.__lshift__`):: rs1 = rs.comp1.S1.Count + rs.comp1.S2.Count # This result selector will save a single # value that corresponds to the sum of S1 # and S2 counts in comp1. rs2 = rs.comp1.S1.Count << rs.comp1.S2.Count # This one will save 2 values, the count of # S1 in comp1 and the count of S2 in comp1. Result selectors can also transform data and only save the result of the transformation:: rs3 = rs.SUM(rs.TETS(tetlst).S1.Count) # This will save only one value: the total number # of S1 in all the tetrahedrons in tetlst. Once we defined all our result selectors, we need to add them to the :py:class:`steps.API_2.sim.Simulation` so that the corresponding data gets saved during simulation runs. This is done with e.g.:: sim.toSave(rs1, rs2, rs3, dt=0.01) # Save the three result selectors every 0.01 seconds. After simulations have been run, results can be accessed with the same result selector objects:: rs1.data[0] # Accessing the data saved during run 0 rs1.time[0] # The time points associated to each saving for run 0 Usage of result selectors is presented in more details in the user guide. """ def __init__(self, sim, *args, **kwargs): super().__init__(*args, **kwargs) if not isinstance(sim, nsim.Simulation): raise TypeError(f'Expected a Simulation object, got {sim} instead.') self.sim = sim self._dataHandler = _MemoryDataHandler(self) self._saveDt = None self._saveTpnts = None self._saveTind = None self._nextTime = math.inf self._addedToSim = False self._selectorInd = None self._optimGroupInd = None self._labels = None self._metaData = _MetaData(self) self._description = None # When the result selector gets distributed across MPI rank, it keeps a list of the indexes # of its values in the original result selector self._distrInds = None self._fullLen = None
[docs] @classmethod def FromFile(cls, path): """Load data that has been saved to a file :param path: Path to the file :type path: str Result selectors that have been saved to file (with the :py:func:`toFile` method), can then be loaded in a different python process and be used in the same way as in the simulaiton process. Usage:: rs1 = ResultSelector.FromFile('path/to/file') plt.plot(rs1.time[0], rs1.data[0]) plt.legend(rs1.labels) ... """ return _ReadOnlyResultSelector(_FileDataHandler(None, path))
@property def time(self): """Get the time points at which data saving was done for this result selector An accessor to the timepoints data that should then be indexed with square brackets notation. The underlying data it two dimensional; the first dimension corresponds to runs and the second to time. :type: Data accessor, read-only Usage assuming 5 runs of 1s with data saving every 10ms:: >>> rs1.time[0] # Time points of first run array([0., 0.01, 0.02, ..., 0.98, 0.99, 1.]) >>> rs1.time[0, -1] # Last time point of first run array(1.) >>> rs1.time[0][-1] # Same as above array(1.) >>> rs1.time[:, -1] # Last time point of all 5 runs array([1., 1., 1., 1., 1.]) >>> rs1.time[1:3, 0] # First time point of 2nd and 3rd runs array([0, 0]) >>> rs1.time[...] # All time points of all runs array([[0., 0.01, 0.02, ..., 0.98, 0.99, 1.], [0., 0.01, 0.02, ..., 0.98, 0.99, 1.], [0., 0.01, 0.02, ..., 0.98, 0.99, 1.], [0., 0.01, 0.02, ..., 0.98, 0.99, 1.], [0., 0.01, 0.02, ..., 0.98, 0.99, 1.]]) .. warning:: Although the type of this property implements square bracket element access, it is not a list or an array itself and does not directly contain the data. The data is only really accessed upon using the square bracket notation. To force the retrieval of all the data, it is possible to use the ellipsis notation in square brackets: ``rs.time[...]``. """ self._checkAddedToSim() return self._dataHandler.time() @property def data(self): """Get the data that was saved by this result selector An accessor to the data that should then be indexed with square brackets notation The underlying data it three dimensional; the first dimension corresponds to runs, the second to time, and the third to saved paths. :type: Data accessor, read-only Usage assuming 5 runs of 3s, saving 3 values every 1 ms:: >>> rs1.data[0] # Data from the first run array([[312., 221., 0.], [310., 219., 2.], [308., 217., 4.], ... [206., 115., 106.], [205., 114., 107.], [205., 114., 107.]]) >>> rs1.data[0, -1] # Data corresponding to the last time point of first run array([205., 114., 107.]) >>> rs1.data[0][-1] # Same as above array([205., 114., 107.]) >>> rs1.data[:, -1] # Data corresponding to the last time point of all 5 runs array([[205., 114., 107.], [189., 98., 123.], [188., 97., 124.], [185., 95., 127.], [198., 107., 114.]]) >>> rs1.data[0, :, 0] # First saved value for all time points of first run array([312., 310, 308, ..., 206, 205, 205]) >>> rs1.data[...] # All data from all runs array([[[312., 221., 0.], [310., 219., 2.], [308., 217., 4.], ..., [206., 115., 106.], [205., 114., 107.], [205., 114., 107.]], ... [[312., 221., 0.], [309., 218., 3.], [305., 214., 7.], ..., [199., 108., 113.], [199., 108., 113.], [198., 107., 114.]]]) .. warning:: Although the type of this property implements square bracket element access, it is not a list or an array itself and does not directly contain the data. The data is only really accessed upon using the square bracket notation. To force the retrieval of all the data, it is possible to use the ellipsis notation in square brackets: ``rs.data[...]``. """ self._checkAddedToSim() return self._dataHandler.data() @property def labels(self): """A list of descriptions of the values saved by the result selector :type: List[str] By default labels are automatically generated from the result selector. Assuming 3 saved values, one can access their values with:: >>> rs1.labels # Default values, built from the simulation paths used for saving ['comp.molA.Count', 'comp.molB.Count', 'comp.molC.Count'] The labels can also be set by the user but it needs to be done before :py:func:`steps.API_2.sim.Simulation.newRun` has been called. Assuming 3 saved value, one would write:: >>> rs1.labels = ['custom1', 'custom2', 'custom3'] Labels be saved to whichever support the result selector is being saved to (memory, file, database, etc.). """ return self._labels @labels.setter def labels(self, lbls): """Set custom labels.""" lbls = list(lbls) if len(lbls) != self._getEvalLen(): raise Exception( f'Expected a list of length {self._getEvalLen()}, got a list of length {len(lbls)}.' ) if self._dataHandler._savingStarted(): raise Exception(f'Cannot modify the labels once sim.newRun() has been called.') self._labels = lbls @property def metaData(self): """Meta data relative to the values saved by the result selector :type: Mapping[str, List[Union[str, int, float, None]]] This property allows the user to save additional static (i.e. not time-dependent) data about the values being saved by the result selector. It works as a mapping between arbitrary string keys and lists of values that have the same length as the number of values saved by the result selector. The meta data needs to be set before :py:func:`steps.API_2.sim.Simulation.newRun` has been called. Assuming 3 values saved, one could write:: >>> rs1.metaData['key1'] = ['str1', 'str2', 'str3'] >>> rs1.metaData['key2'] = [1, 2, 3] >>> rs1.metaData['key1'] array(['str1', 'str2', 'str3'], dtype='<U4') >>> 'key2' in rs1.metaData True >>> 'key3' in rs1.metaData False Like labels, meta data will be saved to whichever support the result selector is being saved to (memory, file, database, etc.). .. note:: Some path elements automatically define their own meta data, one can always check which meta data is already declared with e.g. ``print(rs1.metaData.keys())`` .. warning:: Although the type of this property implements square bracket key access, it is not a dict itself and does not directly contain the data. The data is only really accessed upon using the square bracket notation. However, it does implement ``keys()``, ``items()``, ``__iter__()`` and ``__contains__()`` so it can be used like a dict to some extent. """ return self._metaData @property def description(self): """String description of the result selector :type: str All results selectors have a default string description generated by STEPS. It can be modified by setting this property and the changes will be saved to whichever support the result selector is being saved to (memory, file, database, etc.). """ return self._description @description.setter def description(self, descr): """Set custom description""" if not isinstance(descr, str): raise TypeError(f'Expected a string as description, got {descr} instead.') self._description = descr
[docs] def toFile(self, path, buffering=-1): """Specify that the data should be saved to a file :param path: The path to the file :type path: str :param buffering: The buffering parameter passed to the ``open()`` function, see https://docs.python.org/3/library/functions.html#open for details :type buffering: int This method should be called before :py:func:`steps.API_2.sim.Simulation.newRun` has been called. The file is written in a custom binary format and can be read in a different python process by creating a result selector from file with :py:func:`ResultSelector.FromFile`. .. warning:: After all simulations are finished, depending on the buffering policy, it is possible that the file does not contain all the data. The data will be flushed to the file upon destruction of the result selector (when the python process ends for example). This should not create any issues for using the result selector in the process in which it was created (because the data that might not be written to file is kept in memory) but it could create issues when trying to load the file from another python process while the first one is still running. """ self._checkComplete() self._dataHandler = _FileDataHandler(self, path, self._getEvalLen(), buffering)
def _newRun(self): """Signal that a new run of the simulation started.""" self._dataHandler._newRun() # Initialize time save points if self._saveDt is not None: self._saveTind = 0 self._nextTime = 0 elif self._saveTpnts is not None and len(self._saveTpnts) > 0: self._saveTind = 0 self._nextTime = self._saveTpnts[0] else: self._saveTind = None self._nextTime = math.inf
[docs] def save(self): """Trigger saving of the result selector at the current simulation time Most saving should be done automatically by providing a ``dt`` or a ``timePoints`` argument to the :py:func:`steps.API_2.sim.Simulation.toSave` method but it is possible to manually decide when to save data by calling ``save()`` on a result selector during simulation. Usage:: for r in range(NBRUNS): sim.newRun() for t in timePoints: sim.run(t) rs1.save() # Saving values manually """ self._checkAddedToSim() self._save(self.sim.Time, (self.sim.Time, self.sim._runId))
[docs] def clear(self): """Discard all recorded data This method is only available for ResultSelectors that do not save data to files. """ self._checkComplete() self._dataHandler.clear()
def _toDB(self, dbhanlder): """ Specify that the data should be saved to a database. """ self._checkComplete() self._dataHandler = dbhanlder._getDataHandler(self) def _saveWithDt(self, dt): """Specify that the data needs to be saved every dt seconds.""" self._checkComplete() self._saveDt = dt self._saveTpnts = [] self._saveTind = 0 self._nextTime = 0 def _saveWithTpnts(self, tpnts): """Specify at which time points the data should be saved.""" self._saveTpnts = tpnts self._saveDt = None self._saveDtStart = None self._saveTind = 0 self._nextTime = self._saveTpnts[0] def _addedToSimulation(self, ind, rsGroupInd): """Specify that the result selector was added to a simulation with index ind.""" self._checkComplete() self._addedToSim = True self._selectorInd = ind self._optimGroupInd = rsGroupInd def _concat(self, other): """Concatenate two result selectors into a _ResultList.""" return _ResultList([self, other], self.sim) def _save(self, t, solvStateId=None): """Save the data using self._dataHandler.""" self._dataHandler.save(t, self._evaluate(solvStateId)) self._updateNextSaveTime() def _updateNextSaveTime(self): """Update the time of the next save.""" if self._saveTind is not None: self._saveTind += 1 if self._saveDt is not None: self._nextTime = self._saveTind * self._saveDt elif self._saveTind < len(self._saveTpnts): self._nextTime = self._saveTpnts[self._saveTind] else: self._nextTime = math.inf def _distribute(self): """Distribute the path across MPI ranks if it involves mesh elements of a distributed meshes.""" return self, False def _evaluate(self, solvStateId=None): """ Return a list of the values to save. An optional integer can be given to uniquely identify a solver state, this is useful for optimizing solver calls (i.e. not calling several times the same thing if the solver state did not change). """ pass def _getEvalLen(self): """Return the number of values that _evaluate() will return.""" pass
[docs] def __getattr__(self, name): """Redirect attribute access to a SimPath See :py:func:`steps.API_2.sim.SimPath.__getattr__`. .. note:: This method should not be called explicitely, it is only documented for clarity. :meta public: """ try: return super().__getattr__(name) except AttributeError: return getattr(_ResultPath(self.sim), name)
def _checkAddedToSim(self): """Check that the ResultSelector was added to the Simulation.""" if not self._addedToSim: raise Exception( f'Cannot access data from a ResultSelector that was not added to a ' f'simulation with the "toSave" method.' ) def _checkCompatible(self, other): """ Check that 'other' is a ResultSelector that is associated to the same simulation as self """ if not isinstance(other, ResultSelector): raise TypeError(f'Cannot combine a ResultSelector with {other}.') if self.sim != other.sim: raise Exception(f'Cannot combine ResultSelectors associated to different simulations.') self._checkComplete() other._checkComplete() def _checkComplete(self): """Raise an exception if the result selector is not complete.""" raise Exception(f'{self} is not a complete ResultSelector.') def _savingStarted(self): """Return whether data started being saved.""" return self._dataHandler._savingStarted() def _strDescr(self): """Return a default generic description of the ResultSelector.""" raise NotImplementedError() def _binaryOp(self, other, op, symetric=False, opStr='{0} {1}'): """Return a _ResultCombiner that represents the binary operation op.""" labelStrFunc=lambda s1, s2: opStr.format(s1, s2) if isinstance(other, numbers.Number): def opFunc(x): return [op(v, other) for v in x] return _ResultCombiner( opFunc, lambda x: x, [self], self.sim, labelArgFunc=lambda i, chld: (_LabelSelector(chld[0], i), other), labelStrFunc=labelStrFunc, metaDataFunc=lambda vals: vals, strDescr=opStr.format(self.description, other), ) elif isinstance(other, ResultSelector): self._checkCompatible(other) if other._getEvalLen() == 1: def opFunc(x): return [op(v, x[-1]) for v in x[:-1]] def mtdtFunc(vals): v2 = vals[-1] vals = [v1 if v1 == v2 else None for v1 in vals[:-1]] if any(v is not None for v in vals): return vals else: return None return _ResultCombiner( opFunc, lambda x: x - 1, [self, other], self.sim, labelArgFunc=lambda i, chld: (_LabelSelector(chld[0], i), _LabelSelector(chld[1], 0)), labelStrFunc=labelStrFunc, metaDataFunc=mtdtFunc, strDescr=opStr.format(self.description, other.description), ) elif symetric and self._getEvalLen() == 1: return other._binaryOp(self, op, True, opStr) elif other._getEvalLen() == self._getEvalLen(): n = self._getEvalLen() def opFunc(x): return [op(a, b) for a, b in zip(x[:n], x[n:])] def mtdtFunc(vals): vals1 = vals[:len(vals) // 2] vals2 = vals[len(vals) // 2:] ret = [v1 if v1 == v2 else None for v1, v2 in zip(vals1, vals2)] if any(v is not None for v in ret): return ret else: return None return _ResultCombiner( opFunc, lambda x: x // 2, [self, other], self.sim, labelArgFunc=lambda i, chld: (_LabelSelector(chld[0], i), _LabelSelector(chld[1], i)), labelStrFunc=labelStrFunc, metaDataFunc=mtdtFunc, strDescr=opStr.format(self.description, other.description), ) else: raise Exception( f'Cannot apply binary operation {opStr.format("","")}, ' f'incompatible output lengths: "{self}" has an output ' f'length of {self._getEvalLen()} while "{other}" has an ' f'output length of {other._getEvalLen()}.' ) else: raise TypeError(f'Cannot combine a resultSelector with {other} using {op}.')
[docs] def __lshift__(self, other): """Concatenate two result selectors with the ``<<`` operator :param other: The other result selector :type other: :py:class:`ResultSelector` :returns: The result selector resulting from the concatenation of both operands. Its length is thus the sum of both of the operands' lengths. :rtype: :py:class:`ResultSelector` Usage:: rs2 = rs.comp1.S1.Count << rs.comp1.S2.Count # rs2 will save 2 values, the count of S1 # in comp1 and the count of S2 in comp1. :meta public: """ self._checkCompatible(other) return self._concat(other)
[docs] def __mul__(self, other): """Multiply result selectors with the * operator :param other: The other result selector or a number :type other: Union[:py:class:`ResultSelector`, float] :returns: The result selector resulting from the multiplication of both operands. If both operands are result selectors and have the same size, this corresponds to the element-wise product of values. If one of the operand is a number or a result selector of length 1, all values of the result selector are multiplied with this single value. :rtype: :py:class:`ResultSelector` Usage:: rs3 = 10 * rs.TETS(tetlst).S1.Count # rs3 will save the number # of S1 in each tetrahedron # in tetlst, multiplied by # 10. rs4 = rs.TETS(tetlst).S1.Count * rs.TETS(tetlst).S2.Count # rs4 will save the product # of the number of S1 and # the number of S2 in each # tetrahedron in tetLst. :meta public: """ return self._binaryOp(other, operator.mul, symetric=True, opStr='({0} * {1})')
def __rmul__(self, other): return self._binaryOp(other, operator.mul, symetric=True, opStr='({1} * {0})')
[docs] def __truediv__(self, other): """Divide result selectors with the ``/`` operator :param other: The other result selector or a number :type other: Union[:py:class:`ResultSelector`, float] :returns: The result selector resulting from the division of both operands. If both operands are result selectors and have the same size, this corresponds to the element-wise division of values. If one of the operand is a number or a result selector of length 1, all values of the result selectors are divided by this single value (or this single value is divided by all values from the result selector, depending on order). :rtype: :py:class:`ResultSelector` Usage:: rs3 = rs.TETS(tetlst).S1.Count / 10 # rs3 will save the number # of S1 in each tetrahedron # in tetlst, divided by 10. rs4 = 1 / rs.TETS(tetlst).S1.Count # rs4 will save the inverse # of the number of S1 in # each tetrahedron in # tetlst, divided by 10. rs5 = rs.TETS(tetlst).S1.Count / rs.TETS(tetlst).S2.Count # rs5 will save the ratio of # S1 to S2 in each # tetrahedron in tetLst. :meta public: """ return self._binaryOp(other, operator.truediv, symetric=False, opStr='({0} / {1})')
def __rtruediv__(self, other): return self._binaryOp(other, lambda a, b: b / a, symetric=False, opStr='({1} / {0})')
[docs] def __add__(self, other): """Add result selectors with the ``+`` operator :param other: The other result selector or a number :type other: Union[:py:class:`ResultSelector`, float] :returns: The result selector resulting from the addition of both operands. If both operands are result selectors and have the same size, this corresponds to the element-wise addition of values. If one of the operand is a number or a result selector of length 1, this single value is added to all values of the result selector. :rtype: :py:class:`ResultSelector` Usage:: rs3 = 10 + rs.TETS(tetlst).S1.Count # rs3 will save the number # of S1 in each tetrahedron # in tetlst, increased by # 10. rs4 = rs.TETS(tetlst).S1.Count + rs.TETS(tetlst).S2.Count # rs4 will save the sum # of the number of S1 and # the number of S2 in each # tetrahedron in tetLst. :meta public: """ return self._binaryOp(other, operator.add, symetric=True, opStr='({0} + {1})')
def __radd__(self, other): return self._binaryOp(other, operator.add, symetric=True, opStr='({1} + {0})')
[docs] def __sub__(self, other): """Subtract result selectors with the ``-`` operator :param other: The other result selector or a number :type other: Union[:py:class:`ResultSelector`, float] :returns: The result selector resulting from the subtraction of both operands. If both operands are result selectors and have the same size, this corresponds to the element-wise subtraction of values. If one of the operand is a number or a result selector of length 1, this single value is subtracted from all values of the result selectors (or each value from the result selector is subtracted from the single value, depending on order). :rtype: :py:class:`ResultSelector` Usage:: rs3 = rs.TETS(tetlst).S1.Count - 10 # rs3 will save the number # of S1 in each tetrahedron # in tetlst, minus 10. rs4 = 10 - rs.TETS(tetlst).S1.Count # rs4 will save 10 minus the # number of S1 for each # tetrahedron in tetlst. rs5 = rs.TETS(tetlst).S1.Count - rs.TETS(tetlst).S2.Count # rs5 will save the number # of S1 minus the number of # S2 in each tetrahedron in # tetLst. :meta public: """ return self._binaryOp(other, operator.sub, symetric=False, opStr='({0} - {1})')
def __rsub__(self, other): return self._binaryOp(other, lambda a, b: b - a, symetric=False, opStr='({1} - {0})')
[docs] def __pow__(self, other): """Exponentiate result selectors with the ** operator :param other: The other result selector or a number :type other: Union[:py:class:`ResultSelector`, float] :returns: The result selector resulting from the exponentiation of both operands. If both operands are result selectors and have the same size, this corresponds to the element-wise exponentiation of values. If one of the operand is a number or a result selector of length 1, this single value is exponentiated by each value of the result selector (or each value in the result selector is exponentiated by the single value, depending on order). :rtype: :py:class:`ResultSelector` Usage:: rs3 = rs.TETS(tetlst).S1.Count ** 2 # rs3 will save the square # of the number of S1 in # each tetrahedron in # tetlst. :meta public: """ return self._binaryOp(other, operator.pow, symetric=False, opStr='({0} ** {1})')
# Needed for the heapq ordering in Simulation def __lt__(self, other): return True
[docs] @classmethod def SUM(cls, sel): """Sum of all values from a result selector :param sel: Result selector whose values should be summed :type sel: :py:class:`ResultSelector` :returns: A result selector with a single value that corresponds to the sum of the values in ``sel``. :rtype: :py:class:`ResultSelector` Usage:: rs3 = rs.SUM(rs.TETS(tetLst).S1.Count) # The total number of S1 in tetLst """ return _ResultCombiner( lambda x: [sum(x)], lambda x: 1, [sel], sel.sim, labelArgFunc=lambda _, d: tuple(c.description for c in d), labelStrFunc=lambda *args: f"SUM({' + '.join(args)})", strDescr=f'SUM({sel.description})', )
[docs] @classmethod def MIN(cls, sel): """Minimum of all values from a result selector :param sel: Result selector whose values should be used :type sel: :py:class:`ResultSelector` :returns: A result selector with a single value that corresponds to the minimum of the values in ``sel``. :rtype: :py:class:`ResultSelector` Usage:: rs3 = rs.MIN(rs.TETS(tetLst).S1.Count) # The minimum number of S1 in tetLst """ return _ResultCombiner( lambda x: [min(x)], lambda x: 1, [sel], sel.sim, labelArgFunc=lambda _, d: tuple(c.description for c in d), labelStrFunc=lambda *args: f"MIN({', '.join(args)})", strDescr=f'MIN({sel.description})', )
[docs] @classmethod def MAX(cls, sel): """Maximum of all values from a result selector :param sel: Result selector whose values should be used :type sel: :py:class:`ResultSelector` :returns: A result selector with a single value that corresponds to the maximum of the values in ``sel``. :rtype: :py:class:`ResultSelector` Usage:: rs3 = rs.MAX(rs.TETS(tetLst).S1.Count) # The maximum number of S1 in tetLst """ return _ResultCombiner( lambda x: [max(x)], lambda x: 1, [sel], sel.sim, labelArgFunc=lambda _, d: tuple(c.description for c in d), labelStrFunc=lambda *args: f"MAX({', '.join(args)})", strDescr=f'MAX({sel.description})', )
[docs] @classmethod def JOIN(cls, selectors): """Concatenate values from several result selectors :param selectors: Result selectors whose values should be concatenated :type selectors: Iterable[:py:class:`ResultSelector`] :returns: A result selector that corresponds to the concatenatin of the values in ``selectors``. :rtype: :py:class:`ResultSelector` Usage:: # The total number of species for each tetrahedron in tetLst rs3 = rs.JOIN(rs.SUM(rs.TET(tet).ALL(Species).Count) for tet in tetLst) """ sels = list(selectors) if len(sels) == 0: raise ValueError(f'At least one result selector should be supplied to JOIN().') if len(sels) > 1: for s in sels[1:]: sels[0]._checkCompatible(s) return _ResultList(sels, sels[0].sim)
[docs]class CustomResults(ResultSelector): """Class to manually save data This class helps to save data to file or databases in the same format as :py:class:`ResultSelector`. This is useful for cases in which the data to be saved is not easily describable with a standard :py:class:`ResultSelector`. Instead of describing the data to be saved, :py:class:`CustomResults` requires a list of the types of the data to be saved. The user then calls the `save` method with a list of values to be saved whenever it is needed. :param sim: The simulation for which we want to save data :type sim: :py:class:`steps.API_2.sim.Simulation` :param types: A list of types that describes the types of each data that will be saved upon calls `save`. :type types: List[Union[dict, list, str, float, int]] """ _TYPE_MAP = { dict: 'dict', list: 'list', str: 'str', float: None, int: None, } def __init__(self, sim, types, *args, **kwargs): super().__init__(sim, *args, **kwargs) self._types = [] for tpe in types: try: self._types.append(CustomResults._TYPE_MAP[tpe]) except KeyError: raise TypeError( f'Unsupported type {tpe}, types can only be one of:' f'{list(CustomResults._TYPE_MAP.keys())}.' ) self.metaData['value_type'] = self._types self.labels = [f'col{i}_{tpe.__name__}' for i, tpe in enumerate(types)] self.description = self._strDescr()
[docs] def save(self, values): """Save the provided values at the current simulation time Unlike other :py:class:`ResultsSelector`s, :py:class:`CustomResults` will only save values that are explicitely given to it through this method, along with the simulation time at the time of the call. :param values: A list of values that has the same length as the list of types used to create the :py:class:`CustomResults` object. :type values: List[Union[Dict, List, str, Number]] Usage:: cr1 = CustomResults([dict, list, float]) sim.toSave(cr1) # No dt or timepoints parameters sim.newRun() for t in timePoints: sim.run(t) myvalues = [{'key1': 92, 'key2': 6}, ['str1', 'str2'], 42.0] # Compute values rs1.save(myvalues) # Saving values manually """ self._checkAddedToSim() self._dataHandler.save(self.sim.Time, values)
def _saveWithDt(self, dt): raise NotImplementedError('CustomResults cannot be saved automatically.') def _saveWithTpnts(self, tpnts): raise NotImplementedError('CustomResults cannot be saved automatically.') def _concat(self, other): self._checkCompatible() def _getEvalLen(self): return len(self._types) def _strDescr(self): """Return a default generic description of the ResultSelector.""" return 'CustomResults' def __getattr__(self, name): raise AttributeError() def _checkComplete(self): pass def _checkCompatible(self, other): raise NotImplementedError('CustomResults cannot be combined with other result selectors.')
class _ResultPath(ResultSelector): """ Represents a SimPath to be saved during simulation. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.simpath = nsim.SimPath(self.sim) self._descriptor = None self._compiledFuncs = None self._len = None self._descr = [str(self.sim)] # When distributing result paths, the simpath can cover more data than needs to be saved. # We need to call non-distributed methods on all ranks but we only need to save the results # on rank 0. # _simPathMask defines a boolean mask determining which parts of the simpath that should be # saved locally, the other parts are not saved and the corresponding labels and metadata are # not saved either. self._simPathMask = None def _strDescr(self): """Return a default generic description of the ResultSelector.""" return '.'.join(self._descr[1:]) def _evaluate(self, solvStateId=None): """Return a list of the values to save.""" if self._compiledFuncs is None: self._compiledFuncs = self._descriptor._getFinalPathsFunction(self.simpath) if self._simPathMask is None: return [f(*args, **kwargs) for f, args, kwargs in self._compiledFuncs] else: res = [] for i, (f, args, kwargs) in enumerate(self._compiledFuncs): val = f(*args, **kwargs) if self._simPathMask[i]: res.append(val) return res def _getEvalLen(self): """Return the number of values that _evaluate() will return.""" return self._len def _checkComplete(self): """Raise an exception if the path is not complete.""" if self._descriptor is None: raise Exception(f'{self} is incomplete.') def _concat(self, other): """Concatenate two result selectors into a _ResultList.""" self._checkComplete() return super()._concat(other) def _binaryOp(self, other, op, symetric=False, opStr='{0} {1}'): """Return a _ResultCombiner that represents the binary operation op.""" self._checkComplete() return super()._binaryOp(other, op, symetric, opStr) def _distribute(self): """Distribute the path across MPI ranks if it involves mesh elements of a distributed meshes.""" if self._distrInds is not None: return self, False self._fullLen = self._getEvalLen() self.simpath, self._distrInds, spMask, changed = self.simpath._distribute() if changed: if self.simpath is None: return None, changed self._simPathMask = numpy.array(spMask) # Recompute length, labels and metadata self._finalize() return self, changed def _finalize(self): """Finalize the result path by computing length, labels, and metadata""" self._len = len([p for p in self.simpath]) # Compute labels self._labels = [] *_descr, endName = self._descr for descr in self.simpath._getDescriptions(tuple(_descr)): self._labels.append('.'.join(descr[1:] + (endName,))) # Compute automatic metadata mtdt = {} for i, path in enumerate(self.simpath._walk()): for key, lst in mtdt.items(): lst.append(None) if isinstance(path, nutils.SimPathCombiner): # If the path is a combination of paths, only consider metadata that is # common to all paths dct = {} for elem in path.paths[0]: dct.update(elem._simPathAutoMetaData()) for p in path.paths[1:]: dct2 = {} for elem in p: dct2.update(elem._simPathAutoMetaData().items()) dct = {key: dct[key] for key in dct.keys() & dct2.keys() if dct[key] == dct2[key]} for key, val in dct.items(): if key not in mtdt: mtdt[key] = [None] * (i + 1) mtdt[key][i] = val else: # Otherwise save all metadata for the path for elem in path: for key, val in elem._simPathAutoMetaData().items(): if key not in mtdt: mtdt[key] = [None] * (i + 1) mtdt[key][i] = val # Restrict data to simPathMask, if applicable if self._simPathMask is not None: self._len = sum(self._simPathMask) self._labels = [lbl for lbl, msk in zip(self._labels, self._simPathMask) if msk] mtdt = {key:[v for v, msk in zip(lst, self._simPathMask) if msk] for key, lst in mtdt.items()} # Set the metadata for key, lst in mtdt.items(): self.metaData.__setitem__(key, lst, _internal=True) # Add property metadata self.metaData.__setitem__('property', [self._endName] * self._len, _internal=True) # Add endname specific metadata if not already in result path if self._endName in nsim.SimPath._PATH_END_METADATA: for key, val in nsim.SimPath._PATH_END_METADATA[self._endName].items(): if key in self.metaData: row = self.metaData[key] # Only add where it is not defined for i, v in enumerate(row): row[i] = val if v is None else v else: row = [val] * self._len self.metaData.__setitem__(key, row, _internal=True) # Set the description self.description = self._strDescr() def __getattr__(self, name): self._descr.append(name) if name not in nsim.SimPath._endNames: self.simpath = getattr(self.simpath, name) else: self._descriptor = getattr(nsim.SimPath, name) self._endName = name self._finalize() return self def __call__(self, *args, **kwargs): self.simpath = self.simpath(*args, **kwargs) self._descr[-1] += f'({nutils.args2str(*args, **kwargs)})' return self def __getitem__(self, key): self.simpath = self.simpath[key] self._descr[-1] += f'[{nutils.key2str(key)}]' return self def __repr__(self): return self.description class _ResultList(ResultSelector): """Represents the concatenation of several ResultSelectors.""" def __init__(self, lst, *args, **kwargs): super().__init__(*args, **kwargs) self.children = lst self._finalize() # Do not initialize metadata directly self._metaData = None def _strDescr(self): """Return a default generic description of the ResultSelector.""" return ', '.join(c.description for c in self.children) def _evaluate(self, solvStateId=None): """Return a list of the values to save.""" res = [] for c in self.children: res += list(c._evaluate(solvStateId)) return res def _getEvalLen(self): """Return the number of values that _evaluate() will return.""" return self._evalLen def _concat(self, other): """Concatenate two result selectors into a _ResultList.""" if other.__class__ is _ResultList: return _ResultList(self.children + other.children, self.sim) else: return _ResultList(self.children + [other], self.sim) def _checkComplete(self): """Raise an exception if the result selector is not complete.""" for c in self.children: c._checkComplete() def _finalize(self): """Compute labels and evel length""" self._evalLen = 0 for c in self.children: self._evalLen += c._getEvalLen() # concatenate labels self._labels = [] for c in self.children: self._labels += c.labels self.description = self._strDescr() def _distribute(self): """Distribute the path across MPI ranks if it involves mesh elements of a distributed meshes.""" if self._distrInds is not None: return self, False self._fullLen = self._getEvalLen() newLst = [] globalIdx = 0 changed = False self._distrInds = [] for c in self.children: totLen = c._getEvalLen() nc, cChanged = c._distribute() if nc is not None: newLst.append(nc) distrInds = c._distrInds if c._distrInds is not None else range(c._getEvalLen()) self._distrInds += [globalIdx + idx for idx in distrInds] changed |= cChanged globalIdx += totLen self.children = newLst self._finalize() # Update metaData for key, vals in self.metaData.items(): if len(vals) != self._getEvalLen(): self._metaData.__setitem__(key, [vals[ind] for ind in self._distrInds], _internal=True) return self, changed def _computeMetaData(self): """Compute the concatenation of metadata.""" metaDataKeys = set() for c in self.children: metaDataKeys.update(c.metaData.keys()) mtdt = {} for key in metaDataKeys: lst = [] for c in self.children: try: lst += list(c.metaData[key]) except KeyError: lst += [None] * c._getEvalLen() mtdt[key] = lst return mtdt def _getAllTerminalChildren(self): """Return all children that are not ResultLists recursively""" for c in self.children: if isinstance(c, _ResultList): yield from c._getAllTerminalChildren() else: yield c @ResultSelector.metaData.getter def metaData(self): if self._metaData is None: self._metaData = _MetaData(self) mtdt = self._computeMetaData() for key, lst in mtdt.items(): self._metaData._dict[key] = lst return self._metaData def __repr__(self): return self.description class _ResultCombiner(_ResultList): """ Transforms results using function func that takes an iterable and outputs a list. function lenFunc should take the length of children output as an argument and return the length of the combiner output. """ def __init__(self, func, lenFunc, *args, labelArgFunc=None, labelStrFunc=None, metaDataFunc=None, strDescr=None, **kwargs): self.func = func super().__init__(*args, **kwargs) self._len = lenFunc(super()._getEvalLen()) self._labelArgFunc = labelArgFunc if labelArgFunc is not None else lambda i, chld: (f'{chld}[{i}]',) self._labelStrFunc = labelStrFunc if labelStrFunc is not None else lambda *args: ''.join(args) self._metadataFunc = metaDataFunc if metaDataFunc is not None else nutils.getValueIfAllIdentical if strDescr is not None: self.description = strDescr self._labels = [] for i in range(self._len): subLabels = self._labelArgFunc(i, self.children) args = [arg.sel.labels[arg.ind] if isinstance(arg, _LabelSelector) else arg for arg in subLabels] self._labels.append(self._labelStrFunc(*args)) def _concat(self, other): """Concatenate two result selectors into a _ResultList.""" return _ResultList([self, other], self.sim) def _strDescr(self): """Return a default generic description of the ResultSelector.""" return f"{self.func}({super()._strDescr()})" def _evaluate(self, solvStateId=None): """Return a list of the values to save.""" return self.func(super()._evaluate(solvStateId)) def _getEvalLen(self): """Return the number of values that _evaluate() will return.""" return self._len def _distribute(self): """Distribute the path across MPI ranks if it involves mesh elements of a distributed meshes.""" if self._distrInds is not None: return self, False self._fullLen = self._getEvalLen() if not nsim.MPI._shouldWrite: self._len = 0 self._labels = [] self._distrInds = [] if self._metaData is not None: self._metaData._clear() return self, True else: self._distrInds = list(range(self._getEvalLen())) return self, False @ResultSelector.metaData.getter def metaData(self): if self._metaData is None: self._metaData = _MetaData(self) mtdt = self._computeMetaData() # Only keep common metadata for key, lst in mtdt.items(): vals = self._metadataFunc(lst) if vals is not None: self._metaData._dict[key] = vals return self._metaData def __repr__(self): return self.description ################################################################################################### # Read only ResultSelectors class _ReadOnlyResultSelector: """ Only implement data access methods of ResultSelector """ def __init__(self, handler): self._dataHandler = handler @property def time(self): """Return an accessor to the timepoints data.""" return self._dataHandler.time() @property def data(self): """Return an accessor to the saved data.""" return self._dataHandler.data() @property def labels(self): """Return a list of strings describing the things being saved.""" return self._dataHandler.labels() @property def metaData(self): """Return the metadata associated to the ResultSelector.""" return self._dataHandler.metaData() @property def description(self): """Return a string describing the result selector""" return self._dataHandler.description() ################################################################################################### # Data handlers class _DataHandler(nutils.Versioned): """ Interface for data saving classes. """ DESCRIPTION_ADDED_VERSION_ABOVE = '5.0.0' def __init__(self, parent, *args, **kwargs): super().__init__(*args, **kwargs) self._parent = parent self._runId = -1 def time(self): """Return an accessor to the timepoints data.""" pass def data(self): """Return an accessor to the saved data.""" pass def labels(self): """Return the labels of the data being saved.""" raise NotImplementedError() @nutils.Versioned._versionRange(belowOrEq=DESCRIPTION_ADDED_VERSION_ABOVE) def description(self): """Return a description of the data being saved.""" raise NotImplementedError( f'Result selector description is not available for files saved with STEPS ' f'{_DataHandler.DESCRIPTION_ADDED_VERSION_ABOVE} or below. This file was ' f'saved with STEPS {self._version}.' ) def _newRun(self): """Signal that a new run of the simulation started.""" self._runId += 1 def save(self, t, row): """Save the data.""" pass def clear(self): """Discard all saved data""" raise NotImplementedError(f'clear() is not available for {self.__class__}.') def _savingStarted(self): """Return whether data started being saved.""" return self._runId >= 0 @classmethod def _checkCanAccess(cls): if not nsim.MPI._shouldWrite: raise Exception(f'Cannot access ResultSelector data out of the rank 0 process while using MPI.') class _MemoryDataHandler(_DataHandler): """ Data handler for saving data in memory. """ def __init__(self, parent, *args, **kwargs): super().__init__(parent, *args, **kwargs) self.saveData = [] self.saveTime = [] def time(self): """Return an accessor to the timepoints data.""" return _MemoryDataAccessor(self.saveTime, 2) def data(self): """Return an accessor to the saved data.""" return _MemoryDataAccessor(self.saveData, 3) @nutils.Versioned._versionRange(above=_DataHandler.DESCRIPTION_ADDED_VERSION_ABOVE) def description(self): """Return a description of the data being saved.""" return self._parent.description def _newRun(self): """Signal that a new run of the simulation started.""" super()._newRun() self.saveData.append([]) self.saveTime.append([]) def save(self, t, row): """Save the data.""" self.saveTime[-1].append(t) self.saveData[-1].append(copy.copy(row)) def clear(self): """Discard all saved data""" self.saveData = [] self.saveTime = [] class _FileDataHandler(_DataHandler): """ Data handler for saving data to files. """ HEADER_FORMAT = '>QQQQ' DATA_FORMAT = '>d' DEFAULT_BUFFER_SIZE = 4096 INT_SIZE = 4 SELECTOR_DESCRIPTION_STR = '__selector_description__' FILE_FORMAT_STR = '__steps_version__' FILE_FORMAT_OLDEST_VERSION = '3.6.0' RESERVED_KEY_NAMES = [SELECTOR_DESCRIPTION_STR, FILE_FORMAT_STR] def __init__(self, parent, path, evalLen=None, buffering=-1, *args, **kwargs): super().__init__(parent, *args, **kwargs) self._savePath = path self._readOnly = evalLen is None self._evalLen = evalLen if evalLen is not None else 1 self._saveFile = None self._saveBuffering = buffering self._fileHeaderInfo = None self._filePrevPos = None self.saveData = collections.deque([], self._getDequeMaxSize()) self.saveTime = [] self._labels = None self._metaData = None # TODO Not urgent: make labels and metadata readonly self._labelEndPos = None # If we are reading from a file, we need to set the version if self._readOnly: version = self.metaData(internal=True)._dict.get( _FileDataHandler.FILE_FORMAT_STR, _FileDataHandler.FILE_FORMAT_OLDEST_VERSION ) self._setVersion(version) def __del__(self): if hasattr(self, '_saveFile') and self._saveFile is not None: self._finalizeFile() def time(self): """Return an accessor to the timepoints data.""" self._checkCanAccess() self._finalizeFile() return _FileDataAccessor(self._savePath, parent=self, saveTime=True) def data(self): """Return an accessor to the saved data.""" self._checkCanAccess() self._finalizeFile() return _FileDataAccessor(self._savePath, parent=self, saveTime=False) def labels(self): """Return the labels of the data being saved.""" self._checkCanAccess() if self._labels is None: self._readLabelsAndMetaData() return self._labels def metaData(self, internal=False): """Return the metaData of the data being saved.""" self._checkCanAccess() if self._metaData is None: self._readLabelsAndMetaData() md = _MetaData(None) if internal: md._dict = self._metaData else: md._dict = { key:data for key, data in self._metaData.items() if key not in _FileDataHandler.RESERVED_KEY_NAMES } return md @nutils.Versioned._versionRange(above=_DataHandler.DESCRIPTION_ADDED_VERSION_ABOVE) def description(self): """Return a description of the data being saved.""" return self.metaData(internal=True)._dict.get(_FileDataHandler.SELECTOR_DESCRIPTION_STR) @property def _dataStartPos(self): if self._labelEndPos is None: self.labels() return self._labelEndPos def _newRun(self): """Signal that a new run of the simulation started.""" super()._newRun() self.saveData.clear() self.saveTime.append([]) if nsim.MPI._shouldWrite: self._writeRunHeader(self._runId, 0, 1 + self._evalLen) def save(self, t, row): """Save the data.""" self.saveTime[-1].append(t) self.saveData.append(list(row)) if nsim.MPI._shouldWrite: self._writeToFile(t, self.saveData[-1]) def _openFile(self): """Open the file in the correct mode.""" if self._saveFile is None: if self._fileHeaderInfo is None: self._saveFile = open(self._savePath, 'wb', buffering=self._saveBuffering) else: self._saveFile = open(self._savePath, 'r+b', buffering=self._saveBuffering) self._saveFile.seek(0, 2) return self._saveFile def _writeRunHeader(self, runId, nbRows, nbCols, writeNext=True): """Write the header line of a run.""" self._openFile() if self._fileHeaderInfo is not None: nxtPos = self._saveFile.seek(0, 1) self._saveFile.seek(self._filePrevPos, 0) if writeNext: self._fileHeaderInfo[3] = nxtPos self._saveFile.write(struct.pack(_FileDataHandler.HEADER_FORMAT, *self._fileHeaderInfo)) self._saveFile.seek(nxtPos, 0) self._saveFile.flush() else: self._writeLabelsAndMetaData() if writeNext: self._fileHeaderInfo = [runId, nbRows, nbCols, 0] self._filePrevPos = self._saveFile.seek(0, 1) self._saveFile.write(struct.pack(_FileDataHandler.HEADER_FORMAT, *self._fileHeaderInfo)) def _writeLabelsAndMetaData(self): """Write the labels and the metadata to the file header.""" # Labels lbls = self._parent.labels self._writeInt(len(lbls)) for l in lbls: self._writeStr(l) # MetaData mtdt = copy.copy(self._parent.metaData._dict) for keyname in _FileDataHandler.RESERVED_KEY_NAMES: if keyname in mtdt: raise Exception( f'The metadata contains the reserved key name "{keyname}"' ) mtdt[_FileDataHandler.FILE_FORMAT_STR] = steps.__version__ mtdt[_FileDataHandler.SELECTOR_DESCRIPTION_STR] = self._parent.description data = pickle.dumps(mtdt) self._writeInt(len(data)) self._saveFile.write(data) def _writeInt(self, i): """Write int i to the binary file.""" self._saveFile.write(i.to_bytes(_FileDataHandler.INT_SIZE, byteorder='big')) def _writeStr(self, s): """Write string s to the binary file.""" bs = s.encode('ascii') self._writeInt(len(bs)) self._saveFile.write(bs) @staticmethod def _readInt(f): """Read an int from binary file f.""" return int.from_bytes(f.read(_FileDataHandler.INT_SIZE), byteorder='big') @staticmethod def _readStr(f): """Read a string from binary file f.""" strLen = _FileDataHandler._readInt(f) return f.read(strLen).decode('ascii') def _readLabelsAndMetaData(self): """Open the file and read labels.""" with open(self._savePath, 'rb') as f: # Labels nbLbls = _FileDataHandler._readInt(f) self._labels = [] for i in range(nbLbls): self._labels.append(_FileDataHandler._readStr(f)) mtdtSz = _FileDataHandler._readInt(f) # TODO Not urgent: make the dict readonly self._metaData = pickle.loads(f.read(mtdtSz)) self._labelEndPos = f.seek(0, 1) @nutils.Versioned._versionRange(belowOrEq=FILE_FORMAT_OLDEST_VERSION) def _writeToFile(self, t, vals): """Write the data to file.""" self._openFile() self._saveFile.write(struct.pack('>d' + 'd' * len(vals), t, *vals)) self._fileHeaderInfo[1] += 1 @nutils.Versioned._versionRange(above=FILE_FORMAT_OLDEST_VERSION) def _writeToFile(self, t, vals): """Write the data to file.""" self._openFile() pickle.dump((t, vals), self._saveFile) self._fileHeaderInfo[1] += 1 def _finalizeFile(self): """Flush the file buffer and close the file.""" # Only write things if the result selector was created from a simulation, not a file path if nsim.MPI._shouldWrite and not self._readOnly: self._writeRunHeader(None, None, None, writeNext=False) self._saveFile.close() self._saveFile = None def _getDequeMaxSize(self): """Return the length of the buffer deque.""" if self._saveBuffering != -1: buf = self._saveBuffering else: buf = _FileDataHandler.DEFAULT_BUFFER_SIZE return max(1, buf // self._evalLen) class _DBDataHandler(_DataHandler): pass class _SQLiteDataHandler(_DBDataHandler): """ Data handler for saving to sqlite db file. """ TABLE_NAME_TEMPLATE = 'Group_{}_Selector_{}' COLUMN_NAME_TEMPLATE = 'Col_{} real' MTDT_STEPS_VERSION_STR = '__steps_version__' MTDT_OLDEST_VERSION = '5.0.0' RESERVED_KEY_NAMES = [MTDT_STEPS_VERSION_STR] def __init__( self, parent, dbh, commitFreq, *args, groupId=None, rsid=None, tableName=None, nbCols=None, **kwargs ): super().__init__(parent, *args, **kwargs) self._dbh = dbh self._conn = dbh._conn self._commitFreq = commitFreq self._commitInd = 0 self._initialized = False self._groupId = groupId self._rsid = rsid self._tableName = tableName self._nbCols = nbCols self._labels = None self._metaData = None if self._groupId is not None: # Load version if we are reading from a database version = self.metaData(internal=True).get( _SQLiteDataHandler.MTDT_STEPS_VERSION_STR, _SQLiteDataHandler.MTDT_OLDEST_VERSION ) self._setVersion(version) def time(self): """Return an accessor to the timepoints data.""" self._checkCanAccess() return _SQLiteDataAccessor( self._dbh, self._groupId, self._rsid, self._tableName, self._nbCols, saveTime=True ) def data(self): """Return an accessor to the saved data.""" self._checkCanAccess() return _SQLiteDataAccessor( self._dbh, self._groupId, self._rsid, self._tableName, self._nbCols, saveTime=False ) def labels(self): """Return the labels of the saved data.""" self._checkCanAccess() if self._labels is None: self._labels = self._dbh._labelsQuerry(self._groupId, self._rsid) return self._labels def metaData(self, internal=False): """Return the metadata of the saved data.""" self._checkCanAccess() if self._metaData is None: self._metaData = self._dbh._metaDataQuerry(self._groupId, self._rsid) if internal: return self._metaData else: return { key:data for key, data in self._metaData.items() if key not in _SQLiteDataHandler.RESERVED_KEY_NAMES } @nutils.Versioned._versionRange(above=_DataHandler.DESCRIPTION_ADDED_VERSION_ABOVE) def description(self): """Return a description of the data being saved.""" return self._dbh._descriptionQuerry(self._groupId, self._rsid) def _initialize(self): """ Create the table and initialize everything. Should only be called after the first newRun. """ lbls = self._parent.labels self._groupId = self._dbh._groupId self._rsid = self._parent._selectorInd colStr = ', '.join(_SQLiteDataHandler.COLUMN_NAME_TEMPLATE.format(i) for i in range(len(lbls))) self._nbCols = len(lbls) self._tableName = _SQLiteDataHandler.TABLE_NAME_TEMPLATE.format(self._groupId, self._rsid) self._insertStr = f"INSERT INTO {self._tableName} VALUES ({','.join('?'*(2+self._nbCols))});" # Check if the table already exists rows = self._conn.execute( f"SELECT name FROM sqlite_master WHERE type='table' AND name='{self._tableName}'" ).fetchall() if len(rows) == 0: # Create table self._conn.execute(f'CREATE TABLE {self._tableName} (runid int, time real, {colStr});') # Add table info to main table self._conn.execute( f'INSERT INTO {SQLiteDBHandler._RS_MAIN_TABLE_NAME} VALUES (?,?,?,?,?);', (self._groupId, self._rsid, self._parent.description, self._tableName, self._nbCols), ) # Add labels self._conn.executemany( f'INSERT INTO {SQLiteDBHandler._RS_LABEL_TABLE_NAME} VALUES (?,?,?,?);', [(self._groupId, self._rsid, i, lbl) for i, lbl in enumerate(lbls)], ) # Add MetaData mtdt = copy.copy(self._parent.metaData._dict) for keyname in _SQLiteDataHandler.RESERVED_KEY_NAMES: if keyname in mtdt: raise Exception( f'The metadata contains the reserved key name "{keyname}"' ) mtdt[_SQLiteDataHandler.MTDT_STEPS_VERSION_STR] = steps.__version__ self._conn.execute( f'INSERT INTO {SQLiteDBHandler._RS_META_DATA_TABLE_NAME} VALUES (?,?,?);', (self._groupId, self._rsid, pickle.dumps(mtdt)), ) else: # Initialize the runId to the last recorded one rid = self._conn.execute(f'SELECT MAX(runid) FROM {self._tableName}').fetchone()[0] if rid is not None: self._runId = rid self._conn.commit() self._cursor = self._conn.cursor() self._initialized = True def _newRun(self): """Signal that a new run of the simulation started.""" if not self._initialized and nsim.MPI._shouldWrite: self._initialize() super()._newRun() def save(self, t, row): """Save the data.""" if nsim.MPI._shouldWrite: self._cursor.execute(self._insertStr, (self._runId, t) + tuple(row)) self._commitInd += 1 if self._commitInd % self._commitFreq == 0: self._conn.commit() class _HDF5DataHandler(_DBDataHandler): """ Data handler for saving to HDF5 file. """ _RS_COLREMAPPING_NAME = 'ColumnRemapping' _LABELS_DSET_NAME = 'labels' _METADATA_GROUP_NAME = 'metaData' _RUNS_GROUP_NAME = 'runs' _DATA_DSET_NAME = 'data' _TIME_DSET_NAME = 'time' _RUN_GROUP_TEMPLATE = 'Run_{}' def __init__(self, dbh, parent, group, *args, **kwargs): super().__init__(parent, *args, **kwargs) self._dbh = dbh self._group = group self._initialized = False self._timeInd = None self._time = None self._data = None self._lbls = None # Vector representing the permutation that should be applied before saving the data to file # It is used by XDMF data handler to ensure contiguous blocks of data. self._colRemapping = None self._revColRemapping = None self._compObjInds = None if parent is None: # Load column remapping if we are reading data self._loadColumnRemap() self._initializeCompoundObjects() def time(self): """Return an accessor to the timepoints data.""" self._checkCanAccess() return _HDF5DataAccessor(self, True) def data(self): """Return an accessor to the saved data.""" self._checkCanAccess() return _HDF5DataAccessor(self, False) def labels(self): """Return the labels of the data being saved.""" if self._lbls is None: self._checkCanAccess() self._lbls = [lbl.decode('utf-8') for lbl in self._group[_HDF5DataHandler._LABELS_DSET_NAME]] return self._lbls def metaData(self): """Return the metadata of the saved data.""" self._checkCanAccess() return _HDF5MetaDataAccessor(self) @nutils.Versioned._versionRange(above=_DataHandler.DESCRIPTION_ADDED_VERSION_ABOVE) def description(self): """Return a description of the data being saved.""" return self._group.attrs[HDF5Handler._RS_DESCRIPTION_ATTR] def _checkCanAccess(self): if not self._dbh._shouldWrite: raise Exception( f'Cannot access HDF5 data out of the rank 0 process while using non-distributed ' f'simulation with MPI.' ) def _initialize(self): """ Create the subgroups and initialize everything. Should only be called after the first newRun. """ dskwargs = self._dbh._dataSetKWargs if _HDF5DataHandler._RUNS_GROUP_NAME not in self._group: self._group.create_group(_HDF5DataHandler._RUNS_GROUP_NAME, track_order=True) else: # Initialize the runId to the last recorded one self._runId = len(self._group[_HDF5DataHandler._RUNS_GROUP_NAME]) - 1 if _HDF5DataHandler._LABELS_DSET_NAME not in self._group: self._group.create_dataset( _HDF5DataHandler._LABELS_DSET_NAME, data=[lbl.encode('utf-8') for lbl in self._parent.labels], **dskwargs ) if _HDF5DataHandler._METADATA_GROUP_NAME not in self._group: mtdtGroup = self._group.create_group(_HDF5DataHandler._METADATA_GROUP_NAME) for key, vals in self._parent.metaData._dict.items(): # Try to handle different types if any(isinstance(v, str) for v in vals): vals = [v.encode('utf-8') for v in map(str, vals)] elif any(isinstance(v, numbers.Number) or v is None for v in vals): vals = [numpy.nan if v is None else v for v in vals] elif len(vals) > 0: raise TypeError( f'Metadata contains values that are not strings or numbers, they cannot ' f'be saved to HDF5 format.' ) if len(vals) > 0: mtdtGroup.create_dataset(key, data=vals, **dskwargs) self._loadColumnRemap() self._initializeCompoundObjects() self._initialized = True def _initializeCompoundObjects(self): """Initialize compound object handler, if applicable""" tpes = self.metaData().get('value_type', None) if tpes is not None: self._compObjInds = [i for i, tpe in enumerate(tpes) if tpe is not None] if len(self._compObjInds) == 0: self._compObjInds = None def _loadColumnRemap(self): """Load Column remapping, if available""" if _HDF5DataHandler._RS_COLREMAPPING_NAME in self._group: # Load column remapping if it was already saved to file colRemapping = numpy.array(self._group[_HDF5DataHandler._RS_COLREMAPPING_NAME]) if self._colRemapping is not None and any(list(colRemapping != self._colRemapping)): raise Exception( f'The column remapping saved in the HDF5 file for result selector ' f'{self._parent._selectorInd} ({self._parent}) is different from the one that was computed ' f'for this simulation. Try to save your data to a different HDF5 file.' ) self._colRemapping = colRemapping # Compute reverse mapping for data reading self._revColRemapping = numpy.array([0] * len(self._colRemapping)) for i, v in enumerate(self._colRemapping): self._revColRemapping[v] = i def _newRun(self): """Signal that a new run of the simulation started.""" dskwargs = self._dbh._dataSetKWargs if not self._initialized and self._group is not None: self._initialize() self._timeInd = -1 super()._newRun() if self._group is not None: n = self._parent._getEvalLen() runGroup = self._group[_HDF5DataHandler._RUNS_GROUP_NAME].create_group( _HDF5DataHandler._RUN_GROUP_TEMPLATE.format(self._runId) ) self._data = runGroup.create_dataset( _HDF5DataHandler._DATA_DSET_NAME, (1, n), maxshape=(None, n), dtype='d', **dskwargs ) self._time = runGroup.create_dataset( _HDF5DataHandler._TIME_DSET_NAME, (1,), maxshape=(None,), dtype='d', **dskwargs ) def save(self, t, row): """Save the data.""" if self._group is not None: self._timeInd += 1 if self._timeInd >= self._time.shape[0]: self._time.resize(self._timeInd + 1, axis=0) self._data.resize(self._timeInd + 1, axis=0) self._time[self._timeInd] = t if self._compObjInds is not None: for i in self._compObjInds: row[i] = self._dbh._compObjHandler.write(row[i]) if self._colRemapping is None: self._data[self._timeInd, :] = numpy.array(row) else: self._data[self._timeInd, :] = numpy.array(row)[self._colRemapping] class _HDF5DistribDataHandler(_HDF5DataHandler): """ Data handler for loading several HDF5 files that have been saved in a distributed way """ def __init__(self, hdfGroup, *args, **kwargs): super().__init__(*args, **kwargs) self._hdfGroup = hdfGroup self._colMap = numpy.array(self._group[HDF5Handler._RS_DIST_IND_MAP_NAME]) self._nbCols = self._colMap.shape[1] self._rsInd = self._group.attrs[HDF5Handler._RS_INDEX_ATTR] self._dbUid = self._hdfGroup.name self._usedRanks = None self._setUpRankFiles() def _setUpRankFiles(self): self._usedRanks = set(self._colMap[0,:]) for rnk in self._usedRanks: # Open HDF5 files if rnk not in self._dbh._distribRankDBHs: if rnk == nsim.MPI._rank: self._dbh._distribRankDBHs[rnk] = self._dbh else: self._dbh._distribRankDBHs[rnk] = HDF5Handler( HDF5Handler._DISTRIBUTED_HDF_SUFFIX.format(self._dbh._pathPrefix, rnk), hdf5FileKwArgs=self._dbh._fileKwArgs, version=self._version ) # Load read-only result selector rnkDBH = self._dbh._distribRankDBHs[rnk] rsDict = rnkDBH._distribRS.setdefault(self._dbUid, {}) if self._rsInd not in rsDict: rnkDBH._checkOpenFile() rsGroup = rnkDBH._file[self._dbUid][HDF5Handler._RS_GROUP_NAME.format(self._rsInd)] rsDict[self._rsInd] = _ReadOnlyResultSelector( _HDF5DataHandler(rnkDBH, None, rsGroup, version=self._version) ) def _initializeCompoundObjects(self): """Initialize compound object handler, if applicable""" pass def time(self): """Return an accessor to the timepoints data.""" # All times should be the same, can return the first one rnk = self._colMap[0, 0] return self._dbh._distribRankDBHs[rnk]._distribRS[self._dbUid][self._rsInd].time def data(self): """Return an accessor to the saved data.""" return _HDF5DistribDataAccessor(self) def labels(self): """Return the labels of the data being saved.""" if self._lbls is None: self._lbls = [''] * self._nbCols rnk2Lbls = {} for rnk in self._usedRanks: rnk2Lbls[rnk] = self._dbh._distribRankDBHs[rnk]._distribRS[self._dbUid][self._rsInd].labels for i, (rnk, locIdx) in enumerate(self._colMap.T): self._lbls[i] = rnk2Lbls[rnk][locIdx] return self._lbls def metaData(self): """Return the metadata of the saved data.""" return _HDF5DistribMetaDataAccessor(self) class _HDF5CompoundObjHandler(nutils.Versioned): """Utility class for writing compound objects to HDF groups Support writing python lists and dicts to HDF5 groups. """ _COMPOBJ_GROUP_NAME = 'CompoundObjects' _COMPOBJ_DSET_NAME = 'CompObjs' _IND_DTYPE = 'i' class _DATA_TYPE: INT = 0 FLOAT = 1 STRING = 2 LIST = 3 DICT = 4 _DATA_INFO = { # Data type: (dtype, dataset name) _DATA_TYPE.INT: ('i', 'Ints'), _DATA_TYPE.FLOAT: ('d', 'Floats'), _DATA_TYPE.STRING: ('B', 'Strings'), _DATA_TYPE.LIST: (_IND_DTYPE, 'Lists'), } def __init__(self, parentGroup, dbh, cachedTypes=[], readOnly=False, maxFullLoadSize=1024**2, **kwargs): super().__init__(**kwargs) self._parentGroup = parentGroup self._dbh = dbh self._caches = {tpe: {} for tpe in cachedTypes} self._group = None self._compDset = None self._dsets = None self._cacheInit = False self._readOnly = readOnly self._maxFullLoadSize = maxFullLoadSize self._setUp() def _setUp(self): if self._COMPOBJ_GROUP_NAME not in self._parentGroup: if self._readOnly: raise ReadOnlyWriteError() self._parentGroup.create_group(self._COMPOBJ_GROUP_NAME) self._group = self._parentGroup[self._COMPOBJ_GROUP_NAME] dskwargs = self._dbh._dataSetKWargs if self._COMPOBJ_DSET_NAME not in self._group: if self._readOnly: raise ReadOnlyWriteError() self._group.create_dataset( self._COMPOBJ_DSET_NAME, (0, 3), maxshape=(None, 3), dtype=self._IND_DTYPE, **dskwargs ) self._compDset = self._group[self._COMPOBJ_DSET_NAME] if self._readOnly and len(self._compDset) <= self._maxFullLoadSize: nutils._print('Loading full compound data dataset', 3) self._compDset = self._compDset[...] self._dsets = [] for tpe, (dtype, dsetName) in self._DATA_INFO.items(): if dsetName not in self._group: if self._readOnly: raise ReadOnlyWriteError() self._group.create_dataset(dsetName, (0,), maxshape=(None,), dtype=dtype, **dskwargs) self._dsets.append(self._group[dsetName]) if self._readOnly and len(self._dsets[-1]) <= self._maxFullLoadSize: nutils._print(f'Loading full compound data subdataset: {dsetName}', 3) self._dsets[-1] = self._dsets[-1][...] # Faster than if/else or match statement when needing to load a lot of objects ints, floats, strings, lists = self._dsets self._dataLoaders = [ lambda start, end: list(ints[start:end]) if end >= start else ints[start], lambda start, end: list(floats[start:end]) if end >= start else floats[start], lambda start, end: bytearray(strings[start:end]).decode('utf-8'), lambda start, end: [self.read(i) for i in lists[start:end]], lambda start, end: {k: v for k, v in zip(*[self.read(i) for i in lists[start:end]])}, ] def _getDataRanges(self, ind): """Return datasets and ranges that contain the data of the object with index ind.""" tpe, start, end = self._compDset[ind, :] if tpe in [self._DATA_TYPE.FLOAT, self._DATA_TYPE.INT, self._DATA_TYPE.STRING]: if end >= start: return (self._dsets[tpe], start, end) else: return (self._dsets[tpe], start, start + 1) elif tpe in [self._DATA_TYPE.LIST, self._DATA_TYPE.DICT]: return [self._getDataRanges(i) for i in self._dsets[self._DATA_TYPE.LIST][start:end]] else: raise NotImplementedError() def _getCacheKey(self, val): if isinstance(val, list): return tuple(self._getCacheKey(v) for v in val) elif isinstance(val, dict): return (self._getCacheKey(list(val.keys())), self._getCacheKey(list(val.values()))) return val def _pushData(self, tpe, data): dset = self._dsets[tpe] start = len(dset) end = start + len(data) dset.resize(end, axis=0) dset[start:end] = data return start, end def _pushCompound(self, tpe, start, end): ind = self._compDset.shape[0] self._compDset.resize(ind + 1, axis=0) self._compDset[ind,:] = [tpe, start, end] return ind def _pushObject(self, tpe, data, dataTpe=None): if dataTpe is None: dataTpe = tpe if tpe in self._caches: # Caching cache = self._caches[tpe] key = self._getCacheKey(data) if key not in cache: cache[key] = self._pushCompound(tpe, *self._pushData(dataTpe, data)) return cache[key] else: return self._pushCompound(tpe, *self._pushData(dataTpe, data)) def write(self, obj): if not self._cacheInit: # Initialize cache if len(self._caches) > 0: for i, (tpe, start, end) in enumerate(self._compDset[...]): if tpe in self._caches: key = self._getCacheKey(self.read(i)) self._caches[tpe][key] = i self._cacheInit = True if obj is None: return -1 elif isinstance(obj, numbers.Number): # Single value tpe = self._DATA_TYPE.FLOAT if isinstance(obj, float) else self._DATA_TYPE.INT start, end = self._pushData(tpe, [obj]) return self._pushCompound(tpe, start, -1) elif isinstance(obj, str): return self._pushObject(self._DATA_TYPE.STRING, list(obj.encode('utf-8'))) elif isinstance(obj, (list, tuple)): if all(isinstance(v, numbers.Number) for v in obj): # List of numbers if any(isinstance(v, float) for v in obj): return self._pushObject(self._DATA_TYPE.FLOAT, obj) else: # Integer list return self._pushObject(self._DATA_TYPE.INT, obj) else: # List of compounds return self._pushObject(self._DATA_TYPE.LIST, [self.write(v) for v in obj]) elif isinstance(obj, dict): indKeys = self.write(tuple(obj.keys())) indVals = self.write(list(obj.values())) return self._pushObject(self._DATA_TYPE.DICT, [indKeys, indVals], dataTpe=self._DATA_TYPE.LIST) else: raise TypeError(f'Unsupported type {type(obj)} cannot be added to the HDF5 file.') def read(self, ind): if ind < 0: return None try: tpe, start, end = self._compDset[ind, :] return self._dataLoaders[tpe](start, end) except IndexError: warnings.warn(f'Could not read compound object {ind}, returning None instead of its value.') return None class _XDMFDataHandler(_HDF5DataHandler): """ Data handler for writing xdmf files for each run while saving data to HDF5 file. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def _newRun(self): """Signal that a new run of the simulation started.""" super()._newRun() self._dbh._newRun(self._runId) def save(self, t, row): """Save the data.""" super().save(t, row) self._dbh._newTimeStep(t, self._parent, self._timeInd) ################################################################################################### # Data accessors def _sliceData(data, key): """ Slice multidimentional data in nested lists according to key. Use the numpy slicing conventions. 'key' should be a tuple that can only contain integers or slices. Return the data in nested lists. """ if len(key) == 1: return data[key[0]] k = key[0] if isinstance(k, slice): res = [] for sub in data[k]: res.append(_sliceData(sub, key[1:])) return res else: return _sliceData(data[k], key[1:]) class _MemoryDataAccessor: """ Data accessor for _MemoryDataHandler """ def __init__(self, data, nbDims): self._data = data self._nbDims = nbDims def __getitem__(self, key): key = nutils.formatKey(key, self._nbDims, forceSz=True) return nutils.nparray(_sliceData(self._data, key)) def __array__(self): return nutils.nparray(self._data) def __len__(self): return len(self._data) class _FileDataAccessor(nutils.Versioned): """ Data accessor for _FileDataHandler """ HEADER_SIZE = struct.calcsize(_FileDataHandler.HEADER_FORMAT) DATA_SIZE = struct.calcsize(_FileDataHandler.DATA_FORMAT) DEFAULT_MAXRUNID = sys.maxsize class UnexpectedEnd(Exception): pass # TODO Optimization: save the number of runs and related data and only update it if the file # was changed def __init__(self, fp, parent, saveTime=False): self._fp = fp self._saveTime = saveTime self._parentHandler = parent self._dataStartPos = parent._dataStartPos self._file = open(self._fp, 'rb') self._fileInfo = {} self._nbDims = 2 if saveTime else 3 self._setVersion(self._parentHandler._version) def __del__(self): if hasattr(self, '_file') and self._file is not None: self._file.close() def __len__(self): if 'len' not in self._fileInfo: pos = self._file.seek(0, 1) self._file.seek(self._dataStartPos) nb = 0 try: runId, nbRows, nbCols, nxt = struct.unpack( _FileDataHandler.HEADER_FORMAT, self._file.read(_FileDataAccessor.HEADER_SIZE) ) nb += 1 while nxt != 0: self._file.seek(nxt) runId, nbRows, nbCols, nxt = struct.unpack( _FileDataHandler.HEADER_FORMAT, self._file.read(_FileDataAccessor.HEADER_SIZE) ) nb += 1 except struct.error: pass self._file.seek(pos) self._fileInfo['len'] = nb return self._fileInfo['len'] def __getitem__(self, key, forceArray=False): key = nutils.formatKey(key, self._nbDims, forceSz=True) # If possible, try to access from memory if not self._parentHandler._readOnly: if self._saveTime: return nutils.nparray(_sliceData(self._parentHandler.saveTime, key)) elif self._parentHandler._fileHeaderInfo is not None: idxs = nutils.getSliceIds(key[0], sz=self._parentHandler._fileHeaderInfo[0] + 1) if len(idxs) == 1 and idxs[0] == self._parentHandler._fileHeaderInfo[0]: nbRows = self._parentHandler._fileHeaderInfo[1] lenDeque = len(self._parentHandler.saveData) inds = nutils.getSliceIds(key[1], sz=nbRows) if all(nbRows - lenDeque <= ti < nbRows for ti in inds): res = [self._parentHandler.saveData[ti - (nbRows - lenDeque)] for ti in inds] if forceArray: if isinstance(key[2], slice): return nutils.nparray([[row[key[2]] for row in res]]) else: return nutils.nparray([[[row[key[2]]] for row in res]]) else: mk = (slice(None) if isinstance(key[1], slice) else 0, key[2]) return nutils.nparray(_sliceData(res, mk)) # Otherwise, read from file res = [] # Find the number of runs first nbRuns = len(self) if nbRuns == 0: raise IndexError(f'Cannot access data, nothing has been written to the file.') # Read header self._file.seek(self._dataStartPos) runId, nbRows, nbCols, nxt = struct.unpack( _FileDataHandler.HEADER_FORMAT, self._file.read(_FileDataAccessor.HEADER_SIZE) ) # Iterate through runs for ind in nutils.getSliceIds(key[0], sz=nbRuns): while runId != ind: if nxt == 0: break pos = self._file.seek(nxt) try: runId, nbRows, nbCols, nxt = struct.unpack( _FileDataHandler.HEADER_FORMAT, self._file.read(_FileDataAccessor.HEADER_SIZE) ) except struct.error: break if runId != ind: if isinstance(key[0], numbers.Integral) or key[0].stop is not None: raise IndexError(f'Run {ind} is not in the file.') else: break # handle the cases in which the file was only partially written and nbRows == 0 if nxt == 0 and nbRows == 0: warnings.warn( f'Run {ind} from file {self._fp} was not correctly written to file, the ' f'corresponding data will be partial.' ) nbRows = None if (isinstance(key[1], numbers.Integral) and key[1] < 0) or ( isinstance(key[1], slice) and ( (key[1].start is not None and key[1].start < 0) or (key[1].stop is not None and key[1].stop < 0) ) ): raise IndexError('Cannot access partially written data using negative indices.') if nbRows is None: nbRows = _FileDataAccessor.DEFAULT_MAXRUNID rowInds = nutils.getSliceIds(key[1], sz=nbRows) res.append([]) # Read actual data try: for t, line in self._readRows(rowInds, nbCols): if self._saveTime: res[-1].append(t) else: line = line[key[2]] if isinstance(key[2], slice) else [line[key[2]]] res[-1].append(line) except _FileDataAccessor.UnexpectedEnd: if nxt == 0: break else: raise IndexError( f'Could not load time slice {key[1]} of run {ind} from {self._fp}.' f' The file might be corrupted.' ) if forceArray: return nutils.nparray(res) mk = tuple(slice(None) if isinstance(k, slice) else 0 for k in key) return nutils.nparray(_sliceData(res, mk)) @nutils.Versioned._versionRange(belowOrEq=_FileDataHandler.FILE_FORMAT_OLDEST_VERSION) def _readRows(self, rowInds, nbCols): datFormat = _FileDataHandler.DATA_FORMAT[0] + _FileDataHandler.DATA_FORMAT[1] * nbCols pos = self._file.seek(0, 1) try: for ti in rowInds: self._file.seek(pos + ti * nbCols * _FileDataAccessor.DATA_SIZE) if self._saveTime: t, *line = struct.unpack( _FileDataHandler.DATA_FORMAT, self._file.read(_FileDataAccessor.DATA_SIZE) ) else: t, *line = struct.unpack( datFormat, self._file.read(_FileDataAccessor.DATA_SIZE * nbCols) ) yield t, line except (EOFError, struct.error): raise _FileDataAccessor.UnexpectedEnd() @nutils.Versioned._versionRange(above=_FileDataHandler.FILE_FORMAT_OLDEST_VERSION) def _readRows(self, rowInds, nbCols): currti = -1 try: for ti in rowInds: # Find desired row while currti < ti: t, line = pickle.load(self._file) currti += 1 yield t, line except (EOFError, pickle.UnpicklingError): raise _FileDataAccessor.UnexpectedEnd() def __array__(self): return self.__getitem__(slice(None, None, None), forceArray=True) class _SQLiteDataAccessor: """ Data accessor for SQLite database """ def __init__(self, dbh, groupid, rsid, tabName, nbCols, saveTime=False): self._dbh = dbh self._groupid = groupid self._rsid = rsid self._tabName = tabName self._nbCols = nbCols self._saveTime = saveTime self._nbDims = 2 if saveTime else 3 self._colLst = [_SQLiteDataHandler.COLUMN_NAME_TEMPLATE.format(ci) for ci in range(self._nbCols)] def __getitem__(self, key, forceArray=False): key = nutils.formatKey(key, self._nbDims, forceSz=True) res = [] for ri in nutils.getSliceIds(key[0], sz=len(self)): if self._saveTime: timeDat = self._dbh._conn.execute( f'SELECT time FROM {self._tabName} WHERE runid={ri} ORDER BY time' ).fetchall() res.append([timeDat[i][0] for i in nutils.getSliceIds(key[1], len(timeDat))]) else: colStr = ','.join(self._colLst[i] for i in nutils.getSliceIds(key[2], self._nbCols)) allDat = self._dbh._conn.execute( f'SELECT {colStr} FROM {self._tabName} WHERE runid={ri} ORDER BY time' ).fetchall() res.append([list(allDat[i]) for i in nutils.getSliceIds(key[1], len(allDat))]) if forceArray: return nutils.nparray(res) mk = tuple(slice(None) if isinstance(k, slice) else 0 for k in key) return nutils.nparray(_sliceData(res, mk)) def __len__(self): return self._dbh._conn.execute(f'SELECT MAX(runid) FROM {self._tabName}').fetchone()[0] + 1 def __array__(self): return self.__getitem__(slice(None, None, None), forceArray=True) class _HDF5DataAccessor(nutils.Versioned): """ Data accessor for HDF5 files """ def __init__(self, handler, saveTime=False, **kwargs): super().__init__(**kwargs) self._handler = handler self._saveTime = saveTime self._nbDims = 2 if saveTime else 3 # Compound objects if self._handler._compObjInds is not None: if self._handler._revColRemapping is not None: self._compObjInds = set(self._handler._revColRemapping[self._handler._compObjInds]) else: self._compObjInds = set(self._handler._compObjInds) else: self._compObjInds = None def __getitem__(self, key, forceArray=False): key = nutils.formatKey(key, self._nbDims, forceSz=True) res = [] runs = self._handler._group[_HDF5DataHandler._RUNS_GROUP_NAME] for ri in nutils.getSliceIds(key[0], sz=len(self)): res.append([]) runGrp = runs[_HDF5DataHandler._RUN_GROUP_TEMPLATE.format(ri)] runTime = runGrp[_HDF5DataHandler._TIME_DSET_NAME] runData = runGrp[_HDF5DataHandler._DATA_DSET_NAME] for i in nutils.getSliceIds(key[1], runTime.shape[0]): if self._saveTime: res[-1].append(runTime[i]) else: remapKey = key[2] if self._handler._revColRemapping is not None: remapKey = self._handler._revColRemapping[key[2]] if self._compObjInds is not None: # Compound objects res[-1].append([]) if not hasattr(remapKey, '__iter__'): remapKey = nutils.getSliceIds(remapKey, runData.shape[1]) for k in remapKey: if k in self._compObjInds: obj = self._handler._dbh._compObjHandler.read(int(runData[i, k])) res[-1][-1].append(obj) else: res[-1][-1].append(runData[i, k]) else: # Float values if isinstance(remapKey, slice): res[-1].append(runData[i, remapKey]) elif hasattr(remapKey, '__iter__'): res[-1].append([runData[i, j] for j in remapKey]) else: res[-1].append([runData[i, remapKey]]) if forceArray: return nutils.nparray(res) mk = tuple(slice(None) if isinstance(k, slice) or hasattr(k, '__iter__') else 0 for k in key) return nutils.nparray(_sliceData(res, mk)) def __len__(self): return len(self._handler._group[_HDF5DataHandler._RUNS_GROUP_NAME]) def __array__(self): return self.__getitem__(slice(None, None, None), forceArray=True) class _HDF5DistribDataAccessor(_HDF5DataAccessor): """ Data accessor for HDF5 files """ def __init__(self, handler, **kwargs): super().__init__(handler, saveTime=False, **kwargs) def __getitem__(self, key, forceArray=False): key = nutils.formatKey(key, self._nbDims, forceSz=True) runKey, timeKey, colKey = key rsInd = self._handler._rsInd rnk2RsAndInds = {} allColInds = nutils.getSliceIds(colKey, sz=self._handler._nbCols) for i, ci in enumerate(allColInds): rnk, li = self._handler._colMap[:, ci] if rnk not in rnk2RsAndInds: rs = self._handler._dbh._distribRankDBHs[rnk]._distribRS[self._handler._dbUid][rsInd] rnk2RsAndInds[rnk] = (rs, [], []) rnk2RsAndInds[rnk][1].append(i) rnk2RsAndInds[rnk][2].append(li) res = None for rnk, (rs, resInds, locInds) in rnk2RsAndInds.items(): locData = rs.data.__getitem__((runKey, timeKey, locInds), forceArray=True) if res is None: nbRuns, nbTpts, _ = locData.shape res = numpy.zeros((nbRuns, nbTpts, len(allColInds))) res[:, :, resInds] = locData if forceArray: return nutils.nparray(res) mk = tuple(slice(None) if isinstance(k, slice) or hasattr(k, '__iter__') else 0 for k in key) return nutils.nparray(_sliceData(res, mk)) def __len__(self): # Number of runs should be the same in all files, return the first rnk = self._handler._colMap[0, 0] rs = self._handler._dbh._distribRankDBHs[rnk]._distribRS[self._handler._dbUid][self._handler._rsInd] return len(rs.data) class _HDF5MetaDataAccessor(nutils.Versioned, nutils.ReadOnlyDictInterface): """ Meta data accessor for HDF5 files """ def __init__(self, handler, **kwargs): super().__init__(**kwargs) self._handler = handler self._cache = {} def __getitem__(self, key): if key is Ellipsis: return {k: self[k] for k in self} if key not in self._cache: if key not in self._handler._group[_HDF5DataHandler._METADATA_GROUP_NAME]: raise KeyError(f'Cannot access metaData with key: {key}.') dset = self._handler._group[_HDF5DataHandler._METADATA_GROUP_NAME][key] if isinstance(dset[0], bytes): res = [v.decode('utf-8') for v in dset] # Try to convert strings to numbers, if possible for i, v in enumerate(res): for tpe in [int, float]: try: res[i] = tpe(v) break except ValueError: pass # Restore Nones res = [None if v == 'None' else v for v in res] else: res = [None if numpy.isnan(v) else v for v in dset] self._cache[key] = res return self._cache[key] def keys(self): for name in self._handler._group[_HDF5DataHandler._METADATA_GROUP_NAME]: yield name class _HDF5DistribMetaDataAccessor(_HDF5MetaDataAccessor): """ Meta data accessor for HDF5 files """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._rnk2Mtdt = {} dbuid = self._handler._dbUid rsInd = self._handler._rsInd for rnk in self._handler._usedRanks: self._rnk2Mtdt[rnk] = self._handler._dbh._distribRankDBHs[rnk]._distribRS[dbuid][rsInd].metaData self._keys = None def __getitem__(self, key): if key not in self._cache: res = [None] * self._handler._nbCols rnk2mtdt = {rnk: m[key] for rnk, m in self._rnk2Mtdt.items()} for i, (rnk, locIdx) in enumerate(self._handler._colMap.T): res[i] = rnk2mtdt[rnk][locIdx] self._cache[key] = res return self._cache[key] def keys(self): if self._keys is None: keys = set() for rnk, mtdt in self._rnk2Mtdt.items(): keys |= set(mtdt.keys()) self._keys = sorted(keys) return self._keys class _HDF5StaticDataAccessor(nutils.MutableDictInterface, nutils.Versioned): """ Static data accessor for HDF5 files """ def __init__(self, dbh, group, *args, **kwargs): super().__init__(*args, **kwargs) self._dbh = dbh gname = dbh._STATIC_DATA_GROUP_NAME if gname not in group: if self._dbh._file.mode == 'r': raise UnavailableDataError('There is no recorded static data in the file.') self._group = group.create_group(gname) else: self._group = group[gname] def _checkKey(self, key): if not isinstance(key, str): raise KeyError(f'Static data keys must be strings, got {key} instead.') def __setitem__(self, key, value): self._checkKey(key) if key in self._group.attrs: oldValue = self._dbh._compObjHandler.read(self._group.attrs[key]) if oldValue != value: raise Exception( f'The previously recorded value for {key} static data:\n{oldValue}' f'\nis different from the currently given {key} value:\n{value}' ) else: self._group.attrs[key] = self._dbh._compObjHandler.write(value) def __getitem__(self, key): self._checkKey(key) if key not in self._group.attrs: raise KeyError(f'There is no recorded value for {key} in the static data.') return self._dbh._compObjHandler.read(self._group.attrs[key]) def keys(self): return self._group.attrs.keys() ################################################################################################### # Database handlers
[docs]class DatabaseHandler: """Base class for all database handlers.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._parameters = None self._param2Groups = None def _getDataHandler(self, rs): """Return a _DBDataHandler for ResultSelector rs.""" pass def _getFilePaths(self): """Return a list of file paths managed by this rank""" return [] def __getitem__(self, key): """Access a run group from its unique identifier""" raise NotImplementedError() def __iter__(self): """Iterate over run groups in the database""" raise NotImplementedError() def _initializeParameters(self): """Setup parameters and param2Groups data structures""" if self._parameters is None or self._param2Groups is None: self._parameters = {} self._param2Groups = {} for group in self: for name, val in group.parameters.items(): self._parameters.setdefault(name, set()).add(val) self._param2Groups.setdefault((name, val), set()).add(group) @property def parameters(self): """All parameter values from all run groups A dictionary whose keys are parameter names and values are sets of possible values :rtype: Mapping[str, Set[Any]], read-only """ self._initializeParameters() return copy.deepcopy(self._parameters)
[docs] def filter(self, **kwargs): r"""Return all run groups that match the given parameter values :param \*\*kwargs: Keyword arguments specifying the values of parameters that the filtered run groups must match. :return: A set of run groups whose parameter values match the parameters supplied by keyword arguments :rtype: Set[DatabaseGroup] """ self._initializeParameters() res = None for keyVal in kwargs.items(): try: if res is None: res = copy.copy(self._param2Groups[keyVal]) else: res = res & self._param2Groups[keyVal] except KeyError: raise KeyError(f'Could not find any run group with {keyVal[0]} == {keyVal[1]}') return res if res is not None else set(self)
[docs] def get(self, **kwargs): r"""Get the run group that matches the given parameter values :param \*\*kwargs: Keyword arguments specifying the values of parameters that the run group must match. :return: The single run group that matches the given parameter values :rtype: DatabaseGroup If several or none of the groups match these values, an exception will be raised. """ groups = self.filter(**kwargs) if len(groups) != 1: paramVals = ' and '.join(f'{name} == {val}' for name, val in kwargs.items()) raise Exception( f'Expected a single run group to match ({paramVals}), got {len(groups)} run groups instead.' ) else: return next(iter(groups))
[docs]class DatabaseGroup: """Base class for all database run groups""" def __init__(self, dbh, *args, **kwargs): super().__init__(*args, **kwargs) self._dbh = dbh @property def name(self): """The unique identifier of the group :type: str, read-only """ raise NotImplementedError() def __hash__(self): return hash((self._dbh, self.name)) def __eq__(self, other): return isinstance(other, self.__class__) and (self._dbh, self.name) == (other._dbh, other.name)
[docs]class SQLiteDBHandler(DatabaseHandler): r"""SQLite database handler :param path: The path to the SQLite database file :type path: str :param \*args: Transmitted to :py:func:`sqlite3.connect`, see `documentation <https://docs.python.org/3/library/sqlite3.html#sqlite3.connect>`__ for details :param commitFreq: How frequently the data should be committed to the database. For example, this value is set to 10 by default which means that every 10 saving events, the data will be committed. If a result selector is saved every 10ms, it means the data will be committed to database every 100ms. :type commitFreq: int :param \*\*kwargs: Transmitted to :py:func:`sqlite3.connect`, see `documentation <https://docs.python.org/3/library/sqlite3.html#sqlite3.connect>`__ for details Handles the connection to a SQLite database and enables the saving of result selectors to that database. In contrast to the regular saving of result selectors (to memory or to file), it is possible to define groups of runs identified by a unique string so that the same database file can be used for several (sequential) runs of scripts. The database handler should be used as a context manager that wraps all simulation code. Inside this wrapped block, the user should call the :py:func:`steps.API_2.sim.Simulation.toDB` method to indicate that all results selectors associated to the simulation should be saved in the database. In this call, the user should provide the unique simulation group identifier as well as optional parameters that will also be saved to the database. Usage when saving:: sim.toSave(rs1, rs2, rs3, dt=0.01) # Add the result selectors to the # simulation. with SQLiteDBHandler(dbPath) as dbh: # Create database handler. sim.toDB(dbh, 'MySimulation', val1=1, val2=2) # Create a new group of runs in the # database with identifier # 'MySimulation' and save additional # parameters val1 and val2. for i in range(NBRUNS): # Run a series of runs, all of them sim.newRun() # being associated to the ... # 'MySimulation' group. sim.run(...) Note that after calling `sim.toDB(...)` it is still possible to force the saving of some result selectors to files by calling ``toFile(...)`` on them. Result selectors that contain a high number of values to save are probably better saved to a file. The name of the file can be added as a keyword parameter to the ``simtoDB(...)`` call to simplify loading. Usage when accessing data from the database:: with SQLiteDBHandler(dbPath) as dbh: # Create database handler. val1 = dbh['MySimulation'].val1 # Querying a parameter value from the # 'MySimulation' group. rs1, rs2, rs3 = dbh['MySimulation'].results # Querying the result selectors that # were saved for the 'MySimulation' # group. They are returned in the same # order ad they were added to the # simulation. plt.plot(rs1.time[0], rs1.data[0]) # The results selectors can be used as # if they had been declared in the same # process. """ _RS_MAIN_TABLE_NAME = 'ResultSelectors' _RS_LABEL_TABLE_NAME = 'Labels' _RS_META_DATA_TABLE_NAME = 'MetaData' _GROUP_TABLE_NAME = 'SimGroups' _DEFAULT_COMMIT_FREQ = 10 _GROUP_TABLE_KEYS = ['groupid', 'timestamp', 'uniqueid', 'nbselectors'] def __init__(self, path, *args, commitFreq=-1, **kwargs): super().__init__(*args, **kwargs) self._path = path # Only rank 0 should actually connect to the database if nsim.MPI._shouldWrite: self._conn = sqlite3.connect(path, *args, **kwargs) self._conn.row_factory = sqlite3.Row self._connected = True self._createTables() else: self._conn = None self._connected = False self._commitFreq = commitFreq if commitFreq > 0 else SQLiteDBHandler._DEFAULT_COMMIT_FREQ self._groupId = None def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self._close() def __del__(self): self._close() def _close(self): """Commit and close the connection.""" if self._connected: self._conn.commit() self._conn.close() self._connected = False def _checkConnection(self): if not self._connected: if nsim.MPI._shouldWrite: raise Exception(f'The connection to the database has been closed.') else: raise Exception(f'Cannot access the database out of the rank 0 process while using MPI.') def _getFilePaths(self): """Return a list of file paths managed by this rank""" return [self._path] if self._conn is not None else [] def _getDataHandler(self, rs): """Return a _DBDataHandler for ResultSelector rs.""" return _SQLiteDataHandler(rs, self, self._commitFreq) def _newGroup(self, sim, uid, selectors, **kwargs): """Initialize the database and add a new run group.""" if nsim.MPI._shouldWrite: self._checkConnection() # Check if the group already exists rows = self._conn.execute( f"SELECT * FROM {SQLiteDBHandler._GROUP_TABLE_NAME} WHERE uniqueid = '{uid}'" ).fetchall() if len(rows) == 0: # If it doesn't, create it typeMap = {int: 'int', float: 'real', str: 'text', bytes: 'BLOB'} colNames = ['timestamp', 'uniqueid', 'nbselectors'] values = [datetime.datetime.now(), uid, len(selectors)] for colName, val in kwargs.items(): if type(val) not in typeMap: raise TypeError( f'Cannot process {colName}={val} because val is not from one of ' f'these types: {typeMap.keys()}' ) try: self._conn.execute( f'ALTER TABLE {SQLiteDBHandler._GROUP_TABLE_NAME} ADD COLUMN ' f'{colName} {typeMap[type(val)]}' ) except sqlite3.OperationalError: pass colNames.append(colName) values.append(val) c = self._conn.cursor() c.execute( f"INSERT INTO {SQLiteDBHandler._GROUP_TABLE_NAME}({','.join(colNames)}) " f"VALUES ({','.join('?'*len(values))})", values, ) self._groupId = c.lastrowid rows = self._conn.execute( f"SELECT * FROM {SQLiteDBHandler._GROUP_TABLE_NAME} WHERE uniqueid = '{uid}'" ).fetchall() group = SQLiteGroup(self, rows[0]) else: row = rows[0] group = SQLiteGroup(self, row) # get existing group id self._groupId = row[SQLiteDBHandler._GROUP_TABLE_KEYS.index('groupid')] # Checks parameters params = { k: v for k, v in zip(row.keys(), row) if k not in SQLiteDBHandler._GROUP_TABLE_KEYS and v is not None } if kwargs != params: raise Exception( f'The keyword arguments provided to the toDB method ({kwargs}) ' f'do not match with the keyword arguments in the database for ' f'the same unique identifier ({params}).' ) # Check selectors if row[SQLiteDBHandler._GROUP_TABLE_KEYS.index('nbselectors')] != len(selectors): idnbs = SQLiteDBHandler._GROUP_TABLE_KEYS.index('nbselectors') raise Exception( f'The {uid} run group saved in the database is associated with ' f'{row[idnbs]} resultSelectors while the current simulation is ' f'associated with {len(selectors)}' ) allRs = self._conn.execute( f'SELECT * FROM {SQLiteDBHandler._RS_MAIN_TABLE_NAME} ' f'WHERE groupid={self._groupId} ORDER BY rsid' ).fetchall() for dbrs, simrs in zip(allRs, selectors): groupId, rsid, descr, tabName, nbCols = dbrs dataHandler = _SQLiteDataHandler( None, self, None, groupId=groupId, rsid=rsid, tableName=tabName, nbCols=nbCols ) if simrs.description != descr: raise Exception( f'The result selector that was previously used for this ' f'unique identifier ({descr}) differs from the one being ' f'currently used ({simrs.description}).' ) if simrs._getEvalLen() != nbCols: raise Exception( f'The result selector that was previously used for this ' f'unique identifier had {nbCols} columns while the current ' f' one has {simrs._getEvalLen()} columns.' ) # check labels dblbls = dataHandler.labels() if simrs.labels != dblbls: raise Exception( f'The result selector that was previously used for this ' f'unique identifier had different column labels. Expected ' f'{dblbls} but got {simrs.labels} instead.' ) # check metadata dbmd = dataHandler.metaData() simmd = simrs.metaData._dict if simmd != dbmd: raise Exception( f'The result selector that was previously used for this ' f'unique identifier had different metadata. Expected ' f'{dbmd} but got {simmd} instead.' ) self._conn.commit() else: group = None return group, selectors def _labelsQuerry(self, groupId, rsid): """Return labels for ResultSelector rsid in group groupid.""" self._checkConnection() rows = self._conn.execute( f'SELECT label FROM ' f'{SQLiteDBHandler._RS_LABEL_TABLE_NAME} ' f'WHERE groupid={groupId} AND rsid={rsid} ' f'ORDER BY colind' ).fetchall() return [row[0] for row in rows] def _metaDataQuerry(self, groupId, rsid): """Return metadata for ResultSelector rsid in group groupid.""" self._checkConnection() dat = self._conn.execute( f'SELECT data FROM ' f'{SQLiteDBHandler._RS_META_DATA_TABLE_NAME} ' f'WHERE groupid={groupId} AND rsid={rsid} ' ).fetchone()[0] return pickle.loads(dat) def _descriptionQuerry(self, groupId, rsid): return self._conn.execute( f'SELECT descr FROM {SQLiteDBHandler._RS_MAIN_TABLE_NAME} ' f'WHERE groupid={groupId} AND rsid={rsid}' ).fetchone()[0] def _createTables(self): """Create the tables if they do not exist.""" self._checkConnection() self._conn.execute( f'CREATE TABLE IF NOT EXISTS {SQLiteDBHandler._GROUP_TABLE_NAME} ' f'(groupid INTEGER PRIMARY KEY AUTOINCREMENT, timestamp date, ' f'uniqueid text UNIQUE, nbselectors int);' ) self._conn.execute( f'CREATE TABLE IF NOT EXISTS {SQLiteDBHandler._RS_MAIN_TABLE_NAME} ' f'(groupid int, rsid int, descr text, tabName text, nbcols int);' ) self._conn.execute( f'CREATE TABLE IF NOT EXISTS {SQLiteDBHandler._RS_LABEL_TABLE_NAME} ' f'(groupid int, rsid int, colind int, label text);' ) self._conn.execute( f'CREATE TABLE IF NOT EXISTS {SQLiteDBHandler._RS_META_DATA_TABLE_NAME} ' f'(groupid int, rsid int, data blob);' ) self._conn.commit()
[docs] def __getitem__(self, key): """Access a SQLite group from its unique identifier :param key: Unique identifier to the group :type key: str :returns: The associated SQLite group :rtype: :py:class:`SQLiteGroup` See :py:class:`SQLiteDBHandler` for usage examples. Raises a ``KeyError`` if the key is not in the database. :meta public: """ self._checkConnection() if not isinstance(key, str): raise TypeError(f'Expected a unique identifier string, got {key} instead.') rows = self._conn.execute( f"SELECT * FROM {SQLiteDBHandler._GROUP_TABLE_NAME} WHERE uniqueid == '{key}'" ).fetchall() if len(rows) == 0: raise KeyError(f'{key} does not exist in {self._path}.') return SQLiteGroup(self, rows[0])
[docs] def __iter__(self): """Iterate over SQLite groups in the database Usage:: with SQLiteDBHandler(dbPath) as dbh: # Create database handler. for group in dbh: # Iterate over all groups val1 = group.val1 # Access group data :meta public: """ self._checkConnection() rows = self._conn.execute( f'SELECT * FROM {SQLiteDBHandler._GROUP_TABLE_NAME} ORDER BY groupid' ).fetchall() for row in rows: yield SQLiteGroup(self, row)
[docs]class SQLiteGroup(DatabaseGroup): """A class representing a group of runs in a SQLite database .. note:: This class should never be instantiated by the user, it is obtained through :py:class:`SQLiteDBHandler` instead. """ def __init__(self, dbh, row, *args, **kwargs): super().__init__(*args, dbh=dbh, **kwargs) self._dict = {} for k, v in zip(row.keys(), row): if v is not None: self._dict[k] = v
[docs] def __getattr__(self, name): """Attribute access for parameters of the group :param name: Name of the parameter, as defined in the original call to ``sim.toDB(...)`` :type name: str :returns: The corresponding parameter value See :py:class:`SQLiteDBHandler` for usage examples. :meta public: """ if name not in self._dict: raise AttributeError(f'{name} is not an attribute of {self}.') return self._dict[name]
@DatabaseGroup.name.getter def name(self): """The unique identifier of the group :type: str, read-only """ return self._dict['uniqueid'] @property def results(self): """A list of all result selectors that were saved :type: List[:py:class:`ResultSelector`], read-only The result selectors are returned in the same order as they were added to the simulation with the :py:func:`steps.API_2.sim.Simulation.toSave` method. See :py:class:`SQLiteDBHandler` for usage examples. """ res = [None] * self.nbselectors rows = self._dbh._conn.execute( f'SELECT * FROM {SQLiteDBHandler._RS_MAIN_TABLE_NAME} WHERE groupid={self.groupid}' ).fetchall() for groupId, rsid, descr, tableName, nbCols in rows: res[rsid] = _ReadOnlyResultSelector( _SQLiteDataHandler( None, self._dbh, None, groupId=groupId, rsid=rsid, tableName=tableName, nbCols=nbCols ) ) return res @property def parameters(self): """A dictionary of all parameters defined for this group :type: Mapping[str, Any], read-only Usage:: >>> with SQLiteDBHandler(dbPath) as dbh: ... dbh['MySimulation'].parameters {'val1': 1, 'val2': 2} """ return {k: v for k, v in self._dict.items() if k not in SQLiteDBHandler._GROUP_TABLE_KEYS} @property def staticData(self): """Not supported for SQLite databases. See :py:attr:`HDF5Handler.staticData` """ raise NotImplementedError('Static data saving is not supported with SQLite databases.')
[docs]class HDF5Handler(DatabaseHandler, nutils.Versioned): """HDF5 File handler :param pathPrefix: Path and prefix for the HDF5 file(s) (e.g. './data/HDF5Data' would yield one file named './data/HDF5Data.h5' when the simulation is not distributed and several files named './data/HDF5Data_rank0.h5', './data/HDF5Data_rank1.h5', etc. when the simulation is distributed. :type pathPrefix: str :param hdf5FileKwArgs: Keyword arguments transmitted to :py:func:`h5py.File`, see `documentation <https://docs.h5py.org/en/stable/high/file.html#h5py.File>`__ for details :type hdf5FileKwArgs: dict :param hdf5DatasetKwArgs: Keyword arguments transmitted to :py:func:`h5py.Group.create_dataset`, see `documentation <https://docs.h5py.org/en/stable/high/group.html#h5py.Group.create_dataset>`__ for details. Most notably, compression-related argument can be set there. :type hdf5FileKwArgs: dict :param internalKwArgs: Keyword arguments specific to the handling of HDF5 files by STEPS, currently only supports `maxFullLoadSize` which improves reading speed of lists or dictionaries saved in result selectors by fully loading some datasets in memory if their size is below `maxFullLoadSize`. :type internalKwArgs: dict Handles reading and writing to an HDF5 file and enables the saving of result selectors to that file. In contrast to the regular saving of result selectors (to memory or to file), it is possible to define groups of runs identified by a unique string so that the same HDF5 file can be used for several (sequential) runs of scripts. The HDF5Handler should be used as a context manager that wraps all simulation code. Inside this wrapped block, the user should call the :py:func:`steps.API_2.sim.Simulation.toDB` method to indicate that all results selectors associated to the simulation should be saved in the HDF5 file. In this call, the user should provide the unique simulation group identifier as well as optional parameters that will also be saved to the file. Usage when saving:: sim.toSave(rs1, rs2, rs3, dt=0.01) # Add the result selectors to the # simulation. with HDF5Handler('./path/to/Prefix') as hdf: # Create database handler. sim.toDB(hdf, 'MySimulation', val1=1, val2=2) # Create a new group of runs in the # HDF5 file with identifier # 'MySimulation' and save additional # parameters val1 and val2. for i in range(NBRUNS): # Run a series of runs, all of them sim.newRun() # being associated to the ... # 'MySimulation' group. sim.run(...) Note that, in contrast with :py:class:`SQLiteDBHandler`, there is no use for forcing the saving of some result selectors to files by calling ``toFile(...)`` on them. HDF5 files can contain high amounts of data. Usage when accessing data from the database:: with HDF5Handler('./path/to/Prefix') as hdf: # Create database handler. val1 = hdf['MySimulation'].val1 # Querying a parameter value from the # 'MySimulation' group. rs1, rs2, rs3 = hdf['MySimulation'].results # Querying the result selectors that # were saved for the 'MySimulation' # group. They are returned in the same # order as they were added to the # simulation. plt.plot(rs1.time[0], rs1.data[0]) # The results selectors can be used as # if they had been declared in the same # process. Note that :py:class:`XDMFHandler` inherits from :py:class:`HDF5Handler` and generates `.xmf` files that point to the HDF5 files and can be read by data visualization software such as `Paraview <https://www.paraview.org/>`_. """ _TIMESTAMP_ATTR_NAME = 'timestamp' _STEPS_VERSION_ATTR_NAME = 'steps_version' _NB_DISTR_RANKS_ATTR_NAME = 'nb_distributed_ranks' _GROUP_DEFAULT_ATTRS = [_TIMESTAMP_ATTR_NAME, _STEPS_VERSION_ATTR_NAME, _NB_DISTR_RANKS_ATTR_NAME] _RS_DESCRIPTION_ATTR = 'Description' _RS_INDEX_ATTR = 'RSIndex' _RS_GROUP_NAME = 'ResultSelector{}' _RS_DISTGROUP_NAME = 'DistributedResultSelector{}' _RS_DIST_IND_MAP_NAME = 'DistributedColumnMap' _STATIC_DATA_GROUP_NAME = 'staticData' _DISTRIBUTED_HDF_SUFFIX = '{}_rank{}' _HDF_EXTENSION = '.h5' def __init__(self, pathPrefix, hdf5FileKwArgs={}, hdf5DatasetKwArgs={}, internalKwArgs={}, **kwargs): super().__init__(**kwargs) self._pathPrefix = pathPrefix self._path = None self._file = None self._currGroup = None self._compObjHandler = None self._shouldWrite = nsim.MPI._shouldWrite self._fileKwArgs = hdf5FileKwArgs self._dataSetKWargs = hdf5DatasetKwArgs self._internalKwArgs = internalKwArgs self._nbSavingRanks = None # HDF5 database handlers needed when loading distributed data self._distribRankDBHs = {} # rank -> dbh # Local read-only result selectors needed when loading distributed data self._distribRS = {} # dbUID -> { rsInd -> rs } def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self._close() def __del__(self): self._close() def _close(self): """Close the file""" if hasattr(self, '_file') and self._file is not None: self._file.close() for rnk, dbh in self._distribRankDBHs.items(): if rnk != nsim.MPI._rank: dbh._close() def _checkOpenFile(self, sim=None): if self._file is None: import h5py if sim is None: self._path = self._pathPrefix + HDF5Handler._HDF_EXTENSION if not os.path.isfile(self._path): self._path = ( HDF5Handler._DISTRIBUTED_HDF_SUFFIX.format(self._pathPrefix, 0) + HDF5Handler._HDF_EXTENSION ) if not os.path.isfile(self._path): raise Exception( f'Cannot load any HDF files with prefix {self._pathPrefix}.' ) self._file = h5py.File(self._path, 'r', **self._fileKwArgs) else: if sim._isDistributed(): self._path = HDF5Handler._DISTRIBUTED_HDF_SUFFIX.format(self._pathPrefix, nsim.MPI._rank) elif nsim.MPI._shouldWrite: self._path = self._pathPrefix else: raise Exception( f'Cannot access the HDF5 file out of the rank 0 process while using MPI.' ) self._path += HDF5Handler._HDF_EXTENSION self._file = h5py.File(self._path, 'a', **self._fileKwArgs) # Initialize compound object handler if self._compObjHandler is None: DTPE = _HDF5CompoundObjHandler._DATA_TYPE try: self._compObjHandler = _HDF5CompoundObjHandler( self._file, self, cachedTypes=[DTPE.INT, DTPE.STRING, DTPE.LIST, DTPE.DICT], readOnly=sim is None, **self._internalKwArgs ) except ReadOnlyWriteError: self._compObjHandler = None elif not self._file: raise Exception(f'The HDF5 was closed.') def _getFilePaths(self): """Return a list of file paths managed by this rank""" return [self._path] if self._path is not None else [] def _getRsHDFGroup(self, rs, groupNamePattern=None): if groupNamePattern is None: groupNamePattern = HDF5Handler._RS_GROUP_NAME if self._shouldWrite and rs._getEvalLen() > 0: rsName = groupNamePattern.format(rs._selectorInd) hdf5Group = self._currGroup._group if rsName not in hdf5Group: rsgroup = hdf5Group.create_group(rsName) rsgroup.attrs[HDF5Handler._RS_DESCRIPTION_ATTR] = rs.description rsgroup.attrs[HDF5Handler._RS_INDEX_ATTR] = rs._selectorInd return hdf5Group[rsName] else: return None def _getDataHandler(self, rs, groupNamePattern=None): """Return a _DBDataHandler for ResultSelector rs.""" if groupNamePattern is None: groupNamePattern = HDF5Handler._RS_GROUP_NAME return _HDF5DataHandler(self, rs, self._getRsHDFGroup(rs, groupNamePattern), version=self._version) def _checkSelectors(self, uid, selectors): """Check that selectors in the HDF5 file match the simulation selectors""" selectorNames = [ n for n in self._currGroup._group if n.startswith(HDF5Handler._RS_GROUP_NAME.format('')) ] nonEmptySelectors = [rs for rs in selectors if rs._getEvalLen() > 0] if len(selectorNames) != len(nonEmptySelectors): raise Exception( f'The {uid} run group saved in the database is associated with ' f'{len(selectorNames)} resultSelectors while the current simulation is ' f'associated with {len(nonEmptySelectors)}.' ) for rsName, simrs in zip(selectorNames, nonEmptySelectors): handler = self._getDataHandler(simrs) grouprs = self._currGroup._group[rsName] descr = grouprs.attrs[HDF5Handler._RS_DESCRIPTION_ATTR] rsInd = grouprs.attrs[HDF5Handler._RS_INDEX_ATTR] if simrs.description != descr: raise Exception( f'The result selector that was previously used for this ' f'unique identifier ({descr}) differs from the one being ' f'currently used ({simrs.description}).' ) if simrs._selectorInd != rsInd: raise Exception( f'The index result selector that was previously used for this ' f'unique identifier ({rsInd}) differs from the one being ' f'currently used ({simrs._selectorInd}).' ) if simrs._getEvalLen() != len(handler.data()[0, 0, :]): raise Exception( f'The result selector that was previously used for this unique identifier ' f'had {grouprs["data"].shape[2]} columns while the current one has ' f'{simrs._getEvalLen()} columns.' ) # check labels if handler.labels() != simrs.labels: raise Exception( f'The result selector that was previously used for this ' f'unique identifier had different column labels. Expected ' f'{handler.labels()} but got {simrs.labels} instead.' ) # check metadata filemd = handler.metaData() simmd = simrs.metaData._dict if simmd != filemd: raise Exception( f'The result selector that was previously used for this ' f'unique identifier had different metadata. Expected ' f'{filemd} but got {simmd} instead.' ) def _checkDistributedSelectors(self, uid, rsInd2DistColMap): """Check that the distributed column map of each distributed result selector in the HDF5 file matches the simulation values""" selectorNames = [ n for n in self._currGroup._group if n.startswith(HDF5Handler._RS_DISTGROUP_NAME.format('')) ] if len(selectorNames) != len(rsInd2DistColMap): if nsim.MPI._shouldWrite: raise Exception( f'The {uid} run group saved in the database is associated with ' f'{len(selectorNames)} resultSelectors while the current simulation is ' f'associated with {len(rsInd2DistColMap)}.' ) else: raise Exception( f'The HDF5 file associated with MPI rank {nsim.MPI._rank} contains ' f'{len(selectorNames)} distributed column remapping dataset while it should ' f'contain {len(rsInd2DistColMap)}.' ) for rsInd, rsName in enumerate(selectorNames): grouprs = self._currGroup._group[rsName] simColMap = rsInd2DistColMap[rsInd] rsColMap = numpy.array(grouprs[HDF5Handler._RS_DIST_IND_MAP_NAME]) if simColMap.shape[1] != rsColMap.shape[1]: raise Exception( f'The result selector that was previously used for this unique identifier ' f'had {rsColMap.shape[1]} columns while the current one has ' f'{simColMap.shape[1]} columns.' ) for i, ((rsRnk, rsIdx), (simRnk, simIdx)) in enumerate(zip(rsColMap.T, simColMap.T)): if (rsRnk, rsIdx) != (simRnk, simIdx): raise Exception( f'In Result selector {rsInd}, according to the data in the HDF5 file, column ' f'{i} should be mapped to rank {rsRnk} and local column {rsIdx}, but in the' f'current simulation, it is mapped to rank {simRnk} and local column {simIdx}.' ) def _distributeSelectors(self, sim, selectors): """Distribute result selectors Return a list of distributed selectors and, on rank 0, a map between result selector index and distributed column mapping: for each column in the non-distributed selector, it contains the rank in which the value will be saved and the local index in the distributed selector. """ optimGroups = [[]] rs2FullLen = {} for rs in selectors: # Distribute the result selector rs, changed = rs._distribute() if rs is not None: # Re-create optimization groups if len(optimGroups[-1]) == 0 or optimGroups[-1][-1][0]._optimGroupInd == rs._optimGroupInd: optimGroups[-1].append((rs, changed)) else: optimGroups.append([]) distribSelectors = [] for optGrp in optimGroups: if len(optGrp) > 0: distRs, distChanged = zip(*optGrp) # Reoptimize the calls if something changed after distribution if any(distChanged): distRs = nsaving_optim.OptimizeSelectors(sim, distRs) distribSelectors += distRs import mpi4py.MPI rsInd2DistColMap = {} # Build the mapping between full selector and local distributed selectors on rank 0 localDistrInds = {rs._selectorInd: (rs._fullLen, rs._distrInds) for rs in distribSelectors} allDistribInds = mpi4py.MPI.COMM_WORLD.gather(localDistrInds, root=0) if nsim.MPI._shouldWrite: allRsInds = set() rsFullLens = {} for dct in allDistribInds: for rsInd, (fullLen, _) in dct.items(): allRsInds.add(rsInd) rsFullLens[rsInd] = max(rsFullLens.get(rsInd, 0), fullLen) for rsInd in allRsInds: dcm = -1 * numpy.ones((2, rsFullLens[rsInd]), dtype=numpy.int64) for rnk, distInds in enumerate(allDistribInds): if rsInd in distInds: for localInd, ind in enumerate(distInds[rsInd][1]): dcm[0, ind] = rnk dcm[1, ind] = localInd rsInd2DistColMap[rsInd] = dcm return distribSelectors, rsInd2DistColMap def _newGroup(self, sim, uid, selectors, **kwargs): """Initialize the file and add a new run group.""" self._shouldWrite = nsim.MPI._shouldWrite or sim._isDistributed() self._nbSavingRanks = nsim.MPI._nhosts if sim._isDistributed() else 1 if self._shouldWrite: self._checkOpenFile(sim) # Check if the group already exists if uid not in self._file: # If it doesn't, create it group = self._file.create_group(uid, track_order=True) group.attrs[HDF5Handler._TIMESTAMP_ATTR_NAME] = str(datetime.datetime.now()) group.attrs[HDF5Handler._STEPS_VERSION_ATTR_NAME] = steps.__version__ group.attrs[HDF5Handler._NB_DISTR_RANKS_ATTR_NAME] = self._nbSavingRanks self._currGroup = HDF5Group(self, group, version=self._version) try: for argName, val in kwargs.items(): group.attrs[argName] = val except Exception as ex: # If attributes could not be set, delete the group before raising the exception self._currGroup = None del self._file[uid] raise ex # Pre-create distributed result selector groups if needed if sim._isDistributed() and self._nbSavingRanks > 1: # All ranks distribute their selectors distribSelectors, rsInd2DistColMap = self._distributeSelectors(sim, selectors) if nsim.MPI._shouldWrite: # Only rank 0 writes the full result selectors for rs in selectors: handler = self._getDataHandler(rs, groupNamePattern=HDF5Handler._RS_DISTGROUP_NAME) handler._group.create_dataset( HDF5Handler._RS_DIST_IND_MAP_NAME, data=rsInd2DistColMap[rs._selectorInd], **self._dataSetKWargs ) selectors = distribSelectors # Pre-create local result selector groups to be sure to have the correct order for rs in selectors: self._getRsHDFGroup(rs) else: group = self._file[uid] self._currGroup = HDF5Group(self, group, version=self._version) # Load version and check it matches the currently used version version = self._parseVersion(group.attrs[HDF5Handler._STEPS_VERSION_ATTR_NAME]) if version != self._version: raise Exception( f'Cannot add results to a group that was created with a different STEPS version. ' f'Group version: {version}, Current version: {self._version}' ) # Check that the number of distributed ranks is the same fileNbRanks = group.attrs[HDF5Handler._NB_DISTR_RANKS_ATTR_NAME] if fileNbRanks != self._nbSavingRanks: raise Exception( f'Cannot add results to a group that was created with a different number of MPI ' f'processes. The {uid} group was created with {fileNbRanks} while the current ' f'simulation is being run with {self._nbSavingRanks} processes.' ) # Checks parameters params = { k: v for k, v in group.attrs.items() if k not in HDF5Handler._GROUP_DEFAULT_ATTRS } if kwargs != params: raise Exception( f'The keyword arguments provided to the toDB method ({kwargs}) ' f'do not match with the keyword arguments in the HDF5 file for ' f'the same unique identifier ({params}).' ) # Check distributed result selectors if sim._isDistributed() and self._nbSavingRanks > 1: # All ranks distribute their selectors selectors, rsInd2DistColMap = self._distributeSelectors(sim, selectors) self._checkDistributedSelectors(uid, rsInd2DistColMap) # Check local result selectors self._checkSelectors(uid, selectors) else: self._currGroup = None self._close() return self._currGroup, selectors
[docs] def __getitem__(self, key): """Access an HDF5 group from its unique identifier :param key: Unique identifier to the group :type key: str :returns: The associated HDF5 group :rtype: :py:class:`HDF5Group` See :py:class:`HDF5Handler` for usage examples. Raises a ``KeyError`` if the key is not in the file. :meta public: """ self._checkOpenFile() if not isinstance(key, str): raise TypeError(f'Expected a unique identifier string, got {key} instead.') if key not in self._file: raise KeyError(f'{key} does not exist in {self._path}.') group = self._file[key] return HDF5Group(self, group, version=group.attrs[HDF5Handler._STEPS_VERSION_ATTR_NAME])
[docs] def __iter__(self): """Iterate over STEPS groups in the file Usage:: with HDF5Handler(filePath) as hdf: # Create database handler. for group in hdf: # Iterate over all groups val1 = group.val1 # Access group data Note that not all HDF5 groups in the file will be iterated on, only the ones that were added by STEPS. Other groups will be ignored by STEPS. :meta public: """ self._checkOpenFile() res = [] def visit(name, group): if all(attr in group.attrs for attr in HDF5Handler._GROUP_DEFAULT_ATTRS): res.append(HDF5Group(self, group, version=group.attrs[HDF5Handler._STEPS_VERSION_ATTR_NAME])) self._file.visititems(visit) for gr in res: yield gr
[docs]class HDF5Group(DatabaseGroup, nutils.Versioned): """A class representing a group of runs in an HDF5 file .. note:: This class should never be instantiated by the user, it is obtained through :py:class:`HDF5Handler` instead. """ def __init__(self, dbh, group, *args, **kwargs): super().__init__(*args, dbh=dbh, **kwargs) self._group = group
[docs] def __getattr__(self, name): """Attribute access for parameters of the group :param name: Name of the parameter, as defined in the original call to ``sim.toDB(...)`` :type name: str :returns: The corresponding parameter value See :py:class:`SQLiteDBHandler` for usage examples. :meta public: """ if name not in self._group.attrs: raise AttributeError(f'{name} is not an attribute of {self}.') return self._group.attrs[name]
@DatabaseGroup.name.getter def name(self): """The unique identifier of the group :type: str, read-only """ path = self._group.name.split('/') return path[-1] @property def results(self): """A list of all result selectors that were saved :type: List[:py:class:`ResultSelector`], read-only The result selectors are returned in the same order as they were added to the simulation with the :py:func:`steps.API_2.sim.Simulation.toSave` method. See :py:class:`HDF5Handler` for usage examples. """ # First check if distributed data available distrGroupNames = [ gn for gn in self._group if re.match(HDF5Handler._RS_DISTGROUP_NAME.format(r'\d+'), gn) is not None ] if len(distrGroupNames) > 0: return [ _ReadOnlyResultSelector( _HDF5DistribDataHandler(self, self._dbh, None, self._group[gn], version=self._version) ) for gn in distrGroupNames ] else: return [ _ReadOnlyResultSelector( _HDF5DataHandler(self._dbh, None, self._group[gn], version=self._version) ) for gn in self._group if re.match(HDF5Handler._RS_GROUP_NAME.format(r'\d+'), gn) is not None ] @property def parameters(self): """A dictionary of all parameters defined for this group :type: Mapping[str, Any], read-only Usage:: >>> with HDF5Handler(filePath) as hdf: ... hdf['MySimulation'].parameters {'val1': 1, 'val2': 2} """ return {k: v for k, v in self._group.attrs.items() if k not in HDF5Handler._GROUP_DEFAULT_ATTRS} @property def staticData(self): """A mutable mapping which contains static data specific to this run group :type: Mapping[str, Union[List, Dict, float, int, str]] Usage when writing data:: >>> with HDF5Handler(filePath) as hdf: >>> group = sim.toDB(hdf, 'RunGroup1') >>> group.staticData['StimPoints'] = [1, 5, 8] The static data that is saved must be specific to the whole run group. If the key associated to the data already exists in the static data, STEPS will check that the value given is the same as the one that was already saved. If not, an exception will be raised. Usage when reading data:: >>> with HDF5Handler(filePath) as hdf: >>> group = hdf['RunGroup1'] >>> group.staticData['StimPoints'] [1, 5, 8] Note that when using MPI, only rank 0 can access this property. """ if not nsim.MPI._shouldWrite: raise Exception(f'Only rank 0 can access staticData.') return _HDF5StaticDataAccessor(self._dbh, self._group, version=self._version)
[docs]class XDMFHandler(HDF5Handler): """XDMF / HDF5 File handler :param pathPrefix: Path and prefix for the HDF5 file(s), see :py:class:`HDF5Handler`. :type pathPrefix: str :param hdf5FileKwArgs: see :py:class:`HDF5Handler`. :type hdf5FileKwArgs: dict :param hdf5DatasetKwArgs: see :py:class:`HDF5Handler`. :type hdf5FileKwArgs: dict :param xdmfFolder: Path to the folder to which XDMF files should be written. If `None`, it uses the folder in which HDF5 files will be saved. :type xdmfFolder: Union[str, None] The `XDMF file format <https://www.xdmf.org/>`_ uses XML files with the `.xmf` extension to describe data saved in an HDF5 file. Scientific visualization tools like `Paraview <https://www.paraview.org/>`_ can read `.xmf` files and access the corresponding data in HDF5 files to display the mesh and the data associated with it. The :py:class:`XDMFHandler` database handler inherits from :py:class:`HDF5Handler` and thus behaves like an HDF5 database handler. The main difference is that :py:class:`XDMFHandler` also saves mesh information to the HDF5 file and generates `.xmf` XDMF files that describe all the data that can be visualized on meshes. Since it works in the same way as :py:class:`HDF5Handler`, usage examples can be seen there. Regardless on whether the data saved by a :py:class:`ResulSelector` is specific to mesh elements, it will be saved in the same way as if :py:class:`HDF5Handler` was used. Data that is specific to mesh elements (e.g. count of species in a tetrahedron, concentration of species in a region of interest, etc.) will be described in the `.xmf` file. The following mesh locations (see :py:class:`steps.API_2.sim.SimPath`) are supported: +---------------------+--------------------------+----------------+ | Location | Result selector | XDMF data type | +=====================+==========================+================+ | Tetrahedrons | ``rs.TETS(tetLst)...`` | Cell data | +---------------------+--------------------------+----------------+ | Triangles | ``rs.TRIS(triLst)...`` | Cell data | +---------------------+--------------------------+----------------+ | Vertices | ``rs.VERTS(vertLst)...`` | Node data | +---------------------+--------------------------+----------------+ | Regions of Interest | ``rs.ROIname...`` | Grid data | +---------------------+--------------------------+----------------+ | Compartments | ``rs.compName...`` | Grid data | +---------------------+--------------------------+----------------+ | Patches | ``rs.patchName...`` | Grid data | +---------------------+--------------------------+----------------+ Technical note: the result selectors that involve mesh data (data that will be described in the `.xmf` file) might be stored in the HDF5 file in a way that is different from how it is normally stored with :py:class:`HDF5Handler`. Notably, the order of :py:class:`ResultSelector` columns might be changed to allow contiguous data access when loading the data into scientific visualization softwares. These differences do not affect users as long as the data is read using :py:func:`HDF5Handler.__getitem__`, columns will be correctly reordered to match the original :py:class:`ResultSelector` order. """ _MESH_GROUP_NAME = 'mesh' _FILE_NAME_PATTERN = '{uid}_Run{run}_rank{rank}.xmf' _FULL_FILE_NAME_PATTERN = '{uid}_Run{run}_Full.xmf' def __init__(self, pathPrefix, hdf5FileKwArgs={}, hdf5DatasetKwArgs={}, xdmfFolder=None, **kwargs): super().__init__(pathPrefix, hdf5FileKwArgs, hdf5DatasetKwArgs, **kwargs) if xdmfFolder is None: xdmfFolder = os.path.dirname(pathPrefix) self._xdmfFolder = xdmfFolder self._xdmfTree = None self._fullXdmf = None self._currUID = None self._sim = None self._temporalGrids = None self._spatialGrids = None self._currTime = None self._currRun = None self._savedFilePaths = [] def _close(self): """Close the file""" super()._close() if self._xdmfTree is not None: self._writeXMLTree() self._xdmfTree = None def _getFilePaths(self): """Return a list of file paths managed by this rank""" savedFp = set(self._savedFilePaths) # To be saved soon: savedFp.add(self._getCurrXDMFFilePath()) if self._sim._isDistributed() and nsim.MPI._rank == 0: savedFp.add(self._getCurrFullXDMFFilePath()) return super()._getFilePaths() + sorted(savedFp) def _getDataHandler(self, rs, groupNamePattern=None): """Return a _DBDataHandler for ResultSelector rs.""" return _XDMFDataHandler(self, rs, self._getRsHDFGroup(rs, groupNamePattern)) def _newGroup(self, sim, uid, selectors, **kwargs): """Initialize the file and add a new run group.""" group, selectors = super()._newGroup(sim, uid, selectors, **kwargs) if self._xdmfTree is not None: self._writeXMLTree() self._currUID = uid self._sim = sim self._xdmfTree = None self._fullXdmf = None self._currRun = None if not isinstance(self._sim.geom, ngeom._BaseTetMesh): raise TypeError(f'XDMF data saving only works with tetrahedral meshes.') self._setUpSelectors(selectors) self._writeModelInfo() return group, selectors def _getAttributeName(self, val, loctpe): return f'{val} ({loctpe})' if loctpe is not None else val def _extractAndAddSpatialGridInfo(self, elem2infos, refCls, rsColRemaps, ROIgrids): elem2infos = elem2infos[refCls._locStr] lstCls = refCls._lstCls # Group by set of saved values elemMap = {} for elemIdx, infos in elem2infos.items(): key = frozenset(info[0:2] for info in infos) for info in infos: elemMap.setdefault(key, []).append((elemIdx, ) + info) lstKwArgs = dict(mesh=self._sim.geom) if self._sim._isDistributed(): lstKwArgs['local'] = True coveredElems = refCls._lstCls([], **lstKwArgs) # Treat set of saved values independently for rsSet, elemInfo in elemMap.items(): # Remove duplicate values from different result selectors elemInfo = {(idx, val): (rsId, rsPos) for idx, val, rsId, rsPos in elemInfo} # Get element list for each result selector rs2Elems = {} for (idx, val), (rsId, rsPos) in elemInfo.items(): rs2Elems.setdefault(rsId, set()).add(idx) # Compute the intersections of element lists grids = [] for rsId, elemInds in rs2Elems.items(): elemLst = lstCls(sorted(elemInds), **lstKwArgs) # Add intersection with ROI grids for roitets, roiRsVals, loc in ROIgrids[refCls]: inter = (roitets & elemLst) - coveredElems if len(inter) > 0: grids.append((inter, copy.copy(roiRsVals), loc)) coveredElems |= inter newGrids = [] for g, rsVals, loc in grids: if len(elemLst) > 0: inter = g & elemLst if len(inter) > 0: newGrids.append((inter, copy.copy(rsVals), loc)) g -= inter elemLst -= inter if len(g) > 0: newGrids.append((g, copy.copy(rsVals), loc)) if len(elemLst) > 0: coveredElems |= elemLst newGrids.append((elemLst, [], (refCls,))) grids = newGrids # Add The resulting element lists as our distinct spatial grids elem2Grid = {} for grid, rsVals, loc in grids: gridPos = len(self._spatialGrids[refCls]) self._spatialGrids[refCls].append((grid, rsVals, loc)) for elem in grid: elem2Grid[elem.idx] = gridPos # Group data by result selectors and grid rs2Infos = {} for (idx, val), (rsId, rsPos) in elemInfo.items(): rs2Infos.setdefault( (rsId, elem2Grid[idx]), {} ).setdefault( val, [] ).append((idx, rsPos)) # Add mapping between result selectors and spatial grids for (rsId, gridPos), val2elems in rs2Infos.items(): for val, idxLst in val2elems.items(): idxLst.sort(key=lambda x: x[0]) elemIdxs, rsPoss = zip(*idxLst) start = len(rsColRemaps.get(rsId, [])) center = 'Node' if refCls == ngeom.VertReference else 'Cell' self._spatialGrids[refCls][gridPos][1].append( (val, None, rsId, start, 1, len(elemIdxs), center) ) rsColRemaps.setdefault(rsId, []) rsColRemaps[rsId] += list(rsPoss) # Add portions of ROI grid that were not added previously for roiElems, roiRsVals, loc in ROIgrids[refCls]: remaining = roiElems - coveredElems if len(remaining) > 0: self._spatialGrids[refCls].append((remaining, roiRsVals, loc)) def _getValName(self, sel, i): if isinstance(sel, _ResultPath): loc, *objsval = sel._labels[i].split('.') return '.'.join(objsval) elif isinstance(sel, _ResultCombiner): subLabels = sel._labelArgFunc(i, sel.children) return sel._labelStrFunc( *[self._getValName(s.sel, s.ind) if isinstance(s, _LabelSelector) else s for s in subLabels] ) elif isinstance(sel, _ResultList): tot = 0 for c in sel.children: if i - tot < c._getEvalLen(): return self._getValName(c, i - tot) tot += c._getEvalLen() def _setUpSelectors(self, selectors): self._spatialGrids = { ngeom.TetReference: [], ngeom.TriReference: [], ngeom.VertReference: [], } self._rs2Grids = {(nsim.MPI._rank, rs._selectorInd): [] for rs in selectors} locStr2RefCls = { cls._locStr: cls for cls in [ngeom.TetReference, ngeom.TriReference, ngeom.VertReference] } # Extract info from result selectors # elem2infos will contain data using global ids in the case of distributed meshes distrElem2infos = { ngeom.TetReference._locStr: {}, ngeom.TriReference._locStr: {}, ngeom.VertReference._locStr: {}, } nonDistrElem2infos = { ngeom.ROI._locStr: {}, ngeom.Compartment._locStr: {}, ngeom.Patch._locStr: {}, } for rs in selectors: rsId = (nsim.MPI._rank, rs._selectorInd) if 'loc_type' in rs.metaData and 'loc_id' in rs.metaData: types = rs.metaData['loc_type'] inds = rs.metaData['loc_id'] ves_types = rs.metaData.get('vesicle_type', None) raft_types = rs.metaData.get('raft_type', None) for i, (tpe, ind) in enumerate(zip(types, inds)): if not all(tpes is None or tpes[i] is None for tpes in [ves_types, raft_types]): # Do not consider values that are linked with vesicles or rafts continue val = self._getValName(rs, i) # If the mesh is distributed, we need to use local inds if self._sim._isDistributed() and tpe in locStr2RefCls: ind = locStr2RefCls[tpe]._distCls._getToLocalFunc(self._sim.geom)(ind) if tpe in distrElem2infos: distrElem2infos[tpe].setdefault(ind, []).append((val, rsId, i)) elif tpe in nonDistrElem2infos: nonDistrElem2infos[tpe].setdefault(ind, []).append((val, rsId, i)) if self._sim._isDistributed(): # Update local nonDistrElem2infos with the ones from other ranks import mpi4py.MPI allNDElem2Infos = mpi4py.MPI.COMM_WORLD.allgather(nonDistrElem2infos) for rnk, e2i in enumerate(allNDElem2Infos): for locStr, dct in e2i.items(): for ind, lst in dct.items(): nonDistrElem2infos[locStr].setdefault(ind, []) nonDistrElem2infos[locStr][ind] += lst # Merge distributable and non-distributable elem2infos elem2infos = {**distrElem2infos, **nonDistrElem2infos} rsColRemaps = {} # Regions of interest # ROIgrids will contain data using local ids in the case of distributed meshes ROIgrids = { ngeom.TetReference: [], ngeom.TriReference: [], ngeom.VertReference: [], } # Pre-split the grids according to compartments and patches for comp in self._sim.geom.ALL(ngeom.Compartment): if self._sim._isDistributed(): compTets = comp.tets.toLocal() else: compTets = comp.tets if len(compTets) > 0: loc = (ngeom.TetReference, comp.name) ROIgrids[ngeom.TetReference].append((compTets, [], loc)) for patch in self._sim.geom.ALL(ngeom.Patch): if self._sim._isDistributed(): patchTris = patch.tris.toLocal() else: patchTris = patch.tris if len(patchTris) > 0: loc = (ngeom.TriReference, patch.name) ROIgrids[ngeom.TriReference].append((patchTris, [], loc)) # Treat compartments and patches data as ROI data for roiCls in [ngeom.ROI, ngeom.Compartment, ngeom.Patch]: for name, valsLst in elem2infos[roiCls._locStr].items(): zone = getattr(self._sim.geom, name) if roiCls == ngeom.ROI: elems = zone[:] elif roiCls == ngeom.Compartment: elems = zone.tets elif roiCls == ngeom.Patch: elems = zone.tris if self._sim._isDistributed(): elems = elems.toLocal() rsVals = [ (val, roiCls._locStr, rsId, rsPos, 1, 1, 'Grid') for val, rsId, rsPos in valsLst ] newGrids = [] for elems2, rsVals2, loc2 in ROIgrids[elems._refCls]: inter = elems & elems2 elems -= inter elems2 -= inter if len(inter) > 0: newGrids.append((inter, rsVals + rsVals2, loc2)) if len(elems2) > 0: newGrids.append((elems2, rsVals2, loc2)) if len(elems) > 0: newGrids.append((elems, rsVals, (elems._refCls,))) ROIgrids[elems._refCls] = newGrids # Tetrahedron grids self._extractAndAddSpatialGridInfo(elem2infos, ngeom.TetReference, rsColRemaps, ROIgrids) # Add empty tet grid at the end if self._sim._isDistributed(): tets = self._sim.geom.tets.toLocal() else: tets = self._sim.geom.tets for tetInds, *_ in self._spatialGrids[ngeom.TetReference]: tets -= tetInds if len(tets) > 0: self._spatialGrids[ngeom.TetReference].append((tets, [], (ngeom.TetReference,))) # Triangle grids self._extractAndAddSpatialGridInfo(elem2infos, ngeom.TriReference, rsColRemaps, ROIgrids) # Vertices self._extractAndAddSpatialGridInfo(elem2infos, ngeom.VertReference, rsColRemaps, ROIgrids) # Fill rs2Grids map for refCls, grids in self._spatialGrids.items(): for i, (elems, rsVals, _) in enumerate(grids): for val, loctpe, rsId, start, step, nVals, center in rsVals: self._rs2Grids.setdefault(rsId, []).append((val, loctpe, start, step, nVals, refCls, i, center)) if self._shouldWrite: # Write column remapping of result selectors, if needed for rs in selectors: n = rs._getEvalLen() colRemap = rsColRemaps.get((nsim.MPI._rank, rs._selectorInd), []) if len(colRemap) > 0: if len(colRemap) < n: colRemap += sorted(set(range(n)) - set(colRemap)) colRemap = numpy.array(colRemap, dtype=numpy.int64) # Write the column remap if different from neutral remap if any(a != b for a, b in zip(colRemap, range(len(colRemap)))): rsgroup = self._getRsHDFGroup(rs) if _HDF5DataHandler._RS_COLREMAPPING_NAME in rsgroup: if any(a != b for a, b in zip(rsgroup[_HDF5DataHandler._RS_COLREMAPPING_NAME], colRemap)): raise Exception( f'Column remapping was different for previous runs. Try saving to an ' f'empty HDF5 file.' ) else: rsgroup.create_dataset(_HDF5DataHandler._RS_COLREMAPPING_NAME, data=colRemap, **self._dataSetKWargs) if self._sim._isDistributed(): # Synchronize non-distributed result selector data import mpi4py.MPI allRS2Len = mpi4py.MPI.COMM_WORLD.allgather({rs._selectorInd: rs._getEvalLen() for rs in selectors}) allColRemaps = mpi4py.MPI.COMM_WORLD.allgather(rsColRemaps) self._grid2NonDistrRS = {} for (rnk, rsIdx), gridVals in self._rs2Grids.items(): if rnk != nsim.MPI._rank: for val, loctpe, start, step, nVals, gridCls, gridInd, center in gridVals: assert nVals == 1 localRemap = allColRemaps[nsim.MPI._rank].get((rnk, rsIdx), None) remoteRemap = allColRemaps[rnk].get((rnk, rsIdx), None) if localRemap is not None: start = localRemap[start] if remoteRemap is not None: start = remoteRemap.index(start) remoteRsLen = allRS2Len[rnk][rsIdx] self._grid2NonDistrRS.setdefault((gridCls, gridInd), []).append( (val, loctpe, start, center, rnk, rsIdx, remoteRsLen) ) def _getCurrXDMFFilePath(self): fileName = XDMFHandler._FILE_NAME_PATTERN.format( uid=self._currUID, run=self._currRun, rank=nsim.MPI._rank ) return os.path.join(self._xdmfFolder, fileName) def _getCurrFullXDMFFilePath(self): fileName = XDMFHandler._FULL_FILE_NAME_PATTERN.format( uid=self._currUID, run=self._currRun ) return os.path.join(self._xdmfFolder, fileName) def _writeXMLTree(self): if self._shouldWrite: tree = ElementTree.ElementTree(self._xdmfTree) filePath = self._getCurrXDMFFilePath() self._savedFilePaths.append(filePath) tree.write(filePath) if self._fullXdmf is not None: tree = ElementTree.ElementTree(self._fullXdmf) filePath = self._getCurrFullXDMFFilePath() self._savedFilePaths.append(filePath) tree.write(filePath) @staticmethod def _getHierarchicalParent(hierarchicalGrids, path): xmlPath = '/Xdmf/Domain/Grid' for i in range(len(path)): gridName = f'{path[i]._locStr}Grids' if i == 0 else path[i] xmlPath += f"/Grid[@Name='{gridName}']" if path[:i+1] not in hierarchicalGrids: # First element is a reference class, not a string, requires special treatment. hierarchicalGrids[path[:i+1]] = ElementTree.SubElement( hierarchicalGrids[path[:i]], 'Grid', Name=gridName, GridType='Collection', CollectionType='Spatial' ) return hierarchicalGrids[path], xmlPath @staticmethod def _createXMLDocument(): xdmfTree = ElementTree.Element('Xdmf', Version='2.0') xdmfTree.set('xmlns:xi', 'http://www.w3.org/2001/XInclude') dom = ElementTree.SubElement(xdmfTree, 'Domain') hierarchicalGrids = {} # Add root hierarchicalGrids[tuple()] = ElementTree.SubElement( dom, 'Grid', Name=f'SpatialGrids', GridType='Collection', CollectionType='Spatial' ) return xdmfTree, hierarchicalGrids def _newRun(self, rid): if self._currRun != rid: self._temporalGrids = { ngeom.TetReference: [], ngeom.TriReference: [], ngeom.VertReference: [], } # Write previous XML tree if self._xdmfTree is not None: self._writeXMLTree() # Create XML document self._xdmfTree, self._hierarchicalGrids = self._createXMLDocument() # Add temporal grids allLocations = [] for refCls, grids in self._spatialGrids.items(): for i, (elems, _, loc) in enumerate(grids): if self._shouldWrite: if self._sim._isDistributed(): gridName = f'{refCls._locStr}Grid{i}_rank{nsim.MPI._rank}' else: gridName = f'{refCls._locStr}Grid{i}' parentGrid, xmlPath = self._getHierarchicalParent(self._hierarchicalGrids, loc) xmlPath += f"/Grid[@Name='{gridName}']" # Add the grid tempGrid = ElementTree.SubElement( parentGrid, 'Grid', Name=gridName, GridType='Collection', CollectionType='Temporal' ) allLocations.append((loc, xmlPath)) else: tempGrid = None xmlPath = None self._temporalGrids[refCls].append((tempGrid, xmlPath, None, None)) # Write the grids for t=0, this way we are sure that all grids are written at least once # even if no data is associated to it. self._writeGrid(refCls, i, 0, 0) self._currTime = None self._currRun = rid if self._sim._isDistributed(): # Gather allLocations to rank 0 import mpi4py.MPI elemStr2ElemCls = {elemCls._locStr: elemCls for elemCls in self._temporalGrids.keys()} allLocations = [((loc[0]._locStr,) + loc[1:], path) for loc, path in allLocations] fileName = XDMFHandler._FILE_NAME_PATTERN.format( uid=self._currUID, run=self._currRun, rank=nsim.MPI._rank ) allInfos = mpi4py.MPI.COMM_WORLD.gather((fileName, allLocations), root=0) if nsim.MPI._rank == 0: # Write a common xmf file on rank 0 self._fullXdmf, fullhierarchicalGrids = self._createXMLDocument() for fileName, locations in allInfos: for loc, xmlPath in locations: loc = (elemStr2ElemCls[loc[0]],) + loc[1:] parent, _ = self._getHierarchicalParent(fullhierarchicalGrids, loc) ElementTree.SubElement(parent, 'xi:include', href=fileName, xpointer=f'xpointer({xmlPath})') def _getHyperSlab(self, parent, fileName, rsIdx, rsLen, tind, start, step, nVals): hyperslab = ElementTree.SubElement( parent, 'DataItem', ItemType='HyperSlab', Dimensions=f'{nVals}' ) dims = ElementTree.SubElement( hyperslab, 'DataItem', Dimensions='3 2', NumberType='Int', Format='XML' ) dims.text = f'{tind} {start} 1 {step} 1 {nVals}' data_item = ElementTree.SubElement( hyperslab, 'DataItem', DataType='Float', Dimensions=f'{tind+1} {rsLen}', Format='HDF', Precision='8' ) rsName = HDF5Handler._RS_GROUP_NAME.format(rsIdx) hdffn = os.path.basename(fileName) data_item.text = f'{hdffn}:{self._currUID}/{rsName}/runs/Run_{self._currRun}/data' return hyperslab def _newTimeStep(self, t, rs, tind): rsId = (nsim.MPI._rank, rs._selectorInd) for val, loctpe, start, step, nVals, gridCls, gridInd, center in self._rs2Grids[rsId]: tempGrid, xmlPath, currTime, currGrid = self._temporalGrids[gridCls][gridInd] if currTime != t: currGrid = self._writeGrid(gridCls, gridInd, t, tind) # Only triggered once per grid and per timestep if self._sim._isDistributed(): # Add attributes that are not in this rank if (gridCls, gridInd) in self._grid2NonDistrRS: for _val, _loctpe, _start, _center, rnk, rsIdx, rsLen in self._grid2NonDistrRS[(gridCls, gridInd)]: att = ElementTree.SubElement( currGrid, 'Attribute', Name=self._getAttributeName(_val, _loctpe), AttributeType='Scalar', Center=_center ) fileName = HDF5Handler._DISTRIBUTED_HDF_SUFFIX.format( self._pathPrefix, rnk) + HDF5Handler._HDF_EXTENSION hyperslab = self._getHyperSlab(att, fileName, rsIdx, rsLen, tind, _start, 1, 1) if self._shouldWrite: att = ElementTree.SubElement( currGrid, 'Attribute', Name=self._getAttributeName(val, loctpe), AttributeType='Scalar', Center=center ) if center == 'Grid': # Save the value in the xdmf file directly instead of using a hyperslab data_item = ElementTree.SubElement( att, 'DataItem', DataType='Float', Dimensions=f'1', Format='XML' ) data_item.text = str(rs.data[self._currRun, tind, start]) else: hyperslab = self._getHyperSlab( att, self._path, rs._selectorInd, rs._getEvalLen(), tind, start, step, nVals ) def _getHDF5SubGroup(self, name): if self._shouldWrite: if name not in self._currGroup._group: return self._currGroup._group.create_group(name) else: return self._currGroup._group[name] else: return None def _writeGrid(self, gridCls, gridInd, t, tind): meshGroup = self._getHDF5SubGroup(XDMFHandler._MESH_GROUP_NAME) # Write xdmf description of mesh data tpeMap = { ngeom.TetReference: (6, 5), ngeom.TriReference: (4, 4), ngeom.VertReference: (1, 1), } if self._sim._isDistributed(): gridName = f'{gridCls._locStr}Grid{gridInd}_rank{nsim.MPI._rank}' else: gridName = f'{gridCls._locStr}Grid{gridInd}' elemCode, elemColNb = tpeMap[gridCls] cond = f'{gridName}/XYZ' not in meshGroup if meshGroup is not None else None if nsim.MPI._usingMPI and not self._sim._isDistributed(): # If the mesh is not distributed, all ranks need to be involved in the calls to get mesh data import mpi4py.MPI cond = mpi4py.MPI.COMM_WORLD.bcast(cond, root=0) elems, _, loc = self._spatialGrids[gridCls][gridInd] if cond: allVerts = [] topo = [] if gridCls == ngeom.VertReference: allVerts = [numpy.array(vert) for vert in elems] vertInds = {vert: i for i, vert in enumerate(elems)} else: vertInds = {} for elem in elems: localInds = [] for vert in elem.verts: if vert not in vertInds: vertInds[vert] = len(allVerts) vertPos = numpy.array(vert) if self._shouldWrite: allVerts.append(vertPos) localInds.append(vertInds[vert]) if self._shouldWrite: topo.append(numpy.array([elemCode] + localInds)) if self._shouldWrite: meshGroup.create_dataset(f'{gridName}/vertInds', data=numpy.array([vert.idx for vert, _ in vertInds.items()]), **self._dataSetKWargs ) meshGroup.create_dataset(f'{gridName}/elemInds', data=numpy.array(elems.indices), **self._dataSetKWargs) meshGroup.create_dataset(f'{gridName}/XYZ', data=numpy.array(allVerts), **self._dataSetKWargs) if len(loc) > 0 and isinstance(loc[-1], str): meshGroup[gridName].attrs['loc_id'] = loc[-1] if len(topo) > 0: meshGroup.create_dataset(f'{gridName}/topology', data=numpy.array(topo), **self._dataSetKWargs) if self._shouldWrite: tempGrid, xmlPath, _, currGrid = self._temporalGrids[gridCls][gridInd] grid = ElementTree.SubElement( tempGrid, 'Grid', Name=f'{gridName}_{tind}', GridType='Uniform' ) ElementTree.SubElement(grid, 'Time', Value=f'{t}') if currGrid is None: nverts = meshGroup[f'{gridName}/XYZ'].shape[0] if f'{gridName}/topology' in meshGroup: nelems = meshGroup[f'{gridName}/topology'].shape[0] else: nelems = nverts if nelems != len(elems): raise Exception( f'Previous simulations were run with a different XDMF mesh splitting. ' f'Try saving to a different HDF5 file.' ) hdfFileName = os.path.basename(self._path) if gridCls == ngeom.VertReference: topo = ElementTree.SubElement( grid, 'Topology', TopologyType='PolyVertex', NumberOfElements=f'{nelems}' ) else: topo = ElementTree.SubElement( grid, 'Topology', TopologyType='Mixed', Dimensions=f'{nelems}' ) data_item = ElementTree.SubElement( topo, 'DataItem', DataType='Int', Dimensions=f'{nelems * elemColNb}', Format='HDF', Precision='8' ) data_item.text = f'{hdfFileName}:{self._currUID}/{XDMFHandler._MESH_GROUP_NAME}/{gridName}/topology' geo = ElementTree.SubElement(grid, 'Geometry', GeometryType='XYZ') data_item = ElementTree.SubElement( geo, 'DataItem', DataType='Float', Dimensions=f'{nverts * 3}', Format='HDF', Precision='8' ) data_item.text = f'{hdfFileName}:{self._currUID}/{XDMFHandler._MESH_GROUP_NAME}/{gridName}/XYZ' else: topo = ElementTree.SubElement(grid, 'Topology', Reference='XML') topo.text = f'{xmlPath}/Grid/Topology' geo = ElementTree.SubElement(grid, 'Geometry', Reference='XML') geo.text = f'{xmlPath}/Grid/Geometry' else: tempGrid, xmlPath, grid = None, None, None self._temporalGrids[gridCls][gridInd] = (tempGrid, xmlPath, t, grid) return grid def _writeModelInfo(self): """Write a dictionary containing model information that can be useful for visualization""" # Only rank 0 writes this data if nsim.MPI._shouldWrite: sd = self._currGroup.staticData # Write vesicle and raft diameters sd['Vesicles'] = {ves.name: {'Diameter': ves.Diameter} for ves in self._sim.model.ALL(nmodel.Vesicle)} sd['Rafts'] = {raft.name: {'Diameter': raft.Diameter} for raft in self._sim.model.ALL(nmodel.Raft)} try: sd['VesiclePaths'] = self._sim.solver._getAllPaths() except AttributeError: pass dct = {} for patch in self._sim.geom.ALL(ngeom.Patch): for zone in patch.ALL(ngeom.EndocyticZone): dct[zone.name] = {'patch': patch.name, 'tris': zone.tris.indices} sd['EndocyticZones'] = dct