Source code for pytta.functions

# -*- coding: utf-8 -*-
"""
Set of useful functions of general purpouses when using PyTTa.

Includes reading and writing wave files, seeing the audio IO
devices available and few signal processing tools.

Available functions:

    >>> pytta.list_devices()
    >>> pytta.read_wav(fileName)
    >>> pytta.write_wav(fileName, signalObject)
    >>> pytta.merge(signalObj1, signalObj2, ..., signalObjN)
    >>> pytta.slipt(signalObj)
    >>> pytta.fft_convolve(signalObj1, signalObj2)
    >>> pytta.find_delay(signalObj1, signalObj2)
    >>> pytta.corr_coef(signalObj1, signalObj2)
    >>> pytta.resample(signalObj, newSamplingRate)
    >>> pytta.peak_time(signalObj1, signalObj2, ..., signalObjN)
    >>> pytta.plot_time(signalObj1, signalObj2, ..., signalObjN)
    >>> pytta.plot_time_dB(signalObj1, signalObj2, ..., signalObjN)
    >>> pytta.plot_freq(signalObj1, signalObj2, ..., signalObjN)
    >>> pytta.plot_bars(signalObj1, signalObj2, ..., signalObjN)
    >>> pytta.save(fileName, obj1, ..., objN)
    >>> pytta.load(fileName)

For further information, check the function specific documentation.

"""

import os
import time
import json
from scipy.io import wavfile as wf
import scipy.io as sio
import numpy as np
import sounddevice as sd
import scipy.signal as ss
import scipy.fftpack as sfft
import zipfile as zf
import h5py
from typing import Union, List
from pytta.classes import SignalObj, ImpulsiveResponse, \
                    RecMeasure, PlayRecMeasure, FRFMeasure, \
                    Analysis, OctFilter
from pytta.classes._base import ChannelsList, ChannelObj
from pytta.generate import measurement  # TODO: Change to class instantiation.
from pytta import _h5utils as _h5
import copy as cp
from warnings import warn
from pytta import _plot as plot
# For backwards compatibility purposes. Planned to get out of here
from pytta.utils.maths import fft_degree as new_fft_degree


[docs]def list_devices(): """ Shortcut to sounddevice.query_devices(). Made to exclude the need of importing Sounddevice directly just to find out which audio devices can be used. >>> pytta.list_devices() Returns ------- A tuple containing all available audio devices. """ return sd.query_devices()
def print_devices(): """ Print the devices list to stdout. Returns ------- None. """ return print(list_devices())
[docs]def get_device_from_user() -> Union[List[int], int]: """ Print the device list and query for a number input of the device, or devices. Returns ------- Union[List[int], int] Practical interface for querying devices to be used within scripts. """ print_devices() device = [int(dev.strip()) for dev in input("Input the device number: ").split(',')] if len(device) == 1: device = device[0] text = "Device is:" else: text = "Devices are:" print(text, device) return device
[docs]def read_wav(fileName): """Read a wave file into a SignalObj.""" samplingRate, data = wf.read(fileName) if data.dtype == 'int16': data = data/(2**15) if data.dtype == 'int32': data = data/(2**31) signal = SignalObj(data, 'time', samplingRate=samplingRate) return signal
[docs]def write_wav(fileName, signalIn): """Write a SignalObj into a single wave file.""" samplingRate = signalIn.samplingRate data = signalIn.timeSignal return wf.write(fileName if '.wav' in fileName else fileName+'.wav', samplingRate, data)
[docs]def SPL(signal, nthOct=3, minFreq=100, maxFreq=4000): """ Calculate the `signal`'s Sound Pressure Level The calculations are made by frequency bands and ranges from `minFreq` to `maxFreq` with `nthOct` bands per octave. Returns ------- Analysis: The sound pressure level packed into an Analysis object. """ with OctFilter(order=4, nthOct=nthOct, minFreq=minFreq, maxFreq=maxFreq, base=10, refFreq=1000, samplingRate=signal.samplingRate) as ofb: fsignal = ofb.filter(signal) out = [] for filtsignal in fsignal: out.append(Analysis('L', nthOct, minFreq, maxFreq, filtsignal.spl())) return out if len(out) > 1 else out[0]
[docs]def merge(signal1, *signalObjects): """Gather all channels of the signalObjs given as input arguments into a single SignalObj.""" j = 1 freqMin = cp.deepcopy(signal1.freqMin) freqMax = cp.deepcopy(signal1.freqMax) comment = cp.deepcopy(signal1.comment) channels = cp.deepcopy(signal1.channels) timeSignal = cp.deepcopy(signal1.timeSignal) for inObj in signalObjects: if signal1.samplingRate != inObj.samplingRate: message = '\ \n To merge signals they must have the same sampling rate!\ \n SignalObj 1 and '+str(j+1)+' have different sampling rates.' raise AttributeError(message) if signal1.numSamples != inObj.numSamples: message = '\ \n To merge signals they must have the same length!\ \n SignalObj 1 and '+str(j+1)+' have different lengths.' raise AttributeError(message) comment = comment + ' / ' + inObj.comment for ch in inObj.channels._channels: channels.append(ch) timeSignal = np.hstack((timeSignal, inObj.timeSignal)) j += 1 newSignal = SignalObj(timeSignal, domain='time', samplingRate=signal1.samplingRate, freqMin=freqMin, freqMax=freqMax, comment=comment) channels.conform_to() newSignal.channels = channels return newSignal
def split(*signalObjects, channels: list = None) -> list: """ Split the provided SignalObjs' channels into several SignalObjs. If the 'channels' input argument is given, split the specified channel numbers of each SignalObj, otherwise split all channels. Arguments (default), (type): ----------------------------- * non-keyword arguments (), (SignalObj) * channels (None), (list): specified channels to split from the provided SignalObjs; Return (type): -------------- * spltdChs (list): a list containing SignalObjs for each split channel; """ spltdChs = [] for sigObj in signalObjects: moreSpltdChs = sigObj.split(channels=channels) spltdChs.extend(moreSpltdChs) return spltdChs
[docs]def fft_convolve(signal1, signal2): """ Use scipy.signal.fftconvolve() to convolve two time domain signals. >>> convolution = pytta.fft_convolve(signal1,signal2) """ # Fs = signal1.Fs conv = ss.fftconvolve(signal1.timeSignal, signal2.timeSignal) signal = SignalObj(conv, 'time', signal1.samplingRate) return signal
[docs]def find_delay(signal1, signal2): """ Cross Correlation alternative. More efficient fft based method to calculate time shift between two signals. >>> shift = pytta.find_delay(signal1,signal2) """ if signal1.N != signal2.N: return print('Signal1 and Signal2 must have the same length') else: freqSignal1 = signal1.freqSignal freqSignal2 = sfft.fft(np.flipud(signal2.timeSignal)) convoluted = np.real(sfft.ifft(freqSignal1 * freqSignal2)) convShifted = sfft.fftshift(convoluted) zeroIndex = int(signal1.numSamples / 2) - 1 shift = zeroIndex - np.argmax(convShifted) return shift
[docs]def corr_coef(signal1, signal2): """Finds the correlation coefficient between two SignalObjs using the numpy.corrcoef() function.""" coef = np.corrcoef(signal1.timeSignal, signal2.timeSignal) return coef[0, 1]
[docs]def resample(signal, newSamplingRate): """ Resample the timeSignal of the input SignalObj to the given sample rate using the scipy.signal.resample() function """ newSignalSize = np.int(signal.timeLength*newSamplingRate) resampled = ss.resample(signal.timeSignal[:], newSignalSize) newSignal = SignalObj(resampled, "time", newSamplingRate) return newSignal
[docs]def peak_time(signal): """ Return the time at signal's amplitude peak. """ if not isinstance(signal, SignalObj): raise TypeError('Signal must be an SignalObj.') peaks_time = [] for chindex in range(signal.numChannels): maxamp = max(np.abs(signal.timeSignal[:, chindex])) maxindex = np.where(signal.timeSignal[:, chindex] == np.abs(maxamp))[0] maxtime = signal.timeVector[maxindex][0] peaks_time.append(maxtime) if signal.numChannels > 1: return peaks_time else: return peaks_time[0]
[docs]def fft_degree(*args,**kwargs): """ DEPRECATED ---------- Being replaced by pytta.utils.maths.fft_degree on version 0.1.0. Power-of-two value that can be used to calculate the total number of samples of the signal. >>> numSamples = 2**fftDegree Parameters ---------- * timeLength (float = 0): Value, in seconds, of the time duration of the signal or recording. * samplingRate (int = 1): Value, in samples per second, that the data will be captured or emitted. Returns ------- fftDegree (float = 0): Power of 2 that can be used to calculate number of samples. """ warn(DeprecationWarning("Function 'pytta.fft_degree' is DEPRECATED and " + "being replaced by pytta.utils.maths.fft_degree" + " on version 0.1.0")) return new_fft_degree(*args, **kwargs)
[docs]def plot_time(*sigObjs, xLabel:str=None, yLabel:str=None, yLim:list=None, xLim:list=None, title:str=None, decimalSep:str=',', timeUnit:str='s'): """Plot provided SignalObjs togheter in time domain. Saves xLabel, yLabel, and title when provided for the next plots. Parameters (default), (type): ----------- * sigObjs (), (SignalObj): non-keyworded input arguments with N SignalObjs. * xLabel (None), (str): x axis label. * yLabel (None), (str): y axis label. * yLim (None), (list): inferior and superior limits. >>> yLim = [-100, 100] * xLim (None), (str): left and right limits. >>> xLim = [0, 15] * title (None), (str): plot title. * decimalSep (','), (str): may be dot or comma. >>> decimalSep = ',' # in Brazil * timeUnit ('s'), (str): 'ms' or 's'. Return: -------- matplotlib.figure.Figure object. """ realSigObjs = _remove_non_(SignalObj, sigObjs, msgPrefix='plot_time:') if len(realSigObjs) > 0: fig = plot.time(realSigObjs, xLabel, yLabel, yLim, xLim, title, decimalSep, timeUnit) return fig else: return
[docs]def plot_time_dB(*sigObjs, xLabel:str=None, yLabel:str=None, yLim:list=None, xLim:list=None, title:str=None, decimalSep:str=',', timeUnit:str='s'): """Plot provided SignalObjs togheter in decibels in time domain. Parameters (default), (type): ----------- * sigObjs (), (SignalObj): non-keyworded input arguments with N SignalObjs. * xLabel ('Time [s]'), (str): x axis label. * yLabel ('Amplitude'), (str): y axis label. * yLim (), (list): inferior and superior limits. >>> yLim = [-100, 100] * xLim (), (list): left and right limits >>> xLim = [0, 15] * title (), (str): plot title * decimalSep (','), (str): may be dot or comma. >>> decimalSep = ',' # in Brazil * timeUnit ('s'), (str): 'ms' or 's'. Return: -------- matplotlib.figure.Figure object. """ realSigObjs = \ _remove_non_(SignalObj, sigObjs, msgPrefix='plot_time_dB:') if len(realSigObjs) > 0: fig = plot.time_dB(realSigObjs, xLabel, yLabel, yLim, xLim, title, decimalSep, timeUnit) return fig else: return
[docs]def plot_freq(*sigObjs, smooth:bool=False, xLabel:str=None, yLabel:str=None, yLim:list=None, xLim:list=None, title:str=None, decimalSep:str=','): """Plot provided SignalObjs magnitudes togheter in frequency domain. Parameters (default), (type): ----------------------------- * sigObjs (), (SignalObj): non-keyworded input arguments with N SignalObjs. * xLabel ('Time [s]'), (str): x axis label. * yLabel ('Amplitude'), (str): y axis label. * yLim (), (list): inferior and superior limits. >>> yLim = [-100, 100] * xLim (), (list): left and right limits >>> xLim = [15, 21000] * title (), (str): plot title * decimalSep (','), (str): may be dot or comma. >>> decimalSep = ',' # in Brazil Return: -------- matplotlib.figure.Figure object. """ realSigObjs = \ _remove_non_(SignalObj, sigObjs, msgPrefix='plot_freq:') if len(realSigObjs) > 0: fig = plot.freq(realSigObjs, smooth, xLabel, yLabel, yLim, xLim, title, decimalSep) return fig else: return
[docs]def plot_bars(*analyses, xLabel:str=None, yLabel:str=None, yLim:list=None, xLim:list=None, title:str=None, decimalSep:str=',', barWidth:float=0.75, errorStyle:str=None, forceZeroCentering:bool=False, overlapBars:bool=False, color:list=None): """Plot the analysis data in fractinal octave bands. Parameters (default), (type): ----------------------------- * analyses (), (SignalObj): non-keyworded input arguments with N SignalObjs. * xLabel ('Time [s]'), (str): x axis label. * yLabel ('Amplitude'), (str): y axis label. * yLim (), (list): inferior and superior limits. >>> yLim = [-100, 100] * xLim (), (list): bands limits. >>> xLim = [100, 10000] * title (), (str): plot title * decimalSep (','), (str): may be dot or comma. >>> decimalSep = ',' # in Brazil * barWidth (0.75), float: width of the bars from one fractional octave band. 0 < barWidth < 1. * errorStyle ('standard'), str: error curve style. May be 'laza' or None/'standard'. * forceZeroCentering ('False'), bool: force centered bars at Y zero. * overlapBars ('False'), bool: overlap bars. No side by side bars of different data. * color (None), list: list containing the color of each Analysis. Return: -------- matplotlib.figure.Figure object. """ analyses = _remove_non_(Analysis, analyses, msgPrefix='plot_bars:') if len(analyses) > 0: fig = plot.bars(analyses, xLabel, yLabel, yLim, xLim, title, decimalSep, barWidth, errorStyle, forceZeroCentering, overlapBars, color) return fig else: return
[docs]def plot_spectrogram(*sigObjs, winType:str='hann', winSize:int=1024, overlap:float=0.5, xLabel:str=None, yLabel:str=None, yLim:list=None, xLim:list=None, title:str=None, decimalSep:str=','): """ Plots provided SignalObjs spectrogram. Parameters (default), (type): ----------------------------- * sigObjs (), (SignalObj): non-keyworded input arguments with N SignalObjs. * winType ('hann'), (str): window type for the time slicing. * winSize (1024), (int): window size in samples * overlap (0.5), (float): window overlap in % * xLabel ('Time [s]'), (str): x axis label. * yLabel ('Frequency [Hz]'), (str): y axis label. * yLim (), (list): inferior and superior frequency limits. >>> yLim = [20, 1000] * xLim (), (list): left and right time limits >>> xLim = [1, 3] * title (), (str): plot title * decimalSep (','), (str): may be dot or comma. >>> decimalSep = ',' # in Brazil Return: -------- List of matplotlib.figure.Figure objects for each item in curveData. """ realSigObjs = \ _remove_non_(SignalObj, sigObjs, msgPrefix='plot_spectrogram:') if len(realSigObjs) > 0: figs = plot.spectrogram(realSigObjs, winType, winSize, overlap, xLabel, yLabel, xLim, yLim, title, decimalSep) return figs else: return
[docs]def plot_waterfall(*sigObjs, step=2 ** 9, n=2 ** 13, fmin=None, fmax=None, pmin=None, pmax=None, tmax=None, xaxis='linear', time_tick=None, freq_tick=None, mag_tick=None, tick_fontsize=None, fpad=1, delta=60, dBref=2e-5, fill_value='pmin', fill_below=True, overhead=3, winAlpha=0, plots=['waterfall'], show=True, cmap='jet', alpha=[1, 1], saveFig=False, figRatio=[1, 1, 1], figsize=(950, 950), camera=[2, 1, 2]): """ This function was gently sent by Rinaldi Polese Petrolli. # TO DO Keyword Arguments: step {int} -- [description] (default: {10}) xLim {list} -- [description] (default: {None}) Pmin {int} -- [description] (default: {20}) Pmax {[type]} -- [description] (default: {None}) tmin {int} -- [description] (default: {0}) tmax {[type]} -- [description] (default: {None}) azim {int} -- [description] (default: {-72}) elev {int} -- [description] (default: {14}) cmap {str} -- [description] (default: {'jet'}) winPlot {bool} -- [description] (default: {False}) waterfallPlot {bool} -- [description] (default: {True}) fill {bool} -- [description] (default: {True}) lines {bool} -- [description] (default: {False}) alpha {int} -- [description] (default: {1}) figsize {tuple} -- [description] (default: {(20, 8)}) winAlpha {int} -- [description] (default: {0}) removeGridLines {bool} -- [description] (default: {False}) saveFig {bool} -- [description] (default: {False}) bar {bool} -- [description] (default: {False}) width {float} -- [description] (default: {0.70}) size {int} -- [description] (default: {3}) lcol {[type]} -- [description] (default: {None}) filtered {bool} -- [description] (default: {True}) Returns: [type] -- [description] """ realSigObjs = \ _remove_non_(SignalObj, sigObjs, msgPrefix='plot_waterfall:') if len(realSigObjs) > 0: figs = plot.waterfall(realSigObjs, step, n, fmin, fmax, pmin, pmax, tmax, xaxis, time_tick, freq_tick, mag_tick, tick_fontsize, fpad, delta, dBref, fill_value, fill_below, overhead, winAlpha, plots, show, cmap, alpha, saveFig, figRatio, figsize, camera) return figs else: return
def _remove_non_(dataType, dataSet, msgPrefix:str='_remove_non_SignalObjs:'): if isinstance(dataSet, (list, tuple)): newDataSet = [] for idx, item in enumerate(dataSet): if isinstance(item, dataType): newDataSet.append(item) elif isinstance(item, ImpulsiveResponse) and \ dataType.__name__ == 'SignalObj': newDataSet.append(item.systemSignal) else: print("{}: skipping object {} as it isn't a {}." .format(msgPrefix, idx+1, dataType.__name__)) if isinstance(dataSet, tuple): newDataSet = tuple(newDataSet) return newDataSet
[docs]def save(fileName: str = time.ctime(time.time()), *PyTTaObjs): """ Main save function for .hdf5 and .pytta files. The file format is chose by the extension applied to the fileName. If no extension is provided choose the default file format (.hdf5). For more information on saving PyTTa objects in .hdf5 format see pytta.functions._h5_save documentation. For more information on saving PyTTa objects in .pytta format see pytta.functions.pytta_save' documentation. (DEPRECATED) """ # default file format defaultFormat = '.hdf5' # Checking the chosen file format if fileName.split('.')[-1] == 'hdf5': _h5_save(fileName, *PyTTaObjs) elif fileName.split('.')[-1] == 'pytta': # DEPRECATED warn(DeprecationWarning("'.pytta' format is DEPRECATED and being " + "replaced by '.hdf5'.")) pytta_save(fileName, *PyTTaObjs) else: print("File extension must be '.hdf5'.\n" + "Applying the default extension.") fileName += defaultFormat save(fileName, *PyTTaObjs)
[docs]def load(fileName: str): """ Main save function for .pytta and .hdf5 files. """ if fileName.split('.')[-1] == 'hdf5': output = _h5_load(fileName) elif fileName.split('.')[-1] == 'pytta': warn(DeprecationWarning("'.pytta' format is DEPRECATED and being " + "replaced by '.hdf5'.")) output = pytta_load(fileName) else: ValueError('pytta.load only works with *.hdf5 or *.pytta files.') return output
def pytta_save(fileName: str = time.ctime(time.time()), *PyTTaObjs): """ Saves any number of PyTTaObj subclasses' objects to fileName.pytta file. Just calls .save() method of each class and packs them all into a major .pytta file along with a Meta.json file containing the fileName of each saved object. """ if fileName.split('.')[-1] == 'pytta': fileName = fileName.replace('.pytta', '') meta = {} with zf.ZipFile(fileName + '.pytta', 'w') as zdir: for idx, obj in enumerate(PyTTaObjs): sobj = obj.pytta_save('obj' + str(idx)) meta['obj' + str(idx)] = sobj zdir.write(sobj) os.remove(sobj) with open('Meta.json', 'w') as f: json.dump(meta, f, indent=4) zdir.write('Meta.json') os.remove('Meta.json') return fileName + '.pytta' def pytta_load(fileName: str): """ Loads .pytta files and parses it's types to the correct objects. """ if fileName.split('.')[-1] == 'pytta': with zf.ZipFile(fileName, 'r') as zdir: objects = zdir.namelist() for obj in objects: if obj.split('.')[-1] == 'json': meta = obj zdir.extractall() output = __parse_load(meta) else: raise ValueError("pytta_load function only works with *.pytta files") return output def __parse_load(className): name = className.split('.')[0] jsonFile = open(className, 'r') openJson = json.load(jsonFile) if name == 'SignalObj': openMat = sio.loadmat(openJson['timeSignalAddress']) out = SignalObj(openMat['timeSignal'], domain=openJson['lengthDomain'], samplingRate=openJson['samplingRate'], freqMin=openJson['freqLims'][0], freqMax=openJson['freqLims'][1], comment=openJson['comment']) out.channels = __parse_channels(openJson['channels'], out.channels) os.remove(openJson['timeSignalAddress']) elif name == 'ImpulsiveResponse': ir = pytta_load(openJson['SignalAddress']['ir']) out = ImpulsiveResponse(ir=ir, **openJson['methodInfo']) os.remove(openJson['SignalAddress']['ir']) elif name == 'RecMeasure': inch = list(np.arange(len(openJson['inChannels']))) out = RecMeasure(device=openJson['device'], inChannels=inch, lengthDomain='samples', fftDegree=openJson['fftDegree']) out.inChannels = __parse_channels(openJson['inChannels'], out.inChannels) elif name == 'PlayRecMeasure': inch = list(1 + np.arange(len(openJson['inChannels']))) excit = pytta_load(openJson['excitationAddress']) out = PlayRecMeasure(excitation=excit, device=openJson['device'], inChannels=inch) out.inChannels = __parse_channels(openJson['inChannels'], out.inChannels) os.remove(openJson['excitationAddress']) elif name == 'FRFMeasure': inch = list(1 + np.arange(len(openJson['inChannels']))) excit = pytta_load(openJson['excitationAddress']) out = FRFMeasure(excitation=excit, device=openJson['device'], inChannels=inch) out.inChannels = __parse_channels(openJson['inChannels'], out.inChannels) os.remove(openJson['excitationAddress']) elif name == 'Meta': out = [] for val in openJson.values(): out.append(pytta_load(val)) os.remove(val) os.remove(className) jsonFile.close() return out def __parse_channels(chDict, chList): ch = 1 for key in chDict.keys(): chList[ch].num = key chList[ch].unit = chDict[key]['unit'] chList[ch].name = chDict[key]['name'] chList[ch].CF = chDict[key]['calib'][0] chList[ch].calibCheck\ = chDict[key]['calib'][1] ch += 1 return chList def _h5_save(fileName: str, *PyTTaObjs): """ Open an hdf5 file, create groups for each PyTTa object, pass it to the own object and it saves itself inside the group. >>> pytta._h5_save(fileName, PyTTaObj_1, PyTTaObj_2, ..., PyTTaObj_n) Dictionaries can also be passed as a PyTTa object. An hdf5 group will be created for each dictionary and its PyTTa objects will be saved. To ensure the diciontary name will be saved, create the key 'dictName' inside it with its name in a string as the value. This function will take this key and use as variable name for the dict. Lists can also be passed as a PyTTa object. An hdf5 group will be created for each list and its PyTTa objects will be saved. To ensure the list name will be saved, append to the list a string containing its name. This function will take the first string found and use it as variable name for the list. """ # Checking if filename has .hdf5 extension if fileName.split('.')[-1] != 'hdf5': fileName += '.hdf5' with h5py.File(fileName, 'w') as f: # Save the version to the HDF5 file f.attrs['GENERATED_BY'] = 'PyTTa' f.attrs['LONG_DESCR'] = 'HDF5 file generated by the PyTTa toolbox' f.attrs['FILE_SYS_VERSION'] = 1 # Dict for counting equal names for correctly renaming totalPObjCount = 0 # Counter for total groups savedPObjCount = 0 # Counter for loaded objects for idx, pObj in enumerate(PyTTaObjs): packTotalPObjCount, packSavedPObjCount = \ __h5_pack(f, pObj, idx) totalPObjCount, savedPObjCount = \ totalPObjCount + packTotalPObjCount, \ savedPObjCount + packSavedPObjCount # Final message plural1 = 's' if savedPObjCount > 1 else '' plural2 = 's' if totalPObjCount > 1 else '' print('Saved inside the hdf5 file {} PyTTa object{}' .format(savedPObjCount, plural1) + ' of {} object{} provided.'.format(totalPObjCount, plural2)) return fileName def __h5_pack(rootH5Group, pObj, objDesc): """ __h5_pack packs a PyTTa object or dict into its respective HDF5 group. """ if isinstance(pObj, (SignalObj, ImpulsiveResponse, RecMeasure, PlayRecMeasure, FRFMeasure, Analysis)): # Creation name if isinstance(objDesc, str): creationName = objDesc else: creationName = pObj.creation_name # Check if creation_name was already used creationName = __h5_pack_count_and_rename(creationName, rootH5Group) # create obj's group objH5Group = rootH5Group.create_group(creationName) # save the obj inside its group pObj._h5_save(objH5Group) return (1, 1) elif isinstance(pObj, dict): # Creation name if 'dictName' in pObj: creationName = pObj.pop('dictName') elif isinstance(objDesc, str): creationName = objDesc else: creationName = 'noNameDict' creationName = __h5_pack_count_and_rename(creationName, rootH5Group) print("Saving the dict '{}'.".format(creationName)) # create obj's group objH5Group = rootH5Group.create_group(creationName) objH5Group.attrs['class'] = 'dict' # Saving each key of the dict inside the hdf5 group totalPObjCount = 0 savedPObjCount = 0 for key, pObjFromDict in pObj.items(): packTotalPObjCount, packSavedPObjCount = \ __h5_pack(objH5Group, pObjFromDict, key) totalPObjCount, savedPObjCount = \ totalPObjCount + packTotalPObjCount, \ savedPObjCount + packSavedPObjCount return (totalPObjCount, savedPObjCount) elif isinstance(pObj, list): # Creation name creationName = None for idx, item in enumerate(pObj): if isinstance(item, str): creationName = item pObj.pop(idx) continue if creationName is None: if isinstance(objDesc, str): creationName = objDesc else: creationName = 'noNameList' creationName = __h5_pack_count_and_rename(creationName, rootH5Group) print("Saving the list '{}'.".format(creationName)) # create obj's group objH5Group = rootH5Group.create_group(creationName) objH5Group.attrs['class'] = 'list' # Saving each item of the list inside the hdf5 group totalPObjCount = 0 savedPObjCount = 0 for idx, pObjFromList in enumerate(pObj): packTotalPObjCount, packSavedPObjCount = \ __h5_pack(objH5Group, pObjFromList, str(idx)) totalPObjCount, savedPObjCount = \ totalPObjCount + packTotalPObjCount, \ savedPObjCount + packSavedPObjCount return totalPObjCount, savedPObjCount else: print("Only PyTTa objects and dicts/lists with PyTTa objects " + "can be saved through this function. Skipping " + "object '" + str(objDesc) + "'.") return (1, 0) def __h5_pack_count_and_rename(creationName, h5Group): # Check if creation_name was already used objNameCount = 1 newCreationName = cp.copy(creationName) while newCreationName in h5Group: objNameCount += 1 newCreationName = \ creationName + '_' + str(objNameCount) creationName = newCreationName return creationName def _h5_load(fileName: str): """ Load an hdf5 file and recreate the PyTTa objects. """ # Checking if the file is an hdf5 file if fileName.split('.')[-1] != 'hdf5': raise ValueError("_h5_load function only works with *.hdf5 files") f = h5py.File(fileName, 'r') # Check if it is a PyTTa-like hdf5 file try: if 'GENERATED_BY' not in f.attrs.keys() or \ f.attrs['GENERATED_BY'] != "PyTTa": raise NotImplementedError except: # raise NotImplementedError("Only PyTTa-like hdf5 files can be loaded.") warn(DeprecationWarning("'GENERATED_BY' tag couldn't be found in " + "the .hdf5 file. Still trying to load " + "because of legacy PyTTa HDF5 files.")) loadedObjects = {} objCount = 0 # Counter for loaded objects totCount = 0 # Counter for total groups for PyTTaObjName, PyTTaObjH5Group in f.items(): totCount += 1 try: loadedObjects[PyTTaObjName] = __h5_unpack(PyTTaObjH5Group) objCount += 1 except NotImplementedError: print("Skipping hdf5 group named {} as it ".format(PyTTaObjName) + "isn't an PyTTa object group.") f.close() # Final message plural1 = 's' if objCount > 1 else '' plural2 = 's' if totCount > 1 else '' print('Imported {} PyTTa object-like group'.format(objCount) + plural1 + ' of {} group'.format(totCount) + plural2 + ' inside the hdf5 file.') return loadedObjects def __h5_unpack(objH5Group): """ Unpack an HDF5 group into its respective PyTTa object """ if objH5Group.attrs['class'] == 'SignalObj': # PyTTaObj attrs unpacking samplingRate = objH5Group.attrs['samplingRate'] freqMin = _h5.none_parser(objH5Group.attrs['freqMin']) freqMax = _h5.none_parser(objH5Group.attrs['freqMax']) lengthDomain = objH5Group.attrs['lengthDomain'] comment = objH5Group.attrs['comment'] # SignalObj attr unpacking channels = eval(objH5Group.attrs['channels']) # Added with an if for compatibility issues if 'signalType' in objH5Group.attrs: signalType = _h5.attr_parser(objH5Group.attrs['signalType']) else: signalType = 'power' # Creating and conforming SignalObj SigObj = SignalObj(signalArray=np.array(objH5Group['timeSignal']), domain='time', signalType=signalType, samplingRate=samplingRate, freqMin=freqMin, freqMax=freqMax, comment=comment) SigObj.channels = channels SigObj.lengthDomain = lengthDomain return SigObj elif objH5Group.attrs['class'] == 'ImpulsiveResponse': systemSignal = __h5_unpack(objH5Group['systemSignal']) method = objH5Group.attrs['method'] winType = objH5Group.attrs['winType'] winSize = objH5Group.attrs['winSize'] overlap = objH5Group.attrs['overlap'] IR = ImpulsiveResponse(method=method, winType=winType, winSize=winSize, overlap=overlap, ir=systemSignal) return IR elif objH5Group.attrs['class'] == 'RecMeasure': # PyTTaObj attrs unpacking samplingRate = objH5Group.attrs['samplingRate'] freqMin = _h5.none_parser(objH5Group.attrs['freqMin']) freqMax = _h5.none_parser(objH5Group.attrs['freqMax']) comment = objH5Group.attrs['comment'] lengthDomain = objH5Group.attrs['lengthDomain'] fftDegree = objH5Group.attrs['fftDegree'] timeLength = objH5Group.attrs['timeLength'] # Measurement attrs unpacking device = _h5.list_w_int_parser(objH5Group.attrs['device']) inChannels = eval(objH5Group.attrs['inChannels']) blocking = objH5Group.attrs['blocking'] # Recreating the object rObj = measurement(kind='rec', device=device, inChannels=inChannels, blocking=blocking, samplingRate=samplingRate, freqMin=freqMin, freqMax=freqMax, comment=comment, lengthDomain=lengthDomain, fftDegree=fftDegree, timeLength=timeLength) return rObj elif objH5Group.attrs['class'] == 'PlayRecMeasure': # PyTTaObj attrs unpacking samplingRate = objH5Group.attrs['samplingRate'] freqMin = _h5.none_parser(objH5Group.attrs['freqMin']) freqMax = _h5.none_parser(objH5Group.attrs['freqMax']) comment = objH5Group.attrs['comment'] lengthDomain = objH5Group.attrs['lengthDomain'] fftDegree = objH5Group.attrs['fftDegree'] timeLength =objH5Group.attrs['timeLength'] # Measurement attrs unpacking device = _h5.list_w_int_parser(objH5Group.attrs['device']) inChannels = eval(objH5Group.attrs['inChannels']) outChannels = eval(objH5Group.attrs['outChannels']) blocking = objH5Group.attrs['blocking'] # PlayRecMeasure attrs unpacking excitation = __h5_unpack(objH5Group['excitation']) outputAmplification = objH5Group.attrs['outputAmplification'] # Recreating the object prObj = measurement(kind='playrec', excitation=excitation, outputAmplification=outputAmplification, device=device, inChannels=inChannels, outChannels=outChannels, blocking=blocking, samplingRate=samplingRate, freqMin=freqMin, freqMax=freqMax, comment=comment) return prObj elif objH5Group.attrs['class'] == 'FRFMeasure': # PyTTaObj attrs unpacking samplingRate = objH5Group.attrs['samplingRate'] freqMin = _h5.none_parser(objH5Group.attrs['freqMin']) freqMax = _h5.none_parser(objH5Group.attrs['freqMax']) comment = objH5Group.attrs['comment'] lengthDomain = objH5Group.attrs['lengthDomain'] fftDegree = objH5Group.attrs['fftDegree'] timeLength = objH5Group.attrs['timeLength'] # Measurement attrs unpacking device = _h5.list_w_int_parser(objH5Group.attrs['device']) inChannels = eval(objH5Group.attrs['inChannels']) outChannels = eval(objH5Group.attrs['outChannels']) blocking = objH5Group.attrs['blocking'] # PlayRecMeasure attrs unpacking excitation = __h5_unpack(objH5Group['excitation']) outputAmplification = objH5Group.attrs['outputAmplification'] # FRFMeasure attrs unpacking method = _h5.none_parser(objH5Group.attrs['method']) winType = _h5.none_parser(objH5Group.attrs['winType']) winSize = _h5.none_parser(objH5Group.attrs['winSize']) overlap = _h5.none_parser(objH5Group.attrs['overlap']) # Recreating the object frfObj = measurement(kind='frf', method=method, winType=winType, winSize=winSize, overlap=overlap, excitation=excitation, outputAmplification=outputAmplification, device=device, inChannels=inChannels, outChannels=outChannels, blocking=blocking, samplingRate=samplingRate, freqMin=freqMin, freqMax=freqMax, comment=comment) return frfObj elif objH5Group.attrs['class'] == 'Analysis': # Analysis attrs unpacking anType = _h5.attr_parser(objH5Group.attrs['anType']) nthOct = _h5.attr_parser(objH5Group.attrs['nthOct']) minBand = _h5.attr_parser(objH5Group.attrs['minBand']) maxBand = _h5.attr_parser(objH5Group.attrs['maxBand']) comment = _h5.attr_parser(objH5Group.attrs['comment']) title = _h5.attr_parser(objH5Group.attrs['title']) dataLabel = _h5.attr_parser(objH5Group.attrs['dataLabel']) errorLabel = _h5.attr_parser(objH5Group.attrs['errorLabel']) xLabel = _h5.attr_parser(objH5Group.attrs['xLabel']) yLabel = _h5.attr_parser(objH5Group.attrs['yLabel']) # Analysis data unpacking data = np.array(objH5Group['data']) # If error in save moment was None no group was created for it if 'error' in objH5Group: error = np.array(objH5Group['error']) else: error = None # Recreating the object anObject = Analysis(anType=anType, nthOct=nthOct, minBand=minBand, maxBand=maxBand, data=data, dataLabel=dataLabel, error=error, errorLabel=errorLabel, comment=comment, xLabel=xLabel, yLabel=yLabel, title=title) return anObject elif objH5Group.attrs['class'] == 'dict': dictObj = {} for PyTTaObjName, PyTTaObjH5Group in objH5Group.items(): dictObj[PyTTaObjName] = __h5_unpack(PyTTaObjH5Group) return dictObj elif objH5Group.attrs['class'] == 'list': dictObj = {} for idx, PyTTaObjH5Group in objH5Group.items(): dictObj[int(idx)] = __h5_unpack(PyTTaObjH5Group) idxs = [int(item) for item in list(dictObj.keys())] maxIdx = max(idxs) listObj = [] for idx in range(maxIdx+1): listObj.append(dictObj[idx]) return listObj else: raise NotImplementedError