Repository: SpikeInterface/spikeextractors Branch: master Commit: d24335cc2fa6 Files: 118 Total size: 697.7 KB Directory structure: gitextract_c4b26tzl/ ├── .github/ │ └── workflows/ │ ├── python-package.yml │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── environment-dev.yml ├── full_requirements.txt ├── requirements-dev.txt ├── requirements.txt ├── setup.py ├── spikeextractors/ │ ├── __init__.py │ ├── baseextractor.py │ ├── cacheextractors.py │ ├── example_datasets/ │ │ ├── __init__.py │ │ ├── synthesize_random_firings.py │ │ ├── synthesize_random_waveforms.py │ │ ├── synthesize_single_waveform.py │ │ ├── synthesize_timeseries.py │ │ └── toy_example.py │ ├── exceptions.py │ ├── extraction_tools.py │ ├── extractorlist.py │ ├── extractors/ │ │ ├── __init__.py │ │ ├── alfsortingextractor/ │ │ │ ├── __init__.py │ │ │ └── alfsortingextractor.py │ │ ├── axonaunitrecordingextractor/ │ │ │ ├── __init__.py │ │ │ └── axonaunitrecordingextractor.py │ │ ├── bindatrecordingextractor/ │ │ │ ├── __init__.py │ │ │ └── bindatrecordingextractor.py │ │ ├── biocamrecordingextractor/ │ │ │ ├── __init__.py │ │ │ └── biocamrecordingextractor.py │ │ ├── cedextractors/ │ │ │ ├── __init__.py │ │ │ ├── cedrecordingextractor.py │ │ │ └── utils.py │ │ ├── cellexplorersortingextractor/ │ │ │ ├── __init__.py │ │ │ └── cellexplorersortingextractor.py │ │ ├── combinatosortingextractor/ │ │ │ ├── __init__.py │ │ │ └── combinatosortingextractor.py │ │ ├── exdirextractors/ │ │ │ ├── __init__.py │ │ │ └── exdirextractors.py │ │ ├── hdsortsortingextractor/ │ │ │ ├── __init__.py │ │ │ └── hdsortsortingextractor.py │ │ ├── hs2sortingextractor/ │ │ │ ├── __init__.py │ │ │ └── hs2sortingextractor.py │ │ ├── intanrecordingextractor/ │ │ │ ├── __init__.py │ │ │ └── intanrecordingextractor.py │ │ ├── jrcsortingextractor/ │ │ │ ├── __init__.py │ │ │ └── jrcsortingextractor.py │ │ ├── kilosortextractors/ │ │ │ ├── __init__.py │ │ │ └── kilosortextractors.py │ │ ├── klustaextractors/ │ │ │ ├── __init__.py │ │ │ └── klustaextractors.py │ │ ├── matsortingextractor/ │ │ │ ├── __init__.py │ │ │ └── matsortingextractor.py │ │ ├── maxwellextractors/ │ │ │ ├── __init__.py │ │ │ └── maxwellextractors.py │ │ ├── mcsh5recordingextractor/ │ │ │ ├── __init__.py │ │ │ └── mcsh5recordingextractor.py │ │ ├── mdaextractors/ │ │ │ ├── __init__.py │ │ │ ├── mdaextractors.py │ │ │ └── mdaio.py │ │ ├── mearecextractors/ │ │ │ ├── __init__.py │ │ │ └── mearecextractors.py │ │ ├── neoextractors/ │ │ │ ├── __init__.py │ │ │ ├── axonaextractor.py │ │ │ ├── blackrockextractor.py │ │ │ ├── mcsrawrecordingextractor.py │ │ │ ├── neobaseextractor.py │ │ │ ├── neuralynxextractor.py │ │ │ ├── plexonextractor.py │ │ │ └── spikegadgetsextractor.py │ │ ├── neuropixelsdatrecordingextractor/ │ │ │ ├── __init__.py │ │ │ ├── channel_positions_neuropixels.txt │ │ │ └── neuropixelsdatrecordingextractor.py │ │ ├── neuroscopeextractors/ │ │ │ ├── __init__.py │ │ │ └── neuroscopeextractors.py │ │ ├── nixioextractors/ │ │ │ ├── __init__.py │ │ │ └── nixioextractors.py │ │ ├── npzsortingextractor/ │ │ │ ├── __init__.py │ │ │ └── npzsortingextractor.py │ │ ├── numpyextractors/ │ │ │ ├── __init__.py │ │ │ └── numpyextractors.py │ │ ├── nwbextractors/ │ │ │ ├── __init__.py │ │ │ └── nwbextractors.py │ │ ├── openephysextractors/ │ │ │ ├── __init__.py │ │ │ └── openephysextractors.py │ │ ├── phyextractors/ │ │ │ ├── __init__.py │ │ │ └── phyextractors.py │ │ ├── shybridextractors/ │ │ │ ├── __init__.py │ │ │ └── shybridextractors.py │ │ ├── spikeglxrecordingextractor/ │ │ │ ├── __init__.py │ │ │ ├── readSGLX.py │ │ │ └── spikeglxrecordingextractor.py │ │ ├── spykingcircusextractors/ │ │ │ ├── __init__.py │ │ │ └── spykingcircusextractors.py │ │ ├── tridescloussortingextractor/ │ │ │ ├── __init__.py │ │ │ └── tridescloussortingextractor.py │ │ ├── waveclussortingextractor/ │ │ │ ├── __init__.py │ │ │ └── waveclussortingextractor.py │ │ └── yassextractors/ │ │ ├── __init__.py │ │ └── yassextractors.py │ ├── multirecordingchannelextractor.py │ ├── multirecordingtimeextractor.py │ ├── multisortingextractor.py │ ├── recordingextractor.py │ ├── save_tools.py │ ├── sortingextractor.py │ ├── subrecordingextractor.py │ ├── subsortingextractor.py │ ├── testing.py │ └── version.py └── tests/ ├── __init__.py ├── probe_test.prb ├── test_extractors.py ├── test_gin_repo.py ├── test_numpy_extractors.py └── test_tools.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/python-package.yml ================================================ name: Python Package using Conda on: push: branches: - master pull_request: branches: [master] types: [synchronize, opened, reopened, ready_for_review] jobs: build-and-test: name: Test on (${{ matrix.os }}) runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: os: ["ubuntu-latest", "macos-latest", "windows-latest"] steps: - uses: actions/checkout@v2 - uses: s-weigand/setup-conda@v1 with: python-version: 3.8 - name: Which python run: | conda --version which python - name: Install dependencies run: | conda install -c conda-forge datalad conda install -c conda-forge ruamel.yaml conda install flake8 conda install pytest pip install -r requirements-dev.txt pip install -r requirements.txt pip install h5py==2.10 pip install -e .[full] # needed for correct operation of git/git-annex/DataLad git config --global user.email "CI@example.com" git config --global user.name "CI Almighty" - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest and build coverage report run: | pytest ================================================ FILE: .github/workflows/python-publish.yml ================================================ # This workflow will upload a Python Package using Twine when a release is created # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries name: Test and Upload Python Package on: push: tags: - '*' jobs: deploy: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - uses: s-weigand/setup-conda@v1 with: python-version: 3.8 - name: Which python run: | conda --version which python - name: Install dependencies run: | conda install -c conda-forge datalad conda install -c conda-forge ruamel.yaml conda install flake8 conda install pytest pip install setuptools wheel twine pip install -r requirements-dev.txt pip install -r requirements.txt pip install h5py==2.10 pip install -e .[full] # needed for correct operation of git/git-annex/DataLad git config --global user.email "CI@example.com" git config --global user.name "CI Almighty" - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest and build coverage report run: | pytest - name: Publish on PyPI env: TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} run: | python setup.py sdist bdist_wheel twine upload dist/* ================================================ FILE: .gitignore ================================================ .eggs *.egg-info .ipynb_checkpoints __pycache__ sample_*_dataset ephy_testing_data/ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2019 SpikeInterface Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: MANIFEST.in ================================================ include spikeextractors/extractors/neuropixelsdatrecordingextractor/channel_positions_neuropixels.txt ================================================ FILE: README.md ================================================ # SpikeExtractors (LEGACY) The `spikeextractors` package has now been integrated into [spikeinterface](https://github.com/SpikeInterface/spikeinterface). This package will be maintained for a while for bug fixes only, then it will be deprecated. New features and improvements will only be implemented for the `spikeinterface` package. ================================================ FILE: environment-dev.yml ================================================ name: test dependencies: - python=3.8 - pip - pip: - numpy==1.22.0 - tqdm - lxml - h5py - shybrid - pynwb - nixio - pyintan - pyopenephys - neo - MEArec - hdf5storage - exdir - hdbscan - tridesclous - parametrized ================================================ FILE: full_requirements.txt ================================================ h5py #>=3.2.1 scipy>=1.6.3 pyintan>=0.3.0 pyopenephys>=1.1.4 neo>=0.9.0 MEArec<1.8 pynwb>=1.4 lxml>=4.6.3 nixio==1.5.0 shybrid>=0.4.2 pyyaml>=5.4.1 mtscomp>=1.0.1 exdir==0.4.1 hdf5storage sonpy;python_version>'3.7' ================================================ FILE: requirements-dev.txt ================================================ datalad parameterized neo==0.10 ================================================ FILE: requirements.txt ================================================ numpy==1.22.0 tqdm packaging ================================================ FILE: setup.py ================================================ import setuptools d = {} exec(open("spikeextractors/version.py").read(), None, d) version = d['version'] pkg_name = "spikeextractors" long_description = open("README.md").read() with open("full_requirements.txt", mode='r') as f: full_requires = f.read().split('\n') full_requires = [e for e in full_requires if len(e) > 0] extras_require = {"full": full_requires} setuptools.setup( name=pkg_name, version=version, author="Alessio Buccino, Cole Hurwitz, Samuel Garcia, Jeremy Magland, Matthias Hennig", author_email="alessio.buccino@gmail.com", description="Python module for extracting recorded and spike sorted extracellular data from different file types and formats", url="https://github.com/SpikeInterface/spikeextractors", long_description=long_description, long_description_content_type="text/markdown", packages=setuptools.find_packages(), package_data={}, include_package_data=True, install_requires=[ 'numpy', 'tqdm', 'joblib' ], extras_require=extras_require, classifiers=( "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", ) ) ================================================ FILE: spikeextractors/__init__.py ================================================ from .recordingextractor import RecordingExtractor from .sortingextractor import SortingExtractor from .cacheextractors import CacheRecordingExtractor, CacheSortingExtractor from .subsortingextractor import SubSortingExtractor from .subrecordingextractor import SubRecordingExtractor from .multirecordingchannelextractor import concatenate_recordings_by_channel, MultiRecordingChannelExtractor from .multirecordingtimeextractor import concatenate_recordings_by_time, MultiRecordingTimeExtractor from .multisortingextractor import concatenate_sortings, MultiSortingExtractor from .extractorlist import * from . import example_datasets from .extraction_tools import load_probe_file, save_to_probe_file, read_binary, write_to_binary_dat_format,\ write_to_h5_dataset_format, get_sub_extractors_by_property, load_extractor_from_json, load_extractor_from_dict, \ load_extractor_from_pickle from .save_tools import save_si_object from .version import version as __version__ ================================================ FILE: spikeextractors/baseextractor.py ================================================ import json from pathlib import Path import importlib import numpy as np import datetime from copy import deepcopy import tempfile import pickle import shutil from .exceptions import NotDumpableExtractorError class BaseExtractor: # To be specified in concrete sub-classes # The default filename (extension to be added by corresponding method) # to be used if no file path is provided _default_filename = None def __init__(self): self._kwargs = {} self._tmp_folder = None self._key_properties = {} self._properties = {} self._annotations = {} self._memmap_files = [] self._features = {} self._epochs = {} self._times = None self.is_dumpable = True self.id = np.random.randint(low=0, high=9223372036854775807, dtype='int64') def __del__(self): # close memmap files (for Windows) for memmap_obj in self._memmap_files: self.del_memmap_file(memmap_obj) if self._tmp_folder is not None and len(self._memmap_files) > 0: try: shutil.rmtree(self._tmp_folder) except Exception as e: print('Impossible to delete temp file:', self._tmp_folder, 'Error', e) def del_memmap_file(self, memmap_file): """ Safely deletes instantiated memmap file. Parameters ---------- memmap_file: str or Path The memmap file to delete """ if isinstance(memmap_file, np.memmap): memmap_file = memmap_file.filename else: memmap_file = Path(memmap_file) existing_memmap_files = [Path(memmap.filename) for memmap in self._memmap_files] if memmap_file in existing_memmap_files: try: memmap_idx = existing_memmap_files.index(memmap_file) memmap_obj = self._memmap_files[memmap_idx] if not memmap_obj._mmap.closed: memmap_obj._mmap.close() del memmap_obj memmap_file.unlink() del self._memmap_files[memmap_idx] except Exception as e: raise Exception(f"Error in deleting {memmap_file.name}: Error {e}") def make_serialized_dict(self, relative_to=None): """ Makes a nested serialized dictionary out of the extractor. The dictionary be used to re-initialize an extractor with spikeextractors.load_extractor_from_dict(dump_dict) Parameters ---------- relative_to: str, Path, or None If not None, file_paths are serialized relative to this path Returns ------- dump_dict: dict Serialized dictionary """ class_name = str(type(self)).replace("", '') module = class_name.split('.')[0] imported_module = importlib.import_module(module) try: version = imported_module.__version__ except AttributeError: version = 'unknown' if self.is_dumpable: dump_dict = {'class': class_name, 'module': module, 'kwargs': self._kwargs, 'key_properties': self._key_properties, 'annotations': self._annotations, 'version': version, 'dumpable': True} else: dump_dict = {'class': class_name, 'module': module, 'kwargs': {}, 'key_properties': self._key_properties, 'annotations': self._annotations, 'version': version, 'dumpable': False} if relative_to is not None: relative_to = Path(relative_to).absolute() assert relative_to.is_dir(), "'relative_to' must be an existing directory" dump_dict = _make_paths_relative(dump_dict, relative_to) return dump_dict def dump_to_dict(self, relative_to=None): """ Dumps recording to a dictionary. The dictionary be used to re-initialize an extractor with spikeextractors.load_extractor_from_dict(dump_dict) Parameters ---------- relative_to: str, Path, or None If not None, file_paths are serialized relative to this path Returns ------- dump_dict: dict Serialized dictionary """ return self.make_serialized_dict(relative_to) def _get_file_path(self, file_path, extensions): """ Helper to be used by various dump_to_file utilities. Returns default file_path (if not specified), assures that target directory exists, adds correct file extension if none, and assures that provided file extension is one of the allowed. Parameters ---------- file_path: str or None extensions: list or tuple First provided is used as an extension for the default file_path. All are tested against Returns ------- Path Path object with file path to the file Raises ------ NotDumpableExtractorError """ ext = extensions[0] if self.check_if_dumpable(): if file_path is None: file_path = self._default_filename + ext file_path = Path(file_path) file_path.parent.mkdir(parents=True, exist_ok=True) folder_path = file_path.parent if Path(file_path).suffix == '': file_path = folder_path / (str(file_path) + ext) assert file_path.suffix in extensions, \ "'file_path' should have one of the following extensions:" \ " %s" % (', '.join(extensions)) return file_path else: raise NotDumpableExtractorError( f"The extractor is not dumpable to {ext}") def dump_to_json(self, file_path=None, relative_to=None): """ Dumps recording extractor to json file. The extractor can be re-loaded with spikeextractors.load_extractor_from_json(json_file) Parameters ---------- file_path: str Path of the json file relative_to: str, Path, or None If not None, file_paths are serialized relative to this path """ dump_dict = self.make_serialized_dict(relative_to) self._get_file_path(file_path, ['.json'])\ .write_text( json.dumps(_check_json(dump_dict), indent=4), encoding='utf8' ) def dump_to_pickle(self, file_path=None, include_properties=True, include_features=True, relative_to=None): """ Dumps recording extractor to a pickle file. The extractor can be re-loaded with spikeextractors.load_extractor_from_json(json_file) Parameters ---------- file_path: str Path of the json file include_properties: bool If True, all properties are dumped include_features: bool If True, all features are dumped relative_to: str, Path, or None If not None, file_paths are serialized relative to this path """ file_path = self._get_file_path(file_path, ['.pkl', '.pickle']) # Dump all dump_dict = {'serialized_dict': self.make_serialized_dict(relative_to)} if include_properties: if len(self._properties.keys()) > 0: dump_dict['properties'] = self._properties if include_features: if len(self._features.keys()) > 0: dump_dict['features'] = self._features # include times dump_dict["times"] = self._times file_path.write_bytes(pickle.dumps(dump_dict)) def get_tmp_folder(self): """ Returns temporary folder associated to the extractor Returns ------- temp_folder: Path The temporary folder """ if self._tmp_folder is None: self._tmp_folder = Path(tempfile.mkdtemp()) return self._tmp_folder def set_tmp_folder(self, folder): """ Sets temporary folder of the extractor Parameters ---------- folder: str or Path The temporary folder """ self._tmp_folder = Path(folder) def allocate_array(self, memmap, shape=None, dtype=None, name=None, array=None): """ Allocates a memory or memmap array Parameters ---------- memmap: bool If True, a memmap array is created in the sorting temporary folder shape: tuple Shape of the array. If None array must be given dtype: dtype Dtype of the array. If None array must be given name: str or None Name (root) of the file (if memmap is True). If None, a random name is generated array: np.array If array is given, shape and dtype are initialized based on the array. If memmap is True, the array is then deleted to clear memory Returns ------- arr: np.array or np.memmap The allocated memory or memmap array """ if memmap: tmp_folder = self.get_tmp_folder() if array is not None: shape = array.shape dtype = array.dtype else: assert shape is not None and dtype is not None, "Pass 'shape' and 'dtype' arguments" if name is None: tmp_file = tempfile.NamedTemporaryFile(suffix=".raw", dir=tmp_folder).name else: if Path(name).suffix == '': tmp_file = tmp_folder / (name + '.raw') else: tmp_file = tmp_folder / name raw_tmp_file = r'{}'.format(str(tmp_file)) # make sure any open memmap files with same path are deleted self.del_memmap_file(raw_tmp_file) arr = np.memmap(raw_tmp_file, mode='w+', shape=shape, dtype=dtype) if array is not None: arr[:] = array del array else: arr[:] = 0 self._memmap_files.append(arr) else: if array is not None: arr = array else: arr = np.zeros(shape, dtype=dtype) return arr def annotate(self, annotation_key, value, overwrite=False): """This function adds an entry to the annotations dictionary. Parameters ---------- annotation_key: str An annotation stored by the Extractor value: The data associated with the given property name. Could be many formats as specified by the user overwrite: bool If True and the annotation already exists, it is overwritten """ if annotation_key not in self._annotations.keys(): self._annotations[annotation_key] = value else: if overwrite: self._annotations[annotation_key] = value else: print(f"{annotation_key} is already an annotation key. Use 'overwrite=True' to overwrite it") def get_annotation(self, annotation_name): """This function returns the data stored under the annotation name. Parameters ---------- annotation_name: str A property stored by the Extractor Returns ---------- annotation_data The data associated with the given property name. Could be many formats as specified by the user """ if annotation_name not in self._annotations.keys(): print(f"{annotation_name} is not an annotation") return None else: return deepcopy(self._annotations[annotation_name]) def get_annotation_keys(self): """This function returns a list of stored annotation keys Returns ---------- property_names: list List of stored annotation keys """ return list(self._annotations.keys()) def copy_annotations(self, extractor): """Copy object properties from another extractor to the current extractor. Parameters ---------- extractor: Extractor The extractor from which the annotations will be copied """ self._annotations = deepcopy(extractor._annotations) def add_epoch(self, epoch_name, start_frame, end_frame): """This function adds an epoch to your extractor that tracks a certain time period. It is stored in an internal dictionary of start and end frame tuples. Parameters ---------- epoch_name: str The name of the epoch to be added start_frame: int The start frame of the epoch to be added (inclusive) end_frame: int The end frame of the epoch to be added (exclusive). If set to None, it will include the entire sorting after the start_frame """ if isinstance(epoch_name, str): start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame) self._epochs[epoch_name] = {'start_frame': start_frame, 'end_frame': end_frame} else: raise TypeError("epoch_name must be a string") def remove_epoch(self, epoch_name): """This function removes an epoch from your extractor. Parameters ---------- epoch_name: str The name of the epoch to be removed """ if isinstance(epoch_name, str): if epoch_name in list(self._epochs.keys()): del self._epochs[epoch_name] else: raise ValueError("This epoch has not been added") else: raise ValueError("epoch_name must be a string") def get_epoch_names(self): """This function returns a list of all the epoch names in the extractor Returns ---------- epoch_names: list List of epoch names in the recording extractor """ epoch_names = list(self._epochs.keys()) if not epoch_names: pass else: epoch_start_frames = [] for epoch_name in epoch_names: epoch_info = self.get_epoch_info(epoch_name) start_frame = epoch_info['start_frame'] epoch_start_frames.append(start_frame) epoch_names = [epoch_name for _, epoch_name in sorted(zip(epoch_start_frames, epoch_names))] return epoch_names def get_epoch_info(self, epoch_name): """This function returns the start frame and end frame of the epoch in a dict. Parameters ---------- epoch_name: str The name of the epoch to be returned Returns ---------- epoch_info: dict A dict containing the start frame and end frame of the epoch """ # Default (Can add more information into each epoch in subclass) if isinstance(epoch_name, str): if epoch_name in list(self._epochs.keys()): epoch_info = self._epochs[epoch_name] return epoch_info else: raise ValueError("This epoch has not been added") else: raise ValueError("epoch_name must be a string") def copy_epochs(self, extractor): """Copy epochs from another extractor. Parameters ---------- extractor: BaseExtractor The extractor from which the epochs will be copied """ for epoch_name in extractor.get_epoch_names(): epoch_info = extractor.get_epoch_info(epoch_name) self.add_epoch(epoch_name, epoch_info["start_frame"], epoch_info["end_frame"]) def _cast_start_end_frame(self, start_frame, end_frame): from .extraction_tools import cast_start_end_frame return cast_start_end_frame(start_frame, end_frame) @staticmethod def load_extractor_from_json(json_file): """ Instantiates extractor from json file Parameters ---------- json_file: str or Path Path to json file Returns ------- extractor: RecordingExtractor or SortingExtractor The loaded extractor object """ json_file = Path(json_file) with open(str(json_file), 'r') as f: d = json.load(f) extractor = _load_extractor_from_dict(d) return extractor @staticmethod def load_extractor_from_pickle(pkl_file): """ Instantiates extractor from pickle file. Parameters ---------- pkl_file: str or Path Path to pickle file Returns ------- extractor: RecordingExtractor or SortingExtractor The loaded extractor object """ pkl_file = Path(pkl_file) with open(str(pkl_file), 'rb') as f: d = pickle.load(f) extractor = _load_extractor_from_dict(d['serialized_dict']) if 'properties' in d.keys(): extractor._properties = d['properties'] if 'features' in d.keys(): extractor._features = d['features'] if 'times' in d.keys(): extractor._times = d['times'] return extractor @staticmethod def load_extractor_from_dict(d): """ Instantiates extractor from dictionary Parameters ---------- d: dictionary Python dictionary Returns ------- extractor: RecordingExtractor or SortingExtractor The loaded extractor object """ extractor = _load_extractor_from_dict(d) return extractor def check_if_dumpable(self): return _check_if_dumpable(self.make_serialized_dict()) def _make_paths_relative(d, relative): dcopy = deepcopy(d) if "kwargs" in dcopy.keys(): relative_kwargs = _make_paths_relative(dcopy["kwargs"], relative) dcopy["kwargs"] = relative_kwargs return dcopy else: for k in d.keys(): # in SI, all input paths have the "path" keyword if "path" in k: d[k] = str(Path(d[k]).relative_to(relative)) return d def _load_extractor_from_dict(dic): cls = None class_name = None probe_file = None kwargs = deepcopy(dic['kwargs']) if np.any([isinstance(v, dict) for v in kwargs.values()]): # nested for k in kwargs.keys(): if isinstance(kwargs[k], dict): if 'module' in kwargs[k].keys() and 'class' in kwargs[k].keys() and 'version' in kwargs[k].keys(): extractor = _load_extractor_from_dict(kwargs[k]) class_name = dic['class'] cls = _get_class_from_string(class_name) kwargs[k] = extractor break elif np.any([isinstance(v, list) and isinstance(v[0], dict) for v in kwargs.values()]): # multi for k in kwargs.keys(): if isinstance(kwargs[k], list) and isinstance(kwargs[k][0], dict): extractors = [] for kw in kwargs[k]: if 'module' in kw.keys() and 'class' in kw.keys() and 'version' in kw.keys(): extr = _load_extractor_from_dict(kw) extractors.append(extr) class_name = dic['class'] cls = _get_class_from_string(class_name) kwargs[k] = extractors break else: class_name = dic['class'] cls = _get_class_from_string(class_name) assert cls is not None and class_name is not None, "Could not load spikeinterface class" if not _check_same_version(class_name, dic['version']): print('Versions are not the same. This might lead to errors. Use ', class_name.split('.')[0], 'version', dic['version']) if 'probe_file' in kwargs.keys(): probe_file = kwargs.pop('probe_file') # instantiate extrator object extractor = cls(**kwargs) # load probe file if probe_file is not None: assert 'Recording' in class_name, "Only recording extractors can have probe files" extractor = extractor.load_probe_file(probe_file=probe_file) # load properties and features if 'key_properties' in dic.keys(): extractor._key_properties = dic['key_properties'] if 'annotations' in dic.keys(): extractor._annotations = dic['annotations'] return extractor def _get_class_from_string(class_string): class_name = class_string.split('.')[-1] module = '.'.join(class_string.split('.')[:-1]) imported_module = importlib.import_module(module) try: imported_class = getattr(imported_module, class_name) except: imported_class = None return imported_class def _check_same_version(class_string, version): module = class_string.split('.')[0] imported_module = importlib.import_module(module) try: return imported_module.__version__ == version except AttributeError: return 'unknown' def _check_if_dumpable(d): kwargs = d['kwargs'] if np.any([isinstance(v, dict) and 'dumpable' in v.keys() for (k, v) in kwargs.items()]): for k, v in kwargs.items(): if 'dumpable' in v.keys(): return _check_if_dumpable(v) else: return d['dumpable'] def _check_json(d): # quick hack to ensure json writable for k, v in d.items(): if isinstance(v, dict): d[k] = _check_json(v) elif isinstance(v, Path): d[k] = str(v.absolute()) elif isinstance(v, bool): d[k] = bool(v) elif isinstance(v, (int, np.integer)): d[k] = int(v) elif isinstance(v, float): d[k] = float(v) elif isinstance(v, datetime.datetime): d[k] = v.isoformat() elif isinstance(v, (np.ndarray, list)): if len(v) > 0: if isinstance(v[0], dict): # these must be extractors for multi extractors d[k] = [_check_json(v_el) for v_el in v] else: v_arr = np.array(v) if len(v_arr.shape) == 1: if 'int' in str(v_arr.dtype): v_arr = [int(v_el) for v_el in v_arr] d[k] = v_arr elif 'float' in str(v_arr.dtype): v_arr = [float(v_el) for v_el in v_arr] d[k] = v_arr elif isinstance(v_arr[0], str): v_arr = [str(v_el) for v_el in v_arr] d[k] = v_arr else: print(f'Skipping field {k}: only 1D arrays of int, float, or str types can be serialized') elif len(v_arr.shape) == 2: if 'int' in str(v_arr.dtype): v_arr = [[int(v_el) for v_el in v_row] for v_row in v_arr] d[k] = v_arr elif 'float' in str(v_arr.dtype): v_arr = [[float(v_el) for v_el in v_row] for v_row in v_arr] d[k] = v_arr elif 'bool' in str(v_arr.dtype): v_arr = [[bool(v_el) for v_el in v_row] for v_row in v_arr] d[k] = v_arr else: print(f'Skipping field {k}: only 2D arrays of int or float type can be serialized') else: print(f"Skipping field {k}: only 1D and 2D arrays can be serialized") else: d[k] = list(v) return d ================================================ FILE: spikeextractors/cacheextractors.py ================================================ from spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor from spikeextractors.extractors.npzsortingextractor import NpzSortingExtractor from spikeextractors import RecordingExtractor, SortingExtractor import tempfile from pathlib import Path from copy import deepcopy import importlib import shutil class CacheRecordingExtractor(BinDatRecordingExtractor, RecordingExtractor): def __init__(self, recording, return_scaled=True, chunk_size=None, chunk_mb=500, save_path=None, n_jobs=1, joblib_backend='loky', verbose=False): RecordingExtractor.__init__(self) # init tmp folder before constructing BinDatRecordingExtractor tmp_folder = self.get_tmp_folder() self._recording = recording if save_path is None: self._is_tmp = True self._tmp_file = tempfile.NamedTemporaryFile(suffix=".dat", dir=tmp_folder).name else: save_path = Path(save_path) if save_path.suffix != '.dat' and save_path.suffix != '.bin': save_path = save_path.with_suffix('.dat') save_path.parent.mkdir(parents=True, exist_ok=True) self._is_tmp = False self._tmp_file = save_path self._return_scaled = return_scaled self._dtype = recording.get_dtype(return_scaled) recording.write_to_binary_dat_format(save_path=self._tmp_file, dtype=self._dtype, chunk_size=chunk_size, chunk_mb=chunk_mb, n_jobs=n_jobs, joblib_backend=joblib_backend, return_scaled=self._return_scaled, verbose=verbose) # keep track of filter status when dumping self.is_filtered = self._recording.is_filtered BinDatRecordingExtractor.__init__(self, self._tmp_file, numchan=recording.get_num_channels(), recording_channels=recording.get_channel_ids(), sampling_frequency=recording.get_sampling_frequency(), dtype=self._dtype, is_filtered=self.is_filtered) self.set_tmp_folder(tmp_folder) self.copy_channel_properties(recording) self.copy_times(recording) if 'gain' in recording.get_shared_channel_property_names() and not return_scaled: self.set_channel_gains(recording.get_channel_gains()) self.set_channel_offsets(recording.get_channel_offsets()) self.has_unscaled = True else: self.clear_channel_gains() self.clear_channel_offsets() # keep BinDatRecording kwargs self._bindat_kwargs = deepcopy(self._kwargs) self._kwargs = {'recording': recording, 'chunk_size': chunk_size, 'chunk_mb': chunk_mb} def __del__(self): if self._is_tmp: try: # close memmap file (for Windows) del self._timeseries Path(self._tmp_file).unlink() except Exception as e: print("Unable to remove temporary file", e) @property def filename(self): return str(self._tmp_file) def move_to(self, save_path): save_path = Path(save_path) if save_path.suffix != '.dat' and save_path.suffix != '.bin': save_path = save_path.with_suffix('.dat') save_path.parent.mkdir(parents=True, exist_ok=True) # close memmap file (for Windows) del self._timeseries shutil.move(self._tmp_file, str(save_path)) self._tmp_file = str(save_path) self._kwargs['file_path'] = str(Path(self._tmp_file).absolute()) self._bindat_kwargs['file_path'] = str(Path(self._tmp_file).absolute()) self._is_tmp = False tmp_folder = self.get_tmp_folder() # re-initialize with new file BinDatRecordingExtractor.__init__(self, self._tmp_file, numchan=self._recording.get_num_channels(), recording_channels=self._recording.get_channel_ids(), sampling_frequency=self._recording.get_sampling_frequency(), dtype=self._dtype, is_filtered=self.is_filtered) self.set_tmp_folder(tmp_folder) self.copy_channel_properties(self._recording) # override to make serialization avoid reloading and saving binary file def make_serialized_dict(self, include_properties=None, include_features=None): """ Makes a nested serialized dictionary out of the extractor. The dictionary be used to re-initialize an extractor with spikeextractors.load_extractor_from_dict(dump_dict) Returns ------- include_properties: list or None List of properties to include in the dictionary include_features: list or None List of features to include in the dictionary """ class_name = str(BinDatRecordingExtractor).replace("", '') module = class_name.split('.')[0] imported_module = importlib.import_module(module) if self._is_tmp: print("Warning: dumping a CacheRecordingExtractor. The path to the tmp binary file will be lost in " "further sessions. To prevent this, use the 'CacheRecordingExtractor.move_to('path-to-file)' " "function") dump_dict = {'class': class_name, 'module': module, 'kwargs': self._bindat_kwargs, 'key_properties': self._key_properties, 'version': imported_module.__version__, 'dumpable': True} return dump_dict class CacheSortingExtractor(NpzSortingExtractor, SortingExtractor): def __init__(self, sorting, save_path=None): SortingExtractor.__init__(self) # init tmp folder before constructing NpzSortingExtractor tmp_folder = self.get_tmp_folder() self._sorting = sorting if save_path is None: self._is_tmp = True self._tmp_file = tempfile.NamedTemporaryFile(suffix=".npz", dir=tmp_folder).name else: save_path = Path(save_path) if save_path.suffix != '.npz': save_path = save_path.with_suffix('.npz') save_path.parent.mkdir(parents=True, exist_ok=True) self._is_tmp = False self._tmp_file = save_path NpzSortingExtractor.write_sorting(self._sorting, self._tmp_file) NpzSortingExtractor.__init__(self, self._tmp_file) # keep Npz kwargs self._npz_kwargs = deepcopy(self._kwargs) self.set_tmp_folder(tmp_folder) self.copy_unit_properties(sorting) self.copy_unit_spike_features(sorting) self._kwargs = {'sorting': sorting} def __del__(self): if self._is_tmp: try: Path(self._tmp_file).unlink() except Exception as e: print("Unable to remove temporary file", e) @property def filename(self): return str(self._tmp_file) def move_to(self, save_path): save_path = Path(save_path) if save_path.suffix != '.npz': save_path = save_path.with_suffix('.npz') save_path.parent.mkdir(parents=True, exist_ok=True) shutil.move(self._tmp_file, str(save_path)) self._tmp_file = str(save_path) self._kwargs['file_path'] = str(Path(self._tmp_file).absolute()) self._npz_kwargs['file_path'] = str(Path(self._tmp_file).absolute()) self._is_tmp = False tmp_folder = self.get_tmp_folder() # re-initialize with new file NpzSortingExtractor.__init__(self, self._tmp_file) # keep Npz kwargs self.set_tmp_folder(tmp_folder) self.copy_unit_properties(self._sorting) self.copy_unit_spike_features(self._sorting) # override to make serialization avoid reloading and saving npz file def make_serialized_dict(self, include_properties=None, include_features=None): """ Makes a nested serialized dictionary out of the extractor. The dictionary be used to re-initialize an extractor with spikeextractors.load_extractor_from_dict(dump_dict) Returns ------- include_properties: list or None List of properties to include in the dictionary include_features: list or None List of features to include in the dictionary """ class_name = str(NpzSortingExtractor).replace("", '') module = class_name.split('.')[0] imported_module = importlib.import_module(module) if self._is_tmp: print("Warning: dumping a CacheSortingExtractor. The path to the tmp binary file will be lost in " "further sessions. To prevent this, use the 'CacheSortingExtractor.move_to('path-to-file)' " "function") dump_dict = {'class': class_name, 'module': module, 'kwargs': self._npz_kwargs, 'key_properties': self._key_properties, 'version': imported_module.__version__, 'dumpable': True} return dump_dict ================================================ FILE: spikeextractors/example_datasets/__init__.py ================================================ from .toy_example import toy_example ================================================ FILE: spikeextractors/example_datasets/synthesize_random_firings.py ================================================ import numpy as np def synthesize_random_firings(*, K=20, sampling_frequency=30000.0, duration=60, seed=None): if seed is not None: np.random.seed(seed) seeds = np.random.RandomState(seed=seed).randint(0, 2147483647, K) else: seeds = np.random.randint(0, 2147483647, K) firing_rates = 3 * np.ones((K)) refr = 4 N = np.int64(duration * sampling_frequency) # events/sec * sec/timepoint * N populations = np.ceil(firing_rates / sampling_frequency * N).astype('int') times = np.zeros(0) labels = np.zeros(0) for i, k in enumerate(range(1, K + 1)): refr_timepoints = refr / 1000 * sampling_frequency times0 = np.random.rand(populations[k - 1]) * (N - 1) + 1 ## make an interesting autocorrelogram shape times0 = np.hstack((times0, times0 + rand_distr2(refr_timepoints, refr_timepoints * 20, times0.size, seeds[i]))) times0 = times0[np.random.RandomState(seed=seeds[i]).choice(times0.size, int(times0.size / 2))] times0 = times0[np.where((0 <= times0) & (times0 < N))] times0 = enforce_refractory_period(times0, refr_timepoints) times = np.hstack((times, times0)) labels = np.hstack((labels, k * np.ones(times0.shape))) sort_inds = np.argsort(times) times = times[sort_inds] labels = labels[sort_inds] return (times, labels) def rand_distr2(a, b, num, seed): X = np.random.RandomState(seed=seed).rand(num) X = a + (b - a) * X ** 2 return X def enforce_refractory_period(times_in, refr): if (times_in.size == 0): return times_in times0 = np.sort(times_in) done = False while not done: diffs = times0[1:] - times0[:-1] diffs = np.hstack((diffs, np.inf)) # hack to make sure we handle the last one inds0 = np.where((diffs[:-1] <= refr) & (diffs[1:] >= refr))[0] # only first violator in every group if len(inds0) > 0: times0[inds0] = -1 # kind of a hack, what's the better way? times0 = times0[np.where(times0 >= 0)] else: done = True return times0 ================================================ FILE: spikeextractors/example_datasets/synthesize_random_waveforms.py ================================================ import numpy as np from .synthesize_single_waveform import synthesize_single_waveform def synthesize_random_waveforms(*, M=5, T=500, K=20, upsamplefac=13, timeshift_factor=3, average_peak_amplitude=-10, seed=None): if seed is not None: np.random.seed(seed) seeds = np.random.RandomState(seed=seed).randint(0, 2147483647, K) else: seeds = np.random.randint(0, 2147483647, K) geometry = None avg_durations = [200, 10, 30, 200] avg_amps = [0.5, 10, -1, 0] rand_durations_stdev = [10, 4, 6, 20] rand_amps_stdev = [0.2, 3, 0.5, 0] rand_amp_factor_range = [0.5, 1] geom_spread_coef1 = 0.2 geom_spread_coef2 = 1 if not geometry: geometry = np.zeros((2, M)) geometry[0, :] = np.arange(1, M + 1) geometry = np.array(geometry) avg_durations = np.array(avg_durations) avg_amps = np.array(avg_amps) rand_durations_stdev = np.array(rand_durations_stdev) rand_amps_stdev = np.array(rand_amps_stdev) rand_amp_factor_range = np.array(rand_amp_factor_range) neuron_locations = get_default_neuron_locations(M, K, geometry) ## The waveforms_out WW = np.zeros((M, T * upsamplefac, K)) for i, k in enumerate(range(1, K + 1)): for m in range(1, M + 1): diff = neuron_locations[:, k - 1] - geometry[:, m - 1] dist = np.sqrt(np.sum(diff ** 2)) durations0 = np.maximum(np.ones(avg_durations.shape), avg_durations + np.random.RandomState(seed=seeds[i]).randn(1, 4) * rand_durations_stdev) * upsamplefac amps0 = avg_amps + np.random.RandomState(seed=seeds[i]).randn(1, 4) * rand_amps_stdev waveform0 = synthesize_single_waveform(N=T * upsamplefac, durations=durations0, amps=amps0) waveform0 = np.roll(waveform0, int(timeshift_factor * dist * upsamplefac)) waveform0 = waveform0 * np.random.RandomState(seed=seeds[i]).uniform(rand_amp_factor_range[0], rand_amp_factor_range[1]) WW[m - 1, :, k - 1] = waveform0 / (geom_spread_coef1 + dist * geom_spread_coef2) peaks = np.max(np.abs(WW), axis=(0, 1)) WW = WW / np.mean(peaks) * average_peak_amplitude return (WW, geometry.T) def get_default_neuron_locations(M, K, geometry): num_dims = geometry.shape[0] neuron_locations = np.zeros((num_dims, K)) for k in range(1, K + 1): if K > 0: ind = (k - 1) / (K - 1) * (M - 1) + 1 ind0 = int(ind) if ind0 == M: ind0 = M - 1 p = 1 else: p = ind - ind0 if M > 0: neuron_locations[:, k - 1] = (1 - p) * geometry[:, ind0 - 1] + p * geometry[:, ind0] else: neuron_locations[:, k - 1] = geometry[:, 0] else: neuron_locations[:, k - 1] = geometry[:, 0] return neuron_locations ================================================ FILE: spikeextractors/example_datasets/synthesize_single_waveform.py ================================================ import numpy as np def exp_growth(amp1, amp2, dur1, dur2): t = np.arange(0, dur1) Y = np.exp(t / dur2) # Want Y[0]=amp1 # Want Y[-1]=amp2 Y = Y / (Y[-1] - Y[0]) * (amp2 - amp1) Y = Y - Y[0] + amp1; return Y def exp_decay(amp1, amp2, dur1, dur2): Y = exp_growth(amp2, amp1, dur1, dur2) Y = np.flipud(Y) # used to be flip, but that was not supported by older versions of numpy return Y def smooth_it(Y, t): Z = np.zeros(Y.size) for j in range(-t, t + 1): Z = Z + np.roll(Y, j) return Z def synthesize_single_waveform(*, N=800, durations=[200, 10, 30, 200], amps=[0.5, 10, -1, 0]): durations = np.array(durations).ravel() if (np.sum(durations) >= N - 2): durations[-1] = N - 2 - np.sum(durations[0:durations.size - 1]) amps = np.array(amps).ravel() timepoints = np.round(np.hstack((0, np.cumsum(durations) - 1))).astype('int'); t = np.r_[0:np.sum(durations) + 1] Y = np.zeros(len(t)) Y[timepoints[0]:timepoints[1] + 1] = exp_growth(0, amps[0], timepoints[1] + 1 - timepoints[0], durations[0] / 4) Y[timepoints[1]:timepoints[2] + 1] = exp_growth(amps[0], amps[1], timepoints[2] + 1 - timepoints[1], durations[1]) Y[timepoints[2]:timepoints[3] + 1] = exp_decay(amps[1], amps[2], timepoints[3] + 1 - timepoints[2], durations[2] / 4) Y[timepoints[3]:timepoints[4] + 1] = exp_decay(amps[2], amps[3], timepoints[4] + 1 - timepoints[3], durations[3] / 5) Y = smooth_it(Y, 3) Y = Y - np.linspace(Y[0], Y[-1], len(t)) Y = np.hstack((Y, np.zeros(N - len(t)))) Nmid = int(np.floor(N / 2)) peakind = np.argmax(np.abs(Y)) Y = np.roll(Y, Nmid - peakind) return Y # Y=smooth_it(Y,3); # Y=Y-linspace(Y(1),Y(end),length(Y)); # # Y=[Y,zeros(1,N-length(Y))]; # # Nmid=floor(N/2); # [~,peakind]=max(abs(Y)); # Y=circshift(Y,[0,Nmid-peakind]); # # end # # function test_synth_waveform # Y=synthesize_single_waveform(800); # figure; plot(Y); # end # # function Y=exp_growth(amp1,amp2,dur1,dur2) # t=1:dur1; # Y=exp(t/dur2); # % Want Y(1)=amp1 # % Want Y(end)=amp2 # Y=Y/(Y(end)-Y(1))*(amp2-amp1); # Y=Y-Y(1)+amp1; # end # # function Y=exp_decay(amp1,amp2,dur1,dur2) # Y=exp_growth(amp2,amp1,dur1,dur2); # Y=Y(end:-1:1); # end # # function Z=smooth_it(Y,t) # Z=Y; # Z(1+t:end-t)=0; # for j=-t:t # Z(1+t:end-t)=Z(1+t:end-t)+Y(1+t+j:end-t+j)/(2*t+1); # end; # end if __name__ == '__main__': Y = synthesize_single_waveform() import matplotlib.pyplot as plt plt.plot(Y) ================================================ FILE: spikeextractors/example_datasets/synthesize_timeseries.py ================================================ import numpy as np def synthesize_timeseries(*, sorting, waveforms, noise_level=1, sampling_frequency=30000.0, duration=60, waveform_upsamplefac=13, seed=None): num_timepoints = np.int64(sampling_frequency * duration) waveform_upsamplefac = int(waveform_upsamplefac) W = waveforms M, TT, K = W.shape[0], W.shape[1], W.shape[2] T = int(TT / waveform_upsamplefac) Tmid = int(np.ceil((T + 1) / 2 - 1)) N = num_timepoints if seed is not None: X = np.random.RandomState(seed=seed).randn(M, N) * noise_level else: X = np.random.randn(M, N) * noise_level unit_ids = sorting.get_unit_ids() for k0 in unit_ids: waveform0 = waveforms[:, :, k0 - 1] times0 = sorting.get_unit_spike_train(unit_id=k0) for t0 in times0: amp0 = 1 frac_offset = int(np.floor((t0 - np.floor(t0)) * waveform_upsamplefac)) tstart = np.int64(np.floor(t0)) - Tmid if (0 <= tstart) and (tstart + T <= N): X[:, tstart:tstart + T] = X[:, tstart:tstart + T] + waveform0[:, frac_offset::waveform_upsamplefac] * amp0 return X ================================================ FILE: spikeextractors/example_datasets/toy_example.py ================================================ import numpy as np from pathlib import Path from typing import Optional, Union import spikeextractors as se from .synthesize_random_waveforms import synthesize_random_waveforms from .synthesize_random_firings import synthesize_random_firings from .synthesize_timeseries import synthesize_timeseries def toy_example( duration: float = 10., num_channels: int = 4, sampling_frequency: float = 30000., K: int = 10, dumpable: bool = False, dump_folder: Optional[Union[str, Path]] = None, seed: Optional[int] = None ): """ Create toy recording and sorting extractors. Parameters ---------- duration: float Duration in s (default 10) num_channels: int Number of channels (default 4) sampling_frequency: float Sampling frequency (default 30000) K: int Number of units (default 10) dumpable: bool If True, objects are dumped to file and become 'dumpable' dump_folder: str or Path Path to dump folder (if None, 'test' is used seed: int Seed for random initialization Returns ------- recording: RecordingExtractor The output recording extractor. If dumpable is False it's a NumpyRecordingExtractor, otherwise it's an MdaRecordingExtractor sorting: SortingExtractor The output sorting extractor. If dumpable is False it's a NumpyRecordingExtractor, otherwise it's an NpzSortingExtractor """ upsamplefac = 13 waveforms, geom = synthesize_random_waveforms(K=K, M=num_channels, average_peak_amplitude=-100, upsamplefac=upsamplefac, seed=seed) times, labels = synthesize_random_firings(K=K, duration=duration, sampling_frequency=sampling_frequency, seed=seed) labels = labels.astype(np.int64) SX = se.NumpySortingExtractor() SX.set_times_labels(times, labels) X = synthesize_timeseries(sorting=SX, waveforms=waveforms, noise_level=10, sampling_frequency=sampling_frequency, duration=duration, waveform_upsamplefac=upsamplefac, seed=seed) SX.set_sampling_frequency(sampling_frequency) RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) RX.is_filtered = True if dumpable: if dump_folder is None: dump_folder = 'toy_example' dump_folder = Path(dump_folder) se.MdaRecordingExtractor.write_recording(RX, dump_folder) RX = se.MdaRecordingExtractor(dump_folder) se.NpzSortingExtractor.write_sorting(SX, dump_folder / 'sorting.npz') SX = se.NpzSortingExtractor(dump_folder / 'sorting.npz') return RX, SX ================================================ FILE: spikeextractors/exceptions.py ================================================ class NotDumpableExtractorError(TypeError): """Raised whenever current extractor cannot be dumped""" ================================================ FILE: spikeextractors/extraction_tools.py ================================================ import numpy as np import csv import os import sys from pathlib import Path import warnings import datetime from functools import wraps from .baseextractor import BaseExtractor from tqdm import tqdm from joblib import Parallel, delayed try: import h5py HAVE_H5 = True except ImportError: HAVE_H5 = False def read_python(path): """Parses python scripts in a dictionary Parameters ---------- path: str or Path Path to file to parse Returns ------- metadata: dictionary containing parsed file """ from six import exec_ import re path = Path(path).absolute() assert path.is_file() with path.open('r') as f: contents = f.read() contents = re.sub(r'range\(([\d,]*)\)',r'list(range(\1))',contents) metadata = {} exec_(contents, {}, metadata) metadata = {k.lower(): v for (k, v) in metadata.items()} return metadata def write_python(path, dict): """Saves python dictionary to file Parameters ---------- path: str or Path Path to save file dict: dict dictionary to save """ with Path(path).open('w') as f: for k, v in dict.items(): if isinstance(v ,str) and not v.startswith("'"): if 'path' in k and 'win' in sys.platform: f.write(str(k) + " = r'" + str(v) + "'\n") else: f.write(str(k) + " = '" + str(v) + "'\n") else: f.write(str(k) + " = " + str(v) + "\n") def load_probe_file(recording, probe_file, channel_map=None, channel_groups=None, verbose=False): """This function returns a SubRecordingExtractor that contains information from the given probe file (channel locations, groups, etc.) If a .prb file is given, then 'location' and 'group' information for each channel is added to the SubRecordingExtractor. If a .csv file is given, then it will only add 'location' to the SubRecordingExtractor. Parameters ---------- recording: RecordingExtractor The recording extractor to load channel information from. probe_file: str Path to probe file. Either .prb or .csv channel_map : array-like A list of channel IDs to set in the loaded file. Only used if the loaded file is a .csv. channel_groups : array-like A list of groups (ints) for the channel_ids to set in the loaded file. Only used if the loaded file is a .csv. verbose: bool If True, output is verbose Returns --------- subrecording: SubRecordingExtractor The extractor containing all of the probe information. """ from .subrecordingextractor import SubRecordingExtractor probe_file = Path(probe_file) if probe_file.suffix == '.prb': probe_dict = read_python(probe_file) if 'channel_groups' in probe_dict.keys(): ordered_channels = np.array([], dtype=int) groups = sorted(probe_dict['channel_groups'].keys()) for cgroup_id in groups: cgroup = probe_dict['channel_groups'][cgroup_id] for key_prop, prop_val in cgroup.items(): if key_prop == 'channels': ordered_channels = np.concatenate((ordered_channels, prop_val)) if not np.all([chan in recording.get_channel_ids() for chan in ordered_channels]) and verbose: print('Some channel in PRB file are not in original recording') present_ordered_channels = [chan for chan in ordered_channels if chan in recording.get_channel_ids()] subrecording = SubRecordingExtractor(recording, channel_ids=present_ordered_channels) for cgroup_id in groups: cgroup = probe_dict['channel_groups'][cgroup_id] if 'channels' not in cgroup.keys() and len(groups) > 1: raise Exception("If more than one 'channel_group' is in the probe file, the 'channels' field" "for each channel group is required") elif 'channels' not in cgroup.keys(): channels_in_group = subrecording.get_num_channels() channels_id_in_group = subrecording.get_channel_ids() else: channels_in_group = len(cgroup['channels']) channels_id_in_group = cgroup['channels'] for key_prop, prop_val in cgroup.items(): if key_prop == 'channels': for i_ch, prop in enumerate(prop_val): if prop in subrecording.get_channel_ids(): subrecording.set_channel_groups(int(cgroup_id), channel_ids=prop) elif key_prop == 'geometry' or key_prop == 'location': if isinstance(prop_val, dict): if len(prop_val.keys()) != channels_in_group and verbose: print('geometry in PRB does not have the same length as channel in group') for (i_ch, prop) in prop_val.items(): if i_ch in subrecording.get_channel_ids(): subrecording.set_channel_locations(prop, channel_ids=i_ch) elif isinstance(prop_val, (list, np.ndarray)) and len(prop_val) == channels_in_group: if 'channels' not in cgroup.keys(): raise Exception("'geometry'/'location' in the .prb file can be a list only if " "'channels' field is specified.") if len(prop_val) != channels_in_group and verbose: print('geometry in PRB does not have the same length as channel in group') for (i_ch, prop) in zip(channels_id_in_group, prop_val): if i_ch in subrecording.get_channel_ids(): subrecording.set_channel_locations(prop, channel_ids=i_ch) else: if isinstance(prop_val, dict) and len(prop_val.keys()) == channels_in_group: for (i_ch, prop) in prop_val.items(): if i_ch in subrecording.get_channel_ids(): subrecording.set_channel_property(i_ch, key_prop, prop) elif isinstance(prop_val, (list, np.ndarray)) and len(prop_val) == channels_in_group: for (i_ch, prop) in zip(channels_id_in_group, prop_val): if i_ch in subrecording.get_channel_ids(): subrecording.set_channel_property(i_ch, key_prop, prop) # create dummy locations if 'geometry' not in cgroup.keys() and 'location' not in cgroup.keys(): if 'location' not in subrecording.get_shared_channel_property_names(): locs = np.zeros((subrecording.get_num_channels(), 2)) locs[:, 1] = np.arange(subrecording.get_num_channels()) subrecording.set_channel_locations(locs) else: raise AttributeError("'.prb' file should contain the 'channel_groups' field") elif probe_file.suffix == '.csv': if channel_map is not None: assert np.all([chan in channel_map for chan in recording.get_channel_ids()]), \ "all channel_ids in 'channel_map' must be in the original recording channel ids" subrecording = SubRecordingExtractor(recording, channel_ids=channel_map) else: subrecording = SubRecordingExtractor(recording, channel_ids=recording.get_channel_ids()) with probe_file.open() as csvfile: posreader = csv.reader(csvfile) row_count = 0 loaded_pos = [] for pos in (posreader): row_count += 1 loaded_pos.append(pos) assert len(subrecording.get_channel_ids()) == row_count, "The .csv file must contain as many " \ "rows as the number of channels in the recordings" for i_ch, pos in zip(subrecording.get_channel_ids(), loaded_pos): if i_ch in subrecording.get_channel_ids(): subrecording.set_channel_locations(list(np.array(pos).astype(float)), i_ch) if channel_groups is not None and len(channel_groups) == len(subrecording.get_channel_ids()): for i_ch, chg in zip(subrecording.get_channel_ids(), channel_groups): if i_ch in subrecording.get_channel_ids(): subrecording.set_channel_groups(chg, i_ch) else: raise NotImplementedError("Only .csv and .prb probe files can be loaded.") subrecording._kwargs['probe_file'] = str(probe_file.absolute()) return subrecording def save_to_probe_file(recording, probe_file, grouping_property=None, radius=None, graph=True, geometry=True, verbose=False): """Saves probe file from the channel information of the given recording extractor. Parameters ---------- recording: RecordingExtractor The recording extractor to save probe file from probe_file: str file name of .prb or .csv file to save probe information to grouping_property: str (default None) If grouping_property is a shared_channel_property, different groups are saved based on the property. radius: float (default None) Adjacency radius (used by some sorters). If None it is not saved to the probe file. graph: bool If True, the adjacency graph is saved (default=True) geometry: bool If True, the geometry is saved (default=True) verbose: bool If True, output is verbose """ probe_file = Path(probe_file) if not probe_file.parent.is_dir(): probe_file.parent.mkdir() if probe_file.suffix == '.csv': # write csv probe file with probe_file.open('w') as f: if 'location' in recording.get_shared_channel_property_names(): for chan in recording.get_channel_ids(): loc = recording.get_channel_locations(chan)[0] if len(loc) == 2: f.write(str(loc[0])) f.write(',') f.write(str(loc[1])) f.write('\n') elif len(loc) == 3: f.write(str(loc[0])) f.write(',') f.write(str(loc[1])) f.write(',') f.write(str(loc[2])) f.write('\n') else: raise AttributeError("Recording extractor needs to have " "'location' property to save .csv probe file") elif probe_file.suffix == '.prb': _export_prb_file(recording, probe_file, grouping_property=grouping_property, radius=radius, graph=graph, geometry=geometry, verbose=verbose) else: raise NotImplementedError("Only .csv and .prb probe files can be saved.") def read_binary(file, numchan, dtype, time_axis=0, offset=0): """ Reads binary .bin or .dat file. Parameters ---------- file: str File name numchan: int Number of channels dtype: dtype dtype of the file time_axis: 0 (default) or 1 If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file. If 1, the traces shape (nb_channel, nb_sample) is kept in the file. offset: int number of offset bytes """ numchan = int(numchan) with Path(file).open() as f: nsamples = (os.fstat(f.fileno()).st_size - offset) // (numchan * np.dtype(dtype).itemsize) if time_axis == 0: samples = np.memmap(file, np.dtype(dtype), mode='r', offset=offset, shape=(nsamples, numchan)).T else: samples = np.memmap(file, np.dtype(dtype), mode='r', offset=offset, shape=(numchan, nsamples)) return samples def write_to_binary_dat_format(recording, save_path=None, file_handle=None, time_axis=0, dtype=None, chunk_size=None, chunk_mb=500, n_jobs=1, joblib_backend='loky', return_scaled=True, verbose=False): """Saves the traces of a recording extractor in binary .dat format. Parameters ---------- recording: RecordingExtractor The recording extractor object to be saved in .dat format save_path: str The path to the file. file_handle: file handle The file handle to dump data. This can be used to append data to an header. In case file_handle is given, the file is NOT closed after writing the binary data. time_axis: 0 (default) or 1 If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file. If 1, the traces shape (nb_channel, nb_sample) is kept in the file. dtype: dtype Type of the saved data. Default float32. chunk_size: None or int Size of each chunk in number of frames. If None (default) and 'chunk_mb' is given, the file is saved in chunks of 'chunk_mb' Mb (default 500Mb) chunk_mb: None or int Chunk size in Mb (default 500Mb) n_jobs: int Number of jobs to use (Default 1) joblib_backend: str Joblib backend for parallel processing ('loky', 'threading', 'multiprocessing') return_scaled: bool If True, traces are written after scaling (using gain/offset). If False, the raw traces are written verbose: bool If True, output is verbose (when chunks are used) """ assert save_path is not None or file_handle is not None, "Provide 'save_path' or 'file handle'" if save_path is not None: save_path = Path(save_path) if save_path.suffix == '': # when suffix is already raw/bin/dat do not change it. save_path = save_path.parent / (save_path.name + '.dat') if chunk_size is not None or chunk_mb is not None: if time_axis == 1: print("Chunking disabled due to 'time_axis' == 1") chunk_size = None chunk_mb = None # set chunk size if chunk_size is not None: chunk_size = int(chunk_size) elif chunk_mb is not None: n_bytes = np.dtype(recording.get_dtype()).itemsize max_size = int(chunk_mb * 1e6) # set Mb per chunk chunk_size = max_size // (recording.get_num_channels() * n_bytes) if n_jobs is None: n_jobs = 1 if n_jobs == 0: n_jobs = 1 if n_jobs > 1: if chunk_size is not None: chunk_size /= n_jobs if not recording.check_if_dumpable(): if n_jobs > 1: n_jobs = 1 print("RecordingExtractor is not dumpable and can't be processed in parallel") rec_arg = recording else: if n_jobs > 1: rec_arg = recording.dump_to_dict() else: rec_arg = recording if chunk_size is None: traces = recording.get_traces(return_scaled=return_scaled) if dtype is not None: traces = traces.astype(dtype) if time_axis == 0: traces = traces.T if save_path is not None: with save_path.open('wb') as f: traces.tofile(f) else: traces.tofile(file_handle) else: # chunk size is not None num_frames = recording.get_num_frames() num_channels = recording.get_num_channels() # chunk_size = num_bytes_per_chunk / num_bytes_per_frame chunks = divide_recording_into_time_chunks( num_frames=num_frames, chunk_size=chunk_size, padding_size=0 ) n_chunk = len(chunks) if verbose and n_jobs == 1: chunks_loop = tqdm(range(n_chunk), ascii=True, desc="Writing to binary .dat file") else: chunks_loop = range(n_chunk) if save_path is not None: if n_jobs == 1: if time_axis == 0: shape = (num_frames, num_channels) else: shape = (num_channels, num_frames) rec_memmap = np.memmap(str(save_path), dtype=dtype, mode='w+', shape=shape) for i in chunks_loop: _write_dat_one_chunk(i, rec_arg, chunks, rec_memmap, dtype, time_axis, return_scaled, verbose=False) else: if time_axis == 0: shape = (num_frames, num_channels) else: shape = (num_channels, num_frames) rec_memmap = np.memmap(str(save_path), dtype=dtype, mode='w+', shape=shape) Parallel(n_jobs=n_jobs, backend=joblib_backend)( delayed(_write_dat_one_chunk)(i, rec_arg, chunks, rec_memmap, dtype, time_axis, return_scaled, verbose,) for i in chunks_loop) else: for i in chunks_loop: start_frame = chunks[i]['istart'] end_frame = chunks[i]['iend'] traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, return_scaled=return_scaled) if dtype is not None: traces = traces.astype(dtype) if time_axis == 0: traces = traces.T file_handle.write(traces.tobytes()) return save_path def write_to_h5_dataset_format(recording, dataset_path, save_path=None, file_handle=None, time_axis=0, dtype=None, chunk_size=None, chunk_mb=500, verbose=False): """Saves the traces of a recording extractor in an h5 dataset. Parameters ---------- recording: RecordingExtractor The recording extractor object to be saved in .dat format dataset_path: str Path to dataset in h5 filee (e.g. '/dataset') save_path: str The path to the file. file_handle: file handle The file handle to dump data. This can be used to append data to an header. In case file_handle is given, the file is NOT closed after writing the binary data. time_axis: 0 (default) or 1 If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file. If 1, the traces shape (nb_channel, nb_sample) is kept in the file. dtype: dtype Type of the saved data. Default float32. chunk_size: None or int Size of each chunk in number of frames. If None (default) and 'chunk_mb' is given, the file is saved in chunks of 'chunk_mb' Mb (default 500Mb) chunk_mb: None or int Chunk size in Mb (default 500Mb) verbose: bool If True, output is verbose (when chunks are used) """ assert HAVE_H5, "To write to h5 you need to install h5py: pip install h5py" assert save_path is not None or file_handle is not None, "Provide 'save_path' or 'file handle'" if save_path is not None: save_path = Path(save_path) if save_path.suffix == '': # when suffix is already raw/bin/dat do not change it. save_path = save_path.parent / (save_path.name + '.h5') num_channels = recording.get_num_channels() num_frames = recording.get_num_frames() if file_handle is not None: assert isinstance(file_handle, h5py.File) else: file_handle = h5py.File(save_path, 'w') if dtype is None: dtype_file = recording.get_dtype() else: dtype_file = dtype if time_axis == 0: dset = file_handle.create_dataset(dataset_path, shape=(num_frames, num_channels), dtype=dtype_file) else: dset = file_handle.create_dataset(dataset_path, shape=(num_channels, num_frames), dtype=dtype_file) # set chunk size if chunk_size is not None: chunk_size = int(chunk_size) elif chunk_mb is not None: n_bytes = np.dtype(recording.get_dtype()).itemsize max_size = int(chunk_mb * 1e6) # set Mb per chunk chunk_size = max_size // (num_channels * n_bytes) if chunk_size is None: traces = recording.get_traces() if dtype is not None: traces = traces.astype(dtype_file) if time_axis == 0: traces = traces.T dset[:] = traces else: chunk_start = 0 # chunk size is not None n_chunk = num_frames // chunk_size if num_frames % chunk_size > 0: n_chunk += 1 if verbose: chunks = tqdm(range(n_chunk), ascii=True, desc="Writing to .h5 file") else: chunks = range(n_chunk) for i in chunks: traces = recording.get_traces(start_frame=i * chunk_size, end_frame=min((i + 1) * chunk_size, num_frames)) chunk_frames = traces.shape[1] if dtype is not None: traces = traces.astype(dtype_file) if time_axis == 0: dset[chunk_start:chunk_start + chunk_frames] = traces.T else: dset[:, chunk_start:chunk_start + chunk_frames] = traces chunk_start += chunk_frames if save_path is not None: file_handle.close() return save_path def get_sub_extractors_by_property(extractor, property_name, return_property_list=False): """Returns a list of SubExtractors from the Extractor based on the given property_name (e.g. group) Parameters ---------- extractor: RecordingExtractor or SortingExtractor The extractor object to access SubRecordingExtractors from. property_name: str The property used to subdivide the extractor return_property_list: bool If True the property list is returned Returns ------- sub_list: list The list of subextractors to be returned. OR sub_list, prop_list If return_property_list is True, the property list will be returned as well. """ from spikeextractors import RecordingExtractor, SortingExtractor, SubRecordingExtractor, SubSortingExtractor if isinstance(extractor, RecordingExtractor): if property_name not in extractor.get_shared_channel_property_names(): raise ValueError("'property_name' must be must be a property of the recording channels") else: sub_list = [] recording = extractor properties = np.array([recording.get_channel_property(chan, property_name) for chan in recording.get_channel_ids()]) prop_list = np.unique(properties) for prop in prop_list: prop_idx = np.where(prop == properties) chan_idx = list(np.array(recording.get_channel_ids())[prop_idx]) sub_list.append(SubRecordingExtractor(recording, channel_ids=chan_idx)) if return_property_list: return sub_list, prop_list else: return sub_list elif isinstance(extractor, SortingExtractor): if property_name not in extractor.get_shared_unit_property_names(): raise ValueError("'property_name' must be must be a property of the units") else: sub_list = [] sorting = extractor properties = np.array([sorting.get_unit_property(unit, property_name) for unit in sorting.get_unit_ids()]) prop_list = np.unique(properties) for prop in prop_list: prop_idx = np.where(prop == properties) unit_idx = list(np.array(sorting.get_unit_ids())[prop_idx]) sub_list.append(SubSortingExtractor(sorting, unit_ids=unit_idx)) if return_property_list: return sub_list, prop_list else: return sub_list else: raise ValueError("'extractor' must be a RecordingExtractor or a SortingExtractor") def _export_prb_file(recording, file_name, grouping_property=None, graph=True, geometry=True, radius=None, adjacency_distance=100, verbose=False): """Exports .prb file Parameters ---------- recording: RecordingExtractor The recording extractor to save probe file from file_name: str probe filename to be exported to grouping_property: str (default None) If grouping_property is a shared_channel_property, different groups are saved based on the property. graph: bool If True, the adjacency graph is saved (default=True) geometry: bool If True, the geometry is saved (default=True) radius: float (default None) Adjacency radius (used by some sorters). If None it is not saved to the probe file. adjacency_distance: float Distance to consider two channels to adjacent (if 'location' is a property). If radius is given, then adjacency_distance is set to the radius. verbose : bool If True, output is verbose """ file_name = Path(file_name) assert file_name is not None abspath = file_name.absolute() if radius is not None: adjacency_distance = radius if geometry: if 'location' in recording.get_shared_channel_property_names(): positions = recording.get_channel_locations() else: if verbose: print("'location' property is not available and it will not be saved.") positions = None geometry = False else: positions = None if grouping_property is not None: if grouping_property in recording.get_shared_channel_property_names(): grouping_property_groups = np.array([recording.get_channel_property(chan, grouping_property) for chan in recording.get_channel_ids()]) channel_groups = np.unique([grouping_property_groups]) else: if verbose: print(f"{grouping_property} property is not available and it will not be saved.") channel_groups = [0] grouping_property_groups = np.array([0] * recording.get_num_channels()) else: channel_groups = [0] grouping_property_groups = np.array([0] * recording.get_num_channels()) n_elec = recording.get_num_channels() # find adjacency graph if graph: if positions is not None and adjacency_distance is not None: adj_graph = [] for chg in channel_groups: group_graph = [] elecs = list(np.where(grouping_property_groups == chg)[0]) for i in range(len(elecs)): for j in range(i, len(elecs)): if elecs[i] != elecs[j]: if np.linalg.norm(positions[elecs[i]] - positions[elecs[j]]) < adjacency_distance: group_graph.append((elecs[i], elecs[j])) adj_graph.append(group_graph) else: # all connected by group adj_graph = [] for chg in channel_groups: group_graph = [] elecs = list(np.where(grouping_property_groups == chg)[0]) for i in range(len(elecs)): for j in range(i, len(elecs)): if elecs[i] != elecs[j]: group_graph.append((elecs[i], elecs[j])) adj_graph.append(group_graph) with abspath.open('w') as f: f.write('total_nb_channels = ' + str(n_elec) + '\n') if radius is not None: f.write('radius = ' + str(radius) + '\n') f.write('channel_groups = {\n') if len(channel_groups) > 0: for i_chg, chg in enumerate(channel_groups): f.write(" " + str(int(chg)) + ": ") elecs = list(np.where(grouping_property_groups == chg)[0]) f.write("\n {\n") f.write(" 'channels': " + str(elecs) + ',\n') if graph: if len(adj_graph) == 1: f.write(" 'graph': " + str(adj_graph[0]) + ',\n') else: f.write(" 'graph': " + str(adj_graph[i_chg]) + ',\n') if geometry: f.write(" 'geometry': {\n") for i, pos in enumerate(positions[elecs]): f.write(' ' + str(elecs[i]) + ': ' + str(list(pos)) + ',\n') f.write(' }\n') f.write(' },\n') f.write('}\n') else: for elec in range(n_elec): f.write(' ' + str(elec) + ': ') f.write("\n {\n") f.write(" 'channels': [" + str(elec) + '],\n') f.write(" 'graph': [],\n") f.write(' },\n') f.write('}\n') def _check_json(d): # quick hack to ensure json writable for k, v in d.items(): if isinstance(v, Path): d[k] = str(v) elif isinstance(v, (int, np.integer)): d[k] = int(v) elif isinstance(v, float): d[k] = float(v) elif isinstance(v, datetime.datetime): d[k] = v.isoformat() return d def load_extractor_from_json(json_file): """ Instantiates extractor from json file Parameters ---------- json_file: str or Path Path to json file Returns ------- extractor: RecordingExtractor or SortingExtractor The loaded extractor object """ return BaseExtractor.load_extractor_from_json(json_file) def load_extractor_from_dict(d): """ Instantiates extractor from dictionary Parameters ---------- d: dictionary Python dictionary Returns ------- extractor: RecordingExtractor or SortingExtractor The loaded extractor object """ return BaseExtractor.load_extractor_from_dict(d) def load_extractor_from_pickle(pkl_file): """ Instantiates extractor from pickle file Parameters ---------- pkl_file: str or Path Path to pickle file Returns ------- extractor: RecordingExtractor or SortingExtractor The loaded extractor object """ return BaseExtractor.load_extractor_from_pickle(pkl_file) def check_get_unit_spike_train(func): @wraps(func) def check_validity(sorting, unit_id, start_frame=None, end_frame=None): # parse args and kwargs if unit_id is None: raise TypeError("get_unit_spike_train() missing 1 required positional argument: 'unit_id')") elif not (isinstance(unit_id, (int, np.integer))): raise ValueError("unit_id must be an integer") elif unit_id not in sorting.get_unit_ids(): raise ValueError(f"{unit_id} is an invalid unit id") start_frame, end_frame = cast_start_end_frame(start_frame, end_frame) if start_frame is None: start_frame = 0 if end_frame is None: end_frame = np.Inf return func(sorting, unit_id, start_frame=start_frame, end_frame=end_frame) return check_validity def check_get_traces_args(func): @wraps(func) def corrected_args(recording, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True, **kwargs): if channel_ids is not None: if isinstance(channel_ids, (int, np.integer)): channel_ids = list([channel_ids]) else: channel_ids = channel_ids if np.any([ch not in recording.get_channel_ids() for ch in channel_ids]): print("Removing invalid 'channel_ids'", [ch for ch in channel_ids if ch not in recording.get_channel_ids()]) channel_ids = [ch for ch in channel_ids if ch in recording.get_channel_ids()] else: channel_ids = recording.get_channel_ids() if start_frame is not None: if start_frame < 0: start_frame = recording.get_num_frames() + start_frame else: start_frame = 0 if end_frame is not None: if end_frame > recording.get_num_frames(): print("'end_frame' set to", recording.get_num_frames()) end_frame = recording.get_num_frames() elif end_frame < 0: end_frame = recording.get_num_frames() + end_frame else: end_frame = recording.get_num_frames() assert end_frame - start_frame > 0, "'start_frame' must be less than 'end_frame'!" start_frame, end_frame = cast_start_end_frame(start_frame, end_frame) if not recording.has_unscaled and not return_scaled: warnings.warn("The recording extractor does not have unscaled traces. Returning scaled traces") return_scaled = True traces = func(recording, channel_ids=channel_ids, start_frame=start_frame, end_frame=end_frame, return_scaled=return_scaled, **kwargs) # scaling if recording.has_unscaled and return_scaled: channel_idxs = np.array([recording.get_channel_ids().index(ch) for ch in channel_ids]) gains = recording.get_channel_gains()[channel_idxs, None] offsets = recording.get_channel_offsets()[channel_idxs, None] traces = (traces.astype("float32") * gains + offsets).astype("float32") return traces return corrected_args def check_get_ttl_args(func): @wraps(func) def corrected_args(recording, start_frame=None, end_frame=None, channel_id=0, **kwargs): if start_frame is not None: if start_frame < 0: start_frame = recording.get_num_frames() + start_frame else: start_frame = 0 if end_frame is not None: if end_frame > recording.get_num_frames(): print("'end_frame' set to", recording.get_num_frames()) end_frame = recording.get_num_frames() elif end_frame < 0: end_frame = recording.get_num_frames() + end_frame else: end_frame = recording.get_num_frames() assert end_frame - start_frame > 0, "'start_frame' must be less than 'end_frame'!" assert isinstance(channel_id, (int, np.integer)), "'channel_id' must be a single int" start_frame, end_frame = cast_start_end_frame(start_frame, end_frame) # pass recording as arg and rest as kwargs get_ttl_correct_arg = func(recording, start_frame=start_frame, end_frame=end_frame, channel_id=channel_id, **kwargs) return get_ttl_correct_arg return corrected_args def cast_start_end_frame(start_frame, end_frame): if isinstance(start_frame, float): start_frame = int(start_frame) elif isinstance(start_frame, (int, np.integer, type(None))): start_frame = start_frame else: raise ValueError("start_frame must be an int, float (not infinity), or None") if isinstance(end_frame, float) and np.isfinite(end_frame): end_frame = int(end_frame) elif isinstance(end_frame, (int, np.integer, type(None))): end_frame = end_frame # else end_frame is infinity (accepted for get_unit_spike_train) if start_frame is not None: start_frame = int(start_frame) if end_frame is not None and np.isfinite(end_frame): end_frame = int(end_frame) return start_frame, end_frame def divide_recording_into_time_chunks(num_frames, chunk_size, padding_size): chunks = [] ii = 0 while ii < num_frames: ii2 = int(min(ii + chunk_size, num_frames)) chunks.append(dict( istart=ii, iend=ii2, istart_with_padding=int(max(0, ii - padding_size)), iend_with_padding=int(min(num_frames, ii2 + padding_size)) )) ii = ii2 return chunks def _write_dat_one_chunk(i, rec_arg, chunks, rec_memmap, dtype, time_axis, return_scaled, verbose): chunk = chunks[i] if verbose: print(f"Writing chunk {i + 1} / {len(chunks)}") if isinstance(rec_arg, dict): recording = load_extractor_from_dict(rec_arg) else: recording = rec_arg start_frame = chunk['istart'] end_frame = chunk['iend'] traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, return_scaled=return_scaled) if dtype is not None: traces = traces.astype(dtype) if time_axis == 0: traces = traces.T rec_memmap[start_frame:end_frame, :] = traces else: rec_memmap[:, start_frame:end_frame] = traces ================================================ FILE: spikeextractors/extractorlist.py ================================================ from .extractors.mdaextractors.mdaextractors import MdaRecordingExtractor, MdaSortingExtractor from .extractors.mearecextractors.mearecextractors import MEArecRecordingExtractor, MEArecSortingExtractor from .extractors.biocamrecordingextractor.biocamrecordingextractor import BiocamRecordingExtractor from .extractors.exdirextractors.exdirextractors import ExdirRecordingExtractor, ExdirSortingExtractor from .extractors.intanrecordingextractor.intanrecordingextractor import IntanRecordingExtractor from .extractors.hdsortsortingextractor.hdsortsortingextractor import HDSortSortingExtractor from .extractors.hs2sortingextractor.hs2sortingextractor import HS2SortingExtractor from .extractors.klustaextractors.klustaextractors import KlustaSortingExtractor, KlustaRecordingExtractor from .extractors.kilosortextractors.kilosortextractors import KiloSortSortingExtractor, KiloSortRecordingExtractor from .extractors.numpyextractors.numpyextractors import NumpyRecordingExtractor, NumpySortingExtractor from .extractors.nwbextractors.nwbextractors import NwbRecordingExtractor, NwbSortingExtractor from .extractors.openephysextractors.openephysextractors import OpenEphysRecordingExtractor, \ OpenEphysSortingExtractor, OpenEphysNPIXRecordingExtractor from .extractors.maxwellextractors import MaxOneRecordingExtractor, MaxOneSortingExtractor, MaxTwoRecordingExtractor, \ MaxTwoSortingExtractor from .extractors.phyextractors.phyextractors import PhyRecordingExtractor, PhySortingExtractor from .extractors.bindatrecordingextractor.bindatrecordingextractor import BinDatRecordingExtractor from .extractors.spykingcircusextractors.spykingcircusextractors import SpykingCircusSortingExtractor, \ SpykingCircusRecordingExtractor from .extractors.spikeglxrecordingextractor.spikeglxrecordingextractor import SpikeGLXRecordingExtractor from .extractors.tridescloussortingextractor.tridescloussortingextractor import TridesclousSortingExtractor from .extractors.npzsortingextractor.npzsortingextractor import NpzSortingExtractor from .extractors.mcsh5recordingextractor.mcsh5recordingextractor import MCSH5RecordingExtractor from .extractors.shybridextractors import SHYBRIDRecordingExtractor, SHYBRIDSortingExtractor from .extractors.nixioextractors.nixioextractors import NIXIORecordingExtractor, NIXIOSortingExtractor from .extractors.neoextractors import (AxonaRecordingExtractor, PlexonRecordingExtractor, PlexonSortingExtractor, NeuralynxRecordingExtractor, NeuralynxSortingExtractor, BlackrockRecordingExtractor, BlackrockSortingExtractor, MCSRawRecordingExtractor, SpikeGadgetsRecordingExtractor) from .extractors.neuroscopeextractors import NeuroscopeRecordingExtractor, NeuroscopeMultiRecordingTimeExtractor, \ NeuroscopeSortingExtractor, NeuroscopeMultiSortingExtractor from .extractors.waveclussortingextractor import WaveClusSortingExtractor from .extractors.yassextractors import YassSortingExtractor from .extractors.combinatosortingextractor import CombinatoSortingExtractor from .extractors.alfsortingextractor import ALFSortingExtractor from .extractors.cedextractors import CEDRecordingExtractor from .extractors.cellexplorersortingextractor import CellExplorerSortingExtractor from .extractors.neuropixelsdatrecordingextractor import NeuropixelsDatRecordingExtractor from .extractors.axonaunitrecordingextractor import AxonaUnitRecordingExtractor recording_extractor_full_list = [ MdaRecordingExtractor, MEArecRecordingExtractor, BiocamRecordingExtractor, ExdirRecordingExtractor, OpenEphysRecordingExtractor, OpenEphysNPIXRecordingExtractor, IntanRecordingExtractor, BinDatRecordingExtractor, KlustaRecordingExtractor, KiloSortRecordingExtractor, SpykingCircusRecordingExtractor, SpikeGLXRecordingExtractor, PhyRecordingExtractor, MaxOneRecordingExtractor, MaxTwoRecordingExtractor, MCSH5RecordingExtractor, SHYBRIDRecordingExtractor, NIXIORecordingExtractor, NwbRecordingExtractor, NeuroscopeRecordingExtractor, NeuroscopeMultiRecordingTimeExtractor, CEDRecordingExtractor, NeuropixelsDatRecordingExtractor, AxonaUnitRecordingExtractor, # neo based AxonaRecordingExtractor, PlexonRecordingExtractor, NeuralynxRecordingExtractor, BlackrockRecordingExtractor, MCSRawRecordingExtractor, SpikeGadgetsRecordingExtractor, ] recording_extractor_dict = {recording_class.extractor_name: recording_class for recording_class in recording_extractor_full_list} installed_recording_extractor_list = [rx for rx in recording_extractor_full_list if rx.installed] sorting_extractor_full_list = [ MdaSortingExtractor, MEArecSortingExtractor, ExdirSortingExtractor, HDSortSortingExtractor, HS2SortingExtractor, KlustaSortingExtractor, KiloSortSortingExtractor, OpenEphysSortingExtractor, PhySortingExtractor, SpykingCircusSortingExtractor, TridesclousSortingExtractor, MaxTwoSortingExtractor, MaxOneSortingExtractor, NpzSortingExtractor, SHYBRIDSortingExtractor, NIXIOSortingExtractor, NeuroscopeSortingExtractor, NeuroscopeMultiSortingExtractor, NwbSortingExtractor, WaveClusSortingExtractor, YassSortingExtractor, CombinatoSortingExtractor, ALFSortingExtractor, # neo based PlexonSortingExtractor, NeuralynxSortingExtractor, BlackrockSortingExtractor, CellExplorerSortingExtractor ] installed_sorting_extractor_list = [sx for sx in sorting_extractor_full_list if sx.installed] sorting_extractor_dict = {sorting_class.extractor_name: sorting_class for sorting_class in sorting_extractor_full_list} writable_sorting_extractor_list = [sx for sx in installed_sorting_extractor_list if sx.is_writable] writable_sorting_extractor_dict = {sorting_class.extractor_name: sorting_class for sorting_class in writable_sorting_extractor_list} ================================================ FILE: spikeextractors/extractors/__init__.py ================================================ ================================================ FILE: spikeextractors/extractors/alfsortingextractor/__init__.py ================================================ from .alfsortingextractor import ALFSortingExtractor ================================================ FILE: spikeextractors/extractors/alfsortingextractor/alfsortingextractor.py ================================================ from abc import ABC from spikeextractors import SortingExtractor from pathlib import Path import numpy as np try: import pandas as pd HAVE_PANDAS = True except: HAVE_PANDAS = False class ALFSortingExtractor(SortingExtractor): extractor_name = 'ALFSorting' installed = HAVE_PANDAS # check at class level if installed or not is_writable = True mode = 'folder' installation_mesg = "To use the ALFSortingExtractor run:\n\n pip install pandas\n\n" def __init__(self, folder_path, sampling_frequency=30000): assert self.installed, self.installation_mesg SortingExtractor.__init__(self) # check correct parent folder: self.file_loc = Path(folder_path) if 'probe' not in Path(self.file_loc).name: raise ValueError('folder name should contain "probe", containing channels, clusters.* .npy datasets') # load datasets as mmap into a dict: self._required_alf_datasets = ['spikes.times', 'spikes.clusters'] self._found_alf_datasets = dict() for alf_dataset_name in self.file_loc.iterdir(): if 'spikes' in alf_dataset_name.stem or 'clusters' in alf_dataset_name.stem: if 'npy' in alf_dataset_name.suffix: self._found_alf_datasets.update({alf_dataset_name.stem: self._load_npy(alf_dataset_name)}) elif 'metrics' in alf_dataset_name.stem: self._found_alf_datasets.update({alf_dataset_name.stem: pd.read_csv(alf_dataset_name)}) # check existence of datasets: if not any([i in self._found_alf_datasets for i in self._required_alf_datasets]): raise Exception(f'could not find {self._required_alf_datasets} in folder') # setting units properties: self._total_units = 0 for alf_dataset_name, alf_dataset in self._found_alf_datasets.items(): if 'clusters' in alf_dataset_name: if 'clusters.metrics' in alf_dataset_name: for property_name, property_values in self._found_alf_datasets[alf_dataset_name].iteritems(): self.set_units_property(unit_ids=self.get_unit_ids(), property_name=property_name, values=property_values.tolist()) else: self.set_units_property(unit_ids=self.get_unit_ids(), property_name=alf_dataset_name.split('.')[1], values=alf_dataset) if self._total_units == 0: self._total_units = alf_dataset.shape[0] self._units_map = {i: j for i, j in zip(self.get_unit_ids(), list(range(self._total_units)))} self._units_raster = [] self._sampling_frequency = sampling_frequency self._kwargs = {'folder_path': str(Path(folder_path).absolute()), 'sampling_frequency': sampling_frequency} def _load_npy(self, npy_path): return np.load(npy_path, mmap_mode='r',allow_pickle=True) def _get_clusters_spike_times(self, cluster_idx): if len(self._units_raster) == 0: spike_cluster_data = self._found_alf_datasets['spikes.clusters'] spike_times_data = self._found_alf_datasets['spikes.times'] df = pd.DataFrame({'sp_cluster': spike_cluster_data, 'sp_times': spike_times_data}) data = df.groupby(['sp_cluster'])['sp_times'].apply(np.array).reset_index(name='sp_times_group') self._max_time = 0 self._units_raster = [None]*self._total_units for index, sp_times_list in data.values: self._units_raster[index] = sp_times_list max_time = max(sp_times_list) if max_time > self._max_time: self._max_time = max_time return self._units_raster[cluster_idx] def get_unit_ids(self): if 'clusters.metrics' in self._found_alf_datasets and \ self._found_alf_datasets['clusters.metrics'].get('cluster_id') is not None: unit_ids = self._found_alf_datasets['clusters.metrics'].get('cluster_id').tolist() else: unit_ids = list(range(self._total_units)) return unit_ids def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): """Code to extract spike frames from the specified unit. It will return spike frames from within three ranges: [start_frame, t_start+1, ..., end_frame-1] [start_frame, start_frame+1, ..., final_unit_spike_frame - 1] [0, 1, ..., end_frame-1] [0, 1, ..., final_unit_spike_frame - 1] if both start_frame and end_frame are given, if only start_frame is given, if only end_frame is given, or if neither start_frame or end_frame are given, respectively. Spike frames are returned in the form of an array_like of spike frames. In this implementation, start_frame is inclusive and end_frame is exclusive conforming to numpy standards. """ unit_idx = self._units_map.get(unit_id) if unit_idx is None: raise ValueError(f'enter one of unit_id={self.get_unit_ids()}') cluster_sp_times = self._get_clusters_spike_times(unit_idx) if cluster_sp_times is None: return np.array([]) max_frame = np.ceil(cluster_sp_times[-1]*self.get_sampling_frequency()).astype('int64') min_frame = np.floor(cluster_sp_times[0]*self.get_sampling_frequency()).astype('int64') start_frame = min_frame if start_frame is None or start_frame < min_frame else start_frame end_frame = max_frame if end_frame is None or end_frame > max_frame else end_frame if start_frame > max_frame or end_frame < min_frame: raise ValueError(f'Use start_frame to end_frame between {min_frame} and {max_frame}') cluster_sp_frames = (cluster_sp_times * self.get_sampling_frequency()).astype('int64') frame_idx = np.where((cluster_sp_frames >= start_frame) & (cluster_sp_frames < end_frame)) return cluster_sp_frames[frame_idx] @staticmethod def write_sorting(sorting, save_path): """ This is an example of a function that is not abstract so it is optional if you want to override it. It allows other SortingExtractors to use your new SortingExtractor to convert their sorted data into your sorting file format. """ assert HAVE_PANDAS, ALFSortingExtractor.installation_mesg # write cluster properties as clusters..npy save_path = Path(save_path) csv_property_names = ['cluster_id', 'cluster_id.1', 'num_spikes', 'firing_rate', 'presence_ratio', 'presence_ratio_std', 'frac_isi_viol', 'contamination_est', 'contamination_est2', 'missed_spikes_est', 'cum_amp_drift', 'max_amp_drift', 'cum_depth_drift', 'max_depth_drift', 'ks2_contamination_pct', 'ks2_label','amplitude_cutoff', 'amplitude_std', 'epoch_name', 'isi_viol'] clusters_metrics_df = pd.DataFrame() for property_name in sorting.get_unit_property_names(0): data = sorting.get_units_property(property_name=property_name) if property_name not in csv_property_names: np.save(save_path/f'clusters.{property_name}', data) else: clusters_metrics_df[property_name] = data clusters_metrics_df.to_csv(save_path/'clusters.metrics.csv') # save spikes.times, spikes.clusters clusters_number = [] unit_spike_times = [] for unit_no, unit_id in enumerate(sorting.get_unit_ids()): unit_spike_train = sorting.get_unit_spike_train(unit_id=unit_id) if unit_spike_train is not None: unit_spike_times.extend(np.array(unit_spike_train)/sorting.get_sampling_frequency()) clusters_number.extend([unit_no]*len(unit_spike_train)) unit_spike_train = np.array(unit_spike_times) clusters_number = np.array(clusters_number) spike_times_ids = np.argsort(unit_spike_train) spike_times = unit_spike_train[spike_times_ids] spike_clusters = clusters_number[spike_times_ids] np.save(save_path/'spikes.times', spike_times) np.save(save_path/'spikes.clusters', spike_clusters) ================================================ FILE: spikeextractors/extractors/axonaunitrecordingextractor/__init__.py ================================================ from .axonaunitrecordingextractor import AxonaUnitRecordingExtractor ================================================ FILE: spikeextractors/extractors/axonaunitrecordingextractor/axonaunitrecordingextractor.py ================================================ from spikeextractors.extraction_tools import check_get_traces_args from spikeextractors.extractors.neoextractors.neobaseextractor import ( _NeoBaseExtractor, NeoBaseRecordingExtractor) from spikeextractors import RecordingExtractor from pathlib import Path import numpy as np from typing import Union import warnings PathType = Union[Path, str] try: import neo from neo.rawio.baserawio import _signal_channel_dtype, _signal_stream_dtype HAVE_NEO = True except ImportError: HAVE_NEO = False class AxonaUnitRecordingExtractor(NeoBaseRecordingExtractor, RecordingExtractor, _NeoBaseExtractor): """ Instantiates a RecordingExtractor from an Axona Unit mode file. Since the unit mode format only saves waveform cutouts, the get_traces function fills in the rest of the recording with Gaussian uncorrelated noise Parameters ---------- noise_std: float Standard deviation of the Gaussian background noise (default 3) """ extractor_name = 'AxonaUnitRecording' mode = 'file' NeoRawIOClass = 'AxonaRawIO' def __init__(self, noise_std: float = 3, block_index=None, seg_index=None, **kargs): RecordingExtractor.__init__(self) _NeoBaseExtractor.__init__(self, block_index=block_index, seg_index=seg_index, **kargs) # Enforce 1 signal stream (there are 0 raw streams), we will create 1 from waveforms signal_streams = self.neo_reader._get_signal_streams_header() signal_channels = self.neo_reader._get_signal_chan_header() self.neo_reader.header['signal_streams'] = np.array(signal_streams, dtype=_signal_stream_dtype) self.neo_reader.header['signal_channels'] = np.array(signal_channels, dtype=_signal_channel_dtype) if hasattr(self.neo_reader, 'get_group_signal_channel_indexes'): # Neo >= 0.9.0 channel_indexes_list = self.neo_reader.get_group_signal_channel_indexes() num_streams = len(channel_indexes_list) assert num_streams <= 1, 'This file have several channel groups spikeextractors support only one groups' self.after_v10 = False elif hasattr(self.neo_reader, 'get_group_channel_indexes'): # Neo < 0.9.0 channel_indexes_list = self.neo_reader.get_group_channel_indexes() num_streams = len(channel_indexes_list) self.after_v10 = False elif hasattr(self.neo_reader, 'signal_streams_count'): # Neo >= 0.10.0 (not release yet in march 2021) num_streams = self.neo_reader.signal_streams_count() self.after_v10 = True else: raise ValueError('Strange neo version. Please upgrade your neo package: pip install --upgrade neo') assert num_streams <= 1, 'This file have several signal streams spikeextractors support only one streams' \ 'Maybe you can use option to select only one stream' # spikeextractor for units to be uV implicitly # check that units are V, mV or uV units = self.neo_reader.header['signal_channels']['units'] assert np.all(np.isin(units, ['V', 'mV', 'uV'])), 'Signal units no Volt compatible' self.additional_gain = np.ones(units.size, dtype='float') self.additional_gain[units == 'V'] = 1e6 self.additional_gain[units == 'mV'] = 1e3 self.additional_gain[units == 'uV'] = 1. self.additional_gain = self.additional_gain.reshape(1, -1) # Add channels properties header_channels = self.neo_reader.header['signal_channels'][slice(None)] self._neo_chan_ids = self.neo_reader.header['signal_channels']['id'] # In neo there is not guarantee that channel ids are unique. # for instance Blacrock can have several times the same chan_id # different sampling rate # so check it assert np.unique(self._neo_chan_ids).size == self._neo_chan_ids.size, 'In this format channel ids are not ' \ 'unique! Incompatible with SpikeInterface' try: channel_ids = [int(ch) for ch in self._neo_chan_ids] except Exception as e: warnings.warn("Could not parse channel ids to int: using linear channel map") channel_ids = list(np.arange(len(self._neo_chan_ids))) self._channel_ids = channel_ids gains = header_channels['gain'] * self.additional_gain[0] self.set_channel_gains(gains=gains, channel_ids=self._channel_ids) names = header_channels['name'] for i, ind in enumerate(self._channel_ids): self.set_channel_property(channel_id=ind, property_name='name', value=names[i]) self._noise_std = noise_std # Read channel groups by tetrode IDs self.set_channel_groups(groups=[ tetrode_id - 1 for tetrode_id in self.neo_reader.get_active_tetrode() for _ in range(4)]) header_channels = self.neo_reader.header['signal_channels'][slice(None)] names = header_channels['name'] channel_ids = self.get_channel_ids() for i, ind in enumerate(channel_ids): self.set_channel_property(channel_id=ind, property_name='name', value=names[i]) # Set channel gains for int8 .X Unit data gains = self.neo_reader._get_channel_gain(bytes_per_sample=1)[0:len(channel_ids)] self.set_channel_gains(gains, channel_ids=channel_ids) @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): timebase_sr = int(self.neo_reader.file_parameters['unit']['timebase'].split(' ')[0]) samples_pre = int(self.neo_reader.file_parameters['set']['file_header']['pretrigSamps']) samples_post = int(self.neo_reader.file_parameters['set']['file_header']['spikeLockout']) sampling_rate = self.get_sampling_frequency() tcmap = self._get_tetrode_channel_table(channel_ids) traces = self._noise_std * np.random.randn(len(channel_ids), end_frame - start_frame) if return_scaled: traces = traces.astype(np.float32) else: traces = traces.astype(np.int8) # Loop through tetrodes and include requested channels in traces itrc = 0 for tetrode_id in np.unique(tcmap[:, 0]): channels_oi = tcmap[tcmap[:, 0] == tetrode_id, 2] waveforms = self.neo_reader._get_spike_raw_waveforms( block_index=0, seg_index=0, unit_index=tetrode_id - 1, # Tetrodes IDs are 1-indexed t_start=start_frame / sampling_rate, t_stop=end_frame / sampling_rate ) waveforms = waveforms[:, channels_oi, :] nch = len(channels_oi) spike_train = self.neo_reader._get_spike_timestamps( block_index=0, seg_index=0, unit_index=tetrode_id - 1, t_start=start_frame / sampling_rate, t_stop=end_frame / sampling_rate ) # Fill waveforms into traces timestamp by timestamp for t, wf in zip(spike_train, waveforms): t = int(t // (timebase_sr / sampling_rate)) # timestamps are sampled at higher frequency t = t - start_frame if (t - samples_pre < 0) and (t + samples_post > traces.shape[1]): traces[itrc:itrc + nch, :] = wf[:, samples_pre - t:traces.shape[1] - (t - samples_pre)] elif t - samples_pre < 0: traces[itrc:itrc + nch, :t + samples_post] = wf[:, samples_pre - t:] elif t + samples_post > traces.shape[1]: traces[itrc:itrc + nch, t - samples_pre:] = wf[:, :traces.shape[1] - (t - samples_pre)] else: traces[itrc:itrc + nch, t - samples_pre:t + samples_post] = wf itrc += nch return traces def get_num_frames(self): n = int(self.neo_reader.segment_t_stop(block_index=0, seg_index=0) * self.get_sampling_frequency()) if self.get_sampling_frequency() == 24000: n = n // 2 return n def get_sampling_frequency(self): return int(self.neo_reader.header['spike_channels'][0][-1]) def get_channel_ids(self): return self._channel_ids def _get_tetrode_channel_table(self, channel_ids): '''Create auxiliary np.array with the following columns: Tetrode ID, Channel ID, Channel ID within tetrode This is useful in `get_traces()` Parameters ---------- channel_ids : list List of channel ids to include in table Returns ------- np.array Rows = channels, columns = TetrodeID, ChannelID, ChannelID within Tetrode ''' active_tetrodes = self.neo_reader.get_active_tetrode() tcmap = np.zeros((len(active_tetrodes) * 4, 3), dtype=int) row_id = 0 for tetrode_id in [int(s[0].split(' ')[1]) for s in self.neo_reader.header['spike_channels']]: all_channel_ids = self.neo_reader._get_channel_from_tetrode(tetrode_id) for i in range(4): tcmap[row_id, 0] = int(tetrode_id) tcmap[row_id, 1] = int(all_channel_ids[i]) tcmap[row_id, 2] = int(i) row_id += 1 del_idx = [False if i in channel_ids else True for i in tcmap[:, 1]] return np.delete(tcmap, del_idx, axis=0) ================================================ FILE: spikeextractors/extractors/bindatrecordingextractor/__init__.py ================================================ from .bindatrecordingextractor import BinDatRecordingExtractor ================================================ FILE: spikeextractors/extractors/bindatrecordingextractor/bindatrecordingextractor.py ================================================ import shutil import numpy as np from pathlib import Path from typing import Union, Optional from spikeextractors import RecordingExtractor from spikeextractors.extraction_tools import read_binary, write_to_binary_dat_format, check_get_traces_args PathType = Union[str, Path] DtypeType = Union[str, np.dtype] ArrayType = Union[list, np.ndarray] OptionalDtypeType = Optional[DtypeType] OptionalArrayType = Optional[Union[np.ndarray, list]] class BinDatRecordingExtractor(RecordingExtractor): """ RecordingExtractor for a binary format Parameters ---------- file_path: str or Path Path to the binary file sampling_frequency: float The sampling frequncy numchan: int Number of channels dtype: str or dtype The dtype of the binary file time_axis: int The axis of the time dimension (default 0: F order) recording_channels: list (optional) A list of channel ids geom: array-like (optional) A list or array with channel locations file_offset: int (optional) Number of bytes in the file to offset by during memmap instantiation. gain: float or array-like (optional) The gain to apply to the traces channel_offset: float or array-like The offset to apply to the traces is_filtered: bool If True, the recording is assumed to be filtered """ extractor_name = 'BinDatRecording' has_default_locations = False has_unscaled = False installed = True is_writable = True mode = "file" installation_mesg = "" def __init__(self, file_path: PathType, sampling_frequency: float, numchan: int, dtype: DtypeType, time_axis: int = 0, recording_channels: Optional[list] = None, geom: Optional[ArrayType] = None, file_offset: Optional[float] = 0, gain: Optional[Union[float, ArrayType]] = None, channel_offset: Optional[Union[float, ArrayType]] = None, is_filtered: Optional[bool] = None): RecordingExtractor.__init__(self) self._datfile = Path(file_path) self._time_axis = time_axis self._dtype = np.dtype(dtype).name self._sampling_frequency = float(sampling_frequency) self._numchan = numchan self._geom = geom self._timeseries = read_binary(self._datfile, numchan, dtype, time_axis, file_offset) if is_filtered is not None: self.is_filtered = is_filtered else: self.is_filtered = False if recording_channels is not None: assert len(recording_channels) <= self._timeseries.shape[0], \ 'Provided recording channels have the wrong length' self._channels = recording_channels else: self._channels = list(range(self._timeseries.shape[0])) if len(self._channels) == self._timeseries.shape[0]: self._complete_channels = True else: assert max(self._channels) < self._timeseries.shape[0], "Channel ids exceed the number of " \ "available channels" self._complete_channels = False if geom is not None: self.set_channel_locations(self._geom) self.has_default_locations = True if 'numpy' in str(dtype): dtype_str = str(dtype).replace("", "") dtype_str = dtype_str.split('.')[1] else: dtype_str = str(dtype) if gain is not None: self.set_channel_gains(channel_ids=self.get_channel_ids(), gains=gain) self.has_unscaled = True if channel_offset is not None: self.set_channel_offsets(channel_offset) self._kwargs = {'file_path': str(Path(file_path).absolute()), 'sampling_frequency': sampling_frequency, 'numchan': numchan, 'dtype': dtype_str, 'recording_channels': recording_channels, 'time_axis': time_axis, 'geom': geom, 'file_offset': file_offset, 'gain': gain, 'is_filtered': is_filtered} def get_channel_ids(self): return self._channels def get_num_frames(self): return self._timeseries.shape[1] def get_sampling_frequency(self): return self._sampling_frequency @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): if self._complete_channels: if np.array_equal(channel_ids, self.get_channel_ids()): traces = self._timeseries[:, start_frame:end_frame] else: channel_idxs = np.array([self.get_channel_ids().index(ch) for ch in channel_ids]) if np.all(np.diff(channel_idxs) == 1): traces = self._timeseries[channel_idxs[0]:channel_idxs[0]+len(channel_idxs), start_frame:end_frame] else: # This block of the execution will return the data as an array, not a memmap traces = self._timeseries[channel_idxs, start_frame:end_frame] else: # in this case channel ids are actually indexes traces = self._timeseries[channel_ids, start_frame:end_frame] return traces @staticmethod def write_recording( recording: RecordingExtractor, save_path: PathType, time_axis: int = 0, dtype: OptionalDtypeType = None, **write_binary_kwargs ): """ Save the traces of a recording extractor in binary .dat format. Parameters ---------- recording : RecordingExtractor The recording extractor object to be saved in .dat format. save_path : str The path to the file. time_axis : int, optional If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file. If 1, the traces shape (nb_channel, nb_sample) is kept in the file. dtype : dtype Type of the saved data. Default float32. **write_binary_kwargs: keyword arguments for write_to_binary_dat_format() function """ write_to_binary_dat_format(recording, save_path, time_axis=time_axis, dtype=dtype, **write_binary_kwargs) ================================================ FILE: spikeextractors/extractors/biocamrecordingextractor/__init__.py ================================================ from .biocamrecordingextractor import BiocamRecordingExtractor ================================================ FILE: spikeextractors/extractors/biocamrecordingextractor/biocamrecordingextractor.py ================================================ from spikeextractors import RecordingExtractor from spikeextractors.extraction_tools import check_get_traces_args import numpy as np from pathlib import Path import ctypes try: import h5py HAVE_BIOCAM = True except ImportError: HAVE_BIOCAM = False class BiocamRecordingExtractor(RecordingExtractor): extractor_name = 'BiocamRecording' has_default_locations = True has_unscaled = False installed = HAVE_BIOCAM # check at class level if installed or not is_writable = True mode = 'file' installation_mesg = "To use the BiocamRecordingExtractor install h5py: \n\n pip install h5py\n\n" # error message when not installed def __init__(self, file_path, verbose=False, mea_pitch=42): assert self.installed, self.installation_mesg self._mea_pitch = mea_pitch self._recording_file = file_path self._rf, self._nFrames, self._samplingRate, self._nRecCh, self._chIndices, \ self._file_format, self._signalInv, self._positions, self._read_function = openBiocamFile( self._recording_file, self._mea_pitch, verbose) RecordingExtractor.__init__(self) self.set_channel_locations(self._positions) self._kwargs = {'file_path': str(Path(file_path).absolute()), 'mea_pitch': mea_pitch, 'verbose': verbose} def __del__(self): self._rf.close() def get_channel_ids(self): return list(range(self._nRecCh)) def get_num_frames(self): return self._nFrames def get_sampling_frequency(self): return self._samplingRate @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): data = self._read_function(self._rf, start_frame, end_frame, self.get_num_channels()) # transform to slice if possible if sorted(channel_ids) == channel_ids and np.all(np.diff(channel_ids) == 1): channel_ids = slice(channel_ids[0], channel_ids[0]+len(channel_ids)) return data[:, channel_ids].T @staticmethod def write_recording(recording, save_path): # Convert to uV: # AnalogValue = MVOffset + DigitalValue * ADCCountsToMV # Where ADCCountsToMV is defined as: # ADCCountsToMV = SignalInversion * ((MaxVolt - MinVolt) / 2^BitDepth) # And MVOffset as: # MVOffset = SignalInversion * MinVolt # conversion back # DigitalValue = (AnalogValue - MVOffset)/ADCCountsToMV # we center at 2048 assert HAVE_BIOCAM, BiocamRecordingExtractor.installation_mesg M = recording.get_num_channels() N = recording.get_num_frames() rf = h5py.File(save_path, 'w') g = rf.create_group('3BData') dr = rf.create_dataset('3BData/Raw', (M * N,), dtype=int) dt = 50000 for i in range(N // dt): dr[M * i * dt:M * (i + 1) * dt] = recording.get_traces(range(M), i * dt, (i + 1) * dt).T.flatten() dr[M * (N // dt) * dt:] = recording.get_traces(range(M), (N // dt) * dt, N).T.flatten() g.attrs['Version'] = 101 rf.create_dataset('3BRecInfo/3BRecVars/MinVolt', data=[0]) rf.create_dataset('3BRecInfo/3BRecVars/MaxVolt', data=[1]) rf.create_dataset('3BRecInfo/3BRecVars/NRecFrames', data=[N]) rf.create_dataset('3BRecInfo/3BRecVars/SamplingRate', data=[recording.get_sampling_frequency()]) rf.create_dataset('3BRecInfo/3BRecVars/SignalInversion', data=[1]) rf.create_dataset('3BRecInfo/3BMeaChip/NCols', data=[M]) r = recording.get_channel_locations()[:, 0] c = recording.get_channel_locations()[:, 1] d = np.ndarray((1, len(r)), dtype=[('Row', ' 0, "'smrx_channel_ids' cannot be an empty list!" super().__init__() # Open smrx file self._recording_file_path = file_path self._recording_file = sp.SonFile(sName=str(file_path), bReadOnly=True) if self._recording_file.GetOpenError() != 0: raise ValueError(f'Error opening file:', sp.GetErrorString(self._recording_file.GetOpenError())) # Map Recording channel_id to smrx index / test for invalid indexes / # get channel info / set channel gains self._channelid_to_smrxind = dict() self._channel_smrxinfo = dict() self._channel_names = [] gains = [] for i, ind in enumerate(smrx_channel_ids): if self._recording_file.ChannelType(ind) == sp.DataType.Off: raise ValueError(f'Channel {ind} is type Off and cannot be used') self._channelid_to_smrxind[i] = ind self._channel_smrxinfo[i] = get_channel_info( f=self._recording_file, smrx_ch_ind=ind ) # Set channel gains: http://ced.co.uk/img/Spike10.pdf # from 16-bit encoded int / to ADC +-5V input / to measured Volts gain = self._channel_smrxinfo[i]['scale'] / 6553.6 gain *= 1000 # mV --> uV gains.append(gain) self._channel_names.append(self._channel_smrxinfo[i]['title']) # Set gains self.set_channel_gains(gains=gains) self.has_unscaled = True rate0 = self._channel_smrxinfo[0]['rate'] for chan, info in self._channel_smrxinfo.items(): assert info['rate'] == rate0, "Inconsistency between 'sampling_frequency' of different channels. The " \ "extractor only supports channels with the same 'rate'" # Set self._times times = (self._channel_smrxinfo[0]['frame_offset'] + np.arange(self.get_num_frames())) / self.get_sampling_frequency() self.set_times(times=times) self._kwargs = {'file_path': str(Path(file_path).absolute()), 'smrx_channel_ids': smrx_channel_ids} @property def channel_names(self): return deepcopy(self._channel_names) @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): """This function extracts and returns a trace from the recorded data from the given channels ids and the given start and end frame. It will return traces from within three ranges: [start_frame, start_frame+1, ..., end_frame-1] [start_frame, start_frame+1, ..., final_recording_frame - 1] [0, 1, ..., end_frame-1] [0, 1, ..., final_recording_frame - 1] if both start_frame and end_frame are given, if only start_frame is given, if only end_frame is given, or if neither start_frame or end_frame are given, respectively. Traces are returned in a 2D array that contains all of the traces from each channel with dimensions (num_channels x num_frames). In this implementation, start_frame is inclusive and end_frame is exclusive conforming to numpy standards. Parameters ---------- start_frame: int The starting frame of the trace to be returned (inclusive) end_frame: int The ending frame of the trace to be returned (exclusive) channel_ids: array_like A list or 1D array of channel ids (ints) from which each trace will be extracted return_scaled: bool If True, traces are returned after scaling (using gain/offset). If False, the traces are returned as integers Returns ---------- traces: numpy.ndarray A 2D array that contains all of the traces from each channel. Dimensions are: (num_channels x num_frames) """ recordings = np.vstack( [get_channel_data( f=self._recording_file, smrx_ch_ind=self._channelid_to_smrxind[i], start_frame=start_frame, end_frame=end_frame ) for i in channel_ids] ) return recordings def get_num_frames(self): """This function returns the number of frames in the recording Returns ------- num_frames: int Number of frames in the recording (duration of recording) """ return 1 + int(self._channel_smrxinfo[0]['max_time'] / self._channel_smrxinfo[0]['divide'] - self._channel_smrxinfo[0]['frame_offset']) def get_sampling_frequency(self): """This function returns the sampling frequency in units of Hz. Returns ------- fs: float Sampling frequency of the recordings in Hz """ return self._channel_smrxinfo[0]['rate'] def get_channel_ids(self): """Returns the list of channel ids. If not specified, the range from 0 to num_channels - 1 is returned. Returns ------- channel_ids: list Channel list """ return list(self._channelid_to_smrxind.keys()) @staticmethod def get_all_channels_info(file_path): """ Extract info from all channels in the smrx file. Returns a dictionary with valid smrx channel indexes as keys and the respective channel information as value. Parameters: ----------- f: str Path to .smrx file """ f = sp.SonFile(sName=str(file_path), bReadOnly=True) n_channels = f.MaxChannels() return { i: get_channel_info(f, i) for i in range(n_channels) if f.ChannelType(i) != sp.DataType.Off } ================================================ FILE: spikeextractors/extractors/cedextractors/utils.py ================================================ import numpy as np try: from sonpy import lib as sp # Data storage and function finder DataReadFunctions = { sp.DataType.Adc: sp.SonFile.ReadInts, sp.DataType.EventFall: sp.SonFile.ReadEvents, sp.DataType.EventRise: sp.SonFile.ReadEvents, sp.DataType.EventBoth: sp.SonFile.ReadEvents, sp.DataType.Marker: sp.SonFile.ReadMarkers, sp.DataType.AdcMark: sp.SonFile.ReadWaveMarks, sp.DataType.RealMark: sp.SonFile.ReadRealMarks, sp.DataType.TextMark: sp.SonFile.ReadTextMarks, sp.DataType.RealWave: sp.SonFile.ReadFloats } except: pass # Get the saved time and date # f.GetTimeDate() def get_channel_info(f, smrx_ch_ind): """ Extract info from smrx files Parameters: ----------- f: str SonFile object. smrx_ch_ind: int Index of smrx channel. Does not match necessarily with extractor id. """ nMax = 1 + int(f.ChannelMaxTime(smrx_ch_ind) / f.ChannelDivide(smrx_ch_ind)) frame_offset = f.FirstTime(chan=smrx_ch_ind, tFrom=0, tUpto=nMax) / f.ChannelDivide(smrx_ch_ind) ch_info = { 'type': f.ChannelType(smrx_ch_ind), # Get the channel kind 'ch_number': f.PhysicalChannel(smrx_ch_ind), # Get the physical channel number associated with this channel 'title': f.GetChannelTitle(smrx_ch_ind), # Get the channel title 'ideal_rate': f.GetIdealRate(smrx_ch_ind), # Get the requested channel ideal rate 'rate': 1 / (f.GetTimeBase() * f.ChannelDivide(smrx_ch_ind)), # Get the requested channel real rate 'max_time': f.ChannelMaxTime(smrx_ch_ind), # Get the time of the last item in the channel (in clock ticks) 'divide': f.ChannelDivide(smrx_ch_ind), # Get the waveform sample interval in file clock ticks 'time_base': f.GetTimeBase(), # Get how many seconds there are per clock tick 'frame_offset': frame_offset, # Get frame offset 'scale': f.GetChannelScale(smrx_ch_ind), # Get the channel scale 'offset': f.GetChannelOffset(smrx_ch_ind), # Get the channel offset 'unit': f.GetChannelUnits(smrx_ch_ind), # Get the channel units 'y_range': f.GetChannelYRange(smrx_ch_ind), # Get a suggested Y range for the channel 'comment': f.GetChannelComment(smrx_ch_ind), # Get the comment associated with a channel 'size_bytes:': f.ChannelBytes(smrx_ch_ind), # Get an estimate of the data bytes stored for the channel } return ch_info def get_channel_data(f, smrx_ch_ind, start_frame=0, end_frame=None): """ Extract info from smrx files Parameters: ----------- f: str SonFile object. smrx_ch_ind: int Index of smrx channel. Does not match necessarily with extractor id. start_frame: int The starting frame of the trace to be returned (inclusive). end_frame: int The ending frame of the trace to be returned (exclusive). """ if end_frame is None: end_frame = 1 + int(f.ChannelMaxTime(smrx_ch_ind) / f.ChannelDivide(smrx_ch_ind)) nMax = 1 + int(f.ChannelMaxTime(smrx_ch_ind) / f.ChannelDivide(smrx_ch_ind)) frame_offset = int(f.FirstTime(chan=smrx_ch_ind, tFrom=0, tUpto=nMax) / f.ChannelDivide(smrx_ch_ind)) start_frame += frame_offset end_frame += frame_offset data = DataReadFunctions[f.ChannelType(smrx_ch_ind)]( self=f, chan=smrx_ch_ind, nMax=nMax, tFrom=int(start_frame * f.ChannelDivide(smrx_ch_ind)), tUpto=int(end_frame * f.ChannelDivide(smrx_ch_ind)) ) return np.array(data) ================================================ FILE: spikeextractors/extractors/cellexplorersortingextractor/__init__.py ================================================ from .cellexplorersortingextractor import CellExplorerSortingExtractor ================================================ FILE: spikeextractors/extractors/cellexplorersortingextractor/cellexplorersortingextractor.py ================================================ from spikeextractors import SortingExtractor import numpy as np from pathlib import Path from spikeextractors.extraction_tools import check_get_unit_spike_train from typing import Union, Optional try: import scipy.io import hdf5storage HAVE_SCIPY_AND_HDF5STORAGE = True except ImportError: HAVE_SCIPY_AND_HDF5STORAGE = False PathType = Union[str, Path] OptionalPathType = Optional[PathType] class CellExplorerSortingExtractor(SortingExtractor): """ Extracts spiking information from .mat files stored in the CellExplorer format. Spike times are stored in units of seconds. Parameters ---------- spikes_matfile_path : PathType Path to the sorting_id.spikes.cellinfo.mat file. """ extractor_name = "CellExplorerSortingExtractor" installed = HAVE_SCIPY_AND_HDF5STORAGE is_writable = True mode = "file" installation_mesg = "To use the CellExplorerSortingExtractor install scipy and hdf5storage: \n\n pip install scipy\n\n and \n\n pip install hdf5 storage \n\n" def __init__(self, spikes_matfile_path: PathType, session_info_matfile_path: OptionalPathType=None, sampling_frequency: Optional[float] = None): assert self.installed, self.installation_mesg SortingExtractor.__init__(self) spikes_matfile_path = Path(spikes_matfile_path) assert ( spikes_matfile_path.is_file() ), f"The spikes_matfile_path ({spikes_matfile_path}) must exist!" if sampling_frequency is None: folder_path = spikes_matfile_path.parent sorting_id = spikes_matfile_path.name.split(".")[0] if session_info_matfile_path is None: session_info_matfile_path = folder_path / f"{sorting_id}.sessionInfo.mat" assert ( session_info_matfile_path.is_file() ), f"No {sorting_id}.sessionInfo.mat file found in the folder!" try: session_info_mat = scipy.io.loadmat(file_name=str(session_info_matfile_path)) self.read_session_info_with_scipy = True except NotImplementedError: session_info_mat = hdf5storage.loadmat(file_name=str(session_info_matfile_path)) self.read_session_info_with_scipy = False assert session_info_mat["sessionInfo"]["rates"][0][0]["wideband"], ( "The sesssionInfo.mat file must contain " "a 'sessionInfo' struct with field 'rates' containing field 'wideband' to extract the sampling frequency!" ) if self.read_session_info_with_scipy: self._sampling_frequency = float( session_info_mat["sessionInfo"]["rates"][0][0]["wideband"][0][0][0][0] ) # careful not to confuse it with the lfpsamplingrate; reported in units Hz else: self._sampling_frequency = float( session_info_mat["sessionInfo"]["rates"][0][0]["wideband"][0][0] ) # careful not to confuse it with the lfpsamplingrate; reported in units Hz else: self._sampling_frequency = sampling_frequency try: spikes_mat = scipy.io.loadmat(file_name=str(spikes_matfile_path)) self.read_spikes_info_with_scipy = True except NotImplementedError: spikes_mat = hdf5storage.loadmat(file_name=str(spikes_matfile_path)) self.read_spikes_info_with_scipy = False assert np.all( np.isin(["UID", "times"], spikes_mat["spikes"].dtype.names) ), "The spikes.cellinfo.mat file must contain a 'spikes' struct with fields 'UID' and 'times'!" # CellExplorer reports spike times in units seconds; SpikeExtractors uses time units of sampling frames # Rounding is necessary to prevent data loss from int-casting floating point errors if self.read_spikes_info_with_scipy: self._unit_ids = np.asarray(spikes_mat["spikes"]["UID"][0][0][0], dtype=int) self._spiketrains = [ (np.array([y[0] for y in x]) * self._sampling_frequency).round().astype(int) for x in spikes_mat["spikes"]["times"][0][0][0] ] else: self._unit_ids = np.asarray(spikes_mat["spikes"]["UID"][0][0], dtype=int) self._spiketrains = [ (np.array([y[0] for y in x]) * self._sampling_frequency).round().astype(int) for x in spikes_mat["spikes"]["times"][0][0] ] self._kwargs = dict(spikes_matfile_path=str(spikes_matfile_path.absolute())) def get_unit_ids(self): return list(self._unit_ids) @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): times = self._spiketrains[self.get_unit_ids().index(unit_id)] inds = np.where((start_frame <= times) & (times < end_frame)) return times[inds] @staticmethod def write_sorting(sorting: SortingExtractor, save_path: PathType): assert save_path.suffixes == [ ".spikes", ".cellinfo", ".mat", ], "The save_path must correspond to the CellExplorer format of sorting_id.spikes.cellinfo.mat!" base_path = save_path.parent sorting_id = save_path.name.split(".")[0] session_info_save_path = base_path / f"{sorting_id}.sessionInfo.mat" spikes_save_path = save_path base_path.mkdir(parents=True, exist_ok=True) sampling_frequency = sorting.get_sampling_frequency() session_info_mat_dict = dict( sessionInfo=dict(rates=dict(wideband=sampling_frequency)) ) scipy.io.savemat(file_name=session_info_save_path, mdict=session_info_mat_dict) spikes_mat_dict = dict( spikes=dict( UID=sorting.get_unit_ids(), times=[ [[y / sampling_frequency] for y in x] for x in sorting.get_units_spike_train() ], ) ) # If, in the future, it is ever desired to allow this to write unit properties, they must conform # to the format here: https://cellexplorer.org/datastructure/data-structure-and-format/ scipy.io.savemat(file_name=spikes_save_path, mdict=spikes_mat_dict) ================================================ FILE: spikeextractors/extractors/combinatosortingextractor/__init__.py ================================================ from .combinatosortingextractor import CombinatoSortingExtractor ================================================ FILE: spikeextractors/extractors/combinatosortingextractor/combinatosortingextractor.py ================================================ from pathlib import Path import numpy as np from spikeextractors import SortingExtractor from spikeextractors.extraction_tools import check_get_unit_spike_train from typing import Union try: import h5py HAVE_H5PY = True except ImportError: HAVE_H5PY = False PathType = Union[str, Path] class CombinatoSortingExtractor(SortingExtractor): extractor_name = 'CombinatoSorting' installation_mesg = "" # error message when not installed installed = HAVE_H5PY is_writable = False installation_mesg = "To use the CombinatoSortingExtractor install h5py: \n\n pip install h5py\n\n" # error message when not installed def __init__(self, datapath: PathType, sampling_frequency=None, user='simple',det_sign = 'both'): super().__init__() datapath = Path(datapath) assert datapath.is_dir(), 'Folder {} doesn\'t exist'.format(datapath) if sampling_frequency is None: h5_path = str(datapath) + '.h5' if Path(h5_path).exists(): with h5py.File(h5_path, mode='r') as f: sampling_frequency = f['sr'][0] self.set_sampling_frequency(sampling_frequency) det_file = str(datapath / Path('data_' + datapath.stem + '.h5')) sort_cat_files = [] for sign in ['neg', 'pos']: if det_sign in ['both', sign]: sort_cat_file = datapath / Path('sort_{}_{}/sort_cat.h5'.format(sign,user)) if sort_cat_file.exists(): sort_cat_files.append((sign, str(sort_cat_file))) unit_counter = 0 self._spike_trains = {} metadata = {} unsorted = [] fdet = h5py.File(det_file, mode='r') for sign, sfile in sort_cat_files: with h5py.File(sfile, mode='r') as f: sp_class = f['classes'][()] gaux = f['groups'][()] groups = {g:gaux[gaux[:, 1] == g, 0] for g in np.unique(gaux[:, 1])} #array of classes per group group_type = {group: g_type for group,g_type in f['types'][()]} sp_index = f['index'][()] times_css = fdet[sign]['times'][()] for gr, cls in groups.items(): if group_type[gr] == -1: #artifacts continue elif group_type[gr] == 0: #unsorted unsorted.append(np.rint(times_css[sp_index[np.isin(sp_class,cls)]] * (sampling_frequency/1000))) continue unit_counter = unit_counter + 1 self._spike_trains[unit_counter] = np.rint(times_css[sp_index[np.isin(sp_class, cls)]] * (sampling_frequency / 1000)) metadata[unit_counter] = {'det_sign': sign, 'group_type': 'single-unit' if group_type[gr] else 'multi-unit'} fdet.close() self._unsorted_train = np.array([]) if len(unsorted) == 1: self._unsorted_train = unsorted[0] elif len(unsorted) == 2: #unsorted in both signs self._unsorted_train = np.sort(np.concatenate(unsorted), kind='mergesort') self._unit_ids = list(range(1, unit_counter+1)) for u in self._unit_ids: for prop,value in metadata[u].items(): self.set_unit_property(u, prop, value) def get_unit_ids(self): return self._unit_ids @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame) start_frame = start_frame or 0 end_frame = end_frame or np.infty st = self._spike_trains[unit_id] return st[(st >= start_frame) & (st < end_frame)] def get_unsorted_spike_train(self, start_frame=None, end_frame=None): start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame) start_frame = start_frame or 0 end_frame = end_frame or np.infty u = self._unsorted_train return u[(u >= start_frame) & (u < end_frame)] ================================================ FILE: spikeextractors/extractors/exdirextractors/__init__.py ================================================ from .exdirextractors import ExdirRecordingExtractor, ExdirSortingExtractor ================================================ FILE: spikeextractors/extractors/exdirextractors/exdirextractors.py ================================================ from spikeextractors import RecordingExtractor from spikeextractors import SortingExtractor import numpy as np from pathlib import Path from copy import copy from spikeextractors.extraction_tools import check_get_traces_args, check_get_unit_spike_train try: import exdir import exdir.plugins.quantities import quantities as pq HAVE_EXDIR = True except ImportError: HAVE_EXDIR = False class ExdirRecordingExtractor(RecordingExtractor): extractor_name = 'ExdirRecording' has_default_locations = False has_unscaled = False installed = HAVE_EXDIR # check at class level if installed or not is_writable = True mode = 'folder' installation_mesg = "To use the ExdirExtractors run:\n\n pip install exdir\n\n" # error message when not installed def __init__(self, folder_path): assert self.installed, self.installation_mesg self._exdir_file = folder_path exdir_group = exdir.File(folder_path, plugins=[exdir.plugins.quantities]) self._recordings = exdir_group['acquisition']['timeseries'] self._sampling_frequency = float(self._recordings.attrs['sample_rate'].rescale('Hz').magnitude) self._num_channels = self._recordings.shape[0] self._num_timepoints = self._recordings.shape[1] RecordingExtractor.__init__(self) self._kwargs = {'folder_path': str(Path(folder_path).absolute())} def get_channel_ids(self): return list(range(self._num_channels)) def get_num_frames(self): return self._num_timepoints def get_sampling_frequency(self): return self._sampling_frequency @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): return self._recordings.data[np.array(channel_ids), start_frame:end_frame] @staticmethod def write_recording(recording, save_path, lfp=False, mua=False): assert HAVE_EXDIR, ExdirRecordingExtractor.installation_mesg channel_ids = recording.get_channel_ids() raw = recording.get_traces() exdir_group = exdir.File(save_path, plugins=[exdir.plugins.quantities]) if not lfp and not mua: acq = exdir_group.require_group('acquisition') timeseries = acq.require_dataset('timeseries', data=raw) timeseries.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz timeseries.attrs['electrode_identities'] = np.array(channel_ids) return elif lfp: ephys = exdir_group.require_group('processing').require_group('electrophysiology') ephys.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz if 'group' in recording.get_shared_channel_property_names(): channel_groups = np.unique(recording.get_channel_groups()) else: channel_groups = [0] if len(channel_groups) == 1: chan = 0 ch_group = ephys.require_group('channel_group_' + str(chan)) lfp_group = ch_group.require_group('LFP') ch_group.attrs['electrode_group_id'] = chan ch_group.attrs['electrode_identities'] = np.array(recording.get_channel_ids()) ch_group.attrs['electrode_idx'] = np.arange(len(recording.get_channel_ids())) ch_group.attrs['start_time'] = 0 * pq.s ch_group.attrs['stop_time'] = recording.get_num_frames() / \ float(recording.get_sampling_frequency()) * pq.s for i_c, ch in enumerate(recording.get_channel_ids()): ts_group = lfp_group.require_group('LFP_timeseries_' + str(ch)) ts_group.attrs['electrode_group_id'] = chan ts_group.attrs['electrode_identity'] = ch ts_group.attrs['num_samples'] = recording.get_num_frames() ts_group.attrs['electrode_idx'] = i_c ts_group.attrs['start_time'] = 0 * pq.s ts_group.attrs['stop_time'] = recording.get_num_frames() / \ float(recording.get_sampling_frequency()) * pq.s ts_group.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz data = ts_group.require_dataset('data', data=recording.get_traces(channel_ids=[ch])) data.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz data.attrs['unit'] = pq.uV else: channel_groups = np.unique(recording.get_channel_groups()) for chan in channel_groups: ch_group = ephys.require_group('channel_group_' + str(chan)) lfp_group = ch_group.require_group('LFP') ch_group.attrs['electrode_group_id'] = chan ch_group.attrs['electrode_identities'] = np.array([ch for ch in recording.get_channel_ids() if recording.get_channel_groups(ch) == chan]) ch_group.attrs['electrode_idx'] = np.array([i_c for i_c, ch in enumerate(recording.get_channel_ids()) if recording.get_channel_groups(ch) == chan]) ch_group.attrs['start_time'] = 0 * pq.s ch_group.attrs['stop_time'] = recording.get_num_frames() / \ float(recording.get_sampling_frequency()) * pq.s for i_c, ch in enumerate(recording.get_channel_ids()): if recording.get_channel_groups(ch) == chan: ts_group = lfp_group.require_group('LFP_timeseries_' + str(ch)) ts_group.attrs['electrode_group_id'] = chan ts_group.attrs['electrode_identity'] = ch ts_group.attrs['num_samples'] = recording.get_num_frames() ts_group.attrs['electrode_idx'] = i_c ts_group.attrs['start_time'] = 0 * pq.s ts_group.attrs['stop_time'] = recording.get_num_frames() / \ float(recording.get_sampling_frequency()) * pq.s ts_group.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz data = ts_group.require_dataset('data', data=recording.get_traces(channel_ids=[ch])) data.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz data.attrs['unit'] = pq.uV return elif mua: ephys = exdir_group.require_group('processing').require_group('electrophysiology') ephys.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz if 'group' in recording.get_shared_channel_property_names(): channel_groups = np.unique(recording.get_channel_groups()) else: channel_groups = [0] if len(channel_groups) == 1: chan = 0 ch_group = ephys.require_group('channel_group_' + str(chan)) mua_group = ch_group.require_group('MUA') ch_group.attrs['electrode_group_id'] = chan ch_group.attrs['electrode_identities'] = np.array(recording.get_channel_ids()) ch_group.attrs['electrode_idx'] = np.arange(len(recording.get_channel_ids())) ch_group.attrs['start_time'] = 0 * pq.s ch_group.attrs['stop_time'] = recording.get_num_frames() / \ float(recording.get_sampling_frequency()) * pq.s for i_c, ch in enumerate(recording.get_channel_ids()): ts_group = mua_group.require_group('MUA_timeseries_' + str(ch)) ts_group.attrs['electrode_group_id'] = chan ts_group.attrs['electrode_identity'] = ch ts_group.attrs['num_samples'] = recording.get_num_frames() ts_group.attrs['electrode_idx'] = i_c ts_group.attrs['start_time'] = 0 * pq.s ts_group.attrs['stop_time'] = recording.get_num_frames() / \ float(recording.get_sampling_frequency()) * pq.s ts_group.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz data = ts_group.require_dataset('data', data=recording.get_traces(channel_ids=[ch])) data.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz data.attrs['unit'] = pq.uV else: channel_groups = np.unique(recording.get_channel_groups()) for chan in channel_groups: ch_group = ephys.require_group('channel_group_' + str(chan)) mua_group = ch_group.require_group('MUA') ch_group.attrs['electrode_group_id'] = chan ch_group.attrs['electrode_identities'] = np.array([ch for ch in recording.get_channel_ids() if recording.get_channel_groups(ch) == chan]) ch_group.attrs['electrode_idx'] = np.array([i_c for i_c, ch in enumerate(recording.get_channel_ids()) if recording.get_channel_groups(ch) == chan]) ch_group.attrs['start_time'] = 0 * pq.s ch_group.attrs['stop_time'] = recording.get_num_frames() / \ float(recording.get_sampling_frequency()) * pq.s for i_c, ch in enumerate(recording.get_channel_ids()): if recording.get_channel_groups(ch) == chan: ts_group = mua_group.require_group('MUA_timeseries_' + str(ch)) ts_group.attrs['electrode_group_id'] = chan ts_group.attrs['electrode_identity'] = ch ts_group.attrs['num_samples'] = recording.get_num_frames() ts_group.attrs['electrode_idx'] = i_c ts_group.attrs['start_time'] = 0 * pq.s ts_group.attrs['stop_time'] = recording.get_num_frames() / \ float(recording.get_sampling_frequency()) * pq.s ts_group.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz data = ts_group.require_dataset('data', data=recording.get_traces(channel_ids=[ch])) data.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz data.attrs['unit'] = pq.uV class ExdirSortingExtractor(SortingExtractor): extractor_name = 'ExdirSorting' installed = HAVE_EXDIR # check at class level if installed or not is_writable = True mode = 'folder' installation_mesg = "To use the ExdirExtractors run:\n\n pip install exdir\n\n" # error message when not installed def __init__(self, folder_path, sampling_frequency=None, channel_group=None, load_waveforms=False): assert self.installed, self.installation_mesg SortingExtractor.__init__(self) self._exdir_file = folder_path exdir_group = exdir.File(folder_path, plugins=exdir.plugins.quantities) electrophysiology = None sf = copy(sampling_frequency) if 'processing' in exdir_group.keys(): if 'electrophysiology' in exdir_group['processing']: electrophysiology = exdir_group['processing']['electrophysiology'] ephys_attrs = electrophysiology.attrs if 'sample_rate' in ephys_attrs: sf = ephys_attrs['sample_rate'] else: if sf is None: raise Exception("Sampling rate information not found. Please provide it with the 'sampling_frequency' " "argument") else: sf = sf * pq.Hz self._sampling_frequency = float(sf.rescale('Hz').magnitude) if electrophysiology is None: raise Exception("'electrophysiology' group not found!") self._unit_ids = [] current_unit = 1 self._spike_trains = [] for chan_name, channel in electrophysiology.items(): if 'channel' in chan_name: group = int(chan_name.split('_')[-1]) if channel_group is not None: if group != channel_group: continue if load_waveforms: if 'Clustering' in channel.keys() and 'EventWaveform' in channel.keys(): clustering = channel.require_group('Clustering') eventwaveform = channel.require_group('EventWaveform') nums = clustering['nums'].data waveforms = eventwaveform.require_group('waveform_timeseries')['data'].data if 'UnitTimes' in channel.keys(): for unit, unit_times in channel['UnitTimes'].items(): self._unit_ids.append(current_unit) self._spike_trains.append((unit_times['times'].data.rescale('s') * sf).magnitude) attrs = unit_times.attrs for k, v in attrs.items(): self.set_unit_property(current_unit, k, v) if load_waveforms: unit_idxs = np.where(nums == int(unit)) wf = waveforms[unit_idxs] self.set_unit_spike_features(current_unit, 'waveforms', wf) current_unit += 1 self._kwargs = {'folder_path': str(Path(folder_path).absolute()), 'sampling_frequency': sampling_frequency, 'channel_group': channel_group, 'load_waveforms': load_waveforms} def get_unit_ids(self): return self._unit_ids @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): times = self._spike_trains[self._unit_ids.index(unit_id)] inds = np.where((start_frame <= times) & (times < end_frame)) return np.rint(times[inds]).astype(int) @staticmethod def write_sorting(sorting, save_path, recording=None, sampling_frequency=None, save_waveforms=False, verbose=False): assert HAVE_EXDIR, ExdirSortingExtractor.installation_mesg if sampling_frequency is None and recording is None: raise Exception("Provide 'sampling_frequency' argument (Hz)") else: if recording is None: sampling_frequency = sampling_frequency * pq.Hz else: sampling_frequency = recording.get_sampling_frequency() * pq.Hz exdir_group = exdir.File(save_path, plugins=exdir.plugins.quantities) ephys = exdir_group.require_group('processing').require_group('electrophysiology') ephys.attrs['sample_rate'] = sampling_frequency if 'group' in sorting.get_shared_unit_property_names(): channel_groups = np.unique([sorting.get_unit_property(unit, 'group') for unit in sorting.get_unit_ids()]) else: channel_groups = [0] if len(channel_groups) == 1 and channel_groups[0] == 0: chan = 0 if verbose: print("Single group: ", chan) ch_group = ephys.require_group('channel_group_' + str(chan)) try: del ch_group['UnitTimes'] del ch_group['EventWaveform'] del ch_group['Clustering'] except Exception as e: pass unittimes = ch_group.require_group('UnitTimes') unit_stop_time = np.max( [(np.max(sorting.get_unit_spike_train(u).astype(float) / sampling_frequency).rescale('s')) for u in sorting.get_unit_ids()]) * pq.s recording_stop_time = None if recording is not None: ch_group.attrs['electrode_group_id'] = chan ch_group.attrs['electrode_identities'] = np.array([]) ch_group.attrs['electrode_idx'] = np.arange(len(recording.get_channel_ids())) ch_group.attrs['start_time'] = 0 * pq.s recording_stop_time = recording.get_num_frames() / float(recording.get_sampling_frequency()) * pq.s unittimes.attrs['electrode_group_id'] = chan unittimes.attrs['electrode_identities'] = np.array([]) unittimes.attrs['electrode_idx'] = np.array(recording.get_channel_ids()) unittimes.attrs['start_time'] = 0 * pq.s ch_group.attrs['sample_rate'] = sampling_frequency if recording_stop_time is not None: unittimes.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \ else unit_stop_time ch_group.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \ else unit_stop_time nums = np.array([]) timestamps = np.array([]) waveforms = np.array([]) for unit in sorting.get_unit_ids(): unit_group = unittimes.require_group(str(unit)) unit_group.require_dataset('times', data=(sorting.get_unit_spike_train(unit).astype(float) / sampling_frequency).rescale('s')) unit_group.attrs['cluster_group'] = 'unsorted' unit_group.attrs['group_id'] = chan unit_group.attrs['name'] = 'unit #' + str(unit) timestamps = np.concatenate((timestamps, (sorting.get_unit_spike_train(unit).astype(float) / sampling_frequency).rescale('s'))) nums = np.concatenate((nums, [unit] * len(sorting.get_unit_spike_train(unit)))) if 'waveforms' in sorting.get_unit_spike_feature_names(unit): if len(waveforms) == 0: waveforms = sorting.get_unit_spike_features(unit, 'waveforms') else: waveforms = np.vstack((waveforms, sorting.get_unit_spike_features(unit, 'waveforms'))) if save_waveforms: if verbose: print("Saving EventWaveforms") if 'waveforms' in sorting.get_shared_unit_spike_feature_names(): eventwaveform = ch_group.require_group('EventWaveform') waveform_ts = eventwaveform.require_group('waveform_timeseries') data = waveform_ts.require_dataset('data', data=waveforms) waveform_ts.attrs['electrode_group_id'] = chan data.attrs['num_samples'] = len(waveforms) data.attrs['sample_rate'] = sampling_frequency data.attrs['unit'] = pq.dimensionless times = waveform_ts.require_dataset('timestamps', data=timestamps) times.attrs['num_samples'] = len(timestamps) times.attrs['unit'] = pq.s if recording is not None: waveform_ts.attrs['electrode_identities'] = np.array([]) waveform_ts.attrs['electrode_idx'] = np.arange(len(recording.get_channel_ids())) waveform_ts.attrs['start_time'] = 0 * pq.s if recording_stop_time is not None: waveform_ts.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \ else unit_stop_time waveform_ts.attrs['sample_rate'] = sampling_frequency waveform_ts.attrs['sample_length'] = waveforms.shape[1] waveform_ts.attrs['num_samples'] = len(waveforms) if verbose: print("Saving Clustering") clustering = ch_group.require_group('Clustering') ts = clustering.require_dataset('timestamps', data=timestamps * pq.s) ts.attrs['num_samples'] = len(timestamps) ts.attrs['unit'] = pq.s ns = clustering.require_dataset('nums', data=nums) ns.attrs['num_samples'] = len(nums) cn = clustering.require_dataset('cluster_nums', data=np.array(sorting.get_unit_ids())) cn.attrs['num_samples'] = len(sorting.get_unit_ids()) else: # remove preexisten spike sorting data max_group = 10 for chan in np.arange(max_group): if 'channel_group_' + str(chan) in ephys.keys(): if verbose: print('Removing channel', chan, 'info') ch_group = ephys.require_group('channel_group_' + str(chan)) try: del ch_group['UnitTimes'] del ch_group['EventWaveform'] del ch_group['Clustering'] except Exception as e: pass channel_groups = np.unique([sorting.get_unit_property(unit, 'group') for unit in sorting.get_unit_ids()]) for chan in channel_groups: if verbose: print("Group: ", chan) ch_group = ephys.require_group('channel_group_' + str(chan)) unittimes = ch_group.require_group('UnitTimes') unit_stop_time = np.max( [(np.max(sorting.get_unit_spike_train(u).astype(float) / sampling_frequency).rescale('s')) for u in sorting.get_unit_ids()]) * pq.s recording_stop_time = None if recording is not None: unittimes.attrs['electrode_group_id'] = chan unittimes.attrs['electrode_identities'] = np.array([]) unittimes.attrs['electrode_idx'] = np.array( [ch for i_c, ch in enumerate(recording.get_channel_ids()) if recording.get_channel_groups(ch) == chan]) unittimes.attrs['start_time'] = 0 * pq.s recording_stop_time = recording.get_num_frames() / float(recording.get_sampling_frequency()) * pq.s ch_group.attrs['electrode_group_id'] = chan ch_group.attrs['electrode_identities'] = np.array( [i_c for i_c, ch in enumerate(recording.get_channel_ids()) if recording.get_channel_groups(ch) == chan]) ch_group.attrs['electrode_idx'] = np.array( [i_c for i_c, ch in enumerate(recording.get_channel_ids()) if recording.get_channel_groups(ch) == chan]) ch_group.attrs['start_time'] = 0 * pq.s ch_group.attrs['sample_rate'] = sampling_frequency if recording_stop_time is not None: unittimes.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \ else unit_stop_time ch_group.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \ else unit_stop_time nums = np.array([]) timestamps = np.array([]) waveforms = np.array([]) for unit in sorting.get_unit_ids(): if sorting.get_unit_property(unit, 'group') == chan: if verbose: print("Unit: ", unit) unit_group = unittimes.require_group(str(unit)) unit_group.require_dataset('times', data=(sorting.get_unit_spike_train(unit).astype(float) / sampling_frequency).rescale('s')) unit_group.attrs['cluster_group'] = 'unsorted' unit_group.attrs['group_id'] = chan unit_group.attrs['name'] = 'unit #' + str(unit) timestamps = np.concatenate((timestamps, (sorting.get_unit_spike_train(unit).astype(float) / sampling_frequency).rescale('s'))) nums = np.concatenate((nums, [unit] * len(sorting.get_unit_spike_train(unit)))) if 'waveforms' in sorting.get_unit_spike_feature_names(unit): if len(waveforms) == 0: waveforms = sorting.get_unit_spike_features(unit, 'waveforms') else: waveforms = np.vstack((waveforms, sorting.get_unit_spike_features(unit, 'waveforms'))) if save_waveforms: if verbose: print("Saving EventWaveforms") if 'waveforms' in sorting.get_shared_unit_spike_feature_names(): eventwaveform = ch_group.require_group('EventWaveform') waveform_ts = eventwaveform.require_group('waveform_timeseries') data = waveform_ts.require_dataset('data', data=waveforms) data.attrs['num_samples'] = len(waveforms) data.attrs['sample_rate'] = sampling_frequency data.attrs['unit'] = pq.dimensionless times = waveform_ts.require_dataset('timestamps', data=timestamps) times.attrs['num_samples'] = len(timestamps) times.attrs['unit'] = pq.s waveform_ts.attrs['electrode_group_id'] = chan if recording is not None: waveform_ts.attrs['electrode_identities'] = np.array([]) waveform_ts.attrs['electrode_idx'] = np.array([ch for i_c, ch in enumerate(recording.get_channel_ids()) if recording.get_channel_groups(ch) == chan]) waveform_ts.attrs['start_time'] = 0 * pq.s if recording_stop_time is not None: waveform_ts.attrs[ 'stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \ else unit_stop_time waveform_ts.attrs['sample_rate'] = sampling_frequency waveform_ts.attrs['sample_length'] = waveforms.shape[1] waveform_ts.attrs['num_samples'] = len(waveforms) if verbose: print("Saving Clustering") clustering = ephys.require_group('channel_group_' + str(chan)).require_group('Clustering') ts = clustering.require_dataset('timestamps', data=timestamps * pq.s) ts.attrs['num_samples'] = len(timestamps) ts.attrs['unit'] = pq.s ns = clustering.require_dataset('nums', data=nums) ns.attrs['num_samples'] = len(nums) cn = clustering.require_dataset('cluster_nums', data=np.array(sorting.get_unit_ids())) cn.attrs['num_samples'] = len(sorting.get_unit_ids()) ================================================ FILE: spikeextractors/extractors/hdsortsortingextractor/__init__.py ================================================ from .hdsortsortingextractor import HDSortSortingExtractor ================================================ FILE: spikeextractors/extractors/hdsortsortingextractor/hdsortsortingextractor.py ================================================ from pathlib import Path from typing import Union import numpy as np import sys import os from spikeextractors.extractors.matsortingextractor.matsortingextractor import MATSortingExtractor from spikeextractors.extraction_tools import check_get_unit_spike_train PathType = Union[str, Path] convert_cell_array_to_struct_code = """ hdsortOutput = load(fileName); hdsortOutput.Units = [hdsortOutput.Units{:}]; Units = hdsortOutput.Units; MultiElectrode = hdsortOutput.MultiElectrode; noiseStd = hdsortOutput.noiseStd; samplingRate = hdsortOutput.samplingRate; save(fileName, 'Units', 'MultiElectrode', 'noiseStd', 'samplingRate'); """ class HDSortSortingExtractor(MATSortingExtractor): extractor_name = "HDSortSortingExtractor" def __init__(self, file_path: PathType, keep_good_only: bool = True): super().__init__(file_path) if not self._old_style_mat: _units = self._data['Units'] units = _parse_units(self._data, _units) # Extracting MutliElectrode field by field: _ME = self._data["MultiElectrode"] multi_electrode = dict((k, _ME.get(k)[()]) for k in _ME.keys()) # Extracting sampling_frequency: sr = self._data["samplingRate"] self._sampling_frequency = float(_squeeze_ds(sr)) # Remove noise units if necessary: if keep_good_only: units = [unit for unit in units if unit["ID"].flatten()[0].astype(int) % 1000 != 0] if 'sortingInfo' in self._data.keys(): info = self._data["sortingInfo"] start_frame = _squeeze_ds(info['startTimes']) self.start_frame = int(start_frame) else: self.start_frame = 0 else: _units = self._getfield('Units').squeeze() fields = _units.dtype.fields.keys() units = [] for unit in _units: unit_dict = {} for f in fields: unit_dict[f] = unit[f] units.append(unit_dict) sr = self._getfield("samplingRate") self._sampling_frequency = float(_squeeze_ds(sr)) _ME = self._data["MultiElectrode"] multi_electrode = dict((k, _ME[k][0][0].T) for k in _ME.dtype.fields.keys()) # Remove noise units if necessary: if keep_good_only: units = [unit for unit in units if unit["ID"].flatten()[0].astype(int) % 1000 != 0] if 'sortingInfo' in self._data.keys(): info = self._getfield("sortingInfo") start_frame = _squeeze_ds(info['startTimes']) self.start_frame = int(start_frame) else: self.start_frame = 0 # Parse through 'units': self._spike_trains = {} self._unit_ids = np.empty(0, np.int) for uc, unit in enumerate(units): uid = int(_squeeze_ds(unit["ID"])) self._unit_ids = np.append(self._unit_ids, uid) self._spike_trains[uc] = _squeeze(unit["spikeTrain"]).astype(np.int) - self.start_frame # For memory efficiency in case it's necessary: # X = self.allocate_array( "amplitudes_" + uid, array= unit["spikeAmplitudes"].flatten().T) # self.set_unit_spike_features(uid, "amplitudes", X) self.set_unit_spike_features(uid, "amplitudes", _squeeze(unit["spikeAmplitudes"])) self.set_unit_spike_features(uid, "detection_channel", _squeeze(unit["detectionChannel"]).astype(np.int)) idx = unit["detectionChannel"].astype(int) - 1 spikePositions = np.vstack((_squeeze(multi_electrode["electrodePositions"][0][idx]), _squeeze(multi_electrode["electrodePositions"][1][idx]))).T self.set_unit_spike_features(uid, "positions", spikePositions) if self._old_style_mat: template = unit["footprint"].T else: template = unit["footprint"] self.set_unit_property(uid, "template", template) self.set_unit_property(uid, "template_frames_cut_before", unit["cutLeft"].flatten()) self._units = units self._multi_electrode = multi_electrode self._kwargs['keep_good_only'] = keep_good_only @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): uidx = np.where(np.array(self.get_unit_ids()) == unit_id)[0][0] st = self._spike_trains[uidx] return st[(st >= start_frame) & (st < end_frame)] def get_unit_ids(self): return self._unit_ids.tolist() @staticmethod def write_sorting(sorting, save_path, locations=None, noise_std_by_channel=None, start_frame=0, convert_cell_to_struct=True): # First, find out how many channels there are if locations is not None: # write_locations must be a 2D numpy array with n_channels in first dim., (x,y) in second dim. n_channels = locations.shape[0] elif 'template' in sorting.get_shared_unit_property_names() or \ 'detection_channel' in sorting.get_shared_unit_property_names(): # Without locations, check if there is a template to get the number of channels uid = int(sorting.get_unit_ids()[0]) if "template" in sorting.get_unit_property_names(uid): template = sorting.get_unit_property(uid, "template") n_channels = template.shape[0] else: # If there is also no template, loop through all units and find max. detection_channel max_channel = 1 for uid_ in sorting.get_unit_ids(): uid = int(uid_) detection_channel = sorting.get_unit_spike_features(uid, "detection_channel") max_channel = max([max_channel], np.append(detection_channel)) n_channels = max_channel else: n_channels = 1 # Now loop through all units and extract the data that we want to save: units = [] for uid_ in sorting.get_unit_ids(): uid = int(uid_) unit = {"ID": uid, "spikeTrain": sorting.get_unit_spike_train(uid)} num_spikes = len(sorting.get_unit_spike_train(uid)) if "amplitudes" in sorting.get_unit_spike_feature_names(uid): unit["spikeAmplitudes"] = sorting.get_unit_spike_features(uid, "amplitudes") else: # Save a spikeAmplitudes = 1 unit["spikeAmplitudes"] = np.ones(num_spikes, np.double) if "detection_channel" in sorting.get_unit_spike_feature_names(uid): unit["detectionChannel"] = sorting.get_unit_spike_features(uid, "detection_channel") else: # Save a detectionChannel = 1 unit["detectionChannel"] = np.ones(num_spikes, np.double) if "template" in sorting.get_unit_property_names(uid): unit["footprint"] = sorting.get_unit_property(uid, "template").T else: # If this unit does not have a footprint, create an empty one: unit["footprint"] = np.zeros((3, n_channels), np.double) if "template_cut_left" in sorting.get_unit_property_names(uid): unit["cutLeft"] = sorting.get_unit_property(uid, "template_cut_left") else: unit["cutLeft"] = 1 units.append(unit) # Save the electrode locations: if locations is None: # Create artificial locations if none are provided: x = np.zeros(n_channels, np.double) y = np.array(np.arange(n_channels), np.double) locations = np.vstack((x, y)).T multi_electrode = {"electrodePositions": locations, "electrodeNumbers": np.arange(n_channels)} if noise_std_by_channel is None: noise_std_by_channel = np.ones((1, n_channels)) dict_to_save = {'Units': np.array(units), 'MultiElectrode': multi_electrode, 'noiseStd': noise_std_by_channel, "samplingRate": sorting._sampling_frequency} # Save Units and MultiElectrode to .mat file: MATSortingExtractor.write_dict_to_mat(save_path, dict_to_save, version='7.3') if convert_cell_to_struct: # read the template txt files convert_cellarray_to_structarray = f"fileName='{str(Path(save_path).absolute())}';\n" \ f"{convert_cell_array_to_struct_code}" convert_script = Path(save_path).parent / "convert_cellarray_to_structarray.m" with convert_script.open('w') as f: f.write(convert_cellarray_to_structarray) if 'win' in sys.platform and sys.platform != 'darwin': matlab_cmd = """ #!/bin/bash cd {tmpdir} matlab -nosplash -wait -log -r convert_cellarray_to_structarray """.format(tmpdir={str(convert_script.parent)}) else: matlab_cmd = """ #!/bin/bash cd {tmpdir} matlab -nosplash -nodisplay -log -r convert_cellarray_to_structarray """.format(tmpdir={str(convert_script.parent)}) try: os.system(matlab_cmd) except: print("Failed to convert cell array to struct array") convert_script.unlink() # For .mat v7.3: Function to extract all fields of a struct-array: def _parse_units(file, _units): import h5py t_units = {} if isinstance(_units, h5py.Group): for name in _units.keys(): value = _units[name] dict_val = [] for val in value: if isinstance(file[val[0]], h5py.Dataset): dict_val.append(file[val[0]][()]) t_units[name] = dict_val else: break out = [dict(zip(t_units, col)) for col in zip(*t_units.values())] else: out = [] for unit in _units: group = file[unit[()][0]] unit_dict = {} for k in group.keys(): unit_dict[k] = group[k][()] out.append(unit_dict) return out def _squeeze_ds(ds): while not isinstance(ds, (int, float, np.integer, np.float)): ds = ds[0] return ds def _squeeze(arr): shape = arr.shape if len(shape) == 2: if shape[0] == 1: arr = arr[0] elif shape[1] == 1: arr = arr[:, 0] return arr ================================================ FILE: spikeextractors/extractors/hs2sortingextractor/__init__.py ================================================ from .hs2sortingextractor import HS2SortingExtractor ================================================ FILE: spikeextractors/extractors/hs2sortingextractor/hs2sortingextractor.py ================================================ from spikeextractors import SortingExtractor import numpy as np from pathlib import Path from spikeextractors.extraction_tools import check_get_unit_spike_train try: import h5py HAVE_HS2SX = True except ImportError: HAVE_HS2SX = False class HS2SortingExtractor(SortingExtractor): extractor_name = 'HS2Sorting' installed = HAVE_HS2SX # check at class level if installed or not is_writable = True mode = 'file' installation_mesg = "To use the HS2SortingExtractor install h5py: \n\n pip install h5py\n\n" # error message when not installed def __init__(self, file_path, load_unit_info=True): assert self.installed, self.installation_mesg SortingExtractor.__init__(self) self._recording_file = file_path self._rf = h5py.File(self._recording_file, mode='r') if 'Sampling' in self._rf: if self._rf['Sampling'][()] == 0: self._sampling_frequency = None else: self._sampling_frequency = self._rf['Sampling'][()] self._cluster_id = self._rf['cluster_id'][()] self._unit_ids = set(self._cluster_id) self._spike_times = self._rf['times'][()] if load_unit_info: self.load_unit_info() self._kwargs = {'file_path': str(Path(file_path).absolute()), 'load_unit_info': load_unit_info} def load_unit_info(self): if 'centres' in self._rf.keys() and len(self._spike_times) > 0: self._unit_locs = self._rf['centres'][()] # cache for faster access for u_i, unit_id in enumerate(self._unit_ids): self.set_unit_property(unit_id, property_name='unit_location', value=self._unit_locs[u_i]) inds = [] # get these only once for unit_id in self._unit_ids: inds.append(np.where(self._cluster_id == unit_id)[0]) if 'data' in self._rf.keys() and len(self._spike_times) > 0: d = self._rf['data'][()] for i, unit_id in enumerate(self._unit_ids): self.set_unit_spike_features(unit_id, 'spike_location', d[:, inds[i]].T) if 'ch' in self._rf.keys() and len(self._spike_times) > 0: d = self._rf['ch'][()] for i, unit_id in enumerate(self._unit_ids): self.set_unit_spike_features(unit_id, 'max_channel', d[inds[i]]) def get_unit_indices(self, x): return np.where(self._cluster_id == x)[0] def get_unit_ids(self): return list(self._unit_ids) @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): times = self._spike_times[self.get_unit_indices(unit_id)] inds = np.where((start_frame <= times) & (times < end_frame)) return times[inds] @staticmethod def write_sorting(sorting, save_path): assert HAVE_HS2SX, HS2SortingExtractor.installation_mesg unit_ids = sorting.get_unit_ids() times_list = [] labels_list = [] for i in range(len(unit_ids)): unit = unit_ids[i] times = sorting.get_unit_spike_train(unit_id=unit) times_list.append(times) labels_list.append(np.ones(times.shape, dtype=int) * unit) all_times = np.concatenate(times_list) all_labels = np.concatenate(labels_list) rf = h5py.File(save_path, mode='w') if sorting.get_sampling_frequency() is not None: rf.create_dataset("Sampling", data=sorting.get_sampling_frequency()) else: rf.create_dataset("Sampling", data=0) if 'unit_location' in sorting.get_shared_unit_property_names(): spike_centres = [sorting.get_unit_property(u, 'unit_location') for u in sorting.get_unit_ids()] spike_centres = np.array(spike_centres) rf.create_dataset("centres", data=spike_centres) if 'spike_location' in sorting.get_shared_unit_spike_feature_names(): spike_loc_x = [] spike_loc_y = [] for u in sorting.get_unit_ids(): l = sorting.get_unit_spike_features(u, 'spike_location') spike_loc_x.append(l[:, 0]) spike_loc_y.append(l[:, 1]) spike_loc = np.vstack((np.concatenate(spike_loc_x), np.concatenate(spike_loc_y))) rf.create_dataset("data", data=spike_loc) if 'max_channel' in sorting.get_shared_unit_spike_feature_names(): spike_max_channel = np.concatenate( [sorting.get_unit_spike_features(u, 'max_channel') for u in sorting.get_unit_ids()]) rf.create_dataset("ch", data=spike_max_channel) rf.create_dataset("times", data=all_times) rf.create_dataset("cluster_id", data=all_labels) rf.close() ================================================ FILE: spikeextractors/extractors/intanrecordingextractor/__init__.py ================================================ from .intanrecordingextractor import IntanRecordingExtractor ================================================ FILE: spikeextractors/extractors/intanrecordingextractor/intanrecordingextractor.py ================================================ import numpy as np from pathlib import Path from packaging.version import parse from typing import Union, Optional from spikeextractors import RecordingExtractor from spikeextractors.extraction_tools import check_get_traces_args, check_get_ttl_args DtypeType = Union[str, np.dtype] OptionalArrayType = Optional[Union[np.ndarray, list]] try: import pyintan if parse(pyintan.__version__) >= parse('0.3.0'): HAVE_INTAN = True else: print("pyintan version requires an update (>=0.3.0). Please upgrade with 'pip install --upgrade pyintan'") HAVE_INTAN = False except ImportError: HAVE_INTAN = False class IntanRecordingExtractor(RecordingExtractor): """ Extracts raw neural recordings from the Intan file format. The recording extractor always returns channel IDs starting from 0. The recording data will always be returned in the shape of (num_channels, num_frames). Parameters ---------- file_path : str Path to the .dat file to be extracted. dtype : dtype The data type used in the binary file. verbose : bool, optional Print output during pyintan file read. """ extractor_name = 'IntanRecording' has_default_locations = False has_unscaled = True is_writable = False mode = "file" installed = HAVE_INTAN installation_mesg = "To use the Intan extractor, install pyintan: \n\n pip install pyintan\n\n" def __init__(self, file_path: str, verbose: bool = False): assert self.installed, self.installation_mesg RecordingExtractor.__init__(self) assert Path(file_path).suffix == '.rhs' or Path(file_path).suffix == '.rhd', \ "Only '.rhd' and '.rhs' files are supported" self._recording_file = file_path self._recording = pyintan.File(file_path, verbose) self._num_frames = len(self._recording.times) self._analog_channels = np.array([ ch for ch in self._recording._anas_chan if all([other_ch not in ch['name'] for other_ch in ['ADC', 'VDD', 'AUX']]) ]) self._num_channels = len(self._analog_channels) self._channel_ids = list(range(self._num_channels)) self._fs = float(self._recording.sample_rate.rescale('Hz').magnitude) for i, ch in enumerate(self._analog_channels): self.set_channel_gains(channel_ids=i, gains=ch['gain']) self.set_channel_offsets(channel_ids=i, offsets=ch['offset']) self._kwargs = dict(file_path=str(Path(file_path).absolute()), verbose=verbose) def get_channel_ids(self): return self._channel_ids def get_num_frames(self): return self._num_frames def get_sampling_frequency(self): return self._fs @check_get_traces_args def get_traces( self, channel_ids: OptionalArrayType = None, start_frame: Optional[int] = None, end_frame: Optional[int] = None, return_scaled: bool = True, ): """ This function extracts and returns a trace from the recorded data from the given channels ids and the given start and end frame. It will return traces from within three ranges: [start_frame, start_frame+1, ..., end_frame-1] [start_frame, start_frame+1, ..., final_recording_frame - 1] [0, 1, ..., end_frame-1] [0, 1, ..., final_recording_frame - 1] if both start_frame and end_frame are given, if only start_frame is given, if only end_frame is given, or if neither start_frame or end_frame are given, respectively. Traces are returned in a 2D array that contains all of the traces from each channel with dimensions (num_channels x num_frames). In this implementation, start_frame is inclusive and end_frame is exclusive conforming to numpy standards. Parameters ---------- start_frame : int, optional The starting frame of the trace to be returned (inclusive) end_frame : int, optional The ending frame of the trace to be returned (exclusive) channel_ids : ArrayType, optional A list or 1D array of channel ids (ints) from which each trace will be extracted return_scaled : bool, optional If True, traces are returned after scaling (using gain/offset). If False, the raw traces are returned. Defaults to True. Returns ---------- traces: numpy.ndarray A 2D array that contains all of the traces from each channel. Dimensions are: (num_channels x num_frames) """ channel_idxs = np.array([self._channel_ids.index(ch) for ch in channel_ids]) return self._recording._read_analog( channels=self._analog_channels[channel_idxs], i_start=start_frame, i_stop=end_frame, dtype="uint16" ).T @check_get_ttl_args def get_ttl_events(self, start_frame=None, end_frame=None, channel_id=0): channels = [np.unique(ev.channels)[0] for ev in self._recording.digital_in_events] assert channel_id in channels, f"Specified 'channel' not found. Available channels are {channels}" ev = self._recording.events[channels.index(channel_id)] ttl_frames = (ev.times.rescale("s") * self.get_sampling_frequency()).magnitude.astype(int) ttl_states = np.sign(ev.channel_states) ttl_valid_idxs = np.where((ttl_frames >= start_frame) & (ttl_frames < end_frame))[0] return ttl_frames[ttl_valid_idxs], ttl_states[ttl_valid_idxs] ================================================ FILE: spikeextractors/extractors/jrcsortingextractor/__init__.py ================================================ from .jrcsortingextractor import JRCSortingExtractor ================================================ FILE: spikeextractors/extractors/jrcsortingextractor/jrcsortingextractor.py ================================================ from pathlib import Path import re from typing import Union import numpy as np from spikeextractors.extractors.matsortingextractor.matsortingextractor import MATSortingExtractor, HAVE_MAT from spikeextractors.extraction_tools import check_get_unit_spike_train PathType = Union[str, Path] class JRCSortingExtractor(MATSortingExtractor): extractor_name = "JRCSortingExtractor" installation_mesg = "To use the MATSortingExtractor install h5py and scipy: \n\n pip install h5py scipy\n\n" # error message when not installed def __init__(self, file_path: PathType, keep_good_only: bool = False): super().__init__(file_path) file_path = self._kwargs["file_path"] spike_times = self._getfield("spikeTimes").ravel() - 1 # int32 spike_clusters = self._getfield("spikeClusters").ravel() # uint32 spike_amplitudes = self._getfield("spikeAmps").ravel() # int16 spike_sites = self._getfield("spikeSites").ravel() - 1 # uint32 spike_positions = self._getfield("spikePositions").T # float32 unit_centroids = self._getfield("clusterCentroids").astype(np.float).T unit_sites = self._getfield("clusterSites").astype(np.uint32).ravel() mean_waveforms = self._getfield("meanWfGlobal").T mean_waveforms_raw = self._getfield("meanWfGlobalRaw").T # try to extract various parameters from the .prm file self._bit_scaling = np.float32(0.30518) # conversion factor for ADC units -> µV sample_rate = 30000. filter_type = "ndiff" ndiff_order = 2 prm_file = Path(file_path.parent, file_path.name.replace("_res.mat", ".prm")) with prm_file.open("r") as fh: lines = [line.strip() for line in fh.readlines()] for line in lines: try: key, val = line.split('%', 1)[0].strip(" ;").split("=") except ValueError: continue key = key.strip() val = val.strip() if key == "sampleRate": try: sample_rate = float(val) except (IndexError, ValueError): pass elif key == "bitScaling": try: self._bit_scaling = np.float32(val) except (IndexError, ValueError): pass elif key == "filterType": filter_type = val elif key == "nDiffOrder": try: ndiff_order = int(val) except (IndexError, ValueError): pass elif key == "siteLoc": site_locs = [] str_locs = map(lambda v: v.strip(" ]["), val.split(";")) for loc in str_locs: x, y = map(float, re.split(r",?\s+", loc)) site_locs.append([x, y]) site_locs = np.array(site_locs) elif key == "shankMap": val = val.strip("][") try: shank_map = np.array(map(float, re.split(r"[,;]?\s+", val))) except: shank_map = np.array([]) self.set_sampling_frequency(sample_rate) if filter_type == "sgdiff": self._bit_scaling /= (2 * (np.arange(1, ndiff_order + 1) ** 2).sum()) elif filter_type == "ndiff": self._bit_scaling /= 2 # traces, features raw_file = Path(file_path.parent, file_path.name.replace("_res.mat", "_raw.jrc")) raw_shape = tuple(self._getfield("rawShape").ravel().astype(np.int)) self._raw_traces = np.memmap(raw_file, dtype=np.int16, mode="r", shape=raw_shape, order="F") filt_file = Path(file_path.parent, file_path.name.replace("_res.mat", "_filt.jrc")) filt_shape = tuple(self._getfield("filtShape").ravel().astype(np.int)) self._filt_traces = np.memmap(filt_file, dtype=np.int16, mode="r", shape=filt_shape, order="F") features_file = Path(file_path.parent, file_path.name.replace("_res.mat", "_features.jrc")) features_shape = tuple(self._getfield("featuresShape").ravel().astype(np.int)) self._cluster_features = np.memmap(features_file, dtype=np.float32, mode="r", shape=features_shape, order="F") neighbors = _find_site_neighbors(site_locs, raw_shape[1], shank_map) # get nearest neighbors for each site # nonpositive clusters are noise or deleted units if keep_good_only: good_mask = spike_clusters > 0 else: good_mask = np.ones_like(spike_clusters, dtype=np.bool) self._unit_ids = np.unique(spike_clusters[good_mask]) # load spike trains self._spike_trains = {} self._unit_masks = {} for uid in self._unit_ids: mask = (spike_clusters == uid) self._unit_masks[uid] = mask self._spike_trains[uid] = spike_times[mask] self.set_unit_spike_features(uid, "amplitudes", spike_amplitudes[mask]) self.set_unit_spike_features(uid, "max_channels", spike_sites[mask]) self.set_unit_spike_features(uid, "positions", spike_positions[mask, :]) self.set_unit_spike_features(uid, "site_neighbors", neighbors[spike_sites[mask], :]) self.set_unit_property(uid, "centroid", unit_centroids[uid - 1, :]) self.set_unit_property(uid, "max_channel", unit_sites[uid - 1]) self.set_unit_property(uid, "template", mean_waveforms[:, :, uid - 1]) self.set_unit_property(uid, "template_raw", mean_waveforms_raw[:, :, uid - 1]) self._kwargs["keep_good_only"] = keep_good_only @check_get_unit_spike_train def get_unit_spike_features(self, unit_id, feature_name, start_frame=None, end_frame=None): if feature_name not in ("raw_traces", "filtered_traces", "cluster_features"): return super().get_unit_spike_features(unit_id, feature_name, start_frame, end_frame) mask = self._unit_masks[unit_id] if feature_name == "raw_traces": return self._raw_traces[:, :, mask] * self._bit_scaling elif feature_name == "filtered_traces": return self._filt_traces[:, :, mask] * self._bit_scaling else: return self._cluster_features[:, :, mask] @check_get_unit_spike_train def get_unit_spike_feature_names(self, unit_id): return super().get_unit_spike_feature_names(unit_id) + ["raw_traces", "filtered_traces", "cluster_features"] @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame) start_frame = start_frame or 0 end_frame = end_frame or np.infty st = self._spike_trains[unit_id] return st[(st >= start_frame) & (st < end_frame)] def get_unit_ids(self): return self._unit_ids.tolist() def _find_site_neighbors(site_locs, n_neighbors, shank_map): from scipy.spatial.distance import cdist if np.unique(shank_map).size <= 1: pass n_sites = site_locs.shape[0] n_neighbors = int(min(n_neighbors, n_sites)) neighbors = np.zeros((n_sites, n_neighbors), dtype=np.int) for i in range(n_sites): i_loc = site_locs[i, :][np.newaxis, :] dists = cdist(i_loc, site_locs).ravel() neighbors[i, :] = dists.argsort()[:n_neighbors] return neighbors ================================================ FILE: spikeextractors/extractors/kilosortextractors/__init__.py ================================================ from .kilosortextractors import KiloSortSortingExtractor, KiloSortRecordingExtractor ================================================ FILE: spikeextractors/extractors/kilosortextractors/kilosortextractors.py ================================================ from spikeextractors.extractors.phyextractors import PhyRecordingExtractor, PhySortingExtractor from pathlib import Path class KiloSortRecordingExtractor(PhyRecordingExtractor): extractor_name = 'KiloSortRecording' has_default_locations = True has_unscaled = False installed = True # check at class level if installed or not is_writable = False mode = 'folder' installation_mesg = "" # error message when not installed def __init__(self, folder_path): PhyRecordingExtractor.__init__(self, folder_path) class KiloSortSortingExtractor(PhySortingExtractor): extractor_name = 'KiloSortSorting' installed = True # check at class level if installed or not installation_mesg = "" # error message when not installed is_writable = False mode = 'folder' def __init__(self, folder_path, exclude_cluster_groups=None, keep_good_only=False): PhySortingExtractor.__init__(self, folder_path, exclude_cluster_groups) self._keep_good_only = keep_good_only self._good_units = [] if keep_good_only: for u in self.get_unit_ids(): if 'KSLabel' in self.get_unit_property_names(u): if self.get_unit_property(u, 'KSLabel') == 'good': self._good_units.append(u) self._unit_ids = self._good_units self._kwargs = {'folder_path': str(Path(folder_path).absolute()), 'exclude_cluster_groups': exclude_cluster_groups, 'keep_good_only': keep_good_only} ================================================ FILE: spikeextractors/extractors/klustaextractors/__init__.py ================================================ from .klustaextractors import KlustaSortingExtractor, KlustaRecordingExtractor ================================================ FILE: spikeextractors/extractors/klustaextractors/klustaextractors.py ================================================ """ kwik structure based on: https://github.com/kwikteam/phy-doc/blob/master/docs/kwik-format.md cluster_group defaults based on: https://github.com/kwikteam/phy-doc/blob/master/docs/kwik-model.md 04/08/20 """ from spikeextractors import SortingExtractor from spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor from spikeextractors.extraction_tools import read_python, check_get_unit_spike_train import numpy as np from pathlib import Path try: import h5py HAVE_KLSX = True except ImportError: HAVE_KLSX = False # noinspection SpellCheckingInspection class KlustaRecordingExtractor(BinDatRecordingExtractor): extractor_name = 'KlustaRecording' has_default_locations = False has_unscaled = False installed = HAVE_KLSX # check at class level if installed or not is_writable = True mode = 'folder' installation_mesg = "To use the KlustaSortingExtractor install h5py: \n\n pip install h5py\n\n" # error message when not installed def __init__(self, folder_path): assert self.installed, self.installation_mesg klustafolder = Path(folder_path).absolute() config_file = [f for f in klustafolder.iterdir() if f.suffix == '.prm'][0] dat_file = [f for f in klustafolder.iterdir() if f.suffix == '.dat'][0] assert config_file.is_file() and dat_file.is_file(), "Not a valid klusta folder" config = read_python(str(config_file)) sampling_frequency = config['traces']['sample_rate'] n_channels = config['traces']['n_channels'] dtype = config['traces']['dtype'] BinDatRecordingExtractor.__init__(self, file_path=dat_file, sampling_frequency=sampling_frequency, numchan=n_channels, dtype=dtype) self._kwargs = {'folder_path': str(Path(folder_path).absolute())} # noinspection SpellCheckingInspection class KlustaSortingExtractor(SortingExtractor): extractor_name = 'KlustaSorting' installed = HAVE_KLSX # check at class level if installed or not installation_mesg = "To use the KlustaSortingExtractor install h5py: \n\n pip install h5py\n\n" # error message when not installed is_writable = True mode = 'file_or_folder' default_cluster_groups = {0: 'Noise', 1: 'MUA', 2: 'Good', 3: 'Unsorted'} def __init__(self, file_or_folder_path, exclude_cluster_groups=None): assert self.installed, self.installation_mesg SortingExtractor.__init__(self) kwik_file_or_folder = Path(file_or_folder_path) kwikfile = None klustafolder = None if kwik_file_or_folder.is_file(): assert kwik_file_or_folder.suffix == '.kwik', "Not a '.kwik' file" kwikfile = Path(kwik_file_or_folder).absolute() klustafolder = kwikfile.parent elif kwik_file_or_folder.is_dir(): klustafolder = kwik_file_or_folder kwikfiles = [f for f in kwik_file_or_folder.iterdir() if f.suffix == '.kwik'] if len(kwikfiles) == 1: kwikfile = kwikfiles[0] assert kwikfile is not None, "Could not load '.kwik' file" try: config_file = [f for f in klustafolder.iterdir() if f.suffix == '.prm'][0] config = read_python(str(config_file)) sampling_frequency = config['traces']['sample_rate'] self._sampling_frequency = sampling_frequency except Exception as e: print("Could not load sampling frequency info") kf_reader = h5py.File(kwikfile, 'r') self._spiketrains = [] self._unit_ids = [] unique_units = [] klusta_units = [] cluster_groups_name = [] groups = [] unit = 0 cs_to_exclude = [] valid_group_names = [i[1].lower() for i in self.default_cluster_groups.items()] if exclude_cluster_groups is not None: assert isinstance(exclude_cluster_groups, list), 'exclude_cluster_groups should be a list' for ec in exclude_cluster_groups: assert ec in valid_group_names, f'select exclude names out of: {valid_group_names}' cs_to_exclude.append(ec.lower()) for channel_group in kf_reader.get('/channel_groups'): if 'spikes' not in kf_reader.get(f'/channel_groups/{channel_group}'): print('No spikes found for this channel group') continue else: chan_cluster_id_arr = kf_reader.get(f'/channel_groups/{channel_group}/spikes/clusters/main')[()] chan_cluster_times_arr = kf_reader.get(f'/channel_groups/{channel_group}/spikes/time_samples')[()] chan_cluster_ids = np.unique(chan_cluster_id_arr) # if clusters were merged in gui, # the original id's are still in the kwiktree, but # in this array for cluster_id in chan_cluster_ids: cluster_frame_idx = np.nonzero(chan_cluster_id_arr == cluster_id) # the [()] is a h5py thing st = chan_cluster_times_arr[cluster_frame_idx] assert st.shape[0] > 0, 'no spikes in cluster' cluster_group = kf_reader.get(f'/channel_groups/{channel_group}/clusters/main/{cluster_id}').attrs['cluster_group'] assert cluster_group in self.default_cluster_groups.keys(), f'cluster_group not in "default_dict: {cluster_group}' cluster_group_name = self.default_cluster_groups[cluster_group] if cluster_group_name.lower() in cs_to_exclude: continue self._spiketrains.append(st) klusta_units.append(int(cluster_id)) unique_units.append(unit) unit += 1 groups.append(int(channel_group)) cluster_groups_name.append(cluster_group_name) if len(np.unique(klusta_units)) == len(np.unique(unique_units)): self._unit_ids = klusta_units else: print('Klusta units are not unique! Using unique unit ids') self._unit_ids = unique_units for i, u in enumerate(self._unit_ids): self.set_unit_property(u, 'group', groups[i]) self.set_unit_property(u, 'quality', cluster_groups_name[i].lower()) self._kwargs = {'file_or_folder_path': str(Path(file_or_folder_path).absolute())} def get_unit_ids(self): return list(self._unit_ids) @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): times = self._spiketrains[self.get_unit_ids().index(unit_id)] inds = np.where((start_frame <= times) & (times < end_frame)) return times[inds] ================================================ FILE: spikeextractors/extractors/matsortingextractor/__init__.py ================================================ from .matsortingextractor import MATSortingExtractor ================================================ FILE: spikeextractors/extractors/matsortingextractor/matsortingextractor.py ================================================ from collections import deque from pathlib import Path from typing import Union import numpy as np try: import h5py HAVE_H5PY = True except ImportError: HAVE_H5PY = False try: from scipy.io.matlab import loadmat, savemat HAVE_LOADMAT = True except ImportError: HAVE_LOADMAT = False try: import hdf5storage HAVE_HDF5STORAGE = True except ImportError: HAVE_HDF5STORAGE = False HAVE_MAT = HAVE_H5PY & HAVE_LOADMAT from spikeextractors import SortingExtractor PathType = Union[str, Path] class MATSortingExtractor(SortingExtractor): extractor_name = "MATSortingExtractor" installed = HAVE_MAT # check at class level if installed or not is_writable = False mode = "file" installation_mesg = "To use the MATSortingExtractor install h5py and scipy: " \ "\n\n pip install h5py scipy\n\n" # error message when not installed def __init__(self, file_path: PathType): assert self.installed, self.installation_mesg super().__init__() file_path = Path(file_path) if isinstance(file_path, str) else file_path if not isinstance(file_path, Path): raise TypeError(f"Expected a str or Path file_path but got '{type(file_path).__name__}'") file_path = file_path.resolve() # get absolute path to this file if not file_path.is_file(): raise ValueError(f"Specified file path '{file_path}' is not a file.") self._kwargs = {"file_path": str(file_path.absolute())} try: # load old-style (up to 7.2) .mat file self._data = loadmat(file_path, matlab_compatible=True) self._old_style_mat = True except NameError: # loadmat not defined raise ImportError("Old-style .mat file given, but `loadmat` is not defined.") except NotImplementedError: # new style .mat file try: self._data = h5py.File(file_path, "r+") self._old_style_mat = False except NameError: raise ImportError("Version 7.2 .mat file given, but you don't have h5py installed.") def __del__(self): if not self._old_style_mat: self._data.close() def _getfield(self, fieldname: str): def _drill(d: dict, keys: deque): if len(keys) == 1: return d[keys.popleft()] else: return _drill(d[keys.popleft()], keys) if self._old_style_mat: return _drill(self._data, deque(fieldname.split("/"))) else: return self._data[fieldname][()] @staticmethod def write_dict_to_mat(mat_file_path, dict_to_write, version='7.3'): # field must be a dict assert HAVE_HDF5STORAGE, "To use the MATSortingExtractor write_dict_to_mat function install hdf5storage: " \ "\n\n pip install hdf5storage\n\n" if version == '7.3': hdf5storage.write(dict_to_write, '/', mat_file_path, matlab_compatible=True, options='w') elif version < '7.3' and version > '4': savemat(mat_file_path, dict_to_write) ================================================ FILE: spikeextractors/extractors/maxwellextractors/__init__.py ================================================ from .maxwellextractors import MaxOneRecordingExtractor, MaxOneSortingExtractor, \ MaxTwoRecordingExtractor, MaxTwoSortingExtractor ================================================ FILE: spikeextractors/extractors/maxwellextractors/maxwellextractors.py ================================================ from spikeextractors import RecordingExtractor, SortingExtractor from pathlib import Path import numpy as np from spikeextractors.extraction_tools import check_get_traces_args, check_get_ttl_args, check_get_unit_spike_train try: import h5py HAVE_MAX = True except ImportError: HAVE_MAX = False installation_mesg = "To use the MaxOneRecordingExtractor install h5py: \n\n pip install h5py\n\n" class MaxOneRecordingExtractor(RecordingExtractor): extractor_name = 'MaxOneRecording' has_default_locations = True has_unscaled = True installed = HAVE_MAX # check at class level if installed or not is_writable = False mode = 'file' installation_mesg = installation_mesg def __init__(self, file_path, load_spikes=True, rec_name='rec0000'): assert self.installed, self.installation_mesg RecordingExtractor.__init__(self) self._file_path = file_path self._fs = None self._positions = None self._recordings = None self._filehandle = None self._load_spikes = load_spikes self._mapping = None self._rec_name = rec_name self._initialize() self._kwargs = {'file_path': str(Path(file_path).absolute()), 'load_spikes': load_spikes} def __del__(self): self._filehandle.close() def _initialize(self): self._filehandle = h5py.File(self._file_path, 'r') self._version = self._filehandle['version'][0].decode() if int(self._version) == 20160704: # old format self._mapping = self._filehandle['mapping'] if 'lsb' in self._filehandle['settings'].keys(): self._lsb = self._filehandle['settings']['lsb'][0] * 1e6 else: print("Couldn't read lsb. Setting lsb to 1") self._lsb = 1. channels = np.array(self._mapping['channel']) electrodes = np.array(self._mapping['electrode']) # remove unused channels routed_idxs = np.where(electrodes > -1)[0] self._channel_ids = list(channels[routed_idxs]) self._electrode_ids = list(electrodes[routed_idxs]) self._num_channels = len(self._channel_ids) self._fs = float(20000) self._signals = self._filehandle['sig'] self._num_frames = self._signals.shape[1] elif int(self._version) > 20160704: # new format well_name = 'well000' rec_name = self._rec_name settings = self._filehandle['wells'][well_name][rec_name]['settings'] self._mapping = settings['mapping'] if 'lsb' in settings.keys(): self._lsb = settings['lsb'][()][0] * 1e6 else: self._lsb = 1. channels = np.array(self._mapping['channel']) electrodes = np.array(self._mapping['electrode']) # remove unused channels routed_idxs = np.where(electrodes > -1)[0] self._channel_ids = list(channels[routed_idxs]) self._electrode_ids = list(electrodes[routed_idxs]) self._num_channels = len(self._channel_ids) self._fs = settings['sampling'][()][0] self._signals = self._filehandle['wells'][well_name][rec_name]['groups']['routed']['raw'] self._num_frames = self._signals.shape[1] else: raise Exception("Could not parse the MaxOne file") # This happens when only spikes are recorded if self._num_frames == 0: find_max_frame = True else: find_max_frame = False for i_ch, ch, el in zip(routed_idxs, self._channel_ids, self._electrode_ids): self.set_channel_locations([self._mapping['x'][i_ch], self._mapping['y'][i_ch]], ch) self.set_channel_property(ch, 'electrode', el) # set gains self.set_channel_gains(self._lsb) if self._load_spikes: if 'proc0' in self._filehandle: if 'spikeTimes' in self._filehandle['proc0']: spikes = self._filehandle['proc0']['spikeTimes'] spike_mask = [True] * len(spikes) for i, ch in enumerate(spikes['channel']): if ch not in self._channel_ids: spike_mask[i] = False spikes_channels = np.array(spikes['channel'])[spike_mask] if find_max_frame: self._num_frames = np.ptp(spikes['frameno']) # load activity as property activity_channels, counts = np.unique(spikes_channels, return_counts=True) # transform to spike rate duration = float(self._num_frames) / self._fs counts = counts.astype(float) / duration activity_channels = list(activity_channels) for ch in self.get_channel_ids(): if ch in activity_channels: self.set_channel_property(ch, 'spike_rate', counts[activity_channels.index(ch)]) spike_amplitudes = spikes[np.where(spikes['channel'] == ch)]['amplitude'] self.set_channel_property(ch, 'spike_amplitude', np.median(spike_amplitudes)) else: self.set_channel_property(ch, 'spike_rate', 0) self.set_channel_property(ch, 'spike_amplitude', 0) def get_channel_ids(self): return list(self._channel_ids) def get_electrode_ids(self): return list(self._electrode_ids) def get_num_frames(self): return self._num_frames def get_sampling_frequency(self): return self._fs def correct_for_missing_frames(self, verbose=False): """ Corrects for missing frames. The correct times can be retrieved with the frame_to_time and time_to_frame functions. Parameters ---------- verbose: bool If True, output is verbose """ frame_idxs_span = self._get_frame_number(self.get_num_frames() - 1) - self._get_frame_number(0) if frame_idxs_span > self.get_num_frames(): if verbose: print(f"Found missing frames! Correcting for it (this might take a while)") framenos = self._get_frame_numbers() # find missing frames diff_frames = np.diff(framenos) missing_frames_idxs = np.where(diff_frames > 1)[0] delays_in_frames = [] for mf_idx in missing_frames_idxs: delays_in_frames.append(diff_frames[mf_idx]) if verbose: print(f"Found {len(delays_in_frames)} missing intervals") times = np.round(np.arange(self.get_num_frames()) / self.get_sampling_frequency(), 6) for mf_idx, duration in zip(missing_frames_idxs, delays_in_frames): times[mf_idx:] += np.round(duration / self.get_sampling_frequency(), 6) self.set_times(times) else: if verbose: print("No missing frames found") def _get_frame_numbers(self): bitvals = self._signals[-2:, :] frame_nos = np.bitwise_or(np.left_shift(bitvals[-1].astype('int64'), 16), bitvals[0]) return frame_nos def _get_frame_number(self, index): bitvals = self._signals[-2:, index] frameno = bitvals[1] << 16 | bitvals[0] return frameno @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): if np.array(channel_ids).size > 1: if np.any(np.diff(channel_ids) < 0): sorted_channel_ids = np.sort(channel_ids) sorted_idx = np.array([list(sorted_channel_ids).index(ch) for ch in channel_ids]) traces = self._signals[sorted_channel_ids, start_frame:end_frame][sorted_idx] else: traces = self._signals[np.array(channel_ids), start_frame:end_frame] else: traces = self._signals[np.array(channel_ids), start_frame:end_frame] return traces @check_get_ttl_args def get_ttl_events(self, start_frame=None, end_frame=None, channel_id=0): bitvals = self._signals[-2:, 0] first_frame = bitvals[1] << 16 | bitvals[0] bits = self._filehandle['bits'] bit_frames = bits['frameno'] - first_frame bit_states = bits['bits'] bit_idxs = np.where((bit_frames >= start_frame) & (bit_frames < end_frame))[0] ttl_frames = bit_frames[bit_idxs] ttl_states = bit_states[bit_idxs] ttl_states[ttl_states == 0] = -1 return ttl_frames, ttl_states class MaxOneSortingExtractor(SortingExtractor): extractor_name = 'MaxOneSorting' installed = HAVE_MAX # check at class level if installed or not is_writable = False mode = 'file' installation_mesg = installation_mesg def __init__(self, file_path): assert self.installed, self.installation_mesg SortingExtractor.__init__(self) self._file_path = file_path self._filehandle = None self._mapping = None self._version = None self._initialize() self._kwargs = {'file_path': str(Path(file_path).absolute())} def _initialize(self): self._filehandle = h5py.File(self._file_path, 'r') self._mapping = self._filehandle['mapping'] self._signals = self._filehandle['sig'] bitvals = self._signals[-2:, 0] self._first_frame = bitvals[1] << 16 | bitvals[0] channels = np.array(self._mapping['channel']) electrodes = np.array(self._mapping['electrode']) # remove unused channels routed_idxs = np.where(electrodes > -1)[0] self._channel_ids = list(channels[routed_idxs]) self._unit_ids = list(electrodes[routed_idxs]) self._sampling_frequency = float(20000) self._spiketrains = [] self._unit_ids = [] try: spikes = self._filehandle['proc0']['spikeTimes'] for u in self._channel_ids: spiketrain_idx = np.where(spikes['channel'] == u)[0] if len(spiketrain_idx) > 0: self._unit_ids.append(u) spiketrain = spikes['frameno'][spiketrain_idx] - self._first_frame idxs_greater_0 = np.where(spiketrain >= 0)[0] self._spiketrains.append(spiketrain[idxs_greater_0]) self.set_unit_spike_features(u, 'amplitude', spikes['amplitude'][spiketrain_idx][idxs_greater_0]) except: raise AttributeError("Spike times information are missing from the .h5 file") def get_unit_ids(self): return self._unit_ids @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame) if start_frame is None: start_frame = 0 if end_frame is None: end_frame = np.Inf unit_idx = self._unit_ids.index(unit_id) spiketrain = self._spiketrains[unit_idx] inds = np.where((start_frame <= spiketrain) & (spiketrain < end_frame)) return spiketrain[inds] class MaxTwoRecordingExtractor(RecordingExtractor): extractor_name = 'MaxTwoRecording' has_default_locations = True has_unscaled = True installed = HAVE_MAX # check at class level if installed or not is_writable = False mode = 'file' installation_mesg = installation_mesg def __init__(self, file_path, well_name='well000', rec_name='rec0000', load_spikes=True): assert self.installed, self.installation_mesg RecordingExtractor.__init__(self) self._file_path = file_path self._well_name = well_name self._rec_name = rec_name self._fs = None self._positions = None self._recordings = None self._filehandle = None self._mapping = None self._load_spikes = load_spikes self._initialize() self._kwargs = {'file_path': str(Path(file_path).absolute()), 'well_name': well_name, 'rec_name': rec_name, 'load_spikes': load_spikes} def _initialize(self): self._filehandle = h5py.File(self._file_path, 'r') settings = self._filehandle['wells'][self._well_name][self._rec_name]['settings'] self._mapping = settings['mapping'] if 'lsb' in settings.keys(): self._lsb = settings['lsb'][()][0] * 1e6 else: self._lsb = 1. channels = np.array(self._mapping['channel']) electrodes = np.array(self._mapping['electrode']) # remove unused channels routed_idxs = np.where(electrodes > -1)[0] self._channel_ids = list(channels[routed_idxs]) self._electrode_ids = list(electrodes[routed_idxs]) self._num_channels = len(self._channel_ids) self._fs = settings['sampling'][()][0] self._signals = self._filehandle['wells'][self._well_name][self._rec_name]['groups']['routed']['raw'] self._num_frames = self._signals.shape[1] # This happens when only spikes are recorded if self._num_frames == 0: find_max_frame = True else: find_max_frame = False for i_ch, ch, el in zip(routed_idxs, self._channel_ids, self._electrode_ids): self.set_channel_locations([self._mapping['x'][i_ch], self._mapping['y'][i_ch]], ch) self.set_channel_property(ch, 'electrode', el) # set gains self.set_channel_gains(self._lsb) if self._load_spikes: if "spikes" in self._filehandle["wells"][self._well_name][self._rec_name].keys(): spikes = self._filehandle["wells"][self._well_name][self._rec_name]["spikes"] spike_mask = [True] * len(spikes) for i, ch in enumerate(spikes['channel']): if ch not in self._channel_ids: spike_mask[i] = False spikes_channels = np.array(spikes['channel'])[spike_mask] if find_max_frame: self._num_frames = np.ptp(spikes['frameno']) # load activity as property activity_channels, counts = np.unique(spikes_channels, return_counts=True) # transform to spike rate duration = float(self._num_frames) / self._fs counts = counts.astype(float) / duration activity_channels = list(activity_channels) for ch in self.get_channel_ids(): if ch in activity_channels: self.set_channel_property(ch, 'spike_rate', counts[activity_channels.index(ch)]) spike_amplitudes = spikes[np.where(spikes['channel'] == ch)]['amplitude'] self.set_channel_property(ch, 'spike_amplitude', np.median(spike_amplitudes)) else: self.set_channel_property(ch, 'spike_rate', 0) self.set_channel_property(ch, 'spike_amplitude', 0) def get_channel_ids(self): return list(self._channel_ids) def get_num_frames(self): return self._num_frames def get_sampling_frequency(self): return self._fs @staticmethod def get_well_names(file_path): with h5py.File(file_path, 'r') as f: wells = list(f["wells"]) return wells @staticmethod def get_recording_names(file_path, well_name): with h5py.File(file_path, 'r') as f: assert well_name in f["wells"], f"Well name should be among: " \ f"{MaxTwoRecordingExtractor.get_well_names(file_path)}" rec_names = list(f["wells"][well_name].keys()) return rec_names @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): channel_idxs = np.array([self.get_channel_ids().index(ch) for ch in channel_ids]) if np.array(channel_idxs).size > 1: if np.any(np.diff(channel_idxs) < 0): sorted_channel_ids = np.sort(channel_idxs) sorted_idx = np.array([list(sorted_channel_ids).index(ch) for ch in channel_idxs]) traces = self._signals[sorted_channel_ids, start_frame:end_frame][sorted_idx] else: traces = self._signals[np.array(channel_idxs), start_frame:end_frame] else: traces = self._signals[np.array(channel_idxs), start_frame:end_frame] return traces class MaxTwoSortingExtractor(SortingExtractor): extractor_name = 'MaxTwoSorting' installed = HAVE_MAX # check at class level if installed or not is_writable = False mode = 'file' installation_mesg = installation_mesg def __init__(self, file_path, well_name='well000', rec_name='rec0000'): assert self.installed, self.installation_mesg SortingExtractor.__init__(self) self._file_path = file_path self._well_name = well_name self._rec_name = rec_name self._filehandle = None self._mapping = None self._version = None self._initialize() self._sampling_frequency = self._fs self._kwargs = {'file_path': str(Path(file_path).absolute()), 'well_name': well_name, 'rec_name': rec_name} def _initialize(self): self._filehandle = h5py.File(self._file_path, 'r') settings = self._filehandle['wells'][self._well_name][self._rec_name]['settings'] self._mapping = settings['mapping'] if 'lsb' in settings.keys(): self._lsb = settings['lsb'][()] * 1e6 else: self._lsb = 1. channels = np.array(self._mapping['channel']) electrodes = np.array(self._mapping['electrode']) # remove unused channels routed_idxs = np.where(electrodes > -1)[0] self._channel_ids = list(channels[routed_idxs]) self._unit_ids = list(electrodes[routed_idxs]) self._fs = settings['sampling'][()][0] self._first_frame = self._filehandle['wells'][self._well_name][self._rec_name] \ ['groups']['routed']['frame_nos'][0] self._spiketrains = [] self._unit_ids = [] try: spikes = self._filehandle["wells"][self._well_name][self._rec_name]["spikes"] for u in self._channel_ids: spiketrain_idx = np.where(spikes['channel'] == u)[0] if len(spiketrain_idx) > 0: self._unit_ids.append(u) spiketrain = spikes['frameno'][spiketrain_idx] - self._first_frame idxs_greater_0 = np.where(spiketrain >= 0)[0] self._spiketrains.append(spiketrain[idxs_greater_0]) self.set_unit_spike_features(u, 'amplitude', spikes['amplitude'][spiketrain_idx][idxs_greater_0]) except: raise AttributeError("Spike times information are missing from the .h5 file") def get_unit_ids(self): return self._unit_ids @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame) if start_frame is None: start_frame = 0 if end_frame is None: end_frame = np.Inf unit_idx = self._unit_ids.index(unit_id) spiketrain = self._spiketrains[unit_idx] inds = np.where((start_frame <= spiketrain) & (spiketrain < end_frame)) return spiketrain[inds] ================================================ FILE: spikeextractors/extractors/mcsh5recordingextractor/__init__.py ================================================ from .mcsh5recordingextractor import MCSH5RecordingExtractor ================================================ FILE: spikeextractors/extractors/mcsh5recordingextractor/mcsh5recordingextractor.py ================================================ from spikeextractors import RecordingExtractor import numpy as np from pathlib import Path from spikeextractors.extraction_tools import check_get_traces_args try: import h5py HAVE_MCSH5 = True except ImportError: HAVE_MCSH5 = False class MCSH5RecordingExtractor(RecordingExtractor): extractor_name = 'MCSH5Recording' has_default_locations = False has_unscaled = False installed = HAVE_MCSH5 # check at class level if installed or not is_writable = False mode = 'file' installation_mesg = "To use the MCSH5RecordingExtractor install h5py: \n\n pip install h5py\n\n" # error message when not installed def __init__(self, file_path, stream_id=0, verbose=False): assert self.installed, self.installation_mesg self._recording_file = file_path self._verbose = verbose self._available_stream_ids = self.get_available_stream_ids() self.set_stream_id(stream_id) RecordingExtractor.__init__(self) self._kwargs = {'file_path': str(Path(file_path).absolute()), 'stream_id': stream_id, 'verbose': verbose} def __del__(self): self._rf.close() def get_channel_ids(self): return list(self._channel_ids) def get_num_frames(self): return self._nFrames def get_sampling_frequency(self): return self._samplingRate def set_stream_id(self, stream_id): assert stream_id in self._available_stream_ids, "The specified stream ID is unavailable." self._stream_id = stream_id if hasattr(self, '_rf'): self._rf.close() self._rf, self._nFrames, self._samplingRate, self._nRecCh, \ self._channel_ids, self._electrodeLabels, self._exponent, self._convFact \ = openMCSH5File(self._recording_file, stream_id, self._verbose) def get_stream_id(self): assert hasattr(self, '_stream_id'), "Stream ID has not been set yet." return self._stream_id def get_available_stream_ids(self): if hasattr(self, '_available_stream_ids'): return self._available_stream_ids else: rf = h5py.File(self._recording_file, 'r') analog_stream_names = list(rf.require_group('/Data/Recording_0/AnalogStream').keys()) return list(range(len(analog_stream_names))) @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): channel_idxs = [] for m in channel_ids: assert m in self._channel_ids, 'channel_id {} not found'.format(m) channel_idxs.append(np.where(np.array(self._channel_ids) == m)[0][0]) stream = self._rf.require_group('/Data/Recording_0/AnalogStream/Stream_' + str(self._stream_id)) conv = self._convFact.astype(float) * (10.0 ** self._exponent) if np.array(channel_idxs).size > 1: if np.any(np.diff(channel_idxs) < 0): sorted_channel_ids = np.sort(channel_idxs) sorted_idx = np.array([list(sorted_channel_ids).index(ch) for ch in channel_idxs]) signals = stream.get('ChannelData')[sorted_channel_ids, start_frame:end_frame][sorted_idx] else: signals = stream.get('ChannelData')[np.sort(channel_idxs), start_frame:end_frame] else: signals = stream.get('ChannelData')[np.array(channel_idxs), start_frame:end_frame] if return_scaled: return signals * conv else: return signals def openMCSH5File(filename, stream_id, verbose=False): """Open an MCS hdf5 file, read and return the recording info.""" rf = h5py.File(filename, 'r') stream_name = 'Stream_' + str(stream_id) analog_stream_names = list(rf.require_group('/Data/Recording_0/AnalogStream').keys()) assert stream_name in analog_stream_names, "Specified stream does not exist." stream = rf.require_group('/Data/Recording_0/AnalogStream/' + stream_name) data = np.array(stream.get('ChannelData'), dtype=np.int) timestamps = np.array(stream.get('ChannelDataTimeStamps')) info = np.array(stream.get('InfoChannel')) Unit = info['Unit'][0] Tick = info['Tick'][0] / 1e6 exponent = info['Exponent'][0] convFact = info['ConversionFactor'][0] nRecCh, nFrames = data.shape channel_ids = info['ChannelID'] assert len(np.unique(channel_ids)) == len(channel_ids), 'Duplicate MCS channel IDs found' electrodeLabels = info['Label'] assert timestamps[0][0] < timestamps[0][2], 'Please check the validity of \'ChannelDataTimeStamps\' in the stream.' TimeVals = np.arange(timestamps[0][0], timestamps[0][2] + 1, 1) * Tick assert Unit == b'V', 'Unexpected units found, expected volts, found {}'.format(Unit.decode('UTF-8')) data_V = data * convFact.astype(float) * (10.0 ** (exponent)) timestep_avg = np.mean(TimeVals[1:] - TimeVals[0:-1]) timestep_std = np.std(TimeVals[1:] - TimeVals[0:-1]) timestep_min = np.min(TimeVals[1:] - TimeVals[0:-1]) timestep_max = np.min(TimeVals[1:] - TimeVals[0:-1]) assert all(np.abs(np.array( (timestep_min, timestep_max)) - timestep_avg) / timestep_avg < 1e-6), 'Time steps vary by more than 1 ppm' samplingRate = 1. / timestep_avg if verbose: print('# MCS H5 data format') print('#') print('# File: {}'.format(rf.filename)) print('# File size: {:.2f} MB'.format(rf.id.get_filesize() / 1024 ** 2)) print('#') for key in rf.attrs.keys(): print('# {}: {}'.format(key, rf.attrs[key])) print('#') print('# Signal range: {:.2f} to {:.2f} µV'.format(np.amin(data_V) * 1e6, np.amax(data_V) * 1e6)) print('# Number of channels: {}'.format(nRecCh)) print('# Number of frames: {}'.format(nFrames)) print('# Time step: {:.2f} µs ± {:.5f} % (range {} to {})'.format(timestep_avg * 1e6, timestep_std / timestep_avg * 100, timestep_min * 1e6, timestep_max * 1e6)) print('# Sampling rate: {:.2f} Hz'.format(samplingRate)) print('#') print('# MCSH5RecordingExtractor currently only reads /Data/Recording_0/AnalogStream/Stream_0') return rf, nFrames, samplingRate, nRecCh, channel_ids, electrodeLabels, exponent, convFact ================================================ FILE: spikeextractors/extractors/mdaextractors/__init__.py ================================================ from .mdaextractors import MdaRecordingExtractor, MdaSortingExtractor ================================================ FILE: spikeextractors/extractors/mdaextractors/mdaextractors.py ================================================ from spikeextractors import RecordingExtractor from spikeextractors import SortingExtractor from spikeextractors.extraction_tools import write_to_binary_dat_format, check_get_traces_args, \ check_get_unit_spike_train import json import numpy as np from pathlib import Path from .mdaio import DiskReadMda, readmda, writemda64, MdaHeader import shutil class MdaRecordingExtractor(RecordingExtractor): extractor_name = 'MdaRecording' has_default_locations = True has_unscaled = False installed = True # check at class level if installed or not is_writable = True mode = 'folder' installation_mesg = "" # error message when not installed def __init__(self, folder_path, raw_fname='raw.mda', params_fname='params.json', geom_fname='geom.csv'): dataset_directory = Path(folder_path) self._dataset_directory = dataset_directory timeseries0 = dataset_directory / raw_fname self._dataset_params = read_dataset_params(dataset_directory, params_fname) self._sampling_frequency = self._dataset_params['samplerate'] * 1.0 self._timeseries_path = str(timeseries0.absolute()) geom0 = dataset_directory / geom_fname self._geom_fname = geom0 self._geom = np.loadtxt(self._geom_fname, delimiter=',', ndmin=2) X = DiskReadMda(self._timeseries_path) if self._geom.shape[0] != X.N1(): raise Exception( 'Incompatible dimensions between geom.csv and timeseries file {} <> {}'.format(self._geom.shape[0], X.N1())) self._num_channels = X.N1() self._num_timepoints = X.N2() RecordingExtractor.__init__(self) self.set_channel_locations(self._geom) self._kwargs = {'folder_path': str(Path(folder_path).absolute())} def get_channel_ids(self): return list(range(self._num_channels)) def get_num_frames(self): return self._num_timepoints def get_sampling_frequency(self): return self._sampling_frequency @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): X = DiskReadMda(self._timeseries_path) recordings = X.readChunk(i1=0, i2=start_frame, N1=X.N1(), N2=end_frame - start_frame) recordings = recordings[channel_ids, :] return recordings def write_to_binary_dat_format(self, save_path, time_axis=0, dtype=None, chunk_size=None, chunk_mb=500, n_jobs=1, joblib_backend='loky', verbose=False): """Saves the traces of this recording extractor into binary .dat format. Parameters ---------- save_path: str The path to the file. time_axis: 0 (default) or 1 If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file. If 1, the traces shape (nb_channel, nb_sample) is kept in the file. dtype: dtype Type of the saved data. Default float32 chunk_size: None or int Size of each chunk in number of frames. If None (default) and 'chunk_mb' is given, the file is saved in chunks of 'chunk_mb' Mb (default 500Mb) chunk_mb: None or int Chunk size in Mb (default 500Mb) n_jobs: int Number of jobs to use (Default 1) joblib_backend: str Joblib backend for parallel processing ('loky', 'threading', 'multiprocessing') verbose: bool If True, output is verbose """ X = DiskReadMda(self._timeseries_path) header_size = X._header.header_size if dtype is None or dtype == self.get_dtype(): try: with open(self._timeseries_path, 'rb') as src, open(save_path, 'wb') as dst: src.seek(header_size) shutil.copyfileobj(src, dst) except Exception as e: print('Error occurred while copying:', e) print('Writing to binary') write_to_binary_dat_format(self, save_path=save_path, time_axis=time_axis, dtype=dtype, chunk_size=chunk_size, chunk_mb=chunk_mb, n_jobs=n_jobs, joblib_backend=joblib_backend, verbose=verbose) else: write_to_binary_dat_format(self, save_path=save_path, time_axis=time_axis, dtype=dtype, chunk_size=chunk_size, chunk_mb=chunk_mb, n_jobs=n_jobs, joblib_backend=joblib_backend, verbose=verbose) @staticmethod def write_recording(recording, save_path, params=dict(), raw_fname='raw.mda', params_fname='params.json', geom_fname='geom.csv', dtype=None, chunk_size=None, n_jobs=None, chunk_mb=500, verbose=False): """ Writes recording to file in MDA format. Parameters ---------- recording: RecordingExtractor The recording extractor to be saved save_path: str or Path The folder in which the Mda files are saved params: dictionary Dictionary with optional parameters to save metadata. Sampling frequency is appended to this dictionary. raw_fname: str File name of raw file (default raw.mda) params_fname: str File name of params file (default params.json) geom_fname: str File name of geom file (default geom.csv) dtype: dtype dtype to be used. If None dtype is same as recording traces. chunk_size: None or int Size of each chunk in number of frames. If None (default) and 'chunk_mb' is given, the file is saved in chunks of 'chunk_mb' Mb (default 500Mb) n_jobs: int Number of jobs to use (Default 1) chunk_mb: None or int Chunk size in Mb (default 500Mb) verbose: bool If True, output is verbose """ save_path = Path(save_path) save_path.mkdir(parents=True, exist_ok=True) save_file_path = save_path / raw_fname parent_dir = save_path channel_ids = recording.get_channel_ids() num_chan = recording.get_num_channels() num_frames = recording.get_num_frames() geom = recording.get_channel_locations() if dtype is None: dtype = recording.get_dtype() if dtype == 'float': dtype = 'float32' if dtype == 'int': dtype = 'int16' with save_file_path.open('wb') as f: header = MdaHeader(dt0=dtype, dims0=(num_chan, num_frames)) header.write(f) # takes care of the chunking write_to_binary_dat_format(recording, file_handle=f, dtype=dtype, n_jobs=n_jobs, chunk_size=chunk_size, chunk_mb=chunk_mb, verbose=verbose) params["samplerate"] = float(recording.get_sampling_frequency()) with (parent_dir / params_fname).open('w') as f: json.dump(params, f) np.savetxt(str(parent_dir / geom_fname), geom, delimiter=',') class MdaSortingExtractor(SortingExtractor): extractor_name = 'MdaSorting' installed = True # check at class level if installed or not is_writable = True mode = 'file' installation_mesg = "" # error message when not installed def __init__(self, file_path, sampling_frequency=None): SortingExtractor.__init__(self) self._firings_path = file_path self._firings = readmda(self._firings_path) self._max_channels = self._firings[0, :] self._spike_times = self._firings[1, :] self._labels = self._firings[2, :] self._unit_ids = np.unique(self._labels).astype(int) self._sampling_frequency = sampling_frequency self._kwargs = {'file_path': str(Path(file_path).absolute()), 'sampling_frequency': sampling_frequency} def get_unit_ids(self): return list(self._unit_ids) @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): inds = np.where( (self._labels == unit_id) & (start_frame <= self._spike_times) & (self._spike_times < end_frame)) return np.rint(self._spike_times[inds]).astype(int) @staticmethod def write_sorting(sorting, save_path, write_primary_channels=False): unit_ids = sorting.get_unit_ids() times_list = [] labels_list = [] primary_channels_list = [] for unit_id in unit_ids: times = sorting.get_unit_spike_train(unit_id=unit_id) times_list.append(times) labels_list.append(np.ones(times.shape) * unit_id) if write_primary_channels: if 'max_channel' in sorting.get_unit_property_names(unit_id): primary_channels_list.append([sorting.get_unit_property(unit_id, 'max_channel')] * times.shape[0]) else: raise ValueError( "Unable to write primary channels because 'max_channel' spike feature not set in unit " + str( unit_id)) else: primary_channels_list.append(np.zeros(times.shape)) all_times = _concatenate(times_list) all_labels = _concatenate(labels_list) all_primary_channels = _concatenate(primary_channels_list) sort_inds = np.argsort(all_times) all_times = all_times[sort_inds] all_labels = all_labels[sort_inds] all_primary_channels = all_primary_channels[sort_inds] L = len(all_times) firings = np.zeros((3, L)) firings[0, :] = all_primary_channels firings[1, :] = all_times firings[2, :] = all_labels writemda64(firings, save_path) def _concatenate(list): if len(list) == 0: return np.array([]) return np.concatenate(list) def read_dataset_params(dsdir, params_fname): fname1 = dsdir / params_fname if not fname1.is_file(): raise Exception('Dataset parameter file does not exist: ' + fname1) with open(fname1) as f: return json.load(f) ================================================ FILE: spikeextractors/extractors/mdaextractors/mdaio.py ================================================ import numpy as np import struct import os import tempfile import traceback from pathlib import Path class MdaHeader: def __init__(self, dt0, dims0): uses64bitdims = (max(dims0) > 2e9) self.uses64bitdims = uses64bitdims self.dt_code = _dt_code_from_dt(dt0) self.dt = dt0 self.num_bytes_per_entry = get_num_bytes_per_entry_from_dt(dt0) self.num_dims = len(dims0) self.dimprod = np.prod(dims0) self.dims = dims0 if uses64bitdims: self.header_size = 3 * 4 + self.num_dims * 8 else: self.header_size = (3 + self.num_dims) * 4 def write(self, f): H = self _write_int32(f, H.dt_code) _write_int32(f, H.num_bytes_per_entry) if H.uses64bitdims: _write_int32(f, -H.num_dims) for j in range(0, H.num_dims): _write_int64(f, H.dims[j]) else: _write_int32(f, H.num_dims) for j in range(0, H.num_dims): _write_int32(f, H.dims[j]) def npy_dtype_to_string(dt): str = dt.str[1:] map = { "f2": 'float16', "f4": 'float32', "f8": 'float64', "i1": 'int8', "i2": 'int16', "i4": 'int32', "u2": 'uint16', "u4": 'uint32' } return map[str] class DiskReadMda: def __init__(self, path, header=None): self._npy_mode = False self._path = path if file_extension(path) == '.npy': raise Exception('DiskReadMda implementation has not been tested for npy files') self._npy_mode = True if header: raise Exception('header not allowed in npy mode for DiskReadMda') if header: self._header = header self._header.header_size = 0 else: self._header = _read_header(self._path) def dims(self): if self._npy_mode: A = np.load(self._path, mmap_mode='r') return A.shape return self._header.dims def N1(self): return self.dims()[0] def N2(self): return self.dims()[1] def N3(self): return self.dims()[2] def dt(self): if self._npy_mode: A = np.load(self._path, mmap_mode='r') return npy_dtype_to_string(A.dtype) return self._header.dt def numBytesPerEntry(self): if self._npy_mode: A = np.load(self._path, mmap_mode='r') return A.itemsize return self._header.num_bytes_per_entry def readChunk(self, i1=-1, i2=-1, i3=-1, N1=1, N2=1, N3=1): # print("Reading chunk {} {} {} {} {} {}".format(i1,i2,i3,N1,N2,N3)) if i2 < 0: if self._npy_mode: A = np.load(self._path, mmap_mode='r') return A[:, :, i1:i1 + N1] return self._read_chunk_1d(i1, N1) elif i3 < 0: if N1 != self.N1(): print("Unable to support N1 {} != {}".format(N1, self.N1())) return None X = self._read_chunk_1d(i1 + N1 * i2, N1 * N2) if X is None: print('Problem reading chunk from file: ' + self._path) return None if self._npy_mode: A = np.load(self._path, mmap_mode='r') return A[:, i2:i2 + N2] return np.reshape(X, (N1, N2), order='F') else: if N1 != self.N1(): print("Unable to support N1 {} != {}".format(N1, self.N1())) return None if N2 != self.N2(): print("Unable to support N2 {} != {}".format(N2, self.N2())) return None if self._npy_mode: A = np.load(self._path, mmap_mode='r') return A[:, :, i3:i3 + N3] X = self._read_chunk_1d(i1 + N1 * i2 + N1 * N2 * i3, N1 * N2 * N3) return np.reshape(X, (N1, N2, N3), order='F') def _read_chunk_1d(self, i, N): offset = self._header.header_size + self._header.num_bytes_per_entry * i if is_url(self._path): tmp_fname = _download_bytes_to_tmpfile(self._path, offset, offset + self._header.num_bytes_per_entry * N) try: ret = self._read_chunk_1d_helper(tmp_fname, N, offset=0) except: ret = None return ret return self._read_chunk_1d_helper(self._path, N, offset=offset) def _read_chunk_1d_helper(self, path0, N, *, offset): f = open(path0, "rb") try: f.seek(offset) ret = np.fromfile(f, dtype=self._header.dt, count=N) f.close() return ret except Exception as e: # catch *all* exceptions print(e) f.close() return None def is_url(path): return path.startswith('http://') or path.startswith('https://') def _download_bytes_to_tmpfile(url, start, end): try: import requests except: raise Exception('Unable to import module: requests') headers = {"Range": "bytes={}-{}".format(start, end - 1)} r = requests.get(url, headers=headers, stream=True) fd, tmp_fname = tempfile.mkstemp() with open(tmp_fname, 'wb') as f: for chunk in r.iter_content(chunk_size=1024): if chunk: f.write(chunk) return tmp_fname def _read_header(path): if is_url(path): tmp_fname = _download_bytes_to_tmpfile(path, 0, 200) if not tmp_fname: raise Exception('Problem downloading bytes from ' + path) try: ret = _read_header(tmp_fname) except: ret = None Path(tmp_fname).unlink() return ret f = open(path, "rb") try: dt_code = _read_int32(f) num_bytes_per_entry = _read_int32(f) num_dims = _read_int32(f) uses64bitdims = False if num_dims < 0: uses64bitdims = True num_dims = -num_dims if num_dims < 1 or num_dims > 6: # allow single dimension as of 12/6/17 print("Invalid number of dimensions: {}".format(num_dims)) f.close() return None dims = [] dimprod = 1 if uses64bitdims: for j in range(0, num_dims): tmp0 = _read_int64(f) dimprod = dimprod * tmp0 dims.append(tmp0) else: for j in range(0, num_dims): tmp0 = _read_int32(f) dimprod = dimprod * tmp0 dims.append(tmp0) dt = _dt_from_dt_code(dt_code) if dt is None: print("Invalid data type code: {}".format(dt_code)) f.close() return None H = MdaHeader(dt, dims) if uses64bitdims: H.uses64bitdims = True H.header_size = 3 * 4 + H.num_dims * 8 f.close() return H except Exception as e: # catch *all* exceptions print(e) f.close() return None def _dt_from_dt_code(dt_code): if dt_code == -2: dt = 'uint8' elif dt_code == -3: dt = 'float32' elif dt_code == -4: dt = 'int16' elif dt_code == -5: dt = 'int32' elif dt_code == -6: dt = 'uint16' elif dt_code == -7: dt = 'float64' elif dt_code == -8: dt = 'uint32' else: dt = None return dt def _dt_code_from_dt(dt): if dt == 'uint8': return -2 if dt == 'float32': return -3 if dt == 'int16': return -4 if dt == 'int32': return -5 if dt == 'uint16': return -6 if dt == 'float64': return -7 if dt == 'uint32': return -8 return None def get_num_bytes_per_entry_from_dt(dt): if dt == 'uint8': return 1 if dt == 'float32': return 4 if dt == 'int16': return 2 if dt == 'int32': return 4 if dt == 'uint16': return 2 if dt == 'float64': return 8 if dt == 'uint32': return 4 return None def readmda_header(path): if file_extension(path) == '.npy': raise Exception('Cannot read mda header for .npy file.') return _read_header(path) def _write_header(path, H, rewrite=False): if rewrite: f = open(path, "r+b") else: f = open(path, "wb") try: _write_int32(f, H.dt_code) _write_int32(f, H.num_bytes_per_entry) if H.uses64bitdims: _write_int32(f, -H.num_dims) for j in range(0, H.num_dims): _write_int64(f, H.dims[j]) else: _write_int32(f, H.num_dims) for j in range(0, H.num_dims): _write_int32(f, H.dims[j]) f.close() return True except Exception as e: # catch *all* exceptions print(e) f.close() return False def readmda(path): if file_extension(path) == '.npy': return readnpy(path); H = _read_header(path) if H is None: print("Problem reading header of: {}".format(path)) return None f = open(path, "rb") try: f.seek(H.header_size) # This is how I do the column-major order ret = np.fromfile(f, dtype=H.dt, count=H.dimprod) ret = np.reshape(ret, H.dims, order='F') f.close() return ret except Exception as e: # catch *all* exceptions print(e) f.close() return None def writemda32(X, fname): if file_extension(fname) == '.npy': return writenpy32(X, fname) return _writemda(X, fname, 'float32') def writemda64(X, fname): if file_extension(fname) == '.npy': return writenpy64(X, fname) return _writemda(X, fname, 'float64') def writemda8(X, fname): if file_extension(fname) == '.npy': return writenpy8(X, fname) return _writemda(X, fname, 'uint8') def writemda32i(X, fname): if file_extension(fname) == '.npy': return writenpy32i(X, fname) return _writemda(X, fname, 'int32') def writemda32ui(X, fname): if file_extension(fname) == '.npy': return writenpy32ui(X, fname) return _writemda(X, fname, 'uint32') def writemda16i(X, fname): if file_extension(fname) == '.npy': return writenpy16i(X, fname) return _writemda(X, fname, 'int16') def writemda16ui(X, fname): if file_extension(fname) == '.npy': return writenpy16ui(X, fname) return _writemda(X, fname, 'uint16') def writemda(X, fname, *, dtype): return _writemda(X, fname, dtype) def _writemda(X, fname, dt): num_bytes_per_entry = get_num_bytes_per_entry_from_dt(dt) dt_code = _dt_code_from_dt(dt) if dt_code is None: print("Unexpected data type: {}".format(dt)) return False if type(fname) == str: f = open(fname, 'wb') else: f = fname try: _write_int32(f, dt_code) _write_int32(f, num_bytes_per_entry) _write_int32(f, X.ndim) for j in range(0, X.ndim): _write_int32(f, X.shape[j]) # This is how I do column-major order # A=np.reshape(X,X.size,order='F').astype(dt) # A.tofile(f) bytes0 = X.astype(dt).tobytes(order='F') f.write(bytes0) if type(fname) == str: f.close() return True except Exception as e: # catch *all* exceptions traceback.print_exc() print(e) if type(fname) == str: f.close() return False def readnpy(path): return np.load(path) def writenpy8(X, path): return _writenpy(X, path, dtype='int8') def writenpy32(X, path): return _writenpy(X, path, dtype='float32') def writenpy64(X, path): return _writenpy(X, path, dtype='float64') def writenpy16i(X, path): return _writenpy(X, path, dtype='int16') def writenpy16ui(X, path): return _writenpy(X, path, dtype='uint16') def writenpy32i(X, path): return _writenpy(X, path, dtype='int32') def writenpy32ui(X, path): return _writenpy(X, path, dtype='uint32') def writenpy(X, path, *, dtype): return _writenpy(X, path, dtype=dtype) def _writenpy(X, path, *, dtype): np.save(path, X.astype(dtype=dtype, copy=False)) # astype will always create copy if dtype does not match # apparently allowing pickling is a security issue. (according to the docs) ?? # np.save(path,X.astype(dtype=dtype,copy=False),allow_pickle=False) # astype will always create copy if dtype does not match return True def appendmda(X, path): if file_extension(path) == '.npy': raise Exception('appendmda not yet implemented for .npy files') H = _read_header(path) if H is None: print("Problem reading header of: {}".format(path)) return None if len(H.dims) != len(X.shape): print("Incompatible number of dimensions in appendmda", H.dims, X.shape) return None num_entries_old = np.product(H.dims) num_dims = len(H.dims) for j in range(num_dims - 1): if X.shape[j] != X.shape[j]: print("Incompatible dimensions in appendmda", H.dims, X.shape) return None H.dims[num_dims - 1] = H.dims[num_dims - 1] + X.shape[num_dims - 1] try: _write_header(path, H, rewrite=True) f = open(path, "r+b") f.seek(H.header_size + H.num_bytes_per_entry * num_entries_old) A = np.reshape(X, X.size, order='F').astype(H.dt) A.tofile(f) f.close() except Exception as e: # catch *all* exceptions print(e) f.close() return False def file_extension(fname): if type(fname) == str: filename, ext = os.path.splitext(fname) return ext else: return None def _read_int32(f): return struct.unpack(' 6: # allow single dimension as of 12/6/17 print("Invalid number of dimensions: {}".format(num_dims)) return None dims = [] dimprod = 1 if uses64bitdims: for j in range(0, num_dims): tmp0 = _read_int64(f) dimprod = dimprod * tmp0 dims.append(tmp0) else: for j in range(0, num_dims): tmp0 = _read_int32(f) dimprod = dimprod * tmp0 dims.append(tmp0) dt = _dt_from_dt_code(dt_code) if dt is None: print("Invalid data type code: {}".format(dt_code)) return None H = MdaHeader(dt, dims) if uses64bitdims: H.uses64bitdims = True H.header_size = 3 * 4 + H.num_dims * 8 return H except Exception as e: # catch *all* exceptions print(e) return None ================================================ FILE: spikeextractors/extractors/mearecextractors/__init__.py ================================================ from .mearecextractors import MEArecRecordingExtractor, MEArecSortingExtractor ================================================ FILE: spikeextractors/extractors/mearecextractors/mearecextractors.py ================================================ from spikeextractors import RecordingExtractor from spikeextractors import SortingExtractor from spikeextractors.extraction_tools import check_get_traces_args, check_get_unit_spike_train import numpy as np from pathlib import Path from packaging.version import parse try: import MEArec as mr import neo import quantities as pq if parse(mr.__version__) >= parse('1.5.0'): HAVE_MREX = True else: print("MEArec version requires an update (>=1.5). Please upgrade with 'pip install --upgrade MEArec'") HAVE_MREX = False except ImportError: HAVE_MREX = False class MEArecRecordingExtractor(RecordingExtractor): extractor_name = 'MEArecRecording' has_default_locations = True has_unscaled = False installed = HAVE_MREX # check at class level if installed or not is_writable = True mode = 'file' installation_mesg = "To use the MEArec extractors, install MEArec: \n\n pip install MEArec\n\n" # error message when not installed def __init__(self, file_path, locs_2d=True): assert self.installed, self.installed self._recording_path = file_path self._fs = None self._positions = None self._recordings = None self._recgen = None self._locs_2d = locs_2d self._locations = None self._initialize() RecordingExtractor.__init__(self) if self._locations is not None: self.set_channel_locations(self._locations) self._kwargs = {'file_path': str(Path(file_path).absolute()), 'locs_2d': locs_2d} def _initialize(self): self._recgen = mr.load_recordings(recordings=self._recording_path, return_h5_objects=True, check_suffix=False, load=['recordings', 'channel_positions']) self._fs = self._recgen.info['recordings']['fs'] self._recordings = self._recgen.recordings self._num_frames, self._num_channels = self._recordings.shape if len(np.array(self._recgen.channel_positions)) == self._num_channels: self._locations = np.array(self._recgen.channel_positions) if self._locs_2d: if 'electrodes' in self._recgen.info.keys(): if 'plane' in self._recgen.info['electrodes'].keys(): probe_plane = self._recgen.info['electrodes']['plane'] if probe_plane == 'xy': self._locations = self._locations[:, :2] elif probe_plane == 'yz': self._locations = self._locations[:, 1:] elif probe_plane == 'xz': self._locations = self._locations[:, [0, 2]] if self._locations.shape[1] == 3: self._locations = self._locations[:, 1:] def get_channel_ids(self): return list(range(self._num_channels)) def get_num_frames(self): return self._num_frames def get_sampling_frequency(self): return self._fs @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): if np.any(np.diff(channel_ids) < 0): sorted_channel_ids = np.sort(channel_ids) sorted_idx = np.array([list(sorted_channel_ids).index(ch) for ch in channel_ids]) recordings = self._recordings[start_frame:end_frame, sorted_channel_ids.tolist()] return np.array(recordings[:, sorted_idx]).T else: if sorted(channel_ids) == channel_ids and np.all(np.diff(channel_ids) == 1): channel_ids = slice(channel_ids[0], channel_ids[0] + len(channel_ids)) return np.array(self._recordings[start_frame:end_frame, channel_ids]).T @staticmethod def write_recording(recording, save_path, check_suffix=True): """ Save recording extractor to MEArec format. Parameters ---------- recording: RecordingExtractor Recording extractor object to be saved save_path: str .h5 or .hdf5 path """ assert HAVE_MREX, MEArecRecordingExtractor.installation_mesg save_path = Path(save_path) if save_path.is_dir(): print("The file will be saved as recording.h5 in the provided folder") save_path = save_path / 'recording.h5' if (save_path.suffix == '.h5' or save_path.suffix == '.hdf5') or (not check_suffix): info = {'recordings': {'fs': recording.get_sampling_frequency()}} rec_dict = {'recordings': recording.get_traces().transpose()} if 'location' in recording.get_shared_channel_property_names(): positions = recording.get_channel_locations() rec_dict['channel_positions'] = positions recgen = mr.RecordingGenerator(rec_dict=rec_dict, info=info) mr.save_recording_generator(recgen, str(save_path), verbose=False) else: raise Exception("Provide a folder or an .h5/.hdf5 as 'save_path'") class MEArecSortingExtractor(SortingExtractor): extractor_name = 'MEArecSorting' installed = HAVE_MREX # check at class level if installed or not is_writable = True mode = 'file' installation_mesg = "To use the MEArec extractors, install MEArec: \n\n pip install MEArec\n\n" # error message when not installed def __init__(self, file_path): assert self.installed, self.installed SortingExtractor.__init__(self) self._recording_path = file_path self._num_units = None self._spike_trains = None self._unit_ids = None self._fs = None self._initialize() self._kwargs = {'file_path': str(Path(file_path).absolute())} def _initialize(self): recgen = mr.load_recordings(recordings=self._recording_path, return_h5_objects=True, check_suffix=False, load=['spiketrains']) self._num_units = len(recgen.spiketrains) if 'unit_id' in recgen.spiketrains[0].annotations: self._unit_ids = [int(st.annotations['unit_id']) for st in recgen.spiketrains] else: self._unit_ids = list(range(self._num_units)) self._spike_trains = recgen.spiketrains self._fs = recgen.info['recordings']['fs'] * pq.Hz # fs is in kHz self._sampling_frequency = recgen.info['recordings']['fs'] if 'soma_position' in self._spike_trains[0].annotations: for u, st in zip(self._unit_ids, self._spike_trains): self.set_unit_property(u, 'soma_location', st.annotations['soma_position']) def get_unit_ids(self): if self._unit_ids is None: self._initialize() return self._unit_ids def get_num_units(self): if self._num_units is None: self._initialize() return self._num_units @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): if self._spike_trains is None: self._initialize() times = (self._spike_trains[self.get_unit_ids().index(unit_id)].times.rescale('s') * self._fs.rescale('Hz')).magnitude inds = np.where((start_frame <= times) & (times < end_frame)) return np.rint(times[inds]).astype(int) @staticmethod def write_sorting(sorting, save_path, sampling_frequency, check_suffix=True): """ Save sorting extractor to MEArec format. Parameters ---------- sorting: SortingExtractor Sorting extractor object to be saved save_path: str .h5 or .hdf5 path sampling_frequency: int Sampling frequency in Hz """ assert HAVE_MREX, MEArecSortingExtractor.installation_mesg save_path = Path(save_path) if save_path.is_dir(): print("The file will be saved as sorting.h5 in the provided folder") save_path = save_path / 'sorting.h5' if (save_path.suffix == '.h5' or save_path.suffix == '.hdf5') or (not check_suffix): # create neo spike trains spiketrains = [] for u in sorting.get_unit_ids(): st = neo.SpikeTrain(times=sorting.get_unit_spike_train(u) / float(sampling_frequency) * pq.s, t_start=np.min(sorting.get_unit_spike_train(u) / float(sampling_frequency)) * pq.s, t_stop=np.max(sorting.get_unit_spike_train(u) / float(sampling_frequency)) * pq.s) st.annotate(unit_id=u) spiketrains.append(st) assert len(spiketrains) > 0, """ The sorting for output contains no unit, please check the sorting. """ duration = np.max([st.t_stop.magnitude for st in spiketrains]) info = {'recordings': {'fs': sampling_frequency}, 'spiketrains': {'duration': duration}} rec_dict = {'spiketrains': spiketrains} recgen = mr.RecordingGenerator(rec_dict=rec_dict, info=info) mr.save_recording_generator(recgen, str(save_path), verbose=False) else: raise Exception("Provide a folder or an .h5/.hdf5 as 'save_path'") ================================================ FILE: spikeextractors/extractors/neoextractors/__init__.py ================================================ from .plexonextractor import PlexonRecordingExtractor, PlexonSortingExtractor from .neuralynxextractor import NeuralynxRecordingExtractor, NeuralynxSortingExtractor from .mcsrawrecordingextractor import MCSRawRecordingExtractor from .blackrockextractor import BlackrockRecordingExtractor, BlackrockSortingExtractor from .axonaextractor import AxonaRecordingExtractor from .spikegadgetsextractor import SpikeGadgetsRecordingExtractor ================================================ FILE: spikeextractors/extractors/neoextractors/axonaextractor.py ================================================ from .neobaseextractor import NeoBaseRecordingExtractor try: import neo HAVE_NEO = True except ImportError: HAVE_NEO = False class AxonaRecordingExtractor(NeoBaseRecordingExtractor): extractor_name = 'AxonaRecording' mode = 'file' NeoRawIOClass = 'AxonaRawIO' def __init__(self, **kargs): super().__init__(**kargs) # Read channel groups by tetrode IDs self.set_channel_groups(groups=[x - 1 for x in self.neo_reader.raw_annotations[ 'blocks'][0]['segments'][0]['signals'][0]['__array_annotations__']['tetrode_id']]) header_channels = self.neo_reader.header['signal_channels'][slice(None)] names = header_channels['name'] for i, ind in enumerate(self.get_channel_ids()): self.set_channel_property(channel_id=ind, property_name='name', value=names[i]) ================================================ FILE: spikeextractors/extractors/neoextractors/blackrockextractor.py ================================================ from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor from pathlib import Path from typing import Union, Optional try: import neo HAVE_NEO = True except ImportError: HAVE_NEO = False PathType = Union[str, Path] class BlackrockRecordingExtractor(NeoBaseRecordingExtractor): """ The Blackrock extractor is wrapped from neo.rawio.BlackrockRawIO. Parameters ---------- filename: str The Blackrock file (.ns1, .ns2, .ns3, .ns4m .ns4, or .ns6) block_index: None or int If the underlying dataset have several blocks the index must be specified. seg_index_index: None or int If the underlying dataset have several segments the index must be specified. """ extractor_name = 'BlackrockRecording' mode = 'file' installed = HAVE_NEO NeoRawIOClass = 'BlackrockRawIO' def __init__(self, filename: PathType, nsx_to_load: Optional[int] = None, block_index: Optional[int] = None, seg_index: Optional[int] = None, **kwargs): super().__init__(filename=filename, nsx_to_load=nsx_to_load, block_index=block_index, seg_index=seg_index, **kwargs) class BlackrockSortingExtractor(NeoBaseSortingExtractor): """ The Blackrock extractor is wrapped from neo.rawio.BlackrockRawIO. Parameters ---------- filename: str The Blackrock file (.nev) block_index: None or int If the underlying dataset have several blocks the index must be specified. seg_index_index: None or int If the underlying dataset have several segments the index must be specified. """ extractor_name = 'BlackrockSorting' mode = 'file' installed = HAVE_NEO NeoRawIOClass = 'BlackrockRawIO' def __init__(self, filename: PathType, nsx_to_load: Optional[int] = None, block_index: Optional[int] = None, seg_index: Optional[int] = None, **kwargs): super().__init__(filename=filename, nsx_to_load=nsx_to_load, block_index=block_index, seg_index=seg_index, **kwargs) ================================================ FILE: spikeextractors/extractors/neoextractors/mcsrawrecordingextractor.py ================================================ from .neobaseextractor import NeoBaseRecordingExtractor try: import neo HAVE_NEO = True except ImportError: HAVE_NEO = False class MCSRawRecordingExtractor(NeoBaseRecordingExtractor): extractor_name='mcsrawRecoding' mode='file' NeoRawIOClass='RawMCSRawIO' ================================================ FILE: spikeextractors/extractors/neoextractors/neobaseextractor.py ================================================ import numpy as np import warnings from spikeextractors import RecordingExtractor from spikeextractors import SortingExtractor from spikeextractors.extraction_tools import check_get_traces_args, check_get_unit_spike_train try: import neo HAVE_NEO = True except ImportError: HAVE_NEO = False class _NeoBaseExtractor: NeoRawIOClass = None installed = HAVE_NEO is_writable = False has_default_locations = False has_unscaled = True installation_mesg = "To use the Neo extractors, install Neo: \n\n pip install neo\n\n" def __init__(self, block_index=None, seg_index=None, **kargs): """ if block_index is None then check if only one block if seg_index is None then check if only one segment """ assert self.installed, self.installation_mesg neoIOclass = eval('neo.rawio.' + self.NeoRawIOClass) self.neo_reader = neoIOclass(**kargs) self.neo_reader.parse_header() if block_index is None: # auto select first block num_block = self.neo_reader.block_count() assert num_block == 1, 'This file is multi block spikeextractors support only one segment, please provide block_index=' block_index = 0 if seg_index is None: # auto select first segment num_seg = self.neo_reader.segment_count(block_index) assert num_seg == 1, 'This file is multi segment spikeextractors support only one segment, please provide seg_index=' seg_index = 0 self.block_index = block_index self.seg_index = seg_index self._kwargs = kargs self._kwargs.update({'seg_index': seg_index, 'block_index': block_index}) class NeoBaseRecordingExtractor(RecordingExtractor, _NeoBaseExtractor): def __init__(self, block_index=None, seg_index=None, **kargs): RecordingExtractor.__init__(self) _NeoBaseExtractor.__init__(self, block_index=block_index, seg_index=seg_index, **kargs) if hasattr(self.neo_reader, 'get_group_signal_channel_indexes'): # Neo >= 0.9.0 channel_indexes_list = self.neo_reader.get_group_signal_channel_indexes() num_streams = len(channel_indexes_list) assert num_streams == 1, 'This file have several channel groups spikeextractors support only one groups' self.after_v10 = False elif hasattr(self.neo_reader, 'get_group_channel_indexes'): # Neo < 0.9.0 channel_indexes_list = self.neo_reader.get_group_channel_indexes() num_streams = len(channel_indexes_list) self.after_v10 = False elif hasattr(self.neo_reader, 'signal_streams_count'): # Neo >= 0.10.0 (not release yet in march 2021) num_streams = self.neo_reader.signal_streams_count() self.after_v10 = True else: raise ValueError('Strange neo version. Please upgrade your neo package: pip install --upgrade neo') assert num_streams == 1, 'This file have several signal streams spikeextractors support only one streams' \ 'Maybe you can use option to select only one stream' # spikeextractor for units to be uV implicitly # check that units are V, mV or uV units = self.neo_reader.header['signal_channels']['units'] if not np.all(np.isin(units, ['V', 'mV', 'uV'])): warnings.warn('Signal units no Volt compatible, assuming scaling as uV') self.additional_gain = np.ones(units.size, dtype='float') self.additional_gain[units == 'V'] = 1e6 self.additional_gain[units == 'mV'] = 1e3 self.additional_gain[units == 'uV'] = 1. self.additional_gain[units == ''] = 1. self.additional_gain = self.additional_gain.reshape(1, -1) # Add channels properties header_channels = self.neo_reader.header['signal_channels'][slice(None)] self._neo_chan_ids = self.neo_reader.header['signal_channels']['id'] # In neo there is not guarantee that channel ids are unique. # for instance Blacrock can have several times the same chan_id # different sampling rate # so check it assert np.unique(self._neo_chan_ids).size == self._neo_chan_ids.size, 'In this format channel ids are not ' \ 'unique! Incompatible with SpikeInterface' try: channel_ids = [int(ch) for ch in self._neo_chan_ids] except Exception as e: warnings.warn("Could not parse channel ids to int: using linear channel map") channel_ids = list(np.arange(len(self._neo_chan_ids))) self._channel_ids = channel_ids gains = header_channels['gain'] * self.additional_gain[0] self.set_channel_gains(gains=gains, channel_ids=self._channel_ids) names = header_channels['name'] for i, ind in enumerate(self._channel_ids): self.set_channel_property(channel_id=ind, property_name='name', value=names[i]) @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): # in neo rawio channel can acces by names/ids/indexes # there is no garranty that ids/names are unique on some formats channel_idxs = [self.get_channel_ids().index(ch) for ch in channel_ids] neo_chan_ids = self._neo_chan_ids[channel_idxs] if self.after_v10: raw_traces = self.neo_reader.get_analogsignal_chunk(block_index=self.block_index, seg_index=self.seg_index, i_start=start_frame, i_stop=end_frame, channel_indexes=None, channel_names=None, stream_index=0, channel_ids=neo_chan_ids) else: raw_traces = self.neo_reader.get_analogsignal_chunk(block_index=self.block_index, seg_index=self.seg_index, i_start=start_frame, i_stop=end_frame, channel_indexes=None, channel_names=None, channel_ids=neo_chan_ids) # neo works with (samples, channels) strides # so transpose to spikeextractors wolrd return raw_traces.transpose() def get_num_frames(self): # channel_indexes=None means all channels if self.after_v10: n = self.neo_reader.get_signal_size(self.block_index, self.seg_index, stream_index=0) else: n = self.neo_reader.get_signal_size(self.block_index, self.seg_index, channel_indexes=None) return n def get_sampling_frequency(self): # channel_indexes=None means all channels if self.after_v10: sf = self.neo_reader.get_signal_sampling_rate(stream_index=0) else: sf = self.neo_reader.get_signal_sampling_rate(channel_indexes=None) return sf def get_channel_ids(self): return self._channel_ids class NeoBaseSortingExtractor(SortingExtractor, _NeoBaseExtractor): def __init__(self, block_index=None, seg_index=None, **kargs): SortingExtractor.__init__(self) _NeoBaseExtractor.__init__(self, block_index=block_index, seg_index=seg_index, **kargs) # the sampling frequency is quite tricky because in neo # spike are handle in s or ms # internally many format do have have the spike time stamps # at the same speed as the signal but at a higher clocks speed. # here in spikeinterface we need spike index to be at the same speed # that signal it do not make sens to have spikes at 50kHz sample # when the sig is 10kHz. # neo handle this but not spikeextractors self._handle_sampling_frequency() def _handle_sampling_frequency(self): # bacause neo handle spike in times (s or ms) but spikeextractors in frames related to signals. # In neo spikes can have diffrents sampling rate than signals so conversion from #  signals frames to times is format dependent # here the generic case #  all channels are in the same neo group so if len(self.neo_reader.header['signal_channels']['sampling_rate']) > 0: self._neo_sig_sampling_rate = self.neo_reader.header['signal_channels']['sampling_rate'][0] self.set_sampling_frequency(self._neo_sig_sampling_rate) else: warnings.warn("Sampling frequency not found: setting it to 30 kHz") self._sampling_frequency = 30000 self._neo_sig_sampling_rate = self._sampling_frequency if hasattr(self.neo_reader, 'get_group_signal_channel_indexes'): # Neo >= 0.9.0 if len(self.neo_reader.get_group_signal_channel_indexes()) > 0: self._neo_sig_time_start = self.neo_reader.get_signal_t_start(self.block_index, self.seg_index, channel_indexes=[0]) else: warnings.warn("Start time not found: setting it to 0 s") self._neo_sig_time_start = 0 elif hasattr(self.neo_reader, 'get_group_channel_indexes'): # Neo < 0.9.0 if len(self.neo_reader.get_group_channel_indexes()) > 0: self._neo_sig_time_start = self.neo_reader.get_signal_t_start(self.block_index, self.seg_index, channel_indexes=[0]) else: warnings.warn("Start time not found: setting it to 0 s") self._neo_sig_time_start = 0 elif hasattr(self.neo_reader, 'signal_streams_count'): # Neo >= 0.10.0 (not release yet in march 2021) num_streams = self.neo_reader.signal_streams_count() if num_streams > 0: self._neo_sig_time_start = self.neo_reader.get_signal_t_start(self.block_index, self.seg_index, stream_index=0) else: warnings.warn("Start time not found: setting it to 0 s") self._neo_sig_time_start = 0 else: raise ValueError('Strange neo version. Please upgrade your neo package: pip install --upgrade neo') # For some IOs when there is no signals at inside the dataset this could not work # in that case the extractor class must overwrite this method def get_unit_ids(self): # should be this but this is strings in neo #  unit_ids = self.neo_reader.header['unit_channels']['id'] # in neo unit_ids are string so here we take unit_index if 'unit_channels' in self.neo_reader.header: unit_ids = np.arange(self.neo_reader.header['unit_channels'].size, dtype='int64') elif 'spike_channels' in self.neo_reader.header: unit_ids = np.arange(self.neo_reader.header['spike_channels'].size, dtype='int64') else: raise ValueError('Strange neo version. Please upgrade your neo package: pip install --upgrade neo') return unit_ids @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame) # this is a string #  neo_unit_id = self.neo_reader.header['unit_channels']['id'][unit_id] # this is an int unit_index = unit_id # in neo can be a sample, or hiher sample rate or even float try: # version >= 0.9.0 spike_timestamps = self.neo_reader.get_spike_timestamps(block_index=self.block_index, seg_index=self.seg_index, spike_channel_index=unit_index, t_start=None, t_stop=None) except TypeError as e: # version < 0.9.0 spike_timestamps = self.neo_reader.get_spike_timestamps(block_index=self.block_index, seg_index=self.seg_index, unit_index=unit_index, t_start=None, t_stop=None) if start_frame is not None: spike_timestamps = spike_timestamps[spike_timestamps >= start_frame] if end_frame is not None: spike_timestamps = spike_timestamps[spike_timestamps <= end_frame] # convert to second second spike_times = self.neo_reader.rescale_spike_timestamp(spike_timestamps, dtype='float64') # convert to sample related to recording signals spike_indexes = ((spike_times - self._neo_sig_time_start) * self._neo_sig_sampling_rate).astype('int64') return spike_indexes ================================================ FILE: spikeextractors/extractors/neoextractors/neuralynxextractor.py ================================================ from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor try: import neo HAVE_NEO = True except ImportError: HAVE_NEO = False class NeuralynxRecordingExtractor(NeoBaseRecordingExtractor): """ The neruralynx extractor is wrapped from neo.rawio.NeuralynxRawIO. Parameters ---------- dirname: str The neuralynx folder that contain all neuralynx files ('nse', 'ncs', 'nev', 'ntt') block_index: None or int If the underlying dataset have several blocks the index must be specified. seg_index_index: None or int If the underlying dataset have several segments the index must be specified. """ extractor_name = 'NeuralynxRecording' mode = 'folder' installed = HAVE_NEO NeoRawIOClass = 'NeuralynxRawIO' class NeuralynxSortingExtractor(NeoBaseSortingExtractor): """ The neruralynx extractor is wrapped from neo.rawio.NeuralynxRawIO. Parameters ---------- dirname: str The neuralynx folder that contain all neuralynx files ('nse', 'ncs', 'nev', 'ntt') block_index: None or int If the underlying dataset have several blocks the index must be specified. seg_index_index: None or int If the underlying dataset have several segments the index must be specified. """ extractor_name = 'NeuralynxSorting' mode = 'folder' installed = HAVE_NEO NeoRawIOClass = 'NeuralynxRawIO' ================================================ FILE: spikeextractors/extractors/neoextractors/plexonextractor.py ================================================ from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor try: import neo HAVE_NEO = True except ImportError: HAVE_NEO = False class PlexonRecordingExtractor(NeoBaseRecordingExtractor): """ The plxon extractor is wrapped from neo.rawio.PlexonRawIO. Parameters ---------- filename: str The plexon file ('plx') block_index: None or int If the underlying dataset have several blocks the index must be specified. seg_index_index: None or int If the underlying dataset have several segments the index must be specified. """ extractor_name = 'PlexonRecording' mode = 'file' installed = HAVE_NEO NeoRawIOClass = 'PlexonRawIO' class PlexonSortingExtractor(NeoBaseSortingExtractor): extractor_name = 'PlexonSorting' mode = 'file' installed = HAVE_NEO NeoRawIOClass = 'PlexonRawIO' ================================================ FILE: spikeextractors/extractors/neoextractors/spikegadgetsextractor.py ================================================ from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor try: import neo HAVE_NEO = True except ImportError: HAVE_NEO = False class SpikeGadgetsRecordingExtractor(NeoBaseRecordingExtractor): """ The spikegadgets extractor is wrapped from neo.rawio.SpikegadgetsRawIO. Parameters ---------- filename: str The spike gadgets file ('rec') selected_streams: str The id of the stream to load 'trodes' is ephy channels. Can also be ECU, ... block_index: None or int If the underlying dataset have several blocks the index must be specified. seg_index_index: None or int If the underlying dataset have several segments the index must be specified. """ extractor_name = 'SpikeGadgetsRecording' mode = 'file' installed = HAVE_NEO NeoRawIOClass = 'SpikeGadgetsRawIO' def __init__(self, filename, selected_streams='trodes',**kwargs): super().__init__(filename=filename, selected_streams=selected_streams, **kwargs) ================================================ FILE: spikeextractors/extractors/neuropixelsdatrecordingextractor/__init__.py ================================================ from .neuropixelsdatrecordingextractor import NeuropixelsDatRecordingExtractor ================================================ FILE: spikeextractors/extractors/neuropixelsdatrecordingextractor/channel_positions_neuropixels.txt ================================================ 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 2.000000000000000000e+01 2.000000000000000000e+01 4.000000000000000000e+01 4.000000000000000000e+01 6.000000000000000000e+01 6.000000000000000000e+01 8.000000000000000000e+01 8.000000000000000000e+01 1.000000000000000000e+02 1.000000000000000000e+02 1.200000000000000000e+02 1.200000000000000000e+02 1.400000000000000000e+02 1.400000000000000000e+02 1.600000000000000000e+02 1.600000000000000000e+02 1.800000000000000000e+02 1.800000000000000000e+02 2.000000000000000000e+02 2.000000000000000000e+02 2.200000000000000000e+02 2.200000000000000000e+02 2.400000000000000000e+02 2.400000000000000000e+02 2.600000000000000000e+02 2.600000000000000000e+02 2.800000000000000000e+02 2.800000000000000000e+02 3.000000000000000000e+02 3.000000000000000000e+02 3.200000000000000000e+02 3.200000000000000000e+02 3.400000000000000000e+02 3.400000000000000000e+02 3.600000000000000000e+02 3.600000000000000000e+02 3.800000000000000000e+02 3.800000000000000000e+02 4.000000000000000000e+02 4.000000000000000000e+02 4.200000000000000000e+02 4.200000000000000000e+02 4.400000000000000000e+02 4.400000000000000000e+02 4.600000000000000000e+02 4.600000000000000000e+02 4.800000000000000000e+02 4.800000000000000000e+02 5.000000000000000000e+02 5.000000000000000000e+02 5.200000000000000000e+02 5.200000000000000000e+02 5.400000000000000000e+02 5.400000000000000000e+02 5.600000000000000000e+02 5.600000000000000000e+02 5.800000000000000000e+02 5.800000000000000000e+02 6.000000000000000000e+02 6.000000000000000000e+02 6.200000000000000000e+02 6.200000000000000000e+02 6.400000000000000000e+02 6.400000000000000000e+02 6.600000000000000000e+02 6.600000000000000000e+02 6.800000000000000000e+02 6.800000000000000000e+02 7.000000000000000000e+02 7.000000000000000000e+02 7.200000000000000000e+02 7.200000000000000000e+02 7.400000000000000000e+02 7.400000000000000000e+02 7.600000000000000000e+02 7.600000000000000000e+02 7.800000000000000000e+02 7.800000000000000000e+02 8.000000000000000000e+02 8.000000000000000000e+02 8.200000000000000000e+02 8.200000000000000000e+02 8.400000000000000000e+02 8.400000000000000000e+02 8.600000000000000000e+02 8.600000000000000000e+02 8.800000000000000000e+02 8.800000000000000000e+02 9.000000000000000000e+02 9.000000000000000000e+02 9.200000000000000000e+02 9.200000000000000000e+02 9.400000000000000000e+02 9.400000000000000000e+02 9.600000000000000000e+02 9.600000000000000000e+02 9.800000000000000000e+02 9.800000000000000000e+02 1.000000000000000000e+03 1.000000000000000000e+03 1.020000000000000000e+03 1.020000000000000000e+03 1.040000000000000000e+03 1.040000000000000000e+03 1.060000000000000000e+03 1.060000000000000000e+03 1.080000000000000000e+03 1.080000000000000000e+03 1.100000000000000000e+03 1.100000000000000000e+03 1.120000000000000000e+03 1.120000000000000000e+03 1.140000000000000000e+03 1.140000000000000000e+03 1.160000000000000000e+03 1.160000000000000000e+03 1.180000000000000000e+03 1.180000000000000000e+03 1.200000000000000000e+03 1.200000000000000000e+03 1.220000000000000000e+03 1.220000000000000000e+03 1.240000000000000000e+03 1.240000000000000000e+03 1.260000000000000000e+03 1.260000000000000000e+03 1.280000000000000000e+03 1.280000000000000000e+03 1.300000000000000000e+03 1.300000000000000000e+03 1.320000000000000000e+03 1.320000000000000000e+03 1.340000000000000000e+03 1.340000000000000000e+03 1.360000000000000000e+03 1.360000000000000000e+03 1.380000000000000000e+03 1.380000000000000000e+03 1.400000000000000000e+03 1.400000000000000000e+03 1.420000000000000000e+03 1.420000000000000000e+03 1.440000000000000000e+03 1.440000000000000000e+03 1.460000000000000000e+03 1.460000000000000000e+03 1.480000000000000000e+03 1.480000000000000000e+03 1.500000000000000000e+03 1.500000000000000000e+03 1.520000000000000000e+03 1.520000000000000000e+03 1.540000000000000000e+03 1.540000000000000000e+03 1.560000000000000000e+03 1.560000000000000000e+03 1.580000000000000000e+03 1.580000000000000000e+03 1.600000000000000000e+03 1.600000000000000000e+03 1.620000000000000000e+03 1.620000000000000000e+03 1.640000000000000000e+03 1.640000000000000000e+03 1.660000000000000000e+03 1.660000000000000000e+03 1.680000000000000000e+03 1.680000000000000000e+03 1.700000000000000000e+03 1.700000000000000000e+03 1.720000000000000000e+03 1.720000000000000000e+03 1.740000000000000000e+03 1.740000000000000000e+03 1.760000000000000000e+03 1.760000000000000000e+03 1.780000000000000000e+03 1.780000000000000000e+03 1.800000000000000000e+03 1.800000000000000000e+03 1.820000000000000000e+03 1.820000000000000000e+03 1.840000000000000000e+03 1.840000000000000000e+03 1.860000000000000000e+03 1.860000000000000000e+03 1.880000000000000000e+03 1.880000000000000000e+03 1.900000000000000000e+03 1.900000000000000000e+03 1.920000000000000000e+03 1.920000000000000000e+03 1.940000000000000000e+03 1.940000000000000000e+03 1.960000000000000000e+03 1.960000000000000000e+03 1.980000000000000000e+03 1.980000000000000000e+03 2.000000000000000000e+03 2.000000000000000000e+03 2.020000000000000000e+03 2.020000000000000000e+03 2.040000000000000000e+03 2.040000000000000000e+03 2.060000000000000000e+03 2.060000000000000000e+03 2.080000000000000000e+03 2.080000000000000000e+03 2.100000000000000000e+03 2.100000000000000000e+03 2.120000000000000000e+03 2.120000000000000000e+03 2.140000000000000000e+03 2.140000000000000000e+03 2.160000000000000000e+03 2.160000000000000000e+03 2.180000000000000000e+03 2.180000000000000000e+03 2.200000000000000000e+03 2.200000000000000000e+03 2.220000000000000000e+03 2.220000000000000000e+03 2.240000000000000000e+03 2.240000000000000000e+03 2.260000000000000000e+03 2.260000000000000000e+03 2.280000000000000000e+03 2.280000000000000000e+03 2.300000000000000000e+03 2.300000000000000000e+03 2.320000000000000000e+03 2.320000000000000000e+03 2.340000000000000000e+03 2.340000000000000000e+03 2.360000000000000000e+03 2.360000000000000000e+03 2.380000000000000000e+03 2.380000000000000000e+03 2.400000000000000000e+03 2.400000000000000000e+03 2.420000000000000000e+03 2.420000000000000000e+03 2.440000000000000000e+03 2.440000000000000000e+03 2.460000000000000000e+03 2.460000000000000000e+03 2.480000000000000000e+03 2.480000000000000000e+03 2.500000000000000000e+03 2.500000000000000000e+03 2.520000000000000000e+03 2.520000000000000000e+03 2.540000000000000000e+03 2.540000000000000000e+03 2.560000000000000000e+03 2.560000000000000000e+03 2.580000000000000000e+03 2.580000000000000000e+03 2.600000000000000000e+03 2.600000000000000000e+03 2.620000000000000000e+03 2.620000000000000000e+03 2.640000000000000000e+03 2.640000000000000000e+03 2.660000000000000000e+03 2.660000000000000000e+03 2.680000000000000000e+03 2.680000000000000000e+03 2.700000000000000000e+03 2.700000000000000000e+03 2.720000000000000000e+03 2.720000000000000000e+03 2.740000000000000000e+03 2.740000000000000000e+03 2.760000000000000000e+03 2.760000000000000000e+03 2.780000000000000000e+03 2.780000000000000000e+03 2.800000000000000000e+03 2.800000000000000000e+03 2.820000000000000000e+03 2.820000000000000000e+03 2.840000000000000000e+03 2.840000000000000000e+03 2.860000000000000000e+03 2.860000000000000000e+03 2.880000000000000000e+03 2.880000000000000000e+03 2.900000000000000000e+03 2.900000000000000000e+03 2.920000000000000000e+03 2.920000000000000000e+03 2.940000000000000000e+03 2.940000000000000000e+03 2.960000000000000000e+03 2.960000000000000000e+03 2.980000000000000000e+03 2.980000000000000000e+03 3.000000000000000000e+03 3.000000000000000000e+03 3.020000000000000000e+03 3.020000000000000000e+03 3.040000000000000000e+03 3.040000000000000000e+03 3.060000000000000000e+03 3.060000000000000000e+03 3.080000000000000000e+03 3.080000000000000000e+03 3.100000000000000000e+03 3.100000000000000000e+03 3.120000000000000000e+03 3.120000000000000000e+03 3.140000000000000000e+03 3.140000000000000000e+03 3.160000000000000000e+03 3.160000000000000000e+03 3.180000000000000000e+03 3.180000000000000000e+03 3.200000000000000000e+03 3.200000000000000000e+03 3.220000000000000000e+03 3.220000000000000000e+03 3.240000000000000000e+03 3.240000000000000000e+03 3.260000000000000000e+03 3.260000000000000000e+03 3.280000000000000000e+03 3.280000000000000000e+03 3.300000000000000000e+03 3.300000000000000000e+03 3.320000000000000000e+03 3.320000000000000000e+03 3.340000000000000000e+03 3.340000000000000000e+03 3.360000000000000000e+03 3.360000000000000000e+03 3.380000000000000000e+03 3.380000000000000000e+03 3.400000000000000000e+03 3.400000000000000000e+03 3.420000000000000000e+03 3.420000000000000000e+03 3.440000000000000000e+03 3.440000000000000000e+03 3.460000000000000000e+03 3.460000000000000000e+03 3.480000000000000000e+03 3.480000000000000000e+03 3.500000000000000000e+03 3.500000000000000000e+03 3.520000000000000000e+03 3.520000000000000000e+03 3.540000000000000000e+03 3.540000000000000000e+03 3.560000000000000000e+03 3.560000000000000000e+03 3.580000000000000000e+03 3.580000000000000000e+03 3.600000000000000000e+03 3.600000000000000000e+03 3.620000000000000000e+03 3.620000000000000000e+03 3.640000000000000000e+03 3.640000000000000000e+03 3.660000000000000000e+03 3.660000000000000000e+03 3.680000000000000000e+03 3.680000000000000000e+03 3.700000000000000000e+03 3.700000000000000000e+03 3.720000000000000000e+03 3.720000000000000000e+03 3.740000000000000000e+03 3.740000000000000000e+03 3.760000000000000000e+03 3.760000000000000000e+03 3.780000000000000000e+03 3.780000000000000000e+03 3.800000000000000000e+03 3.800000000000000000e+03 3.820000000000000000e+03 3.820000000000000000e+03 3.840000000000000000e+03 3.840000000000000000e+03 3.860000000000000000e+03 3.860000000000000000e+03 3.880000000000000000e+03 3.880000000000000000e+03 3.900000000000000000e+03 3.900000000000000000e+03 3.920000000000000000e+03 3.920000000000000000e+03 3.940000000000000000e+03 3.940000000000000000e+03 3.960000000000000000e+03 3.960000000000000000e+03 3.980000000000000000e+03 3.980000000000000000e+03 4.000000000000000000e+03 4.000000000000000000e+03 4.020000000000000000e+03 4.020000000000000000e+03 4.040000000000000000e+03 4.040000000000000000e+03 4.060000000000000000e+03 4.060000000000000000e+03 4.080000000000000000e+03 4.080000000000000000e+03 4.100000000000000000e+03 4.100000000000000000e+03 4.120000000000000000e+03 4.120000000000000000e+03 4.140000000000000000e+03 4.140000000000000000e+03 4.160000000000000000e+03 4.160000000000000000e+03 4.180000000000000000e+03 4.180000000000000000e+03 4.200000000000000000e+03 4.200000000000000000e+03 4.220000000000000000e+03 4.220000000000000000e+03 4.240000000000000000e+03 4.240000000000000000e+03 4.260000000000000000e+03 4.260000000000000000e+03 4.280000000000000000e+03 4.280000000000000000e+03 4.300000000000000000e+03 4.300000000000000000e+03 4.320000000000000000e+03 4.320000000000000000e+03 4.340000000000000000e+03 4.340000000000000000e+03 4.360000000000000000e+03 4.360000000000000000e+03 4.380000000000000000e+03 4.380000000000000000e+03 4.400000000000000000e+03 4.400000000000000000e+03 4.420000000000000000e+03 4.420000000000000000e+03 4.440000000000000000e+03 4.440000000000000000e+03 4.460000000000000000e+03 4.460000000000000000e+03 4.480000000000000000e+03 4.480000000000000000e+03 4.500000000000000000e+03 4.500000000000000000e+03 4.520000000000000000e+03 4.520000000000000000e+03 4.540000000000000000e+03 4.540000000000000000e+03 4.560000000000000000e+03 4.560000000000000000e+03 4.580000000000000000e+03 4.580000000000000000e+03 4.600000000000000000e+03 4.600000000000000000e+03 4.620000000000000000e+03 4.620000000000000000e+03 4.640000000000000000e+03 4.640000000000000000e+03 4.660000000000000000e+03 4.660000000000000000e+03 4.680000000000000000e+03 4.680000000000000000e+03 4.700000000000000000e+03 4.700000000000000000e+03 4.720000000000000000e+03 4.720000000000000000e+03 4.740000000000000000e+03 4.740000000000000000e+03 4.760000000000000000e+03 4.760000000000000000e+03 4.780000000000000000e+03 4.780000000000000000e+03 4.800000000000000000e+03 4.800000000000000000e+03 4.820000000000000000e+03 4.820000000000000000e+03 4.840000000000000000e+03 4.840000000000000000e+03 4.860000000000000000e+03 4.860000000000000000e+03 4.880000000000000000e+03 4.880000000000000000e+03 4.900000000000000000e+03 4.900000000000000000e+03 4.920000000000000000e+03 4.920000000000000000e+03 4.940000000000000000e+03 4.940000000000000000e+03 4.960000000000000000e+03 4.960000000000000000e+03 4.980000000000000000e+03 4.980000000000000000e+03 5.000000000000000000e+03 5.000000000000000000e+03 5.020000000000000000e+03 5.020000000000000000e+03 5.040000000000000000e+03 5.040000000000000000e+03 5.060000000000000000e+03 5.060000000000000000e+03 5.080000000000000000e+03 5.080000000000000000e+03 5.100000000000000000e+03 5.100000000000000000e+03 5.120000000000000000e+03 5.120000000000000000e+03 5.140000000000000000e+03 5.140000000000000000e+03 5.160000000000000000e+03 5.160000000000000000e+03 5.180000000000000000e+03 5.180000000000000000e+03 5.200000000000000000e+03 5.200000000000000000e+03 5.220000000000000000e+03 5.220000000000000000e+03 5.240000000000000000e+03 5.240000000000000000e+03 5.260000000000000000e+03 5.260000000000000000e+03 5.280000000000000000e+03 5.280000000000000000e+03 5.300000000000000000e+03 5.300000000000000000e+03 5.320000000000000000e+03 5.320000000000000000e+03 5.340000000000000000e+03 5.340000000000000000e+03 5.360000000000000000e+03 5.360000000000000000e+03 5.380000000000000000e+03 5.380000000000000000e+03 5.400000000000000000e+03 5.400000000000000000e+03 5.420000000000000000e+03 5.420000000000000000e+03 5.440000000000000000e+03 5.440000000000000000e+03 5.460000000000000000e+03 5.460000000000000000e+03 5.480000000000000000e+03 5.480000000000000000e+03 5.500000000000000000e+03 5.500000000000000000e+03 5.520000000000000000e+03 5.520000000000000000e+03 5.540000000000000000e+03 5.540000000000000000e+03 5.560000000000000000e+03 5.560000000000000000e+03 5.580000000000000000e+03 5.580000000000000000e+03 5.600000000000000000e+03 5.600000000000000000e+03 5.620000000000000000e+03 5.620000000000000000e+03 5.640000000000000000e+03 5.640000000000000000e+03 5.660000000000000000e+03 5.660000000000000000e+03 5.680000000000000000e+03 5.680000000000000000e+03 5.700000000000000000e+03 5.700000000000000000e+03 5.720000000000000000e+03 5.720000000000000000e+03 5.740000000000000000e+03 5.740000000000000000e+03 5.760000000000000000e+03 5.760000000000000000e+03 5.780000000000000000e+03 5.780000000000000000e+03 5.800000000000000000e+03 5.800000000000000000e+03 5.820000000000000000e+03 5.820000000000000000e+03 5.840000000000000000e+03 5.840000000000000000e+03 5.860000000000000000e+03 5.860000000000000000e+03 5.880000000000000000e+03 5.880000000000000000e+03 5.900000000000000000e+03 5.900000000000000000e+03 5.920000000000000000e+03 5.920000000000000000e+03 5.940000000000000000e+03 5.940000000000000000e+03 5.960000000000000000e+03 5.960000000000000000e+03 5.980000000000000000e+03 5.980000000000000000e+03 6.000000000000000000e+03 6.000000000000000000e+03 6.020000000000000000e+03 6.020000000000000000e+03 6.040000000000000000e+03 6.040000000000000000e+03 6.060000000000000000e+03 6.060000000000000000e+03 6.080000000000000000e+03 6.080000000000000000e+03 6.100000000000000000e+03 6.100000000000000000e+03 6.120000000000000000e+03 6.120000000000000000e+03 6.140000000000000000e+03 6.140000000000000000e+03 6.160000000000000000e+03 6.160000000000000000e+03 6.180000000000000000e+03 6.180000000000000000e+03 6.200000000000000000e+03 6.200000000000000000e+03 6.220000000000000000e+03 6.220000000000000000e+03 6.240000000000000000e+03 6.240000000000000000e+03 6.260000000000000000e+03 6.260000000000000000e+03 6.280000000000000000e+03 6.280000000000000000e+03 6.300000000000000000e+03 6.300000000000000000e+03 6.320000000000000000e+03 6.320000000000000000e+03 6.340000000000000000e+03 6.340000000000000000e+03 6.360000000000000000e+03 6.360000000000000000e+03 6.380000000000000000e+03 6.380000000000000000e+03 6.400000000000000000e+03 6.400000000000000000e+03 6.420000000000000000e+03 6.420000000000000000e+03 6.440000000000000000e+03 6.440000000000000000e+03 6.460000000000000000e+03 6.460000000000000000e+03 6.480000000000000000e+03 6.480000000000000000e+03 6.500000000000000000e+03 6.500000000000000000e+03 6.520000000000000000e+03 6.520000000000000000e+03 6.540000000000000000e+03 6.540000000000000000e+03 6.560000000000000000e+03 6.560000000000000000e+03 6.580000000000000000e+03 6.580000000000000000e+03 6.600000000000000000e+03 6.600000000000000000e+03 6.620000000000000000e+03 6.620000000000000000e+03 6.640000000000000000e+03 6.640000000000000000e+03 6.660000000000000000e+03 6.660000000000000000e+03 6.680000000000000000e+03 6.680000000000000000e+03 6.700000000000000000e+03 6.700000000000000000e+03 6.720000000000000000e+03 6.720000000000000000e+03 6.740000000000000000e+03 6.740000000000000000e+03 6.760000000000000000e+03 6.760000000000000000e+03 6.780000000000000000e+03 6.780000000000000000e+03 6.800000000000000000e+03 6.800000000000000000e+03 6.820000000000000000e+03 6.820000000000000000e+03 6.840000000000000000e+03 6.840000000000000000e+03 6.860000000000000000e+03 6.860000000000000000e+03 6.880000000000000000e+03 6.880000000000000000e+03 6.900000000000000000e+03 6.900000000000000000e+03 6.920000000000000000e+03 6.920000000000000000e+03 6.940000000000000000e+03 6.940000000000000000e+03 6.960000000000000000e+03 6.960000000000000000e+03 6.980000000000000000e+03 6.980000000000000000e+03 7.000000000000000000e+03 7.000000000000000000e+03 7.020000000000000000e+03 7.020000000000000000e+03 7.040000000000000000e+03 7.040000000000000000e+03 7.060000000000000000e+03 7.060000000000000000e+03 7.080000000000000000e+03 7.080000000000000000e+03 7.100000000000000000e+03 7.100000000000000000e+03 7.120000000000000000e+03 7.120000000000000000e+03 7.140000000000000000e+03 7.140000000000000000e+03 7.160000000000000000e+03 7.160000000000000000e+03 7.180000000000000000e+03 7.180000000000000000e+03 7.200000000000000000e+03 7.200000000000000000e+03 7.220000000000000000e+03 7.220000000000000000e+03 7.240000000000000000e+03 7.240000000000000000e+03 7.260000000000000000e+03 7.260000000000000000e+03 7.280000000000000000e+03 7.280000000000000000e+03 7.300000000000000000e+03 7.300000000000000000e+03 7.320000000000000000e+03 7.320000000000000000e+03 7.340000000000000000e+03 7.340000000000000000e+03 7.360000000000000000e+03 7.360000000000000000e+03 7.380000000000000000e+03 7.380000000000000000e+03 7.400000000000000000e+03 7.400000000000000000e+03 7.420000000000000000e+03 7.420000000000000000e+03 7.440000000000000000e+03 7.440000000000000000e+03 7.460000000000000000e+03 7.460000000000000000e+03 7.480000000000000000e+03 7.480000000000000000e+03 7.500000000000000000e+03 7.500000000000000000e+03 7.520000000000000000e+03 7.520000000000000000e+03 7.540000000000000000e+03 7.540000000000000000e+03 7.560000000000000000e+03 7.560000000000000000e+03 7.580000000000000000e+03 7.580000000000000000e+03 7.600000000000000000e+03 7.600000000000000000e+03 7.620000000000000000e+03 7.620000000000000000e+03 7.640000000000000000e+03 7.640000000000000000e+03 7.660000000000000000e+03 7.660000000000000000e+03 7.680000000000000000e+03 7.680000000000000000e+03 7.700000000000000000e+03 7.700000000000000000e+03 7.720000000000000000e+03 7.720000000000000000e+03 7.740000000000000000e+03 7.740000000000000000e+03 7.760000000000000000e+03 7.760000000000000000e+03 7.780000000000000000e+03 7.780000000000000000e+03 7.800000000000000000e+03 7.800000000000000000e+03 7.820000000000000000e+03 7.820000000000000000e+03 7.840000000000000000e+03 7.840000000000000000e+03 7.860000000000000000e+03 7.860000000000000000e+03 7.880000000000000000e+03 7.880000000000000000e+03 7.900000000000000000e+03 7.900000000000000000e+03 7.920000000000000000e+03 7.920000000000000000e+03 7.940000000000000000e+03 7.940000000000000000e+03 7.960000000000000000e+03 7.960000000000000000e+03 7.980000000000000000e+03 7.980000000000000000e+03 8.000000000000000000e+03 8.000000000000000000e+03 8.020000000000000000e+03 8.020000000000000000e+03 8.040000000000000000e+03 8.040000000000000000e+03 8.060000000000000000e+03 8.060000000000000000e+03 8.080000000000000000e+03 8.080000000000000000e+03 8.100000000000000000e+03 8.100000000000000000e+03 8.120000000000000000e+03 8.120000000000000000e+03 8.140000000000000000e+03 8.140000000000000000e+03 8.160000000000000000e+03 8.160000000000000000e+03 8.180000000000000000e+03 8.180000000000000000e+03 8.200000000000000000e+03 8.200000000000000000e+03 8.220000000000000000e+03 8.220000000000000000e+03 8.240000000000000000e+03 8.240000000000000000e+03 8.260000000000000000e+03 8.260000000000000000e+03 8.280000000000000000e+03 8.280000000000000000e+03 8.300000000000000000e+03 8.300000000000000000e+03 8.320000000000000000e+03 8.320000000000000000e+03 8.340000000000000000e+03 8.340000000000000000e+03 8.360000000000000000e+03 8.360000000000000000e+03 8.380000000000000000e+03 8.380000000000000000e+03 8.400000000000000000e+03 8.400000000000000000e+03 8.420000000000000000e+03 8.420000000000000000e+03 8.440000000000000000e+03 8.440000000000000000e+03 8.460000000000000000e+03 8.460000000000000000e+03 8.480000000000000000e+03 8.480000000000000000e+03 8.500000000000000000e+03 8.500000000000000000e+03 8.520000000000000000e+03 8.520000000000000000e+03 8.540000000000000000e+03 8.540000000000000000e+03 8.560000000000000000e+03 8.560000000000000000e+03 8.580000000000000000e+03 8.580000000000000000e+03 8.600000000000000000e+03 8.600000000000000000e+03 8.620000000000000000e+03 8.620000000000000000e+03 8.640000000000000000e+03 8.640000000000000000e+03 8.660000000000000000e+03 8.660000000000000000e+03 8.680000000000000000e+03 8.680000000000000000e+03 8.700000000000000000e+03 8.700000000000000000e+03 8.720000000000000000e+03 8.720000000000000000e+03 8.740000000000000000e+03 8.740000000000000000e+03 8.760000000000000000e+03 8.760000000000000000e+03 8.780000000000000000e+03 8.780000000000000000e+03 8.800000000000000000e+03 8.800000000000000000e+03 8.820000000000000000e+03 8.820000000000000000e+03 8.840000000000000000e+03 8.840000000000000000e+03 8.860000000000000000e+03 8.860000000000000000e+03 8.880000000000000000e+03 8.880000000000000000e+03 8.900000000000000000e+03 8.900000000000000000e+03 8.920000000000000000e+03 8.920000000000000000e+03 8.940000000000000000e+03 8.940000000000000000e+03 8.960000000000000000e+03 8.960000000000000000e+03 8.980000000000000000e+03 8.980000000000000000e+03 9.000000000000000000e+03 9.000000000000000000e+03 9.020000000000000000e+03 9.020000000000000000e+03 9.040000000000000000e+03 9.040000000000000000e+03 9.060000000000000000e+03 9.060000000000000000e+03 9.080000000000000000e+03 9.080000000000000000e+03 9.100000000000000000e+03 9.100000000000000000e+03 9.120000000000000000e+03 9.120000000000000000e+03 9.140000000000000000e+03 9.140000000000000000e+03 9.160000000000000000e+03 9.160000000000000000e+03 9.180000000000000000e+03 9.180000000000000000e+03 9.200000000000000000e+03 9.200000000000000000e+03 9.220000000000000000e+03 9.220000000000000000e+03 9.240000000000000000e+03 9.240000000000000000e+03 9.260000000000000000e+03 9.260000000000000000e+03 9.280000000000000000e+03 9.280000000000000000e+03 9.300000000000000000e+03 9.300000000000000000e+03 9.320000000000000000e+03 9.320000000000000000e+03 9.340000000000000000e+03 9.340000000000000000e+03 9.360000000000000000e+03 9.360000000000000000e+03 9.380000000000000000e+03 9.380000000000000000e+03 9.400000000000000000e+03 9.400000000000000000e+03 9.420000000000000000e+03 9.420000000000000000e+03 9.440000000000000000e+03 9.440000000000000000e+03 9.460000000000000000e+03 9.460000000000000000e+03 9.480000000000000000e+03 9.480000000000000000e+03 9.500000000000000000e+03 9.500000000000000000e+03 9.520000000000000000e+03 9.520000000000000000e+03 9.540000000000000000e+03 9.540000000000000000e+03 9.560000000000000000e+03 9.560000000000000000e+03 9.580000000000000000e+03 9.580000000000000000e+03 9.600000000000000000e+03 9.600000000000000000e+03 ================================================ FILE: spikeextractors/extractors/neuropixelsdatrecordingextractor/neuropixelsdatrecordingextractor.py ================================================ from spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor import numpy as np from pathlib import Path import warnings try: import xmltodict HAVE_XMLTODICT = True except ImportError: HAVE_XMLTODICT = False class NeuropixelsDatRecordingExtractor(BinDatRecordingExtractor): """ Read raw Neurpoixels recordings from Open Ephys dat file and settings.xml This extractor is currently compatible with the 960 channel Neuropixels probes, where a maximum of 384 channels are recorded simulatenously. The array configuration can be specified by passing the settings.xml file created by OpenEphys (it can be found in the directory tree with teh recordings). If this is not provided, the default configuration using 384 channels at the probe tip will be used (a is warning printed). Parameters ---------- file_path: str The raw data file (usually continuous.dat) settings_file: None or str The file settings.xml generated by OpenEphys containing the array configuration. If not provided the default configuration using 384 channels at the probe tip will be used. verbose: bool Print probe configuration """ extractor_name = 'NeuropixelsDatRecording' has_default_locations = True has_unscaled = True installed = HAVE_XMLTODICT is_writable = False mode = 'file' installation_mesg = "To use the NeuropixelsDat extractor, install xmltodict: \n\n pip install xmltodict\n\n" def __init__(self, file_path, settings_file=None, is_filtered=None, verbose=False): assert self.installed, self.installation_mesg source_dir = Path(__file__).parent self._settings_file = settings_file datfile = Path(file_path) time_axis = 0 dtype = 'int16' sampling_frequency = float(30000) channel_locations = np.loadtxt(source_dir / 'channel_positions_neuropixels.txt') if self._settings_file is not None: with open(self._settings_file) as f: xmldata = f.read() settings = xmltodict.parse(xmldata)['SETTINGS'] channel_info = settings['SIGNALCHAIN']['PROCESSOR'][0]['CHANNEL_INFO'] channels = settings['SIGNALCHAIN']['PROCESSOR'][0]['CHANNEL'] recorded_channels = [] for c in channels: if c['SELECTIONSTATE']['@record'] == '1': recorded_channels.append(int(c['@number'])) used_channels = [] used_channel_gains = [] for c in channel_info['CHANNEL']: if 'AP' in c['@name'] and int(c['@number']) in recorded_channels: used_channels.append(int(c['@number'])) used_channel_gains.append(float(c['@gain'])) if verbose: print(f'{len(recorded_channels)} total channels found, with {len(used_channels)} recording AP') print(f'Channels used:\n{used_channels}') numchan = len(used_channels) geom = channel_locations[:, np.array(used_channels)].T gain = used_channel_gains[0] channels = used_channels else: warnings.warn("No information about this recording available," "using a default of 384 channels at the probe tip." "If the recording differs, use settings_file=settings.xml") numchan = 384 geom = channel_locations[:, :384].T gain = None channels = range(384) BinDatRecordingExtractor.__init__(self, file_path=datfile, numchan=numchan, dtype=dtype, sampling_frequency=sampling_frequency, gain=gain, geom=geom, recording_channels=channels, time_axis=time_axis, is_filtered=is_filtered) self._kwargs = {'filename': str(Path(file_path).absolute()), 'settings_file': settings_file, 'is_filtered': is_filtered} ================================================ FILE: spikeextractors/extractors/neuroscopeextractors/__init__.py ================================================ from .neuroscopeextractors import NeuroscopeRecordingExtractor, NeuroscopeMultiRecordingTimeExtractor from .neuroscopeextractors import NeuroscopeSortingExtractor, NeuroscopeMultiSortingExtractor ================================================ FILE: spikeextractors/extractors/neuroscopeextractors/neuroscopeextractors.py ================================================ from spikeextractors import RecordingExtractor, MultiRecordingTimeExtractor, SortingExtractor, MultiSortingExtractor from spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor import numpy as np from pathlib import Path from spikeextractors.extraction_tools import check_get_unit_spike_train, get_sub_extractors_by_property from typing import Union, Optional import re import warnings try: from lxml import etree as et HAVE_LXML = True except ImportError: HAVE_LXML = False PathType = Union[str, Path] OptionalPathType = Optional[PathType] DtypeType = Union[str, np.dtype, None] def get_single_files(folder_path: Path, suffix: str): return [ f for f in folder_path.iterdir() if f.is_file() and suffix in f.suffixes and not f.name.endswith("~") and len(f.suffixes) == 1 ] def get_shank_files(folder_path: Path, suffix: str): return [ f for f in folder_path.iterdir() if f.is_file() and suffix in f.suffixes and re.search(r"\d+$", f.name) is not None and len(f.suffixes) == 2 ] def find_xml_file_path(folder_path: PathType): xml_files = [f for f in folder_path.iterdir() if f.is_file() if f.suffix == ".xml"] assert any(xml_files), "No .xml files found in the folder_path." assert len(xml_files) == 1, "More than one .xml file found in the folder_path! Specify xml_file_path." xml_file_path = xml_files[0] return xml_file_path def handle_xml_file_path(folder_path: PathType, initial_xml_file_path: PathType): if initial_xml_file_path is None: xml_file_path = find_xml_file_path(folder_path=folder_path) else: assert Path(initial_xml_file_path).is_file(), f".xml file ({initial_xml_file_path}) not found!" xml_file_path = initial_xml_file_path return xml_file_path class NeuroscopeRecordingExtractor(BinDatRecordingExtractor): """ Extracts raw neural recordings from binary .dat files in the neuroscope format. The recording extractor always returns channel IDs starting from 0. The recording data will always be returned in the shape of (num_channels,num_frames). Parameters ---------- file_path : str Path to the .dat file to be extracted. gain : float, optional Numerical value that converts the native int dtype to microvolts. Defaults to 1. xml_file_path : PathType, optional Path to the .xml file referenced by this recording. """ extractor_name = "NeuroscopeRecordingExtractor" installed = HAVE_LXML has_default_locations = False has_unscaled = False is_writable = True mode = "file" installation_mesg = "Please install lxml to use this extractor!" def __init__(self, file_path: PathType, gain: Optional[float] = None, xml_file_path: OptionalPathType = None): assert self.installed, self.installation_mesg file_path = Path(file_path) assert file_path.is_file() and file_path.suffix in [".dat", ".eeg", ".lfp"], \ "file_path must lead to a .dat or .eeg file!" RecordingExtractor.__init__(self) self._recording_file = file_path xml_file_path = handle_xml_file_path(folder_path=Path(file_path).parent, initial_xml_file_path=xml_file_path) xml_root = et.parse(str(xml_file_path)).getroot() n_bits = int(xml_root.find('acquisitionSystem').find('nBits').text) dtype = f"int{n_bits}" numchan_from_file = int(xml_root.find('acquisitionSystem').find('nChannels').text) if file_path.suffix == ".dat": sampling_frequency = float(xml_root.find('acquisitionSystem').find('samplingRate').text) else: sampling_frequency = float(xml_root.find('fieldPotentials').find('lfpSamplingRate').text) BinDatRecordingExtractor.__init__(self, file_path, sampling_frequency=sampling_frequency, dtype=dtype, numchan=numchan_from_file, gain=gain) self._kwargs = dict(file_path=str(Path(file_path).absolute()), gain=gain) @staticmethod def write_recording( recording: RecordingExtractor, save_path: PathType, dtype: DtypeType = None, **write_binary_kwargs ): """ Convert and save the recording extractor to Neuroscope format. Parameters ---------- recording: RecordingExtractor The recording extractor to be converted and saved. save_path: str Path to desired target folder. The name of the files will be the same as the final directory. dtype: dtype Optional. Data type to be used in writing; must be int16 or int32 (default). Will throw a warning if stored recording type from get_traces() does not match. **write_binary_kwargs: keyword arguments for write_to_binary_dat_format function - chunk_size - chunk_mb """ save_path = Path(save_path) save_path.mkdir(parents=True, exist_ok=True) if save_path.suffix == "": recording_name = save_path.name else: recording_name = save_path.stem xml_name = recording_name save_xml_filepath = save_path / f"{xml_name}.xml" recording_filepath = save_path / recording_name # create parameters file if none exists if save_xml_filepath.is_file(): raise FileExistsError(f"{save_xml_filepath} already exists!") xml_root = et.Element('xml') et.SubElement(xml_root, 'acquisitionSystem') et.SubElement(xml_root.find('acquisitionSystem'), 'nBits') et.SubElement(xml_root.find('acquisitionSystem'), 'nChannels') et.SubElement(xml_root.find('acquisitionSystem'), 'samplingRate') recording_dtype = str(recording.get_dtype()) int_loc = recording_dtype.find('int') recording_n_bits = recording_dtype[(int_loc + 3):(int_loc + 5)] valid_dtype = ["16", "32"] if dtype is None: if int_loc != -1 and recording_n_bits in valid_dtype: n_bits = recording_n_bits else: print("Warning: Recording data type must be int16 or int32! Defaulting to int32.") n_bits = "32" dtype = f"int{n_bits}" # update dtype in pass to BinDatRecordingExtractor.write_recording else: dtype = str(dtype) # if user passed numpy data type int_loc = dtype.find('int') assert int_loc != -1, "Data type must be int16 or int32! Non-integer received." n_bits = dtype[(int_loc + 3):(int_loc + 5)] assert n_bits in valid_dtype, "Data type must be int16 or int32!" xml_root.find('acquisitionSystem').find('nBits').text = n_bits xml_root.find('acquisitionSystem').find('nChannels').text = str(recording.get_num_channels()) xml_root.find('acquisitionSystem').find('samplingRate').text = str(recording.get_sampling_frequency()) et.ElementTree(xml_root).write(str(save_xml_filepath), pretty_print=True) recording.write_to_binary_dat_format(recording_filepath, dtype=dtype, **write_binary_kwargs) class NeuroscopeMultiRecordingTimeExtractor(MultiRecordingTimeExtractor): """ Extracts raw neural recordings from several binary .dat files in the neuroscope format. The recording extractor always returns channel IDs starting from 0. The recording data will always be returned in the shape of (num_channels,num_frames). Parameters ---------- folder_path : PathType Path to the .dat files to be extracted. gain : float, optional Numerical value that converts the native int dtype to microvolts. Defaults to 1. xml_file_path : PathType, optional Path to the .xml file referenced by this recording. """ extractor_name = "NeuroscopeMultiRecordingTimeExtractor" installed = HAVE_LXML is_writable = True mode = "folder" installation_mesg = "Please install lxml to use this extractor!" def __init__(self, folder_path: PathType, gain: Optional[float] = None, xml_file_path: OptionalPathType = None): assert self.installed, self.installation_mesg folder_path = Path(folder_path) recording_files = [x for x in folder_path.iterdir() if x.is_file() and x.suffix == ".dat"] assert any(recording_files), "The folder_path must lead to at least one .dat file!" recordings = [NeuroscopeRecordingExtractor(file_path=x, gain=gain, xml_file_path=xml_file_path) for x in recording_files] MultiRecordingTimeExtractor.__init__(self, recordings=recordings) self._kwargs = dict(folder_path=str(folder_path.absolute()), gain=gain) @staticmethod def write_recording( recording: Union[MultiRecordingTimeExtractor, RecordingExtractor], save_path: PathType, dtype: DtypeType = None, **write_binary_kwargs ): """ Convert and save the recording extractor to Neuroscope format. Parameters ---------- recording: MultiRecordingTimeExtractor or RecordingExtractor The recording extractor to be converted and saved. save_path: str Path to desired target folder. The name of the files will be the same as the final directory. dtype: dtype Optional. Data type to be used in writing; must be int16 or int32 (default). Will throw a warning if stored recording type from get_traces() does not match. **write_binary_kwargs: keyword arguments for write_to_binary_dat_format function - chunk_size - chunk_mb """ save_path = Path(save_path) save_path.mkdir(parents=True, exist_ok=True) if save_path.suffix == "": recording_name = save_path.name else: recording_name = save_path.stem xml_name = recording_name save_xml_filepath = save_path / f"{xml_name}.xml" if save_xml_filepath.is_file(): raise FileExistsError(f"{save_xml_filepath} already exists!") recording_dtype = str(recording.get_dtype()) int_loc = recording_dtype.find("int") recording_n_bits = recording_dtype[(int_loc + 3):(int_loc + 5)] valid_int_types = ["16", "32"] if dtype is None: if int_loc != -1 and recording_n_bits in valid_int_types: n_bits = recording_n_bits else: warnings.warn("Recording data type must be int16 or int32! Defaulting to int32.") n_bits = "32" dtype = f"int{n_bits}" else: dtype = str(dtype) int_loc = dtype.find('int') assert int_loc != -1, "Data type must be int16 or int32! Non-integer received." n_bits = dtype[(int_loc + 3):(int_loc + 5)] assert n_bits in valid_int_types, "Data type must be int16 or int32!" xml_root = et.Element('xml') et.SubElement(xml_root, 'acquisitionSystem') et.SubElement(xml_root.find('acquisitionSystem'), 'nBits') et.SubElement(xml_root.find('acquisitionSystem'), 'nChannels') et.SubElement(xml_root.find('acquisitionSystem'), 'samplingRate') xml_root.find('acquisitionSystem').find('nBits').text = n_bits xml_root.find('acquisitionSystem').find('nChannels').text = str(recording.get_num_channels()) xml_root.find('acquisitionSystem').find('samplingRate').text = str(recording.get_sampling_frequency()) et.ElementTree(xml_root).write(str(save_xml_filepath.absolute()), pretty_print=True) if isinstance(recording, MultiRecordingTimeExtractor): for n, record in enumerate(recording.recordings): epoch_id = str(n).zfill(2) # Neuroscope seems to zero-pad length 2 record.write_to_binary_dat_format( save_path=save_path / f"{recording_name}-{epoch_id}.dat", dtype=dtype, **write_binary_kwargs ) elif isinstance(recording, RecordingExtractor): recordings = [recording.get_epoch(epoch_name=epoch_name) for epoch_name in recording.get_epoch_names()] if len(recordings) == 0: recording.write_to_binary_dat_format( save_path=save_path / f"{recording_name}.dat", dtype=dtype, **write_binary_kwargs ) else: for n, subrecording in enumerate(recordings): epoch_id = str(n).zfill(2) # Neuroscope seems to zero-pad length 2 subrecording.write_to_binary_dat_format( save_path=save_path / f"{recording_name}-{epoch_id}.dat", dtype=dtype, **write_binary_kwargs ) class NeuroscopeSortingExtractor(SortingExtractor): """ Extracts spiking information from pair of .res and .clu files. The .res is a text file with a sorted list of spiketimes from all units displayed in sample (integer '%i') units. The .clu file is a file with one more row than the .res with the first row corresponding to the total number of unique ids in the file (and may exclude 0 & 1 from this count) with the rest of the rows indicating which unit id the corresponding entry in the .res file refers to. In the original Neuroscope format: Unit ID 0 is the cluster of unsorted spikes (noise). Unit ID 1 is a cluster of multi-unit spikes. The function defaults to returning multi-unit activity as the first index, and ignoring unsorted noise. To return only the fully sorted units, set keep_mua_units=False. The sorting extractor always returns unit IDs from 1, ..., number of chosen clusters. Parameters ---------- resfile_path : PathType Optional. Path to a particular .res text file. clufile_path : PathType Optional. Path to a particular .clu text file. folder_path : PathType Optional. Path to the collection of .res and .clu text files. Will auto-detect format. keep_mua_units : bool Optional. Whether or not to return sorted spikes from multi-unit activity. Defaults to True. spkfile_path : PathType Optional. Path to a particular .spk binary file containing waveform snippets added to the extractor as features. gain : float Optional. If passing a spkfile_path, this value converts the data type of the waveforms to units of microvolts. xml_file_path : PathType, optional Path to the .xml file referenced by this sorting. """ extractor_name = "NeuroscopeSortingExtractor" installed = HAVE_LXML is_writable = True mode = "custom" installation_mesg = "Please install lxml to use this extractor!" def __init__( self, resfile_path: OptionalPathType = None, clufile_path: OptionalPathType = None, folder_path: OptionalPathType = None, keep_mua_units: bool = True, spkfile_path: OptionalPathType = None, gain: Optional[float] = None, xml_file_path: OptionalPathType = None, ): assert self.installed, self.installation_mesg assert not (folder_path is None and resfile_path is None and clufile_path is None), \ "Either pass a single folder_path location, or a pair of resfile_path and clufile_path! None received." if resfile_path is not None: assert clufile_path is not None, "If passing resfile_path or clufile_path, both are required!" resfile_path = Path(resfile_path) clufile_path = Path(clufile_path) assert resfile_path.is_file() and clufile_path.is_file(), \ f"The resfile_path ({resfile_path}) and clufile_path ({clufile_path}) must be .res and .clu files!" assert folder_path is None, "Pass either a single folder_path location, " \ "or a pair of resfile_path and clufile_path! All received." folder_path_passed = False folder_path = resfile_path.parent else: assert folder_path is not None, "Either pass resfile_path and clufile_path, or folder_path!" folder_path = Path(folder_path) assert folder_path.is_dir(), "The folder_path must be a directory!" res_files = get_single_files(folder_path=folder_path, suffix=".res") clu_files = get_single_files(folder_path=folder_path, suffix=".clu") assert len(res_files) > 0 or len(clu_files) > 0, \ "No .res or .clu files found in the folder_path!" assert len(res_files) == 1 and len(clu_files) == 1, \ "NeuroscopeSortingExtractor expects a single pair of .res and .clu files in the folder_path. " \ "For multiple .res and .clu files, use the NeuroscopeMultiSortingExtractor instead." folder_path_passed = True # flag for setting kwargs for proper dumping resfile_path = res_files[0] clufile_path = clu_files[0] SortingExtractor.__init__(self) res_sorting_name = resfile_path.name[:resfile_path.name.find('.res')] clu_sorting_name = clufile_path.name[:clufile_path.name.find('.clu')] assert res_sorting_name == clu_sorting_name, "The .res and .clu files do not share the same name! " \ f"{res_sorting_name} -- {clu_sorting_name}" xml_file_path = handle_xml_file_path(folder_path=folder_path, initial_xml_file_path=xml_file_path) xml_root = et.parse(str(xml_file_path)).getroot() self._sampling_frequency = float(xml_root.find('acquisitionSystem').find('samplingRate').text) with open(resfile_path) as f: res = np.array([int(line) for line in f], np.int64) with open(clufile_path) as f: clu = np.array([int(line) for line in f], np.int64) n_spikes = len(res) if n_spikes > 0: # Extract the number of unique IDs from the first line of the clufile then remove it from the list n_clu = clu[0] clu = np.delete(clu, 0) unique_ids = np.unique(clu) assert len(unique_ids) == n_clu, ( "First value of .clu file ({clufile_path}) does not match number of unique IDs!" ) unit_map = dict(zip(unique_ids, list(range(1, n_clu + 1)))) if 0 in unique_ids: unit_map.pop(0) if not keep_mua_units and 1 in unique_ids: unit_map.pop(1) self._unit_ids = unit_map.values() self._spiketrains = [] for s_id in unit_map: self._spiketrains.append(res[(clu == s_id).nonzero()]) if spkfile_path is not None and Path(spkfile_path).is_file(): n_bits = int(xml_root.find('acquisitionSystem').find('nBits').text) dtype = f"int{n_bits}" n_samples = int(xml_root.find('neuroscope').find('spikes').find('nSamples').text) wf = np.moveaxis(np.memmap(spkfile_path, dtype=dtype).reshape(n_spikes, n_samples, -1), 1, -1) for unit_id in self.get_unit_ids(): if gain is not None: self.set_unit_property(unit_id=unit_id, property_name='gain', value=gain) self.set_unit_spike_features( unit_id=unit_id, feature_name='waveforms', value=wf[clu == unit_id + 1 - int(keep_mua_units), :, :] ) if folder_path_passed: self._kwargs = dict( resfile_path=None, clufile_path=None, folder_path=str(folder_path.absolute()), keep_mua_units=keep_mua_units, gain=gain ) else: self._kwargs = dict( resfile_path=str(resfile_path.absolute()), clufile_path=str(clufile_path.absolute()), folder_path=None, keep_mua_units=keep_mua_units, gain=gain ) if spkfile_path is not None: self._kwargs.update(spkfile_path=str(spkfile_path.absolute())) else: self._kwargs.update(spkfile_path=spkfile_path) def get_unit_ids(self): return list(self._unit_ids) def get_sampling_frequency(self): return self._sampling_frequency def shift_unit_ids(self, shift): self._unit_ids = [x + shift for x in self._unit_ids] def add_unit(self, unit_id, spike_times): """This function adds a new unit with the given spike times. Parameters ---------- unit_id: int The unit_id of the unit to be added. """ self._unit_ids.append(unit_id) self._spiketrains.append(spike_times) @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): times = self._spiketrains[self.get_unit_ids().index(unit_id)] inds = np.where((start_frame <= times) & (times < end_frame)) return times[inds] @staticmethod def write_sorting(sorting: SortingExtractor, save_path: PathType): # if multiple groups, use the NeuroscopeMultiSortingExtactor write function if 'group' in sorting.get_shared_unit_property_names(): NeuroscopeMultiSortingExtractor.write_sorting(sorting, save_path) else: save_path.mkdir(parents=True, exist_ok=True) if save_path.suffix == '': sorting_name = save_path.name else: sorting_name = save_path.stem xml_name = sorting_name save_xml_filepath = save_path / (str(xml_name) + '.xml') # create parameters file if none exists if save_xml_filepath.is_file(): raise FileExistsError(f'{save_xml_filepath} already exists!') xml_root = et.Element('xml') et.SubElement(xml_root, 'acquisitionSystem') et.SubElement(xml_root.find('acquisitionSystem'), 'samplingRate') xml_root.find('acquisitionSystem').find('samplingRate').text = str(sorting.get_sampling_frequency()) et.ElementTree(xml_root).write(str(save_xml_filepath.absolute()), pretty_print=True) # Create and save .res and .clu files from the current sorting object save_res = save_path / f'{sorting_name}.res' save_clu = save_path / f'{sorting_name}.clu' res, clu = _extract_res_clu_arrays(sorting) np.savetxt(save_res, res, fmt='%i') np.savetxt(save_clu, clu, fmt='%i') class NeuroscopeMultiSortingExtractor(MultiSortingExtractor): """ Extracts spiking information from an arbitrary number of .res.%i and .clu.%i files in the general folder path. The .res is a text file with a sorted list of spiketimes from all units displayed in sample (integer '%i') units. The .clu file is a file with one more row than the .res with the first row corresponding to the total number of unique ids in the file (and may exclude 0 & 1 from this count) with the rest of the rows indicating which unit id the corresponding entry in the .res file refers to. The group id is loaded as unit property 'group'. In the original Neuroscope format: Unit ID 0 is the cluster of unsorted spikes (noise). Unit ID 1 is a cluster of multi-unit spikes. The function defaults to returning multi-unit activity as the first index, and ignoring unsorted noise. To return only the fully sorted units, set keep_mua_units=False. The sorting extractor always returns unit IDs from 1, ..., number of chosen clusters. Parameters ---------- folder_path : str Optional. Path to the collection of .res and .clu text files. Will auto-detect format. keep_mua_units : bool Optional. Whether or not to return sorted spikes from multi-unit activity. Defaults to True. exclude_shanks : list Optional. List of indices to ignore. The set of all possible indices is chosen by default, extracted as the final integer of all the .res.%i and .clu.%i pairs. load_waveforms : bool Optional. If True, extracts waveform data from .spk.%i files in the path corresponding to the .res.%i and .clue.%i files and sets these as unit spike features. Defaults to False. gain : float Optional. If passing a spkfile_path, this value converts the data type of the waveforms to units of microvolts. xml_file_path : PathType, optional Path to the .xml file referenced by this sorting. """ extractor_name = "NeuroscopeMultiSortingExtractor" installed = HAVE_LXML is_writable = True mode = "folder" installation_mesg = "Please install lxml to use this extractor!" def __init__( self, folder_path: PathType, keep_mua_units: bool = True, exclude_shanks: Optional[list] = None, load_waveforms: bool = False, gain: Optional[float] = None, xml_file_path: OptionalPathType = None, ): assert self.installed, self.installation_mesg folder_path = Path(folder_path) if exclude_shanks is not None: # dumping checks do not like having an empty list as default assert all([isinstance(x, (int, np.integer)) and x >= 0 for x in exclude_shanks]), "Optional argument 'exclude_shanks' must contain positive integers only!" exclude_shanks_passed = True else: exclude_shanks = [] exclude_shanks_passed = False xml_file_path = handle_xml_file_path(folder_path=folder_path, initial_xml_file_path=xml_file_path) xml_root = et.parse(str(xml_file_path)).getroot() self._sampling_frequency = float(xml_root.find('acquisitionSystem').find('samplingRate').text) res_files = get_shank_files(folder_path=folder_path, suffix=".res") clu_files = get_shank_files(folder_path=folder_path, suffix=".clu") assert len(res_files) > 0 or len(clu_files) > 0, "No .res or .clu files found in the folder_path!" assert len(res_files) == len(clu_files) res_ids = [int(x.suffix[1:]) for x in res_files] clu_ids = [int(x.suffix[1:]) for x in clu_files] assert sorted(res_ids) == sorted(clu_ids), "Unmatched .clu.%i and .res.%i files detected!" if any([x not in res_ids for x in exclude_shanks]): warnings.warn("Detected indices in exclude_shanks that are not in the directory! These will be ignored.") resfile_names = [x.name[:x.name.find('.res')] for x in res_files] clufile_names = [x.name[:x.name.find('.clu')] for x in clu_files] assert np.all(r == c for (r, c) in zip(resfile_names, clufile_names)), \ "Some of the .res.%i and .clu.%i files do not share the same name!" sorting_name = resfile_names[0] all_shanks_list_se = [] for shank_id in list(set(res_ids) - set(exclude_shanks)): nse_args = dict( resfile_path=folder_path / f"{sorting_name}.res.{shank_id}", clufile_path=folder_path / f"{sorting_name}.clu.{shank_id}", keep_mua_units=keep_mua_units, xml_file_path=xml_file_path, ) if load_waveforms: spk_files = get_shank_files(folder_path=folder_path, suffix=".spk") assert len(spk_files) > 0, "No .spk files found in the folder_path, but 'write_waveforms' is True!" assert len(spk_files) == len(res_files), "Mismatched number of .spk and .res files!" spk_ids = [int(x.suffix[1:]) for x in spk_files] assert sorted(spk_ids) == sorted(res_ids), "Unmatched .spk.%i and .res.%i files detected!" spkfile_names = [x.name[:x.name.find('.spk')] for x in spk_files] assert np.all(s == r for (s, r) in zip(spkfile_names, resfile_names)), \ "Some of the .spk.%i and .res.%i files do not share the same name!" nse_args.update(spkfile_path=folder_path / f"{sorting_name}.spk.{shank_id}", gain=gain) all_shanks_list_se.append(NeuroscopeSortingExtractor(**nse_args)) MultiSortingExtractor.__init__(self, sortings=all_shanks_list_se) if exclude_shanks_passed: self._kwargs = dict( folder_path=str(folder_path.absolute()), keep_mua_units=keep_mua_units, exclude_shanks=exclude_shanks, load_waveforms=load_waveforms, gain=gain ) else: self._kwargs = dict( folder_path=str(folder_path.absolute()), keep_mua_units=keep_mua_units, exclude_shanks=None, load_waveforms=load_waveforms, gain=gain ) @staticmethod def write_sorting(sorting: Union[MultiSortingExtractor, SortingExtractor], save_path: PathType): save_path = Path(save_path) if save_path.suffix == '': sorting_name = save_path.name else: sorting_name = save_path.stem xml_name = sorting_name save_xml_filepath = save_path / (str(xml_name) + '.xml') assert not save_path.is_file(), "Argument 'save_path' should be a folder!" save_path.mkdir(parents=True, exist_ok=True) if save_xml_filepath.is_file(): raise FileExistsError(f"{save_xml_filepath} already exists!") xml_root = et.Element('xml') et.SubElement(xml_root, 'acquisitionSystem') et.SubElement(xml_root.find('acquisitionSystem'), 'samplingRate') xml_root.find('acquisitionSystem').find('samplingRate').text = str(sorting.get_sampling_frequency()) et.ElementTree(xml_root).write(str(save_xml_filepath.absolute()), pretty_print=True) if isinstance(sorting, MultiSortingExtractor): counter = 1 for sort in sorting.sortings: # Create and save .res.%i and .clu.%i files from the current sorting object save_res = save_path / f"{sorting_name}.res.{counter}" save_clu = save_path / f"{sorting_name}.clu.{counter}" counter += 1 res, clu = _extract_res_clu_arrays(sort) np.savetxt(save_res, res, fmt="%i") np.savetxt(save_clu, clu, fmt="%i") elif isinstance(sorting, SortingExtractor): # assert units have group property assert 'group' in sorting.get_shared_unit_property_names() sortings, groups = get_sub_extractors_by_property(sorting, 'group', return_property_list=True) for (sort, group) in zip(sortings, groups): # Create and save .res.%i and .clu.%i files from the current sorting object save_res = save_path / f"{sorting_name}.res.{group}" save_clu = save_path / f"{sorting_name}.clu.{group}" res, clu = _extract_res_clu_arrays(sort) np.savetxt(save_res, res, fmt="%i") np.savetxt(save_clu, clu, fmt="%i") def _extract_res_clu_arrays(sorting): unit_ids = sorting.get_unit_ids() if len(unit_ids) > 0: spiketrains = [sorting.get_unit_spike_train(u) for u in unit_ids] res = np.concatenate(spiketrains).ravel() clu = np.concatenate( [np.repeat(i + 1, len(st)) for i, st in enumerate(spiketrains)]).ravel() # i here counts from 0 res_sort = np.argsort(res) res = res[res_sort] clu = clu[res_sort] unique_ids = np.unique(clu) n_clu = len(unique_ids) clu = np.insert(clu, 0, n_clu) # The +1 is necessary becuase the base sorting object is from 1,...,nUnits else: res, clu = [], [] return res, clu ================================================ FILE: spikeextractors/extractors/nixioextractors/__init__.py ================================================ from .nixioextractors import NIXIORecordingExtractor, NIXIOSortingExtractor ================================================ FILE: spikeextractors/extractors/nixioextractors/nixioextractors.py ================================================ import os import numpy as np from collections.abc import Iterable from pathlib import Path try: import nixio as nix HAVE_NIXIO = True except ImportError: HAVE_NIXIO = False from spikeextractors import RecordingExtractor from spikeextractors import SortingExtractor from spikeextractors.extraction_tools import check_get_traces_args, check_get_unit_spike_train class NIXIORecordingExtractor(RecordingExtractor): extractor_name = 'NIXIORecording' has_default_locations = False has_unscaled = False installed = HAVE_NIXIO is_writable = True installation_mesg = "To use the NIXIORecordingExtractor install nixio: \n\n pip install nixio\n\n" mode = 'file' def __init__(self, file_path): assert self.installed, self.installation_mesg file_path = str(file_path) RecordingExtractor.__init__(self) self._file = nix.File.open(file_path, nix.FileMode.ReadOnly) self._load_properties() self._kwargs = {'file_path': str(Path(file_path).absolute())} def __del__(self): self._file.close() @property def _traces(self): blk = self._file.blocks[0] da = blk.data_arrays["traces"] return da def get_channel_ids(self): da = self._traces channel_dim = da.dimensions[0] channel_ids = [int(chid) for chid in channel_dim.labels] return channel_ids def get_num_frames(self): da = self._traces return da.shape[1] def get_sampling_frequency(self): da = self._traces timedim = da.dimensions[1] sampling_frequency = 1./timedim.sampling_interval return sampling_frequency @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): channels = np.array([self._traces[cid] for cid in channel_ids]) return channels[:, start_frame:end_frame] def _load_properties(self): traces_md = self._traces.metadata if traces_md is None: # no metadata stored return for chan_md in traces_md.sections: chan_id = int(chan_md.name) for prop in chan_md.props: values = prop.values if self._file.version <= (1, 1, 0): values = [v.value for v in prop.values] if len(values) == 1: values = values[0] self.set_channel_property(chan_id, prop.name, values) @staticmethod def write_recording(recording, save_path, overwrite=False): assert HAVE_NIXIO, NIXIORecordingExtractor.installation_mesg if os.path.exists(save_path) and not overwrite: raise FileExistsError("File exists: {}".format(save_path)) nf = nix.File.open(save_path, nix.FileMode.Overwrite) # use the file name to name the top-level block fname = os.path.basename(save_path) block = nf.create_block(fname, "spikeinterface.recording") da = block.create_data_array("traces", "spikeinterface.traces", data=recording.get_traces()) da.unit = "uV" da.label = "voltage" labels = recording.get_channel_ids() if not labels: # channel IDs not specified; just number them labels = list(range(recording.get_num_channels())) chandim = da.append_set_dimension() chandim.labels = labels sfreq = recording.get_sampling_frequency() timedim = da.append_sampled_dimension(sampling_interval=1./sfreq) timedim.unit = "s" # In NIX, channel properties are stored as follows # Traces metadata (nix.Section) # | # |--- Channel 0 (nix.Section) # | | # | |---- Location (nix.Property) # | | # | |---- Other property a (nix.Property) # | | # | `---- Other property b (nix.Property) # | # `--- Channel 1 (nix.Section) # | # |---- Location (nix.Property) # | # |---- Other property a (nix.Property) # | # `---- Other property b (nix.Property) traces_md = nf.create_section("traces.metadata", "spikeinterface.properties") da.metadata = traces_md channels = recording.get_channel_ids() for chan_id in channels: chan_md = traces_md.create_section(str(chan_id), "spikeinterface.properties") for propname in recording.get_channel_property_names(chan_id): propvalue = recording.get_channel_property(chan_id, propname) if nf.version <= (1, 1, 0): if isinstance(propvalue, Iterable): values = list(map(nix.Value, propvalue)) else: values = nix.Value(propvalue) else: values = propvalue chan_md.create_property(propname, values) nf.close() class NIXIOSortingExtractor(SortingExtractor): extractor_name = 'NIXIOSorting' installed = HAVE_NIXIO is_writable = True installation_mesg = "To use the NIXIORecordingExtractor install nixio: \n\n pip install nixio\n\n" mode = 'file' def __init__(self, file_path): assert self.installed, self.installation_mesg file_path = str(file_path) SortingExtractor.__init__(self) self._file = nix.File.open(file_path, nix.FileMode.ReadOnly) md = self._file.sections if "sampling_frequency" in md: sfreq = md["sampling_frequency"] self._sampling_frequency = sfreq self._load_properties() self._kwargs = {'file_path': str(Path(file_path).absolute())} def __del__(self): self._file.close() @property def _spike_das(self): blk = self._file.blocks[0] return blk.data_arrays def get_unit_ids(self): return [int(da.label) for da in self._spike_das] @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): name = "spikes-{}".format(unit_id) da = self._spike_das[name] if np.isfinite(end_frame): return da[start_frame:end_frame] else: return da[start_frame:] def _load_properties(self): spikes_md = self._spike_das[0].metadata if spikes_md is None: # no metadata stored return for unit_md in spikes_md.sections: unit_id = int(unit_md.name) for prop in unit_md.props: values = prop.values if self._file.version <= (1, 1, 0): values = [v.value for v in prop.values] if len(values) == 1: values = values[0] self.set_unit_property(unit_id, prop.name, values) @staticmethod def write_sorting(sorting, save_path, overwrite=False): assert HAVE_NIXIO, NIXIOSortingExtractor.installation_mesg if os.path.exists(save_path) and not overwrite: raise FileExistsError("File exists: {}".format(save_path)) sfreq = sorting.get_sampling_frequency() if sfreq is None: unit = None elif sfreq == 1: unit = "s" else: unit = "{} s".format(1./sfreq) nf = nix.File.open(save_path, nix.FileMode.Overwrite) # use the file name to name the top-level block fname = os.path.basename(save_path) block = nf.create_block(fname, "spikeinterface.sorting") commonmd = nf.create_section(fname, "spikeinterface.sorting.metadata") if sfreq is not None: commonmd["sampling_frequency"] = sfreq spikes_das = list() for unit_id in sorting.get_unit_ids(): spikes = sorting.get_unit_spike_train(unit_id) name = "spikes-{}".format(unit_id) da = block.create_data_array(name, "spikeinterface.spikes", data=spikes) da.unit = unit da.label = str(unit_id) spikes_das.append(da) spikes_md = nf.create_section("spikes.metadata", "spikeinterface.properties") for da in spikes_das: da.metadata = spikes_md units = sorting.get_unit_ids() for unit_id in units: unit_md = spikes_md.create_section(str(unit_id), "spikeinterface.properties") for propname in sorting.get_unit_property_names(unit_id): propvalue = sorting.get_unit_property(unit_id, propname) if nf.version <= (1, 1, 0): if isinstance(propvalue, Iterable): values = list(map(nix.Value, propvalue)) else: values = nix.Value(propvalue) else: values = propvalue unit_md.create_property(propname, values) nf.close() ================================================ FILE: spikeextractors/extractors/npzsortingextractor/__init__.py ================================================ from .npzsortingextractor import NpzSortingExtractor ================================================ FILE: spikeextractors/extractors/npzsortingextractor/npzsortingextractor.py ================================================ from spikeextractors import SortingExtractor from pathlib import Path from spikeextractors.extraction_tools import check_get_unit_spike_train import numpy as np class NpzSortingExtractor(SortingExtractor): """ Dead simple and super light format based on the NPZ numpy format. https://docs.scipy.org/doc/numpy/reference/generated/numpy.savez.html#numpy.savez It is in fact an arichive of several .npy format. All spike are store in two columns maner index+labels """ extractor_name = 'NpzSorting' installed = True # depend only on numpy installation_mesg = "Always installed" is_writable = True mode = 'file' def __init__(self, file_path): SortingExtractor.__init__(self) self.npz_filename = file_path npz = np.load(file_path) self.unit_ids = npz['unit_ids'] self.spike_indexes = npz['spike_indexes'] self.spike_labels = npz['spike_labels'] if 'sampling_frequency' in npz: self._sampling_frequency = float(npz['sampling_frequency'][0]) else: self._sampling_frequency = None self._kwargs = {'file_path': str(Path(file_path).absolute())} def get_unit_ids(self): return list(self.unit_ids) @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): spike_times = self.spike_indexes[self.spike_labels == unit_id] if start_frame is not None: spike_times = spike_times[spike_times >= start_frame] if end_frame is not None: spike_times = spike_times[spike_times < end_frame] return spike_times.astype('int64') @staticmethod def write_sorting(sorting, save_path): d = {} units_ids = np.array(sorting.get_unit_ids()) d['unit_ids'] = units_ids spike_indexes = [] spike_labels = [] for unit_id in units_ids: sp_ind = sorting.get_unit_spike_train(unit_id) spike_indexes.append(sp_ind) spike_labels.append(np.ones(sp_ind.size, dtype='int64')*unit_id) # order times if len(spike_indexes) > 0: spike_indexes = np.concatenate(spike_indexes) spike_labels = np.concatenate(spike_labels) order = np.argsort(spike_indexes) spike_indexes = spike_indexes[order] spike_labels = spike_labels[order] else: spike_indexes = np.array([], dtype='int64') spike_labels = np.array([], dtype='int64') d['spike_indexes'] = spike_indexes d['spike_labels'] = spike_labels if sorting.get_sampling_frequency() is not None: d['sampling_frequency'] = np.array([sorting.get_sampling_frequency()], dtype='float64') np.savez(save_path, **d) ================================================ FILE: spikeextractors/extractors/numpyextractors/__init__.py ================================================ from .numpyextractors import NumpyRecordingExtractor, NumpySortingExtractor ================================================ FILE: spikeextractors/extractors/numpyextractors/numpyextractors.py ================================================ from spikeextractors import RecordingExtractor from spikeextractors import SortingExtractor from pathlib import Path import numpy as np from spikeextractors.extraction_tools import check_get_traces_args, check_get_unit_spike_train, check_get_ttl_args """ The NumpyExtractors can be constructed and used to encapsulate custom file formats and data structures which contain information about recordings or sorting results. NumpyExtractors are instantiated in-memory and function like any other Recording/SortingExtractor. """ class NumpyRecordingExtractor(RecordingExtractor): extractor_name = 'NumpyRecording' is_writable = True has_default_locations = False has_unscaled = False def __init__(self, timeseries, sampling_frequency, geom=None): RecordingExtractor.__init__(self) if isinstance(timeseries, str): if Path(timeseries).is_file(): assert Path(timeseries).suffix == '.npy', "'timeseries' file is not a numpy file (.npy)" self.is_dumpable = True self._timeseries = np.load(timeseries) self._kwargs = {'timeseries': str(Path(timeseries).absolute()), 'sampling_frequency': sampling_frequency, 'geom': geom} else: raise ValueError("'timeeseries' is does not exist") elif isinstance(timeseries, np.ndarray): self.is_dumpable = False self._timeseries = timeseries self._kwargs = {'timeseries': timeseries, 'sampling_frequency': sampling_frequency, 'geom': geom} else: raise TypeError("'timeseries' can be a str or a numpy array") self._sampling_frequency = float(sampling_frequency) self._geom = geom if geom is not None: self.set_channel_locations(self._geom) self._ttl_frames = None self._ttl_states = None def set_ttls(self, ttl_frames, ttl_states=None): self._ttl_frames = ttl_frames.astype('int64') if ttl_states is not None: self._ttl_states = ttl_states.astype('int64') else: self._ttl_states = np.ones_like(ttl_frames, dtype='int64') def get_channel_ids(self): return list(range(self._timeseries.shape[0])) def get_num_frames(self): return self._timeseries.shape[1] def get_sampling_frequency(self): return self._sampling_frequency @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): recordings = self._timeseries[:, start_frame:end_frame][channel_ids, :] return recordings @check_get_ttl_args def get_ttl_events(self, start_frame=None, end_frame=None, channel_id=0): if self._ttl_frames is not None and self._ttl_states is not None: ttl_idxs = np.where((self._ttl_frames >= start_frame) & (self._ttl_frames < end_frame))[0] return self._ttl_frames[ttl_idxs], self._ttl_states[ttl_idxs] else: print("TTL frames have not been added to the extractor. You can add them with the `set_ttls()1 function") return None, None @staticmethod def write_recording(recording, save_path): save_path = Path(save_path) np.save(save_path, recording.get_traces()) class NumpySortingExtractor(SortingExtractor): extractor_name = 'NumpySorting' is_writable = False def __init__(self): SortingExtractor.__init__(self) self._units = {} self.is_dumpable = False def load_from_extractor(self, sorting, copy_unit_properties=False, copy_unit_spike_features=False): """This function loads the information from a SortingExtractor into this extractor. Parameters ---------- sorting: SortingExtractor The SortingExtractor from which this extractor will copy information. copy_unit_properties: bool If True, the unit_properties will be copied from the given SortingExtractor to this extractor. copy_unit_spike_features: bool If True, the unit_spike_features will be copied from the given SortingExtractor to this extractor. """ ids = sorting.get_unit_ids() for id in ids: self.add_unit(id, sorting.get_unit_spike_train(id)) if sorting.get_sampling_frequency() is not None: self.set_sampling_frequency(sorting.get_sampling_frequency()) if copy_unit_properties: self.copy_unit_properties(sorting) if copy_unit_spike_features: self.copy_unit_spike_features(sorting) def set_sampling_frequency(self, sampling_frequency): self._sampling_frequency = sampling_frequency def set_times_labels(self, times, labels): """This function takes in an array of spike times (in frames) and an array of spike labels and adds all the unit information in these lists into the extractor. Parameters ---------- times: np.array An array of spike times (in frames). labels: np.array An array of spike labels corresponding to the given times. """ units = np.sort(np.unique(labels)) for unit in units: times0 = times[np.where(labels == unit)[0]] self.add_unit(unit_id=int(unit), times=times0) def add_unit(self, unit_id, times): """This function adds a new unit with the given spike times. Parameters ---------- unit_id: int The unit_id of the unit to be added. times: np.array An array of spike times (in frames). """ self._units[unit_id] = dict(times=times) def get_unit_ids(self): return list(self._units.keys()) @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): times = self._units[unit_id]['times'] inds = np.where((start_frame <= times) & (times < end_frame))[0] return np.rint(times[inds]).astype(int) ================================================ FILE: spikeextractors/extractors/nwbextractors/__init__.py ================================================ from .nwbextractors import NwbRecordingExtractor, NwbSortingExtractor ================================================ FILE: spikeextractors/extractors/nwbextractors/nwbextractors.py ================================================ import uuid from datetime import datetime from collections import abc from pathlib import Path import numpy as np from packaging.version import parse from typing import Union, List, Optional import warnings import spikeextractors as se from spikeextractors.extraction_tools import check_get_traces_args, check_get_unit_spike_train try: import pandas as pd import pynwb from pynwb import NWBHDF5IO from pynwb import NWBFile from pynwb.ecephys import ElectricalSeries, FilteredEphys, LFP from pynwb.ecephys import ElectrodeGroup from hdmf.data_utils import DataChunkIterator from hdmf.backends.hdf5.h5_utils import H5DataIO HAVE_NWB = True except ModuleNotFoundError: HAVE_NWB = False PathType = Union[str, Path, None] ArrayType = Union[list, np.ndarray] def check_nwb_install(): assert HAVE_NWB, NwbRecordingExtractor.installation_mesg def set_dynamic_table_property(dynamic_table, row_ids, property_name, values, index=False, default_value=np.nan, table=False, description='no description'): check_nwb_install() if not isinstance(row_ids, list) or not all(isinstance(x, (int, np.integer)) for x in row_ids): raise TypeError("'ids' must be a list of integers") ids = list(dynamic_table.id[:]) if any([i not in ids for i in row_ids]): raise ValueError("'ids' contains values outside the range of existing ids") if not isinstance(property_name, str): raise TypeError("'property_name' must be a string") if len(row_ids) != len(values) and index is False: raise ValueError("'ids' and 'values' should be lists of same size") if index is False: if property_name in dynamic_table: for (row_id, value) in zip(row_ids, values): dynamic_table[property_name].data[ids.index(row_id)] = value else: col_data = [default_value] * len(ids) # init with default val for (row_id, value) in zip(row_ids, values): col_data[ids.index(row_id)] = value dynamic_table.add_column( name=property_name, description=description, data=col_data, index=index, table=table ) else: if property_name in dynamic_table: # TODO raise NotImplementedError else: dynamic_table.add_column( name=property_name, description=description, data=values, index=index, table=table ) def get_dynamic_table_property(dynamic_table, *, row_ids=None, property_name): all_row_ids = list(dynamic_table.id[:]) if row_ids is None: row_ids = all_row_ids return [dynamic_table[property_name][all_row_ids.index(x)] for x in row_ids] def get_nspikes(units_table, unit_id): """Return the number of spikes for chosen unit.""" check_nwb_install() ids = np.array(units_table.id[:]) indexes = np.where(ids == unit_id)[0] if not len(indexes): raise ValueError(f"{unit_id} is an invalid unit_id. Valid ids: {ids}.") index = indexes[0] if index == 0: return units_table['spike_times_index'].data[index] else: return units_table['spike_times_index'].data[index] - units_table['spike_times_index'].data[index - 1] def most_relevant_ch(traces: ArrayType): """ Calculate the most relevant channel for a given Unit. Estimates the channel where the max-min difference of the average traces is greatest. Parameters ---------- traces : ndarray ndarray of shape (nSpikes, nChannels, nSamples) """ n_channels = traces.shape[1] avg = np.mean(traces, axis=0) max_min = np.zeros(n_channels) for ch in range(n_channels): max_min[ch] = avg[ch, :].max() - avg[ch, :].min() relevant_ch = np.argmax(max_min) return relevant_ch def update_dict(d: dict, u: dict): """Smart dictionary updates.""" if u is not None: for k, v in u.items(): if isinstance(v, abc.Mapping): d[k] = update_dict(d.get(k, {}), v) else: d[k] = v return d def list_get(li: list, idx: int, default): """Safe index retrieval from list.""" try: return li[idx] except IndexError: return default def check_module(nwbfile, name: str, description: str = None): """ Check if processing module exists. If not, create it. Then return module. Parameters ---------- nwbfile: pynwb.NWBFile name: str description: str | None (optional) Returns ------- pynwb.module """ assert isinstance(nwbfile, pynwb.NWBFile), "'nwbfile' should be of type pynwb.NWBFile" if name in nwbfile.modules: return nwbfile.modules[name] else: if description is None: description = name return nwbfile.create_processing_module(name, description) class NwbRecordingExtractor(se.RecordingExtractor): """Primary class for interfacing between NWBFiles and RecordingExtractors.""" extractor_name = 'NwbRecording' has_default_locations = True has_unscaled = False installed = HAVE_NWB # check at class level if installed or not is_writable = True mode = 'file' installation_mesg = "To use the Nwb extractors, install pynwb: \n\n pip install pynwb\n\n" def __init__(self, file_path: PathType, electrical_series_name: str = None): """ Load an NWBFile as a RecordingExtractor. Parameters ---------- file_path: path to NWB file electrical_series_name: str, optional """ assert self.installed, self.installation_mesg se.RecordingExtractor.__init__(self) self._path = str(file_path) with NWBHDF5IO(self._path, 'r') as io: nwbfile = io.read() if electrical_series_name is not None: self._electrical_series_name = electrical_series_name else: a_names = list(nwbfile.acquisition) if len(a_names) > 1: raise ValueError("More than one acquisition found! You must specify 'electrical_series_name'.") if len(a_names) == 0: raise ValueError("No acquisitions found in the .nwb file.") self._electrical_series_name = a_names[0] es = nwbfile.acquisition[self._electrical_series_name] if hasattr(es, 'timestamps') and es.timestamps: self.sampling_frequency = 1. / np.median(np.diff(es.timestamps)) self.recording_start_time = es.timestamps[0] else: self.sampling_frequency = es.rate if hasattr(es, 'starting_time'): self.recording_start_time = es.starting_time else: self.recording_start_time = 0. self.num_frames = int(es.data.shape[0]) num_channels = len(es.electrodes.data) # Channels gains - for RecordingExtractor, these are values to cast traces to uV if es.channel_conversion is not None: gains = es.conversion * es.channel_conversion[:] * 1e6 else: gains = es.conversion * np.ones(num_channels) * 1e6 # Extractors channel groups must be integers, but Nwb electrodes group_name can be strings if 'group_name' in nwbfile.electrodes.colnames: unique_grp_names = list(np.unique(nwbfile.electrodes['group_name'][:])) # Fill channel properties dictionary from electrodes table self.channel_ids = [es.electrodes.table.id[x] for x in es.electrodes.data] # If gains are not 1, set has_scaled to True if np.any(gains != 1): self.set_channel_gains(gains) self.has_unscaled = True for es_ind, (channel_id, electrode_table_index) in enumerate(zip(self.channel_ids, es.electrodes.data)): this_loc = [] if 'rel_x' in nwbfile.electrodes: this_loc.append(nwbfile.electrodes['rel_x'][electrode_table_index]) if 'rel_y' in nwbfile.electrodes: this_loc.append(nwbfile.electrodes['rel_y'][electrode_table_index]) else: this_loc.append(0) self.set_channel_locations(this_loc, channel_id) for col in nwbfile.electrodes.colnames: if isinstance(nwbfile.electrodes[col][electrode_table_index], ElectrodeGroup): continue elif col == 'group_name': self.set_channel_groups( int(unique_grp_names.index(nwbfile.electrodes[col][electrode_table_index])), channel_id) elif col == 'location': self.set_channel_property(channel_id, 'brain_area', nwbfile.electrodes[col][electrode_table_index]) elif col == 'offset': self.set_channel_offsets(channel_ids=channel_id, offsets=nwbfile.electrodes[col][electrode_table_index]) elif col in ['x', 'y', 'z', 'rel_x', 'rel_y']: continue else: self.set_channel_property(channel_id, col, nwbfile.electrodes[col][electrode_table_index]) # Fill epochs dictionary self._epochs = {} if nwbfile.epochs is not None: df_epochs = nwbfile.epochs.to_dataframe() if 'tags' in df_epochs: tags_or_label = 'tags' # older nwb schema version else: tags_or_label = 'label' self._epochs = { row[tags_or_label][0]: { 'start_frame': self.time_to_frame(row['start_time']), 'end_frame': self.time_to_frame(row['stop_time']) } for _, row in df_epochs.iterrows() } self._kwargs = {'file_path': str(Path(file_path).absolute()), 'electrical_series_name': electrical_series_name} self.make_nwb_metadata(nwbfile=nwbfile, es=es) def make_nwb_metadata(self, nwbfile, es): # Metadata dictionary - useful for constructing a nwb file self.nwb_metadata = dict() self.nwb_metadata['NWBFile'] = { 'session_description': nwbfile.session_description, 'identifier': nwbfile.identifier, 'session_start_time': nwbfile.session_start_time, 'institution': nwbfile.institution, 'lab': nwbfile.lab } self.nwb_metadata['Ecephys'] = dict() # Update metadata with Device info self.nwb_metadata['Ecephys']['Device'] = [] for dev in nwbfile.devices: self.nwb_metadata['Ecephys']['Device'].append({'name': dev}) # Update metadata with ElectrodeGroup info self.nwb_metadata['Ecephys']['ElectrodeGroup'] = [] for k, v in nwbfile.electrode_groups.items(): self.nwb_metadata['Ecephys']['ElectrodeGroup'].append({ 'name': v.name, 'description': v.description, 'location': v.location, 'device': v.device.name }) # Update metadata with ElectricalSeries info self.nwb_metadata['Ecephys']['ElectricalSeries'] = dict( name=es.name, description=es.description ) @check_get_traces_args def get_traces( self, channel_ids: ArrayType = None, start_frame: int = None, end_frame: int = None, return_scaled: bool = True ): with NWBHDF5IO(self._path, 'r') as io: nwbfile = io.read() es = nwbfile.acquisition[self._electrical_series_name] es_channel_ids = np.array(es.electrodes.table.id[:])[es.electrodes.data[:]].tolist() channel_inds = [es_channel_ids.index(id) for id in channel_ids] if np.array(channel_inds).size > 1 and np.any(np.diff(channel_inds) < 0): # h5py constraint does not allow datasets to be indexed out of order ind_sort_order = np.argsort(channel_inds) sorted_channel_inds = np.array(channel_inds)[ind_sort_order] recordings = es.data[start_frame:end_frame, sorted_channel_inds] traces = recordings[:, ind_sort_order].T else: traces = es.data[start_frame:end_frame, channel_inds].T return traces def get_sampling_frequency(self): return self.sampling_frequency def get_num_frames(self): return self.num_frames def get_channel_ids(self): return self.channel_ids @staticmethod def add_devices(recording: se.RecordingExtractor, nwbfile=None, metadata: dict = None): """ Auxiliary static method for nwbextractor. Adds device information to nwbfile object. Will always ensure nwbfile has at least one device, but multiple devices within the metadata list will also be created. Parameters ---------- recording: RecordingExtractor nwbfile: NWBFile nwb file to which the recording information is to be added metadata: dict metadata info for constructing the nwb file (optional). Should be of the format metadata['Ecephys']['Device'] = [{'name': my_name, 'description': my_description}, ...] Missing keys in an element of metadata['Ecephys']['Device'] will be auto-populated with defaults. """ if nwbfile is not None: assert isinstance(nwbfile, NWBFile), "'nwbfile' should be of type pynwb.NWBFile" defaults = dict( name="Device", description="Ecephys probe." ) if metadata is None: metadata = dict() if 'Ecephys' not in metadata: metadata['Ecephys'] = dict() if 'Device' not in metadata['Ecephys']: metadata['Ecephys']['Device'] = [defaults] assert all([isinstance(x, dict) for x in metadata['Ecephys']['Device']]), \ "Expected metadata['Ecephys']['Device'] to be a list of dictionaries!" for dev in metadata['Ecephys']['Device']: if dev.get('name', defaults['name']) not in nwbfile.devices: nwbfile.create_device(**dict(defaults, **dev)) @staticmethod def add_electrode_groups(recording: se.RecordingExtractor, nwbfile=None, metadata: dict = None): """ Auxiliary static method for nwbextractor. Adds electrode group information to nwbfile object. Will always ensure nwbfile has at least one electrode group. Will auto-generate a linked device if the specified name does not exist in the nwbfile. Parameters ---------- recording: RecordingExtractor nwbfile: NWBFile nwb file to which the recording information is to be added metadata: dict metadata info for constructing the nwb file (optional). Should be of the format metadata['Ecephys']['ElectrodeGroup'] = [{'name': my_name, 'description': my_description, 'location': electrode_location, 'device': my_device_name}, ...] Missing keys in an element of metadata['Ecephys']['ElectrodeGroup'] will be auto-populated with defaults. Group names set by RecordingExtractor channel properties will also be included with passed metadata, but will only use default description and location. """ if nwbfile is not None: assert isinstance(nwbfile, NWBFile), "'nwbfile' should be of type pynwb.NWBFile" if len(nwbfile.devices) == 0: se.NwbRecordingExtractor.add_devices(recording=recording, nwbfile=nwbfile, metadata=metadata) if metadata is None: metadata = dict() if 'Ecephys' not in metadata: metadata['Ecephys'] = dict() defaults = [ dict( name=str(group_id), description="no description", location="unknown", device=[i.name for i in nwbfile.devices.values()][0] ) for group_id in np.unique(recording.get_channel_groups()) ] if 'ElectrodeGroup' not in metadata['Ecephys']: metadata['Ecephys']['ElectrodeGroup'] = defaults assert all([isinstance(x, dict) for x in metadata['Ecephys']['ElectrodeGroup']]), \ "Expected metadata['Ecephys']['ElectrodeGroup'] to be a list of dictionaries!" for grp in metadata['Ecephys']['ElectrodeGroup']: if grp.get('name', defaults[0]['name']) not in nwbfile.electrode_groups: device_name = grp.get('device', defaults[0]['device']) if device_name not in nwbfile.devices: new_device = dict( Ecephys=dict( Device=[dict( name=device_name )] ) ) se.NwbRecordingExtractor.add_devices(recording, nwbfile, metadata=new_device) warnings.warn(f"Device \'{device_name}\' not detected in " "attempted link to electrode group! Automatically generating.") electrode_group_kwargs = dict(defaults[0], **grp) # electrode_group_kwargs.pop('device') electrode_group_kwargs.update(device=nwbfile.devices[device_name]) nwbfile.create_electrode_group(**electrode_group_kwargs) if not nwbfile.electrode_groups: device_name = list(nwbfile.devices.keys())[0] device = nwbfile.devices[device_name] if len(nwbfile.devices) > 1: warnings.warn("More than one device found when adding electrode group " f"via channel properties: using device \'{device_name}\'. To use a " "different device, indicate it the metadata argument.") electrode_group_kwargs = dict(defaults[0]) electrode_group_kwargs.update(device=device) for grp_name in np.unique(recording.get_channel_groups()).tolist(): electrode_group_kwargs.update(name=str(grp_name)) nwbfile.create_electrode_group(**electrode_group_kwargs) @staticmethod def add_electrodes(recording: se.RecordingExtractor, nwbfile=None, metadata: dict = None, write_scaled: bool = True): """ Auxiliary static method for nwbextractor. Adds channels from recording object as electrodes to nwbfile object. Parameters ---------- recording: RecordingExtractor nwbfile: NWBFile nwb file to which the recording information is to be added write_scaled: bool (optional, defaults to True) If True, writes the scaled traces (return_scaled=True) metadata: dict metadata info for constructing the nwb file (optional). Should be of the format metadata['Ecephys']['Electrodes'] = [{'name': my_name, 'description': my_description, 'data': [my_electrode_data]}, ...] where each dictionary corresponds to a column in the Electrodes table and [my_electrode_data] is a list in one-to-one correspondence with the nwbfile electrode ids and RecordingExtractor channel ids. Missing keys in an element of metadata['Ecephys']['ElectrodeGroup'] will be auto-populated with defaults whenever possible. If 'my_name' is set to one of the required fields for nwbfile electrodes (id, x, y, z, imp, loccation, filtering, group_name), then the metadata will override their default values. Setting 'my_name' to metadata field 'group' is not supported as the linking to nwbfile.electrode_groups is handled automatically; please specify the string 'group_name' in this case. If no group information is passed via metadata, automatic linking to existing electrode groups, possibly including the default, will occur. """ if nwbfile is not None: assert isinstance(nwbfile, NWBFile), "'nwbfile' should be of type pynwb.NWBFile" if nwbfile.electrode_groups is None: se.NwbRecordingExtractor.add_electrode_groups(recording, nwbfile, metadata) # For older versions of pynwb, we need to manually add these columns if parse(pynwb.__version__) < parse('1.3.0'): if nwbfile.electrodes is None or 'rel_x' not in nwbfile.electrodes.colnames: nwbfile.add_electrode_column('rel_x', 'x position of electrode in electrode group') if nwbfile.electrodes is None or 'rel_y' not in nwbfile.electrodes.colnames: nwbfile.add_electrode_column('rel_y', 'y position of electrode in electrode group') defaults = dict( x=np.nan, y=np.nan, z=np.nan, # There doesn't seem to be a canonical default for impedence, if missing. # The NwbRecordingExtractor follows the -1.0 convention, other scripts sometimes use np.nan imp=-1.0, location="unknown", filtering="none", group_name="ElectrodeGroup" ) if metadata is None: metadata = dict(Ecephys=dict()) if 'Electrodes' not in metadata['Ecephys']: metadata['Ecephys']['Electrodes'] = [] assert all([isinstance(x, dict) and set(x.keys()) == set(['name', 'description', 'data']) and isinstance(x['data'], list) for x in metadata['Ecephys']['Electrodes']]), \ "Expected metadata['Ecephys']['Electrodes'] to be a list of dictionaries!" assert all([x['name'] != 'group' for x in metadata['Ecephys']['Electrodes']]), \ "Passing metadata field 'group' is depricated; pass group_name instead!" if nwbfile.electrodes is None: nwb_elec_ids = [] else: nwb_elec_ids = nwbfile.electrodes.id.data[:] for metadata_column in metadata['Ecephys']['Electrodes']: if (nwbfile.electrodes is None or metadata_column['name'] not in nwbfile.electrodes.colnames) \ and metadata_column['name'] != 'group_name': nwbfile.add_electrode_column( name=str(metadata_column['name']), description=str(metadata_column['description']) ) for j, channel_id in enumerate(recording.get_channel_ids()): if channel_id not in nwb_elec_ids: electrode_kwargs = dict(defaults) electrode_kwargs.update(id=channel_id) # recording.get_channel_locations defaults to np.nan if there are none location = recording.get_channel_locations(channel_ids=channel_id)[0] if all([not np.isnan(loc) for loc in location]): # property 'location' of RX channels corresponds to rel_x and rel_ y of NWB electrodes electrode_kwargs.update( dict( rel_x=float(location[0]), rel_y=float(location[1]) ) ) for metadata_column in metadata['Ecephys']['Electrodes']: if metadata_column['name'] == 'group_name': group_name = list_get(metadata_column['data'], j, defaults['group_name']) if group_name not in nwbfile.electrode_groups: warnings.warn(f"Electrode group for electrode {channel_id} was not " "found in the nwbfile! Automatically adding.") missing_group_metadata = dict( Ecephys=dict( ElectrodeGroup=[dict( name=group_name, description="no description", location="unknown", device="Device" )] ) ) se.NwbRecordingExtractor.add_electrode_groups(recording, nwbfile, missing_group_metadata) electrode_kwargs.update( dict( group=nwbfile.electrode_groups[group_name], group_name=group_name ) ) else: if metadata_column['name'] in defaults: electrode_kwargs.update({ metadata_column['name']: list_get(metadata_column['data'], j, defaults[metadata_column['name']]) }) else: if j < len(metadata_column['data']): electrode_kwargs.update({ metadata_column['name']: metadata_column['data'][j] }) else: metadata_column_name = metadata_column['name'] warnings.warn(f"Custom column {metadata_column_name} " f"has incomplete data for channel id [{j}] and no " "set default! Electrode will not be added.") continue if not any([x.get('name', '') == 'group_name' for x in metadata['Ecephys']['Electrodes']]): group_id = recording.get_channel_groups(channel_ids=channel_id)[0] if str(group_id) in nwbfile.electrode_groups: electrode_kwargs.update( dict( group=nwbfile.electrode_groups[str(group_id)], group_name=str(group_id) ) ) else: warnings.warn("No metadata was passed specifying the electrode group for " f"electrode {channel_id}, and the internal recording channel group was " f"assigned a value (str({group_id})) not present as electrode " "groups in the NWBFile! Electrode will not be added.") continue nwbfile.add_electrode(**electrode_kwargs) assert nwbfile.electrodes is not None, \ "Unable to form electrode table! Check device, electrode group, and electrode metadata." # property 'gain' should not be in the NWB electrodes_table # property 'brain_area' of RX channels corresponds to 'location' of NWB electrodes # property 'offset' should not be in the NWB electrodes_table as not officially supported by schema v2.2.5 channel_prop_names = set(recording.get_shared_channel_property_names()) - set(nwbfile.electrodes.colnames) \ - {'gain', 'location', 'offset'} for channel_prop_name in channel_prop_names: for channel_id in recording.get_channel_ids(): val = recording.get_channel_property(channel_id, channel_prop_name) descr = 'no description' if channel_prop_name == 'brain_area': channel_prop_name = 'location' descr = 'brain area location' set_dynamic_table_property( dynamic_table=nwbfile.electrodes, row_ids=[int(channel_id)], property_name=channel_prop_name, values=[val], default_value=np.nan, description=descr ) @staticmethod def add_electrical_series( recording: se.RecordingExtractor, nwbfile=None, metadata: dict = None, buffer_mb: int = 500, use_times: bool = False, write_as: str = 'raw', es_key: str = None, write_scaled: bool = False ): """ Auxiliary static method for nwbextractor. Adds traces from recording object as ElectricalSeries to nwbfile object. Parameters ---------- recording: RecordingExtractor nwbfile: NWBFile nwb file to which the recording information is to be added metadata: dict metadata info for constructing the nwb file (optional). Should be of the format metadata['Ecephys']['ElectricalSeries'] = {'name': my_name, 'description': my_description} buffer_mb: int (optional, defaults to 500MB) maximum amount of memory (in MB) to use per iteration of the DataChunkIterator (requires traces to be memmap objects) use_times: bool (optional, defaults to False) If True, the times are saved to the nwb file using recording.frame_to_time(). If False (defualut), the sampling rate is used. write_as: str (optional, defaults to 'raw') How to save the traces data in the nwb file. Options: - 'raw' will save it in acquisition - 'processed' will save it as FilteredEphys, in a processing module - 'lfp' will save it as LFP, in a processing module es_key: str (optional) Key in metadata dictionary containing metadata info for the specific electrical series write_scaled: bool (optional, defaults to True) If True, writes the scaled traces (return_scaled=True) Missing keys in an element of metadata['Ecephys']['ElectrodeGroup'] will be auto-populated with defaults whenever possible. """ if nwbfile is not None: assert isinstance(nwbfile, NWBFile), "'nwbfile' should be of type pynwb.NWBFile!" assert buffer_mb > 10, "'buffer_mb' should be at least 10MB to ensure data can be chunked!" if not nwbfile.electrodes: se.NwbRecordingExtractor.add_electrodes(recording, nwbfile, metadata) assert write_as in ['raw', 'processed', 'lfp'], \ f"'write_as' should be 'raw', 'processed' or 'lfp', but intead received value {write_as}" if write_as == 'raw': eseries_kwargs = dict( name="ElectricalSeries_raw", description="Raw acquired data", comments="Generated from SpikeInterface::NwbRecordingExtractor" ) elif write_as == 'processed': eseries_kwargs = dict( name="ElectricalSeries_processed", description="Processed data", comments="Generated from SpikeInterface::NwbRecordingExtractor" ) # Check for existing processing module and data interface ecephys_mod = check_module( nwbfile=nwbfile, name='ecephys', description="Intermediate data from extracellular electrophysiology recordings, e.g., LFP." ) if 'Processed' not in ecephys_mod.data_interfaces: ecephys_mod.add(FilteredEphys(name='Processed')) elif write_as == 'lfp': eseries_kwargs = dict( name="ElectricalSeries_lfp", description="Processed data - LFP", comments="Generated from SpikeInterface::NwbRecordingExtractor" ) # Check for existing processing module and data interface ecephys_mod = check_module( nwbfile=nwbfile, name='ecephys', description="Intermediate data from extracellular electrophysiology recordings, e.g., LFP." ) if 'LFP' not in ecephys_mod.data_interfaces: ecephys_mod.add(LFP(name='LFP')) # If user passed metadata info, overwrite defaults if metadata is not None and 'Ecephys' in metadata and es_key is not None: assert es_key in metadata['Ecephys'], f"metadata['Ecephys'] dictionary does not contain key '{es_key}'" eseries_kwargs.update(metadata['Ecephys'][es_key]) # Check for existing names in nwbfile if write_as == 'raw': assert eseries_kwargs['name'] not in nwbfile.acquisition, \ f"Raw ElectricalSeries '{eseries_kwargs['name']}' is already written in the NWBFile!" elif write_as == 'processed': assert eseries_kwargs['name'] not in nwbfile.processing['ecephys'].data_interfaces['Processed'].electrical_series, \ f"Processed ElectricalSeries '{eseries_kwargs['name']}' is already written in the NWBFile!" elif write_as == 'lfp': assert eseries_kwargs['name'] not in nwbfile.processing['ecephys'].data_interfaces['LFP'].electrical_series, \ f"LFP ElectricalSeries '{eseries_kwargs['name']}' is already written in the NWBFile!" # Electrodes table region channel_ids = recording.get_channel_ids() table_ids = [list(nwbfile.electrodes.id[:]).index(id) for id in channel_ids] electrode_table_region = nwbfile.create_electrode_table_region( region=table_ids, description="electrode_table_region" ) eseries_kwargs.update(electrodes=electrode_table_region) # channels gains - for RecordingExtractor, these are values to cast traces to uV. # For nwb, the conversions (gains) cast the data to Volts. # To get traces in Volts we take data*channel_conversion*conversion. channel_conversion = recording.get_channel_gains() channel_offset = recording.get_channel_offsets() unsigned_coercion = channel_offset / channel_conversion if not np.all([x.is_integer() for x in unsigned_coercion]): raise NotImplementedError( "Unable to coerce underlying unsigned data type to signed type, which is currently required for NWB " "Schema v2.2.5! Please specify 'write_scaled=True'." ) elif np.any(unsigned_coercion != 0): warnings.warn( "NWB Schema v2.2.5 does not officially support channel offsets. The data will be converted to a signed " "type that does not use offsets." ) unsigned_coercion = unsigned_coercion.astype(int) if write_scaled: eseries_kwargs.update(conversion=1e-6) else: if len(np.unique(channel_conversion)) == 1: # if all gains are equal eseries_kwargs.update(conversion=channel_conversion[0] * 1e-6) else: eseries_kwargs.update(conversion=1e-6) eseries_kwargs.update(channel_conversion=channel_conversion) if isinstance(recording.get_traces(end_frame=5, return_scaled=write_scaled), np.memmap) \ and np.all(channel_offset == 0): n_bytes = np.dtype(recording.get_dtype()).itemsize buffer_size = int(buffer_mb * 1e6) // (recording.get_num_channels() * n_bytes) ephys_data = DataChunkIterator( data=recording.get_traces(return_scaled=write_scaled).T, # nwb standard is time as zero axis buffer_size=buffer_size ) else: def data_generator(recording, channels_ids, unsigned_coercion, write_scaled): for i, ch in enumerate(channels_ids): data = recording.get_traces(channel_ids=[ch], return_scaled=write_scaled) if not write_scaled: data_dtype_name = data.dtype.name if data_dtype_name.startswith("uint"): data_dtype_name = data_dtype_name[1:] # Retain memory of signed data type data = data + unsigned_coercion[i] data = data.astype(data_dtype_name) yield data.flatten() ephys_data = DataChunkIterator( data=data_generator( recording=recording, channels_ids=channel_ids, unsigned_coercion=unsigned_coercion, write_scaled=write_scaled ), iter_axis=1, # nwb standard is time as zero axis maxshape=(recording.get_num_frames(), recording.get_num_channels()) ) eseries_kwargs.update(data=H5DataIO(ephys_data, compression="gzip")) if not use_times: eseries_kwargs.update( starting_time=recording.frame_to_time(0), rate=float(recording.get_sampling_frequency()) ) else: eseries_kwargs.update( timestamps=H5DataIO( recording.frame_to_time(np.arange(recording.get_num_frames())), compression="gzip" ) ) # Add ElectricalSeries to nwbfile object if write_as == 'raw': nwbfile.add_acquisition(ElectricalSeries(**eseries_kwargs)) elif write_as == 'processed': ecephys_mod.data_interfaces['Processed'].add_electrical_series(ElectricalSeries(**eseries_kwargs)) elif write_as == 'lfp': ecephys_mod.data_interfaces['LFP'].add_electrical_series(ElectricalSeries(**eseries_kwargs)) @staticmethod def add_epochs( recording: se.RecordingExtractor, nwbfile=None, metadata: dict = None ): """ Auxiliary static method for nwbextractor. Adds epochs from recording object to nwbfile object. Parameters ---------- recording: RecordingExtractor nwbfile: NWBFile nwb file to which the recording information is to be added metadata: dict metadata info for constructing the nwb file (optional). """ if nwbfile is not None: assert isinstance(nwbfile, NWBFile), "'nwbfile' should be of type pynwb.NWBFile" # add/update epochs for epoch_name in recording.get_epoch_names(): epoch = recording.get_epoch_info(epoch_name) if nwbfile.epochs is None: nwbfile.add_epoch( start_time=recording.frame_to_time(epoch['start_frame']), stop_time=recording.frame_to_time(epoch['end_frame'] - 1), tags=epoch_name ) else: if [epoch_name] in nwbfile.epochs['tags'][:]: ind = nwbfile.epochs['tags'][:].index([epoch_name]) nwbfile.epochs['start_time'].data[ind] = recording.frame_to_time(epoch['start_frame']) nwbfile.epochs['stop_time'].data[ind] = recording.frame_to_time(epoch['end_frame']) else: nwbfile.add_epoch( start_time=recording.frame_to_time(epoch['start_frame']), stop_time=recording.frame_to_time(epoch['end_frame']), tags=epoch_name ) @staticmethod def add_all_to_nwbfile( recording: se.RecordingExtractor, nwbfile=None, buffer_mb: int = 500, use_times: bool = False, metadata: dict = None, write_as: str = 'raw', es_key: str = None, write_scaled: bool = False ): """ Auxiliary static method for nwbextractor. Adds all recording related information from recording object and metadata to the nwbfile object. Parameters ---------- recording: RecordingExtractor nwbfile: NWBFile nwb file to which the recording information is to be added buffer_mb: int (optional, defaults to 500MB) maximum amount of memory (in MB) to use per iteration of the DataChunkIterator (requires traces to be memmap objects) use_times: bool If True, the times are saved to the nwb file using recording.frame_to_time(). If False (defualut), the sampling rate is used. metadata: dict metadata info for constructing the nwb file (optional). Check the auxiliary function docstrings for more information about metadata format. write_as: str (optional, defaults to 'raw') How to save the traces data in the nwb file. Options: - 'raw' will save it in acquisition - 'processed' will save it as FilteredEphys, in a processing module - 'lfp' will save it as LFP, in a processing module es_key: str (optional) Key in metadata dictionary containing metadata info for the specific electrical series write_scaled: bool (optional, defaults to True) If True, writes the scaled traces (return_scaled=True) """ if nwbfile is not None: assert isinstance(nwbfile, NWBFile), "'nwbfile' should be of type pynwb.NWBFile" se.NwbRecordingExtractor.add_devices( recording=recording, nwbfile=nwbfile, metadata=metadata ) se.NwbRecordingExtractor.add_electrode_groups( recording=recording, nwbfile=nwbfile, metadata=metadata ) se.NwbRecordingExtractor.add_electrodes( recording=recording, nwbfile=nwbfile, metadata=metadata, write_scaled=write_scaled ) se.NwbRecordingExtractor.add_electrical_series( recording=recording, nwbfile=nwbfile, buffer_mb=buffer_mb, use_times=use_times, metadata=metadata, write_as=write_as, es_key=es_key, write_scaled=write_scaled ) se.NwbRecordingExtractor.add_epochs( recording=recording, nwbfile=nwbfile, metadata=metadata ) @staticmethod def write_recording( recording: se.RecordingExtractor, save_path: PathType = None, overwrite: bool = False, nwbfile=None, buffer_mb: int = 500, use_times: bool = False, metadata: dict = None, write_as: str = 'raw', es_key: str = None, write_scaled: bool = False ): """ Primary method for writing a RecordingExtractor object to an NWBFile. Parameters ---------- recording: RecordingExtractor save_path: PathType Required if an nwbfile is not passed. Must be the path to the nwbfile being appended, otherwise one is created and written. overwrite: bool If using save_path, whether or not to overwrite the NWBFile if it already exists. nwbfile: NWBFile Required if a save_path is not specified. If passed, this function will fill the relevant fields within the nwbfile. E.g., calling spikeextractors.NwbRecordingExtractor.write_recording( my_recording_extractor, my_nwbfile ) will result in the appropriate changes to the my_nwbfile object. buffer_mb: int (optional, defaults to 500MB) maximum amount of memory (in MB) to use per iteration of the DataChunkIterator (requires traces to be memmap objects) use_times: bool If True, the times are saved to the nwb file using recording.frame_to_time(). If False (defualut), the sampling rate is used. metadata: dict metadata info for constructing the nwb file (optional). Should be of the format metadata['Ecephys'] = {} with keys of the forms metadata['Ecephys']['Device'] = [{'name': my_name, 'description': my_description}, ...] metadata['Ecephys']['ElectrodeGroup'] = [{'name': my_name, 'description': my_description, 'location': electrode_location, 'device': my_device_name}, ...] metadata['Ecephys']['Electrodes'] = [{'name': my_name, 'description': my_description, 'data': [my_electrode_data]}, ...] metadata['Ecephys']['ElectricalSeries'] = {'name': my_name, 'description': my_description} write_as: str (optional, defaults to 'raw') How to save the traces data in the nwb file. Options: - 'raw' will save it in acquisition - 'processed' will save it as FilteredEphys, in a processing module - 'lfp' will save it as LFP, in a processing module es_key: str (optional) Key in metadata dictionary containing metadata info for the specific electrical series write_scaled: bool (optional, defaults to True) If True, writes the scaled traces (return_scaled=True) """ assert HAVE_NWB, NwbRecordingExtractor.installation_mesg if nwbfile is not None: assert isinstance(nwbfile, NWBFile), "'nwbfile' should be of type pynwb.NWBFile" assert parse(pynwb.__version__) >= parse('1.3.3'), \ "'write_recording' not supported for version < 1.3.3. Run pip install --upgrade pynwb" assert save_path is None or nwbfile is None, \ "Either pass a save_path location, or nwbfile object, but not both!" # Update any previous metadata with user passed dictionary if hasattr(recording, 'nwb_metadata'): metadata = update_dict(recording.nwb_metadata, metadata) elif metadata is None: # If not NWBRecording, make metadata from information available on Recording metadata_0 = se.NwbRecordingExtractor.get_nwb_metadata(recording=recording) metadata = update_dict(metadata_0, metadata) if nwbfile is None: if Path(save_path).is_file() and not overwrite: read_mode = 'r+' else: read_mode = 'w' with NWBHDF5IO(str(save_path), mode=read_mode) as io: if read_mode == 'r+': nwbfile = io.read() else: # Default arguments will be over-written if contained in metadata nwbfile_kwargs = dict( session_description="Auto-generated by NwbRecordingExtractor without description.", identifier=str(uuid.uuid4()), session_start_time=datetime(1970, 1, 1) ) if metadata is not None and 'NWBFile' in metadata: nwbfile_kwargs.update(metadata['NWBFile']) nwbfile = NWBFile(**nwbfile_kwargs) se.NwbRecordingExtractor.add_all_to_nwbfile( recording=recording, nwbfile=nwbfile, buffer_mb=buffer_mb, metadata=metadata, use_times=use_times, write_as=write_as, es_key=es_key, write_scaled=write_scaled ) # Write to file io.write(nwbfile) else: se.NwbRecordingExtractor.add_all_to_nwbfile( recording=recording, nwbfile=nwbfile, buffer_mb=buffer_mb, use_times=use_times, metadata=metadata, write_as=write_as, es_key=es_key, write_scaled=write_scaled ) @staticmethod def get_nwb_metadata(recording: se.RecordingExtractor, metadata: dict = None): """ Parameters ---------- recording: RecordingExtractor metadata: dict metadata info for constructing the nwb file (optional). """ metadata = dict( NWBFile=dict( session_description="Auto-generated by NwbRecordingExtractor without description.", identifier=str(uuid.uuid4()), session_start_time=datetime(1970, 1, 1) ), Ecephys=dict( Device=[dict( name="Device", description="no description" )], ElectrodeGroup=[ dict( name=str(gn), description="no description", location="unknown", device="Device" ) for gn in np.unique(recording.get_channel_groups()) ] ) ) return metadata class NwbSortingExtractor(se.SortingExtractor): extractor_name = 'NwbSorting' installed = HAVE_NWB # check at class level if installed or not is_writable = True mode = 'file' installation_mesg = "To use the Nwb extractors, install pynwb: \n\n pip install pynwb\n\n" def __init__(self, file_path, electrical_series=None, sampling_frequency=None): """ Parameters ---------- path: path to NWB file electrical_series: pynwb.ecephys.ElectricalSeries object """ assert self.installed, self.installation_mesg se.SortingExtractor.__init__(self) self._path = str(file_path) with NWBHDF5IO(self._path, 'r') as io: nwbfile = io.read() if sampling_frequency is None: # defines the electrical series from where the sorting came from # important to know the sampling_frequency if electrical_series is None: if len(nwbfile.acquisition) > 1: raise Exception('More than one acquisition found. You must specify electrical_series.') if len(nwbfile.acquisition) == 0: raise Exception("No acquisitions found in the .nwb file from which to read sampling frequency. \ Please, specify 'sampling_frequency' parameter.") es = list(nwbfile.acquisition.values())[0] else: es = electrical_series # get rate if es.rate is not None: self._sampling_frequency = es.rate else: self._sampling_frequency = 1 / (es.timestamps[1] - es.timestamps[0]) else: self._sampling_frequency = sampling_frequency # get all units ids units_ids = nwbfile.units.id[:] # store units properties and spike features to dictionaries all_pr_ft = list(nwbfile.units.colnames) all_names = [i.name for i in nwbfile.units.columns] for item in all_pr_ft: if item == 'spike_times': continue # test if item is a unit_property or a spike_feature if item + '_index' in all_names: # if it has index, it is a spike_feature for u_id in units_ids: ind = list(units_ids).index(u_id) self.set_unit_spike_features(u_id, item, nwbfile.units[item][ind]) else: # if it is unit_property for u_id in units_ids: ind = list(units_ids).index(u_id) if isinstance(nwbfile.units[item][ind], pd.DataFrame): prop_value = nwbfile.units[item][ind].index[0] else: prop_value = nwbfile.units[item][ind] if isinstance(prop_value, (list, np.ndarray)): self.set_unit_property(u_id, item, prop_value) else: if prop_value == prop_value: # not nan self.set_unit_property(u_id, item, prop_value) # Fill epochs dictionary self._epochs = {} if nwbfile.epochs is not None: df_epochs = nwbfile.epochs.to_dataframe() self._epochs = {row['tags'][0]: { 'start_frame': self.time_to_frame(row['start_time']), 'end_frame': self.time_to_frame(row['stop_time'])} for _, row in df_epochs.iterrows()} self._kwargs = {'file_path': str(Path(file_path).absolute()), 'electrical_series': electrical_series, 'sampling_frequency': sampling_frequency} def get_unit_ids(self): """This function returns a list of ids (ints) for each unit in the sorted result. Returns ---------- unit_ids: array_like A list of the unit ids in the sorted result (ints). """ check_nwb_install() with NWBHDF5IO(self._path, 'r') as io: nwbfile = io.read() unit_ids = [int(i) for i in nwbfile.units.id[:]] return unit_ids @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): check_nwb_install() with NWBHDF5IO(self._path, 'r') as io: nwbfile = io.read() # chosen unit and interval times = nwbfile.units['spike_times'][list(nwbfile.units.id[:]).index(unit_id)][:] # spike times are measured in samples frames = self.time_to_frame(times) return frames[(frames > start_frame) & (frames < end_frame)] @staticmethod def write_units( sorting: se.SortingExtractor, nwbfile, property_descriptions: Optional[dict] = None, skip_properties: Optional[List[str]] = None, skip_features: Optional[List[str]] = None, use_times: bool = True ): """Auxilliary function for write_sorting.""" unit_ids = sorting.get_unit_ids() fs = sorting.get_sampling_frequency() if fs is None: raise ValueError("Writing a SortingExtractor to an NWBFile requires a known sampling frequency!") all_properties = set() all_features = set() for unit_id in unit_ids: all_properties.update(sorting.get_unit_property_names(unit_id)) all_features.update(sorting.get_unit_spike_feature_names(unit_id)) default_descriptions = dict( isi_violation="Quality metric that measures the ISI violation ratio as a proxy for the purity of the unit.", firing_rate="Number of spikes per unit of time.", template="The extracellular average waveform.", max_channel="The recording channel id with the largest amplitude.", halfwidth="The full-width half maximum of the negative peak computed on the maximum channel.", peak_to_valley="The duration between the negative and the positive peaks computed on the maximum channel.", snr="The signal-to-noise ratio of the unit.", quality="Quality of the unit as defined by phy (good, mua, noise).", spike_amplitude="Average amplitude of peaks detected on the channel.", spike_rate="Average rate of peaks detected on the channel." ) if property_descriptions is None: property_descriptions = dict(default_descriptions) else: property_descriptions = dict(default_descriptions, **property_descriptions) if skip_properties is None: skip_properties = [] if skip_features is None: skip_features = [] if nwbfile.units is None: # Check that array properties have the same shape across units property_shapes = dict() for pr in all_properties: shapes = [] for unit_id in unit_ids: if pr in sorting.get_unit_property_names(unit_id): prop_value = sorting.get_unit_property(unit_id, pr) if isinstance(prop_value, (int, np.integer, float, str, bool)): shapes.append(1) elif isinstance(prop_value, (list, np.ndarray)): if np.array(prop_value).ndim == 1: shapes.append(len(prop_value)) else: shapes.append(np.array(prop_value).shape) elif isinstance(prop_value, dict): print(f"Skipping property '{pr}' because dictionaries are not supported.") skip_properties.append(pr) break else: shapes.append(np.nan) property_shapes[pr] = shapes for pr in property_shapes.keys(): elems = [elem for elem in property_shapes[pr] if not np.any(np.isnan(elem))] if not np.all([elem == elems[0] for elem in elems]): print(f"Skipping property '{pr}' because it has variable size across units.") skip_properties.append(pr) write_properties = set(all_properties) - set(skip_properties) for pr in write_properties: if pr not in property_descriptions: warnings.warn( f"Description for property {pr} not found in property_descriptions. " "Setting description to 'no description'" ) for pr in write_properties: unit_col_args = dict(name=pr, description=property_descriptions.get(pr, "No description.")) if pr in ['max_channel', 'max_electrode'] and nwbfile.electrodes is not None: unit_col_args.update(table=nwbfile.electrodes) nwbfile.add_unit_column(**unit_col_args) for unit_id in unit_ids: unit_kwargs = dict() if use_times: spkt = sorting.frame_to_time(sorting.get_unit_spike_train(unit_id=unit_id)) else: spkt = sorting.get_unit_spike_train(unit_id=unit_id) / sorting.get_sampling_frequency() for pr in write_properties: if pr in sorting.get_unit_property_names(unit_id): prop_value = sorting.get_unit_property(unit_id, pr) unit_kwargs.update({pr: prop_value}) else: # Case of missing data for this unit and this property unit_kwargs.update({pr: np.nan}) nwbfile.add_unit(id=int(unit_id), spike_times=spkt, **unit_kwargs) # TODO # # Stores average and std of spike traces # This will soon be updated to the current NWB standard # if 'waveforms' in sorting.get_unit_spike_feature_names(unit_id=id): # wf = sorting.get_unit_spike_features(unit_id=id, # feature_name='waveforms') # relevant_ch = most_relevant_ch(wf) # # Spike traces on the most relevant channel # traces = wf[:, relevant_ch, :] # traces_avg = np.mean(traces, axis=0) # traces_std = np.std(traces, axis=0) # nwbfile.add_unit( # id=id, # spike_times=spkt, # waveform_mean=traces_avg, # waveform_sd=traces_std # ) # Check that multidimensional features have the same shape across units feature_shapes = dict() for ft in all_features: shapes = [] for unit_id in unit_ids: if ft in sorting.get_unit_spike_feature_names(unit_id): feat_value = sorting.get_unit_spike_features(unit_id, ft) if isinstance(feat_value[0], (int, np.integer, float, str, bool)): break elif isinstance(feat_value[0], (list, np.ndarray)): # multidimensional features if np.array(feat_value).ndim > 1: shapes.append(np.array(feat_value).shape) feature_shapes[ft] = shapes elif isinstance(feat_value[0], dict): print(f"Skipping feature '{ft}' because dictionaries are not supported.") skip_features.append(ft) break else: print(f"Skipping feature '{ft}' because not share across all units.") skip_features.append(ft) break nspikes = {k: get_nspikes(nwbfile.units, int(k)) for k in unit_ids} for ft in feature_shapes.keys(): # skip first dimension (num_spikes) when comparing feature shape if not np.all([elem[1:] == feature_shapes[ft][0][1:] for elem in feature_shapes[ft]]): print(f"Skipping feature '{ft}' because it has variable size across units.") skip_features.append(ft) for ft in set(all_features) - set(skip_features): values = [] if not ft.endswith('_idxs'): for unit_id in sorting.get_unit_ids(): feat_vals = sorting.get_unit_spike_features(unit_id, ft) if len(feat_vals) < nspikes[unit_id]: skip_features.append(ft) print(f"Skipping feature '{ft}' because it is not defined for all spikes.") break # this means features are available for a subset of spikes # all_feat_vals = np.array([np.nan] * nspikes[unit_id]) # feature_idxs = sorting.get_unit_spike_features(unit_id, feat_name + '_idxs') # all_feat_vals[feature_idxs] = feat_vals else: all_feat_vals = feat_vals values.append(all_feat_vals) flatten_vals = [item for sublist in values for item in sublist] nspks_list = [sp for sp in nspikes.values()] spikes_index = np.cumsum(nspks_list).astype('int64') if ft in nwbfile.units: # If property already exists, skip it warnings.warn(f'Feature {ft} already present in units table, skipping it') continue set_dynamic_table_property( dynamic_table=nwbfile.units, row_ids=[int(k) for k in unit_ids], property_name=ft, values=flatten_vals, index=spikes_index, ) else: warnings.warn("The nwbfile already contains units. These units will not be over-written.") @staticmethod def write_sorting( sorting: se.SortingExtractor, save_path: PathType = None, overwrite: bool = False, nwbfile=None, property_descriptions: Optional[dict] = None, skip_properties: Optional[List[str]] = None, skip_features: Optional[List[str]] = None, use_times: bool = True, **nwbfile_kwargs ): """ Primary method for writing a SortingExtractor object to an NWBFile. Parameters ---------- sorting: SortingExtractor save_path: PathType Required if an nwbfile is not passed. The location where the NWBFile either exists, or will be written. overwrite: bool If using save_path, whether or not to overwrite the NWBFile if it already exists. nwbfile: NWBFile Required if a save_path is not specified. If passed, this function will fill the relevant fields within the nwbfile. E.g., calling spikeextractors.NwbRecordingExtractor.write_recording( my_recording_extractor, my_nwbfile ) will result in the appropriate changes to the my_nwbfile object. property_descriptions: dict For each key in this dictionary which matches the name of a unit property in sorting, adds the value as a description to that custom unit column. skip_properties: list of str Each string in this list that matches a unit property will not be written to the NWBFile. skip_features: list of str Each string in this list that matches a spike feature will not be written to the NWBFile. use_times: bool (optional, defaults to False) If True, the times are saved to the nwb file using sorting.frame_to_time(). If False (defualut), the sampling rate is used. nwbfile_kwargs: dict Information for constructing the nwb file (optional). Only used if no nwbfile exists at the save_path, and no nwbfile was directly passed. """ assert HAVE_NWB, NwbSortingExtractor.installation_mesg assert save_path is None or nwbfile is None, \ "Either pass a save_path location, or nwbfile object, but not both!" if nwbfile is not None: assert isinstance(nwbfile, NWBFile), "'nwbfile' should be a pynwb.NWBFile object!" if nwbfile is None: if Path(save_path).is_file() and not overwrite: read_mode = 'r+' else: read_mode = 'w' with NWBHDF5IO(str(save_path), mode=read_mode) as io: if read_mode == 'r+': nwbfile = io.read() else: default_nwbfile_kwargs = dict( session_description="Auto-generated by NwbRecordingExtractor without description.", identifier=str(uuid.uuid4()), session_start_time=datetime(1970, 1, 1) ) default_nwbfile_kwargs.update(**nwbfile_kwargs) nwbfile = NWBFile(**default_nwbfile_kwargs) se.NwbSortingExtractor.write_units( sorting=sorting, nwbfile=nwbfile, property_descriptions=property_descriptions, skip_properties=skip_properties, skip_features=skip_features, use_times=use_times ) io.write(nwbfile) else: se.NwbSortingExtractor.write_units( sorting=sorting, nwbfile=nwbfile, property_descriptions=property_descriptions, skip_properties=skip_properties, skip_features=skip_features, use_times=use_times ) ================================================ FILE: spikeextractors/extractors/openephysextractors/__init__.py ================================================ from .openephysextractors import OpenEphysRecordingExtractor, OpenEphysSortingExtractor ================================================ FILE: spikeextractors/extractors/openephysextractors/openephysextractors.py ================================================ from spikeextractors import RecordingExtractor, SortingExtractor from pathlib import Path import numpy as np from spikeextractors.extraction_tools import check_get_traces_args, check_get_unit_spike_train, check_get_ttl_args from packaging.version import parse import warnings try: import pyopenephys HAVE_OE = True if parse(pyopenephys.__version__) >= parse("1.1.2"): HAVE_OE_11 = True else: warnings.warn("pyopenephys>=1.1.2 should be installed. Support for older versions will be removed in " "future releases. Install with:\n\n pip install --upgrade pyopenephys\n\n") HAVE_OE_11 = False except ImportError: HAVE_OE = False HAVE_OE_11 = False extractors_dir = Path(__file__).parent.parent class OpenEphysRecordingExtractor(RecordingExtractor): extractor_name = 'OpenEphysRecording' has_default_locations = False has_unscaled = True installed = HAVE_OE # check at class level if installed or not is_writable = False mode = 'folder' installation_mesg = "To use the OpenEphys extractor, install pyopenephys: \n\n pip install pyopenephys\n\n" def __init__(self, folder_path, experiment_id=0, recording_id=0): assert self.installed, self.installation_mesg RecordingExtractor.__init__(self) self._recording_file = folder_path self._fileobj = pyopenephys.File(folder_path) self._recording = self._fileobj.experiments[experiment_id].recordings[recording_id] self._set_analogsignal(self._recording.analog_signals[0]) self._kwargs = {'folder_path': str(Path(folder_path).absolute()), 'experiment_id': experiment_id, 'recording_id': recording_id} def _set_analogsignal(self, analogsignals): self._analogsignals = analogsignals # Set gains: int16 to uV if HAVE_OE_11: self.set_channel_gains(gains=self._analogsignals.gains) else: self.set_channel_gains(gains=self._analogsignals.gain) def get_channel_ids(self): if HAVE_OE_11: return list(self._analogsignals.channel_ids) else: return list(range(self._analogsignals.signal.shape[0])) def get_num_frames(self): return self._analogsignals.signal.shape[1] def get_sampling_frequency(self): if HAVE_OE_11: return self._analogsignals.sample_rate else: return float(self._recording.sample_rate.rescale('Hz').magnitude) @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): return self._analogsignals.signal[channel_ids, start_frame:end_frame] @check_get_ttl_args def get_ttl_events(self, start_frame=None, end_frame=None, channel_id=0): channels = [np.unique(ev.channels)[0] for ev in self._recording.events] assert channel_id in channels, f"Specified 'channel' not found. Available channels are {channels}" ev = self._recording.events[channels.index(channel_id)] ttl_frames = (ev.times.rescale("s") * self.get_sampling_frequency()).magnitude.astype(int) ttl_states = np.sign(ev.channel_states) ttl_valid_idxs = np.where((ttl_frames >= start_frame) & (ttl_frames < end_frame))[0] return ttl_frames[ttl_valid_idxs], ttl_states[ttl_valid_idxs] class OpenEphysNPIXRecordingExtractor(OpenEphysRecordingExtractor): extractor_name = 'OpenEphysNPIXRecording' has_default_locations = False installed = HAVE_OE_11 # check at class level if installed or not is_writable = False mode = 'folder' installation_mesg = "To use the OpenEphys extractor, " \ "install pyopenephys >= 1.1: \n\n pip install pyopenephys>=1.1\n\n" def __init__(self, folder_path, experiment_id=0, recording_id=0, stream="AP"): assert self.installed, self.installation_mesg assert stream.upper() in ["AP", "LFP"] OpenEphysRecordingExtractor.__init__(self, folder_path, experiment_id, recording_id) analogsignals = self._recording.analog_signals for analog in analogsignals: channel_names = analog.channel_names if np.all([stream.upper() in chan for chan in channel_names]): self._set_analogsignal(analog) # load neuropixels locations channel_locations = np.loadtxt(extractors_dir / 'neuropixelsdatrecordingextractor' / 'channel_positions_neuropixels.txt').T # get correct channel ID from channel name (e.g. AP32 --> 32) channel_ids = [int(chan_name[chan_name.find(stream.upper())+len(stream):]) - 1 for chan_name in channel_names] locations = channel_locations[channel_ids] self.set_channel_locations(locations) for i, ch in enumerate(self.get_channel_ids()): self.set_channel_property(ch, "channel_name", channel_names[i]) break self._kwargs = {'folder_path': str(Path(folder_path).absolute()), 'experiment_id': experiment_id, 'recording_id': recording_id, 'stream': stream} class OpenEphysSortingExtractor(SortingExtractor): extractor_name = 'OpenEphysSorting' installed = HAVE_OE # check at class level if installed or not is_writable = False mode = 'file' installation_mesg = "To use the OpenEphys extractor, install pyopenephys: \n\n pip install pyopenephys\n\n" # error message when not installed def __init__(self, folder_path, experiment_id=0, recording_id=0): assert self.installed, self.installation_mesg SortingExtractor.__init__(self) self._recording_file = folder_path self._recording = pyopenephys.File(folder_path).experiments[experiment_id].recordings[recording_id] self._spiketrains = self._recording.spiketrains self._unit_ids = list([np.unique(st.clusters)[0] for st in self._spiketrains]) self._sampling_frequency = float(self._recording.sample_rate.rescale('Hz').magnitude) self._kwargs = {'folder_path': str(Path(folder_path).absolute()), 'experiment_id': experiment_id, 'recording_id': recording_id} def get_unit_ids(self): return self._unit_ids @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): st = self._spiketrains[unit_id] inds = np.where((start_frame <= (st.times * self._recording.sample_rate)) & ((st.times * self._recording.sample_rate) < end_frame)) return (st.times[inds] * self._recording.sample_rate).magnitude ================================================ FILE: spikeextractors/extractors/phyextractors/__init__.py ================================================ from .phyextractors import PhyRecordingExtractor, PhySortingExtractor ================================================ FILE: spikeextractors/extractors/phyextractors/phyextractors.py ================================================ import numpy as np from pathlib import Path import csv from typing import Union, Optional from spikeextractors import SortingExtractor, RecordingExtractor from spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor from spikeextractors.extraction_tools import read_python, check_get_unit_spike_train PathType = Union[str, Path] class PhyRecordingExtractor(BinDatRecordingExtractor): """ RecordingExtractor for a Phy output folder Parameters ---------- folder_path: str or Path Path to the output Phy folder (containing the params.py) """ extractor_name = 'PhyRecording' has_default_locations = True has_unscaled = False installed = True # check at class level if installed or not is_writable = False mode = 'folder' installation_mesg = "" # error message when not installed def __init__(self, folder_path: PathType): RecordingExtractor.__init__(self) phy_folder = Path(folder_path) self.params = read_python(str(phy_folder / 'params.py')) datfile = [x for x in phy_folder.iterdir() if x.suffix == '.dat' or x.suffix == '.bin'] if (phy_folder / 'channel_map_si.npy').is_file(): channel_map = list(np.squeeze(np.load(phy_folder / 'channel_map_si.npy'))) assert max(channel_map) < self.params['n_channels_dat'], "Channel map inconsistent with dat file." elif (phy_folder / 'channel_map.npy').is_file(): channel_map = list(np.squeeze(np.load(phy_folder / 'channel_map.npy'))) assert max(channel_map) < self.params['n_channels_dat'], "Channel map inconsistent with dat file." else: channel_map = list(range(self.params['n_channels_dat'])) BinDatRecordingExtractor.__init__(self, datfile[0], sampling_frequency=float(self.params['sample_rate']), dtype=self.params['dtype'], numchan=self.params['n_channels_dat'], recording_channels=list(channel_map)) if (phy_folder / 'channel_groups.npy').is_file(): channel_groups = np.load(phy_folder / 'channel_groups.npy') assert len(channel_groups) == self.get_num_channels() self.set_channel_groups(channel_groups) if (phy_folder / 'channel_positions.npy').is_file(): channel_locations = np.load(phy_folder / 'channel_positions.npy') assert len(channel_locations) == self.get_num_channels() self.set_channel_locations(channel_locations) self._kwargs = {'folder_path': str(Path(folder_path).absolute())} class PhySortingExtractor(SortingExtractor): """ SortingExtractor for a Phy output folder Parameters ---------- folder_path: str or Path Path to the output Phy folder (containing the params.py) exclude_cluster_groups: list (optional) List of cluster groups to exclude (e.g. ["noise", "mua"] """ extractor_name = 'PhySorting' installed = True # check at class level if installed or not is_writable = False mode = 'folder' installation_mesg = "" # error message when not installed def __init__(self, folder_path: PathType, exclude_cluster_groups: Optional[list] = None): SortingExtractor.__init__(self) phy_folder = Path(folder_path) spike_times = np.load(phy_folder / 'spike_times.npy') spike_templates = np.load(phy_folder / 'spike_templates.npy') if (phy_folder / 'spike_clusters.npy').is_file(): spike_clusters = np.load(phy_folder / 'spike_clusters.npy') else: spike_clusters = spike_templates if (phy_folder / 'amplitudes.npy').is_file(): amplitudes = np.squeeze(np.load(phy_folder / 'amplitudes.npy')) else: amplitudes = np.ones(len(spike_times)) if (phy_folder / 'pc_features.npy').is_file(): pc_features = np.squeeze(np.load(phy_folder / 'pc_features.npy')) else: pc_features = None clust_id = np.unique(spike_clusters) self._unit_ids = list(clust_id) spike_times.astype(int) self.params = read_python(str(phy_folder / 'params.py')) self._sampling_frequency = self.params['sample_rate'] # set unit quality properties csv_tsv_files = [x for x in phy_folder.iterdir() if x.suffix == '.csv' or x.suffix == '.tsv'] for f in csv_tsv_files: if f.suffix == '.csv': with f.open() as csv_file: csv_reader = csv.reader(csv_file, delimiter=',') line_count = 0 for row in csv_reader: if line_count == 0: tokens = row[0].split("\t") property_name = tokens[1] else: tokens = row[0].split("\t") if int(tokens[0]) in self.get_unit_ids(): if 'cluster_group' in str(f): self.set_unit_property(int(tokens[0]), 'quality', tokens[1]) elif property_name == 'chan_grp' or property_name == 'ch_group': self.set_unit_property(int(tokens[0]), 'group', int(tokens[1])) else: if isinstance(tokens[1], (int, np.int, float, str)): self.set_unit_property(int(tokens[0]), property_name, tokens[1]) line_count += 1 elif f.suffix == '.tsv': with f.open() as csv_file: csv_reader = csv.reader(csv_file, delimiter='\t') line_count = 0 for row in csv_reader: if line_count == 0: property_name = row[1] else: if len(row) == 2: if int(row[0]) in self.get_unit_ids(): if 'cluster_group' in str(f): self.set_unit_property(int(row[0]), 'quality', row[1]) elif property_name == 'chan_grp' or property_name == 'ch_group': self.set_unit_property(int(row[0]), 'group', int(row[1])) else: if isinstance(row[1], (int, float, str)) and len(row) == 2: self.set_unit_property(int(row[0]), property_name, row[1]) line_count += 1 for unit in self.get_unit_ids(): if 'quality' not in self.get_unit_property_names(unit): self.set_unit_property(unit, 'quality', 'unsorted') if exclude_cluster_groups is not None: if len(exclude_cluster_groups) > 0: included_units = [] for u in self.get_unit_ids(): if self.get_unit_property(u, 'quality') not in exclude_cluster_groups: included_units.append(u) else: included_units = self._unit_ids else: included_units = self._unit_ids original_units = self._unit_ids self._unit_ids = included_units # set features self._spiketrains = [] for clust in self._unit_ids: idx = np.where(spike_clusters == clust)[0] self._spiketrains.append(spike_times[idx]) self.set_unit_spike_features(clust, 'amplitudes', amplitudes[idx]) if pc_features is not None: self.set_unit_spike_features(clust, 'pc_features', pc_features[idx]) self._kwargs = {'folder_path': str(Path(folder_path).absolute()), 'exclude_cluster_groups': exclude_cluster_groups} def get_unit_ids(self): return list(self._unit_ids) @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): times = self._spiketrains[self.get_unit_ids().index(unit_id)] inds = np.where((start_frame <= times) & (times < end_frame)) return times[inds] ================================================ FILE: spikeextractors/extractors/shybridextractors/__init__.py ================================================ from .shybridextractors import SHYBRIDRecordingExtractor, SHYBRIDSortingExtractor ================================================ FILE: spikeextractors/extractors/shybridextractors/shybridextractors.py ================================================ import os from pathlib import Path import numpy as np from spikeextractors import RecordingExtractor, SortingExtractor from spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor from spikeextractors.extraction_tools import save_to_probe_file, load_probe_file, check_get_unit_spike_train try: import hybridizer.io as sbio import hybridizer.probes as sbprb import yaml HAVE_SBEX = True except ImportError: HAVE_SBEX = False class SHYBRIDRecordingExtractor(RecordingExtractor): extractor_name = 'SHYBRIDRecording' installed = HAVE_SBEX has_default_locations = True has_unscaled = False is_writable = True mode = 'file' installation_mesg = "To use the SHYBRID extractors, install SHYBRID and pyyaml: " \ "\n\n pip install shybrid pyyaml\n\n" def __init__(self, file_path): # load params file related to the given shybrid recording assert self.installed, self.installation_mesg RecordingExtractor.__init__(self) params = sbio.get_params(file_path)['data'] # create a shybrid probe object probe = sbprb.Probe(params['probe']) nb_channels = probe.total_nb_channels # translate the byte ordering # TODO still ambiguous, shybrid should assume time_axis=1, since spike interface makes an assumption on the byte ordering byte_order = params['order'] if byte_order == 'C': time_axis = 1 elif byte_order == 'F': time_axis = 0 # piggyback on binary data recording extractor recording = BinDatRecordingExtractor( file_path, params['fs'], nb_channels, params['dtype'], time_axis=time_axis) # load probe file self._recording = load_probe_file(recording, params['probe']) self._kwargs = {'file_path': str(Path(file_path).absolute())} def get_channel_ids(self): return self._recording.get_channel_ids() def get_num_frames(self): return self._recording.get_num_frames() def get_sampling_frequency(self): return self._recording.get_sampling_frequency() def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): return self._recording.get_traces(channel_ids=channel_ids, start_frame=start_frame, end_frame=end_frame, return_scaled=return_scaled) @staticmethod def write_recording(recording, save_path, initial_sorting_fn, dtype='float32', **write_binary_kwargs): """ Convert and save the recording extractor to SHYBRID format Parameters ---------- recording: RecordingExtractor The recording extractor to be converted and saved save_path: str Full path to desired target folder initial_sorting_fn: str Full path to the initial sorting csv file (can also be generated using write_sorting static method from the SHYBRIDSortingExtractor) dtype: dtype Type of the saved data. Default float32. **write_binary_kwargs: keyword arguments for write_to_binary_dat_format() function """ assert HAVE_SBEX, SHYBRIDRecordingExtractor.installation_mesg RECORDING_NAME = 'recording.bin' PROBE_NAME = 'probe.prb' PARAMETERS_NAME = 'recording.yml' # location information has to be present in order for shybrid to # be able to operate on the recording if 'location' not in recording.get_shared_channel_property_names(): raise GeometryNotLoadedError("Channel locations were not found") # write recording recording_fn = os.path.join(save_path, RECORDING_NAME) recording.write_to_binary_dat_format(save_path=recording_fn, time_axis=0, dtype=dtype, **write_binary_kwargs) # write probe file probe_fn = os.path.join(save_path, PROBE_NAME) save_to_probe_file(recording, probe_fn) # create parameters file parameters = dict(clusters=initial_sorting_fn, data=dict(dtype=dtype, fs=str(recording.get_sampling_frequency()), order='F', probe=probe_fn)) # write parameters file parameters_fn = os.path.join(save_path, PARAMETERS_NAME) with open(parameters_fn, 'w') as fp: yaml.dump(parameters, fp) class SHYBRIDSortingExtractor(SortingExtractor): extractor_name = 'SHYBRIDSorting' installed = HAVE_SBEX is_writable = True installation_mesg = "To use the SHYBRID extractors, install SHYBRID: \n\n pip install shybrid\n\n" def __init__(self, file_path, delimiter=','): assert self.installed, self.installation_mesg SortingExtractor.__init__(self) if os.path.isfile(file_path): self._spike_clusters = sbio.SpikeClusters() self._spike_clusters.fromCSV(file_path, None, delimiter=delimiter) else: raise FileNotFoundError('the ground truth file "{}" could not be found'.format(file_path)) self._kwargs = {'file_path': str(Path(file_path).absolute()), 'delimiter': delimiter} def get_unit_ids(self): return self._spike_clusters.keys() @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): train = self._spike_clusters[unit_id].get_actual_spike_train().spikes idxs = np.where((start_frame <= train) & (train < end_frame)) return train[idxs] @staticmethod def write_sorting(sorting, save_path): """ Convert and save the sorting extractor to SHYBRID CSV format parameters ---------- sorting : SortingExtractor The sorting extractor to be converted and saved save_path : str Full path to the desired target folder """ assert HAVE_SBEX, SHYBRIDSortingExtractor.installation_mesg dump = np.empty((0, 2)) for unit_id in sorting.get_unit_ids(): spikes = sorting.get_unit_spike_train(unit_id)[:, np.newaxis] expanded_id = (np.ones(spikes.size) * unit_id)[:, np.newaxis] tmp_concat = np.concatenate((expanded_id, spikes), axis=1) dump = np.concatenate((dump, tmp_concat), axis=0) sorting_fn = os.path.join(save_path, 'initial_sorting.csv') np.savetxt(sorting_fn, dump, delimiter=',', fmt='%i') class GeometryNotLoadedError(Exception): """ Raised when the recording extractor has no associated channel locations """ pass params_template = \ """clusters: csv: {initial_sorting_fn} data: dtype: {data_type} fs: {sampling_frequency} order: {byte_ordering} probe: {probe_fn} """ ================================================ FILE: spikeextractors/extractors/spikeglxrecordingextractor/__init__.py ================================================ from .spikeglxrecordingextractor import SpikeGLXRecordingExtractor ================================================ FILE: spikeextractors/extractors/spikeglxrecordingextractor/readSGLX.py ================================================ # -*- coding: utf-8 -*- """ ---------------------------------------------------------------- This is an adapted version of auxiliary functions to read from SpikeGLX data files. The original code can be found at: https://billkarsh.github.io/SpikeGLX/#offline-analysis-tools ---------------------------------------------------------------- Requires python 3 The main() function at the bottom of this file can run from an interpreter, or, the helper functions can be imported into a new module or Jupyter notebook (an example is included). Simple helper functions and python dictionary demonstrating how to read and manipulate SpikeGLX meta and binary files. The most important part of the demo is readMeta(). Please read the comments for that function. Use of the 'meta' dictionary will make your data handling much easier! """ import numpy as np # import matplotlib.pyplot as plt from pathlib import Path # from tkinter import Tk # from tkinter import filedialog # Parse ini file returning a dictionary whose keys are the metadata # left-hand-side-tags, and values are string versions of the right-hand-side # metadata values. We remove any leading '~' characters in the tags to match # the MATLAB version of readMeta. # # The string values are converted to numbers using the "int" and "float" # fucntions. Note that python 3 has no size limit for integers. # def readMeta(binFullPath): metaName = binFullPath.stem + ".meta" metaPath = Path(binFullPath.parent / metaName) metaDict = {} if metaPath.exists(): # print("meta file present") with metaPath.open() as f: mdatList = f.read().splitlines() # convert the list entries into key value pairs for m in mdatList: csList = m.split(sep='=') if csList[0][0] == '~': currKey = csList[0][1:len(csList[0])] else: currKey = csList[0] metaDict.update({currKey: csList[1]}) else: print("no meta file") return(metaDict) # Return sample rate as python float. # On most systems, this will be implemented as C++ double. # Use python command sys.float_info to get properties of float on your system. # def SampRate(meta): if meta['typeThis'] == 'imec': srate = float(meta['imSampRate']) else: srate = float(meta['niSampRate']) return(srate) # Return a multiplicative factor for converting 16-bit file data # to volatge. This does not take gain into account. The full # conversion with gain is: # dataVolts = dataInt * fI2V / gain # Note that each channel may have its own gain. # def Int2Volts(meta): if meta['typeThis'] == 'imec': fI2V = float(meta['imAiRangeMax'])/512 else: fI2V = float(meta['niAiRangeMax'])/32768 return(fI2V) # Return array of original channel IDs. As an example, suppose we want the # imec gain for the ith channel stored in the binary data. A gain array # can be obtained using ChanGainsIM(), but we need an original channel # index to do the lookup. Because you can selectively save channels, the # ith channel in the file isn't necessarily the ith acquired channel. # Use this function to convert from ith stored to original index. # Note that the SpikeGLX channels are 0 based. # def OriginalChans(meta): if meta['snsSaveChanSubset'] == 'all': # output = int32, 0 to nSavedChans - 1 chans = np.arange(0, int(meta['nSavedChans'])) else: # parse the snsSaveChanSubset string # split at commas chStrList = meta['snsSaveChanSubset'].split(sep=',') chans = np.arange(0, 0) # creates an empty array of int32 for sL in chStrList: currList = sL.split(sep=':') if len(currList) > 1: # each set of contiguous channels specified by # chan1:chan2 inclusive newChans = np.arange(int(currList[0]), int(currList[1])+1) else: newChans = np.arange(int(currList[0]), int(currList[0])+1) chans = np.append(chans, newChans) return(chans) # Return counts of each nidq channel type that composes the timepoints # stored in the binary file. # def ChannelCountsNI(meta): chanCountList = meta['snsMnMaXaDw'].split(sep=',') MN = int(chanCountList[0]) MA = int(chanCountList[1]) XA = int(chanCountList[2]) DW = int(chanCountList[3]) return(MN, MA, XA, DW) # Return counts of each imec channel type that composes the timepoints # stored in the binary files. # def ChannelCountsIM(meta): chanCountList = meta['snsApLfSy'].split(sep=',') AP = int(chanCountList[0]) LF = int(chanCountList[1]) SY = int(chanCountList[2]) return(AP, LF, SY) # Return gain for ith channel stored in nidq file. # ichan is a saved channel index, rather than the original (acquired) index. # def ChanGainNI(ichan, savedMN, savedMA, meta): if ichan < savedMN: gain = float(meta['niMNGain']) elif ichan < (savedMN + savedMA): gain = float(meta['niMAGain']) else: gain = 1 # non multiplexed channels have no extra gain return(gain) # Return gain for imec channels. # Index into these with the original (acquired) channel IDs. # def ChanGainsIM(meta): imroList = meta['imroTbl'].split(sep=')') # One entry for each channel plus header entry, # plus a final empty entry following the last ')' nChan = len(imroList) - 2 APgain = np.zeros(nChan) # default type = float LFgain = np.zeros(nChan) for i in range(0, nChan): currList = imroList[i+1].split(sep=' ') APgain[i] = currList[3] LFgain[i] = currList[4] return(APgain, LFgain) # Having accessed a block of raw nidq data using makeMemMapRaw, convert # values to gain-corrected voltage. The conversion is only applied to the # saved-channel indicies in chanList. Remember, saved-channel indicies are # in the range [0:nSavedChans-1]. The dimensions of dataArray remain # unchanged. ChanList examples: # [0:MN-1] all MN channels (MN from ChannelCountsNI) # [2,6,20] just these three channels (zero based, as they appear in SGLX). # def GainCorrectNI(dataArray, chanList, meta): MN, MA, XA, DW = ChannelCountsNI(meta) fI2V = Int2Volts(meta) # print statements used for testing... # print("NI fI2V: %.3e" % (fI2V)) # print("NI ChanGainNI: %.3f" % (ChanGainNI(0, MN, MA, meta))) # make array of floats to return. dataArray contains only the channels # in chanList, so output matches that shape # convArray = np.zeros(dataArray.shape, dtype=float) conv = np.zeros(len(chanList), dtype=float) for i in range(0, len(chanList)): j = chanList[i] # index into timepoint conv[i] = fI2V/ChanGainNI(j, MN, MA, meta) # dataArray contains only the channels in chanList #convArray[i, :] = dataArray[i, :] * conv[i] return conv # Having accessed a block of raw imec data using makeMemMapRaw, convert # values to gain corrected voltages. The conversion is only applied to # the saved-channel indicies in chanList. Remember saved-channel indicies # are in the range [0:nSavedChans-1]. The dimensions of the dataArray # remain unchanged. ChanList examples: # [0:AP-1] all AP channels # [2,6,20] just these three channels (zero based) # Remember that for an lf file, the saved channel indicies (fetched by # OriginalChans) will be in the range 384-767 for a standard 3A or 3B probe. # def GainCorrectIM(dataArray, chanList, meta): # Look up gain with acquired channel ID chans = OriginalChans(meta) APgain, LFgain = ChanGainsIM(meta) nAP = len(APgain) nNu = nAP * 2 # Common converstion factor fI2V = Int2Volts(meta) # make array of floats to return. dataArray contains only the channels # in chanList, so output matches that shape # convArray = np.zeros(dataArray.shape, dtype='float') conv = np.zeros(len(chanList), dtype=float) for i in range(0, len(chanList)): j = chanList[i] # index into timepoint k = chans[j] # acquisition index if k < nAP: conv[i] = fI2V / APgain[k] elif k < nNu: conv[i] = fI2V / LFgain[k - nAP] else: conv[i] = 1 # The dataArray contains only the channels in chList #convArray[i, :] = dataArray[i, :]*conv[i] return conv def makeMemMapRaw(binFullPath, meta): nChan = int(meta['nSavedChans']) nFileSamp = int(int(meta['fileSizeBytes'])/(2*nChan)) # print("nChan: %d, nFileSamp: %d" % (nChan, nFileSamp)) rawData = np.memmap(binFullPath, dtype='int16', mode='r', shape=(nChan, nFileSamp), offset=0, order='F') return(rawData) # Return an array [lines X timepoints] of uint8 values for a # specified set of digital lines. # # - dwReq is the zero-based index into the saved file of the # 16-bit word that contains the digital lines of interest. # - dLineList is a zero-based list of one or more lines/bits # to scan from word dwReq. # def ExtractDigital(rawData, firstSamp, lastSamp, dwReq, dLineList, meta): # Get channel index of requested digial word dwReq if meta['typeThis'] == 'imec': AP, LF, SY = ChannelCountsIM(meta) if SY == 0: print("No imec sync channel saved.") digArray = np.zeros((0), 'uint8') return(digArray) else: digCh = AP + LF + dwReq else: MN, MA, XA, DW = ChannelCountsNI(meta) if dwReq > DW-1: print("Maximum digital word in file = %d" % (DW-1)) digArray = np.zeros((0), 'uint8') return(digArray) else: digCh = MN + MA + XA + dwReq selectData = np.ascontiguousarray(rawData[digCh, firstSamp:lastSamp], 'int16') nSamp = lastSamp-firstSamp # unpack bits of selectData; unpack bits works with uint8 # origintal data is int16 bitWiseData = np.unpackbits(selectData.view(dtype='uint8')) # output is 1-D array, nSamp*16. Reshape and transpose bitWiseData = np.transpose(np.reshape(bitWiseData, (nSamp, 16))) nLine = len(dLineList) digArray = np.zeros((nLine, nSamp), 'uint8') for i in range(0, nLine): byteN, bitN = np.divmod(dLineList[i], 8) targI = byteN*8 + (7 - bitN) digArray[i, :] = bitWiseData[targI, :] return (digArray) # Sample calling program to get a file from the user, # read metadata fetch sample rate, voltage conversion # values for this file and channel, and plot a small range # of voltages from a single channel. # Note that this code merely demonstrates indexing into the # data file, without any optimization for efficiency. # # def main(): # # # Get file from user # root = Tk() # create the Tkinter widget # root.withdraw() # hide the Tkinter root window # # # Windows specific; forces the window to appear in front # root.attributes("-topmost", True) # # binFullPath = Path(filedialog.askopenfilename(title="Select binary file")) # root.destroy() # destroy the Tkinter widget # # # Other parameters about what data to read # tStart = 0 # tEnd = 1 # dataType = 'D' # 'A' for analog, 'D' for digital data # # # For analog channels: zero-based index of a channel to extract, # # gain correct and plot (plots first channel only) # chanList = [0] # # # For a digital channel: zero based index of the digital word in # # the saved file. For imec data there is never more than one digital word. # dw = 0 # # # Zero-based Line indicies to read from the digital word and plot. # # For 3B2 imec data: the sync pulse is stored in line 6. # dLineList = [0, 1, 6] # # # Read in metadata; returns a dictionary with string for values # meta = readMeta(binFullPath) # # # parameters common to NI and imec data # sRate = SampRate(meta) # firstSamp = int(sRate*tStart) # lastSamp = int(sRate*tEnd) # # array of times for plot # tDat = np.arange(firstSamp, lastSamp+1) # tDat = 1000*tDat/sRate # plot time axis in msec # # rawData = makeMemMapRaw(binFullPath, meta) # # if dataType == 'A': # selectData = rawData[chanList, firstSamp:lastSamp+1] # if meta['typeThis'] == 'imec': # # apply gain correction and convert to uV # convData = 1e6*GainCorrectIM(selectData, chanList, meta) # else: # MN, MA, XA, DW = ChannelCountsNI(meta) # # print("NI channel counts: %d, %d, %d, %d" % (MN, MA, XA, DW)) # # apply gain coorection and conver to mV # convData = 1e3*GainCorrectNI(selectData, chanList, meta) # # # # Plot the first of the extracted channels # # fig, ax = plt.subplots() # # ax.plot(tDat, convData[0, :]) # # plt.show() # # else: # digArray = ExtractDigital(rawData, firstSamp, lastSamp, dw, # dLineList, meta) # # # # Plot the first of the extracted channels # # fig, ax = plt.subplots() # # # # for i in range(0, len(dLineList)): # # ax.plot(tDat, digArray[i, :]) # # plt.show() # # # if __name__ == "__main__": # main() ================================================ FILE: spikeextractors/extractors/spikeglxrecordingextractor/spikeglxrecordingextractor.py ================================================ from .readSGLX import readMeta, SampRate, makeMemMapRaw, GainCorrectIM, GainCorrectNI, ExtractDigital import numpy as np from pathlib import Path from spikeextractors import RecordingExtractor from spikeextractors.extraction_tools import check_get_traces_args, check_get_ttl_args import re class SpikeGLXRecordingExtractor(RecordingExtractor): """ RecordingExtractor from a SpikeGLX Neuropixels file Parameters ---------- file_path: str or Path Path to the ap.bin, lf.bin, or nidq.bin file dtype: str 'int16' or 'float'. If 'float' is selected, the returned traces are converted to uV x_pitch: int The x pitch of the probe (default 16) y_pitch: int The y pitch of the probe (default 20) """ extractor_name = 'SpikeGLXRecording' has_default_locations = True has_unscaled = True installed = True # check at class level if installed or not is_writable = False mode = 'file' installation_mesg = "To use the SpikeGLXRecordingExtractor run:\n\n pip install mtscomp\n\n" # error message when not installed def __init__(self, file_path: str, x_pitch: int = 32, y_pitch: int = 20): RecordingExtractor.__init__(self) self._npxfile = Path(file_path) self._basepath = self._npxfile.parents[0] # Gets file type: 'imec0.ap', 'imec0.lf' or 'nidq' assert re.search(r'imec[0-9]*.(ap|lf){1}.bin$', self._npxfile.name) or 'nidq' in self._npxfile.name, \ "'file_path' can be an imec.ap, imec.lf, imec0.ap, imec0.lf, or nidq file" if 'ap.bin' in str(self._npxfile): rec_type = "ap" self.is_filtered = True elif 'lf.bin' in str(self._npxfile): rec_type = "lf" else: rec_type = "nidq" aux = self._npxfile.stem.split('.')[-1] if aux == 'nidq': self._ftype = aux else: self._ftype = self._npxfile.stem.split('.')[-2] + '.' + aux # Metafile self._metafile = self._basepath.joinpath(self._npxfile.stem+'.meta') if not self._metafile.exists(): raise Exception("'meta' file for '"+self._ftype+"' traces should be in the same folder.") # Read in metadata, returns a dictionary meta = readMeta(self._npxfile) self._meta = meta # Traces in 16-bit format self._raw = makeMemMapRaw(self._npxfile, meta) # [chanList, firstSamp:lastSamp+1] # sampling rate and ap channels self._sampling_frequency = SampRate(meta) tot_chan, ap_chan, lfp_chan, locations, channel_ids, channel_names \ = _parse_spikeglx_metafile(self._metafile, x_pitch=x_pitch, y_pitch=y_pitch, rec_type=rec_type) if rec_type in ("ap", "lf"): self._channels = channel_ids # locations if len(locations) > 0: self.set_channel_locations(locations) if len(channel_names) > 0: if len(channel_names) == len(self._channels): for i, ch in enumerate(self._channels): self.set_channel_property(ch, "channel_name", channel_names[i]) if rec_type == "ap": if ap_chan < tot_chan: self._timeseries = self._raw[0:ap_chan, :] elif rec_type == "lf": if lfp_chan < tot_chan: self._timeseries = self._raw[0:lfp_chan, :] else: # nidq self._channels = list(range(int(tot_chan))) self._timeseries = self._raw # get gains if meta['typeThis'] == 'imec': gains = GainCorrectIM(self._timeseries, self._channels, meta) elif meta['typeThis'] == 'nidq': gains = GainCorrectNI(self._timeseries, self._channels, meta) # set gains - convert from int16 to uVolt self.set_channel_gains(gains=gains*1e6, channel_ids=self._channels) self._kwargs = {'file_path': str(Path(file_path).absolute()), 'x_pitch': x_pitch, 'y_pitch': y_pitch} def get_channel_ids(self): return self._channels def get_num_frames(self): return self._timeseries.shape[1] def get_sampling_frequency(self): return self._sampling_frequency @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): channel_idxs = np.array([self.get_channel_ids().index(ch) for ch in channel_ids]) if np.array_equal(channel_ids, self.get_channel_ids()): traces = self._timeseries[:, start_frame:end_frame] else: if np.all(np.diff(channel_idxs) == 1): traces = self._timeseries[channel_idxs[0]:channel_idxs[0]+len(channel_idxs), start_frame:end_frame] else: # This block of the execution will return the data as an array, not a memmap traces = self._timeseries[channel_idxs, start_frame:end_frame] return traces @check_get_ttl_args def get_ttl_events(self, start_frame=None, end_frame=None, channel_id=0): channel = [channel_id] dw = 0 dig = ExtractDigital(self._raw, firstSamp=start_frame, lastSamp=end_frame, dwReq=dw, dLineList=channel, meta=self._meta) dig = np.squeeze(dig) diff_dig = np.diff(dig.astype(int)) rising = np.where(diff_dig > 0)[0] + start_frame falling = np.where(diff_dig < 0)[0] + start_frame ttl_frames = np.concatenate((rising, falling)) ttl_states = np.array([1] * len(rising) + [-1] * len(falling)) sort_idxs = np.argsort(ttl_frames) return ttl_frames[sort_idxs], ttl_states[sort_idxs] def _parse_spikeglx_metafile(metafile, x_pitch, y_pitch, rec_type): tot_channels = None ap_channels = None lfp_channels = None y_offset = 20 x_offset = 11 locations = [] channel_names = [] channel_ids = [] with Path(metafile).open() as f: for line in f.readlines(): if 'nSavedChans' in line: tot_channels = int(line.split('=')[-1]) if 'snsApLfSy' in line: ap_channels = int(line.split('=')[-1].split(',')[0].strip()) lfp_channels = int(line.split(',')[-2].strip()) if 'imSampRate' in line: fs = float(line.split('=')[-1]) if rec_type in ("ap", "lf"): if 'snsChanMap' in line: map = line.split('=')[-1] chans = map.split(')')[1:] for chan in chans: chan_name = chan[1:].split(';')[0] if rec_type == "ap": if "AP" in chan_name: channel_names.append(chan_name) chan_id = int(chan_name[2:]) channel_ids.append(chan_id) elif rec_type == "lf": if "LF" in chan_name: channel_names.append(chan_name) chan_id = int(chan_name[2:]) channel_ids.append(chan_id) if 'snsShankMap' in line: map = line.split('=')[-1] chans = map.split(')')[1:] for chan in chans: chan = chan[1:] if len(chan) > 0: x_idx = int(chan.split(':')[1]) y_idx = int(chan.split(':')[2]) stagger = np.mod(y_idx + 0, 2) * x_pitch / 2 x_pos = (1 - x_idx) * x_pitch + stagger + x_offset y_pos = y_idx * y_pitch + y_offset locations.append([x_pos, y_pos]) return tot_channels, ap_channels, lfp_channels, locations, channel_ids, channel_names ================================================ FILE: spikeextractors/extractors/spykingcircusextractors/__init__.py ================================================ from .spykingcircusextractors import SpykingCircusSortingExtractor, SpykingCircusRecordingExtractor ================================================ FILE: spikeextractors/extractors/spykingcircusextractors/spykingcircusextractors.py ================================================ from spikeextractors import RecordingExtractor, SortingExtractor from spikeextractors.extractors.numpyextractors import NumpyRecordingExtractor from spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor import numpy as np from pathlib import Path from spikeextractors.extraction_tools import check_get_unit_spike_train try: import h5py HAVE_SCSX = True except ImportError: HAVE_SCSX = False class SpykingCircusRecordingExtractor(RecordingExtractor): """ RecordingExtractor for a SpykingCircus output folder Parameters ---------- folder_path: str or Path Path to the output Spyking Circus folder or result folder """ extractor_name = 'SpykingCircusRecording' has_default_locations = False has_unscaled = False installed = True # check at class level if installed or not is_writable = False mode = 'folder' installation_mesg = "" # error message when not installed def __init__(self, folder_path): RecordingExtractor.__init__(self) spykingcircus_folder = Path(folder_path) listfiles = spykingcircus_folder.iterdir() parent_folder = None result_folder = None for f in listfiles: if f.is_dir(): if any([f_.suffix == '.hdf5' for f_ in f.iterdir()]): parent_folder = spykingcircus_folder result_folder = f if parent_folder is None: parent_folder = spykingcircus_folder.parent for f in parent_folder.iterdir(): if f.is_dir(): if any([f_.suffix == '.hdf5' for f_ in f.iterdir()]): result_folder = spykingcircus_folder assert isinstance(parent_folder, Path) and isinstance(result_folder, Path), "Not a valid spyking circus folder" params = None params_file = None for f in parent_folder.iterdir(): if f.suffix == '.params': params = _load_params(f) params_file = f break assert params is not None, "Could not find the .params file" recording_name = params_file.stem file_format = params["file_format"].lower() if file_format == "numpy": recording_file = parent_folder / f"{recording_name}.npy" self._recording = NumpyRecordingExtractor(recording_file, params["sampling_frequency"]) elif file_format == "raw_binary": recording_file = parent_folder / f"{recording_name}.dat" self._recording = BinDatRecordingExtractor(recording_file, sampling_frequency=params["sampling_frequency"], numchan=params["nb_channels"], dtype=params["dtype"], time_axis=0) else: raise Exception(f"'file_format' {params['file_format']} is not supported by the " f"SpykingCircusRecordingExtractor") if params["mapping"].is_file(): self._recording = self.load_probe_file(params["mapping"]) self.params = params self._kwargs = {'folder_path': str(Path(folder_path).absolute())} def get_channel_ids(self): return self._recording.get_channel_ids() def get_num_frames(self): return self._recording.get_num_frames() def get_sampling_frequency(self): return self._recording.get_sampling_frequency() def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): return self._recording.get_traces(channel_ids=channel_ids, start_frame=start_frame, end_frame=end_frame, return_scaled=return_scaled) class SpykingCircusSortingExtractor(SortingExtractor): """ SortingExtractor for SpykingCircus output folder or file Parameters ---------- file_or_folder_path: str or Path Path to the output Spyking Circus folder, the result folder, or a specific hdf5 file in the result folder load_templates: bool If True, templates are loaded from Spyking Circus output """ extractor_name = 'SpykingCircusSorting' installed = HAVE_SCSX # check at class level if installed or not is_writable = True mode = 'folder' installation_mesg = "To use the SpykingCircusSortingExtractor install h5py: \n\n pip install h5py\n\n" def __init__(self, file_or_folder_path, load_templates=False): assert self.installed, self.installation_mesg SortingExtractor.__init__(self) file_or_folder_path = Path(file_or_folder_path) if file_or_folder_path.is_dir(): listfiles = file_or_folder_path.iterdir() results = None parent_folder = None result_folder = None for f in listfiles: if f.is_dir(): if any([f_.suffix == '.hdf5' for f_ in f.iterdir()]): parent_folder = file_or_folder_path result_folder = f if parent_folder is None: parent_folder = file_or_folder_path.parent for f in parent_folder.iterdir(): if f.is_dir(): if any([f_.suffix == '.hdf5' for f_ in f.iterdir()]): result_folder = file_or_folder_path # load files for f in result_folder.iterdir(): if 'result.hdf5' in str(f): results = f result_extension = '' base_name = f.name[:f.name.find("result")-1] if 'result-merged.hdf5' in str(f): results = f result_extension = '-merged' base_name = f.name[:f.name.find("result")-1] break else: assert file_or_folder_path.suffix in ['.h5', '.hdf5'] result_folder = file_or_folder_path.parent parent_folder = result_folder.parent results = file_or_folder_path result_extension = results.stem[results.stem.find("result") + 6:] base_name = file_or_folder_path.name[:file_or_folder_path.name.find("result") - 1] assert isinstance(parent_folder, Path) and isinstance(result_folder, Path), "Not a valid spyking circus folder" # load params params = {} for f in parent_folder.iterdir(): if f.suffix == '.params': params = _load_params(f) if "sampling_frequency" in params.keys(): self._sampling_frequency = params["sampling_frequency"] if results is None: raise Exception(f"{file_or_folder_path} is not a spyking circus folder") f_results = h5py.File(results, 'r') self._spiketrains = [] self._unit_ids = [] for temp in f_results['spiketimes'].keys(): self._spiketrains.append(np.array(f_results['spiketimes'][temp]).astype('int64')) self._unit_ids.append(int(temp.split('_')[-1])) if load_templates: try: import scipy except: raise ImportError("'scipy' is needed to load templates from Spyking Circus") filename = result_folder / f"{base_name}.templates{result_extension}.hdf5" with h5py.File(filename, 'r', libver='earliest') as f: temp_x = f.get('temp_x')[:].ravel() temp_y = f.get('temp_y')[:].ravel() temp_data = f.get('temp_data')[:].ravel() N_e, N_t, nb_templates = f.get('temp_shape')[:].ravel().astype(np.int32) templates = scipy.sparse.csc_matrix((temp_data, (temp_x, temp_y)), shape=(N_e * N_t, nb_templates)) templates = np.array([templates[:, i].toarray().reshape(N_e, N_t) for i in range(templates.shape[1])]) templates = templates[:len(templates)//2] for u_i, unit in enumerate(self.get_unit_ids()): self.set_unit_property(unit, 'template', templates[u_i]) self._kwargs = {'file_or_folder_path': str(Path(file_or_folder_path).absolute()), 'load_templates': load_templates} def get_unit_ids(self): return list(self._unit_ids) @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): times = self._spiketrains[self.get_unit_ids().index(unit_id)] inds = np.where((start_frame <= times) & (times < end_frame)) return times[inds] @staticmethod def write_sorting(sorting, save_path): assert HAVE_SCSX, SpykingCircusSortingExtractor.installation_mesg save_path = Path(save_path) if save_path.is_dir(): save_path = save_path / 'data.result.hdf5' elif save_path.suffix == '.hdf5': if not str(save_path).endswith('result.hdf5') or not str(save_path).endswith('result-merged.hdf5'): raise AttributeError("'save_path' is either a folder or an hdf5 file " "ending with 'result.hdf5' or 'result-merged.hdf5") else: save_path.mkdir() save_path = save_path / 'data.result.hdf5' F = h5py.File(save_path, 'w') spiketimes = F.create_group('spiketimes') for id in sorting.get_unit_ids(): spiketimes.create_dataset('tmp_' + str(id), data=sorting.get_unit_spike_train(id)) def _load_params(params_file): params = {} with params_file.open('r') as f: for r in f.readlines(): if 'sampling_rate' in r: sampling_frequency = r.split('=')[-1] if '#' in sampling_frequency: sampling_frequency = sampling_frequency[:sampling_frequency.find('#')] sampling_frequency = sampling_frequency.strip(" ").strip("\n") sampling_frequency = float(sampling_frequency) params["sampling_frequency"] = sampling_frequency if 'file_format' in r: file_format = r.split('=')[-1] if '#' in file_format: file_format = file_format[:file_format.find('#')] file_format = file_format.strip(" ").strip("\n") params["file_format"] = file_format if 'nb_channels' in r: nb_channels = r.split('=')[-1] if '#' in nb_channels: nb_channels = nb_channels[:nb_channels.find('#')] nb_channels = nb_channels.strip(" ").strip("\n") params["nb_channels"] = int(nb_channels) if 'data_dtype' in r: dtype = r.split('=')[-1] if '#' in dtype: dtype = dtype[:dtype.find('#')] dtype = dtype.strip(" ").strip("\n") params["dtype"] = dtype if 'mapping' in r: mapping = r.split('=')[-1] if '#' in mapping: mapping = mapping[:mapping.find('#')] mapping = mapping.strip(" ").strip("\n") params["mapping"] = Path(mapping) return params ================================================ FILE: spikeextractors/extractors/tridescloussortingextractor/__init__.py ================================================ from .tridescloussortingextractor import TridesclousSortingExtractor ================================================ FILE: spikeextractors/extractors/tridescloussortingextractor/tridescloussortingextractor.py ================================================ from spikeextractors import SortingExtractor from pathlib import Path from spikeextractors.extraction_tools import check_get_unit_spike_train try: import tridesclous as tdc HAVE_TDC = True except ImportError: HAVE_TDC = False class TridesclousSortingExtractor(SortingExtractor): extractor_name = 'TridesclousSorting' installed = HAVE_TDC # check at class level if installed or not is_writable = False mode = 'folder' installation_mesg = "To use the TridesclousSortingExtractor install tridesclous: \n\n pip install tridesclous\n\n" # error message when not installed def __init__(self, folder_path, chan_grp=None): assert self.installed, self.installation_mesg tdc_folder = Path(folder_path) SortingExtractor.__init__(self) dataio = tdc.DataIO(str(tdc_folder)) if chan_grp is None: # if chan_grp is not provided, take the first one if unique chan_grps = list(dataio.channel_groups.keys()) assert len(chan_grps) == 1, 'There are several groups in the folder, specify chan_grp=...' chan_grp = chan_grps[0] self.chan_grp = chan_grp catalogue = dataio.load_catalogue(name='initial', chan_grp=chan_grp) labels = catalogue['clusters']['cluster_label'] labels = labels[labels >= 0] self._unit_ids = list(labels) # load all spike in memory (this avoid to lock the folder with memmap throug dataio self._all_spikes = dataio.get_spikes(seg_num=0, chan_grp=self.chan_grp, i_start=None, i_stop=None).copy() self._sampling_frequency = dataio.sample_rate self._kwargs = {'folder_path': str(Path(folder_path).absolute()), 'chan_grp': chan_grp} def get_unit_ids(self): return self._unit_ids @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame) spikes = self._all_spikes spikes = spikes[spikes['cluster_label'] == unit_id] spike_times = spikes['index'] if start_frame is not None: spike_times = spike_times[spike_times >= start_frame] if end_frame is not None: spike_times = spike_times[spike_times < end_frame] return spike_times.copy() ================================================ FILE: spikeextractors/extractors/waveclussortingextractor/__init__.py ================================================ from .waveclussortingextractor import WaveClusSortingExtractor ================================================ FILE: spikeextractors/extractors/waveclussortingextractor/waveclussortingextractor.py ================================================ from pathlib import Path from typing import Union import numpy as np from spikeextractors.extractors.matsortingextractor.matsortingextractor import MATSortingExtractor from spikeextractors.extraction_tools import check_get_unit_spike_train PathType = Union[str, Path] class WaveClusSortingExtractor(MATSortingExtractor): extractor_name = "WaveClusSortingExtractor" installation_mesg = "" # error message when not installed def __init__(self, file_path: PathType): super().__init__(file_path) cluster_classes = self._getfield("cluster_class") classes = cluster_classes[:, 0] spike_times = cluster_classes[:, 1] par = self._getfield("par") sample_rate = par[0, 0][np.where(np.array(par.dtype.names) == 'sr')[0][0]][0][0] self.set_sampling_frequency(sample_rate) self._unit_ids = np.unique(classes[classes > 0]).astype('int') self._spike_trains = {} for uid in self._unit_ids: mask = (classes == uid) self._spike_trains[uid] = np.rint(spike_times[mask]*(sample_rate/1000)) self._unsorted_train = np.rint(spike_times[classes == 0] * (sample_rate / 1000)) @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame) start_frame = start_frame or 0 end_frame = end_frame or np.infty st = self._spike_trains[unit_id] return st[(st >= start_frame) & (st < end_frame)] def get_unit_ids(self): return self._unit_ids.tolist() def get_unsorted_spike_train(self, start_frame=None, end_frame=None): start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame) start_frame = start_frame or 0 end_frame = end_frame or np.infty u = self._unsorted_train return u[(u >= start_frame) & (u < end_frame)] ================================================ FILE: spikeextractors/extractors/yassextractors/__init__.py ================================================ from .yassextractors import YassSortingExtractor ================================================ FILE: spikeextractors/extractors/yassextractors/yassextractors.py ================================================ import numpy as np from pathlib import Path from spikeextractors import SortingExtractor from spikeextractors.extractors.numpyextractors import NumpyRecordingExtractor from spikeextractors.extraction_tools import check_get_unit_spike_train try: import yaml HAVE_YASS = True except: HAVE_YASS = False class YassSortingExtractor(SortingExtractor): extractor_name = 'YassSorting' mode = 'folder' installed = HAVE_YASS # check at class level if installed or not has_default_locations = False is_writable = False installation_mesg = "To use the Yass extractor, install pyyaml: \n\n pip install pyyaml\n\n" # error message when not installed def __init__(self, folder_path): assert self.installed, self.installation_mesg SortingExtractor.__init__(self) self.root_dir = folder_path r = Path(self.root_dir) self.fname_spike_train = r / 'tmp' / 'output' / 'spike_train.npy' self.fname_templates = r /'tmp' / 'output' / 'templates' / 'templates_0sec.npy' self.fname_config = r / 'config.yaml' # set defaults to None so they are only loaded if user requires them self.spike_train = None self.temps = None # Read CONFIG File with open(self.fname_config, 'r') as stream: self.config = yaml.safe_load(stream) self._sampling_frequency = self.config['recordings']['sampling_rate'] def get_unit_ids(self): if self.spike_train is None: self.spike_train = np.load(self.fname_spike_train) unit_ids = np.unique(self.spike_train[:,1]) return unit_ids def get_temps(self): # Electrical images/templates. if self.temps is None: self.temps = np.load(self.fname_templates) return self.temps def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): """Code to extract spike frames from the specified unit. """ if self.spike_train is None: self.spike_train = np.load(self.fname_spike_train) # find unit id spike times idx = np.where(self.spike_train[:,1]==unit_id) spike_times = self.spike_train[idx,0].squeeze() # find spike times if start_frame is None: start_frame = 0 if end_frame is None: end_frame = 1E50 # use large time idx2 = np.where(np.logical_and(spike_times>=start_frame, spike_times= start_frames) & (frame < end_frames))[0] if len(inds) == 0: # can only happen if frame == end_frame ind = len(self._start_frames) - 1 else: ind = inds[0] return self._recordings[ind], ind, frame - self._start_frames[ind] def _find_section_for_time(self, time): start_times = np.array(self._start_times) end_times = np.array(self._end_times) inds = np.where((time >= start_times) & (time < end_times))[0] if len(inds) == 0: # can only happen if frame == end_frame ind = len(self._start_times) - 1 else: ind = inds[0] return self._recordings[ind], ind, time - self._start_times[ind] @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): recording1, i_sec1, i_start_frame = self._find_section_for_frame(start_frame) _, i_sec2, i_end_frame = self._find_section_for_frame(end_frame) if i_sec1 == i_sec2: return recording1.get_traces(channel_ids=channel_ids, start_frame=i_start_frame, end_frame=i_end_frame, return_scaled=return_scaled) traces = [] traces.append( self._recordings[i_sec1].get_traces(channel_ids=channel_ids, start_frame=i_start_frame, end_frame=self._recordings[i_sec1].get_num_frames(), return_scaled=return_scaled) ) for i_sec in range(i_sec1 + 1, i_sec2): traces.append( self._recordings[i_sec].get_traces(channel_ids=channel_ids, return_scaled=return_scaled) ) if i_end_frame != 0: traces.append( self._recordings[i_sec2].get_traces(channel_ids=channel_ids, start_frame=0, end_frame=i_end_frame, return_scaled=return_scaled) ) return np.concatenate(traces, axis=1) @check_get_ttl_args def get_ttl_events(self, start_frame=None, end_frame=None, channel_id=0): recording1, i_sec1, i_start_frame = self._find_section_for_frame(start_frame) _, i_sec2, i_end_frame = self._find_section_for_frame(end_frame) if i_sec1 == i_sec2: ttl_frames, ttl_states = recording1.get_ttl_events(start_frame=i_start_frame, end_frame=i_end_frame, channel_id=channel_id) ttl_frames += self._start_frames[i_sec1] else: ttl_frames, ttl_states = [], [] ttl_frames_1, ttl_states_1 = self._recordings[i_sec1].get_ttl_events( start_frame=i_start_frame, end_frame=self._recordings[i_sec1].get_num_frames(), channel_id=channel_id) ttl_frames_1 = (ttl_frames_1 + self._start_frames[i_sec1]).astype('int64') ttl_frames.append(ttl_frames_1) ttl_states.append(ttl_states_1) for i_sec in range(i_sec1 + 1, i_sec2): ttl_frames_i, ttl_states_i = self._recordings[i_sec].get_ttl_events(channel_id=channel_id) ttl_frames_i = (ttl_frames_i + self._start_frames[i_sec]).astype('int64') ttl_frames.append(ttl_frames_i) ttl_states.append(ttl_states_i) ttl_frames_2, ttl_states_2 = self._recordings[i_sec2].get_ttl_events(start_frame=0, end_frame=i_end_frame, channel_id=channel_id) ttl_frames_2 = (ttl_frames_2 + self._start_frames[i_sec2]).astype('int64') ttl_frames.append(ttl_frames_2) ttl_states.append(ttl_states_2) ttl_frames = np.concatenate(np.array(ttl_frames)) ttl_states = np.concatenate(np.array(ttl_states)) return ttl_frames, ttl_states def get_channel_ids(self): return self._channel_ids def get_num_frames(self): return self._num_frames def get_sampling_frequency(self): return self._sampling_frequency def frame_to_time(self, frame): recording, i_epoch, rel_frame = self._find_section_for_frame(frame) return np.round(recording.frame_to_time(rel_frame) + self._start_times[i_epoch], 6) def time_to_frame(self, time): recording, i_epoch, rel_time = self._find_section_for_time(time) return (recording.time_to_frame(rel_time) + self._start_frames[i_epoch]).astype('int64') def concatenate_recordings_by_time(recordings, epoch_names=None): """ Concatenates recordings together by time. The order of the recordings determines the order of the time series in the concatenated recording. Parameters ---------- recordings: list The list of RecordingExtractors to be concatenated by time epoch_names: list The list of strings corresponding to the names of recording time period. Returns ------- recording: MultiRecordingTimeExtractor The concatenated recording extractors enscapsulated in the MultiRecordingTimeExtractor object (which is also a recording extractor) """ return MultiRecordingTimeExtractor( recordings=recordings, epoch_names=epoch_names, ) ================================================ FILE: spikeextractors/multisortingextractor.py ================================================ from .sortingextractor import SortingExtractor import numpy as np from .extraction_tools import check_get_unit_spike_train # Encapsulates a grouping of non-continuous sorting extractors class MultiSortingExtractor(SortingExtractor): def __init__(self, sortings): SortingExtractor.__init__(self) self._sortings = sortings self._all_unit_ids = [] self._unit_map = {} u_id = 0 for s_i, sorting in enumerate(self._sortings): unit_ids = sorting.get_unit_ids() for unit_id in unit_ids: self._all_unit_ids.append(u_id) self._unit_map[u_id] = {'sorting_id': s_i, 'unit_id': unit_id} u_id += 1 self._kwargs = {'sortings': [sort.make_serialized_dict() for sort in sortings]} @property def sortings(self): return self._sortings def get_unit_ids(self): return list(self._all_unit_ids) @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): sorting_id = self._unit_map[unit_id]['sorting_id'] unit_id_sorting = self._unit_map[unit_id]['unit_id'] return self._sortings[sorting_id].get_unit_spike_train(unit_id_sorting, start_frame, end_frame) def set_sampling_frequency(self, sampling_frequency): for sorting in self._sortings: sorting.set_sampling_frequency(sampling_frequency) def get_sampling_frequency(self): return self._sortings[0].get_sampling_frequency() def set_unit_property(self, unit_id, property_name, value): if unit_id not in self._unit_map.keys(): raise ValueError("Non-valid unit_id") sorting_id = self._unit_map[unit_id]['sorting_id'] unit_id_sorting = self._unit_map[unit_id]['unit_id'] self._sortings[sorting_id].set_unit_property(unit_id_sorting, property_name, value) def get_unit_property(self, unit_id, property_name): if unit_id not in self._unit_map.keys(): raise ValueError("Non-valid unit_id") sorting_id = self._unit_map[unit_id]['sorting_id'] unit_id_sorting = self._unit_map[unit_id]['unit_id'] return self._sortings[sorting_id].get_unit_property(unit_id_sorting, property_name) def get_unit_property_names(self, unit_id): sorting_id = self._unit_map[unit_id]['sorting_id'] unit_id_sorting = self._unit_map[unit_id]['unit_id'] property_names = self._sortings[sorting_id].get_unit_property_names(unit_id_sorting) return property_names def clear_unit_property(self, unit_id, property_name): if unit_id not in self._unit_map.keys(): raise ValueError("Non-valid unit_id") sorting_id = self._unit_map[unit_id]['sorting_id'] unit_id_sorting = self._unit_map[unit_id]['unit_id'] self._sortings[sorting_id].clear_unit_property(unit_id_sorting, property_name) def get_unit_spike_features(self, unit_id, feature_name, start_frame=None, end_frame=None): start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame) if unit_id not in self._unit_map.keys(): raise ValueError("Non-valid unit_id") sorting_id = self._unit_map[unit_id]['sorting_id'] unit_id_sorting = self._unit_map[unit_id]['unit_id'] return self._sortings[sorting_id].get_unit_spike_features(unit_id_sorting, feature_name, start_frame=start_frame, end_frame=end_frame) def get_unit_spike_feature_names(self, unit_id): if isinstance(unit_id, (int, np.integer)): if unit_id in self.get_unit_ids(): if unit_id not in self._unit_map.keys(): raise ValueError("Non-valid unit_id") sorting_id = self._unit_map[unit_id]['sorting_id'] unit_id_sorting = self._unit_map[unit_id]['unit_id'] feature_names = sorted(self._sortings[sorting_id].get_unit_spike_feature_names(unit_id_sorting)) return feature_names else: raise ValueError("Non-valid unit_id") else: raise ValueError("unit_id must be an int") def set_unit_spike_features(self, unit_id, feature_name, value, indexes=None): if unit_id not in self._unit_map.keys(): raise ValueError("Non-valid unit_id") sorting_id = self._unit_map[unit_id]['sorting_id'] unit_id_sorting = self._unit_map[unit_id]['unit_id'] self._sortings[sorting_id].set_unit_spike_features(unit_id_sorting, feature_name, value, indexes) def clear_unit_spike_features(self, unit_id, feature_name): if unit_id not in self._unit_map.keys(): raise ValueError("Non-valid unit_id") sorting_id = self._unit_map[unit_id]['sorting_id'] unit_id_sorting = self._unit_map[unit_id]['unit_id'] self._sortings[sorting_id].clear_unit_spike_features(unit_id_sorting, feature_name) def concatenate_sortings(sortings): """ Concatenates sortings together. The sortings should be non-continuous Parameters ---------- sortings: list The list of SortingExtractors to be concatenated Returns ------- recording: MultiSortingExtractor The concatenated sorting extractors enscapsulated in the MultiSortingExtractor object (which is also a sorting extractor) """ return MultiSortingExtractor( sortings=sortings, ) ================================================ FILE: spikeextractors/recordingextractor.py ================================================ from abc import ABC, abstractmethod import numpy as np from copy import deepcopy from .extraction_tools import load_probe_file, save_to_probe_file, write_to_binary_dat_format, \ write_to_h5_dataset_format, get_sub_extractors_by_property, cast_start_end_frame from .baseextractor import BaseExtractor class RecordingExtractor(ABC, BaseExtractor): """A class that contains functions for extracting important information from recorded extracellular data. It is an abstract class so all functions with the @abstractmethod tag must be implemented for the initialization to work. """ _default_filename = "spikeinterface_recording" def __init__(self): BaseExtractor.__init__(self) self._key_properties = {'group': None, 'location': None, 'gain': None, 'offset': None} self.is_filtered = False @abstractmethod def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): """This function extracts and returns a trace from the recorded data from the given channels ids and the given start and end frame. It will return traces from within three ranges: [start_frame, start_frame+1, ..., end_frame-1] [start_frame, start_frame+1, ..., final_recording_frame - 1] [0, 1, ..., end_frame-1] [0, 1, ..., final_recording_frame - 1] if both start_frame and end_frame are given, if only start_frame is given, if only end_frame is given, or if neither start_frame or end_frame are given, respectively. Traces are returned in a 2D array that contains all of the traces from each channel with dimensions (num_channels x num_frames). In this implementation, start_frame is inclusive and end_frame is exclusive conforming to numpy standards. Parameters ---------- channel_ids: array_like A list or 1D array of channel ids (ints) from which each trace will be extracted. start_frame: int The starting frame of the trace to be returned (inclusive). end_frame: int The ending frame of the trace to be returned (exclusive). return_scaled: bool If True, traces are returned after scaling (using gain/offset). If False, the raw traces are returned. Returns ------- traces: numpy.ndarray A 2D array that contains all of the traces from each channel. Dimensions are: (num_channels x num_frames) """ pass @abstractmethod def get_num_frames(self): """This function returns the number of frames in the recording Returns ------- num_frames: int Number of frames in the recording (duration of recording) """ pass @abstractmethod def get_sampling_frequency(self): """This function returns the sampling frequency in units of Hz. Returns ------- fs: float Sampling frequency of the recordings in Hz """ pass @abstractmethod def get_channel_ids(self): """Returns the list of channel ids. If not specified, the range from 0 to num_channels - 1 is returned. Returns ------- channel_ids: list Channel list """ pass def get_num_channels(self): """This function returns the number of channels in the recording. Returns ------- num_channels: int Number of channels in the recording """ return len(self.get_channel_ids()) def get_dtype(self, return_scaled=True): """This function returns the traces dtype Parameters ---------- return_scaled: bool If False and the recording extractor has unscaled traces, it returns the dtype of unscaled traces. If True (default) it returns the dtype of the scaled traces Returns ------- dtype: np.dtype The dtype of the traces """ return self.get_traces(channel_ids=[self.get_channel_ids()[0]], start_frame=0, end_frame=1, return_scaled=return_scaled).dtype def set_times(self, times): """This function sets the recording times (in seconds) for each frame Parameters ---------- times: array-like The times in seconds for each frame """ assert len(times) == self.get_num_frames(), "'times' should have the same length of the " \ "number of frames" self._times = times.astype('float64') def copy_times(self, extractor): """This function copies times from another extractor. Parameters ---------- extractor: BaseExtractor The extractor from which the epochs will be copied """ if extractor._times is not None: self.set_times(deepcopy(extractor._times)) def frame_to_time(self, frames): """This function converts user-inputted frame indexes to times with units of seconds. Parameters ---------- frames: float or array-like The frame or frames to be converted to times Returns ------- times: float or array-like The corresponding times in seconds """ # Default implementation if self._times is None: return np.round(frames / self.get_sampling_frequency(), 6) else: return self._times[frames] def time_to_frame(self, times): """This function converts a user-inputted times (in seconds) to a frame indexes. Parameters ------- times: float or array-like The times (in seconds) to be converted to frame indexes Returns ------- frames: float or array-like The corresponding frame indexes """ # Default implementation if self._times is None: return np.round(times * self.get_sampling_frequency()).astype('int64') else: return np.searchsorted(self._times, times).astype('int64') def get_snippets(self, reference_frames, snippet_len, channel_ids=None, return_scaled=True): """This function returns data snippets from the given channels that are starting on the given frames and are the length of the given snippet lengths before and after. Parameters ---------- reference_frames: array_like A list or array of frames that will be used as the reference frame of each snippet. snippet_len: int or tuple If int, the snippet will be centered at the reference frame and and return half before and half after of the length. If tuple, it will return the first value of before frames and the second value of after frames around the reference frame (allows for asymmetry). channel_ids: array_like A list or array of channel ids (ints) from which each trace will be extracted return_scaled: bool If True, snippets are returned after scaling (using gain/offset). If False, the raw traces are returned. Returns ------- snippets: numpy.ndarray Returns a list of the snippets as numpy arrays. The length of the list is len(reference_frames) Each array has dimensions: (num_channels x snippet_len) Out-of-bounds cases should be handled by filling in zeros in the snippet """ # Default implementation if isinstance(snippet_len, (tuple, list, np.ndarray)): snippet_len_before = int(snippet_len[0]) snippet_len_after = int(snippet_len[1]) else: snippet_len_before = int((snippet_len + 1) / 2) snippet_len_after = int(snippet_len - snippet_len_before) if channel_ids is None: channel_ids = self.get_channel_ids() num_snippets = len(reference_frames) num_channels = len(channel_ids) num_frames = self.get_num_frames() snippet_len_total = int(snippet_len_before + snippet_len_after) snippets = np.zeros((num_snippets, num_channels, snippet_len_total), dtype=self.get_dtype(return_scaled)) for i in range(num_snippets): snippet_chunk = np.zeros((num_channels, snippet_len_total), dtype=self.get_dtype(return_scaled)) if 0 <= reference_frames[i] < num_frames: snippet_range = np.array([int(reference_frames[i]) - snippet_len_before, int(reference_frames[i]) + snippet_len_after]) snippet_buffer = np.array([0, snippet_len_total], dtype='int') # The following handles the out-of-bounds cases if snippet_range[0] < 0: snippet_buffer[0] -= snippet_range[0] snippet_range[0] -= snippet_range[0] if snippet_range[1] >= num_frames: snippet_buffer[1] -= snippet_range[1] - num_frames snippet_range[1] -= snippet_range[1] - num_frames snippet_chunk[:, snippet_buffer[0]:snippet_buffer[1]] = self.get_traces(channel_ids=channel_ids, start_frame=snippet_range[0], end_frame=snippet_range[1], return_scaled=return_scaled) snippets[i] = snippet_chunk return snippets def set_channel_locations(self, locations, channel_ids=None): """This function sets the location key properties of each specified channel id with the corresponding locations of the passed in locations list. Parameters ---------- locations: array_like A list of corresponding locations (array_like) for the given channel_ids channel_ids: array-like or int The channel ids (ints) for which the locations will be specified. If None, all channel ids are assumed. """ if channel_ids is None: channel_ids = list(self.get_channel_ids()) if isinstance(channel_ids, (int, np.integer)): channel_ids = [channel_ids] locations = [locations] # Only None upon initialization if self._key_properties['location'] is None: default_locations = np.empty((self.get_num_channels(), 3), dtype='float') default_locations[:] = np.nan self._key_properties['location'] = default_locations if len(channel_ids) == len(locations): for i in range(len(channel_ids)): if isinstance(locations[i], (list, np.ndarray, tuple)): location = np.asarray(locations[i]) channel_idx = list(self.get_channel_ids()).index(channel_ids[i]) if len(location) == 2: self._key_properties['location'][channel_idx, :2] = location elif len(location) == 3: self._key_properties['location'][channel_idx] = location else: raise TypeError("'location' must be 2d ior 3d") else: raise TypeError("'location' must be an array like object") else: raise ValueError("channel_ids and locations must have same length") def get_channel_locations(self, channel_ids=None, locations_2d=True): """This function returns the location of each channel specified by channel_ids Parameters ---------- channel_ids: array-like or int The channel ids (ints) for which the locations will be returned. If None, all channel ids are assumed. locations_2d: bool If True (default), first two dimensions are returned Returns ------- locations: array_like Returns a list of corresponding locations (floats) for the given channel_ids """ if channel_ids is None: channel_ids = list(self.get_channel_ids()) if isinstance(channel_ids, (int, np.integer)): channel_ids = [channel_ids] locations = self._key_properties['location'] # Only None upon initialization if locations is None: locations = np.empty((self.get_num_channels(), 3), dtype='float') locations[:] = np.nan self._key_properties['location'] = locations locations = np.array(locations) channel_idxs = np.array([list(self.get_channel_ids()).index(ch) for ch in channel_ids]) if locations_2d: locations = np.array(locations)[:, :2] return locations[channel_idxs] def clear_channel_locations(self, channel_ids=None): """This function clears the location of each channel specified by channel_ids. Parameters ---------- channel_ids: array-like or int The channel ids (ints) for which the locations will be cleared. If None, all channel ids are assumed. """ if channel_ids is None: channel_ids = list(self.get_channel_ids()) if isinstance(channel_ids, (int, np.integer)): channel_ids = [channel_ids] # Reset to default locations (NaN) default_locations = np.array([[np.nan, np.nan, np.nan] for i in range(len(channel_ids))]) self.set_channel_locations(default_locations, channel_ids) def set_channel_groups(self, groups, channel_ids=None): """This function sets the group key property of each specified channel id with the corresponding group of the passed in groups list. Parameters ---------- groups: array-like or int A list of groups (ints) for the channel_ids channel_ids: array_like or None The channel ids (ints) for which the groups will be specified. If None, all channel ids are assumed. """ if channel_ids is None: channel_ids = list(self.get_channel_ids()) if isinstance(channel_ids, (int, np.integer)): channel_ids = [channel_ids] if isinstance(groups, (int, np.integer)): groups = [groups] # Only None upon initialization if self._key_properties['group'] is None: self._key_properties['group'] = np.zeros(self.get_num_channels(), dtype='int') if len(channel_ids) == len(groups): for i in range(len(channel_ids)): if isinstance(groups[i], (int, np.integer)): channel_idx = list(self.get_channel_ids()).index(channel_ids[i]) self._key_properties['group'][channel_idx] = int(groups[i]) else: raise TypeError("'group' must be an int") else: raise ValueError("channel_ids and groups must have same length") def get_channel_groups(self, channel_ids=None): """This function returns the group of each channel specified by channel_ids Parameters ---------- channel_ids: array-like or int The channel ids (ints) for which the groups will be returned Returns ------- groups: array_like Returns a list of corresponding groups (ints) for the given channel_ids """ if channel_ids is None: channel_ids = list(self.get_channel_ids()) if isinstance(channel_ids, (int, np.integer)): channel_ids = [channel_ids] groups = self._key_properties['group'] # Only None upon initialization if groups is None: groups = np.zeros(self.get_num_channels(), dtype='int') self._key_properties['group'] = groups groups = np.array(groups) channel_idxs = np.array([list(self.get_channel_ids()).index(ch) for ch in channel_ids]) return groups[channel_idxs] def clear_channel_groups(self, channel_ids=None): """This function clears the group of each channel specified by channel_ids Parameters ---------- channel_ids: array-like or int The channel ids (ints) for which the groups will be cleared. If None, all channel ids are assumed. """ if channel_ids is None: channel_ids = list(self.get_channel_ids()) if isinstance(channel_ids, (int, np.integer)): channel_ids = [channel_ids] # Reset to default groups (0) default_groups = [0] * len(channel_ids) self.set_channel_groups(default_groups, channel_ids) def set_channel_gains(self, gains, channel_ids=None): """This function sets the gain key property of each specified channel id with the corresponding group of the passed in gains float/list. Parameters ---------- gains: float/array_like If a float, each channel will be assigned the corresponding gain. If a list, each channel will be given a gain from the list channel_ids: array_like or None The channel ids (ints) for which the groups will be specified. If None, all channel ids are assumed. """ if channel_ids is None: channel_ids = list(self.get_channel_ids()) if isinstance(channel_ids, (int, np.integer)): channel_ids = [channel_ids] if isinstance(gains, (int, np.integer, float)): gains = [gains] * len(channel_ids) # Only None upon initialization if self._key_properties['gain'] is None: self._key_properties['gain'] = np.ones(self.get_num_channels(), dtype='float') if len(channel_ids) == len(gains): for i in range(len(channel_ids)): if isinstance(gains[i], (int, np.integer, float)): channel_idx = list(self.get_channel_ids()).index(channel_ids[i]) self._key_properties['gain'][channel_idx] = float(gains[i]) else: raise TypeError("'gain' must be an int or float") else: raise ValueError("channel_ids and gains must have same length") def get_channel_gains(self, channel_ids=None): """This function returns the gain of each channel specified by channel_ids. Parameters ---------- channel_ids: array_like The channel ids (ints) for which the gains will be returned Returns ------- gains: array_like Returns a list of corresponding gains (floats) for the given channel_ids """ if channel_ids is None: channel_ids = list(self.get_channel_ids()) if isinstance(channel_ids, (int, np.integer)): channel_ids = [channel_ids] gains = self._key_properties['gain'] # Only None upon initialization if gains is None: gains = np.ones(self.get_num_channels(), dtype='float') self._key_properties['gain'] = gains gains = np.array(gains) channel_idxs = np.array([list(self.get_channel_ids()).index(ch) for ch in channel_ids]) return gains[channel_idxs] def clear_channel_gains(self, channel_ids=None): """This function clears the gains of each channel specified by channel_ids Parameters ---------- channel_ids: array-like or int The channel ids (ints) for which the groups will be cleared. If None, all channel ids are assumed. """ if channel_ids is None: channel_ids = list(self.get_channel_ids()) if isinstance(channel_ids, (int, np.integer)): channel_ids = [channel_ids] # Reset to default gains (1) default_gains = [1.] * len(channel_ids) self.set_channel_gains(default_gains, channel_ids) def set_channel_offsets(self, offsets, channel_ids=None): """This function sets the offset key property of each specified channel id with the corresponding group of the passed in gains float/list. Parameters ---------- offsets: float/array_like If a float, each channel will be assigned the corresponding offset. If a list, each channel will be given an offset from the list channel_ids: array_like or None The channel ids (ints) for which the groups will be specified. If None, all channel ids are assumed. """ if channel_ids is None: channel_ids = list(self.get_channel_ids()) if isinstance(channel_ids, (int, np.integer)): channel_ids = [channel_ids] if isinstance(offsets, (int, np.integer, float)): offsets = [offsets] * len(channel_ids) # Only None upon initialization if self._key_properties['offset'] is None: self._key_properties['offset'] = np.zeros(self.get_num_channels(), dtype='float') if len(channel_ids) == len(offsets): for i in range(len(channel_ids)): if isinstance(offsets[i], (int, np.integer, float)): channel_idx = list(self.get_channel_ids()).index(channel_ids[i]) self._key_properties['offset'][channel_idx] = float(offsets[i]) else: raise TypeError("'offset' must be an int or float") else: raise ValueError("channel_ids and offsets must have same length") def get_channel_offsets(self, channel_ids=None): """This function returns the offset of each channel specified by channel_ids. Parameters ---------- channel_ids: array_like The channel ids (ints) for which the gains will be returned Returns ------- offsets: array_like Returns a list of corresponding offsets for the given channel_ids """ if channel_ids is None: channel_ids = list(self.get_channel_ids()) if isinstance(channel_ids, (int, np.integer)): channel_ids = [channel_ids] offsets = self._key_properties['offset'] # Only None upon initialization if offsets is None: offsets = np.zeros(self.get_num_channels(), dtype='float') self._key_properties['offset'] = offsets offsets = np.array(offsets) channel_idxs = np.array([list(self.get_channel_ids()).index(ch) for ch in channel_ids]) return offsets[channel_idxs] def clear_channel_offsets(self, channel_ids=None): """This function clears the gains of each channel specified by channel_ids. Parameters ---------- channel_ids: array-like or int The channel ids (ints) for which the groups will be cleared. If None, all channel ids are assumed. """ if channel_ids is None: channel_ids = list(self.get_channel_ids()) if isinstance(channel_ids, (int, np.integer)): channel_ids = [channel_ids] # Reset to default offets (0) default_offsets = [0.] * len(channel_ids) self.set_channel_offsets(default_offsets, channel_ids) def set_channel_property(self, channel_id, property_name, value): """This function adds a property dataset to the given channel under the property name. Parameters ---------- channel_id: int The channel id for which the property will be added property_name: str A property stored by the RecordingExtractor (location, etc.) value: The data associated with the given property name. Could be many formats as specified by the user """ if isinstance(channel_id, (int, np.integer)): if channel_id in self.get_channel_ids(): if isinstance(property_name, str): if property_name == 'location': self.set_channel_locations(value, channel_id) elif property_name == 'group': self.set_channel_groups(value, channel_id) else: if channel_id not in self._properties.keys(): self._properties[channel_id] = {} self._properties[channel_id][property_name] = value else: raise TypeError(str(property_name) + " must be a string") else: raise ValueError(str(channel_id) + " is not a valid channel_id") else: raise TypeError(str(channel_id) + " must be an int") def get_channel_property(self, channel_id, property_name): """This function returns the data stored under the property name from the given channel. Parameters ---------- channel_id: int The channel id for which the property will be returned property_name: str A property stored by the RecordingExtractor (location, etc.) Returns ------- property_data The data associated with the given property name. Could be many formats as specified by the user """ if not isinstance(channel_id, (int, np.integer)): raise TypeError(str(channel_id) + " must be an int") if channel_id not in self.get_channel_ids(): raise ValueError(str(channel_id) + " is not a valid channel_id") if property_name == 'location': return self.get_channel_locations(channel_id)[0] if property_name == 'group': return self.get_channel_groups(channel_id)[0] if property_name == 'gain': return self.get_channel_gains(channel_id)[0] if property_name == 'offset': return self.get_channel_offsets(channel_id)[0] if channel_id not in self._properties.keys(): raise ValueError('no properties found for channel ' + str(channel_id)) if property_name not in self._properties[channel_id]: raise RuntimeError(str(property_name) + " has not been added to channel " + str(channel_id)) if not isinstance(property_name, str): raise TypeError(str(property_name) + " must be a string") return self._properties[channel_id][property_name] def get_channel_property_names(self, channel_id): """Get a list of property names for a given channel. Parameters ---------- channel_id: int The channel id for which the property names will be returned If None (default), will return property names for all channels Returns ------- property_names The list of property names """ if isinstance(channel_id, (int, np.integer)): if channel_id in self.get_channel_ids(): if channel_id not in self._properties.keys(): self._properties[channel_id] = {} property_names = list(self._properties[channel_id].keys()) if np.all(np.logical_not(np.isnan(self.get_channel_locations(channel_id)))): property_names.extend(['location']) property_names.extend(['group']) property_names.extend(['gain']) property_names.extend(['offset']) return sorted(property_names) else: raise ValueError(str(channel_id) + " is not a valid channel_id") else: raise TypeError(str(channel_id) + " must be an int") def get_shared_channel_property_names(self, channel_ids=None): """Get the intersection of channel property names for a given set of channels or for all channels if channel_ids is None. Parameters ---------- channel_ids: array_like The channel ids for which the shared property names will be returned. If None (default), will return shared property names for all channels Returns ------- property_names The list of shared property names """ if channel_ids is None: channel_ids = self.get_channel_ids() curr_property_name_set = set(self.get_channel_property_names(channel_id=channel_ids[0])) for channel_id in channel_ids[1:]: curr_channel_property_name_set = set(self.get_channel_property_names(channel_id=channel_id)) curr_property_name_set = curr_property_name_set.intersection(curr_channel_property_name_set) property_names = list(curr_property_name_set) return sorted(property_names) def copy_channel_properties(self, recording, channel_ids=None): """Copy channel properties from another recording extractor to the current recording extractor. Parameters ---------- recording: RecordingExtractor The recording extractor from which the properties will be copied channel_ids: (array_like, (int, np.integer)) The list (or single value) of channel_ids for which the properties will be copied """ if channel_ids is None: self._key_properties = deepcopy(recording._key_properties) self._properties = deepcopy(recording._properties) else: if isinstance(channel_ids, (int, np.integer)): channel_ids = [channel_ids] # copy key properties groups = recording.get_channel_groups(channel_ids=channel_ids) locations = recording.get_channel_locations(channel_ids=channel_ids) gains = recording.get_channel_gains(channel_ids=channel_ids) offsets = recording.get_channel_offsets(channel_ids=channel_ids) self.set_channel_groups(groups) self.set_channel_locations(locations) self.set_channel_gains(gains) self.set_channel_offsets(offsets) # copy normal properties for channel_id in channel_ids: curr_property_names = recording.get_channel_property_names(channel_id=channel_id) for curr_property_name in curr_property_names: if curr_property_name not in self._key_properties.keys(): # key property value = recording.get_channel_property(channel_id=channel_id, property_name=curr_property_name) self.set_channel_property(channel_id=channel_id, property_name=curr_property_name, value=value) def clear_channel_property(self, channel_id, property_name): """This function clears the channel property for the given property. Parameters ---------- channel_id: int The id that specifies a channel in the recording property_name: string The name of the property to be cleared """ if property_name == 'location': self.clear_channel_locations(channel_id) elif property_name == 'group': self.clear_channel_groups(channel_id) elif channel_id in self._properties.keys(): if property_name in self._properties[channel_id]: del self._properties[channel_id][property_name] def clear_channels_property(self, property_name, channel_ids=None): """This function clears the channels' properties for the given property. Parameters ---------- property_name: string The name of the property to be cleared channel_ids: list A list of ids that specifies a set of channels in the recording. If None all channels are cleared """ if channel_ids is None: channel_ids = self.get_channel_ids() for channel_id in channel_ids: self.clear_channel_property(channel_id, property_name) def get_epoch(self, epoch_name): """This function returns a SubRecordingExtractor which is a view to the given epoch Parameters ---------- epoch_name: str The name of the epoch to be returned Returns ------- epoch_extractor: SubRecordingExtractor A SubRecordingExtractor which is a view to the given epoch """ from .subrecordingextractor import SubRecordingExtractor epoch_info = self.get_epoch_info(epoch_name) start_frame = epoch_info['start_frame'] end_frame = epoch_info['end_frame'] return SubRecordingExtractor(parent_recording=self, start_frame=start_frame, end_frame=end_frame) def load_probe_file(self, probe_file, channel_map=None, channel_groups=None, verbose=False): """This function returns a SubRecordingExtractor that contains information from the given probe file (channel locations, groups, etc.) If a .prb file is given, then 'location' and 'group' information for each channel is added to the SubRecordingExtractor. If a .csv file is given, then it will only add 'location' to the SubRecordingExtractor. Parameters ---------- probe_file: str Path to probe file. Either .prb or .csv channel_map : array-like A list of channel IDs to set in the loaded file. Only used if the loaded file is a .csv. channel_groups : array-like A list of groups (ints) for the channel_ids to set in the loaded file. Only used if the loaded file is a .csv. verbose: bool If True, output is verbose Returns ------- subrecording = SubRecordingExtractor The extractor containing all of the probe information. """ subrecording = load_probe_file(self, probe_file, channel_map=channel_map, channel_groups=channel_groups, verbose=verbose) return subrecording def save_to_probe_file(self, probe_file, grouping_property=None, radius=None, graph=True, geometry=True, verbose=False): """Saves probe file from the channel information of this recording extractor. Parameters ---------- probe_file: str file name of .prb or .csv file to save probe information to grouping_property: str (default None) If grouping_property is a shared_channel_property, different groups are saved based on the property. radius: float (default None) Adjacency radius (used by some sorters). If None it is not saved to the probe file. graph: bool If True, the adjacency graph is saved (default=True) geometry: bool If True, the geometry is saved (default=True) verbose: bool If True, output is verbose """ save_to_probe_file(self, probe_file, grouping_property=grouping_property, radius=radius, graph=graph, geometry=geometry, verbose=verbose) def write_to_binary_dat_format(self, save_path, time_axis=0, dtype=None, chunk_size=None, chunk_mb=500, n_jobs=1, joblib_backend='loky', return_scaled=True, verbose=False): """Saves the traces of this recording extractor into binary .dat format. Parameters ---------- save_path: str The path to the file. time_axis: 0 (default) or 1 If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file. If 1, the traces shape (nb_channel, nb_sample) is kept in the file. dtype: dtype Type of the saved data. Default float32 chunk_size: None or int Size of each chunk in number of frames. If None (default) and 'chunk_mb' is given, the file is saved in chunks of 'chunk_mb' Mb (default 500Mb) chunk_mb: None or int Chunk size in Mb (default 500Mb) n_jobs: int Number of jobs to use (Default 1) joblib_backend: str Joblib backend for parallel processing ('loky', 'threading', 'multiprocessing') return_scaled: bool If True, traces are returned after scaling (using gain/offset). If False, the raw traces are returned verbose: bool If True, output is verbose (when chunks are used) """ write_to_binary_dat_format(self, save_path=save_path, time_axis=time_axis, dtype=dtype, chunk_size=chunk_size, chunk_mb=chunk_mb, n_jobs=n_jobs, joblib_backend=joblib_backend, return_scaled=return_scaled, verbose=verbose) def write_to_h5_dataset_format(self, dataset_path, save_path=None, file_handle=None, time_axis=0, dtype=None, chunk_size=None, chunk_mb=500, verbose=False): """Saves the traces of a recording extractor in an h5 dataset. Parameters ---------- dataset_path: str Path to dataset in h5 file (e.g. '/dataset') save_path: str The path to the file. file_handle: file handle The file handle to dump data. This can be used to append data to an header. In case file_handle is given, the file is NOT closed after writing the binary data. time_axis: 0 (default) or 1 If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file. If 1, the traces shape (nb_channel, nb_sample) is kept in the file. dtype: dtype Type of the saved data. Default float32. chunk_size: None or int Size of each chunk in number of frames. If None (default) and 'chunk_mb' is given, the file is saved in chunks of 'chunk_mb' Mb (default 500Mb) chunk_mb: None or int Chunk size in Mb (default 500Mb) verbose: bool If True, output is verbose (when chunks are used) """ write_to_h5_dataset_format(self, dataset_path, save_path, file_handle, time_axis, dtype, chunk_size, chunk_mb, verbose) def get_sub_extractors_by_property(self, property_name, return_property_list=False): """Returns a list of SubRecordingExtractors from this RecordingExtractor based on the given property_name (e.g. group) Parameters ---------- property_name: str The property used to subdivide the extractor return_property_list: bool If True the property list is returned Returns ------- sub_list: list The list of subextractors to be returned OR sub_list, prop_list If return_property_list is True, the property list will be returned as well """ if return_property_list: sub_list, prop_list = get_sub_extractors_by_property(self, property_name=property_name, return_property_list=return_property_list) return sub_list, prop_list else: sub_list = get_sub_extractors_by_property(self, property_name=property_name, return_property_list=return_property_list) return sub_list def get_ttl_events(self, start_frame=None, end_frame=None, channel_id=0): """ Returns an array with frames of TTL signals. To be implemented in sub-classes Parameters ---------- start_frame: int The starting frame of the ttl to be returned (inclusive) end_frame: int The ending frame of the ttl to be returned (exclusive) channel_id: int The TTL channel id Returns ------- ttl_frames: array-like Frames of TTL signal for the specified channel ttl_state: array-like State of the transition: 1 - rising, -1 - falling """ raise NotImplementedError @staticmethod def write_recording(recording, save_path): """This function writes out the recorded file of a given recording extractor to the file format of this current recording extractor. Allows for easy conversion between recording file formats. It is a static method so it can be used without instantiating this recording extractor. Parameters ---------- recording: RecordingExtractor An RecordingExtractor that can extract information from the recording file to be converted to the new format. save_path: string A path to where the converted recorded data will be saved, which may either be a file or a folder, depending on the format. """ raise NotImplementedError("The write_recording function is not \ implemented for this extractor") ================================================ FILE: spikeextractors/save_tools.py ================================================ from pathlib import Path from .cacheextractors import CacheRecordingExtractor, CacheSortingExtractor from .recordingextractor import RecordingExtractor from .sortingextractor import SortingExtractor def save_si_object(object_name: str, si_object, output_folder, cache_raw=False, include_properties=True, include_features=False): """ Save an arbitrary SI object to a temprary location. Parameters ---------- object_name: str The unique name of the SpikeInterface object. si_object: RecordingExtractor or SortingExtractor The extractor to be saved. output_folder: str or Path The folder where the object is saved. cache_raw: bool If True, the Extractor is cached to a binary file (not recommended for RecordingExtractor objects) (default False). include_properties: bool If True, properties (channel or unit) are saved (default True). include_features: bool If True, spike features are saved (default False) """ Path(output_folder).mkdir(parents=True, exist_ok=True) if isinstance(si_object, RecordingExtractor): if not si_object.is_dumpable: cache = CacheRecordingExtractor(si_object, save_path=output_folder / "raw.dat") elif cache_raw: # save to json before caching to keep history (in case it's needed) json_file = output_folder / f"{object_name}.json" si_object.dump_to_json(output_folder / json_file) cache = CacheRecordingExtractor(si_object, save_path=output_folder / "raw.dat") else: cache = si_object elif isinstance(si_object, SortingExtractor): if not si_object.is_dumpable: cache = CacheSortingExtractor(si_object, save_path=output_folder / "sorting.npz") elif cache_raw: # save to json before caching to keep history (in case it's needed) json_file = output_folder / f"{object_name}.json" si_object.dump_to_json(output_folder / json_file) cache = CacheSortingExtractor(si_object, save_path=output_folder / "sorting.npz") else: cache = si_object else: raise ValueError("The 'si_object' argument shoulde be a SpikeInterface Extractor!") pkl_file = output_folder / f"{object_name}.pkl" cache.dump_to_pickle( output_folder / pkl_file, include_properties=include_properties, include_features=include_features ) ================================================ FILE: spikeextractors/sortingextractor.py ================================================ from abc import ABC, abstractmethod import numpy as np from copy import deepcopy from .extraction_tools import get_sub_extractors_by_property from .baseextractor import BaseExtractor class SortingExtractor(ABC, BaseExtractor): """A class that contains functions for extracting important information from spiked sorted data given a spike sorting software. It is an abstract class so all functions with the @abstractmethod tag must be implemented for the initialization to work. """ _default_filename = "spikeinterface_sorting" def __init__(self): BaseExtractor.__init__(self) self._sampling_frequency = None @abstractmethod def get_unit_ids(self): """This function returns a list of ids (ints) for each unit in the sorsted result. Returns ------- unit_ids: array_like A list of the unit ids in the sorted result (ints). """ pass @abstractmethod def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): """This function extracts spike frames from the specified unit. It will return spike frames from within three ranges: [start_frame, t_start+1, ..., end_frame-1] [start_frame, start_frame+1, ..., final_unit_spike_frame - 1] [0, 1, ..., end_frame-1] [0, 1, ..., final_unit_spike_frame - 1] if both start_frame and end_frame are given, if only start_frame is given, if only end_frame is given, or if neither start_frame or end_frame are given, respectively. Spike frames are returned in the form of an array_like of spike frames. In this implementation, start_frame is inclusive and end_frame is exclusive conforming to numpy standards. Parameters ---------- unit_id: int The id that specifies a unit in the recording start_frame: int The frame above which a spike frame is returned (inclusive) end_frame: int The frame below which a spike frame is returned (exclusive) Returns ------- spike_train: numpy.ndarray An 1D array containing all the frames for each spike in the specified unit given the range of start and end frames """ pass def get_units_spike_train(self, unit_ids=None, start_frame=None, end_frame=None): """This function extracts spike frames from the specified units. Parameters ---------- unit_ids: array_like The unit ids from which to return spike trains. If None, all unit spike trains will be returned start_frame: int The frame above which a spike frame is returned (inclusive) end_frame: int The frame below which a spike frame is returned (exclusive) Returns ------- spike_train: numpy.ndarray An 2D array containing all the frames for each spike in the specified units given the range of start and end frames """ if unit_ids is None: unit_ids = self.get_unit_ids() spike_trains = [self.get_unit_spike_train(uid, start_frame, end_frame) for uid in unit_ids] return spike_trains def get_sampling_frequency(self): """ It returns the sampling frequency. Returns ------- sampling_frequency: float The sampling frequency """ return self._sampling_frequency def set_sampling_frequency(self, sampling_frequency): """ It sets the sorting extractor sampling frequency. Parameters ---------- sampling_frequency: float The sampling frequency """ self._sampling_frequency = sampling_frequency def set_unit_spike_features(self, unit_id, feature_name, value, indexes=None): """This function adds a unit features data set under the given features name to the given unit. Parameters ---------- unit_id: int The unit id for which the features will be set feature_name: str The name of the feature to be stored value: array_like The data associated with the given feature name. Could be many formats as specified by the user. indexes: array_like The indices of the specified spikes (if the number of spike features is less than the length of the unit's spike train). If None, it is assumed that value has the same length as the spike train. """ if isinstance(unit_id, (int, np.integer)): if unit_id in self.get_unit_ids(): if unit_id not in self._features.keys(): self._features[unit_id] = {} if indexes is None: if isinstance(feature_name, str) and len(value) == len(self.get_unit_spike_train(unit_id)): self._features[unit_id][feature_name] = value else: if not isinstance(feature_name, str): raise ValueError("feature_name must be a string") else: raise ValueError("feature values should have the same length as the spike train") else: if isinstance(feature_name, str) and len(value) == len(indexes): indexes = np.array(indexes) self._features[unit_id][feature_name] = value self._features[unit_id][feature_name + '_idxs'] = indexes else: if not isinstance(feature_name, str): raise ValueError("feature_name must be a string") else: raise ValueError("feature values should have the same length as indexes") else: raise ValueError(str(unit_id) + " is not a valid unit_id") else: raise ValueError(str(unit_id) + " must be an int") def get_unit_spike_features(self, unit_id, feature_name, start_frame=None, end_frame=None): """This function extracts the specified spike features from the specified unit. It will return spike features from within three ranges: [start_frame, t_start+1, ..., end_frame-1] [start_frame, start_frame+1, ..., final_unit_spike_frame - 1] [0, 1, ..., end_frame-1] [0, 1, ..., final_unit_spike_frame - 1] if both start_frame and end_frame are given, if only start_frame is given, if only end_frame is given, or if neither start_frame or end_frame are given, respectively. Spike features are returned in the form of an array_like of spike features. In this implementation, start_frame is inclusive and end_frame is exclusive conforming to numpy standards. Parameters ---------- unit_id: int The id that specifies a unit in the recording feature_name: string The name of the feature to be returned start_frame: int The frame above which a spike frame is returned (inclusive) end_frame: int The frame below which a spike frame is returned (exclusive) Returns ------- spike_features: numpy.ndarray An array containing all the features for each spike in the specified unit given the range of start and end frames """ start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame) if isinstance(unit_id, (int, np.integer)): if unit_id in self.get_unit_ids(): if unit_id not in self._features.keys(): self._features[unit_id] = {} if isinstance(feature_name, str): if feature_name in self._features[unit_id].keys(): spike_train = self.get_unit_spike_train(unit_id) if start_frame is None: start_frame = 0 if end_frame is None: end_frame = np.inf if start_frame == 0 and end_frame == np.inf: # keep memmap objects return self._features[unit_id][feature_name] else: if len(self._features[unit_id][feature_name]) == len(spike_train): spike_indices = np.where(np.logical_and(spike_train >= start_frame, spike_train < end_frame)) elif len(self._features[unit_id][feature_name]) < len(spike_train): if not feature_name.endswith('idxs'): # retrieve features on the correct idxs assert feature_name + '_idxs' in self.get_unit_spike_feature_names(unit_id=unit_id) feature_name_idxs = feature_name + '_idxs' value_idxs = np.array(self.get_unit_spike_features(unit_id=unit_id, feature_name=feature_name_idxs)) spike_train = spike_train[value_idxs] spike_indices = np.where(np.logical_and(spike_train >= start_frame, spike_train < end_frame)) else: # retrieve idxs features value_idxs = np.array(self.get_unit_spike_features(unit_id=unit_id, feature_name=feature_name)) spike_train = spike_train[value_idxs] spike_indices = np.where(np.logical_and(spike_train >= start_frame, spike_train < end_frame)) else: raise ValueError(str(feature_name) + " dimensions are inconsistent for unit " + str(unit_id)) if isinstance(self._features[unit_id][feature_name], list): return list(np.array(self._features[unit_id][feature_name])[spike_indices]) else: return np.array(self._features[unit_id][feature_name])[spike_indices] else: raise ValueError(str(feature_name) + " has not been added to unit " + str(unit_id)) else: raise ValueError(str(feature_name) + " must be a string") else: raise ValueError(str(unit_id) + " is not a valid unit_id") else: raise ValueError(str(unit_id) + " must be an int") def set_times(self, times): """This function sets the sorting times to convert spike trains to seconds Parameters ---------- times: array-like The times in seconds for each frame """ max_frames = np.array([np.max(self.get_unit_spike_train(u)) for u in self.get_unit_ids()]) assert np.all(max_frames < len(times)), "The length of 'times' should be greater than the maximum " \ "spike frame index" self._times = times.astype('float64') def copy_times(self, extractor): """This function copies times from another extractor. Parameters ---------- extractor: BaseExtractor The extractor from which the epochs will be copied """ if extractor._times is not None: self.set_times(deepcopy(extractor._times)) def frame_to_time(self, frames): """This function converts user-inputted frame indexes to times with units of seconds. Parameters ---------- frames: float or array-like The frame or frames to be converted to times Returns ------- times: float or array-like The corresponding times in seconds """ # Default implementation if self._times is None: return np.round(frames / self.get_sampling_frequency(), 6) else: return self._times[frames] def time_to_frame(self, times): """This function converts a user-inputted times (in seconds) to a frame indexes. Parameters ---------- times: float or array-like The times (in seconds) to be converted to frame indexes Returns ------- frames: float or array-like The corresponding frame indexes """ # Default implementation if self._times is None: return np.round(times * self.get_sampling_frequency()).astype('int64') else: return np.searchsorted(self._times, times).astype('int64') def clear_unit_spike_features(self, unit_id, feature_name): """This function clears the unit spikes features for the given feature. Parameters ---------- unit_id: int The id that specifies a unit in the sorting feature_name: string The name of the feature to be cleared """ if unit_id in self._features.keys(): if feature_name in self._features[unit_id]: del self._features[unit_id][feature_name] def clear_units_spike_features(self, feature_name, unit_ids=None): """This function clears the units' spikes features for the given feature. Parameters ---------- feature_name: string The name of the feature to be cleared unit_ids: list A list of ids that specifies a set of units in the sorting. If None, all units are cleared """ if unit_ids is None: unit_ids = self.get_unit_ids() for unit_id in unit_ids: self.clear_unit_spike_features(unit_id, feature_name) def get_unit_spike_feature_names(self, unit_id): """This function returns the list of feature names for the given unit Parameters ---------- unit_id: int The unit id for which the feature names will be returned Returns ------- property_names The list of feature names. """ if isinstance(unit_id, (int, np.integer)): if unit_id in self.get_unit_ids(): if unit_id not in self._features.keys(): self._features[unit_id] = {} feature_names = sorted(self._features[unit_id].keys()) return feature_names else: raise ValueError(str(unit_id) + " is not a valid unit_id") else: raise ValueError(str(unit_id) + " must be an int") def get_shared_unit_spike_feature_names(self, unit_ids=None): """Get the intersection of unit feature names for a given set of units or for all units if unit_ids is None. Parameters ---------- unit_ids: array_like The unit ids for which the shared feature names will be returned. If None (default), will return shared feature names for all units Returns ------- property_names The list of shared feature names """ if unit_ids is None: unit_ids = self.get_unit_ids() if len(unit_ids) > 0: curr_feature_name_set = set(self.get_unit_spike_feature_names(unit_id=unit_ids[0])) for unit_id in unit_ids[1:]: curr_unit_feature_name_set = set(self.get_unit_spike_feature_names(unit_id=unit_id)) curr_feature_name_set = curr_feature_name_set.intersection(curr_unit_feature_name_set) feature_names = sorted(list(curr_feature_name_set)) else: feature_names = [] return feature_names def set_unit_property(self, unit_id, property_name, value): """This function adds a unit property data set under the given property name to the given unit. Parameters ---------- unit_id: int The unit id for which the property will be set property_name: str The name of the property to be stored value The data associated with the given property name. Could be many formats as specified by the user """ if isinstance(unit_id, (int, np.integer)): if unit_id in self.get_unit_ids(): if unit_id not in self._properties.keys(): self._properties[unit_id] = {} if isinstance(property_name, str): self._properties[unit_id][property_name] = value else: raise ValueError(str(property_name) + " must be a string") else: raise ValueError(str(unit_id) + " is not a valid unit_id") else: raise ValueError(str(unit_id) + " must be an int") def set_units_property(self, *, unit_ids=None, property_name, values): """Sets unit property data for a list of units Parameters ---------- unit_ids: list The list of unit ids for which the property will be set Defaults to get_unit_ids() property_name: str The name of the property value: list The list of values to be set """ if unit_ids is None: unit_ids = self.get_unit_ids() for i, unit in enumerate(unit_ids): self.set_unit_property(unit_id=unit, property_name=property_name, value=values[i]) def get_unit_property(self, unit_id, property_name): """This function returns the data stored under the property name given from the given unit. Parameters ---------- unit_id: int The unit id for which the property will be returned property_name: str The name of the property Returns ------- value The data associated with the given property name. Could be many formats as specified by the user """ if isinstance(unit_id, (int, np.integer)): if unit_id in self.get_unit_ids(): if unit_id not in self._properties.keys(): self._properties[unit_id] = {} if isinstance(property_name, str): if property_name in list(self._properties[unit_id].keys()): return self._properties[unit_id][property_name] else: raise ValueError(str(property_name) + " has not been added to unit " + str(unit_id)) else: raise ValueError(str(property_name) + " must be a string") else: raise ValueError(str(unit_id) + " is not a valid unit_id") else: raise ValueError(str(unit_id) + " must be an int") def get_units_property(self, *, unit_ids=None, property_name): """Returns a list of values stored under the property name corresponding to a list of units Parameters ---------- unit_ids: list The unit ids for which the property will be returned Defaults to get_unit_ids() property_name: str The name of the property Returns ------- values The list of values """ if unit_ids is None: unit_ids = self.get_unit_ids() values = [self.get_unit_property(unit_id=unit, property_name=property_name) for unit in unit_ids] return values def get_unit_property_names(self, unit_id): """Get a list of property names for a given unit. Parameters ---------- unit_id: int The unit id for which the property names will be returned Returns ------- property_names The list of property names """ if isinstance(unit_id, (int, np.integer)): if unit_id in self.get_unit_ids(): if unit_id not in self._properties.keys(): self._properties[unit_id] = {} property_names = sorted(self._properties[unit_id].keys()) return property_names else: raise ValueError(str(unit_id) + " is not a valid unit id") else: raise TypeError(str(unit_id) + " must be an int") def get_shared_unit_property_names(self, unit_ids=None): """Get the intersection of unit property names for a given set of units or for all units if unit_ids is None. Parameters ---------- unit_ids: array_like The unit ids for which the shared property names will be returned. If None (default), will return shared property names for all units Returns ------- property_names The list of shared property names """ if unit_ids is None: unit_ids = self.get_unit_ids() if len(unit_ids) > 0: curr_property_name_set = set(self.get_unit_property_names(unit_id=unit_ids[0])) for unit_id in unit_ids[1:]: curr_unit_property_name_set = set(self.get_unit_property_names(unit_id=unit_id)) curr_property_name_set = curr_property_name_set.intersection(curr_unit_property_name_set) property_names = sorted(list(curr_property_name_set)) else: property_names = [] return property_names def copy_unit_properties(self, sorting, unit_ids=None): """Copy unit properties from another sorting extractor to the current sorting extractor. Parameters ---------- sorting: SortingExtractor The sorting extractor from which the properties will be copied unit_ids: (array_like, (int, np.integer)) The list (or single value) of unit_ids for which the properties will be copied """ # Second condition: Ensure dictionary is not empty if unit_ids is None and len(self._properties.keys()) > 0: self._properties = deepcopy(sorting._properties) else: if unit_ids is None: unit_ids = sorting.get_unit_ids() if isinstance(unit_ids, (int, np.integer)): curr_property_names = sorting.get_unit_property_names(unit_id=unit_ids) for curr_property_name in curr_property_names: value = sorting.get_unit_property(unit_id=unit_ids, property_name=curr_property_name) self.set_unit_property(unit_id=unit_ids, property_name=curr_property_name, value=value) else: for unit_id in unit_ids: curr_property_names = sorting.get_unit_property_names(unit_id=unit_id) for curr_property_name in curr_property_names: value = sorting.get_unit_property(unit_id=unit_id, property_name=curr_property_name) self.set_unit_property(unit_id=unit_id, property_name=curr_property_name, value=value) def clear_unit_property(self, unit_id, property_name): """This function clears the unit property for the given property. Parameters ---------- unit_id: int The id that specifies a unit in the sorting property_name: string The name of the property to be cleared """ if unit_id in self._properties.keys(): if property_name in self._properties[unit_id]: del self._properties[unit_id][property_name] def clear_units_property(self, property_name, unit_ids=None): """This function clears the units' properties for the given property. Parameters ---------- property_name: string The name of the property to be cleared unit_ids: list A list of ids that specifies a set of units in the sorting. If None, all units are cleared """ if unit_ids is None: unit_ids = self.get_unit_ids() for unit_id in unit_ids: self.clear_unit_property(unit_id, property_name) def copy_unit_spike_features(self, sorting, unit_ids=None): """Copy unit spike features from another sorting extractor to the current sorting extractor. Parameters ---------- sorting: SortingExtractor The sorting extractor from which the spike features will be copied unit_ids: (array_like, (int, np.integer)) The list (or single value) of unit_ids for which the spike features will be copied """ if unit_ids is None: self._features = deepcopy(sorting._features) else: if isinstance(unit_ids, (int, np.integer)): unit_ids = [unit_ids] for unit_id in unit_ids: curr_feature_names = sorting.get_unit_spike_feature_names(unit_id=unit_id) for curr_feature_name in curr_feature_names: value = sorting.get_unit_spike_features(unit_id=unit_id, feature_name=curr_feature_name) if len(value) < len(sorting.get_unit_spike_train(unit_id)): if not curr_feature_name.endswith('idxs'): assert curr_feature_name + '_idxs' in \ sorting.get_unit_spike_feature_names(unit_id=unit_id) curr_feature_name_idxs = curr_feature_name + '_idxs' value_idxs = np.array(sorting.get_unit_spike_features(unit_id=unit_id, feature_name=curr_feature_name_idxs)) # find index of first spike self.set_unit_spike_features(unit_id=unit_id, feature_name=curr_feature_name, value=value, indexes=value_idxs) else: self.set_unit_spike_features(unit_id=unit_id, feature_name=curr_feature_name, value=value) def get_epoch(self, epoch_name): """This function returns a SubSortingExtractor which is a view to the given epoch. Parameters ---------- epoch_name: str The name of the epoch to be returned Returns ------- epoch_extractor: SubRecordingExtractor A SubRecordingExtractor which is a view to the given epoch """ epoch_info = self.get_epoch_info(epoch_name) start_frame = epoch_info['start_frame'] end_frame = epoch_info['end_frame'] from .subsortingextractor import SubSortingExtractor return SubSortingExtractor(parent_sorting=self, start_frame=start_frame, end_frame=end_frame) def get_sub_extractors_by_property(self, property_name, return_property_list=False): """Returns a list of SubSortingExtractors from this SortingExtractor based on the given property_name (e.g. group) Parameters ---------- property_name: str The property used to subdivide the extractor return_property_list: bool If True the property list is returned Returns ------- sub_list: list The list of subextractors to be returned """ if return_property_list: sub_list, prop_list = get_sub_extractors_by_property(self, property_name=property_name, return_property_list=return_property_list) return sub_list, prop_list else: sub_list = get_sub_extractors_by_property(self, property_name=property_name, return_property_list=return_property_list) return sub_list @staticmethod def write_sorting(sorting, save_path): """This function writes out the spike sorted data file of a given sorting extractor to the file format of this current sorting extractor. Allows for easy conversion between spike sorting file formats. It is a static method so it can be used without instantiating this sorting extractor. Parameters ---------- sorting: SortingExtractor A SortingExtractor that can extract information from the sorted data file to be converted to the new format save_path: string A path to where the converted sorted data will be saved, which may either be a file or a folder, depending on the format """ raise NotImplementedError("The write_sorting function is not \ implemented for this extractor") def get_unsorted_spike_train(self, start_frame=None, end_frame=None): """This function extracts spike frames from the unsorted events. It will return spike frames from within three ranges: [start_frame, t_start+1, ..., end_frame-1] [start_frame, start_frame+1, ..., final_unit_spike_frame - 1] [0, 1, ..., end_frame-1] [0, 1, ..., final_unit_spike_frame - 1] if both start_frame and end_frame are given, if only start_frame is given, if only end_frame is given, or if neither start_frame or end_frame are given, respectively. Spike frames are returned in the form of an array_like of spike frames. In this implementation, start_frame is inclusive and end_frame is exclusive conforming to numpy standards. Parameters ---------- start_frame: int The frame above which a spike frame is returned (inclusive) end_frame: int The frame below which a spike frame is returned (exclusive) Returns ---------- spike_train: numpy.ndarray An 1D array containing all the frames for each spike in the specified unit given the range of start and end frames """ raise NotImplementedError ================================================ FILE: spikeextractors/subrecordingextractor.py ================================================ from .recordingextractor import RecordingExtractor from .extraction_tools import check_get_traces_args, cast_start_end_frame, check_get_ttl_args import numpy as np # Encapsulates a sub-dataset class SubRecordingExtractor(RecordingExtractor): def __init__(self, parent_recording, *, channel_ids=None, renamed_channel_ids=None, start_frame=None, end_frame=None): start_frame, end_frame = cast_start_end_frame(start_frame, end_frame) self._parent_recording = parent_recording self._channel_ids = channel_ids self._renamed_channel_ids = renamed_channel_ids self._start_frame = start_frame self._end_frame = end_frame if self._channel_ids is None: self._channel_ids = self._parent_recording.get_channel_ids() if self._renamed_channel_ids is None: self._renamed_channel_ids = self._channel_ids if self._start_frame is None: self._start_frame = 0 if self._end_frame is None: self._end_frame = self._parent_recording.get_num_frames() if self._end_frame > self._parent_recording.get_num_frames(): self._end_frame = self._parent_recording.get_num_frames() self._original_channel_id_lookup = {} for i in range(len(self._channel_ids)): self._original_channel_id_lookup[self._renamed_channel_ids[i]] = self._channel_ids[i] RecordingExtractor.__init__(self) self.copy_channel_properties(parent_recording, channel_ids=self._renamed_channel_ids) # avoid rescaling twice self.clear_channel_gains() self.clear_channel_offsets() self.is_filtered = self._parent_recording.is_filtered self.has_unscaled = self._parent_recording.has_unscaled # update dump dict self._kwargs = {'parent_recording': parent_recording.make_serialized_dict(), 'channel_ids': channel_ids, 'renamed_channel_ids': renamed_channel_ids, 'start_frame': start_frame, 'end_frame': end_frame} @check_get_traces_args def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True): sf = self._start_frame + start_frame ef = self._start_frame + end_frame original_ch_ids = self.get_original_channel_ids(channel_ids) return self._parent_recording.get_traces(channel_ids=original_ch_ids, start_frame=sf, end_frame=ef, return_scaled=return_scaled) @check_get_ttl_args def get_ttl_events(self, start_frame=None, end_frame=None, channel_id=0): sf = self._start_frame + start_frame ef = self._start_frame + end_frame sf, ef = cast_start_end_frame(sf, ef) try: ttl_frames, ttl_states = self._parent_recording.get_ttl_events(start_frame=sf, end_frame=ef, channel_id=channel_id) ttl_frames -= self._start_frame return ttl_frames, ttl_states except NotImplementedError: raise NotImplementedError("The parent recording does implement the 'get_ttl_events method'") def get_channel_ids(self): return list(self._renamed_channel_ids) def get_num_frames(self): return self._end_frame - self._start_frame def get_sampling_frequency(self): return self._parent_recording.get_sampling_frequency() def frame_to_time(self, frame): frame2 = frame + self._start_frame time1 = self._parent_recording.frame_to_time(frame2) start_time = self._parent_recording.frame_to_time(self._start_frame) return np.round(time1 - start_time, 6) def time_to_frame(self, time): time2 = time + self._parent_recording.frame_to_time(self._start_frame) frame1 = self._parent_recording.time_to_frame(time2) frame2 = frame1 - self._start_frame return frame2.astype('int64') def get_snippets(self, reference_frames, snippet_len, channel_ids=None, return_scaled=True): if channel_ids is None: channel_ids = self.get_channel_ids() reference_frames_shift = self._start_frame + np.array(reference_frames) original_ch_ids = self.get_original_channel_ids(channel_ids) return self._parent_recording.get_snippets(reference_frames=reference_frames_shift, snippet_len=snippet_len, channel_ids=original_ch_ids, return_scaled=return_scaled) def copy_channel_properties(self, recording, channel_ids=None): if channel_ids is None: channel_ids = self.get_channel_ids() if isinstance(channel_ids, (int, np.integer)): recording_ch_id = channel_ids if recording is self._parent_recording: recording_ch_id = self.get_original_channel_ids(channel_ids) curr_property_names = recording.get_channel_property_names(channel_id=recording_ch_id) for curr_property_name in curr_property_names: if curr_property_name not in self._key_properties.keys(): # key property value = recording.get_channel_property(channel_id=recording_ch_id, property_name=curr_property_name) self.set_channel_property(channel_id=channel_ids, property_name=curr_property_name, value=value) else: if curr_property_name == 'group': group = recording.get_channel_groups(channel_ids=recording_ch_id) self.set_channel_groups(groups=group, channel_ids=channel_ids) elif curr_property_name == 'location': location = recording.get_channel_locations(channel_ids=recording_ch_id) self.set_channel_locations(locations=location, channel_ids=channel_ids) else: # copy key properties original_channel_ids = self.get_original_channel_ids(channel_ids) groups = recording.get_channel_groups(channel_ids=original_channel_ids) locations = recording.get_channel_locations(channel_ids=original_channel_ids) gains = recording.get_channel_gains(channel_ids=original_channel_ids) offsets = recording.get_channel_offsets(channel_ids=original_channel_ids) self.set_channel_groups(groups=groups, channel_ids=channel_ids) self.set_channel_locations(locations=locations, channel_ids=channel_ids) self.set_channel_gains(gains=gains, channel_ids=channel_ids) self.set_channel_offsets(offsets=offsets, channel_ids=channel_ids) # copy normal properties for channel_id in channel_ids: recording_ch_id = channel_id if recording is self._parent_recording: recording_ch_id = self.get_original_channel_ids(channel_id) curr_property_names = recording.get_channel_property_names(channel_id=recording_ch_id) for curr_property_name in curr_property_names: if curr_property_name not in self._key_properties.keys(): # key property value = recording.get_channel_property(channel_id=recording_ch_id, property_name=curr_property_name) self.set_channel_property(channel_id=channel_id, property_name=curr_property_name, value=value) def get_original_channel_ids(self, channel_ids): if isinstance(channel_ids, (int, np.integer)): if channel_ids in self.get_channel_ids(): original_ch_ids = self._original_channel_id_lookup[channel_ids] else: raise ValueError("Non-valid channel_id") else: original_ch_ids = [] for channel_id in channel_ids: if isinstance(channel_id, (int, np.integer)): if channel_id in self.get_channel_ids(): original_ch_id = self._original_channel_id_lookup[channel_id] original_ch_ids.append(original_ch_id) else: raise ValueError("Non-valid channel_id") else: raise ValueError("channel_id must be an int") return original_ch_ids ================================================ FILE: spikeextractors/subsortingextractor.py ================================================ from .sortingextractor import SortingExtractor import numpy as np from .extraction_tools import check_get_unit_spike_train # Encapsulates a subset of a spike sorted data file class SubSortingExtractor(SortingExtractor): def __init__(self, parent_sorting, *, unit_ids=None, renamed_unit_ids=None, start_frame=None, end_frame=None): SortingExtractor.__init__(self) start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame) self._parent_sorting = parent_sorting self._unit_ids = unit_ids self._renamed_unit_ids = renamed_unit_ids self._start_frame = start_frame self._end_frame = end_frame if self._unit_ids is None: self._unit_ids = self._parent_sorting.get_unit_ids() if self._renamed_unit_ids is None: self._renamed_unit_ids = self._unit_ids if self._start_frame is None: self._start_frame = 0 if self._end_frame is None: self._end_frame = np.Inf self._original_unit_id_lookup = {} for i in range(len(self._unit_ids)): self._original_unit_id_lookup[self._renamed_unit_ids[i]] = self._unit_ids[i] self.copy_unit_properties(parent_sorting, unit_ids=self._renamed_unit_ids) self.copy_unit_spike_features(parent_sorting, unit_ids=self._renamed_unit_ids, start_frame=start_frame, end_frame=end_frame) self._kwargs = {'parent_sorting': parent_sorting.make_serialized_dict(), 'unit_ids': unit_ids, 'renamed_unit_ids': renamed_unit_ids, 'start_frame': start_frame, 'end_frame': end_frame} def get_unit_ids(self): return list(self._renamed_unit_ids) @check_get_unit_spike_train def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None): original_unit_id = self._original_unit_id_lookup[unit_id] sf = self._start_frame + start_frame ef = self._start_frame + end_frame if sf < self._start_frame: sf = self._start_frame if ef > self._end_frame: ef = self._end_frame if ef == np.Inf: ef = None return self._parent_sorting.get_unit_spike_train(unit_id=original_unit_id, start_frame=sf, end_frame=ef) - self._start_frame def get_sampling_frequency(self): return self._parent_sorting.get_sampling_frequency() def frame_to_time(self, frame): frame2 = frame + self._start_frame time1 = self._parent_sorting.frame_to_time(frame2) start_time = self._parent_sorting.frame_to_time(self._start_frame) return np.round(time1 - start_time, 6) def time_to_frame(self, time): time2 = time + self._parent_sorting.frame_to_time(self._start_frame) frame1 = self._parent_sorting.time_to_frame(time2) frame2 = frame1 - self._start_frame return frame2.astype('int64') def copy_unit_properties(self, sorting, unit_ids=None): if unit_ids is None: unit_ids = self.get_unit_ids() if isinstance(unit_ids, (int, np.integer)): sorting_unit_id = unit_ids if sorting is self._parent_sorting: sorting_unit_id = self.get_original_unit_ids(unit_ids) curr_property_names = sorting.get_unit_property_names(unit_id=sorting_unit_id) for curr_property_name in curr_property_names: value = sorting.get_unit_property(unit_id=sorting_unit_id, property_name=curr_property_name) self.set_unit_property(unit_id=unit_ids, property_name=curr_property_name, value=value) else: for unit_id in unit_ids: sorting_unit_id = unit_id if sorting is self._parent_sorting: sorting_unit_id = self.get_original_unit_ids(unit_id) curr_property_names = sorting.get_unit_property_names(unit_id=sorting_unit_id) for curr_property_name in curr_property_names: value = sorting.get_unit_property(unit_id=sorting_unit_id, property_name=curr_property_name) self.set_unit_property(unit_id=unit_id, property_name=curr_property_name, value=value) def copy_unit_spike_features(self, sorting, unit_ids=None, start_frame=None, end_frame=None): start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame) if unit_ids is None: unit_ids = self.get_unit_ids() if isinstance(unit_ids, (int, np.integer)): unit_ids = [unit_ids] for unit_id in unit_ids: sorting_unit_id = unit_id if sorting is self._parent_sorting: sorting_unit_id = self.get_original_unit_ids(unit_id) curr_feature_names = sorting.get_unit_spike_feature_names(unit_id=sorting_unit_id) for curr_feature_name in curr_feature_names: value = sorting.get_unit_spike_features(unit_id=sorting_unit_id, feature_name=curr_feature_name, start_frame=start_frame, end_frame=end_frame) if len(value) < len(sorting.get_unit_spike_train(sorting_unit_id, start_frame=start_frame, end_frame=end_frame)): if not curr_feature_name.endswith('idxs'): assert curr_feature_name + '_idxs' in \ sorting.get_unit_spike_feature_names(unit_id=sorting_unit_id) curr_feature_name_idxs = curr_feature_name + '_idxs' value_idxs = np.array(sorting.get_unit_spike_features(unit_id=sorting_unit_id, feature_name=curr_feature_name_idxs, start_frame=start_frame, end_frame=end_frame)) # find index of first spike if start_frame is not None: discarded_spikes_idxs = np.where(sorting.get_unit_spike_train(sorting_unit_id) < start_frame) if len(discarded_spikes_idxs) > 0: n_discarded = len(discarded_spikes_idxs[0]) value_idxs = value_idxs - n_discarded self.set_unit_spike_features(unit_id=unit_id, feature_name=curr_feature_name, value=value, indexes=value_idxs) else: self.set_unit_spike_features(unit_id=unit_id, feature_name=curr_feature_name, value=value) def get_original_unit_ids(self, unit_ids): if isinstance(unit_ids, (int, np.integer)): if unit_ids in self.get_unit_ids(): original_unit_ids = self._original_unit_id_lookup[unit_ids] else: raise ValueError("Non-valid unit_id") else: original_unit_ids = [] for unit_id in unit_ids: if isinstance(unit_id, (int, np.integer)): if unit_id in self.get_unit_ids(): original_unit_id = self._original_unit_id_lookup[unit_id] original_unit_ids.append(original_unit_id) else: raise ValueError("Non-valid unit_id") else: raise ValueError("unit_id must be an int") return original_unit_ids ================================================ FILE: spikeextractors/testing.py ================================================ import os import shutil from pathlib import Path import uuid from datetime import datetime import numpy as np from .extraction_tools import load_extractor_from_pickle, load_extractor_from_dict, \ load_extractor_from_json def check_recordings_equal(RX1, RX2, return_scaled=True, force_dtype=None, check_times=True): N = RX1.get_num_frames() # get_channel_ids assert np.allclose(RX1.get_channel_ids(), RX2.get_channel_ids()) # get_num_channels assert np.allclose(RX1.get_num_channels(), RX2.get_num_channels()) # get_num_frames assert np.allclose(RX1.get_num_frames(), RX2.get_num_frames()) # get_sampling_frequency assert np.allclose(RX1.get_sampling_frequency(), RX2.get_sampling_frequency()) # get_traces if force_dtype is None: assert np.allclose(RX1.get_traces(return_scaled=return_scaled), RX2.get_traces(return_scaled=return_scaled)) else: assert np.allclose(RX1.get_traces(return_scaled=return_scaled).astype(force_dtype), RX2.get_traces(return_scaled=return_scaled).astype(force_dtype)) sf = 0 ef = N if RX1.get_num_channels() > 1: ch = [RX1.get_channel_ids()[0], RX1.get_channel_ids()[-1]] else: ch = RX1.get_channel_ids() if force_dtype is None: assert np.allclose(RX1.get_traces(channel_ids=ch, start_frame=sf, end_frame=ef, return_scaled=return_scaled), RX2.get_traces(channel_ids=ch, start_frame=sf, end_frame=ef, return_scaled=return_scaled)) else: assert np.allclose(RX1.get_traces(channel_ids=ch, start_frame=sf, end_frame=ef, return_scaled=return_scaled).astype(force_dtype), RX2.get_traces(channel_ids=ch, start_frame=sf, end_frame=ef, return_scaled=return_scaled).astype(force_dtype)) if check_times: for f in range(0, RX1.get_num_frames(), 10): assert np.isclose(RX1.frame_to_time(f), RX2.frame_to_time(f)) assert np.isclose(RX1.time_to_frame(RX1.frame_to_time(f)), RX2.time_to_frame(RX2.frame_to_time(f))) # get_snippets frames = [30, 50, 80] snippets1 = RX1.get_snippets(reference_frames=frames, snippet_len=20, return_scaled=return_scaled) snippets2 = RX2.get_snippets(reference_frames=frames, snippet_len=(10, 10), return_scaled=return_scaled) if force_dtype is None: for ii in range(len(frames)): assert np.allclose(snippets1[ii], snippets2[ii]) else: for ii in range(len(frames)): assert np.allclose(snippets1[ii].astype(force_dtype), snippets2[ii].astype(force_dtype)) def check_recording_properties(RX1, RX2): # check properties assert sorted(RX1.get_shared_channel_property_names()) == sorted(RX2.get_shared_channel_property_names()) for prop in RX1.get_shared_channel_property_names(): for ch in RX1.get_channel_ids(): if not isinstance(RX1.get_channel_property(ch, prop), str): assert np.allclose(np.array(RX1.get_channel_property(ch, prop)), np.array(RX2.get_channel_property(ch, prop))) else: assert RX1.get_channel_property(ch, prop) == RX2.get_channel_property(ch, prop) def check_recording_return_types(RX): channel_ids = RX.get_channel_ids() assert isinstance(RX.get_num_channels(), (int, np.integer)) assert isinstance(RX.get_num_frames(), (int, np.integer)) assert isinstance(RX.get_sampling_frequency(), float) assert isinstance(RX.get_traces(start_frame=0, end_frame=10), (np.ndarray, np.memmap)) for channel_id in channel_ids: assert isinstance(channel_id, (int, np.integer)) def check_sorting_return_types(SX): unit_ids = SX.get_unit_ids() assert (all(isinstance(id, (int, np.integer)) or isinstance(id, np.integer) for id in unit_ids)) for id in unit_ids: train = SX.get_unit_spike_train(id) # print(train) assert (all(isinstance(x, (int, np.integer)) or isinstance(x, np.integer) for x in train)) def check_sortings_equal(SX1, SX2): # get_unit_ids ids1 = np.sort(np.array(SX1.get_unit_ids())) ids2 = np.sort(np.array(SX2.get_unit_ids())) assert (np.allclose(ids1, ids2)) for id in ids1: train1 = np.sort(SX1.get_unit_spike_train(id)) train2 = np.sort(SX2.get_unit_spike_train(id)) assert np.array_equal(train1, train2) def check_sorting_properties_features(SX1, SX2): # check properties print(SX1.__class__) print('Properties', sorted(SX1.get_shared_unit_property_names()), sorted(SX2.get_shared_unit_property_names())) assert sorted(SX1.get_shared_unit_property_names()) == sorted(SX2.get_shared_unit_property_names()) for prop in SX1.get_shared_unit_property_names(): for u in SX1.get_unit_ids(): if not isinstance(SX1.get_unit_property(u, prop), str): assert np.allclose(np.array(SX1.get_unit_property(u, prop)), np.array(SX2.get_unit_property(u, prop))) else: assert SX1.get_unit_property(u, prop) == SX2.get_unit_property(u, prop) # check features print('Features', sorted(SX1.get_shared_unit_spike_feature_names()), sorted(SX2.get_shared_unit_spike_feature_names())) assert sorted(SX1.get_shared_unit_spike_feature_names()) == sorted(SX2.get_shared_unit_spike_feature_names()) for feat in SX1.get_shared_unit_spike_feature_names(): for u in SX1.get_unit_ids(): assert np.allclose(np.array(SX1.get_unit_spike_features(u, feat)), np.array(SX2.get_unit_spike_features(u, feat))) def check_dumping(extractor, test_relative=False): # dump to dict d = extractor.dump_to_dict() extractor_loaded = load_extractor_from_dict(d) if 'Recording' in str(type(extractor)): check_recordings_equal(extractor, extractor_loaded, check_times=False) elif 'Sorting' in str(type(extractor)): check_sortings_equal(extractor, extractor_loaded) # dump to json # without file_name extractor.dump_to_json() if 'Recording' in str(type(extractor)): extractor_loaded = load_extractor_from_json('spikeinterface_recording.json') check_recordings_equal(extractor, extractor_loaded, check_times=False) elif 'Sorting' in str(type(extractor)): extractor_loaded = load_extractor_from_json('spikeinterface_sorting.json') check_sortings_equal(extractor, extractor_loaded) # with file_name extractor.dump_to_json(file_path='test_dumping/test.json') extractor_loaded = load_extractor_from_json('test_dumping/test.json') if 'Recording' in str(type(extractor)): check_recordings_equal(extractor, extractor_loaded, check_times=False) elif 'Sorting' in str(type(extractor)): check_sortings_equal(extractor, extractor_loaded) # dump to pickle # without file_name extractor.dump_to_pickle() if 'Recording' in str(type(extractor)): extractor_loaded = load_extractor_from_pickle('spikeinterface_recording.pkl') check_recordings_equal(extractor, extractor_loaded, check_times=True) check_recording_properties(extractor, extractor_loaded) elif 'Sorting' in str(type(extractor)): extractor_loaded = load_extractor_from_pickle('spikeinterface_sorting.pkl') check_sortings_equal(extractor, extractor_loaded) check_sorting_properties_features(extractor, extractor_loaded) # with file_name extractor.dump_to_pickle(file_path='test_dumping/test.pkl') extractor_loaded = load_extractor_from_pickle('test_dumping/test.pkl') if 'Recording' in str(type(extractor)): check_recordings_equal(extractor, extractor_loaded, check_times=True) check_recording_properties(extractor, extractor_loaded) elif 'Sorting' in str(type(extractor)): check_sortings_equal(extractor, extractor_loaded) check_sorting_properties_features(extractor, extractor_loaded) if test_relative: # dump to dict with relative path d = extractor.dump_to_dict(relative_to=".") extractor_loaded = load_extractor_from_dict(d) if 'Recording' in str(type(extractor)): check_recordings_equal(extractor, extractor_loaded, check_times=False) elif 'Sorting' in str(type(extractor)): check_sortings_equal(extractor, extractor_loaded) # dump to json with relative path extractor.dump_to_json(file_path='test_dumping/test_rel.json', relative_to=".") extractor_loaded = load_extractor_from_json('test_dumping/test_rel.json') if 'Recording' in str(type(extractor)): check_recordings_equal(extractor, extractor_loaded, check_times=False) elif 'Sorting' in str(type(extractor)): check_sortings_equal(extractor, extractor_loaded) # dump to pickle with relative path extractor.dump_to_pickle(file_path='test_dumping/test_rel.pkl', relative_to=".") extractor_loaded = load_extractor_from_pickle('test_dumping/test_rel.pkl') if 'Recording' in str(type(extractor)): check_recordings_equal(extractor, extractor_loaded, check_times=True) elif 'Sorting' in str(type(extractor)): check_sortings_equal(extractor, extractor_loaded) shutil.rmtree('test_dumping') if Path('spikeinterface_recording.json').is_file(): os.remove('spikeinterface_recording.json') if Path('spikeinterface_sorting.json').is_file(): os.remove('spikeinterface_sorting.json') if Path('spikeinterface_recording.pkl').is_file(): os.remove('spikeinterface_recording.pkl') if Path('spikeinterface_sorting.pkl').is_file(): os.remove('spikeinterface_sorting.pkl') def get_default_nwbfile_metadata(): """ Returns structure with defaulted metadata values required for a NWBFile. """ metadata = dict( NWBFile=dict( session_description="no description", session_start_time=datetime(1970, 1, 1), identifier=str(uuid.uuid4()) ), Ecephys=dict( Device=[dict( name='Device_ecephys', description='no description' )], ElectrodeGroup=[], ElectricalSeries_raw=dict( name='raw_traces', description='those are the raw traces' ), ElectricalSeries_processed=dict( name='processed_traces', description='those are the processed traces' ), ElectricalSeries_lfp=dict( name='lfp_traces', description='those are the lfp traces' ) ) ) return metadata ================================================ FILE: spikeextractors/version.py ================================================ version = '0.9.11' ================================================ FILE: tests/__init__.py ================================================ ================================================ FILE: tests/probe_test.prb ================================================ channel_groups = { 1: { 'channels': list(range(16)), 'graph' : [], 'geometry': { 0: [ 0.0 , 0.0], 1: [ 0.0 , 50.0], 2: [+21.65, 262.5], 3: [+21.65, 237.5], 4: [+21.65, 187.5], 5: [+21.65, 137.5], 6: [+21.65, 87.5], 7: [+21.65, 37.5], 8: [ 0.0 , 200.0], 9: [ 0.0 , 250.0], 10: [+21.65, 62.5], 11: [+21.65, 112.5], 12: [+21.65, 162.5], 13: [+21.65, 212.5], 14: [ 0.0 , 150.0], 15: [ 0.0 , 100.0], } }, 2: { 'channels': list(range(16,32)), 'graph' : [], 'geometry': { 16: [ 0.0 , 125.0], 17: [ 0.0 , 175.0], 18: [-21.65, 212.5], 19: [-21.65, 162.5], 20: [-21.65, 112.5], 21: [-21.65, 62.5], 22: [ 0.0 , 275.0], 23: [ 0.0 , 225.0], 24: [-21.65, 37.5], 25: [-21.65, 87.5], 26: [-21.65, 137.5], 27: [-21.65, 187.5], 28: [-21.65, 237.5], 29: [-21.65, 262.5], 30: [ 0.0 , 75.0], 31: [ 0.0 , 25.0], } } } ================================================ FILE: tests/test_extractors.py ================================================ import os import shutil import tempfile import unittest from pathlib import Path import numpy as np import spikeextractors as se from spikeextractors.exceptions import NotDumpableExtractorError from spikeextractors.testing import (check_sortings_equal, check_recordings_equal, check_dumping, check_recording_return_types, check_sorting_return_types, get_default_nwbfile_metadata) class TestExtractors(unittest.TestCase): def setUp(self): self.RX, self.RX2, self.RX3, self.SX, self.SX2, self.SX3, self.example_info = self._create_example(seed=0) self.test_dir = tempfile.mkdtemp() # self.test_dir = '.' def tearDown(self): # Remove the directory after the test del self.RX, self.RX2, self.RX3, self.SX, self.SX2, self.SX3 shutil.rmtree(self.test_dir) # pass def _create_example(self, seed): channel_ids = [0, 1, 2, 3] num_channels = 4 num_frames = 10000 num_ttls = 30 sampling_frequency = 30000 X = np.random.RandomState(seed=seed).normal(0, 1, (num_channels, num_frames)) geom = np.random.RandomState(seed=seed).normal(0, 1, (num_channels, 2)) X = (X * 100).astype(int) ttls = np.sort(np.random.permutation(num_frames)[:num_ttls]) RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) RX.set_ttls(ttls) RX.set_channel_locations([0, 0], channel_ids=0) RX.add_epoch("epoch1", 0, 10) RX.add_epoch("epoch2", 10, 20) for i, channel_id in enumerate(RX.get_channel_ids()): RX.set_channel_property(channel_id=channel_id, property_name='shared_channel_prop', value=i) RX2 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) RX2.copy_epochs(RX) times = np.arange(RX.get_num_frames()) / RX.get_sampling_frequency() + 5 RX2.set_times(times) RX3 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) SX = se.NumpySortingExtractor() SX.set_sampling_frequency(sampling_frequency) spike_times = [200, 300, 400] train1 = np.sort(np.rint(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times[0])).astype(int)) SX.add_unit(unit_id=1, times=train1) SX.add_unit(unit_id=2, times=np.sort(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times[1]))) SX.add_unit(unit_id=3, times=np.sort(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times[2]))) SX.set_unit_property(unit_id=1, property_name='stability', value=80) SX.add_epoch("epoch1", 0, 10) SX.add_epoch("epoch2", 10, 20) SX2 = se.NumpySortingExtractor() SX2.set_sampling_frequency(sampling_frequency) spike_times2 = [100, 150, 450] train2 = np.rint(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times2[0])).astype(int) SX2.add_unit(unit_id=3, times=train2) SX2.add_unit(unit_id=4, times=np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times2[1])) SX2.add_unit(unit_id=5, times=np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times2[2])) SX2.set_unit_property(unit_id=4, property_name='stability', value=80) SX2.set_unit_spike_features(unit_id=3, feature_name='widths', value=np.asarray([3] * spike_times2[0])) SX2.copy_epochs(SX) SX2.copy_times(RX2) for i, unit_id in enumerate(SX2.get_unit_ids()): SX2.set_unit_property(unit_id=unit_id, property_name='shared_unit_prop', value=i) SX2.set_unit_spike_features( unit_id=unit_id, feature_name='shared_unit_feature', value=np.asarray([i] * spike_times2[i]) ) SX3 = se.NumpySortingExtractor() train3 = np.asarray([1, 20, 21, 35, 38, 45, 46, 47]) SX3.add_unit(unit_id=0, times=train3) features3 = np.asarray([0, 5, 10, 15, 20, 25, 30, 35]) features4 = np.asarray([0, 10, 20, 30]) feature4_idx = np.asarray([0, 2, 4, 6]) SX3.set_unit_spike_features(unit_id=0, feature_name='dummy', value=features3) SX3.set_unit_spike_features(unit_id=0, feature_name='dummy2', value=features4, indexes=feature4_idx) example_info = dict( channel_ids=channel_ids, num_channels=num_channels, num_frames=num_frames, sampling_frequency=sampling_frequency, unit_ids=[1, 2, 3], train1=train1, train2=train2, train3=train3, features3=features3, unit_prop=80, channel_prop=(0, 0), ttls=ttls, epochs_info=((0, 10), (10, 20)), geom=geom, times=times ) return (RX, RX2, RX3, SX, SX2, SX3, example_info) def test_example(self): self.assertEqual(self.RX.get_channel_ids(), self.example_info['channel_ids']) self.assertEqual(self.RX.get_num_channels(), self.example_info['num_channels']) self.assertEqual(self.RX.get_num_frames(), self.example_info['num_frames']) self.assertEqual(self.RX.get_sampling_frequency(), self.example_info['sampling_frequency']) self.assertEqual(self.SX.get_unit_ids(), self.example_info['unit_ids']) self.assertEqual(self.RX.get_channel_locations(0)[0][0], self.example_info['channel_prop'][0]) self.assertEqual(self.RX.get_channel_locations(0)[0][1], self.example_info['channel_prop'][1]) self.assertTrue(np.array_equal(self.RX.get_ttl_events()[0], self.example_info['ttls'])) self.assertEqual(self.SX.get_unit_property(unit_id=1, property_name='stability'), self.example_info['unit_prop']) self.assertTrue(np.array_equal(self.SX.get_unit_spike_train(1), self.example_info['train1'])) self.assertTrue(issubclass(self.SX.get_unit_spike_train(1).dtype.type, np.integer)) self.assertTrue(self.RX.get_shared_channel_property_names(), ['group', 'location', 'shared_channel_prop']) self.assertTrue(self.RX.get_channel_property_names(0), ['group', 'location', 'shared_channel_prop']) self.assertTrue(self.SX2.get_shared_unit_property_names(), ['shared_unit_prop']) self.assertTrue(self.SX2.get_unit_property_names(4), ['shared_unit_prop', 'stability']) self.assertTrue(self.SX2.get_shared_unit_spike_feature_names(), ['shared_unit_feature']) self.assertTrue(self.SX2.get_unit_spike_feature_names(3), ['shared_channel_prop', 'widths']) self.assertTrue(np.array_equal(self.SX3.get_unit_spike_features(0, 'dummy'), self.example_info['features3'])) self.assertTrue(np.array_equal(self.SX3.get_unit_spike_features(0, 'dummy', start_frame=4), self.example_info['features3'][1:])) self.assertTrue(np.array_equal(self.SX3.get_unit_spike_features(0, 'dummy', end_frame=4), self.example_info['features3'][:1])) self.assertTrue(np.array_equal(self.SX3.get_unit_spike_features(0, 'dummy', start_frame=20, end_frame=46), self.example_info['features3'][1:6])) self.assertTrue('dummy2' in self.SX3.get_unit_spike_feature_names(0)) self.assertTrue('dummy2_idxs' in self.SX3.get_unit_spike_feature_names(0)) sub_extractor_full = se.SubSortingExtractor(self.SX3) sub_extractor_partial = se.SubSortingExtractor(self.SX3, start_frame=20, end_frame=46) self.assertTrue(np.array_equal(sub_extractor_full.get_unit_spike_features(0, 'dummy'), self.SX3.get_unit_spike_features(0, 'dummy'))) self.assertTrue(np.array_equal(sub_extractor_partial.get_unit_spike_features(0, 'dummy'), self.SX3.get_unit_spike_features(0, 'dummy', start_frame=20, end_frame=46))) self.assertEqual(tuple(self.RX.get_epoch_info("epoch1").values()), self.example_info['epochs_info'][0]) self.assertEqual(tuple(self.RX.get_epoch_info("epoch2").values()), self.example_info['epochs_info'][1]) self.assertEqual(tuple(self.SX.get_epoch_info("epoch1").values()), self.example_info['epochs_info'][0]) self.assertEqual(tuple(self.SX.get_epoch_info("epoch2").values()), self.example_info['epochs_info'][1]) self.assertEqual(tuple(self.RX.get_epoch_info("epoch1").values()), tuple(self.RX2.get_epoch_info("epoch1").values())) self.assertEqual(tuple(self.RX.get_epoch_info("epoch2").values()), tuple(self.RX2.get_epoch_info("epoch2").values())) self.assertEqual(tuple(self.SX.get_epoch_info("epoch1").values()), tuple(self.SX2.get_epoch_info("epoch1").values())) self.assertEqual(tuple(self.SX.get_epoch_info("epoch2").values()), tuple(self.SX2.get_epoch_info("epoch2").values())) self.assertTrue(np.array_equal(self.RX2.frame_to_time(np.arange(self.RX2.get_num_frames())), self.example_info['times'])) self.assertTrue(np.array_equal(self.SX2.get_unit_spike_train(3) / self.SX2.get_sampling_frequency() + 5, self.SX2.frame_to_time(self.SX2.get_unit_spike_train(3)))) self.RX3.clear_channel_locations() self.assertTrue('location' not in self.RX3.get_shared_channel_property_names()) self.RX3.set_channel_locations(self.example_info['geom']) self.assertTrue(np.array_equal(self.RX3.get_channel_locations(), self.RX2.get_channel_locations())) self.RX3.set_channel_groups(groups=[1], channel_ids=[1]) self.assertEqual(self.RX3.get_channel_groups(channel_ids=[1]), 1) self.RX3.clear_channel_groups() self.assertEqual(self.RX3.get_channel_groups(channel_ids=[1]), 0) self.RX3.set_channel_locations(locations=[[np.nan, np.nan, np.nan]], channel_ids=[1]) self.assertTrue('location' not in self.RX3.get_shared_channel_property_names()) self.RX3.set_channel_locations(locations=[[0, 0, 0]], channel_ids=[1]) self.assertTrue('location' in self.RX3.get_shared_channel_property_names()) check_recording_return_types(self.RX) def test_allocate_arrays(self): shape = (30, 1000) dtype = 'int16' arr_in_memory = self.RX.allocate_array(shape=shape, dtype=dtype, memmap=False) arr_memmap = self.RX.allocate_array(shape=shape, dtype=dtype, memmap=True) assert isinstance(arr_in_memory, np.ndarray) assert isinstance(arr_memmap, np.memmap) assert arr_in_memory.shape == shape assert arr_memmap.shape == shape assert arr_in_memory.dtype == dtype assert arr_memmap.dtype == dtype arr_in_memory = self.SX.allocate_array(shape=shape, dtype=dtype, memmap=False) arr_memmap = self.SX.allocate_array(shape=shape, dtype=dtype, memmap=True) assert isinstance(arr_in_memory, np.ndarray) assert isinstance(arr_memmap, np.memmap) assert arr_in_memory.shape == shape assert arr_memmap.shape == shape assert arr_in_memory.dtype == dtype assert arr_memmap.dtype == dtype def test_cache_extractor(self): cache_rec = se.CacheRecordingExtractor(self.RX) check_recording_return_types(cache_rec) check_recordings_equal(self.RX, cache_rec) cache_rec.move_to('cache_rec') assert cache_rec.filename == 'cache_rec.dat' check_dumping(cache_rec, test_relative=True) cache_rec = se.CacheRecordingExtractor(self.RX, save_path='cache_rec2') check_recording_return_types(cache_rec) check_recordings_equal(self.RX, cache_rec) assert cache_rec.filename == 'cache_rec2.dat' check_dumping(cache_rec, test_relative=True) # test saving to file del cache_rec assert Path('cache_rec2.dat').is_file() # test tmp cache_rec = se.CacheRecordingExtractor(self.RX) tmp_file = cache_rec.filename del cache_rec assert not Path(tmp_file).is_file() cache_sort = se.CacheSortingExtractor(self.SX) check_sorting_return_types(cache_sort) check_sortings_equal(self.SX, cache_sort) cache_sort.move_to('cache_sort') assert cache_sort.filename == 'cache_sort.npz' check_dumping(cache_sort, test_relative=True) # test saving to file del cache_sort assert Path('cache_sort.npz').is_file() cache_sort = se.CacheSortingExtractor(self.SX, save_path='cache_sort2') check_sorting_return_types(cache_sort) check_sortings_equal(self.SX, cache_sort) assert cache_sort.filename == 'cache_sort2.npz' check_dumping(cache_sort, test_relative=True) # test saving to file del cache_sort assert Path('cache_sort2.npz').is_file() # test tmp cache_sort = se.CacheSortingExtractor(self.SX) tmp_file = cache_sort.filename del cache_sort assert not Path(tmp_file).is_file() # cleanup os.remove('cache_rec.dat') os.remove('cache_rec2.dat') os.remove('cache_sort.npz') os.remove('cache_sort2.npz') def test_not_dumpable_exception(self): try: self.RX.dump_to_json() except Exception as e: assert isinstance(e, NotDumpableExtractorError) try: self.RX.dump_to_pickle() except Exception as e: assert isinstance(e, NotDumpableExtractorError) def test_mda_extractor(self): path1 = self.test_dir + '/mda' path2 = path1 + '/firings_true.mda' se.MdaRecordingExtractor.write_recording(self.RX, path1) se.MdaSortingExtractor.write_sorting(self.SX, path2) RX_mda = se.MdaRecordingExtractor(path1) SX_mda = se.MdaSortingExtractor(path2) check_recording_return_types(RX_mda) check_recordings_equal(self.RX, RX_mda) check_sorting_return_types(SX_mda) check_sortings_equal(self.SX, SX_mda) check_dumping(RX_mda) check_dumping(SX_mda) def test_hdsort_extractor(self): path = self.test_dir + '/results_test_hdsort_extractor.mat' locations = np.ones((10, 2)) se.HDSortSortingExtractor.write_sorting(self.SX, path, locations=locations, noise_std_by_channel=None) SX_hd = se.HDSortSortingExtractor(path) check_sorting_return_types(SX_hd) check_sortings_equal(self.SX, SX_hd) check_dumping(SX_hd) def test_npz_extractor(self): path = self.test_dir + '/sorting.npz' se.NpzSortingExtractor.write_sorting(self.SX, path) SX_npz = se.NpzSortingExtractor(path) # empty write sorting_empty = se.NumpySortingExtractor() path_empty = self.test_dir + '/sorting_empty.npz' se.NpzSortingExtractor.write_sorting(sorting_empty, path_empty) check_sorting_return_types(SX_npz) check_sortings_equal(self.SX, SX_npz) check_dumping(SX_npz) def test_biocam_extractor(self): path1 = self.test_dir + '/raw.brw' se.BiocamRecordingExtractor.write_recording(self.RX, path1) RX_biocam = se.BiocamRecordingExtractor(path1) check_recording_return_types(RX_biocam) check_recordings_equal(self.RX, RX_biocam) check_dumping(RX_biocam) def test_mearec_extractors(self): path1 = self.test_dir + '/raw.h5' se.MEArecRecordingExtractor.write_recording(self.RX, path1) RX_mearec = se.MEArecRecordingExtractor(path1) tr = RX_mearec.get_traces(channel_ids=[0, 1], end_frame=1000) check_recording_return_types(RX_mearec) check_recordings_equal(self.RX, RX_mearec) check_dumping(RX_mearec) path2 = self.test_dir + '/firings_true.h5' se.MEArecSortingExtractor.write_sorting(self.SX, path2, self.RX.get_sampling_frequency()) SX_mearec = se.MEArecSortingExtractor(path2) check_sorting_return_types(SX_mearec) check_sortings_equal(self.SX, SX_mearec) check_dumping(SX_mearec) def test_hs2_extractor(self): path1 = self.test_dir + '/firings_true.hdf5' se.HS2SortingExtractor.write_sorting(self.SX, path1) SX_hs2 = se.HS2SortingExtractor(path1) check_sorting_return_types(SX_hs2) check_sortings_equal(self.SX, SX_hs2) self.assertEqual(SX_hs2.get_sampling_frequency(), self.SX.get_sampling_frequency()) check_dumping(SX_hs2) def test_exdir_extractors(self): path1 = self.test_dir + '/raw.exdir' se.ExdirRecordingExtractor.write_recording(self.RX, path1) RX_exdir = se.ExdirRecordingExtractor(path1) check_recording_return_types(RX_exdir) check_recordings_equal(self.RX, RX_exdir) check_dumping(RX_exdir) path2 = self.test_dir + '/firings.exdir' se.ExdirSortingExtractor.write_sorting(self.SX, path2, self.RX) SX_exdir = se.ExdirSortingExtractor(path2) check_sorting_return_types(SX_exdir) check_sortings_equal(self.SX, SX_exdir) check_dumping(SX_exdir) def test_spykingcircus_extractor(self): path1 = self.test_dir + '/sc' se.SpykingCircusSortingExtractor.write_sorting(self.SX, path1) SX_spy = se.SpykingCircusSortingExtractor(path1) check_sorting_return_types(SX_spy) check_sortings_equal(self.SX, SX_spy) check_dumping(SX_spy) def test_multi_sub_recording_extractor(self): RX_multi = se.MultiRecordingTimeExtractor( recordings=[self.RX, self.RX, self.RX], epoch_names=['A', 'B', 'C'] ) RX_sub = RX_multi.get_epoch('C') check_recordings_equal(self.RX, RX_sub) check_recordings_equal(self.RX, RX_multi.recordings[0]) check_recordings_equal(self.RX, RX_multi.recordings[1]) check_recordings_equal(self.RX, RX_multi.recordings[2]) self.assertEqual(4, len(RX_sub.get_channel_ids())) RX_multi = se.MultiRecordingChannelExtractor( recordings=[self.RX, self.RX2, self.RX3], groups=[1, 2, 3] ) RX_sub = se.SubRecordingExtractor(RX_multi, channel_ids=[4, 5, 6, 7], renamed_channel_ids=[0, 1, 2, 3]) # RX2 has times check_recordings_equal(self.RX2, RX_sub, check_times=False) check_recordings_equal(self.RX, RX_multi.recordings[0]) check_recordings_equal(self.RX2, RX_multi.recordings[1], check_times=False) check_recordings_equal(self.RX3, RX_multi.recordings[2]) self.assertEqual([2, 2, 2, 2], list(RX_sub.get_channel_groups())) self.assertEqual(12, len(RX_multi.get_channel_ids())) RX_multi = se.MultiRecordingChannelExtractor( recordings=[self.RX2, self.RX2, self.RX2], groups=[1, 2, 3] ) RX_sub = se.SubRecordingExtractor(RX_multi, channel_ids=[4, 5, 6, 7], renamed_channel_ids=[0, 1, 2, 3]) check_recordings_equal(self.RX2, RX_sub, check_times=False) check_recordings_equal(self.RX2, RX_multi.recordings[0]) check_recordings_equal(self.RX2, RX_multi.recordings[1], check_times=False) check_recordings_equal(self.RX2, RX_multi.recordings[2]) self.assertTrue(np.array_equal([2, 2, 2, 2], list(RX_sub.get_channel_groups()))) self.assertTrue(12 == len(RX_multi.get_channel_ids())) self.assertTrue(np.array_equal(RX_multi.frame_to_time(np.arange(RX_multi.get_num_frames())), np.arange(RX_multi.get_num_frames()) / RX_multi.get_sampling_frequency() + 5)) rx1 = self.RX rx2 = self.RX2 rx3 = self.RX3 rx2.set_channel_property(0, "foo", 100) rx3.set_channel_locations([11, 11], channel_ids=0) RX_multi_c = se.MultiRecordingChannelExtractor( recordings=[rx1, rx2, rx3], groups=[0, 0, 1] ) self.assertTrue(np.array_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], RX_multi_c.get_channel_ids())) self.assertTrue(np.array_equal([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1], RX_multi_c.get_channel_groups())) self.assertEqual(rx2.get_channel_property(0, "foo"), RX_multi_c.get_channel_property(4, "foo")) self.assertTrue(np.array_equal(rx3.get_channel_locations([0])[0], RX_multi_c.get_channel_locations([8])[0])) def test_ttl_frames_in_sub_multi(self): # sub recording start_frame = self.example_info['num_frames'] // 3 end_frame = 2 * self.example_info['num_frames'] // 3 RX_sub = se.SubRecordingExtractor(self.RX, start_frame=start_frame, end_frame=end_frame) original_ttls = self.RX.get_ttl_events()[0] ttls_in_sub = original_ttls[np.where((original_ttls >= start_frame) & (original_ttls < end_frame))[0]] self.assertTrue(np.array_equal(RX_sub.get_ttl_events()[0], ttls_in_sub - start_frame)) # multirecording RX_multi = se.MultiRecordingTimeExtractor(recordings=[self.RX, self.RX, self.RX]) ttls_originals = self.RX.get_ttl_events()[0] num_ttls = len(ttls_originals) self.assertEqual(len(RX_multi.get_ttl_events()[0]), 3 * num_ttls) self.assertTrue(np.array_equal(RX_multi.get_ttl_events()[0][:num_ttls], ttls_originals)) self.assertTrue(np.array_equal(RX_multi.get_ttl_events()[0][num_ttls:2 * num_ttls], ttls_originals + self.RX.get_num_frames())) self.assertTrue(np.array_equal(RX_multi.get_ttl_events()[0][2 * num_ttls:], ttls_originals + 2 * self.RX.get_num_frames())) def test_multi_sub_sorting_extractor(self): N = self.RX.get_num_frames() SX_multi = se.MultiSortingExtractor( sortings=[self.SX, self.SX, self.SX], ) SX_multi.set_unit_property(unit_id=1, property_name='dummy', value=5) SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=0) check_sortings_equal(SX_multi, SX_sub) self.assertEqual(SX_multi.get_unit_property(1, 'dummy'), SX_sub.get_unit_property(1, 'dummy')) N = self.RX.get_num_frames() SX_multi = se.MultiSortingExtractor( sortings=[self.SX, self.SX2], ) SX_sub1 = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=0, end_frame=N) check_sortings_equal(SX_multi, SX_sub1) check_sortings_equal(self.SX, SX_multi.sortings[0]) check_sortings_equal(self.SX2, SX_multi.sortings[1]) def test_dump_load_multi_sub_extractor(self): # generate dumpable formats path1 = self.test_dir + '/mda' path2 = path1 + '/firings_true.mda' se.MdaRecordingExtractor.write_recording(self.RX, path1) se.MdaSortingExtractor.write_sorting(self.SX, path2) RX_mda = se.MdaRecordingExtractor(path1) SX_mda = se.MdaSortingExtractor(path2) RX_multi_chan = se.MultiRecordingChannelExtractor(recordings=[RX_mda, RX_mda, RX_mda]) check_dumping(RX_multi_chan) RX_multi_time = se.MultiRecordingTimeExtractor(recordings=[RX_mda, RX_mda, RX_mda], ) check_dumping(RX_multi_time) RX_multi_chan = se.SubRecordingExtractor(RX_mda, channel_ids=[0, 1]) check_dumping(RX_multi_chan) SX_sub = se.SubSortingExtractor(SX_mda, unit_ids=[1, 2]) check_dumping(SX_sub) SX_multi = se.MultiSortingExtractor(sortings=[SX_mda, SX_mda, SX_mda]) check_dumping(SX_multi) def test_nwb_extractor(self): path1 = self.test_dir + '/test.nwb' se.NwbRecordingExtractor.write_recording(self.RX, path1) RX_nwb = se.NwbRecordingExtractor(path1) check_recording_return_types(RX_nwb) check_recordings_equal(self.RX, RX_nwb) check_dumping(RX_nwb) del RX_nwb se.NwbRecordingExtractor.write_recording(recording=self.RX, save_path=path1, overwrite=True) RX_nwb = se.NwbRecordingExtractor(path1) check_recording_return_types(RX_nwb) check_recordings_equal(self.RX, RX_nwb) check_dumping(RX_nwb) # append sorting to existing file se.NwbSortingExtractor.write_sorting(sorting=self.SX, save_path=path1, overwrite=False) path2 = self.test_dir + "/firings_true.nwb" se.NwbRecordingExtractor.write_recording(recording=self.RX, save_path=path2) se.NwbSortingExtractor.write_sorting(sorting=self.SX, save_path=path2) SX_nwb = se.NwbSortingExtractor(path2) check_sortings_equal(self.SX, SX_nwb) check_dumping(SX_nwb) # Test for handling unit property descriptions argument property_descriptions = dict(stability="This is a description of stability.") se.NwbRecordingExtractor.write_recording(recording=self.RX, save_path=path1, overwrite=True) se.NwbSortingExtractor.write_sorting( sorting=self.SX, save_path=path1, property_descriptions=property_descriptions ) SX_nwb = se.NwbSortingExtractor(path1) check_sortings_equal(self.SX, SX_nwb) check_dumping(SX_nwb) # Test for handling skip_properties argument se.NwbRecordingExtractor.write_recording(recording=self.RX, save_path=path1, overwrite=True) se.NwbSortingExtractor.write_sorting( sorting=self.SX, save_path=path1, skip_properties=['stability'] ) SX_nwb = se.NwbSortingExtractor(path1) assert 'stability' not in SX_nwb.get_shared_unit_property_names() check_sortings_equal(self.SX, SX_nwb) check_dumping(SX_nwb) # Test for handling skip_features argument se.NwbRecordingExtractor.write_recording(recording=self.RX, save_path=path1, overwrite=True) # SX2 has timestamps, so loading it back from Nwb will not recover the same spike frames. USe use_times=False se.NwbSortingExtractor.write_sorting( sorting=self.SX2, save_path=path1, skip_features=['widths'], use_times=False ) SX_nwb = se.NwbSortingExtractor(path1) assert 'widths' not in SX_nwb.get_shared_unit_spike_feature_names() check_sortings_equal(self.SX2, SX_nwb) check_dumping(SX_nwb) # Test writting multiple recordings using metadata metadata = get_default_nwbfile_metadata() path_nwb = self.test_dir + '/test_multiple.nwb' se.NwbRecordingExtractor.write_recording( recording=self.RX, save_path=path_nwb, metadata=metadata, write_as='raw', es_key='ElectricalSeries_raw', ) se.NwbRecordingExtractor.write_recording( recording=self.RX2, save_path=path_nwb, metadata=metadata, write_as='processed', es_key='ElectricalSeries_processed', ) se.NwbRecordingExtractor.write_recording( recording=self.RX3, save_path=path_nwb, metadata=metadata, write_as='lfp', es_key='ElectricalSeries_lfp', ) RX_nwb = se.NwbRecordingExtractor( file_path=path_nwb, electrical_series_name='raw_traces' ) check_recording_return_types(RX_nwb) check_recordings_equal(self.RX, RX_nwb) check_dumping(RX_nwb) del RX_nwb def test_nixio_extractor(self): path1 = os.path.join(self.test_dir, 'raw.nix') se.NIXIORecordingExtractor.write_recording(self.RX, path1) RX_nixio = se.NIXIORecordingExtractor(path1) check_recording_return_types(RX_nixio) check_recordings_equal(self.RX, RX_nixio) check_dumping(RX_nixio) del RX_nixio # test force overwrite se.NIXIORecordingExtractor.write_recording(self.RX, path1, overwrite=True) path2 = self.test_dir + '/firings_true.nix' se.NIXIOSortingExtractor.write_sorting(self.SX, path2) SX_nixio = se.NIXIOSortingExtractor(path2) check_sorting_return_types(SX_nixio) check_sortings_equal(self.SX, SX_nixio) check_dumping(SX_nixio) def test_shybrid_extractors(self): # test sorting extractor se.SHYBRIDSortingExtractor.write_sorting(self.SX, self.test_dir) initial_sorting_file = os.path.join(self.test_dir, 'initial_sorting.csv') SX_shybrid = se.SHYBRIDSortingExtractor(initial_sorting_file) check_sorting_return_types(SX_shybrid) check_sortings_equal(self.SX, SX_shybrid) check_dumping(SX_shybrid) # test recording extractor se.SHYBRIDRecordingExtractor.write_recording(self.RX, self.test_dir, initial_sorting_file) RX_shybrid = se.SHYBRIDRecordingExtractor(os.path.join(self.test_dir, 'recording.bin')) check_recording_return_types(RX_shybrid) check_recordings_equal(self.RX, RX_shybrid) check_dumping(RX_shybrid) def test_neuroscope_extractors(self): # NeuroscopeRecordingExtractor tests nscope_dir = Path(self.test_dir) / 'neuroscope_rec0' dat_file = nscope_dir / 'neuroscope_rec0.dat' se.NeuroscopeRecordingExtractor.write_recording(self.RX, nscope_dir) RX_ns = se.NeuroscopeRecordingExtractor(dat_file) check_recording_return_types(RX_ns) check_recordings_equal(self.RX, RX_ns, force_dtype='int32') check_dumping(RX_ns) check_recording_return_types(RX_ns) check_recordings_equal(self.RX, RX_ns, force_dtype='int32') check_dumping(RX_ns) del RX_ns # overwrite nscope_dir = Path(self.test_dir) / 'neuroscope_rec1' dat_file = nscope_dir / 'neuroscope_rec1.dat' se.NeuroscopeRecordingExtractor.write_recording(recording=self.RX, save_path=nscope_dir) RX_ns = se.NeuroscopeRecordingExtractor(dat_file) check_recording_return_types(RX_ns) check_recordings_equal(self.RX, RX_ns) check_dumping(RX_ns) # NeuroscopeMultiRecordingTimeExtractor tests nscope_dir = Path(self.test_dir) / "neuroscope_rec2" dat_file = nscope_dir / "neuroscope_rec2.dat" RX_multirecording = se.MultiRecordingTimeExtractor(recordings=[self.RX, self.RX]) se.NeuroscopeMultiRecordingTimeExtractor.write_recording(recording=RX_multirecording, save_path=nscope_dir) RX_mre = se.NeuroscopeMultiRecordingTimeExtractor(folder_path=nscope_dir) check_recording_return_types(RX_mre) check_recordings_equal(RX_multirecording, RX_mre) check_dumping(RX_mre) # NeuroscopeSortingExtractor tests nscope_dir = Path(self.test_dir) / 'neuroscope_sort0' sort_name = 'neuroscope_sort0' initial_sorting_resfile = Path(self.test_dir) / sort_name / f'{sort_name}.res' initial_sorting_clufile = Path(self.test_dir) / sort_name / f'{sort_name}.clu' se.NeuroscopeSortingExtractor.write_sorting(self.SX, nscope_dir) SX_neuroscope = se.NeuroscopeSortingExtractor(resfile_path=initial_sorting_resfile, clufile_path=initial_sorting_clufile) check_sorting_return_types(SX_neuroscope) check_sortings_equal(self.SX, SX_neuroscope) check_dumping(SX_neuroscope) SX_neuroscope_no_mua = se.NeuroscopeSortingExtractor(resfile_path=initial_sorting_resfile, clufile_path=initial_sorting_clufile, keep_mua_units=False) check_sorting_return_types(SX_neuroscope_no_mua) check_dumping(SX_neuroscope_no_mua) # Test for extra argument 'keep_mua_units' resulted in the right output SX_neuroscope_no_mua = se.NeuroscopeSortingExtractor(resfile_path=initial_sorting_resfile, clufile_path=initial_sorting_clufile, keep_mua_units=False) check_sorting_return_types(SX_neuroscope_no_mua) check_dumping(SX_neuroscope_no_mua) num_original_units = len(SX_neuroscope.get_unit_ids()) self.assertEqual(list(SX_neuroscope.get_unit_ids()), list(range(1, num_original_units + 1))) self.assertEqual(list(SX_neuroscope_no_mua.get_unit_ids()), list(range(2, num_original_units + 1))) # Tests for the auto-detection of format for NeuroscopeSortingExtractor SX_neuroscope_from_fp = se.NeuroscopeSortingExtractor(folder_path=nscope_dir) check_sorting_return_types(SX_neuroscope_from_fp) check_sortings_equal(self.SX, SX_neuroscope_from_fp) check_dumping(SX_neuroscope_from_fp) # Tests for the NeuroscopeMultiSortingExtractor nscope_dir = Path(self.test_dir) / 'neuroscope_sort1' SX_multisorting = se.MultiSortingExtractor(sortings=[self.SX, self.SX]) se.NeuroscopeMultiSortingExtractor.write_sorting(SX_multisorting, nscope_dir) SX_neuroscope_mse = se.NeuroscopeMultiSortingExtractor(nscope_dir) check_sorting_return_types(SX_neuroscope_mse) check_sortings_equal(SX_multisorting, SX_neuroscope_mse) check_dumping(SX_neuroscope_mse) def test_cell_explorer_extractor(self): sorter_id = "cell_explorer_sorter" cell_explorer_dir = Path(self.test_dir) / sorter_id spikes_matfile_path = cell_explorer_dir / f"{sorter_id}.spikes.cellinfo.mat" se.CellExplorerSortingExtractor.write_sorting(sorting=self.SX, save_path=spikes_matfile_path) SX_cell_explorer = se.CellExplorerSortingExtractor(spikes_matfile_path=spikes_matfile_path) check_sorting_return_types(SX_cell_explorer) check_sortings_equal(self.SX, SX_cell_explorer) check_dumping(SX_cell_explorer) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_gin_repo.py ================================================ import tempfile import unittest from pathlib import Path import numpy as np import sys from datalad.api import install, Dataset from parameterized import parameterized import spikeextractors as se from spikeextractors.testing import check_recordings_equal, check_sortings_equal run_local = False test_nwb = True test_caching = True if sys.platform == "linux" or run_local: class TestNwbConversions(unittest.TestCase): def setUp(self): pt = Path.cwd() / 'ephy_testing_data' if pt.exists(): self.dataset = Dataset(pt) else: self.dataset = install('https://gin.g-node.org/NeuralEnsemble/ephy_testing_data') # Must pin to previous dataset version # See https://github.com/SpikeInterface/spikeextractors/pull/675 self.dataset.repo.call_git(['checkout', '17e8f37674d70af84cdba6acd83df964a8e09f0c']) self.savedir = Path(tempfile.mkdtemp()) @parameterized.expand([ ( se.AxonaRecordingExtractor, "axona", dict(filename=str(Path.cwd() / "ephy_testing_data" / "axona" / "axona_raw.set")) ), ( se.BlackrockRecordingExtractor, "blackrock/blackrock_2_1", dict( filename=str(Path.cwd() / "ephy_testing_data" / "blackrock" / "blackrock_2_1" / "l101210-001"), seg_index=0, nsx_to_load=5 ) ), ( se.IntanRecordingExtractor, "intan", dict(file_path=Path.cwd() / "ephy_testing_data" / "intan" / "intan_rhd_test_1.rhd") ), ( se.IntanRecordingExtractor, "intan", dict(file_path=Path.cwd() / "ephy_testing_data" / "intan" / "intan_rhs_test_1.rhs") ), # Klusta - no .prm config file in ephy_testing # ( # se.KlustaRecordingExtractor, # "kwik", # dict(folder_path=Path.cwd() / "ephy_testing_data" / "kwik") # ), ( se.MEArecRecordingExtractor, "mearec/mearec_test_10s.h5", dict(file_path=Path.cwd() / "ephy_testing_data" / "mearec" / "mearec_test_10s.h5") ), ( se.NeuralynxRecordingExtractor, "neuralynx/Cheetah_v5.7.4/original_data", dict( dirname=Path.cwd() / "ephy_testing_data" / "neuralynx" / "Cheetah_v5.7.4" / "original_data", seg_index=0 ) ), ( se.NeuroscopeRecordingExtractor, "neuroscope/test1", dict(file_path=Path.cwd() / "ephy_testing_data" / "neuroscope" / "test1" / "test1.dat") ), # Nixio - RuntimeError: Cannot open non-existent file in ReadOnly mode! # ( # se.NIXIORecordingExtractor, # "nix", # dict(file_path=str(Path.cwd() / "ephy_testing_data" / "neoraw.nix")) # ), ( se.OpenEphysRecordingExtractor, "openephys/OpenEphys_SampleData_1", dict(folder_path=Path.cwd() / "ephy_testing_data" / "openephys" / "OpenEphys_SampleData_1") ), ( se.OpenEphysRecordingExtractor, "openephysbinary/v0.4.4.1_with_video_tracking", dict(folder_path=Path.cwd() / "ephy_testing_data" / "openephysbinary" / "v0.4.4.1_with_video_tracking") ), ( se.OpenEphysNPIXRecordingExtractor, "openephysbinary/v0.5.3_two_neuropixels_stream", dict( folder_path=Path.cwd() / "ephy_testing_data" / "openephysbinary" / "v0.5.3_two_neuropixels_stream" / "Record_Node_107") ), ( se.NeuropixelsDatRecordingExtractor, "openephysbinary/v0.5.3_two_neuropixels_stream", dict( file_path=Path.cwd() / "ephy_testing_data" / "openephysbinary" / "v0.5.3_two_neuropixels_stream" / "Record_Node_107" / "experiment1" / "recording1" / "continuous" / "Neuropix-PXI-116.0" / "continuous.dat", settings_file=Path.cwd() / "ephy_testing_data" / "openephysbinary" / "v0.5.3_two_neuropixels_stream" / "Record_Node_107" / "settings.xml") ), ( se.PhyRecordingExtractor, "phy/phy_example_0", dict(folder_path=Path.cwd() / "ephy_testing_data" / "phy" / "phy_example_0") ), # Plexon - AssertionError: This file have several channel groups spikeextractors support only one groups # ( # se.PlexonRecordingExtractor, # "plexon", # dict(filename=Path.cwd() / "ephy_testing_data" / "plexon" / "File_plexon_2.plx") # ), ( se.CEDRecordingExtractor, "spike2/m365_1sec.smrx", dict( file_path=Path.cwd() / "ephy_testing_data" / "spike2" / "m365_1sec.smrx", smrx_channel_ids=range(10) ) ), ( se.SpikeGLXRecordingExtractor, "spikeglx/Noise4Sam_g0", dict( file_path=Path.cwd() / "ephy_testing_data" / "spikeglx" / "Noise4Sam_g0" / "Noise4Sam_g0_imec0" / "Noise4Sam_g0_t0.imec0.ap.bin" ) ) ]) def test_convert_recording_extractor_to_nwb(self, se_class, dataset_path, se_kwargs): print(f"\n\n\n TESTING {se_class.extractor_name}...") dataset_stem = Path(dataset_path).stem self.dataset.get(dataset_path) recording = se_class(**se_kwargs) # # test writing to NWB if test_nwb: nwb_save_path = self.savedir / f"{se_class.__name__}_test_{dataset_stem}.nwb" se.NwbRecordingExtractor.write_recording(recording, nwb_save_path, write_scaled=True) nwb_recording = se.NwbRecordingExtractor(nwb_save_path) check_recordings_equal(recording, nwb_recording, check_times=False) if recording.has_unscaled: nwb_save_path_unscaled = self.savedir / f"{se_class.__name__}_test_{dataset_stem}_unscaled.nwb" if np.all(recording.get_channel_offsets() == 0): se.NwbRecordingExtractor.write_recording(recording, nwb_save_path_unscaled, write_scaled=False) nwb_recording = se.NwbRecordingExtractor(nwb_save_path_unscaled) check_recordings_equal(recording, nwb_recording, return_scaled=False, check_times=False) # Skip check when NWB converts uint to int if recording.get_dtype(return_scaled=False) == nwb_recording.get_dtype(return_scaled=False): check_recordings_equal(recording, nwb_recording, return_scaled=True, check_times=False) # test caching if test_caching: rec_cache = se.CacheRecordingExtractor(recording) check_recordings_equal(recording, rec_cache) if recording.has_unscaled: rec_cache_unscaled = se.CacheRecordingExtractor(recording, return_scaled=False) check_recordings_equal(recording, rec_cache_unscaled, return_scaled=False) check_recordings_equal(recording, rec_cache_unscaled, return_scaled=True) @parameterized.expand([ ( se.BlackrockSortingExtractor, "blackrock/blackrock_2_1", dict( filename=str(Path.cwd() / "ephy_testing_data" / "blackrock" / "blackrock_2_1" / "l101210-001"), seg_index=0, nsx_to_load=5 ) ), ( se.KlustaSortingExtractor, "kwik", dict(file_or_folder_path=Path.cwd() / "ephy_testing_data" / "kwik" / "neo.kwik") ), # Neuralynx - units_ids = nwbfile.units.id[:] - AttributeError: 'NoneType' object has no attribute 'id' # Is the GIN data OK? Or are there no units? # ( # se.NeuralynxSortingExtractor, # "neuralynx/Cheetah_v5.7.4/original_data", # dict( # dirname=Path.cwd() / "ephy_testing_data" / "neuralynx" / "Cheetah_v5.7.4" / "original_data", # seg_index=0 # ) # ), # NIXIO - return [int(da.label) for da in self._spike_das] # TypeError: int() argument must be a string, a bytes-like object or a number, not 'NoneType' # ( # se.NIXIOSortingExtractor, # "nix/nixio_fr.nix", # dict(file_path=str(Path.cwd() / "ephy_testing_data" / "nix" / "nixio_fr.nix")) # ), ( se.MEArecSortingExtractor, "mearec/mearec_test_10s.h5", dict(file_path=Path.cwd() / "ephy_testing_data" / "mearec" / "mearec_test_10s.h5") ), ( se.PhySortingExtractor, "phy/phy_example_0", dict(folder_path=Path.cwd() / "ephy_testing_data" / "phy" / "phy_example_0") ), ( se.PlexonSortingExtractor, "plexon", dict(filename=Path.cwd() / "ephy_testing_data" / "plexon" / "File_plexon_2.plx") ), ( se.SpykingCircusSortingExtractor, "spykingcircus/spykingcircus_example0", dict( file_or_folder_path=Path.cwd() / "ephy_testing_data" / "spykingcircus" / "spykingcircus_example0" / "recording" ) ), # # Tridesclous - dataio error, GIN data is not correct? # ( # se.TridesclousSortingExtractor, # "tridesclous/tdc_example0", # dict(folder_path=Path.cwd() / "ephy_testing_data" / "tridesclous" / "tdc_example0") # ) ]) def test_convert_sorting_extractor_to_nwb(self, se_class, dataset_path, se_kwargs): print(f"\n\n\n TESTING {se_class.extractor_name}...") dataset_stem = Path(dataset_path).stem self.dataset.get(dataset_path) sorting = se_class(**se_kwargs) sf = sorting.get_sampling_frequency() if sf is None: # need to set dummy sampling frequency since no associated acquisition in file sf = 30000 sorting.set_sampling_frequency(sf) if test_nwb: nwb_save_path = self.savedir / f"{se_class.__name__}_test_{dataset_stem}.nwb" se.NwbSortingExtractor.write_sorting(sorting, nwb_save_path) nwb_sorting = se.NwbSortingExtractor(nwb_save_path, sampling_frequency=sf) check_sortings_equal(sorting, nwb_sorting) if test_caching: sort_cache = se.CacheSortingExtractor(sorting) check_sortings_equal(sorting, sort_cache) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_numpy_extractors.py ================================================ import numpy as np import unittest import spikeextractors as se class TestNumpyExtractors(unittest.TestCase): def setUp(self): M = 4 N = 10000 N_ttl = 50 seed = 0 sampling_frequency = 30000 X = np.random.RandomState(seed=seed).normal(0, 1, (M, N)) geom = np.random.RandomState(seed=seed).normal(0, 1, (M, 2)) self._X = X self._geom = geom self._sampling_frequency = sampling_frequency self.RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom) self._ttl_frames = np.sort(np.random.permutation(N)[:N_ttl]) self.RX.set_ttls(self._ttl_frames) self.SX = se.NumpySortingExtractor() L = 200 self._train1 = np.rint(np.random.RandomState(seed=seed).uniform(0, N, L)).astype(int) self.SX.add_unit(unit_id=1, times=self._train1) self.SX.add_unit(unit_id=2, times=np.random.RandomState(seed=seed).uniform(0, N, L)) self.SX.add_unit(unit_id=3, times=np.random.RandomState(seed=seed).uniform(0, N, L)) def tearDown(self): pass def test_recording_extractor(self): # get_channel_ids self.assertEqual(self.RX.get_channel_ids(), [i for i in range(self._X.shape[0])]) # get_num_channels self.assertEqual(self.RX.get_num_channels(), self._X.shape[0]) # get_num_frames self.assertEqual(self.RX.get_num_frames(), self._X.shape[1]) # get_sampling_frequency self.assertEqual(self.RX.get_sampling_frequency(), self._sampling_frequency) # get_traces self.assertTrue(np.allclose(self.RX.get_traces(), self._X)) self.assertTrue( np.allclose(self.RX.get_traces(channel_ids=[0, 3], start_frame=0, end_frame=12), self._X[[0, 3], 0:12])) # get_channel_property - location self.assertTrue(np.allclose(np.array(self.RX.get_channel_locations(1)), self._geom[1, :])) # time_to_frame / frame_to_time self.assertEqual(self.RX.time_to_frame(12), 12 * self.RX.get_sampling_frequency()) self.assertEqual(self.RX.frame_to_time(12), 12 / self.RX.get_sampling_frequency()) # get_snippets snippets = self.RX.get_snippets(reference_frames=[0, 30, 50], snippet_len=20) self.assertTrue(np.allclose(snippets[1], self._X[:, 20:40])) # get_ttl_events self.assertTrue(np.array_equal(self.RX.get_ttl_events()[0], self._ttl_frames)) def test_sorting_extractor(self): unit_ids = [1, 2, 3] # get_unit_ids self.assertEqual(self.SX.get_unit_ids(), unit_ids) # get_unit_spike_train st = self.SX.get_unit_spike_train(unit_id=1) self.assertTrue(np.allclose(st, self._train1)) if __name__ == '__main__': unittest.main() ================================================ FILE: tests/test_tools.py ================================================ import numpy as np import unittest import tempfile import shutil import spikeextractors as se import os from copy import copy from pathlib import Path this_file = Path(__file__).parent class TestTools(unittest.TestCase): def setUp(self): M = 32 N = 10000 seed = 0 sampling_frequency = 30000 X = np.random.RandomState(seed=seed).normal(0, 1, (M, N)) self._X = X self._sampling_frequency = sampling_frequency self.RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency) self.RX.set_channel_locations(np.random.randn(32, 3)) self.test_dir = Path(tempfile.mkdtemp()) def tearDown(self): shutil.rmtree(self.test_dir) def test_load_save_probes(self): sub_RX = se.load_probe_file(self.RX, this_file / 'probe_test.prb') # print(SX.get_channel_property_names()) assert 'location' in sub_RX.get_shared_channel_property_names() assert 'group' in sub_RX.get_shared_channel_property_names() positions = [sub_RX.get_channel_locations(chan)[0] for chan in range(self.RX.get_num_channels())] # save in csv sub_RX.save_to_probe_file(self.test_dir / 'geom.csv') # load csv locations sub_RX_load = sub_RX.load_probe_file(self.test_dir / 'geom.csv') position_loaded = [sub_RX_load.get_channel_locations(chan)[0] for chan in range(sub_RX_load.get_num_channels())] self.assertTrue(np.allclose(positions[10], position_loaded[10])) # prb file RX = copy(self.RX) channel_groups = [] n_group = 4 for i in RX.get_channel_ids(): channel_groups.append(i // n_group) RX.set_channel_groups(channel_groups) RX.save_to_probe_file(this_file / 'probe_test_no_groups.prb') RX.save_to_probe_file(this_file / 'probe_test_groups.prb', grouping_property='group') # load RX_loaded_no_groups = se.load_probe_file(RX, this_file / 'probe_test_no_groups.prb') RX_loaded_groups = se.load_probe_file(RX, this_file / 'probe_test_groups.prb') assert len(np.unique(RX_loaded_no_groups.get_channel_groups())) == 1 assert len(np.unique(RX_loaded_groups.get_channel_groups())) == RX.get_num_channels() // n_group # cleanup (this_file / 'probe_test_no_groups.prb').unlink() (this_file / 'probe_test_groups.prb').unlink() def test_write_dat_file(self): nb_sample = self.RX.get_num_frames() nb_chan = self.RX.get_num_channels() # time_axis=0 chunk_size=None self.RX.write_to_binary_dat_format(self.test_dir / 'rec.dat', time_axis=0, dtype='float32', chunk_size=None) data = np.memmap(self.test_dir / 'rec.dat', dtype='float32', mode='r', shape=(nb_sample, nb_chan)).T assert np.allclose(data, self.RX.get_traces()) del data # this close the file # time_axis=1 chunk_size=None self.RX.write_to_binary_dat_format(self.test_dir / 'rec.dat', time_axis=1, dtype='float32', chunk_size=None) data = np.memmap(self.test_dir / 'rec.dat', dtype='float32', mode='r', shape=(nb_chan, nb_sample)) assert np.allclose(data, self.RX.get_traces()) del data # this close the file # time_axis=0 chunk_size=99 self.RX.write_to_binary_dat_format(self.test_dir / 'rec.dat', time_axis=0, dtype='float32', chunk_size=99) data = np.memmap(self.test_dir / 'rec.dat', dtype='float32', mode='r', shape=(nb_sample, nb_chan)).T assert np.allclose(data, self.RX.get_traces()) del(data) # this close the file # time_axis=0 chunk_mb=2 self.RX.write_to_binary_dat_format(self.test_dir / 'rec.dat', time_axis=0, dtype='float32', chunk_mb=2) data = np.memmap(self.test_dir / 'rec.dat', dtype='float32', mode='r', shape=(nb_sample, nb_chan)).T assert np.allclose(data, self.RX.get_traces()) del data # this close the file # time_axis=1 chunk_mb=2 self.RX.write_to_binary_dat_format(self.test_dir / 'rec.dat', time_axis=1, dtype='float32', chunk_mb=2) data = np.memmap(self.test_dir / 'rec.dat', dtype='float32', mode='r', shape=(nb_chan, nb_sample)) assert np.allclose(data, self.RX.get_traces()) del data # this close the file # time_axis=0 chunk_mb=10, n_jobs=2 self.RX.write_to_binary_dat_format(self.test_dir / 'rec.dat', time_axis=0, dtype='float32', chunk_mb=10, n_jobs=2) data = np.memmap(self.test_dir / 'rec.dat', dtype='float32', mode='r', shape=(nb_sample, nb_chan)).T assert np.allclose(data, self.RX.get_traces()) del data # this close the file # time_axis=1 chunk_mb=10 n_jobs=2 self.RX.write_to_binary_dat_format(self.test_dir / 'rec.dat', time_axis=1, dtype='float32', chunk_mb=2, n_jobs=2) data = np.memmap(self.test_dir / 'rec.dat', dtype='float32', mode='r', shape=(nb_chan, nb_sample)) assert np.allclose(data, self.RX.get_traces()) del data # this close the file if __name__ == '__main__': unittest.main()