Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions behavelet/morlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _morlet_conj_ft(omegas, omega0=5.0, gpu=False):
return ft_wavelet


def _morlet_fft_convolution(X, freqs, scales, dtime, omega0=5.0, gpu=False):
def _morlet_fft_convolution(X, freqs, scales, dtime, omega0=5.0, return_complex=False, gpu=False):
"""
Calculates a Morlet continuous wavelet transform
for a given signal across a range of frequencies
Expand Down Expand Up @@ -163,7 +163,8 @@ def _morlet_fft_convolution(X, freqs, scales, dtime, omega0=5.0, gpu=False):
convolved *= backend.sqrt(scale)

convolved = convolved[idx0:idx1] # remove zero padding
convolved = backend.abs(convolved) # use the norm of the complex values
if not return_complex:
convolved = backend.abs(convolved) # use the norm of the complex values

# scale power to account for disproportionally
# large wavelet response at low frequencies
Expand All @@ -183,7 +184,7 @@ def _morlet_fft_convolution_parallel(feed_dict):

def wavelet_transform(X, n_freqs, fsample, fmin, fmax,
prob=True, omega0=5.0, log_scale=True,
n_jobs=1, gpu=False):
return_complex=False, n_jobs=1, gpu=False):
"""
Applies a Morlet continuous wavelet transform to a data set
across a range of frequencies.
Expand Down Expand Up @@ -215,6 +216,8 @@ def wavelet_transform(X, n_freqs, fsample, fmin, fmax,
Whether to sample the frequencies on a log scale.
omega0 : float (default = 5.0)
Dimensionless omega0 parameter for wavelet transform.
return_complex: bool (default = False)
Whether to return complex wavelet transform.
n_jobs : int (default = 1)
Number of jobs to use for performing the wavelet transform.
If -1, all CPUs are used. If 1 is given, no parallel computing is
Expand Down Expand Up @@ -284,6 +287,7 @@ def wavelet_transform(X, n_freqs, fsample, fmin, fmax,
"scales": scales,
"dtime": dtime,
"omega0": omega0,
"return_complex": return_complex,
"gpu": gpu}
for feature in X.T]

Expand All @@ -299,7 +303,7 @@ def wavelet_transform(X, n_freqs, fsample, fmin, fmax,
# for idx, conv in enumerate(convolved):
# X_new[:, (n_freqs * idx):(n_freqs * (idx + 1))] = conv.T

power = X_new.sum(axis=1, keepdims=True)
power = np.abs(X_new).sum(axis=1, keepdims=True)

if prob:
X_new /= power
Expand Down