Skip to content

Commit

Permalink
Removed numba, made all pyfftw imports conditional
Browse files Browse the repository at this point in the history
  • Loading branch information
bbfrederick committed Apr 25, 2024
1 parent 8b72960 commit 68ac03d
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 200 deletions.
4 changes: 1 addition & 3 deletions Dockerfile.safe
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,8 @@ RUN mamba install -y python \
h5py \
"tensorflow>=2.4.0" \
pyqtgraph \
pyfftw \
pandas \
versioneer \
numba; sync && \
versioneer; sync && \
chmod -R a+rX /usr/local/miniconda; sync && \
chmod +x /usr/local/miniconda/bin/*; sync && \
conda-build purge-all; sync && \
Expand Down
58 changes: 17 additions & 41 deletions capcalc/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,27 @@
"""
import sys
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pyfftw

# from numba import jit
with warnings.catch_warnings():
warnings.simplefilter("ignore")
try:
import pyfftw
except ImportError:
pyfftwpresent = False
else:
pyfftwpresent = True

from scipy import fftpack, ndimage, signal
from scipy.signal import savgol_filter

# import warnings

if pyfftwpresent:
fftpack = pyfftw.interfaces.scipy_fftpack
pyfftw.interfaces.cache.enable()

# fftpack = pyfftw.interfaces.scipy_fftpack
# pyfftw.interfaces.cache.enable()

# ---------------------------------------- Global constants -------------------------------------------
donotusenumba = True

# ----------------------------------------- Conditional imports ---------------------------------------
try:
Expand All @@ -49,23 +53,6 @@
except ImportError:
memprofilerexists = False


# ----------------------------------------- Conditional jit handling ----------------------------------
def conditionaljit():
def resdec(f):
global donotusenumba
if donotusenumba:
return f
return jit(f, nopython=False)

return resdec


def disablenumba():
global donotusenumba
donotusenumba = True


# --------------------------- Filtering functions -------------------------------------------------
# NB: No automatic padding for precalculated filters

Expand Down Expand Up @@ -177,7 +164,8 @@ def ssmooth(xsize, ysize, zsize, sigma, inputdata):


# - butterworth filters
@conditionaljit()


def dolpfiltfilt(Fs, upperpass, inputdata, order, padlen=20, cyclic=False, debug=False):
r"""Performs a bidirectional (zero phase) Butterworth lowpass filter on an input vector
and returns the result. Ends are padded to reduce transients.
Expand Down Expand Up @@ -235,7 +223,6 @@ def dolpfiltfilt(Fs, upperpass, inputdata, order, padlen=20, cyclic=False, debug
).astype(np.float64)


@conditionaljit()
def dohpfiltfilt(Fs, lowerpass, inputdata, order, padlen=20, cyclic=False, debug=False):
r"""Performs a bidirectional (zero phase) Butterworth highpass filter on an input vector
and returns the result. Ends are padded to reduce transients.
Expand Down Expand Up @@ -292,7 +279,6 @@ def dohpfiltfilt(Fs, lowerpass, inputdata, order, padlen=20, cyclic=False, debug
)


@conditionaljit()
def dobpfiltfilt(Fs, lowerpass, upperpass, inputdata, order, padlen=20, cyclic=False, debug=False):
r"""Performs a bidirectional (zero phase) Butterworth bandpass filter on an input vector
and returns the result. Ends are padded to reduce transients.
Expand Down Expand Up @@ -419,7 +405,6 @@ def getlpfftfunc(Fs, upperpass, inputdata, debug=False):
return transferfunc


@conditionaljit()
def dolpfftfilt(Fs, upperpass, inputdata, padlen=20, cyclic=False, debug=False):
r"""Performs an FFT brickwall lowpass filter on an input vector
and returns the result. Ends are padded to reduce transients.
Expand Down Expand Up @@ -462,7 +447,6 @@ def dolpfftfilt(Fs, upperpass, inputdata, padlen=20, cyclic=False, debug=False):
return unpadvec(fftpack.ifft(inputdata_trans).real, padlen=padlen)


@conditionaljit()
def dohpfftfilt(Fs, lowerpass, inputdata, padlen=20, cyclic=False, debug=False):
r"""Performs an FFT brickwall highpass filter on an input vector
and returns the result. Ends are padded to reduce transients.
Expand Down Expand Up @@ -505,7 +489,6 @@ def dohpfftfilt(Fs, lowerpass, inputdata, padlen=20, cyclic=False, debug=False):
return unpadvec(fftpack.ifft(inputdata_trans).real, padlen=padlen)


@conditionaljit()
def dobpfftfilt(Fs, lowerpass, upperpass, inputdata, padlen=20, cyclic=False, debug=False):
r"""Performs an FFT brickwall bandpass filter on an input vector
and returns the result. Ends are padded to reduce transients.
Expand Down Expand Up @@ -555,7 +538,8 @@ def dobpfftfilt(Fs, lowerpass, upperpass, inputdata, padlen=20, cyclic=False, de


# - fft trapezoidal filters
@conditionaljit()


def getlptrapfftfunc(Fs, upperpass, upperstop, inputdata, debug=False):
r"""Generates a trapezoidal lowpass transfer function.
Expand Down Expand Up @@ -608,7 +592,6 @@ def getlptrapfftfunc(Fs, upperpass, upperstop, inputdata, debug=False):
return transferfunc


@conditionaljit()
def getlptransfunc(Fs, inputdata, upperpass=None, upperstop=None, type="brickwall", debug=False):
if upperpass is None:
print("getlptransfunc: upperpass must be specified")
Expand Down Expand Up @@ -693,7 +676,6 @@ def gethptransfunc(Fs, inputdata, lowerstop=None, lowerpass=None, type="brickwal
return transferfunc


@conditionaljit()
def dolptransfuncfilt(
Fs,
inputdata,
Expand Down Expand Up @@ -757,7 +739,6 @@ def dolptransfuncfilt(
return unpadvec(fftpack.ifft(inputdata_trans).real, padlen=padlen)


@conditionaljit()
def dohptransfuncfilt(
Fs,
inputdata,
Expand Down Expand Up @@ -827,7 +808,6 @@ def dohptransfuncfilt(
return unpadvec(fftpack.ifft(inputdata_trans).real, padlen=padlen)


@conditionaljit()
def dobptransfuncfilt(
Fs,
inputdata,
Expand Down Expand Up @@ -908,7 +888,6 @@ def dobptransfuncfilt(
return unpadvec(fftpack.ifft(inputdata_trans).real, padlen=padlen)


@conditionaljit()
def dolptrapfftfilt(Fs, upperpass, upperstop, inputdata, padlen=20, cyclic=False, debug=False):
r"""Performs an FFT filter with a trapezoidal lowpass transfer
function on an input vector and returns the result. Ends are padded to reduce transients.
Expand Down Expand Up @@ -955,7 +934,6 @@ def dolptrapfftfilt(Fs, upperpass, upperstop, inputdata, padlen=20, cyclic=False
return unpadvec(fftpack.ifft(inputdata_trans).real, padlen=padlen)


@conditionaljit()
def dohptrapfftfilt(Fs, lowerstop, lowerpass, inputdata, padlen=20, cyclic=False, debug=False):
r"""Performs an FFT filter with a trapezoidal highpass transfer
function on an input vector and returns the result. Ends are padded to reduce transients.
Expand Down Expand Up @@ -1002,7 +980,6 @@ def dohptrapfftfilt(Fs, lowerstop, lowerpass, inputdata, padlen=20, cyclic=False
return unpadvec(fftpack.ifft(inputdata_trans).real, padlen=padlen)


@conditionaljit()
def dobptrapfftfilt(
Fs,
lowerstop,
Expand Down Expand Up @@ -1293,7 +1270,6 @@ def csdfilter(obsdata, commondata, padlen=20, cyclic=False, debug=False):
return unpadvec(fftpack.ifft(obsdata_trans).real, padlen=padlen)


@conditionaljit()
def arb_pass(
Fs,
inputdata,
Expand Down
47 changes: 2 additions & 45 deletions capcalc/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,15 @@

import matplotlib.pyplot as plt
import numpy as np
import pyfftw
import scipy as sp
import scipy.special as sps

# from numba import jit
from scipy.signal import find_peaks, hilbert

import capcalc.util as ccalc_util

fftpack = pyfftw.interfaces.scipy_fftpack
pyfftw.interfaces.cache.enable()

# ---------------------------------------- Global constants -------------------------------------------
defaultbutterorder = 6
MAXLINES = 10000000
donotbeaggressive = True

# ----------------------------------------- Conditional imports ---------------------------------------
try:
Expand All @@ -48,31 +41,6 @@
except ImportError:
memprofilerexists = False

donotusenumba = True


def conditionaljit():
def resdec(f):
if donotusenumba:
return f
return jit(f, nopython=False)

return resdec


def conditionaljit2():
def resdec(f):
if donotusenumba or donotbeaggressive:
return f
return jit(f, nopython=False)

return resdec


def disablenumba():
global donotusenumba
donotusenumba = True


# --------------------------- Fitting functions -------------------------------------------------
def gaussresidualssk(p, y, x):
Expand Down Expand Up @@ -108,7 +76,6 @@ def gaussskresiduals(p, y, x):
return y - gausssk_eval(x, p)


@conditionaljit()
def gaussresiduals(p, y, x):
"""
Expand Down Expand Up @@ -174,7 +141,6 @@ def gausssk_eval(x, p):
return p[0] * sp.stats.norm.pdf(t) * sp.stats.norm.cdf(p[3] * t)


@conditionaljit()
def kaiserbessel_eval(x, p):
"""
Expand All @@ -201,7 +167,6 @@ def kaiserbessel_eval(x, p):
)


@conditionaljit()
def gauss_eval(x, p):
"""
Expand Down Expand Up @@ -254,7 +219,6 @@ def risetime_eval_loop(x, p):
return r


@conditionaljit()
def trapezoid_eval(x, toplength, p):
"""
Expand All @@ -277,7 +241,6 @@ def trapezoid_eval(x, toplength, p):
return p[1] * (np.exp(-(corrx - toplength) / p[3]))


@conditionaljit()
def risetime_eval(x, p):
"""
Expand Down Expand Up @@ -367,7 +330,8 @@ def locpeak(data, samplerate, lastpeaktime, winsizeinsecs=5.0, thresh=0.75, hyst


# generate the polynomial fit timecourse from the coefficients
@conditionaljit()


def trendgen(thexvals, thefitcoffs, demean):
"""
Expand Down Expand Up @@ -416,7 +380,6 @@ def detrend(inputdata, order=1, demean=False):
return inputdata - thefittc


@conditionaljit()
def findfirstabove(theyvals, thevalue):
"""
Expand Down Expand Up @@ -640,7 +603,6 @@ def territorydecomp(
return fitmap, thecoffs, theRs


@conditionaljit()
def refinepeak_quad(x, y, peakindex, stride=1):
# first make sure this actually is a peak
ismax = None
Expand Down Expand Up @@ -670,7 +632,6 @@ def refinepeak_quad(x, y, peakindex, stride=1):
return peakloc, peakval, peakwidth, ismax, badfit


@conditionaljit2()
def findmaxlag_gauss(
thexcorr_x,
thexcorr_y,
Expand Down Expand Up @@ -924,7 +885,6 @@ def findmaxlag_gauss(
return maxindex, maxlag, maxval, maxsigma, maskval, failreason, fitstart, fitend


@conditionaljit2()
def maxindex_noedge(thexcorr_x, thexcorr_y, bipolar=False):
"""
Expand Down Expand Up @@ -965,8 +925,6 @@ def maxindex_noedge(thexcorr_x, thexcorr_y, bipolar=False):
return maxindex, flipfac


# disabled conditionaljit on 11/8/16. This causes crashes on some machines (but not mine, strangely enough)
@conditionaljit2()
def findmaxlag_gauss_rev(
thexcorr_x,
thexcorr_y,
Expand Down Expand Up @@ -1266,7 +1224,6 @@ def findmaxlag_gauss_rev(
)


@conditionaljit2()
def findmaxlag_quad(
thexcorr_x,
thexcorr_y,
Expand Down
Loading

0 comments on commit 68ac03d

Please sign in to comment.