Source code for pyrfu.pyrf.wavelet

#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Built-in imports
import logging
import os
from typing import Dict, Optional, Union

# 3rd party imports
import numba
import numpy as np
import xarray as xr
from numpy.typing import NDArray
from scipy import fft
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset

# Local imports
from pyrfu.pyrf.calc_fs import calc_fs

__author__ = "Louis Richard"
__email__ = "louisr@irfu.se"
__copyright__ = "Copyright 2020-2024"
__license__ = "MIT"
__version__ = "2.4.13"
__status__ = "Prototype"

logging.captureWarnings(True)
logging.basicConfig(
    format="[%(asctime)s] %(levelname)s: %(message)s",
    datefmt="%d-%b-%y %H:%M:%S",
    level=logging.INFO,
)


@numba.jit(nopython=True, fastmath=True)  # type: ignore
def _ww(
    s_ww: NDArray[np.complex128],
    scales_mat: NDArray[np.float64],
    sigma: float,
    frequencies_mat: NDArray[np.float64],
    f_nyq: float,
) -> NDArray[np.complex128]:
    # TODO : use nested for loop and math instead of numpy and test speed!!
    w_w: NDArray[np.complex128] = s_ww * np.exp(
        -sigma * sigma * ((scales_mat * frequencies_mat - f_nyq) ** 2) / 2,
    )
    w_w = w_w * np.sqrt(1.0)
    return w_w


@numba.jit(nopython=True, parallel=True, fastmath=True)  # type: ignore
def _power_r(
    power: NDArray[np.complex128], new_freq_mat: NDArray[np.float64]
) -> NDArray[np.float64]:
    power2: NDArray[np.float64] = np.absolute(
        (2 * np.pi) * np.conj(power) * power / new_freq_mat
    )
    return power2


@numba.jit(nopython=True, parallel=True, fastmath=True)  # type: ignore
def _power_c(
    power: NDArray[np.complex128], new_freq_mat: NDArray[np.float64]
) -> NDArray[np.complex128]:
    power2: NDArray[np.complex128] = (
        np.sqrt(np.absolute((2 * np.pi) / new_freq_mat)) * power
    )
    return power2


[docs]def wavelet( inp: DataArray, f_s: Optional[float] = None, f: Optional[list[float]] = None, n_freqs: Optional[int] = None, linear: Optional[Union[float, bool]] = None, wavelet_width: Optional[float] = None, cut_edge: Optional[bool] = True, return_power: Optional[bool] = True, ) -> Union[DataArray, Dataset]: """Computes wavelet spectrogram based on fast FFT algorithm. Parameters ---------- inp : DataArray Input quantity. f_s : float, Optional Sampling frequency of the input time series. f : list, Optional Vector [f_min f_max], calculate spectra between frequencies f_min and f_max. n_freqs : int, Optional Number of frequency bins. linear : float or bool, Optional Linear spacing between frequencies of df. wavelet_width : float, Optional Width of the Morlet wavelet. Default 5.36. cut_edge : bool, Optional Set to True to set points affected by edge effects to NaN, False to keep edge affect points. Default True return_power : bool, Optional Set to True to return the power, False for complex wavelet transform. Default True. Returns ------- DataArray or Dataset Wavelet transform of the input. Raises ------ TypeError If linear keyword argument is not bool or float. ValueError If input is not 1D or 2D. """ # Check input if not isinstance(inp, xr.DataArray): raise TypeError("Input must be a DataArray") if f_s is None: f_s = calc_fs(inp) if n_freqs is None: n_freqs = 200 if wavelet_width is None: wavelet_width = 5.36 if linear is not None: if isinstance(linear, float): delta_f: float = linear linear_df: bool = True elif isinstance(linear, bool) and linear: delta_f = 100.0 linear_df = True logging.warning("Unknown input for linear delta_f set to 100") else: raise TypeError("linear keyword argument must be bool or float") else: delta_f = 100.0 linear_df = False # Nyquist frequency and wavelet width f_nyq: float = f_s / 2 sigma: float = wavelet_width / f_nyq # Frequency range if f is None: f_min: float = f_nyq / 10**2 f_max: float = f_nyq / 10**-2 else: f_min, f_max = sorted(f) if linear_df: scale_number: int = int(np.floor(f_nyq / delta_f)) # Scales range scale_min: float = delta_f scale_max: float = scale_number * delta_f scales: NDArray[np.float64] = f_nyq / ( np.linspace(scale_max, scale_min, scale_number, dtype=np.float64) ) else: scale_number = n_freqs scale_min = np.log10(f_nyq / f_max) scale_max = np.log10(f_nyq / f_min) scales = np.logspace(scale_min, scale_max, scale_number, dtype=np.float64) # Unpack time and data. # Remove the last sample if the total number of samples is odd. if len(inp.time.data) % 2: time: NDArray[np.datetime64] = inp.time.data[:-1] data: NDArray[np.float64] = inp.data[:-1, ...].astype(np.float64) else: time = inp.time.data data = inp.data.astype(np.float64) # Preallocate power2 if return_power: power2: NDArray[Union[np.float64, np.complex128]] = np.zeros( (len(time), n_freqs), dtype=np.float64 ) else: power2 = np.zeros((len(time), n_freqs), dtype=np.complex128) # Check for NaNs scales[np.isnan(scales)] = 0.0 # Find the frequencies for an FFT of all data freq: NDArray[np.float64] = ( f_nyq * np.arange(1, 1 + len(data) / 2) / (len(data) / 2) ) # The frequencies corresponding to FFT freqs_fft: NDArray[np.float64] = np.hstack([0, freq, -np.flip(freq[:-1])]) _, freqs_fft_mat = np.meshgrid(scales, freqs_fft, sparse=True) # Get the correct frequencies for the wavelet transform freqs_cwt: NDArray[np.float64] = f_nyq / scales freqs_cwt_mat, _ = np.meshgrid(freqs_cwt, freqs_fft, sparse=True) if data.ndim in [1, 2]: out_dict: Dict[str, object] = {} else: raise ValueError("Input data must be 1D or 2D") # if scalar add virtual axis if len(inp.shape) == 1: data = data[:, np.newaxis] # go through all the data columns for i in range(data.shape[1]): # Make the FFT of all data data_col: NDArray[np.float64] = data[:, i] # Wavelet transform of the data # Forward FFT s_w: NDArray[np.complex128] = fft.fft(data_col, workers=os.cpu_count()) scales_mat, s_w_mat = np.meshgrid(scales, s_w, sparse=True) # Calculate the FFT of the wavelet transform w_w: NDArray[np.complex128] = _ww( s_w_mat, scales_mat, sigma, freqs_fft_mat, f_nyq ) # Backward FFT power: NDArray[np.complex128] = fft.ifft(w_w, axis=0, workers=os.cpu_count()) # Calculate the power spectrum if return_power: power2 = _power_r(power, np.tile(freqs_cwt_mat, (len(power), 1))) else: power2 = _power_c(power, np.tile(freqs_cwt_mat, (len(power), 1))) if cut_edge: censure = np.floor(2 * scales).astype(int) for j in range(scale_number): power2[: censure[j], j] = np.nan power2[len(data_col) - censure[j] : len(data_col), j] = np.nan if len(inp.shape) == 2: # Construct xarray.DataArray here out_dict[str(inp.comp.data[i])] = ( ["time", "frequency"], np.fliplr(power2), ) if len(inp.shape) == 1: out: Union[DataArray, Dataset] = xr.DataArray( np.fliplr(power2), coords=[time, np.flip(freqs_cwt)], dims=["time", "frequency"], ) else: out = xr.Dataset( out_dict, coords={"time": time, "frequency": np.flip(freqs_cwt)}, ) return out