Skip to content

More cleanup of imports focused mostly on testing and utils around testing #3841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/spikeinterface/extractors/tests/test_neoextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import subprocess
import os
from packaging import version
import importlib.util

import pytest

from spikeinterface.core.testing import check_recordings_equal
from spikeinterface import get_global_dataset_folder
from spikeinterface.extractors import *

Expand Down Expand Up @@ -40,11 +40,10 @@ def has_plexon2_dependencies():
return False

# Check for 'zugbruecke' using pip
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Check for 'zugbruecke' using pip

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm cool with committing this. Will do next.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that comment is ... false, right?

try:
import zugbruecke

zugbruecke_spec = importlib.util.find_spec("zugbruecke")
if zugbruecke_spec is not None:
return True
except ImportError:
else:
return False
else:
raise ValueError(f"Unsupported OS: {os_type}")
Expand Down
5 changes: 0 additions & 5 deletions src/spikeinterface/generation/tests/test_drifing_generator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
import pytest
import numpy as np
from pathlib import Path
import shutil

import probeinterface

from spikeinterface.generation import (
make_one_displacement_vector,
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class ComputePrincipalComponents(AnalyzerExtension):
If True, waveforms are pre-whitened
dtype : dtype, default: "float32"
Dtype of the pc scores
{}

Examples
--------
Expand Down Expand Up @@ -522,8 +523,6 @@ def _transform_waveforms(self, spikes, waveforms, pca_model, progress_bar):
# transform a waveforms buffer
# used by _run() and project_new()

from sklearn.exceptions import NotFittedError

mode = self.params["mode"]

# prepare buffer
Expand Down Expand Up @@ -682,6 +681,7 @@ def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafte
return worker_ctx


ComputePrincipalComponents.__doc__.format(_shared_job_kwargs_doc)
register_result_extension(ComputePrincipalComponents)
compute_principal_components = ComputePrincipalComponents.function_factory()

Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/postprocessing/spike_amplitudes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import numpy as np
import warnings

from spikeinterface.core.job_tools import fix_job_kwargs

Expand Down
2 changes: 0 additions & 2 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ def _set_params(
**other_kwargs,
):

import pandas as pd
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import should throw an error at the init of anything that will need it downstream. Is this the case, here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that's the case we can check check for the spec and raise an error. I can add that in a new commit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just wondering if this is a place where some things are initialized but I don't know template metrics.

Copy link
Member Author

@zm711 zm711 Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are NOT currently using the package here. I think because we use it later Sam/Alessio just import it everywhere. Tests are passing so we are not directly using it at init. But you raise a good point that we should probably check that it is around at init for users!


# TODO alessio can you check this : this used to be in the function but now we have ComputeTemplateMetrics.function_factory()
if include_multi_channel_metrics or (
metric_names is not None and any([m in get_multi_channel_template_metric_names() for m in metric_names])
Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/postprocessing/tests/test_correlograms.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np

try:
import numba
import importlib.util

numba_spec = importlib.util.find_spec("numba")
if numba_spec is not None:
HAVE_NUMBA = True
except ModuleNotFoundError as err:
else:
HAVE_NUMBA = False

from spikeinterface import NumpySorting, generate_sorting
Expand Down
3 changes: 0 additions & 3 deletions src/spikeinterface/preprocessing/scale.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from __future__ import annotations

import numpy as np

from spikeinterface.core import BaseRecording
from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor


Expand Down
3 changes: 0 additions & 3 deletions src/spikeinterface/preprocessing/tests/test_align_snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
but check only for BaseRecording general methods.
"""

import pytest
import numpy as np

from spikeinterface.core import generate_snippets
from spikeinterface.preprocessing.align_snippets import AlignSnippets

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import pytest
from pathlib import Path
from spikeinterface.core import load, set_global_tmp_folder
from spikeinterface.core import load
from spikeinterface.core.testing import check_recordings_equal
from spikeinterface.core.generate import generate_recording
from spikeinterface.preprocessing import gaussian_filter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import spikeinterface.preprocessing as spre
import spikeinterface.extractors as se
from spikeinterface.core import generate_recording
import spikeinterface.widgets as sw
import importlib.util

ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS"))
Expand All @@ -25,7 +24,7 @@


@pytest.mark.skipif(
importlib.util.find_spec("neurodsp") is not None or importlib.util.find_spec("spikeglx") or ON_GITHUB,
importlib.util.find_spec("neurodsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB,
reason="Only local. Requires ibl-neuropixel install",
)
@pytest.mark.parametrize("lagc", [False, 1, 300])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import pytest
from pathlib import Path
import numpy as np

from spikeinterface import set_global_tmp_folder
from spikeinterface.core import generate_recording

from spikeinterface.preprocessing import normalize_by_quantile, scale, center, zscore

import numpy as np


def test_normalize_by_quantile():
rec = generate_recording()
Expand Down
7 changes: 0 additions & 7 deletions src/spikeinterface/preprocessing/tests/test_rectify.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
import pytest
from pathlib import Path

from spikeinterface import set_global_tmp_folder
from spikeinterface.core import generate_recording

from spikeinterface.preprocessing import rectify

import numpy as np


def test_rectify():
rec = generate_recording()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
import pytest
from pathlib import Path
import shutil

from spikeinterface import set_global_tmp_folder, NumpyRecording
from spikeinterface.core import generate_recording
import numpy as np

from spikeinterface import NumpyRecording
from spikeinterface.preprocessing import unsigned_to_signed

import numpy as np


def test_unsigned_to_signed():
rng = np.random.RandomState(0)
Expand Down
8 changes: 5 additions & 3 deletions src/spikeinterface/preprocessing/tests/test_whiten.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import importlib.util

import pytest
import numpy as np

from spikeinterface.core import generate_recording
from spikeinterface.core import BaseRecording, BaseRecordingSegment
from spikeinterface.core.numpyextractors import NumpyRecording
from spikeinterface.preprocessing import whiten, scale, compute_whitening_matrix
from spikeinterface.preprocessing.whiten import compute_sklearn_covariance_matrix

try:
sklearn_spec = importlib.util.find_spec("sklearn")
if sklearn_spec is not None:
from sklearn import covariance as sklearn_covariance

HAS_SKLEARN = True
except ImportError:
else:
HAS_SKLEARN = False


Expand Down
2 changes: 0 additions & 2 deletions src/spikeinterface/preprocessing/tests/test_zero_padding.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pytest
from pathlib import Path
import numpy as np

from spikeinterface import set_global_tmp_folder
from spikeinterface.core import generate_recording
from spikeinterface.core.numpyextractors import NumpyRecording

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from spikeinterface import generate_ground_truth_recording
from spikeinterface.core.core_tools import is_editable_mode
import spikeinterface.extractors as se
import spikeinterface.sorters as ss


Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/external/tests/test_kilosort.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from spikeinterface.sorters import KilosortSorter

import os, getpass
import getpass

if getpass.getuser() == "samuel":
# kilosort_path = '/home/samuel/Documents/SpikeInterface/Kilosort1/'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from spikeinterface.sorters import Kilosort2Sorter

import os, getpass
import getpass

if getpass.getuser() == "samuel":
kilosort2_path = "/home/samuel/Documents/SpikeInterface/Kilosort2"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from spikeinterface.sorters import Kilosort3Sorter

import os, getpass
import getpass

if getpass.getuser() == "samuel":
kilosort3_path = "/home/samuel/Documents/SpikeInterface/Kilosort3"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import unittest
import pytest
from pathlib import Path

from spikeinterface import load, generate_ground_truth_recording
from spikeinterface.sorters import Kilosort4Sorter, run_sorter
Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/sorters/external/waveclus.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import shutil
import sys
import json
import importlib.util


from spikeinterface.sorters.basesorter import BaseSorter
Expand All @@ -14,15 +15,15 @@
from spikeinterface.core import write_to_h5_dataset_format
from spikeinterface.extractors import WaveClusSortingExtractor
from spikeinterface.core.channelslice import ChannelSliceRecording
from spikeinterface.preprocessing import ScaleRecording

PathType = Union[str, Path]

try:
h5py_spec = importlib.util.find_spec("h5py")
if h5py_spec is not None:
import h5py

HAVE_H5PY = True
except ImportError:
else:
HAVE_H5PY = False


Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/sorters/runsorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
has_singularity,
has_spython,
has_docker_nvidia_installed,
get_nvidia_docker_dependecies,
)
from .container_tools import (
find_recording_folders,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import os
import pytest
from pathlib import Path
import shutil
import platform
from spikeinterface import generate_ground_truth_recording
from spikeinterface.sorters.utils import has_spython, has_docker_python, has_docker, has_singularity
from spikeinterface.sorters import run_sorter
import subprocess
import sys
import copy


def _monkeypatch_return_false():
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
has_singularity,
has_spython,
has_docker_nvidia_installed,
get_nvidia_docker_dependecies,
get_nvidia_docker_dependencies,
)
31 changes: 13 additions & 18 deletions src/spikeinterface/sorters/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from __future__ import annotations

from pathlib import Path
import subprocess # TODO: decide best format for this
from subprocess import check_output, CalledProcessError
from typing import List, Union

import numpy as np
import importlib.util


class SpikeSortingError(RuntimeError):
Expand Down Expand Up @@ -66,13 +63,13 @@ def has_nvidia():
"""
Checks if the machine has nvidia capability.
"""

try:
cuda_spec = importlib.util.find_spec("cuda")
if cuda_spec is not None:
from cuda import cuda
except ModuleNotFoundError as err:
else:
raise Exception(
"This sorter requires cuda, but the package 'cuda-python' is not installed. You can install it with:\npip install cuda-python"
) from err
)

try:
(cu_result_init,) = cuda.cuInit(0)
Expand Down Expand Up @@ -118,14 +115,14 @@ def has_docker_nvidia_installed():
Whether at least one of the dependencies listed in
`get_nvidia_docker_dependecies()` is installed.
"""
all_dependencies = get_nvidia_docker_dependecies()
all_dependencies = get_nvidia_docker_dependencies()
has_dep = []
for dep in all_dependencies:
has_dep.append(_run_subprocess_silently(f"{dep} --version").returncode == 0)
return any(has_dep)


def get_nvidia_docker_dependecies():
def get_nvidia_docker_dependencies():
"""
See `has_docker_nvidia_installed()`
"""
Expand All @@ -137,18 +134,16 @@ def get_nvidia_docker_dependecies():


def has_docker_python():
try:
import docker

docker_spec = importlib.util.find_spec("docker")
if docker_spec is not None:
return True
except ImportError:
else:
return False


def has_spython():
try:
import spython

spython_spec = importlib.util.find_spec("spython")
if spython_spec is not None:
return True
except ImportError:
else:
return False
Loading