Source code for pyrfu.mms.fft_bandpass
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Built-in imports
from typing import Union
# 3rd party imports
import numpy as np
import xarray as xr
from numpy.typing import NDArray
from xarray.core.dataarray import DataArray
# Local imports
from pyrfu.pyrf.calc_fs import calc_fs
from pyrfu.pyrf.ts_scalar import ts_scalar
from pyrfu.pyrf.ts_vec_xyz import ts_vec_xyz
__author__ = "Louis Richard"
__email__ = "louisr@irfu.se"
__copyright__ = "Copyright 2020-2024"
__license__ = "MIT"
__version__ = "2.4.13"
__status__ = "Prototype"
NDArrayFloats = NDArray[Union[np.float32, np.float64]]
[docs]def fft_bandpass(inp: DataArray, f_min: float, f_max: float) -> DataArray:
r"""Perform simple bandpass using FFT - returns fields between with
``f_min`` < f < ``f_max``.
Parameters
----------
inp : DataArray
Time series to be bandpass filtered.
f_min : float
Minimum frequency of filter, f < ``f_min`` are removed.
f_max : float
Maximum frequency of filter, f > ``f_max`` are removed.
Returns
-------
DataArray
Time series of the bandpass filtered data.
Raises
------
TypeError
* If input is not a xarray.DataArray.
* If f_min is not a float.
* If f_max is not a float.
ValueError
If f_min is larger than f_max.
Notes
-----
Can be some spurious effects near boundary. Can take longer interval then
use pyrfu.pyrf.time_clip to remove.
Examples
--------
>>> from pyrfu import mms
Define time interval
>>> tint = ["2017-07-23T16:54:24.000", "2017-07-23T17:00:00.000"]
Spacecraft index
>>> mms_id = 1
Load Electric Field
>>> e_xyz = mms.get_data("e_gse_edp_brst_l2", tint, mms_id)
Bandpass filter
>>> e_xyz_bp = mms.fft_bandpass(e_xyz, 1e1, 1e2)
"""
# Make sure input is a DataArray
if not isinstance(inp, xr.DataArray):
raise TypeError("Input must be a xarray.DataArray")
# Check f_min and f_max
if not isinstance(f_min, float):
raise TypeError("f_min must be a float")
if not isinstance(f_max, float):
raise TypeError("f_max must be a float")
# Check that f_min < f_max
if f_min >= f_max:
raise ValueError("f_min must be smaller than f_max")
# Get time and data
inp_time: NDArray[np.datetime64] = inp.time.data
inp_data: NDArrayFloats = inp.data
precision = inp_data.dtype
# Reshape to column vector if input is a scalar
if inp_data.ndim == 1:
# If scalar, reshape to column vector
inp_data = inp_data[:, np.newaxis]
elif inp_data.ndim > 2:
raise ValueError("Input must be a scalar or a vector")
# Make sure number of elements is an even number, if odd remove last
# element to make an even number
if len(inp_time) % 2:
inp_time = inp_time[:-1]
inp_data = inp_data[:-1, :]
# Set NaN values to zero so FFT works
idx_nans = np.isnan(inp_data)
inp_data[idx_nans] = 0.0
# Bandpass filter field data
f_sam = calc_fs(inp)
f_nyq = f_sam / 2
frequencies = np.linspace(-f_nyq, f_nyq, len(inp_time))
# Preallocate output
out_data_64: NDArray[np.float64] = np.empty_like(inp_data, dtype=np.float64)
# FFT and remove frequencies
for i in range(inp_data.shape[1]):
inp_tmp: NDArray[np.float64] = inp_data[:, i].astype(np.float64)
inp_fft: NDArray[np.complex128] = np.fft.fft(inp_tmp)
inp_fft = np.fft.fftshift(inp_fft)
inp_fft[np.abs(frequencies) < f_min] = 0.0 + 0.0j
inp_fft[np.abs(frequencies) > f_max] = 0.0 + 0.0j
inp_fft = np.fft.ifftshift(inp_fft)
out_tmp = np.fft.ifft(inp_fft)
out_data_64[:, i] = np.real(out_tmp)
# Put back original NaNs and back to original shape and precision
out_data_64[idx_nans] = np.nan
out_data_64 = np.squeeze(out_data_64)
out_data: NDArrayFloats = out_data_64.astype(precision)
# Return data in the same format as input
if out_data.ndim == 1:
out: DataArray = ts_scalar(inp_time, out_data, attrs=inp.attrs)
else:
out = ts_vec_xyz(inp_time, out_data, attrs=inp.attrs)
return out