[
  {
    "path": ".github/workflows/python-package.yml",
    "content": "name: Python Package using Conda\n\non: \n  push: \n    branches:\n      - master\n  pull_request:\n    branches: [master]\n    types: [synchronize, opened, reopened, ready_for_review]\n\njobs:\n  build-and-test:\n    name: Test on (${{ matrix.os }})\n    runs-on: ${{ matrix.os }}\n    strategy:\n      fail-fast: false\n      matrix:\n        os: [\"ubuntu-latest\", \"macos-latest\", \"windows-latest\"]\n    steps:\n      - uses: actions/checkout@v2\n      - uses: s-weigand/setup-conda@v1\n        with:\n          python-version: 3.8\n      - name: Which python\n        run: |\n          conda --version\n          which python\n      - name: Install dependencies\n        run: |\n          conda install -c conda-forge datalad\n          conda install -c conda-forge ruamel.yaml\n          conda install flake8\n          conda install pytest\n          pip install -r requirements-dev.txt\n          pip install -r requirements.txt\n          pip install h5py==2.10\n          pip install -e .[full]\n          # needed for correct operation of git/git-annex/DataLad\n          git config --global user.email \"CI@example.com\"\n          git config --global user.name \"CI Almighty\"\n      - name: Lint with flake8\n        run: |\n          # stop the build if there are Python syntax errors or undefined names\n          flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics\n          # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide\n          flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics\n      - name: Test with pytest and build coverage report\n        run: |\n          pytest\n"
  },
  {
    "path": ".github/workflows/python-publish.yml",
    "content": "# This workflow will upload a Python Package using Twine when a release is created\n# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries\n\nname: Test and Upload Python Package\n\non:\n  push:\n    tags:\n       - '*'\n\njobs:\n  deploy:\n\n    runs-on: ubuntu-latest\n\n    steps:\n      - uses: actions/checkout@v2\n      - uses: s-weigand/setup-conda@v1\n        with:\n          python-version: 3.8\n      - name: Which python\n        run: |\n          conda --version\n          which python\n      - name: Install dependencies\n        run: |\n          conda install -c conda-forge datalad\n          conda install -c conda-forge ruamel.yaml\n          conda install flake8\n          conda install pytest\n          pip install setuptools wheel twine\n          pip install -r requirements-dev.txt\n          pip install -r requirements.txt\n          pip install h5py==2.10\n          pip install -e .[full]\n          # needed for correct operation of git/git-annex/DataLad\n          git config --global user.email \"CI@example.com\"\n          git config --global user.name \"CI Almighty\"\n      - name: Lint with flake8\n        run: |\n          # stop the build if there are Python syntax errors or undefined names\n          flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics\n          # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide\n          flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics\n      - name: Test with pytest and build coverage report\n        run: |\n          pytest\n      - name: Publish on PyPI\n        env:\n          TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}\n          TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}\n        run: |\n          python setup.py sdist bdist_wheel\n          twine upload dist/*\n"
  },
  {
    "path": ".gitignore",
    "content": ".eggs\n*.egg-info\n.ipynb_checkpoints\n__pycache__\nsample_*_dataset\nephy_testing_data/\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2019 SpikeInterface\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "include spikeextractors/extractors/neuropixelsdatrecordingextractor/channel_positions_neuropixels.txt\n"
  },
  {
    "path": "README.md",
    "content": "# SpikeExtractors (LEGACY)\nThe `spikeextractors` package has now been integrated into [spikeinterface](https://github.com/SpikeInterface/spikeinterface).\n\nThis package will be maintained for a while for bug fixes only, then it will be deprecated.\n\nNew features and improvements will only be implemented for the `spikeinterface` package.\n"
  },
  {
    "path": "environment-dev.yml",
    "content": "name: test\ndependencies:\n  - python=3.8\n  - pip\n  - pip:\n    - numpy==1.22.0\n    - tqdm\n    - lxml\n    - h5py\n    - shybrid\n    - pynwb\n    - nixio\n    - pyintan\n    - pyopenephys\n    - neo\n    - MEArec\n    - hdf5storage\n    - exdir\n    - hdbscan\n    - tridesclous\n    - parametrized\n"
  },
  {
    "path": "full_requirements.txt",
    "content": "h5py #>=3.2.1\nscipy>=1.6.3\npyintan>=0.3.0\npyopenephys>=1.1.4\nneo>=0.9.0\nMEArec<1.8\npynwb>=1.4\nlxml>=4.6.3\nnixio==1.5.0\nshybrid>=0.4.2\npyyaml>=5.4.1\nmtscomp>=1.0.1\nexdir==0.4.1\nhdf5storage\nsonpy;python_version>'3.7'"
  },
  {
    "path": "requirements-dev.txt",
    "content": "datalad\nparameterized\nneo==0.10"
  },
  {
    "path": "requirements.txt",
    "content": "numpy==1.22.0\ntqdm\npackaging"
  },
  {
    "path": "setup.py",
    "content": "import setuptools\n\nd = {}\nexec(open(\"spikeextractors/version.py\").read(), None, d)\nversion = d['version']\npkg_name = \"spikeextractors\"\nlong_description = open(\"README.md\").read()\n\nwith open(\"full_requirements.txt\", mode='r') as f:\n    full_requires = f.read().split('\\n')\nfull_requires = [e for e in full_requires if len(e) > 0]\n\nextras_require = {\"full\": full_requires}\n\nsetuptools.setup(\n    name=pkg_name,\n    version=version,\n    author=\"Alessio Buccino, Cole Hurwitz, Samuel Garcia, Jeremy Magland, Matthias Hennig\",\n    author_email=\"alessio.buccino@gmail.com\",\n    description=\"Python module for extracting recorded and spike sorted extracellular data from different file types and formats\",\n    url=\"https://github.com/SpikeInterface/spikeextractors\",\n    long_description=long_description,\n    long_description_content_type=\"text/markdown\",\n    packages=setuptools.find_packages(),\n    package_data={},\n    include_package_data=True,\n    install_requires=[\n        'numpy',\n        'tqdm',\n        'joblib'\n    ],\n    extras_require=extras_require,\n    classifiers=(\n        \"Programming Language :: Python :: 3\",\n        \"License :: OSI Approved :: Apache Software License\",\n        \"Operating System :: OS Independent\",\n    )\n)\n"
  },
  {
    "path": "spikeextractors/__init__.py",
    "content": "from .recordingextractor import RecordingExtractor\nfrom .sortingextractor import SortingExtractor\nfrom .cacheextractors import CacheRecordingExtractor, CacheSortingExtractor\nfrom .subsortingextractor import SubSortingExtractor\nfrom .subrecordingextractor import SubRecordingExtractor\nfrom .multirecordingchannelextractor import concatenate_recordings_by_channel, MultiRecordingChannelExtractor\nfrom .multirecordingtimeextractor import concatenate_recordings_by_time, MultiRecordingTimeExtractor\nfrom .multisortingextractor import concatenate_sortings, MultiSortingExtractor\n\nfrom .extractorlist import *\n\nfrom . import example_datasets\nfrom .extraction_tools import load_probe_file, save_to_probe_file, read_binary, write_to_binary_dat_format,\\\n    write_to_h5_dataset_format, get_sub_extractors_by_property, load_extractor_from_json, load_extractor_from_dict, \\\n    load_extractor_from_pickle\n\nfrom .save_tools import save_si_object\n\nfrom .version import version as __version__\n"
  },
  {
    "path": "spikeextractors/baseextractor.py",
    "content": "import json\nfrom pathlib import Path\nimport importlib\nimport numpy as np\nimport datetime\nfrom copy import deepcopy\nimport tempfile\nimport pickle\nimport shutil\n\nfrom .exceptions import NotDumpableExtractorError\n\n\nclass BaseExtractor:\n\n    # To be specified in concrete sub-classes\n    # The default filename (extension to be added by corresponding method)\n    # to be used if no file path is provided\n    _default_filename = None\n\n    def __init__(self):\n        self._kwargs = {}\n        self._tmp_folder = None\n        self._key_properties = {}\n        self._properties = {}\n        self._annotations = {}\n        self._memmap_files = []\n        self._features = {}\n        self._epochs = {}\n        self._times = None\n        self.is_dumpable = True\n        self.id = np.random.randint(low=0, high=9223372036854775807, dtype='int64')\n\n    def __del__(self):\n        # close memmap files (for Windows)\n        for memmap_obj in self._memmap_files:\n            self.del_memmap_file(memmap_obj)\n        if self._tmp_folder is not None and len(self._memmap_files) > 0:\n            try:\n                shutil.rmtree(self._tmp_folder)\n            except Exception as e:\n                print('Impossible to delete temp file:', self._tmp_folder, 'Error', e)\n\n    def del_memmap_file(self, memmap_file):\n        \"\"\"\n        Safely deletes instantiated memmap file.\n\n        Parameters\n        ----------\n        memmap_file: str or Path\n            The memmap file to delete\n        \"\"\"\n        if isinstance(memmap_file, np.memmap):\n            memmap_file = memmap_file.filename\n        else:\n            memmap_file = Path(memmap_file)\n\n        existing_memmap_files = [Path(memmap.filename) for memmap in self._memmap_files]\n        if memmap_file in existing_memmap_files:\n            try:\n                memmap_idx = existing_memmap_files.index(memmap_file)\n                memmap_obj = self._memmap_files[memmap_idx]\n                if not memmap_obj._mmap.closed:\n                    memmap_obj._mmap.close()\n                    del memmap_obj\n                memmap_file.unlink()\n                del self._memmap_files[memmap_idx]\n            except Exception as e:\n                raise Exception(f\"Error in deleting {memmap_file.name}: Error {e}\")\n\n    def make_serialized_dict(self, relative_to=None):\n        \"\"\"\n        Makes a nested serialized dictionary out of the extractor. The dictionary be used to re-initialize an\n        extractor with spikeextractors.load_extractor_from_dict(dump_dict)\n\n        Parameters\n        ----------\n        relative_to: str, Path, or None\n            If not None, file_paths are serialized relative to this path\n\n        Returns\n        -------\n        dump_dict: dict\n            Serialized dictionary\n        \"\"\"\n        class_name = str(type(self)).replace(\"<class '\", \"\").replace(\"'>\", '')\n        module = class_name.split('.')[0]\n        imported_module = importlib.import_module(module)\n\n        try:\n            version = imported_module.__version__\n        except AttributeError:\n            version = 'unknown'\n\n        if self.is_dumpable:\n            dump_dict = {'class': class_name, 'module': module, 'kwargs': self._kwargs,\n                         'key_properties': self._key_properties, 'annotations': self._annotations,\n                         'version': version, 'dumpable': True}\n        else:\n            dump_dict = {'class': class_name, 'module': module, 'kwargs': {}, 'key_properties': self._key_properties,\n                         'annotations': self._annotations, 'version': version,\n                         'dumpable': False}\n\n        if relative_to is not None:\n            relative_to = Path(relative_to).absolute()\n            assert relative_to.is_dir(), \"'relative_to' must be an existing directory\"\n\n            dump_dict = _make_paths_relative(dump_dict, relative_to)\n\n        return dump_dict\n\n    def dump_to_dict(self, relative_to=None):\n        \"\"\"\n        Dumps recording to a dictionary.\n        The dictionary be used to re-initialize an\n        extractor with spikeextractors.load_extractor_from_dict(dump_dict)\n\n        Parameters\n        ----------\n        relative_to: str, Path, or None\n            If not None, file_paths are serialized relative to this path\n\n        Returns\n        -------\n        dump_dict: dict\n            Serialized dictionary\n        \"\"\"\n        return self.make_serialized_dict(relative_to)\n\n    def _get_file_path(self, file_path, extensions):\n        \"\"\"\n        Helper to be used by various dump_to_file utilities.\n\n        Returns default file_path (if not specified), assures that target\n        directory exists, adds correct file extension if none, and assures\n        that provided file extension is one of the allowed.\n\n        Parameters\n        ----------\n        file_path: str or None\n        extensions: list or tuple\n            First provided is used as an extension for the default file_path.\n            All are tested against\n\n        Returns\n        -------\n        Path\n            Path object with file path to the file\n\n        Raises\n        ------\n        NotDumpableExtractorError\n        \"\"\"\n        ext = extensions[0]\n        if self.check_if_dumpable():\n            if file_path is None:\n                file_path = self._default_filename + ext\n            file_path = Path(file_path)\n            file_path.parent.mkdir(parents=True, exist_ok=True)\n            folder_path = file_path.parent\n            if Path(file_path).suffix == '':\n                file_path = folder_path / (str(file_path) + ext)\n            assert file_path.suffix in extensions, \\\n                \"'file_path' should have one of the following extensions:\" \\\n                \" %s\" % (', '.join(extensions))\n            return file_path\n        else:\n            raise NotDumpableExtractorError(\n                f\"The extractor is not dumpable to {ext}\")\n\n    def dump_to_json(self, file_path=None, relative_to=None):\n        \"\"\"\n        Dumps recording extractor to json file.\n        The extractor can be re-loaded with spikeextractors.load_extractor_from_json(json_file)\n\n        Parameters\n        ----------\n        file_path: str\n            Path of the json file\n        relative_to: str, Path, or None\n            If not None, file_paths are serialized relative to this path\n\n        \"\"\"\n        dump_dict = self.make_serialized_dict(relative_to)\n        self._get_file_path(file_path, ['.json'])\\\n            .write_text(\n                json.dumps(_check_json(dump_dict), indent=4),\n                encoding='utf8'\n            )\n\n    def dump_to_pickle(self, file_path=None, include_properties=True, include_features=True,\n                       relative_to=None):\n        \"\"\"\n        Dumps recording extractor to a pickle file.\n        The extractor can be re-loaded with spikeextractors.load_extractor_from_json(json_file)\n\n        Parameters\n        ----------\n        file_path: str\n            Path of the json file\n        include_properties: bool\n            If True, all properties are dumped\n        include_features: bool\n            If True, all features are dumped\n        relative_to: str, Path, or None\n            If not None, file_paths are serialized relative to this path\n        \"\"\"\n        file_path = self._get_file_path(file_path, ['.pkl', '.pickle'])\n\n        # Dump all\n        dump_dict = {'serialized_dict': self.make_serialized_dict(relative_to)}\n        if include_properties:\n            if len(self._properties.keys()) > 0:\n                dump_dict['properties'] = self._properties\n        if include_features:\n            if len(self._features.keys()) > 0:\n                dump_dict['features'] = self._features\n        # include times\n        dump_dict[\"times\"] = self._times\n\n        file_path.write_bytes(pickle.dumps(dump_dict))\n\n    def get_tmp_folder(self):\n        \"\"\"\n        Returns temporary folder associated to the extractor\n\n        Returns\n        -------\n        temp_folder: Path\n            The temporary folder\n        \"\"\"\n        if self._tmp_folder is None:\n            self._tmp_folder = Path(tempfile.mkdtemp())\n        return self._tmp_folder\n\n    def set_tmp_folder(self, folder):\n        \"\"\"\n        Sets temporary folder of the extractor\n\n        Parameters\n        ----------\n        folder: str or Path\n            The temporary folder\n        \"\"\"\n        self._tmp_folder = Path(folder)\n\n    def allocate_array(self, memmap, shape=None, dtype=None, name=None, array=None):\n        \"\"\"\n        Allocates a memory or memmap array\n\n        Parameters\n        ----------\n        memmap: bool\n            If True, a memmap array is created in the sorting temporary folder\n        shape: tuple\n            Shape of the array. If None array must be given\n        dtype: dtype\n            Dtype of the array. If None array must be given\n        name: str or None\n            Name (root) of the file (if memmap is True). If None, a random name is generated\n        array: np.array\n            If array is given, shape and dtype are initialized based on the array. If memmap is True, the array is then\n            deleted to clear memory\n\n        Returns\n        -------\n        arr: np.array or np.memmap\n            The allocated memory or memmap array\n        \"\"\"\n        if memmap:\n            tmp_folder = self.get_tmp_folder()\n            if array is not None:\n                shape = array.shape\n                dtype = array.dtype\n            else:\n                assert shape is not None and dtype is not None, \"Pass 'shape' and 'dtype' arguments\"\n            if name is None:\n                tmp_file = tempfile.NamedTemporaryFile(suffix=\".raw\", dir=tmp_folder).name\n            else:\n                if Path(name).suffix == '':\n                    tmp_file = tmp_folder / (name + '.raw')\n                else:\n                    tmp_file = tmp_folder / name\n            raw_tmp_file = r'{}'.format(str(tmp_file))\n\n            # make sure any open memmap files with same path are deleted\n            self.del_memmap_file(raw_tmp_file)\n            arr = np.memmap(raw_tmp_file, mode='w+', shape=shape, dtype=dtype)\n            if array is not None:\n                arr[:] = array\n                del array\n            else:\n                arr[:] = 0\n            self._memmap_files.append(arr)\n        else:\n            if array is not None:\n                arr = array\n            else:\n                arr = np.zeros(shape, dtype=dtype)\n        return arr\n\n    def annotate(self, annotation_key, value, overwrite=False):\n        \"\"\"This function adds an entry to the annotations dictionary.\n\n        Parameters\n        ----------\n        annotation_key: str\n            An annotation stored by the Extractor\n        value:\n            The data associated with the given property name. Could be many\n            formats as specified by the user\n        overwrite: bool\n            If True and the annotation already exists, it is overwritten\n        \"\"\"\n        if annotation_key not in self._annotations.keys():\n            self._annotations[annotation_key] = value\n        else:\n            if overwrite:\n                self._annotations[annotation_key] = value\n            else:\n                print(f\"{annotation_key} is already an annotation key. Use 'overwrite=True' to overwrite it\")\n\n    def get_annotation(self, annotation_name):\n        \"\"\"This function returns the data stored under the annotation name.\n\n        Parameters\n        ----------\n        annotation_name: str\n            A property stored by the Extractor\n\n        Returns\n        ----------\n        annotation_data\n            The data associated with the given property name. Could be many\n            formats as specified by the user\n        \"\"\"\n        if annotation_name not in self._annotations.keys():\n            print(f\"{annotation_name} is not an annotation\")\n            return None\n        else:\n            return deepcopy(self._annotations[annotation_name])\n\n    def get_annotation_keys(self):\n        \"\"\"This function returns a list of stored annotation keys\n\n        Returns\n        ----------\n        property_names: list\n            List of stored annotation keys\n        \"\"\"\n        return list(self._annotations.keys())\n\n    def copy_annotations(self, extractor):\n        \"\"\"Copy object properties from another extractor to the current extractor.\n\n        Parameters\n        ----------\n        extractor: Extractor\n            The extractor from which the annotations will be copied\n        \"\"\"\n        self._annotations = deepcopy(extractor._annotations)\n\n    def add_epoch(self, epoch_name, start_frame, end_frame):\n        \"\"\"This function adds an epoch to your extractor that tracks\n        a certain time period. It is stored in an internal\n        dictionary of start and end frame tuples.\n\n        Parameters\n        ----------\n        epoch_name: str\n            The name of the epoch to be added\n        start_frame: int\n            The start frame of the epoch to be added (inclusive)\n        end_frame: int\n            The end frame of the epoch to be added (exclusive). If set to None, it will include the entire\n            sorting after the start_frame\n        \"\"\"\n        if isinstance(epoch_name, str):\n            start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame)\n            self._epochs[epoch_name] = {'start_frame': start_frame, 'end_frame': end_frame}\n        else:\n            raise TypeError(\"epoch_name must be a string\")\n\n    def remove_epoch(self, epoch_name):\n        \"\"\"This function removes an epoch from your extractor.\n\n        Parameters\n        ----------\n        epoch_name: str\n            The name of the epoch to be removed\n        \"\"\"\n        if isinstance(epoch_name, str):\n            if epoch_name in list(self._epochs.keys()):\n                del self._epochs[epoch_name]\n            else:\n                raise ValueError(\"This epoch has not been added\")\n        else:\n            raise ValueError(\"epoch_name must be a string\")\n\n    def get_epoch_names(self):\n        \"\"\"This function returns a list of all the epoch names in the extractor\n\n        Returns\n        ----------\n        epoch_names: list\n            List of epoch names in the recording extractor\n        \"\"\"\n        epoch_names = list(self._epochs.keys())\n        if not epoch_names:\n            pass\n        else:\n            epoch_start_frames = []\n            for epoch_name in epoch_names:\n                epoch_info = self.get_epoch_info(epoch_name)\n                start_frame = epoch_info['start_frame']\n                epoch_start_frames.append(start_frame)\n            epoch_names = [epoch_name for _, epoch_name in sorted(zip(epoch_start_frames, epoch_names))]\n        return epoch_names\n\n    def get_epoch_info(self, epoch_name):\n        \"\"\"This function returns the start frame and end frame of the epoch\n        in a dict.\n\n        Parameters\n        ----------\n        epoch_name: str\n            The name of the epoch to be returned\n\n        Returns\n        ----------\n        epoch_info: dict\n            A dict containing the start frame and end frame of the epoch\n        \"\"\"\n        # Default (Can add more information into each epoch in subclass)\n        if isinstance(epoch_name, str):\n            if epoch_name in list(self._epochs.keys()):\n                epoch_info = self._epochs[epoch_name]\n                return epoch_info\n            else:\n                raise ValueError(\"This epoch has not been added\")\n        else:\n            raise ValueError(\"epoch_name must be a string\")\n\n    def copy_epochs(self, extractor):\n        \"\"\"Copy epochs from another extractor.\n\n        Parameters\n        ----------\n        extractor: BaseExtractor\n            The extractor from which the epochs will be copied\n        \"\"\"\n        for epoch_name in extractor.get_epoch_names():\n            epoch_info = extractor.get_epoch_info(epoch_name)\n            self.add_epoch(epoch_name, epoch_info[\"start_frame\"], epoch_info[\"end_frame\"])\n\n    def _cast_start_end_frame(self, start_frame, end_frame):\n        from .extraction_tools import cast_start_end_frame\n        return cast_start_end_frame(start_frame, end_frame)\n\n    @staticmethod\n    def load_extractor_from_json(json_file):\n        \"\"\"\n        Instantiates extractor from json file\n\n        Parameters\n        ----------\n        json_file: str or Path\n            Path to json file\n\n        Returns\n        -------\n        extractor: RecordingExtractor or SortingExtractor\n            The loaded extractor object\n        \"\"\"\n        json_file = Path(json_file)\n        with open(str(json_file), 'r') as f:\n            d = json.load(f)\n            extractor = _load_extractor_from_dict(d)\n        return extractor\n\n    @staticmethod\n    def load_extractor_from_pickle(pkl_file):\n        \"\"\"\n        Instantiates extractor from pickle file.\n\n        Parameters\n        ----------\n        pkl_file: str or Path\n            Path to pickle file\n\n        Returns\n        -------\n        extractor: RecordingExtractor or SortingExtractor\n            The loaded extractor object\n        \"\"\"\n        pkl_file = Path(pkl_file)\n        with open(str(pkl_file), 'rb') as f:\n            d = pickle.load(f)\n        extractor = _load_extractor_from_dict(d['serialized_dict'])\n        if 'properties' in d.keys():\n            extractor._properties = d['properties']\n        if 'features' in d.keys():\n            extractor._features = d['features']\n        if 'times' in d.keys():\n            extractor._times = d['times']\n        return extractor\n\n    @staticmethod\n    def load_extractor_from_dict(d):\n        \"\"\"\n        Instantiates extractor from dictionary\n\n        Parameters\n        ----------\n        d: dictionary\n            Python dictionary\n\n        Returns\n        -------\n        extractor: RecordingExtractor or SortingExtractor\n            The loaded extractor object\n        \"\"\"\n        extractor = _load_extractor_from_dict(d)\n        return extractor\n\n    def check_if_dumpable(self):\n        return _check_if_dumpable(self.make_serialized_dict())\n\n\ndef _make_paths_relative(d, relative):\n    dcopy = deepcopy(d)\n    if \"kwargs\" in dcopy.keys():\n        relative_kwargs = _make_paths_relative(dcopy[\"kwargs\"], relative)\n        dcopy[\"kwargs\"] = relative_kwargs\n        return dcopy\n    else:\n        for k in d.keys():\n            # in SI, all input paths have the \"path\" keyword\n            if \"path\" in k:\n                d[k] = str(Path(d[k]).relative_to(relative))\n        return d\n\n\ndef _load_extractor_from_dict(dic):\n    cls = None\n    class_name = None\n    probe_file = None\n    kwargs = deepcopy(dic['kwargs'])\n    if np.any([isinstance(v, dict) for v in kwargs.values()]):\n        # nested\n        for k in kwargs.keys():\n            if isinstance(kwargs[k], dict):\n                if 'module' in kwargs[k].keys() and 'class' in kwargs[k].keys() and 'version' in kwargs[k].keys():\n                    extractor = _load_extractor_from_dict(kwargs[k])\n                    class_name = dic['class']\n                    cls = _get_class_from_string(class_name)\n                    kwargs[k] = extractor\n                    break\n    elif np.any([isinstance(v, list) and isinstance(v[0], dict) for v in kwargs.values()]):\n        # multi\n        for k in kwargs.keys():\n            if isinstance(kwargs[k], list) and isinstance(kwargs[k][0], dict):\n                extractors = []\n                for kw in kwargs[k]:\n                    if 'module' in kw.keys() and 'class' in kw.keys() and 'version' in kw.keys():\n                        extr = _load_extractor_from_dict(kw)\n                        extractors.append(extr)\n                class_name = dic['class']\n                cls = _get_class_from_string(class_name)\n                kwargs[k] = extractors\n                break\n    else:\n        class_name = dic['class']\n        cls = _get_class_from_string(class_name)\n\n    assert cls is not None and class_name is not None, \"Could not load spikeinterface class\"\n    if not _check_same_version(class_name, dic['version']):\n        print('Versions are not the same. This might lead to errors. Use ', class_name.split('.')[0],\n              'version', dic['version'])\n\n    if 'probe_file' in kwargs.keys():\n        probe_file = kwargs.pop('probe_file')\n\n    # instantiate extrator object\n    extractor = cls(**kwargs)\n\n    # load probe file\n    if probe_file is not None:\n        assert 'Recording' in class_name, \"Only recording extractors can have probe files\"\n        extractor = extractor.load_probe_file(probe_file=probe_file)\n\n    # load properties and features\n    if 'key_properties' in dic.keys():\n        extractor._key_properties = dic['key_properties']\n\n    if 'annotations' in dic.keys():\n        extractor._annotations = dic['annotations']\n\n    return extractor\n\n\ndef _get_class_from_string(class_string):\n    class_name = class_string.split('.')[-1]\n    module = '.'.join(class_string.split('.')[:-1])\n    imported_module = importlib.import_module(module)\n\n    try:\n        imported_class = getattr(imported_module, class_name)\n    except:\n        imported_class = None\n\n    return imported_class\n\n\ndef _check_same_version(class_string, version):\n    module = class_string.split('.')[0]\n    imported_module = importlib.import_module(module)\n\n    try:\n        return imported_module.__version__ == version\n    except AttributeError:\n        return 'unknown'\n\n\ndef _check_if_dumpable(d):\n    kwargs = d['kwargs']\n    if np.any([isinstance(v, dict) and 'dumpable' in v.keys() for (k, v) in kwargs.items()]):\n        for k, v in kwargs.items():\n            if 'dumpable' in v.keys():\n                return _check_if_dumpable(v)\n    else:\n        return d['dumpable']\n\n\ndef _check_json(d):\n    # quick hack to ensure json writable\n    for k, v in d.items():\n        if isinstance(v, dict):\n            d[k] = _check_json(v)\n        elif isinstance(v, Path):\n            d[k] = str(v.absolute())\n        elif isinstance(v, bool):\n            d[k] = bool(v)\n        elif isinstance(v, (int, np.integer)):\n            d[k] = int(v)\n        elif isinstance(v, float):\n            d[k] = float(v)\n        elif isinstance(v, datetime.datetime):\n            d[k] = v.isoformat()\n        elif isinstance(v, (np.ndarray, list)):\n            if len(v) > 0:\n                if isinstance(v[0], dict):\n                    # these must be extractors for multi extractors\n                    d[k] = [_check_json(v_el) for v_el in v]\n                else:\n                    v_arr = np.array(v)\n                    if len(v_arr.shape) == 1:\n                        if 'int' in str(v_arr.dtype):\n                            v_arr = [int(v_el) for v_el in v_arr]\n                            d[k] = v_arr\n                        elif 'float' in str(v_arr.dtype):\n                            v_arr = [float(v_el) for v_el in v_arr]\n                            d[k] = v_arr\n                        elif isinstance(v_arr[0], str):\n                            v_arr = [str(v_el) for v_el in v_arr]\n                            d[k] = v_arr\n                        else:\n                            print(f'Skipping field {k}: only 1D arrays of int, float, or str types can be serialized')\n                    elif len(v_arr.shape) == 2:\n                        if 'int' in str(v_arr.dtype):\n                            v_arr = [[int(v_el) for v_el in v_row] for v_row in v_arr]\n                            d[k] = v_arr\n                        elif 'float' in str(v_arr.dtype):\n                            v_arr = [[float(v_el) for v_el in v_row] for v_row in v_arr]\n                            d[k] = v_arr\n                        elif 'bool' in str(v_arr.dtype):\n                            v_arr = [[bool(v_el) for v_el in v_row] for v_row in v_arr]\n                            d[k] = v_arr\n                        else:\n                            print(f'Skipping field {k}: only 2D arrays of int or float type can be serialized')\n                    else:\n                        print(f\"Skipping field {k}: only 1D and 2D arrays can be serialized\")\n            else:\n                d[k] = list(v)\n    return d\n"
  },
  {
    "path": "spikeextractors/cacheextractors.py",
    "content": "from spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor\nfrom spikeextractors.extractors.npzsortingextractor import NpzSortingExtractor\nfrom spikeextractors import RecordingExtractor, SortingExtractor\nimport tempfile\nfrom pathlib import Path\nfrom copy import deepcopy\nimport importlib\nimport shutil\n\n\nclass CacheRecordingExtractor(BinDatRecordingExtractor, RecordingExtractor):\n    def __init__(self, recording, return_scaled=True,\n                 chunk_size=None, chunk_mb=500, save_path=None, n_jobs=1, joblib_backend='loky',\n                 verbose=False):\n        RecordingExtractor.__init__(self)  # init tmp folder before constructing BinDatRecordingExtractor\n        tmp_folder = self.get_tmp_folder()\n        self._recording = recording\n        if save_path is None:\n            self._is_tmp = True\n            self._tmp_file = tempfile.NamedTemporaryFile(suffix=\".dat\", dir=tmp_folder).name\n        else:\n            save_path = Path(save_path)\n            if save_path.suffix != '.dat' and save_path.suffix != '.bin':\n                save_path = save_path.with_suffix('.dat')\n            save_path.parent.mkdir(parents=True, exist_ok=True)\n            self._is_tmp = False\n            self._tmp_file = save_path\n        self._return_scaled = return_scaled\n        self._dtype = recording.get_dtype(return_scaled)\n        recording.write_to_binary_dat_format(save_path=self._tmp_file, dtype=self._dtype, chunk_size=chunk_size,\n                                             chunk_mb=chunk_mb, n_jobs=n_jobs, joblib_backend=joblib_backend,\n                                             return_scaled=self._return_scaled, verbose=verbose)\n        # keep track of filter status when dumping\n        self.is_filtered = self._recording.is_filtered\n        BinDatRecordingExtractor.__init__(self, self._tmp_file, numchan=recording.get_num_channels(),\n                                          recording_channels=recording.get_channel_ids(),\n                                          sampling_frequency=recording.get_sampling_frequency(),\n                                          dtype=self._dtype, is_filtered=self.is_filtered)\n\n        self.set_tmp_folder(tmp_folder)\n        self.copy_channel_properties(recording)\n        self.copy_times(recording)\n\n        if 'gain' in recording.get_shared_channel_property_names() and not return_scaled:\n            self.set_channel_gains(recording.get_channel_gains())\n            self.set_channel_offsets(recording.get_channel_offsets())\n            self.has_unscaled = True\n        else:\n            self.clear_channel_gains()\n            self.clear_channel_offsets()\n\n        # keep BinDatRecording kwargs\n        self._bindat_kwargs = deepcopy(self._kwargs)\n        self._kwargs = {'recording': recording, 'chunk_size': chunk_size, 'chunk_mb': chunk_mb}\n\n    def __del__(self):\n        if self._is_tmp:\n            try:\n                # close memmap file (for Windows)\n                del self._timeseries\n                Path(self._tmp_file).unlink()\n            except Exception as e:\n                print(\"Unable to remove temporary file\", e)\n\n    @property\n    def filename(self):\n        return str(self._tmp_file)\n\n    def move_to(self, save_path):\n        save_path = Path(save_path)\n        if save_path.suffix != '.dat' and save_path.suffix != '.bin':\n            save_path = save_path.with_suffix('.dat')\n        save_path.parent.mkdir(parents=True, exist_ok=True)\n        # close memmap file (for Windows)\n        del self._timeseries\n        shutil.move(self._tmp_file, str(save_path))\n        self._tmp_file = str(save_path)\n        self._kwargs['file_path'] = str(Path(self._tmp_file).absolute())\n        self._bindat_kwargs['file_path'] = str(Path(self._tmp_file).absolute())\n        self._is_tmp = False\n        tmp_folder = self.get_tmp_folder()\n        # re-initialize with new file\n        BinDatRecordingExtractor.__init__(self, self._tmp_file, numchan=self._recording.get_num_channels(),\n                                          recording_channels=self._recording.get_channel_ids(),\n                                          sampling_frequency=self._recording.get_sampling_frequency(),\n                                          dtype=self._dtype, is_filtered=self.is_filtered)\n        self.set_tmp_folder(tmp_folder)\n        self.copy_channel_properties(self._recording)\n\n    # override to make serialization avoid reloading and saving binary file\n    def make_serialized_dict(self, include_properties=None, include_features=None):\n        \"\"\"\n        Makes a nested serialized dictionary out of the extractor. The dictionary be used to re-initialize an\n        extractor with spikeextractors.load_extractor_from_dict(dump_dict)\n\n        Returns\n        -------\n        include_properties: list or None\n            List of properties to include in the dictionary\n        include_features: list or None\n            List of features to include in the dictionary\n        \"\"\"\n        class_name = str(BinDatRecordingExtractor).replace(\"<class '\", \"\").replace(\"'>\", '')\n        module = class_name.split('.')[0]\n        imported_module = importlib.import_module(module)\n\n        if self._is_tmp:\n            print(\"Warning: dumping a CacheRecordingExtractor. The path to the tmp binary file will be lost in \"\n                  \"further sessions. To prevent this, use the 'CacheRecordingExtractor.move_to('path-to-file)' \"\n                  \"function\")\n\n        dump_dict = {'class': class_name, 'module': module, 'kwargs': self._bindat_kwargs,\n                     'key_properties': self._key_properties, 'version': imported_module.__version__, 'dumpable': True}\n        return dump_dict\n\n\nclass CacheSortingExtractor(NpzSortingExtractor, SortingExtractor):\n    def __init__(self, sorting, save_path=None):\n        SortingExtractor.__init__(self)  # init tmp folder before constructing NpzSortingExtractor\n        tmp_folder = self.get_tmp_folder()\n        self._sorting = sorting\n        if save_path is None:\n            self._is_tmp = True\n            self._tmp_file = tempfile.NamedTemporaryFile(suffix=\".npz\", dir=tmp_folder).name\n        else:\n            save_path = Path(save_path)\n            if save_path.suffix != '.npz':\n                save_path = save_path.with_suffix('.npz')\n            save_path.parent.mkdir(parents=True, exist_ok=True)\n            self._is_tmp = False\n            self._tmp_file = save_path\n        NpzSortingExtractor.write_sorting(self._sorting, self._tmp_file)\n        NpzSortingExtractor.__init__(self, self._tmp_file)\n        # keep Npz kwargs\n        self._npz_kwargs = deepcopy(self._kwargs)\n        self.set_tmp_folder(tmp_folder)\n        self.copy_unit_properties(sorting)\n        self.copy_unit_spike_features(sorting)\n        self._kwargs = {'sorting': sorting}\n\n    def __del__(self):\n        if self._is_tmp:\n            try:\n                Path(self._tmp_file).unlink()\n            except Exception as e:\n                print(\"Unable to remove temporary file\", e)\n\n    @property\n    def filename(self):\n        return str(self._tmp_file)\n\n    def move_to(self, save_path):\n        save_path = Path(save_path)\n        if save_path.suffix != '.npz':\n            save_path = save_path.with_suffix('.npz')\n        save_path.parent.mkdir(parents=True, exist_ok=True)\n        shutil.move(self._tmp_file, str(save_path))\n        self._tmp_file = str(save_path)\n        self._kwargs['file_path'] = str(Path(self._tmp_file).absolute())\n        self._npz_kwargs['file_path'] = str(Path(self._tmp_file).absolute())\n        self._is_tmp = False\n        tmp_folder = self.get_tmp_folder()\n        # re-initialize with new file\n        NpzSortingExtractor.__init__(self, self._tmp_file)\n        # keep Npz kwargs\n        self.set_tmp_folder(tmp_folder)\n        self.copy_unit_properties(self._sorting)\n        self.copy_unit_spike_features(self._sorting)\n\n    # override to make serialization avoid reloading and saving npz file\n    def make_serialized_dict(self, include_properties=None, include_features=None):\n        \"\"\"\n        Makes a nested serialized dictionary out of the extractor. The dictionary be used to re-initialize an\n        extractor with spikeextractors.load_extractor_from_dict(dump_dict)\n\n        Returns\n        -------\n        include_properties: list or None\n            List of properties to include in the dictionary\n        include_features: list or None\n            List of features to include in the dictionary\n        \"\"\"\n        class_name = str(NpzSortingExtractor).replace(\"<class '\", \"\").replace(\"'>\", '')\n        module = class_name.split('.')[0]\n        imported_module = importlib.import_module(module)\n\n        if self._is_tmp:\n            print(\"Warning: dumping a CacheSortingExtractor. The path to the tmp binary file will be lost in \"\n                  \"further sessions. To prevent this, use the 'CacheSortingExtractor.move_to('path-to-file)' \"\n                  \"function\")\n\n        dump_dict = {'class': class_name, 'module': module, 'kwargs': self._npz_kwargs,\n                     'key_properties': self._key_properties, 'version': imported_module.__version__, 'dumpable': True}\n        return dump_dict\n"
  },
  {
    "path": "spikeextractors/example_datasets/__init__.py",
    "content": "from .toy_example import toy_example\n"
  },
  {
    "path": "spikeextractors/example_datasets/synthesize_random_firings.py",
    "content": "import numpy as np\n\n\ndef synthesize_random_firings(*, K=20, sampling_frequency=30000.0, duration=60, seed=None):\n    if seed is not None:\n        np.random.seed(seed)\n        seeds = np.random.RandomState(seed=seed).randint(0, 2147483647, K)\n    else:\n        seeds = np.random.randint(0, 2147483647, K)\n\n    firing_rates = 3 * np.ones((K))\n    refr = 4\n\n    N = np.int64(duration * sampling_frequency)\n\n    # events/sec * sec/timepoint * N\n    populations = np.ceil(firing_rates / sampling_frequency * N).astype('int')\n    times = np.zeros(0)\n    labels = np.zeros(0)\n\n    for i, k in enumerate(range(1, K + 1)):\n        refr_timepoints = refr / 1000 * sampling_frequency\n\n        times0 = np.random.rand(populations[k - 1]) * (N - 1) + 1\n\n        ## make an interesting autocorrelogram shape\n        times0 = np.hstack((times0, times0 + rand_distr2(refr_timepoints, refr_timepoints * 20, times0.size, seeds[i])))\n        times0 = times0[np.random.RandomState(seed=seeds[i]).choice(times0.size, int(times0.size / 2))]\n        times0 = times0[np.where((0 <= times0) & (times0 < N))]\n\n        times0 = enforce_refractory_period(times0, refr_timepoints)\n        times = np.hstack((times, times0))\n        labels = np.hstack((labels, k * np.ones(times0.shape)))\n\n    sort_inds = np.argsort(times)\n    times = times[sort_inds]\n    labels = labels[sort_inds]\n\n    return (times, labels)\n\n\ndef rand_distr2(a, b, num, seed):\n    X = np.random.RandomState(seed=seed).rand(num)\n    X = a + (b - a) * X ** 2\n    return X\n\n\ndef enforce_refractory_period(times_in, refr):\n    if (times_in.size == 0): return times_in\n\n    times0 = np.sort(times_in)\n    done = False\n    while not done:\n        diffs = times0[1:] - times0[:-1]\n        diffs = np.hstack((diffs, np.inf))  # hack to make sure we handle the last one\n        inds0 = np.where((diffs[:-1] <= refr) & (diffs[1:] >= refr))[0]  # only first violator in every group\n        if len(inds0) > 0:\n            times0[inds0] = -1  # kind of a hack, what's the better way?\n            times0 = times0[np.where(times0 >= 0)]\n        else:\n            done = True\n\n    return times0\n"
  },
  {
    "path": "spikeextractors/example_datasets/synthesize_random_waveforms.py",
    "content": "import numpy as np\nfrom .synthesize_single_waveform import synthesize_single_waveform\n\n\ndef synthesize_random_waveforms(*, M=5, T=500, K=20, upsamplefac=13, timeshift_factor=3, average_peak_amplitude=-10,\n                                seed=None):\n    if seed is not None:\n        np.random.seed(seed)\n        seeds = np.random.RandomState(seed=seed).randint(0, 2147483647, K)\n    else:\n        seeds = np.random.randint(0, 2147483647, K)\n    geometry = None\n    avg_durations = [200, 10, 30, 200]\n    avg_amps = [0.5, 10, -1, 0]\n    rand_durations_stdev = [10, 4, 6, 20]\n    rand_amps_stdev = [0.2, 3, 0.5, 0]\n    rand_amp_factor_range = [0.5, 1]\n    geom_spread_coef1 = 0.2\n    geom_spread_coef2 = 1\n\n    if not geometry:\n        geometry = np.zeros((2, M))\n        geometry[0, :] = np.arange(1, M + 1)\n\n    geometry = np.array(geometry)\n    avg_durations = np.array(avg_durations)\n    avg_amps = np.array(avg_amps)\n    rand_durations_stdev = np.array(rand_durations_stdev)\n    rand_amps_stdev = np.array(rand_amps_stdev)\n    rand_amp_factor_range = np.array(rand_amp_factor_range)\n\n    neuron_locations = get_default_neuron_locations(M, K, geometry)\n\n    ## The waveforms_out\n    WW = np.zeros((M, T * upsamplefac, K))\n\n    for i, k in enumerate(range(1, K + 1)):\n        for m in range(1, M + 1):\n            diff = neuron_locations[:, k - 1] - geometry[:, m - 1]\n            dist = np.sqrt(np.sum(diff ** 2))\n            durations0 = np.maximum(np.ones(avg_durations.shape),\n                                    avg_durations + np.random.RandomState(seed=seeds[i]).randn(1, 4) * rand_durations_stdev) * upsamplefac\n            amps0 = avg_amps + np.random.RandomState(seed=seeds[i]).randn(1, 4) * rand_amps_stdev\n            waveform0 = synthesize_single_waveform(N=T * upsamplefac, durations=durations0, amps=amps0)\n            waveform0 = np.roll(waveform0, int(timeshift_factor * dist * upsamplefac))\n            waveform0 = waveform0 * np.random.RandomState(seed=seeds[i]).uniform(rand_amp_factor_range[0], rand_amp_factor_range[1])\n            WW[m - 1, :, k - 1] = waveform0 / (geom_spread_coef1 + dist * geom_spread_coef2)\n\n    peaks = np.max(np.abs(WW), axis=(0, 1))\n    WW = WW / np.mean(peaks) * average_peak_amplitude\n\n    return (WW, geometry.T)\n\n\ndef get_default_neuron_locations(M, K, geometry):\n    num_dims = geometry.shape[0]\n    neuron_locations = np.zeros((num_dims, K))\n    for k in range(1, K + 1):\n        if K > 0:\n            ind = (k - 1) / (K - 1) * (M - 1) + 1\n            ind0 = int(ind)\n            if ind0 == M:\n                ind0 = M - 1\n                p = 1\n            else:\n                p = ind - ind0\n            if M > 0:\n                neuron_locations[:, k - 1] = (1 - p) * geometry[:, ind0 - 1] + p * geometry[:, ind0]\n            else:\n                neuron_locations[:, k - 1] = geometry[:, 0]\n        else:\n            neuron_locations[:, k - 1] = geometry[:, 0]\n\n    return neuron_locations\n"
  },
  {
    "path": "spikeextractors/example_datasets/synthesize_single_waveform.py",
    "content": "import numpy as np\n\n\ndef exp_growth(amp1, amp2, dur1, dur2):\n    t = np.arange(0, dur1)\n    Y = np.exp(t / dur2)\n    # Want Y[0]=amp1\n    # Want Y[-1]=amp2\n    Y = Y / (Y[-1] - Y[0]) * (amp2 - amp1)\n    Y = Y - Y[0] + amp1;\n    return Y\n\n\ndef exp_decay(amp1, amp2, dur1, dur2):\n    Y = exp_growth(amp2, amp1, dur1, dur2)\n    Y = np.flipud(Y)  # used to be flip, but that was not supported by older versions of numpy\n    return Y\n\n\ndef smooth_it(Y, t):\n    Z = np.zeros(Y.size)\n    for j in range(-t, t + 1):\n        Z = Z + np.roll(Y, j)\n    return Z\n\n\ndef synthesize_single_waveform(*, N=800, durations=[200, 10, 30, 200], amps=[0.5, 10, -1, 0]):\n    durations = np.array(durations).ravel()\n    if (np.sum(durations) >= N - 2):\n        durations[-1] = N - 2 - np.sum(durations[0:durations.size - 1])\n\n    amps = np.array(amps).ravel()\n\n    timepoints = np.round(np.hstack((0, np.cumsum(durations) - 1))).astype('int');\n\n    t = np.r_[0:np.sum(durations) + 1]\n\n    Y = np.zeros(len(t))\n    Y[timepoints[0]:timepoints[1] + 1] = exp_growth(0, amps[0], timepoints[1] + 1 - timepoints[0], durations[0] / 4)\n    Y[timepoints[1]:timepoints[2] + 1] = exp_growth(amps[0], amps[1], timepoints[2] + 1 - timepoints[1], durations[1])\n    Y[timepoints[2]:timepoints[3] + 1] = exp_decay(amps[1], amps[2], timepoints[3] + 1 - timepoints[2],\n                                                   durations[2] / 4)\n    Y[timepoints[3]:timepoints[4] + 1] = exp_decay(amps[2], amps[3], timepoints[4] + 1 - timepoints[3],\n                                                   durations[3] / 5)\n    Y = smooth_it(Y, 3)\n    Y = Y - np.linspace(Y[0], Y[-1], len(t))\n    Y = np.hstack((Y, np.zeros(N - len(t))))\n    Nmid = int(np.floor(N / 2))\n    peakind = np.argmax(np.abs(Y))\n    Y = np.roll(Y, Nmid - peakind)\n\n    return Y\n\n\n# Y=smooth_it(Y,3);\n# Y=Y-linspace(Y(1),Y(end),length(Y));\n#\n# Y=[Y,zeros(1,N-length(Y))];\n#\n# Nmid=floor(N/2);\n# [~,peakind]=max(abs(Y));\n# Y=circshift(Y,[0,Nmid-peakind]);\n#\n# end\n#\n# function test_synth_waveform\n# Y=synthesize_single_waveform(800);\n# figure; plot(Y);\n# end\n#\n# function Y=exp_growth(amp1,amp2,dur1,dur2)\n# t=1:dur1;\n# Y=exp(t/dur2);\n# % Want Y(1)=amp1\n# % Want Y(end)=amp2\n# Y=Y/(Y(end)-Y(1))*(amp2-amp1);\n# Y=Y-Y(1)+amp1;\n# end\n#\n# function Y=exp_decay(amp1,amp2,dur1,dur2)\n# Y=exp_growth(amp2,amp1,dur1,dur2);\n# Y=Y(end:-1:1);\n# end\n#\n# function Z=smooth_it(Y,t)\n# Z=Y;\n# Z(1+t:end-t)=0;\n# for j=-t:t\n#    Z(1+t:end-t)=Z(1+t:end-t)+Y(1+t+j:end-t+j)/(2*t+1);\n# end;\n# end\n\nif __name__ == '__main__':\n    Y = synthesize_single_waveform()\n    import matplotlib.pyplot as plt\n\n    plt.plot(Y)\n"
  },
  {
    "path": "spikeextractors/example_datasets/synthesize_timeseries.py",
    "content": "import numpy as np\n\n\ndef synthesize_timeseries(*, sorting, waveforms, noise_level=1, sampling_frequency=30000.0, duration=60, waveform_upsamplefac=13, seed=None):\n    num_timepoints = np.int64(sampling_frequency * duration)\n    waveform_upsamplefac = int(waveform_upsamplefac)\n    W = waveforms\n\n    M, TT, K = W.shape[0], W.shape[1], W.shape[2]\n    T = int(TT / waveform_upsamplefac)\n    Tmid = int(np.ceil((T + 1) / 2 - 1))\n\n    N = num_timepoints\n\n    if seed is not None:\n        X = np.random.RandomState(seed=seed).randn(M, N) * noise_level\n    else:\n        X = np.random.randn(M, N) * noise_level\n\n    unit_ids = sorting.get_unit_ids()\n    for k0 in unit_ids:\n        waveform0 = waveforms[:, :, k0 - 1]\n        times0 = sorting.get_unit_spike_train(unit_id=k0)\n        for t0 in times0:\n            amp0 = 1\n            frac_offset = int(np.floor((t0 - np.floor(t0)) * waveform_upsamplefac))\n            tstart = np.int64(np.floor(t0)) - Tmid\n            if (0 <= tstart) and (tstart + T <= N):\n                X[:, tstart:tstart + T] = X[:, tstart:tstart + T] + waveform0[:,\n                                                                    frac_offset::waveform_upsamplefac] * amp0\n\n    return X\n"
  },
  {
    "path": "spikeextractors/example_datasets/toy_example.py",
    "content": "import numpy as np\nfrom pathlib import Path\nfrom typing import Optional, Union\n\nimport spikeextractors as se\nfrom .synthesize_random_waveforms import synthesize_random_waveforms\nfrom .synthesize_random_firings import synthesize_random_firings\nfrom .synthesize_timeseries import synthesize_timeseries\n\n\ndef toy_example(\n    duration: float = 10.,\n    num_channels: int = 4,\n    sampling_frequency: float = 30000.,\n    K: int = 10,\n    dumpable: bool = False,\n    dump_folder: Optional[Union[str, Path]] = None,\n    seed: Optional[int] = None\n):\n    \"\"\"\n    Create toy recording and sorting extractors.\n\n    Parameters\n    ----------\n    duration: float\n        Duration in s (default 10)\n    num_channels: int\n        Number of channels (default 4)\n    sampling_frequency: float\n        Sampling frequency (default 30000)\n    K: int\n        Number of units (default 10)\n    dumpable: bool\n        If True, objects are dumped to file and become 'dumpable'\n    dump_folder: str or Path\n        Path to dump folder (if None, 'test' is used\n    seed: int\n        Seed for random initialization\n\n    Returns\n    -------\n    recording: RecordingExtractor\n        The output recording extractor. If dumpable is False it's a NumpyRecordingExtractor, otherwise it's an\n        MdaRecordingExtractor\n    sorting: SortingExtractor\n        The output sorting extractor. If dumpable is False it's a NumpyRecordingExtractor, otherwise it's an\n        NpzSortingExtractor\n    \"\"\"\n    upsamplefac = 13\n    waveforms, geom = synthesize_random_waveforms(K=K, M=num_channels, average_peak_amplitude=-100,\n                                                  upsamplefac=upsamplefac, seed=seed)\n    times, labels = synthesize_random_firings(K=K, duration=duration, sampling_frequency=sampling_frequency, seed=seed)\n    labels = labels.astype(np.int64)\n    SX = se.NumpySortingExtractor()\n    SX.set_times_labels(times, labels)\n    X = synthesize_timeseries(sorting=SX, waveforms=waveforms, noise_level=10, sampling_frequency=sampling_frequency,\n                              duration=duration,\n                              waveform_upsamplefac=upsamplefac, seed=seed)\n    SX.set_sampling_frequency(sampling_frequency)\n\n    RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom)\n    RX.is_filtered = True\n\n    if dumpable:\n        if dump_folder is None:\n            dump_folder = 'toy_example'\n        dump_folder = Path(dump_folder)\n\n        se.MdaRecordingExtractor.write_recording(RX, dump_folder)\n        RX = se.MdaRecordingExtractor(dump_folder)\n        se.NpzSortingExtractor.write_sorting(SX, dump_folder / 'sorting.npz')\n        SX = se.NpzSortingExtractor(dump_folder / 'sorting.npz')\n\n    return RX, SX\n"
  },
  {
    "path": "spikeextractors/exceptions.py",
    "content": "class NotDumpableExtractorError(TypeError):\n    \"\"\"Raised whenever current extractor cannot be dumped\"\"\"\n"
  },
  {
    "path": "spikeextractors/extraction_tools.py",
    "content": "import numpy as np\nimport csv\nimport os\nimport sys\nfrom pathlib import Path\nimport warnings\nimport datetime\nfrom functools import wraps\nfrom .baseextractor import BaseExtractor\nfrom tqdm import tqdm\nfrom joblib import Parallel, delayed\n\ntry:\n    import h5py\n    HAVE_H5 = True\nexcept ImportError:\n    HAVE_H5 = False\n\n\ndef read_python(path):\n    \"\"\"Parses python scripts in a dictionary\n\n    Parameters\n    ----------\n    path: str or Path\n        Path to file to parse\n\n    Returns\n    -------\n    metadata:\n        dictionary containing parsed file\n\n    \"\"\"\n    from six import exec_\n    import re\n    path = Path(path).absolute()\n    assert path.is_file()\n    with path.open('r') as f:\n        contents = f.read()\n    contents = re.sub(r'range\\(([\\d,]*)\\)',r'list(range(\\1))',contents)\n    metadata = {}\n    exec_(contents, {}, metadata)\n    metadata = {k.lower(): v for (k, v) in metadata.items()}\n    return metadata\n\n\ndef write_python(path, dict):\n    \"\"\"Saves python dictionary to file\n\n    Parameters\n    ----------\n    path: str or Path\n        Path to save file\n    dict: dict\n        dictionary to save\n    \"\"\"\n    with Path(path).open('w') as f:\n        for k, v in dict.items():\n            if isinstance(v ,str) and not v.startswith(\"'\"):\n                if 'path' in k and 'win' in sys.platform:\n                    f.write(str(k) + \" = r'\" + str(v) + \"'\\n\")\n                else:\n                    f.write(str(k) + \" = '\" + str(v) + \"'\\n\")\n            else:\n                f.write(str(k) + \" = \" + str(v) + \"\\n\")\n\n\ndef load_probe_file(recording, probe_file, channel_map=None, channel_groups=None, verbose=False):\n    \"\"\"This function returns a SubRecordingExtractor that contains information from the given\n    probe file (channel locations, groups, etc.) If a .prb file is given, then 'location' and 'group'\n    information for each channel is added to the SubRecordingExtractor. If a .csv file is given, then\n    it will only add 'location' to the SubRecordingExtractor.\n\n    Parameters\n    ----------\n    recording: RecordingExtractor\n        The recording extractor to load channel information from.\n    probe_file: str\n        Path to probe file. Either .prb or .csv\n    channel_map : array-like\n        A list of channel IDs to set in the loaded file.\n        Only used if the loaded file is a .csv.\n    channel_groups : array-like\n        A list of groups (ints) for the channel_ids to set in the loaded file.\n        Only used if the loaded file is a .csv.\n    verbose: bool\n        If True, output is verbose\n\n    Returns\n    ---------\n    subrecording: SubRecordingExtractor\n        The extractor containing all of the probe information.\n    \"\"\"\n    from .subrecordingextractor import SubRecordingExtractor\n    probe_file = Path(probe_file)\n    if probe_file.suffix == '.prb':\n        probe_dict = read_python(probe_file)\n        if 'channel_groups' in probe_dict.keys():\n            ordered_channels = np.array([], dtype=int)\n            groups = sorted(probe_dict['channel_groups'].keys())\n            for cgroup_id in groups:\n                cgroup = probe_dict['channel_groups'][cgroup_id]\n                for key_prop, prop_val in cgroup.items():\n                    if key_prop == 'channels':\n                        ordered_channels = np.concatenate((ordered_channels, prop_val))\n            if not np.all([chan in recording.get_channel_ids() for chan in ordered_channels]) and verbose:\n                print('Some channel in PRB file are not in original recording')\n            present_ordered_channels = [chan for chan in ordered_channels if chan in recording.get_channel_ids()]\n            subrecording = SubRecordingExtractor(recording, channel_ids=present_ordered_channels)\n            for cgroup_id in groups:\n                cgroup = probe_dict['channel_groups'][cgroup_id]\n                if 'channels' not in cgroup.keys() and len(groups) > 1:\n                    raise Exception(\"If more than one 'channel_group' is in the probe file, the 'channels' field\"\n                                    \"for each channel group is required\")\n                elif 'channels' not in cgroup.keys():\n                    channels_in_group = subrecording.get_num_channels()\n                    channels_id_in_group = subrecording.get_channel_ids()\n                else:\n                    channels_in_group = len(cgroup['channels'])\n                    channels_id_in_group = cgroup['channels']\n                for key_prop, prop_val in cgroup.items():\n                    if key_prop == 'channels':\n                        for i_ch, prop in enumerate(prop_val):\n                            if prop in subrecording.get_channel_ids():\n                                subrecording.set_channel_groups(int(cgroup_id), channel_ids=prop)\n                    elif key_prop == 'geometry' or key_prop == 'location':\n                        if isinstance(prop_val, dict):\n                            if len(prop_val.keys()) != channels_in_group and verbose:\n                                print('geometry in PRB does not have the same length as channel in group')\n                            for (i_ch, prop) in prop_val.items():\n                                if i_ch in subrecording.get_channel_ids():\n                                    subrecording.set_channel_locations(prop, channel_ids=i_ch)\n                        elif isinstance(prop_val, (list, np.ndarray)) and len(prop_val) == channels_in_group:\n                            if 'channels' not in cgroup.keys():\n                                raise Exception(\"'geometry'/'location' in the .prb file can be a list only if \"\n                                                \"'channels' field is specified.\")\n                            if len(prop_val) != channels_in_group and verbose:\n                                print('geometry in PRB does not have the same length as channel in group')\n                            for (i_ch, prop) in zip(channels_id_in_group, prop_val):\n                                if i_ch in subrecording.get_channel_ids():\n                                    subrecording.set_channel_locations(prop, channel_ids=i_ch)\n                    else:\n                        if isinstance(prop_val, dict) and len(prop_val.keys()) == channels_in_group:\n                            for (i_ch, prop) in prop_val.items():\n                                if i_ch in subrecording.get_channel_ids():\n                                    subrecording.set_channel_property(i_ch, key_prop, prop)\n                        elif isinstance(prop_val, (list, np.ndarray)) and len(prop_val) == channels_in_group:\n                            for (i_ch, prop) in zip(channels_id_in_group, prop_val):\n                                if i_ch in subrecording.get_channel_ids():\n                                    subrecording.set_channel_property(i_ch, key_prop, prop)\n                # create dummy locations\n                if 'geometry' not in cgroup.keys() and 'location' not in cgroup.keys():\n                    if 'location' not in subrecording.get_shared_channel_property_names():\n                        locs = np.zeros((subrecording.get_num_channels(), 2))\n                        locs[:, 1] = np.arange(subrecording.get_num_channels())\n                        subrecording.set_channel_locations(locs)\n        else:\n            raise AttributeError(\"'.prb' file should contain the 'channel_groups' field\")\n\n    elif probe_file.suffix == '.csv':\n        if channel_map is not None:\n            assert np.all([chan in channel_map for chan in recording.get_channel_ids()]), \\\n                \"all channel_ids in 'channel_map' must be in the original recording channel ids\"\n            subrecording = SubRecordingExtractor(recording, channel_ids=channel_map)\n        else:\n            subrecording = SubRecordingExtractor(recording, channel_ids=recording.get_channel_ids())\n        with probe_file.open() as csvfile:\n            posreader = csv.reader(csvfile)\n            row_count = 0\n            loaded_pos = []\n            for pos in (posreader):\n                row_count += 1\n                loaded_pos.append(pos)\n            assert len(subrecording.get_channel_ids()) == row_count, \"The .csv file must contain as many \" \\\n                                                                     \"rows as the number of channels in the recordings\"\n            for i_ch, pos in zip(subrecording.get_channel_ids(), loaded_pos):\n                if i_ch in subrecording.get_channel_ids():\n                    subrecording.set_channel_locations(list(np.array(pos).astype(float)), i_ch)\n            if channel_groups is not None and len(channel_groups) == len(subrecording.get_channel_ids()):\n                for i_ch, chg in zip(subrecording.get_channel_ids(), channel_groups):\n                    if i_ch in subrecording.get_channel_ids():\n                        subrecording.set_channel_groups(chg, i_ch)\n    else:\n        raise NotImplementedError(\"Only .csv and .prb probe files can be loaded.\")\n\n    subrecording._kwargs['probe_file'] = str(probe_file.absolute())\n    return subrecording\n\n\ndef save_to_probe_file(recording, probe_file, grouping_property=None, radius=None,\n                       graph=True, geometry=True, verbose=False):\n    \"\"\"Saves probe file from the channel information of the given recording\n    extractor.\n\n    Parameters\n    ----------\n    recording: RecordingExtractor\n        The recording extractor to save probe file from\n    probe_file: str\n        file name of .prb or .csv file to save probe information to\n    grouping_property: str (default None)\n        If grouping_property is a shared_channel_property, different groups are saved based on the property.\n    radius: float (default None)\n        Adjacency radius (used by some sorters). If None it is not saved to the probe file.\n    graph: bool\n        If True, the adjacency graph is saved (default=True)\n    geometry: bool\n        If True, the geometry is saved (default=True)\n    verbose: bool\n        If True, output is verbose\n    \"\"\"\n    probe_file = Path(probe_file)\n    if not probe_file.parent.is_dir():\n        probe_file.parent.mkdir()\n\n    if probe_file.suffix == '.csv':\n        # write csv probe file\n        with probe_file.open('w') as f:\n            if 'location' in recording.get_shared_channel_property_names():\n                for chan in recording.get_channel_ids():\n                    loc = recording.get_channel_locations(chan)[0]\n                    if len(loc) == 2:\n                        f.write(str(loc[0]))\n                        f.write(',')\n                        f.write(str(loc[1]))\n                        f.write('\\n')\n                    elif len(loc) == 3:\n                        f.write(str(loc[0]))\n                        f.write(',')\n                        f.write(str(loc[1]))\n                        f.write(',')\n                        f.write(str(loc[2]))\n                        f.write('\\n')\n            else:\n                raise AttributeError(\"Recording extractor needs to have \"\n                                     \"'location' property to save .csv probe file\")\n    elif probe_file.suffix == '.prb':\n        _export_prb_file(recording, probe_file, grouping_property=grouping_property, radius=radius, graph=graph,\n                         geometry=geometry, verbose=verbose)\n    else:\n        raise NotImplementedError(\"Only .csv and .prb probe files can be saved.\")\n\n\ndef read_binary(file, numchan, dtype, time_axis=0, offset=0):\n    \"\"\"\n    Reads binary .bin or .dat file.\n\n    Parameters\n    ----------\n    file: str\n        File name\n    numchan: int\n        Number of channels\n    dtype: dtype\n        dtype of the file\n    time_axis: 0 (default) or 1\n        If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file.\n        If 1, the traces shape (nb_channel, nb_sample) is kept in the file.\n    offset: int\n        number of offset bytes\n    \"\"\"\n    numchan = int(numchan)\n    with Path(file).open() as f:\n        nsamples = (os.fstat(f.fileno()).st_size - offset) // (numchan * np.dtype(dtype).itemsize)\n    if time_axis == 0:\n        samples = np.memmap(file, np.dtype(dtype), mode='r', offset=offset, shape=(nsamples, numchan)).T\n    else:\n        samples = np.memmap(file, np.dtype(dtype), mode='r', offset=offset, shape=(numchan, nsamples))\n    return samples\n\n\ndef write_to_binary_dat_format(recording, save_path=None, file_handle=None,\n                               time_axis=0, dtype=None, chunk_size=None, chunk_mb=500, n_jobs=1, joblib_backend='loky',\n                               return_scaled=True, verbose=False):\n    \"\"\"Saves the traces of a recording extractor in binary .dat format.\n\n    Parameters\n    ----------\n    recording: RecordingExtractor\n        The recording extractor object to be saved in .dat format\n    save_path: str\n        The path to the file.\n    file_handle: file handle\n        The file handle to dump data. This can be used to append data to an header. In case file_handle is given,\n        the file is NOT closed after writing the binary data.\n    time_axis: 0 (default) or 1\n        If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file.\n        If 1, the traces shape (nb_channel, nb_sample) is kept in the file.\n    dtype: dtype\n        Type of the saved data. Default float32.\n    chunk_size: None or int\n        Size of each chunk in number of frames.\n        If None (default) and 'chunk_mb' is given, the file is saved in chunks of 'chunk_mb' Mb (default 500Mb)\n    chunk_mb: None or int\n        Chunk size in Mb (default 500Mb)\n    n_jobs: int\n        Number of jobs to use (Default 1)\n    joblib_backend: str\n        Joblib backend for parallel processing ('loky', 'threading', 'multiprocessing')\n    return_scaled: bool\n        If True, traces are written after scaling (using gain/offset). If False, the raw traces are written\n    verbose: bool\n        If True, output is verbose (when chunks are used)\n    \"\"\"\n    assert save_path is not None or file_handle is not None, \"Provide 'save_path' or 'file handle'\"\n\n    if save_path is not None:\n        save_path = Path(save_path)\n        if save_path.suffix == '':\n            # when suffix is already raw/bin/dat do not change it.\n            save_path = save_path.parent / (save_path.name + '.dat')\n\n    if chunk_size is not None or chunk_mb is not None:\n        if time_axis == 1:\n            print(\"Chunking disabled due to 'time_axis' == 1\")\n            chunk_size = None\n            chunk_mb = None\n\n    # set chunk size\n    if chunk_size is not None:\n        chunk_size = int(chunk_size)\n    elif chunk_mb is not None:\n        n_bytes = np.dtype(recording.get_dtype()).itemsize\n        max_size = int(chunk_mb * 1e6)  # set Mb per chunk\n        chunk_size = max_size // (recording.get_num_channels() * n_bytes)\n\n    if n_jobs is None:\n        n_jobs = 1\n    if n_jobs == 0:\n        n_jobs = 1\n\n    if n_jobs > 1:\n        if chunk_size is not None:\n            chunk_size /= n_jobs\n\n    if not recording.check_if_dumpable():\n        if n_jobs > 1:\n            n_jobs = 1\n            print(\"RecordingExtractor is not dumpable and can't be processed in parallel\")\n        rec_arg = recording\n    else:\n        if n_jobs > 1:\n            rec_arg = recording.dump_to_dict()\n        else:\n            rec_arg = recording\n\n    if chunk_size is None:\n        traces = recording.get_traces(return_scaled=return_scaled)\n        if dtype is not None:\n            traces = traces.astype(dtype)\n        if time_axis == 0:\n            traces = traces.T\n        if save_path is not None:\n            with save_path.open('wb') as f:\n                traces.tofile(f)\n        else:\n            traces.tofile(file_handle)\n    else:\n        # chunk size is not None\n        num_frames = recording.get_num_frames()\n        num_channels = recording.get_num_channels()\n\n        # chunk_size = num_bytes_per_chunk / num_bytes_per_frame\n        chunks = divide_recording_into_time_chunks(\n            num_frames=num_frames,\n            chunk_size=chunk_size,\n            padding_size=0\n        )\n        n_chunk = len(chunks)\n\n        if verbose and n_jobs == 1:\n            chunks_loop = tqdm(range(n_chunk), ascii=True, desc=\"Writing to binary .dat file\")\n        else:\n            chunks_loop = range(n_chunk)\n        if save_path is not None:\n            if n_jobs == 1:\n                if time_axis == 0:\n                    shape = (num_frames, num_channels)\n                else:\n                    shape = (num_channels, num_frames)\n                rec_memmap = np.memmap(str(save_path), dtype=dtype, mode='w+', shape=shape)\n                for i in chunks_loop:\n                    _write_dat_one_chunk(i, rec_arg, chunks, rec_memmap, dtype, time_axis, return_scaled,\n                                         verbose=False)\n            else:\n                if time_axis == 0:\n                    shape = (num_frames, num_channels)\n                else:\n                    shape = (num_channels, num_frames)\n                rec_memmap = np.memmap(str(save_path), dtype=dtype, mode='w+', shape=shape)\n\n                Parallel(n_jobs=n_jobs, backend=joblib_backend)(\n                    delayed(_write_dat_one_chunk)(i, rec_arg, chunks, rec_memmap, dtype, time_axis, return_scaled,\n                                                  verbose,)\n                    for i in chunks_loop)\n        else:\n            for i in chunks_loop:\n                start_frame = chunks[i]['istart']\n                end_frame = chunks[i]['iend']\n                traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame,\n                                              return_scaled=return_scaled)\n\n                if dtype is not None:\n                    traces = traces.astype(dtype)\n                if time_axis == 0:\n                    traces = traces.T\n                file_handle.write(traces.tobytes())\n\n    return save_path\n\n\ndef write_to_h5_dataset_format(recording, dataset_path, save_path=None, file_handle=None,\n                               time_axis=0, dtype=None, chunk_size=None, chunk_mb=500, verbose=False):\n    \"\"\"Saves the traces of a recording extractor in an h5 dataset.\n\n    Parameters\n    ----------\n    recording: RecordingExtractor\n        The recording extractor object to be saved in .dat format\n    dataset_path: str\n        Path to dataset in h5 filee (e.g. '/dataset')\n    save_path: str\n        The path to the file.\n    file_handle: file handle\n        The file handle to dump data. This can be used to append data to an header. In case file_handle is given,\n        the file is NOT closed after writing the binary data.\n    time_axis: 0 (default) or 1\n        If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file.\n        If 1, the traces shape (nb_channel, nb_sample) is kept in the file.\n    dtype: dtype\n        Type of the saved data. Default float32.\n    chunk_size: None or int\n        Size of each chunk in number of frames.\n        If None (default) and 'chunk_mb' is given, the file is saved in chunks of 'chunk_mb' Mb (default 500Mb)\n    chunk_mb: None or int\n        Chunk size in Mb (default 500Mb)\n    verbose: bool\n        If True, output is verbose (when chunks are used)\n    \"\"\"\n    assert HAVE_H5, \"To write to h5 you need to install h5py: pip install h5py\"\n    assert save_path is not None or file_handle is not None, \"Provide 'save_path' or 'file handle'\"\n\n    if save_path is not None:\n        save_path = Path(save_path)\n        if save_path.suffix == '':\n            # when suffix is already raw/bin/dat do not change it.\n            save_path = save_path.parent / (save_path.name + '.h5')\n\n    num_channels = recording.get_num_channels()\n    num_frames = recording.get_num_frames()\n\n    if file_handle is not None:\n        assert isinstance(file_handle, h5py.File)\n    else:\n        file_handle = h5py.File(save_path, 'w')\n\n    if dtype is None:\n        dtype_file = recording.get_dtype()\n    else:\n        dtype_file = dtype\n\n    if time_axis == 0:\n        dset = file_handle.create_dataset(dataset_path, shape=(num_frames, num_channels), dtype=dtype_file)\n    else:\n        dset = file_handle.create_dataset(dataset_path, shape=(num_channels, num_frames), dtype=dtype_file)\n\n    # set chunk size\n    if chunk_size is not None:\n        chunk_size = int(chunk_size)\n    elif chunk_mb is not None:\n        n_bytes = np.dtype(recording.get_dtype()).itemsize\n        max_size = int(chunk_mb * 1e6)  # set Mb per chunk\n        chunk_size = max_size // (num_channels * n_bytes)\n\n    if chunk_size is None:\n        traces = recording.get_traces()\n        if dtype is not None:\n            traces = traces.astype(dtype_file)\n        if time_axis == 0:\n            traces = traces.T\n        dset[:] = traces\n    else:\n        chunk_start = 0\n        # chunk size is not None\n        n_chunk = num_frames // chunk_size\n        if num_frames % chunk_size > 0:\n            n_chunk += 1\n        if verbose:\n            chunks = tqdm(range(n_chunk), ascii=True, desc=\"Writing to .h5 file\")\n        else:\n            chunks = range(n_chunk)\n        for i in chunks:\n            traces = recording.get_traces(start_frame=i * chunk_size,\n                                          end_frame=min((i + 1) * chunk_size, num_frames))\n            chunk_frames = traces.shape[1]\n            if dtype is not None:\n                traces = traces.astype(dtype_file)\n            if time_axis == 0:\n                dset[chunk_start:chunk_start + chunk_frames] = traces.T\n            else:\n                dset[:, chunk_start:chunk_start + chunk_frames] = traces\n            chunk_start += chunk_frames\n\n    if save_path is not None:\n        file_handle.close()\n    return save_path\n\n\ndef get_sub_extractors_by_property(extractor, property_name, return_property_list=False):\n    \"\"\"Returns a list of SubExtractors from the Extractor based on the given\n    property_name (e.g. group)\n\n    Parameters\n    ----------\n    extractor: RecordingExtractor or SortingExtractor\n        The extractor object to access SubRecordingExtractors from.\n    property_name: str\n        The property used to subdivide the extractor\n    return_property_list: bool\n        If True the property list is returned\n\n    Returns\n    -------\n    sub_list: list\n        The list of subextractors to be returned.\n    OR\n    sub_list, prop_list\n        If return_property_list is True, the property list will be returned as well.\n    \"\"\"\n    from spikeextractors import RecordingExtractor, SortingExtractor, SubRecordingExtractor, SubSortingExtractor\n\n    if isinstance(extractor, RecordingExtractor):\n        if property_name not in extractor.get_shared_channel_property_names():\n            raise ValueError(\"'property_name' must be must be a property of the recording channels\")\n        else:\n            sub_list = []\n            recording = extractor\n            properties = np.array([recording.get_channel_property(chan, property_name)\n                                   for chan in recording.get_channel_ids()])\n            prop_list = np.unique(properties)\n            for prop in prop_list:\n                prop_idx = np.where(prop == properties)\n                chan_idx = list(np.array(recording.get_channel_ids())[prop_idx])\n                sub_list.append(SubRecordingExtractor(recording, channel_ids=chan_idx))\n            if return_property_list:\n                return sub_list, prop_list\n            else:\n                return sub_list\n    elif isinstance(extractor, SortingExtractor):\n        if property_name not in extractor.get_shared_unit_property_names():\n            raise ValueError(\"'property_name' must be must be a property of the units\")\n        else:\n            sub_list = []\n            sorting = extractor\n            properties = np.array([sorting.get_unit_property(unit, property_name)\n                                   for unit in sorting.get_unit_ids()])\n            prop_list = np.unique(properties)\n            for prop in prop_list:\n                prop_idx = np.where(prop == properties)\n                unit_idx = list(np.array(sorting.get_unit_ids())[prop_idx])\n                sub_list.append(SubSortingExtractor(sorting, unit_ids=unit_idx))\n            if return_property_list:\n                return sub_list, prop_list\n            else:\n                return sub_list\n    else:\n        raise ValueError(\"'extractor' must be a RecordingExtractor or a SortingExtractor\")\n\n\ndef _export_prb_file(recording, file_name, grouping_property=None, graph=True, geometry=True,\n                     radius=None, adjacency_distance=100, verbose=False):\n    \"\"\"Exports .prb file\n\n    Parameters\n    ----------\n    recording: RecordingExtractor\n        The recording extractor to save probe file from\n    file_name: str\n        probe filename to be exported to\n    grouping_property: str (default None)\n        If grouping_property is a shared_channel_property, different groups are saved based on the property.\n    graph: bool\n        If True, the adjacency graph is saved (default=True)\n    geometry: bool\n        If True, the geometry is saved (default=True)\n    radius: float (default None)\n        Adjacency radius (used by some sorters). If None it is not saved to the probe file.\n    adjacency_distance: float\n        Distance to consider two channels to adjacent (if 'location' is a property). If radius is given,\n        then adjacency_distance is set to the radius.\n    verbose : bool\n        If True, output is verbose\n    \"\"\"\n    file_name = Path(file_name)\n    assert file_name is not None\n    abspath = file_name.absolute()\n\n    if radius is not None:\n        adjacency_distance = radius\n\n    if geometry:\n        if 'location' in recording.get_shared_channel_property_names():\n            positions = recording.get_channel_locations()\n        else:\n            if verbose:\n                print(\"'location' property is not available and it will not be saved.\")\n            positions = None\n            geometry = False\n    else:\n        positions = None\n\n    if grouping_property is not None:\n        if grouping_property in recording.get_shared_channel_property_names():\n            grouping_property_groups = np.array([recording.get_channel_property(chan, grouping_property)\n                                                 for chan in recording.get_channel_ids()])\n            channel_groups = np.unique([grouping_property_groups])\n        else:\n            if verbose:\n                print(f\"{grouping_property} property is not available and it will not be saved.\")\n            channel_groups = [0]\n            grouping_property_groups = np.array([0] * recording.get_num_channels())\n    else:\n        channel_groups = [0]\n        grouping_property_groups = np.array([0] * recording.get_num_channels())\n\n    n_elec = recording.get_num_channels()\n\n    # find adjacency graph\n    if graph:\n        if positions is not None and adjacency_distance is not None:\n            adj_graph = []\n            for chg in channel_groups:\n                group_graph = []\n                elecs = list(np.where(grouping_property_groups == chg)[0])\n                for i in range(len(elecs)):\n                    for j in range(i, len(elecs)):\n                        if elecs[i] != elecs[j]:\n                            if np.linalg.norm(positions[elecs[i]] - positions[elecs[j]]) < adjacency_distance:\n                                group_graph.append((elecs[i], elecs[j]))\n                adj_graph.append(group_graph)\n        else:\n            # all connected by group\n            adj_graph = []\n            for chg in channel_groups:\n                group_graph = []\n                elecs = list(np.where(grouping_property_groups == chg)[0])\n                for i in range(len(elecs)):\n                    for j in range(i, len(elecs)):\n                        if elecs[i] != elecs[j]:\n                            group_graph.append((elecs[i], elecs[j]))\n                adj_graph.append(group_graph)\n\n    with abspath.open('w') as f:\n        f.write('total_nb_channels = ' + str(n_elec) + '\\n')\n        if radius is not None:\n            f.write('radius = ' + str(radius) + '\\n')\n        f.write('channel_groups = {\\n')\n        if len(channel_groups) > 0:\n            for i_chg, chg in enumerate(channel_groups):\n                f.write(\"     \" + str(int(chg)) + \": \")\n                elecs = list(np.where(grouping_property_groups == chg)[0])\n                f.write(\"\\n        {\\n\")\n                f.write(\"           'channels': \" + str(elecs) + ',\\n')\n                if graph:\n                    if len(adj_graph) == 1:\n                        f.write(\"           'graph':  \" + str(adj_graph[0]) + ',\\n')\n                    else:\n                        f.write(\"           'graph':  \" + str(adj_graph[i_chg]) + ',\\n')\n                if geometry:\n                    f.write(\"           'geometry':  {\\n\")\n                    for i, pos in enumerate(positions[elecs]):\n                        f.write('               ' + str(elecs[i]) + ': ' + str(list(pos)) + ',\\n')\n                    f.write('           }\\n')\n                f.write('       },\\n')\n            f.write('}\\n')\n        else:\n            for elec in range(n_elec):\n                f.write('    ' + str(elec) + ': ')\n                f.write(\"\\n        {\\n\")\n                f.write(\"           'channels': [\" + str(elec) + '],\\n')\n                f.write(\"           'graph':  [],\\n\")\n                f.write('        },\\n')\n            f.write('}\\n')\n\n\ndef _check_json(d):\n    # quick hack to ensure json writable\n    for k, v in d.items():\n        if isinstance(v, Path):\n            d[k] = str(v)\n        elif isinstance(v, (int, np.integer)):\n            d[k] = int(v)\n        elif isinstance(v, float):\n            d[k] = float(v)\n        elif isinstance(v, datetime.datetime):\n            d[k] = v.isoformat()\n\n    return d\n\n\ndef load_extractor_from_json(json_file):\n    \"\"\"\n    Instantiates extractor from json file\n\n    Parameters\n    ----------\n    json_file: str or Path\n        Path to json file\n\n    Returns\n    -------\n    extractor: RecordingExtractor or SortingExtractor\n        The loaded extractor object\n    \"\"\"\n    return BaseExtractor.load_extractor_from_json(json_file)\n\n\ndef load_extractor_from_dict(d):\n    \"\"\"\n    Instantiates extractor from dictionary\n\n    Parameters\n    ----------\n    d: dictionary\n        Python dictionary\n\n    Returns\n    -------\n    extractor: RecordingExtractor or SortingExtractor\n        The loaded extractor object\n    \"\"\"\n    return BaseExtractor.load_extractor_from_dict(d)\n\n\ndef load_extractor_from_pickle(pkl_file):\n    \"\"\"\n    Instantiates extractor from pickle file\n\n    Parameters\n    ----------\n    pkl_file: str or Path\n        Path to pickle file\n\n    Returns\n    -------\n    extractor: RecordingExtractor or SortingExtractor\n        The loaded extractor object\n    \"\"\"\n    return BaseExtractor.load_extractor_from_pickle(pkl_file)\n\n\ndef check_get_unit_spike_train(func):\n    @wraps(func)\n    def check_validity(sorting, unit_id, start_frame=None, end_frame=None):\n        # parse args and kwargs\n        if unit_id is None:\n            raise TypeError(\"get_unit_spike_train() missing 1 required positional argument: 'unit_id')\")\n        elif not (isinstance(unit_id, (int, np.integer))):\n            raise ValueError(\"unit_id must be an integer\")\n        elif unit_id not in sorting.get_unit_ids():\n            raise ValueError(f\"{unit_id} is an invalid unit id\")\n        start_frame, end_frame = cast_start_end_frame(start_frame, end_frame)\n        if start_frame is None:\n            start_frame = 0\n        if end_frame is None:\n            end_frame = np.Inf\n        return func(sorting, unit_id, start_frame=start_frame, end_frame=end_frame)\n    return check_validity\n\n\ndef check_get_traces_args(func):\n    @wraps(func)\n    def corrected_args(recording, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True, **kwargs):\n        if channel_ids is not None:\n            if isinstance(channel_ids, (int, np.integer)):\n                channel_ids = list([channel_ids])\n            else:\n                channel_ids = channel_ids\n            if np.any([ch not in recording.get_channel_ids() for ch in channel_ids]):\n                print(\"Removing invalid 'channel_ids'\",\n                      [ch for ch in channel_ids if ch not in recording.get_channel_ids()])\n                channel_ids = [ch for ch in channel_ids if ch in recording.get_channel_ids()]\n        else:\n            channel_ids = recording.get_channel_ids()\n        if start_frame is not None:\n            if start_frame < 0:\n                start_frame = recording.get_num_frames() + start_frame\n        else:\n            start_frame = 0\n        if end_frame is not None:\n            if end_frame > recording.get_num_frames():\n                print(\"'end_frame' set to\", recording.get_num_frames())\n                end_frame = recording.get_num_frames()\n            elif end_frame < 0:\n                end_frame = recording.get_num_frames() + end_frame\n        else:\n            end_frame = recording.get_num_frames()\n        assert end_frame - start_frame > 0, \"'start_frame' must be less than 'end_frame'!\"\n        start_frame, end_frame = cast_start_end_frame(start_frame, end_frame)\n\n        if not recording.has_unscaled and not return_scaled:\n            warnings.warn(\"The recording extractor does not have unscaled traces. Returning scaled traces\")\n            return_scaled = True\n\n        traces = func(recording, channel_ids=channel_ids, start_frame=start_frame, end_frame=end_frame,\n                      return_scaled=return_scaled, **kwargs)\n        # scaling\n        if recording.has_unscaled and return_scaled:\n            channel_idxs = np.array([recording.get_channel_ids().index(ch) for ch in channel_ids])\n            gains = recording.get_channel_gains()[channel_idxs, None]\n            offsets = recording.get_channel_offsets()[channel_idxs, None]\n            traces = (traces.astype(\"float32\") * gains + offsets).astype(\"float32\")\n\n        return traces\n    return corrected_args\n\n\ndef check_get_ttl_args(func):\n    @wraps(func)\n    def corrected_args(recording, start_frame=None, end_frame=None, channel_id=0, **kwargs):\n        if start_frame is not None:\n            if start_frame < 0:\n                start_frame = recording.get_num_frames() + start_frame\n        else:\n            start_frame = 0\n        if end_frame is not None:\n            if end_frame > recording.get_num_frames():\n                print(\"'end_frame' set to\", recording.get_num_frames())\n                end_frame = recording.get_num_frames()\n            elif end_frame < 0:\n                end_frame = recording.get_num_frames() + end_frame\n        else:\n            end_frame = recording.get_num_frames()\n        assert end_frame - start_frame > 0, \"'start_frame' must be less than 'end_frame'!\"\n        assert isinstance(channel_id, (int, np.integer)), \"'channel_id' must be a single int\"\n\n        start_frame, end_frame = cast_start_end_frame(start_frame, end_frame)\n        # pass recording as arg and rest as kwargs\n        get_ttl_correct_arg = func(recording, start_frame=start_frame, end_frame=end_frame, channel_id=channel_id,\n                                   **kwargs)\n        return get_ttl_correct_arg\n    return corrected_args\n\n\ndef cast_start_end_frame(start_frame, end_frame):\n    if isinstance(start_frame, float):\n        start_frame = int(start_frame)\n    elif isinstance(start_frame, (int, np.integer, type(None))):\n        start_frame = start_frame\n    else:\n        raise ValueError(\"start_frame must be an int, float (not infinity), or None\")\n    if isinstance(end_frame, float) and np.isfinite(end_frame):\n        end_frame = int(end_frame)\n    elif isinstance(end_frame, (int, np.integer, type(None))):\n        end_frame = end_frame\n    # else end_frame is infinity (accepted for get_unit_spike_train)\n    if start_frame is not None:\n        start_frame = int(start_frame)\n    if end_frame is not None and np.isfinite(end_frame):\n        end_frame = int(end_frame)\n    return start_frame, end_frame\n\n\ndef divide_recording_into_time_chunks(num_frames, chunk_size, padding_size):\n    chunks = []\n    ii = 0\n    while ii < num_frames:\n        ii2 = int(min(ii + chunk_size, num_frames))\n        chunks.append(dict(\n            istart=ii,\n            iend=ii2,\n            istart_with_padding=int(max(0, ii - padding_size)),\n            iend_with_padding=int(min(num_frames, ii2 + padding_size))\n        ))\n        ii = ii2\n    return chunks\n\n\ndef _write_dat_one_chunk(i, rec_arg, chunks, rec_memmap, dtype, time_axis, return_scaled, verbose):\n    chunk = chunks[i]\n\n    if verbose:\n        print(f\"Writing chunk {i + 1} / {len(chunks)}\")\n    if isinstance(rec_arg, dict):\n        recording = load_extractor_from_dict(rec_arg)\n    else:\n        recording = rec_arg\n\n    start_frame = chunk['istart']\n    end_frame = chunk['iend']\n    traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, return_scaled=return_scaled)\n    if dtype is not None:\n        traces = traces.astype(dtype)\n    if time_axis == 0:\n        traces = traces.T\n        rec_memmap[start_frame:end_frame, :] = traces\n    else:\n        rec_memmap[:, start_frame:end_frame] = traces\n"
  },
  {
    "path": "spikeextractors/extractorlist.py",
    "content": "from .extractors.mdaextractors.mdaextractors import MdaRecordingExtractor, MdaSortingExtractor\nfrom .extractors.mearecextractors.mearecextractors import MEArecRecordingExtractor, MEArecSortingExtractor\nfrom .extractors.biocamrecordingextractor.biocamrecordingextractor import BiocamRecordingExtractor\nfrom .extractors.exdirextractors.exdirextractors import ExdirRecordingExtractor, ExdirSortingExtractor\nfrom .extractors.intanrecordingextractor.intanrecordingextractor import IntanRecordingExtractor\nfrom .extractors.hdsortsortingextractor.hdsortsortingextractor import HDSortSortingExtractor\nfrom .extractors.hs2sortingextractor.hs2sortingextractor import HS2SortingExtractor\nfrom .extractors.klustaextractors.klustaextractors import KlustaSortingExtractor, KlustaRecordingExtractor\nfrom .extractors.kilosortextractors.kilosortextractors import KiloSortSortingExtractor, KiloSortRecordingExtractor\nfrom .extractors.numpyextractors.numpyextractors import NumpyRecordingExtractor, NumpySortingExtractor\nfrom .extractors.nwbextractors.nwbextractors import NwbRecordingExtractor, NwbSortingExtractor\nfrom .extractors.openephysextractors.openephysextractors import OpenEphysRecordingExtractor, \\\n    OpenEphysSortingExtractor, OpenEphysNPIXRecordingExtractor\nfrom .extractors.maxwellextractors import MaxOneRecordingExtractor, MaxOneSortingExtractor, MaxTwoRecordingExtractor, \\\n    MaxTwoSortingExtractor\nfrom .extractors.phyextractors.phyextractors import PhyRecordingExtractor, PhySortingExtractor\nfrom .extractors.bindatrecordingextractor.bindatrecordingextractor import BinDatRecordingExtractor\nfrom .extractors.spykingcircusextractors.spykingcircusextractors import SpykingCircusSortingExtractor, \\\n    SpykingCircusRecordingExtractor\nfrom .extractors.spikeglxrecordingextractor.spikeglxrecordingextractor import SpikeGLXRecordingExtractor\nfrom .extractors.tridescloussortingextractor.tridescloussortingextractor import TridesclousSortingExtractor\nfrom .extractors.npzsortingextractor.npzsortingextractor import NpzSortingExtractor\nfrom .extractors.mcsh5recordingextractor.mcsh5recordingextractor import MCSH5RecordingExtractor\nfrom .extractors.shybridextractors import SHYBRIDRecordingExtractor, SHYBRIDSortingExtractor\nfrom .extractors.nixioextractors.nixioextractors import NIXIORecordingExtractor, NIXIOSortingExtractor\nfrom .extractors.neoextractors import (AxonaRecordingExtractor, PlexonRecordingExtractor, PlexonSortingExtractor,\n                                       NeuralynxRecordingExtractor, NeuralynxSortingExtractor,\n                                       BlackrockRecordingExtractor, BlackrockSortingExtractor,\n                                       MCSRawRecordingExtractor, SpikeGadgetsRecordingExtractor)\nfrom .extractors.neuroscopeextractors import NeuroscopeRecordingExtractor, NeuroscopeMultiRecordingTimeExtractor, \\\n    NeuroscopeSortingExtractor, NeuroscopeMultiSortingExtractor\nfrom .extractors.waveclussortingextractor import WaveClusSortingExtractor\nfrom .extractors.yassextractors import YassSortingExtractor\nfrom .extractors.combinatosortingextractor import CombinatoSortingExtractor\nfrom .extractors.alfsortingextractor import ALFSortingExtractor\nfrom .extractors.cedextractors import CEDRecordingExtractor\nfrom .extractors.cellexplorersortingextractor import CellExplorerSortingExtractor\nfrom .extractors.neuropixelsdatrecordingextractor import NeuropixelsDatRecordingExtractor\nfrom .extractors.axonaunitrecordingextractor import AxonaUnitRecordingExtractor\n\nrecording_extractor_full_list = [\n    MdaRecordingExtractor,\n    MEArecRecordingExtractor,\n    BiocamRecordingExtractor,\n    ExdirRecordingExtractor,\n    OpenEphysRecordingExtractor,\n    OpenEphysNPIXRecordingExtractor,\n    IntanRecordingExtractor,\n    BinDatRecordingExtractor,\n    KlustaRecordingExtractor,\n    KiloSortRecordingExtractor,\n    SpykingCircusRecordingExtractor,\n    SpikeGLXRecordingExtractor,\n    PhyRecordingExtractor,\n    MaxOneRecordingExtractor,\n    MaxTwoRecordingExtractor,\n    MCSH5RecordingExtractor,\n    SHYBRIDRecordingExtractor,\n    NIXIORecordingExtractor,\n    NwbRecordingExtractor,\n    NeuroscopeRecordingExtractor,\n    NeuroscopeMultiRecordingTimeExtractor,\n    CEDRecordingExtractor,\n    NeuropixelsDatRecordingExtractor,\n    AxonaUnitRecordingExtractor,\n\n    # neo based\n    AxonaRecordingExtractor,\n    PlexonRecordingExtractor,\n    NeuralynxRecordingExtractor,\n    BlackrockRecordingExtractor,\n    MCSRawRecordingExtractor,\n    SpikeGadgetsRecordingExtractor,\n]\n\nrecording_extractor_dict = {recording_class.extractor_name: recording_class\n                            for recording_class in recording_extractor_full_list}\ninstalled_recording_extractor_list = [rx for rx in recording_extractor_full_list if rx.installed]\n\nsorting_extractor_full_list = [\n    MdaSortingExtractor,\n    MEArecSortingExtractor,\n    ExdirSortingExtractor,\n    HDSortSortingExtractor,\n    HS2SortingExtractor,\n    KlustaSortingExtractor,\n    KiloSortSortingExtractor,\n    OpenEphysSortingExtractor,\n    PhySortingExtractor,\n    SpykingCircusSortingExtractor,\n    TridesclousSortingExtractor,\n    MaxTwoSortingExtractor,\n    MaxOneSortingExtractor,\n    NpzSortingExtractor,\n    SHYBRIDSortingExtractor,\n    NIXIOSortingExtractor,\n    NeuroscopeSortingExtractor,\n    NeuroscopeMultiSortingExtractor,\n    NwbSortingExtractor,\n    WaveClusSortingExtractor,\n    YassSortingExtractor,\n    CombinatoSortingExtractor,\n    ALFSortingExtractor,\n    # neo based\n    PlexonSortingExtractor,\n    NeuralynxSortingExtractor,\n    BlackrockSortingExtractor,\n    CellExplorerSortingExtractor\n]\n\ninstalled_sorting_extractor_list = [sx for sx in sorting_extractor_full_list if sx.installed]\nsorting_extractor_dict = {sorting_class.extractor_name: sorting_class for sorting_class in sorting_extractor_full_list}\n\nwritable_sorting_extractor_list = [sx for sx in installed_sorting_extractor_list if sx.is_writable]\nwritable_sorting_extractor_dict = {sorting_class.extractor_name: sorting_class\n                                   for sorting_class in writable_sorting_extractor_list}\n"
  },
  {
    "path": "spikeextractors/extractors/__init__.py",
    "content": ""
  },
  {
    "path": "spikeextractors/extractors/alfsortingextractor/__init__.py",
    "content": "from .alfsortingextractor import ALFSortingExtractor"
  },
  {
    "path": "spikeextractors/extractors/alfsortingextractor/alfsortingextractor.py",
    "content": "from abc import ABC\n\nfrom spikeextractors import SortingExtractor\nfrom pathlib import Path\nimport numpy as np\n\ntry:\n    import pandas as pd\n\n    HAVE_PANDAS = True\nexcept:\n    HAVE_PANDAS = False\n\n\nclass ALFSortingExtractor(SortingExtractor):\n    extractor_name = 'ALFSorting'\n    installed = HAVE_PANDAS  # check at class level if installed or not\n    is_writable = True\n    mode = 'folder'\n    installation_mesg = \"To use the ALFSortingExtractor run:\\n\\n pip install pandas\\n\\n\"\n\n    def __init__(self, folder_path, sampling_frequency=30000):\n        assert self.installed, self.installation_mesg\n        SortingExtractor.__init__(self)\n        # check correct parent folder:\n        self.file_loc = Path(folder_path)\n        if 'probe' not in Path(self.file_loc).name:\n            raise ValueError('folder name should contain \"probe\", containing channels, clusters.* .npy datasets')\n        # load datasets as mmap into a dict:\n        self._required_alf_datasets = ['spikes.times', 'spikes.clusters']\n        self._found_alf_datasets = dict()\n        for alf_dataset_name in self.file_loc.iterdir():\n            if 'spikes' in alf_dataset_name.stem or 'clusters' in alf_dataset_name.stem:\n                if 'npy' in alf_dataset_name.suffix:\n                    self._found_alf_datasets.update({alf_dataset_name.stem: self._load_npy(alf_dataset_name)})\n                elif 'metrics' in alf_dataset_name.stem:\n                    self._found_alf_datasets.update({alf_dataset_name.stem: pd.read_csv(alf_dataset_name)})\n        # check existence of datasets:\n        if not any([i in self._found_alf_datasets for i in self._required_alf_datasets]):\n            raise Exception(f'could not find {self._required_alf_datasets} in folder')\n        # setting units properties:\n        self._total_units = 0\n        for alf_dataset_name, alf_dataset in self._found_alf_datasets.items():\n            if 'clusters' in alf_dataset_name:\n                if 'clusters.metrics' in alf_dataset_name:\n                    for property_name, property_values in self._found_alf_datasets[alf_dataset_name].iteritems():\n                        self.set_units_property(unit_ids=self.get_unit_ids(),\n                                                property_name=property_name,\n                                                values=property_values.tolist())\n                else:\n                    self.set_units_property(unit_ids=self.get_unit_ids(),\n                                            property_name=alf_dataset_name.split('.')[1],\n                                            values=alf_dataset)\n                    if self._total_units == 0:\n                        self._total_units = alf_dataset.shape[0]\n        self._units_map = {i: j for i, j in zip(self.get_unit_ids(), list(range(self._total_units)))}\n        self._units_raster = []\n        self._sampling_frequency = sampling_frequency\n        self._kwargs = {'folder_path': str(Path(folder_path).absolute()), 'sampling_frequency': sampling_frequency}\n\n    def _load_npy(self, npy_path):\n        return np.load(npy_path, mmap_mode='r',allow_pickle=True)\n\n    def _get_clusters_spike_times(self, cluster_idx):\n        if len(self._units_raster) == 0:\n            spike_cluster_data = self._found_alf_datasets['spikes.clusters']\n            spike_times_data = self._found_alf_datasets['spikes.times']\n            df = pd.DataFrame({'sp_cluster': spike_cluster_data, 'sp_times': spike_times_data})\n            data = df.groupby(['sp_cluster'])['sp_times'].apply(np.array).reset_index(name='sp_times_group')\n            self._max_time = 0\n            self._units_raster = [None]*self._total_units\n            for index, sp_times_list in data.values:\n                self._units_raster[index] = sp_times_list\n                max_time = max(sp_times_list)\n                if max_time > self._max_time:\n                    self._max_time = max_time\n        return self._units_raster[cluster_idx]\n\n    def get_unit_ids(self):\n        if 'clusters.metrics' in self._found_alf_datasets and \\\n                self._found_alf_datasets['clusters.metrics'].get('cluster_id') is not None:\n            unit_ids = self._found_alf_datasets['clusters.metrics'].get('cluster_id').tolist()\n        else:\n            unit_ids = list(range(self._total_units))\n        return unit_ids\n\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n\n        \"\"\"Code to extract spike frames from the specified unit.\n        It will return spike frames from within three ranges:\n            [start_frame, t_start+1, ..., end_frame-1]\n            [start_frame, start_frame+1, ..., final_unit_spike_frame - 1]\n            [0, 1, ..., end_frame-1]\n            [0, 1, ..., final_unit_spike_frame - 1]\n        if both start_frame and end_frame are given, if only start_frame is\n        given, if only end_frame is given, or if neither start_frame or end_frame\n        are given, respectively. Spike frames are returned in the form of an\n        array_like of spike frames. In this implementation, start_frame is inclusive\n        and end_frame is exclusive conforming to numpy standards.\n\n        \"\"\"\n        unit_idx = self._units_map.get(unit_id)\n        if unit_idx is None:\n            raise ValueError(f'enter one of unit_id={self.get_unit_ids()}')\n        cluster_sp_times = self._get_clusters_spike_times(unit_idx)\n        if cluster_sp_times is None:\n            return np.array([])\n        max_frame = np.ceil(cluster_sp_times[-1]*self.get_sampling_frequency()).astype('int64')\n        min_frame = np.floor(cluster_sp_times[0]*self.get_sampling_frequency()).astype('int64')\n        start_frame = min_frame if start_frame is None or start_frame < min_frame else start_frame\n        end_frame = max_frame if end_frame is None or end_frame > max_frame else end_frame\n        if start_frame > max_frame or end_frame < min_frame:\n            raise ValueError(f'Use start_frame to end_frame between {min_frame} and {max_frame}')\n        cluster_sp_frames = (cluster_sp_times * self.get_sampling_frequency()).astype('int64')\n        frame_idx = np.where((cluster_sp_frames >= start_frame) &\n                            (cluster_sp_frames < end_frame))\n        return cluster_sp_frames[frame_idx]\n\n    @staticmethod\n    def write_sorting(sorting, save_path):\n        \"\"\"\n        This is an example of a function that is not abstract so it is optional if you want to override it. It allows other\n        SortingExtractors to use your new SortingExtractor to convert their sorted data into your\n        sorting file format.\n        \"\"\"\n        assert HAVE_PANDAS, ALFSortingExtractor.installation_mesg\n        # write cluster properties as clusters.<property_name>.npy\n        save_path = Path(save_path)\n        csv_property_names = ['cluster_id', 'cluster_id.1', 'num_spikes', 'firing_rate',\n            'presence_ratio', 'presence_ratio_std', 'frac_isi_viol',\n            'contamination_est', 'contamination_est2', 'missed_spikes_est',\n            'cum_amp_drift', 'max_amp_drift', 'cum_depth_drift', 'max_depth_drift',\n            'ks2_contamination_pct', 'ks2_label','amplitude_cutoff', 'amplitude_std',\n            'epoch_name', 'isi_viol']\n        clusters_metrics_df = pd.DataFrame()\n        for property_name in sorting.get_unit_property_names(0):\n            data = sorting.get_units_property(property_name=property_name)\n            if property_name not in csv_property_names:\n                np.save(save_path/f'clusters.{property_name}', data)\n            else:\n                clusters_metrics_df[property_name] = data\n        clusters_metrics_df.to_csv(save_path/'clusters.metrics.csv')\n        # save spikes.times, spikes.clusters\n        clusters_number = []\n        unit_spike_times = []\n        for unit_no, unit_id in enumerate(sorting.get_unit_ids()):\n            unit_spike_train = sorting.get_unit_spike_train(unit_id=unit_id)\n            if unit_spike_train is not None:\n                unit_spike_times.extend(np.array(unit_spike_train)/sorting.get_sampling_frequency())\n                clusters_number.extend([unit_no]*len(unit_spike_train))\n        unit_spike_train = np.array(unit_spike_times)\n        clusters_number = np.array(clusters_number)\n        spike_times_ids = np.argsort(unit_spike_train)\n        spike_times = unit_spike_train[spike_times_ids]\n        spike_clusters = clusters_number[spike_times_ids]\n        np.save(save_path/'spikes.times', spike_times)\n        np.save(save_path/'spikes.clusters', spike_clusters)\n"
  },
  {
    "path": "spikeextractors/extractors/axonaunitrecordingextractor/__init__.py",
    "content": "from .axonaunitrecordingextractor import AxonaUnitRecordingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/axonaunitrecordingextractor/axonaunitrecordingextractor.py",
    "content": "from spikeextractors.extraction_tools import check_get_traces_args\nfrom spikeextractors.extractors.neoextractors.neobaseextractor import (\n    _NeoBaseExtractor, NeoBaseRecordingExtractor)\nfrom spikeextractors import RecordingExtractor\nfrom pathlib import Path\nimport numpy as np\nfrom typing import Union\nimport warnings\n\nPathType = Union[Path, str]\n\ntry:\n    import neo\n    from neo.rawio.baserawio import _signal_channel_dtype, _signal_stream_dtype\n    HAVE_NEO = True\nexcept ImportError:\n    HAVE_NEO = False\n\n\nclass AxonaUnitRecordingExtractor(NeoBaseRecordingExtractor, RecordingExtractor, _NeoBaseExtractor):\n    \"\"\"\n    Instantiates a RecordingExtractor from an Axona Unit mode file.\n\n    Since the unit mode format only saves waveform cutouts, the get_traces\n    function fills in the rest of the recording with Gaussian uncorrelated\n    noise\n\n    Parameters\n    ----------\n\n    noise_std: float\n        Standard deviation of the Gaussian background noise (default 3)\n    \"\"\"\n    extractor_name = 'AxonaUnitRecording'\n    mode = 'file'\n    NeoRawIOClass = 'AxonaRawIO'\n\n    def __init__(self, noise_std: float = 3, block_index=None, seg_index=None, **kargs):\n        RecordingExtractor.__init__(self)\n        _NeoBaseExtractor.__init__(self, block_index=block_index, seg_index=seg_index, **kargs)\n\n        # Enforce 1 signal stream (there are 0 raw streams), we will create 1 from waveforms\n        signal_streams = self.neo_reader._get_signal_streams_header()\n        signal_channels = self.neo_reader._get_signal_chan_header()\n        self.neo_reader.header['signal_streams'] = np.array(signal_streams,\n                                                            dtype=_signal_stream_dtype)\n        self.neo_reader.header['signal_channels'] = np.array(signal_channels,\n                                                             dtype=_signal_channel_dtype)\n\n        if hasattr(self.neo_reader, 'get_group_signal_channel_indexes'):\n            # Neo >= 0.9.0\n            channel_indexes_list = self.neo_reader.get_group_signal_channel_indexes()\n            num_streams = len(channel_indexes_list)\n            assert num_streams <= 1, 'This file have several channel groups spikeextractors support only one groups'\n            self.after_v10 = False\n        elif hasattr(self.neo_reader, 'get_group_channel_indexes'):\n            # Neo < 0.9.0\n            channel_indexes_list = self.neo_reader.get_group_channel_indexes()\n            num_streams = len(channel_indexes_list)\n            self.after_v10 = False\n        elif hasattr(self.neo_reader, 'signal_streams_count'):\n            # Neo >= 0.10.0 (not release yet in march 2021)\n            num_streams = self.neo_reader.signal_streams_count()\n            self.after_v10 = True\n        else:\n            raise ValueError('Strange neo version. Please upgrade your neo package: pip install --upgrade neo')\n\n        assert num_streams <= 1, 'This file have several signal streams spikeextractors support only one streams' \\\n                                 'Maybe you can use option to select only one stream'\n\n        # spikeextractor for units to be uV implicitly\n        # check that units are V, mV or uV\n        units = self.neo_reader.header['signal_channels']['units']\n        assert np.all(np.isin(units, ['V', 'mV', 'uV'])), 'Signal units no Volt compatible'\n        self.additional_gain = np.ones(units.size, dtype='float')\n        self.additional_gain[units == 'V'] = 1e6\n        self.additional_gain[units == 'mV'] = 1e3\n        self.additional_gain[units == 'uV'] = 1.\n        self.additional_gain = self.additional_gain.reshape(1, -1)\n\n        # Add channels properties\n        header_channels = self.neo_reader.header['signal_channels'][slice(None)]\n        self._neo_chan_ids = self.neo_reader.header['signal_channels']['id']\n\n        # In neo there is not guarantee that channel ids are unique.\n        # for instance Blacrock can have several times the same chan_id\n        # different sampling rate\n        # so check it\n        assert np.unique(self._neo_chan_ids).size == self._neo_chan_ids.size, 'In this format channel ids are not ' \\\n                                                                              'unique! Incompatible with SpikeInterface'\n\n        try:\n            channel_ids = [int(ch) for ch in self._neo_chan_ids]\n        except Exception as e:\n            warnings.warn(\"Could not parse channel ids to int: using linear channel map\")\n            channel_ids = list(np.arange(len(self._neo_chan_ids)))\n        self._channel_ids = channel_ids\n\n        gains = header_channels['gain'] * self.additional_gain[0]\n        self.set_channel_gains(gains=gains, channel_ids=self._channel_ids)\n\n        names = header_channels['name']\n        for i, ind in enumerate(self._channel_ids):\n            self.set_channel_property(channel_id=ind, property_name='name', value=names[i])\n\n        self._noise_std = noise_std\n\n        # Read channel groups by tetrode IDs\n        self.set_channel_groups(groups=[\n            tetrode_id - 1 for tetrode_id in self.neo_reader.get_active_tetrode() for _ in range(4)])\n\n        header_channels = self.neo_reader.header['signal_channels'][slice(None)]\n\n        names = header_channels['name']\n        channel_ids = self.get_channel_ids()\n        for i, ind in enumerate(channel_ids):\n            self.set_channel_property(channel_id=ind, property_name='name', value=names[i])\n\n        # Set channel gains for int8 .X Unit data\n        gains = self.neo_reader._get_channel_gain(bytes_per_sample=1)[0:len(channel_ids)]\n        self.set_channel_gains(gains, channel_ids=channel_ids)\n\n    @check_get_traces_args\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n\n        timebase_sr = int(self.neo_reader.file_parameters['unit']['timebase'].split(' ')[0])\n        samples_pre = int(self.neo_reader.file_parameters['set']['file_header']['pretrigSamps'])\n        samples_post = int(self.neo_reader.file_parameters['set']['file_header']['spikeLockout'])\n        sampling_rate = self.get_sampling_frequency()\n\n        tcmap = self._get_tetrode_channel_table(channel_ids)\n\n        traces = self._noise_std * np.random.randn(len(channel_ids), end_frame - start_frame)\n        if return_scaled:\n            traces = traces.astype(np.float32)\n        else:\n            traces = traces.astype(np.int8)\n\n        # Loop through tetrodes and include requested channels in traces\n        itrc = 0\n        for tetrode_id in np.unique(tcmap[:, 0]):\n\n            channels_oi = tcmap[tcmap[:, 0] == tetrode_id, 2]\n\n            waveforms = self.neo_reader._get_spike_raw_waveforms(\n                block_index=0, seg_index=0,\n                unit_index=tetrode_id - 1,  # Tetrodes IDs are 1-indexed\n                t_start=start_frame / sampling_rate,\n                t_stop=end_frame / sampling_rate\n            )\n            waveforms = waveforms[:, channels_oi, :]\n            nch = len(channels_oi)\n\n            spike_train = self.neo_reader._get_spike_timestamps(\n                block_index=0, seg_index=0,\n                unit_index=tetrode_id - 1,\n                t_start=start_frame / sampling_rate,\n                t_stop=end_frame / sampling_rate\n            )\n\n            # Fill waveforms into traces timestamp by timestamp\n            for t, wf in zip(spike_train, waveforms):\n\n                t = int(t // (timebase_sr / sampling_rate))  # timestamps are sampled at higher frequency\n                t = t - start_frame\n                if (t - samples_pre < 0) and (t + samples_post > traces.shape[1]):\n                    traces[itrc:itrc + nch, :] = wf[:, samples_pre - t:traces.shape[1] - (t - samples_pre)]\n                elif t - samples_pre < 0:\n                    traces[itrc:itrc + nch, :t + samples_post] = wf[:, samples_pre - t:]\n                elif t + samples_post > traces.shape[1]:\n                    traces[itrc:itrc + nch, t - samples_pre:] = wf[:, :traces.shape[1] - (t - samples_pre)]\n                else:\n                    traces[itrc:itrc + nch, t - samples_pre:t + samples_post] = wf\n\n            itrc += nch\n\n        return traces\n\n    def get_num_frames(self):\n        n = int(self.neo_reader.segment_t_stop(block_index=0, seg_index=0) * self.get_sampling_frequency())\n        if self.get_sampling_frequency() == 24000:\n            n = n // 2\n        return n\n\n    def get_sampling_frequency(self):\n        return int(self.neo_reader.header['spike_channels'][0][-1])\n\n    def get_channel_ids(self):\n        return self._channel_ids\n\n    def _get_tetrode_channel_table(self, channel_ids):\n        '''Create auxiliary np.array with the following columns:\n        Tetrode ID, Channel ID, Channel ID within tetrode\n        This is useful in `get_traces()`\n\n        Parameters\n        ----------\n        channel_ids : list\n            List of channel ids to include in table\n\n        Returns\n        -------\n        np.array\n            Rows = channels,\n            columns = TetrodeID, ChannelID, ChannelID within Tetrode\n        '''\n        active_tetrodes = self.neo_reader.get_active_tetrode()\n\n        tcmap = np.zeros((len(active_tetrodes) * 4, 3), dtype=int)\n        row_id = 0\n        for tetrode_id in [int(s[0].split(' ')[1]) for s in self.neo_reader.header['spike_channels']]:\n\n            all_channel_ids = self.neo_reader._get_channel_from_tetrode(tetrode_id)\n\n            for i in range(4):\n                tcmap[row_id, 0] = int(tetrode_id)\n                tcmap[row_id, 1] = int(all_channel_ids[i])\n                tcmap[row_id, 2] = int(i)\n                row_id += 1\n\n        del_idx = [False if i in channel_ids else True for i in tcmap[:, 1]]\n\n        return np.delete(tcmap, del_idx, axis=0)\n"
  },
  {
    "path": "spikeextractors/extractors/bindatrecordingextractor/__init__.py",
    "content": "from .bindatrecordingextractor import BinDatRecordingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/bindatrecordingextractor/bindatrecordingextractor.py",
    "content": "import shutil\nimport numpy as np\nfrom pathlib import Path\nfrom typing import Union, Optional\n\nfrom spikeextractors import RecordingExtractor\nfrom spikeextractors.extraction_tools import read_binary, write_to_binary_dat_format, check_get_traces_args\n\nPathType = Union[str, Path]\nDtypeType = Union[str, np.dtype]\nArrayType = Union[list, np.ndarray]\nOptionalDtypeType = Optional[DtypeType]\nOptionalArrayType = Optional[Union[np.ndarray, list]]\n\n\nclass BinDatRecordingExtractor(RecordingExtractor):\n    \"\"\"\n    RecordingExtractor for a binary format\n\n    Parameters\n    ----------\n    file_path: str or Path\n        Path to the binary file\n    sampling_frequency: float\n        The sampling frequncy\n    numchan: int\n        Number of channels\n    dtype: str or dtype\n        The dtype of the binary file\n    time_axis: int\n        The axis of the time dimension (default 0: F order)\n    recording_channels: list (optional)\n        A list of channel ids\n    geom: array-like (optional)\n        A list or array with channel locations\n    file_offset: int (optional)\n        Number of bytes in the file to offset by during memmap instantiation.\n    gain: float or array-like (optional)\n        The gain to apply to the traces\n    channel_offset: float or array-like\n        The offset to apply to the traces\n    is_filtered: bool\n        If True, the recording is assumed to be filtered\n    \"\"\"\n    extractor_name = 'BinDatRecording'\n    has_default_locations = False\n    has_unscaled = False\n    installed = True\n    is_writable = True\n    mode = \"file\"\n    installation_mesg = \"\"\n\n    def __init__(self, file_path: PathType, sampling_frequency: float, numchan: int, dtype: DtypeType,\n                 time_axis: int = 0, recording_channels: Optional[list] = None,  geom: Optional[ArrayType] = None,\n                 file_offset: Optional[float] = 0,\n                 gain: Optional[Union[float, ArrayType]] = None,\n                 channel_offset: Optional[Union[float, ArrayType]] = None,\n                 is_filtered: Optional[bool] = None):\n        RecordingExtractor.__init__(self)\n        self._datfile = Path(file_path)\n        self._time_axis = time_axis\n        self._dtype = np.dtype(dtype).name\n        self._sampling_frequency = float(sampling_frequency)\n        self._numchan = numchan\n        self._geom = geom\n        self._timeseries = read_binary(self._datfile, numchan, dtype, time_axis, file_offset)\n\n        if is_filtered is not None:\n            self.is_filtered = is_filtered\n        else:\n            self.is_filtered = False\n\n        if recording_channels is not None:\n            assert len(recording_channels) <= self._timeseries.shape[0], \\\n               'Provided recording channels have the wrong length'\n            self._channels = recording_channels\n        else:\n            self._channels = list(range(self._timeseries.shape[0]))\n\n        if len(self._channels) == self._timeseries.shape[0]:\n            self._complete_channels = True\n        else:\n            assert max(self._channels) < self._timeseries.shape[0], \"Channel ids exceed the number of \" \\\n                                                                    \"available channels\"\n            self._complete_channels = False\n\n        if geom is not None:\n            self.set_channel_locations(self._geom)\n            self.has_default_locations = True\n\n        if 'numpy' in str(dtype):\n            dtype_str = str(dtype).replace(\"<class '\", \"\").replace(\"'>\", \"\")\n            dtype_str = dtype_str.split('.')[1]\n        else:\n            dtype_str = str(dtype)\n\n        if gain is not None:\n            self.set_channel_gains(channel_ids=self.get_channel_ids(), gains=gain)\n            self.has_unscaled = True\n\n        if channel_offset is not None:\n            self.set_channel_offsets(channel_offset)\n\n        self._kwargs = {'file_path': str(Path(file_path).absolute()), 'sampling_frequency': sampling_frequency,\n                        'numchan': numchan, 'dtype': dtype_str, 'recording_channels': recording_channels,\n                        'time_axis': time_axis, 'geom': geom, 'file_offset': file_offset, 'gain': gain,\n                        'is_filtered': is_filtered}\n\n    def get_channel_ids(self):\n        return self._channels\n\n    def get_num_frames(self):\n        return self._timeseries.shape[1]\n\n    def get_sampling_frequency(self):\n        return self._sampling_frequency\n\n    @check_get_traces_args\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        if self._complete_channels:\n            if np.array_equal(channel_ids, self.get_channel_ids()):\n                traces = self._timeseries[:, start_frame:end_frame]\n            else:\n                channel_idxs = np.array([self.get_channel_ids().index(ch) for ch in channel_ids])\n                if np.all(np.diff(channel_idxs) == 1):\n                    traces = self._timeseries[channel_idxs[0]:channel_idxs[0]+len(channel_idxs),\n                                              start_frame:end_frame]\n                else:\n                    # This block of the execution will return the data as an array, not a memmap\n                    traces = self._timeseries[channel_idxs, start_frame:end_frame]\n        else:\n            # in this case channel ids are actually indexes\n            traces = self._timeseries[channel_ids, start_frame:end_frame]\n        return traces\n\n    @staticmethod\n    def write_recording(\n        recording: RecordingExtractor,\n        save_path: PathType,\n        time_axis: int = 0,\n        dtype: OptionalDtypeType = None,\n        **write_binary_kwargs\n    ):\n        \"\"\"\n        Save the traces of a recording extractor in binary .dat format.\n\n        Parameters\n        ----------\n        recording : RecordingExtractor\n            The recording extractor object to be saved in .dat format.\n        save_path : str\n            The path to the file.\n        time_axis : int, optional\n            If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file.\n            If 1, the traces shape (nb_channel, nb_sample) is kept in the file.\n        dtype : dtype\n            Type of the saved data. Default float32.\n        **write_binary_kwargs: keyword arguments for write_to_binary_dat_format() function\n        \"\"\"\n        write_to_binary_dat_format(recording, save_path, time_axis=time_axis, dtype=dtype,\n                                   **write_binary_kwargs)\n"
  },
  {
    "path": "spikeextractors/extractors/biocamrecordingextractor/__init__.py",
    "content": "from .biocamrecordingextractor import BiocamRecordingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/biocamrecordingextractor/biocamrecordingextractor.py",
    "content": "from spikeextractors import RecordingExtractor\nfrom spikeextractors.extraction_tools import check_get_traces_args\nimport numpy as np\nfrom pathlib import Path\nimport ctypes\n\ntry:\n    import h5py\n    HAVE_BIOCAM = True\nexcept ImportError:\n    HAVE_BIOCAM = False\n\n\nclass BiocamRecordingExtractor(RecordingExtractor):\n    extractor_name = 'BiocamRecording'\n    has_default_locations = True\n    has_unscaled = False\n    installed = HAVE_BIOCAM  # check at class level if installed or not\n    is_writable = True\n    mode = 'file'\n    installation_mesg = \"To use the BiocamRecordingExtractor install h5py: \\n\\n pip install h5py\\n\\n\"  # error message when not installed\n\n    def __init__(self, file_path, verbose=False, mea_pitch=42):\n        assert self.installed, self.installation_mesg\n        self._mea_pitch = mea_pitch\n        self._recording_file = file_path\n        self._rf, self._nFrames, self._samplingRate, self._nRecCh, self._chIndices, \\\n        self._file_format, self._signalInv, self._positions, self._read_function = openBiocamFile(\n            self._recording_file, self._mea_pitch, verbose)\n        RecordingExtractor.__init__(self)\n        self.set_channel_locations(self._positions)\n\n        self._kwargs = {'file_path': str(Path(file_path).absolute()), 'mea_pitch': mea_pitch,\n                        'verbose': verbose}\n\n    def __del__(self):\n        self._rf.close()\n\n    def get_channel_ids(self):\n        return list(range(self._nRecCh))\n\n    def get_num_frames(self):\n        return self._nFrames\n\n    def get_sampling_frequency(self):\n        return self._samplingRate\n\n    @check_get_traces_args\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        data = self._read_function(self._rf, start_frame, end_frame, self.get_num_channels())\n        # transform to slice if possible\n        if sorted(channel_ids) == channel_ids and np.all(np.diff(channel_ids) == 1):\n            channel_ids = slice(channel_ids[0], channel_ids[0]+len(channel_ids))\n        return data[:, channel_ids].T\n\n    @staticmethod\n    def write_recording(recording, save_path):\n        # Convert to uV:\n        # AnalogValue = MVOffset + DigitalValue * ADCCountsToMV\n        # Where ADCCountsToMV is defined as:\n        # ADCCountsToMV = SignalInversion * ((MaxVolt - MinVolt) / 2^BitDepth)\n        # And MVOffset as:\n        # MVOffset = SignalInversion * MinVolt\n        # conversion back\n        # DigitalValue = (AnalogValue - MVOffset)/ADCCountsToMV\n        # we center at 2048\n\n        assert HAVE_BIOCAM, BiocamRecordingExtractor.installation_mesg\n        M = recording.get_num_channels()\n        N = recording.get_num_frames()\n        rf = h5py.File(save_path, 'w')\n        g = rf.create_group('3BData')\n        dr = rf.create_dataset('3BData/Raw', (M * N,), dtype=int)\n        dt = 50000\n        for i in range(N // dt):\n            dr[M * i * dt:M * (i + 1) * dt] = recording.get_traces(range(M), i * dt, (i + 1) * dt).T.flatten()\n        dr[M * (N // dt) * dt:] = recording.get_traces(range(M), (N // dt) * dt, N).T.flatten()\n        g.attrs['Version'] = 101\n        rf.create_dataset('3BRecInfo/3BRecVars/MinVolt', data=[0])\n        rf.create_dataset('3BRecInfo/3BRecVars/MaxVolt', data=[1])\n        rf.create_dataset('3BRecInfo/3BRecVars/NRecFrames', data=[N])\n        rf.create_dataset('3BRecInfo/3BRecVars/SamplingRate', data=[recording.get_sampling_frequency()])\n        rf.create_dataset('3BRecInfo/3BRecVars/SignalInversion', data=[1])\n        rf.create_dataset('3BRecInfo/3BMeaChip/NCols', data=[M])\n        r = recording.get_channel_locations()[:, 0]\n        c = recording.get_channel_locations()[:, 1]\n        d = np.ndarray((1, len(r)), dtype=[('Row', '<i2'), ('Col', '<i2')])\n        d['Row'] = r\n        d['Col'] = c\n        rf.create_dataset('3BRecInfo/3BMeaStreams/Raw/Chs', data=d)\n        rf.close()\n\n\ndef openBiocamFile(filename, mea_pitch, verbose=False):\n    \"\"\"Open a Biocam hdf5 file, read and return the recording info, pick te correct method to access raw data, and return this to the caller.\"\"\"\n    rf = h5py.File(filename, 'r')\n    # Read recording variables\n    recVars = rf.require_group('3BRecInfo/3BRecVars/')\n    # bitDepth = recVars['BitDepth'].value[0]\n    # maxV = recVars['MaxVolt'].value[0]\n    # minV = recVars['MinVolt'].value[0]\n    nFrames = recVars['NRecFrames'][0]\n    samplingRate = recVars['SamplingRate'][0]\n    signalInv = recVars['SignalInversion'][0]\n    # Read chip variables\n    chipVars = rf.require_group('3BRecInfo/3BMeaChip/')\n    nCols = chipVars['NCols'][0]\n    # Get the actual number of channels used in the recording\n    file_format = rf['3BData'].attrs.get('Version')\n    if file_format == 100:\n        nRecCh = len(rf['3BData/Raw'][0])\n    elif (file_format == 101) or (file_format == 102):\n        nRecCh = int(1. * rf['3BData/Raw'].shape[0] / nFrames)\n    else:\n        raise Exception('Unknown data file format.')\n\n    if verbose:\n        print('# 3Brain data format:', file_format, 'signal inversion', signalInv)\n        print('#       signal range: ', recVars['MinVolt'][0], '- ', recVars['MaxVolt'][0])\n        print('# channels: ', nRecCh)\n        print('# frames: ', nFrames)\n        print('# sampling rate: ', samplingRate)\n    # get channel locations\n    r = (rf['3BRecInfo/3BMeaStreams/Raw/Chs'][()]['Row'] - 1) * mea_pitch\n    c = (rf['3BRecInfo/3BMeaStreams/Raw/Chs'][()]['Col'] - 1) * mea_pitch\n    rawIndices = np.vstack((r, c)).T\n    # assign channel numbers\n    chIndices = np.array([(x - 1) + (y - 1) * nCols for (y, x) in rawIndices])\n    # determine correct function to read data\n    if verbose:\n        print(\"# Signal inversion is \" + str(signalInv) + \".\")\n        print(\"# If your spike sorting results look wrong, invert the signal.\")\n    if (file_format == 100) & (signalInv == 1):\n        read_function = readHDF5t_100\n    elif (file_format == 100) & (signalInv == -1):\n        read_function = readHDF5t_100_i\n    if ((file_format == 101) | (file_format == 102)) & (signalInv == 1):\n        read_function = readHDF5t_101\n    elif ((file_format == 101) | (file_format == 102)) & (signalInv == -1):\n        read_function = readHDF5t_101_i\n    else:\n        raise RuntimeError(\"File format unknown.\")\n    return rf, nFrames, samplingRate, nRecCh, chIndices, file_format, signalInv, rawIndices, read_function\n\n\ndef readHDF5t_100(rf, t0, t1, nch):\n    if t0 <= t1:\n        return rf['3BData/Raw'][t0:t1]\n    else:  # Reversed read\n        raise Exception('Reading backwards? Not sure about this.')\n        return rf['3BData/Raw'][t1:t0]\n\n\ndef readHDF5t_100_i(rf, t0, t1, nch):\n    if t0 <= t1:\n        return 4096 - rf['3BData/Raw'][t0:t1]\n    else:  # Reversed read\n        raise Exception('Reading backwards? Not sure about this.')\n        return 4096 - rf['3BData/Raw'][t1:t0]\n\n\ndef readHDF5t_101(rf, t0, t1, nch):\n    if t0 <= t1:\n        return rf['3BData/Raw'][nch * t0:nch * t1].reshape((t1 - t0, nch), order='C')\n    else:  # Reversed read\n        raise Exception('Reading backwards? Not sure about this.')\n        return rf['3BData/Raw'][nch * t1:nch * t0].reshape((t1 - t0, nch), order='C')\n\n\ndef readHDF5t_101_i(rf, t0, t1, nch):\n    if t0 <= t1:\n        return 4096 - rf['3BData/Raw'][nch * t0:nch * t1].reshape((t1 - t0, nch), order='C')\n    else:  # Reversed read\n        raise Exception('Reading backwards? Not sure about this.')\n        return 4096 - rf['3BData/Raw'][nch * t1:nch * t0].reshape((t1 - t0, nch), order='C')\n"
  },
  {
    "path": "spikeextractors/extractors/cedextractors/__init__.py",
    "content": "from .cedrecordingextractor import CEDRecordingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/cedextractors/cedrecordingextractor.py",
    "content": "from spikeextractors import RecordingExtractor\nfrom .utils import get_channel_info, get_channel_data\nfrom spikeextractors.extraction_tools import check_get_traces_args\n\nimport numpy as np\nfrom pathlib import Path\nfrom typing import Union\nfrom copy import deepcopy\n\ntry:\n    from sonpy import lib as sp\n\n    HAVE_SONPY = True\nexcept ImportError:\n    HAVE_SONPY = False\n\nPathType = Union[str, Path, None]\nDtypeType = Union[str, np.dtype, None]\n\n\nclass CEDRecordingExtractor(RecordingExtractor):\n    \"\"\"\n    Extracts electrophysiology recordings from .smrx files.\n    The recording extractor always returns channel IDs starting from 0.\n    The recording data will always be returned in the shape of (num_channels,num_frames).\n\n    Parameters\n    ----------\n    file_path: str\n        Path to the .smrx file to be extracted\n    smrx_channel_ids: list of int\n        List with indexes of valid smrx channels. Does not match necessarily\n        with extractor id.\n    \"\"\"\n\n    extractor_name = 'CEDRecording'\n    installed = HAVE_SONPY  # check at class level if installed or not\n    is_writable = False\n    has_default_locations = False\n    has_unscaled = False\n    mode = 'file'\n    installation_mesg = \"To use the CED extractor, install sonpy: \\n\\n pip install sonpy\\n\\n\"  # error message when not installed\n\n    def __init__(self, file_path: PathType, smrx_channel_ids: list):\n        assert self.installed, self.installation_mesg\n        file_path = Path(file_path)\n        assert file_path.is_file() and file_path.suffix == '.smrx', 'file_path must lead to a .smrx file!'\n        assert len(smrx_channel_ids) > 0, \"'smrx_channel_ids' cannot be an empty list!\"\n\n        super().__init__()\n\n        # Open smrx file\n        self._recording_file_path = file_path\n        self._recording_file = sp.SonFile(sName=str(file_path), bReadOnly=True)\n        if self._recording_file.GetOpenError() != 0:\n            raise ValueError(f'Error opening file:', sp.GetErrorString(self._recording_file.GetOpenError()))\n\n        # Map Recording channel_id to smrx index / test for invalid indexes /\n        # get channel info / set channel gains\n        self._channelid_to_smrxind = dict()\n        self._channel_smrxinfo = dict()\n        self._channel_names = []\n\n        gains = []\n        for i, ind in enumerate(smrx_channel_ids):\n            if self._recording_file.ChannelType(ind) == sp.DataType.Off:\n                raise ValueError(f'Channel {ind} is type Off and cannot be used')\n            self._channelid_to_smrxind[i] = ind\n            self._channel_smrxinfo[i] = get_channel_info(\n                f=self._recording_file,\n                smrx_ch_ind=ind\n            )\n            # Set channel gains: http://ced.co.uk/img/Spike10.pdf\n            # from 16-bit encoded int / to ADC +-5V input / to measured Volts\n            gain = self._channel_smrxinfo[i]['scale'] / 6553.6\n            gain *= 1000  # mV --> uV\n            gains.append(gain)\n            self._channel_names.append(self._channel_smrxinfo[i]['title'])\n\n        # Set gains\n        self.set_channel_gains(gains=gains)\n        self.has_unscaled = True\n\n        rate0 = self._channel_smrxinfo[0]['rate']\n        for chan, info in self._channel_smrxinfo.items():\n            assert info['rate'] == rate0, \"Inconsistency between 'sampling_frequency' of different channels. The \" \\\n                                          \"extractor only supports channels with the same 'rate'\"\n\n        # Set self._times\n        times = (self._channel_smrxinfo[0]['frame_offset'] + np.arange(self.get_num_frames())) / self.get_sampling_frequency()\n        self.set_times(times=times)\n\n        self._kwargs = {'file_path': str(Path(file_path).absolute()),\n                        'smrx_channel_ids': smrx_channel_ids}\n\n    @property\n    def channel_names(self):\n        return deepcopy(self._channel_names)\n\n    @check_get_traces_args\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        \"\"\"This function extracts and returns a trace from the recorded data from the\n        given channels ids and the given start and end frame. It will return\n        traces from within three ranges:\n\n            [start_frame, start_frame+1, ..., end_frame-1]\n            [start_frame, start_frame+1, ..., final_recording_frame - 1]\n            [0, 1, ..., end_frame-1]\n            [0, 1, ..., final_recording_frame - 1]\n\n        if both start_frame and end_frame are given, if only start_frame is\n        given, if only end_frame is given, or if neither start_frame or end_frame\n        are given, respectively. Traces are returned in a 2D array that\n        contains all of the traces from each channel with dimensions\n        (num_channels x num_frames). In this implementation, start_frame is inclusive\n        and end_frame is exclusive conforming to numpy standards.\n\n        Parameters\n        ----------\n        start_frame: int\n            The starting frame of the trace to be returned (inclusive)\n        end_frame: int\n            The ending frame of the trace to be returned (exclusive)\n        channel_ids: array_like\n            A list or 1D array of channel ids (ints) from which each trace will be\n            extracted\n        return_scaled: bool\n            If True, traces are returned after scaling (using gain/offset). If False, the traces are returned as integers\n\n        Returns\n        ----------\n        traces: numpy.ndarray\n            A 2D array that contains all of the traces from each channel.\n            Dimensions are: (num_channels x num_frames)\n        \"\"\"\n\n        recordings = np.vstack(\n            [get_channel_data(\n                f=self._recording_file,\n                smrx_ch_ind=self._channelid_to_smrxind[i],\n                start_frame=start_frame,\n                end_frame=end_frame\n            ) for i in channel_ids]\n        )\n\n        return recordings\n\n    def get_num_frames(self):\n        \"\"\"This function returns the number of frames in the recording\n\n        Returns\n        -------\n        num_frames: int\n            Number of frames in the recording (duration of recording)\n        \"\"\"\n        return 1 + int(self._channel_smrxinfo[0]['max_time'] / self._channel_smrxinfo[0]['divide'] - self._channel_smrxinfo[0]['frame_offset'])\n\n    def get_sampling_frequency(self):\n        \"\"\"This function returns the sampling frequency in units of Hz.\n\n        Returns\n        -------\n        fs: float\n            Sampling frequency of the recordings in Hz\n        \"\"\"\n        return self._channel_smrxinfo[0]['rate']\n\n    def get_channel_ids(self):\n        \"\"\"Returns the list of channel ids. If not specified, the range from 0 to num_channels - 1 is returned.\n\n        Returns\n        -------\n        channel_ids: list\n            Channel list\n\n        \"\"\"\n        return list(self._channelid_to_smrxind.keys())\n    \n    @staticmethod\n    def get_all_channels_info(file_path):\n        \"\"\"\n        Extract info from all channels in the smrx file. Returns a dictionary with\n        valid smrx channel indexes as keys and the respective channel information as\n        value.\n\n        Parameters:\n        -----------\n        f: str\n            Path to .smrx file\n        \"\"\"\n        f = sp.SonFile(sName=str(file_path), bReadOnly=True)\n        n_channels = f.MaxChannels()\n        return {\n            i: get_channel_info(f, i) for i in range(n_channels)\n            if f.ChannelType(i) != sp.DataType.Off\n        }\n"
  },
  {
    "path": "spikeextractors/extractors/cedextractors/utils.py",
    "content": "import numpy as np\n\ntry:\n    from sonpy import lib as sp\n\n    # Data storage and function finder\n    DataReadFunctions = {\n        sp.DataType.Adc: sp.SonFile.ReadInts,\n        sp.DataType.EventFall: sp.SonFile.ReadEvents,\n        sp.DataType.EventRise: sp.SonFile.ReadEvents,\n        sp.DataType.EventBoth: sp.SonFile.ReadEvents,\n        sp.DataType.Marker: sp.SonFile.ReadMarkers,\n        sp.DataType.AdcMark: sp.SonFile.ReadWaveMarks,\n        sp.DataType.RealMark: sp.SonFile.ReadRealMarks,\n        sp.DataType.TextMark: sp.SonFile.ReadTextMarks,\n        sp.DataType.RealWave: sp.SonFile.ReadFloats\n    }\nexcept:\n    pass\n\n# Get the saved time and date\n# f.GetTimeDate()\n\n\ndef get_channel_info(f, smrx_ch_ind):\n    \"\"\"\n    Extract info from smrx files\n\n    Parameters:\n    -----------\n    f: str\n        SonFile object.\n    smrx_ch_ind: int\n        Index of smrx channel. Does not match necessarily with extractor id.\n    \"\"\"\n\n    nMax = 1 + int(f.ChannelMaxTime(smrx_ch_ind) / f.ChannelDivide(smrx_ch_ind))\n    frame_offset = f.FirstTime(chan=smrx_ch_ind, tFrom=0, tUpto=nMax) / f.ChannelDivide(smrx_ch_ind)\n    ch_info = {\n        'type': f.ChannelType(smrx_ch_ind),           # Get the channel kind\n        'ch_number': f.PhysicalChannel(smrx_ch_ind),  # Get the physical channel number associated with this channel\n        'title': f.GetChannelTitle(smrx_ch_ind),      # Get the channel title\n        'ideal_rate': f.GetIdealRate(smrx_ch_ind),    # Get the requested channel ideal rate\n        'rate': 1 / (f.GetTimeBase() * f.ChannelDivide(smrx_ch_ind)),    # Get the requested channel real rate\n        'max_time': f.ChannelMaxTime(smrx_ch_ind),    # Get the time of the last item in the channel (in clock ticks)\n        'divide': f.ChannelDivide(smrx_ch_ind),       # Get the waveform sample interval in file clock ticks\n        'time_base': f.GetTimeBase(),                 # Get how many seconds there are per clock tick\n        'frame_offset': frame_offset,                 # Get frame offset\n        'scale': f.GetChannelScale(smrx_ch_ind),      # Get the channel scale\n        'offset': f.GetChannelOffset(smrx_ch_ind),    # Get the channel offset\n        'unit': f.GetChannelUnits(smrx_ch_ind),       # Get the channel units\n        'y_range': f.GetChannelYRange(smrx_ch_ind),   # Get a suggested Y range for the channel\n        'comment': f.GetChannelComment(smrx_ch_ind),  # Get the comment associated with a channel\n        'size_bytes:': f.ChannelBytes(smrx_ch_ind),   # Get an estimate of the data bytes stored for the channel\n    }\n\n    return ch_info\n\n\ndef get_channel_data(f, smrx_ch_ind, start_frame=0, end_frame=None):\n    \"\"\"\n    Extract info from smrx files\n\n    Parameters:\n    -----------\n    f: str\n        SonFile object.\n    smrx_ch_ind: int\n        Index of smrx channel. Does not match necessarily with extractor id.\n    start_frame: int\n        The starting frame of the trace to be returned (inclusive).\n    end_frame: int\n        The ending frame of the trace to be returned (exclusive).\n    \"\"\"\n\n    if end_frame is None:\n        end_frame = 1 + int(f.ChannelMaxTime(smrx_ch_ind) / f.ChannelDivide(smrx_ch_ind))\n\n    nMax = 1 + int(f.ChannelMaxTime(smrx_ch_ind) / f.ChannelDivide(smrx_ch_ind))\n    frame_offset = int(f.FirstTime(chan=smrx_ch_ind, tFrom=0, tUpto=nMax) / f.ChannelDivide(smrx_ch_ind))\n    start_frame += frame_offset\n    end_frame += frame_offset\n\n    data = DataReadFunctions[f.ChannelType(smrx_ch_ind)](\n        self=f,\n        chan=smrx_ch_ind,\n        nMax=nMax,\n        tFrom=int(start_frame * f.ChannelDivide(smrx_ch_ind)),\n        tUpto=int(end_frame * f.ChannelDivide(smrx_ch_ind))\n    )\n\n    return np.array(data)"
  },
  {
    "path": "spikeextractors/extractors/cellexplorersortingextractor/__init__.py",
    "content": "from .cellexplorersortingextractor import CellExplorerSortingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/cellexplorersortingextractor/cellexplorersortingextractor.py",
    "content": "from spikeextractors import SortingExtractor\nimport numpy as np\nfrom pathlib import Path\nfrom spikeextractors.extraction_tools import check_get_unit_spike_train\nfrom typing import Union, Optional\n\ntry:\n    import scipy.io \n    import hdf5storage\n    HAVE_SCIPY_AND_HDF5STORAGE = True\nexcept ImportError:\n    HAVE_SCIPY_AND_HDF5STORAGE = False\n\n\nPathType = Union[str, Path]\nOptionalPathType = Optional[PathType]  \n\n\nclass CellExplorerSortingExtractor(SortingExtractor):\n    \"\"\"\n    Extracts spiking information from .mat files stored in the CellExplorer format.\n\n    Spike times are stored in units of seconds.\n\n    Parameters\n    ----------\n    spikes_matfile_path : PathType\n        Path to the sorting_id.spikes.cellinfo.mat file.\n    \"\"\"\n\n    extractor_name = \"CellExplorerSortingExtractor\"\n    installed = HAVE_SCIPY_AND_HDF5STORAGE\n    is_writable = True\n    mode = \"file\"\n    installation_mesg = \"To use the CellExplorerSortingExtractor install scipy and hdf5storage: \\n\\n pip install scipy\\n\\n and  \\n\\n pip install hdf5 storage \\n\\n\"\n\n    def __init__(self, spikes_matfile_path: PathType, session_info_matfile_path: OptionalPathType=None, sampling_frequency: Optional[float] = None):\n        assert self.installed, self.installation_mesg\n        SortingExtractor.__init__(self)\n\n        spikes_matfile_path = Path(spikes_matfile_path)\n        assert (\n            spikes_matfile_path.is_file()\n        ), f\"The spikes_matfile_path ({spikes_matfile_path}) must exist!\"\n        \n        if sampling_frequency is None:\n            folder_path = spikes_matfile_path.parent\n            sorting_id = spikes_matfile_path.name.split(\".\")[0]\n            if session_info_matfile_path is None:\n                session_info_matfile_path = folder_path / f\"{sorting_id}.sessionInfo.mat\"\n\n            assert (\n                session_info_matfile_path.is_file()\n            ), f\"No {sorting_id}.sessionInfo.mat file found in the folder!\" \n\n            try:\n                session_info_mat = scipy.io.loadmat(file_name=str(session_info_matfile_path))\n                self.read_session_info_with_scipy = True\n            except NotImplementedError:\n                session_info_mat = hdf5storage.loadmat(file_name=str(session_info_matfile_path))\n                self.read_session_info_with_scipy = False\n            \n            assert session_info_mat[\"sessionInfo\"][\"rates\"][0][0][\"wideband\"], (\n                \"The sesssionInfo.mat file must contain \"\n                \"a 'sessionInfo' struct with field 'rates' containing field 'wideband' to extract the sampling frequency!\"\n            )\n            if self.read_session_info_with_scipy:\n                self._sampling_frequency = float(\n                    session_info_mat[\"sessionInfo\"][\"rates\"][0][0][\"wideband\"][0][0][0][0]\n                )  # careful not to confuse it with the lfpsamplingrate; reported in units Hz\n            else:\n                self._sampling_frequency = float(\n                    session_info_mat[\"sessionInfo\"][\"rates\"][0][0][\"wideband\"][0][0]\n                )  # careful not to confuse it with the lfpsamplingrate; reported in units Hz\n        else:\n            self._sampling_frequency = sampling_frequency\n\n        try:\n            spikes_mat = scipy.io.loadmat(file_name=str(spikes_matfile_path))\n            self.read_spikes_info_with_scipy = True\n        except NotImplementedError: \n            spikes_mat = hdf5storage.loadmat(file_name=str(spikes_matfile_path))\n            self.read_spikes_info_with_scipy = False\n\n        assert np.all(\n            np.isin([\"UID\", \"times\"], spikes_mat[\"spikes\"].dtype.names)\n        ), \"The spikes.cellinfo.mat file must contain a 'spikes' struct with fields 'UID' and 'times'!\"\n\n        # CellExplorer reports spike times in units seconds; SpikeExtractors uses time units of sampling frames\n        # Rounding is necessary to prevent data loss from int-casting floating point errors\n        if self.read_spikes_info_with_scipy:\n            self._unit_ids = np.asarray(spikes_mat[\"spikes\"][\"UID\"][0][0][0], dtype=int)\n            self._spiketrains = [\n                (np.array([y[0] for y in x]) * self._sampling_frequency).round().astype(int)\n                for x in spikes_mat[\"spikes\"][\"times\"][0][0][0]\n            ]\n        else:\n            self._unit_ids = np.asarray(spikes_mat[\"spikes\"][\"UID\"][0][0], dtype=int)\n            self._spiketrains = [\n                (np.array([y[0] for y in x]) * self._sampling_frequency).round().astype(int)\n                for x in spikes_mat[\"spikes\"][\"times\"][0][0]            \n            ]\n            \n        self._kwargs = dict(spikes_matfile_path=str(spikes_matfile_path.absolute()))\n\n    def get_unit_ids(self):\n        return list(self._unit_ids)\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n        times = self._spiketrains[self.get_unit_ids().index(unit_id)]\n        inds = np.where((start_frame <= times) & (times < end_frame))\n        return times[inds]\n\n    @staticmethod\n    def write_sorting(sorting: SortingExtractor, save_path: PathType):\n        assert save_path.suffixes == [\n            \".spikes\",\n            \".cellinfo\",\n            \".mat\",\n        ], \"The save_path must correspond to the CellExplorer format of sorting_id.spikes.cellinfo.mat!\"\n\n        base_path = save_path.parent\n        sorting_id = save_path.name.split(\".\")[0]\n        session_info_save_path = base_path / f\"{sorting_id}.sessionInfo.mat\"\n        spikes_save_path = save_path\n        base_path.mkdir(parents=True, exist_ok=True)\n\n        sampling_frequency = sorting.get_sampling_frequency()\n        session_info_mat_dict = dict(\n            sessionInfo=dict(rates=dict(wideband=sampling_frequency))\n        )\n        \n        \n        scipy.io.savemat(file_name=session_info_save_path, mdict=session_info_mat_dict)\n\n        spikes_mat_dict = dict(\n            spikes=dict(\n                UID=sorting.get_unit_ids(),\n                times=[\n                    [[y / sampling_frequency] for y in x]\n                    for x in sorting.get_units_spike_train()\n                ],\n            )\n        )\n        # If, in the future, it is ever desired to allow this to write unit properties, they must conform\n        # to the format here: https://cellexplorer.org/datastructure/data-structure-and-format/\n        scipy.io.savemat(file_name=spikes_save_path, mdict=spikes_mat_dict)\n"
  },
  {
    "path": "spikeextractors/extractors/combinatosortingextractor/__init__.py",
    "content": "from .combinatosortingextractor import CombinatoSortingExtractor"
  },
  {
    "path": "spikeextractors/extractors/combinatosortingextractor/combinatosortingextractor.py",
    "content": "from pathlib import Path\nimport numpy as np\nfrom spikeextractors import SortingExtractor\nfrom spikeextractors.extraction_tools import check_get_unit_spike_train\nfrom typing import Union\n\n\ntry:\n    import h5py\n    HAVE_H5PY = True\nexcept ImportError:\n    HAVE_H5PY = False\n\nPathType = Union[str, Path]\n\n\nclass CombinatoSortingExtractor(SortingExtractor):\n    extractor_name = 'CombinatoSorting'\n    installation_mesg = \"\"  # error message when not installed\n    installed = HAVE_H5PY\n    is_writable = False\n    installation_mesg = \"To use the CombinatoSortingExtractor install h5py: \\n\\n pip install h5py\\n\\n\"  # error message when not installed\n\n    def __init__(self, datapath: PathType, sampling_frequency=None, user='simple',det_sign = 'both'):\n        super().__init__()\n        datapath = Path(datapath)\n        assert datapath.is_dir(), 'Folder {} doesn\\'t exist'.format(datapath)\n        if sampling_frequency is None:\n            h5_path = str(datapath) + '.h5'\n            if Path(h5_path).exists():\n                with h5py.File(h5_path, mode='r') as f:\n                    sampling_frequency = f['sr'][0]\n        self.set_sampling_frequency(sampling_frequency)\n        det_file = str(datapath / Path('data_' + datapath.stem + '.h5'))\n        sort_cat_files = []\n        for sign in ['neg', 'pos']:\n            if det_sign in ['both', sign]:\n                sort_cat_file = datapath / Path('sort_{}_{}/sort_cat.h5'.format(sign,user))\n                if sort_cat_file.exists():\n                    sort_cat_files.append((sign, str(sort_cat_file)))\n        unit_counter = 0\n        self._spike_trains = {}\n        metadata = {}\n        unsorted = []\n        fdet = h5py.File(det_file, mode='r')\n        for sign, sfile in sort_cat_files:\n            with h5py.File(sfile, mode='r') as f:\n                sp_class = f['classes'][()]\n                gaux = f['groups'][()]\n                groups = {g:gaux[gaux[:, 1] == g, 0] for g in np.unique(gaux[:, 1])} #array of classes per group\n                group_type = {group: g_type for group,g_type in f['types'][()]}\n                sp_index = f['index'][()]\n\n            times_css = fdet[sign]['times'][()]\n            for gr, cls in groups.items():\n                if group_type[gr] == -1: #artifacts\n                    continue\n                elif group_type[gr] == 0: #unsorted\n                    unsorted.append(np.rint(times_css[sp_index[np.isin(sp_class,cls)]] * (sampling_frequency/1000)))\n                    continue\n\n                unit_counter = unit_counter + 1\n                self._spike_trains[unit_counter] = np.rint(times_css[sp_index[np.isin(sp_class, cls)]] * (sampling_frequency / 1000))\n                metadata[unit_counter] = {'det_sign': sign,\n                                          'group_type': 'single-unit' if group_type[gr] else 'multi-unit'}\n\n        fdet.close()\n\n        self._unsorted_train = np.array([])\n        if len(unsorted) == 1:\n            self._unsorted_train = unsorted[0]\n        elif len(unsorted) == 2: #unsorted in both signs\n            self._unsorted_train = np.sort(np.concatenate(unsorted), kind='mergesort')\n\n        self._unit_ids = list(range(1, unit_counter+1))\n        for u in self._unit_ids:\n            for prop,value in metadata[u].items():\n                self.set_unit_property(u, prop, value)\n\n    def get_unit_ids(self):\n        return self._unit_ids\n\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n        start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame)\n\n        start_frame = start_frame or 0\n        end_frame = end_frame or np.infty\n        st = self._spike_trains[unit_id]\n        return st[(st >= start_frame) & (st < end_frame)]\n\n    def get_unsorted_spike_train(self, start_frame=None, end_frame=None):\n        start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame)\n\n        start_frame = start_frame or 0\n        end_frame = end_frame or np.infty\n        u = self._unsorted_train\n        return u[(u >= start_frame) & (u < end_frame)]\n\n\n\n"
  },
  {
    "path": "spikeextractors/extractors/exdirextractors/__init__.py",
    "content": "from .exdirextractors import ExdirRecordingExtractor, ExdirSortingExtractor"
  },
  {
    "path": "spikeextractors/extractors/exdirextractors/exdirextractors.py",
    "content": "from spikeextractors import RecordingExtractor\nfrom spikeextractors import SortingExtractor\nimport numpy as np\nfrom pathlib import Path\nfrom copy import copy\nfrom spikeextractors.extraction_tools import check_get_traces_args, check_get_unit_spike_train\n\ntry:\n    import exdir\n    import exdir.plugins.quantities\n    import quantities as pq\n\n    HAVE_EXDIR = True\nexcept ImportError:\n    HAVE_EXDIR = False\n\n\nclass ExdirRecordingExtractor(RecordingExtractor):\n    extractor_name = 'ExdirRecording'\n    has_default_locations = False\n    has_unscaled = False\n    installed = HAVE_EXDIR  # check at class level if installed or not\n    is_writable = True\n    mode = 'folder'\n    installation_mesg = \"To use the ExdirExtractors run:\\n\\n pip install exdir\\n\\n\"  # error message when not installed\n\n    def __init__(self, folder_path):\n        assert self.installed, self.installation_mesg\n        self._exdir_file = folder_path\n        exdir_group = exdir.File(folder_path, plugins=[exdir.plugins.quantities])\n\n        self._recordings = exdir_group['acquisition']['timeseries']\n        self._sampling_frequency = float(self._recordings.attrs['sample_rate'].rescale('Hz').magnitude)\n\n        self._num_channels = self._recordings.shape[0]\n        self._num_timepoints = self._recordings.shape[1]\n        RecordingExtractor.__init__(self)\n\n        self._kwargs = {'folder_path': str(Path(folder_path).absolute())}\n\n    def get_channel_ids(self):\n        return list(range(self._num_channels))\n\n    def get_num_frames(self):\n        return self._num_timepoints\n\n    def get_sampling_frequency(self):\n        return self._sampling_frequency\n\n    @check_get_traces_args\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        return self._recordings.data[np.array(channel_ids), start_frame:end_frame]\n\n    @staticmethod\n    def write_recording(recording, save_path, lfp=False, mua=False):\n        assert HAVE_EXDIR, ExdirRecordingExtractor.installation_mesg\n        channel_ids = recording.get_channel_ids()\n        raw = recording.get_traces()\n        exdir_group = exdir.File(save_path, plugins=[exdir.plugins.quantities])\n\n        if not lfp and not mua:\n            acq = exdir_group.require_group('acquisition')\n            timeseries = acq.require_dataset('timeseries', data=raw)\n            timeseries.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz\n            timeseries.attrs['electrode_identities'] = np.array(channel_ids)\n            return\n        elif lfp:\n            ephys = exdir_group.require_group('processing').require_group('electrophysiology')\n            ephys.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz\n            if 'group' in recording.get_shared_channel_property_names():\n                channel_groups = np.unique(recording.get_channel_groups())\n            else:\n                channel_groups = [0]\n\n            if len(channel_groups) == 1:\n                chan = 0\n                ch_group = ephys.require_group('channel_group_' + str(chan))\n                lfp_group = ch_group.require_group('LFP')\n                ch_group.attrs['electrode_group_id'] = chan\n                ch_group.attrs['electrode_identities'] = np.array(recording.get_channel_ids())\n                ch_group.attrs['electrode_idx'] = np.arange(len(recording.get_channel_ids()))\n                ch_group.attrs['start_time'] = 0 * pq.s\n                ch_group.attrs['stop_time'] = recording.get_num_frames() / \\\n                                              float(recording.get_sampling_frequency()) * pq.s\n                for i_c, ch in enumerate(recording.get_channel_ids()):\n                    ts_group = lfp_group.require_group('LFP_timeseries_' + str(ch))\n                    ts_group.attrs['electrode_group_id'] = chan\n                    ts_group.attrs['electrode_identity'] = ch\n                    ts_group.attrs['num_samples'] = recording.get_num_frames()\n                    ts_group.attrs['electrode_idx'] = i_c\n                    ts_group.attrs['start_time'] = 0 * pq.s\n                    ts_group.attrs['stop_time'] = recording.get_num_frames() / \\\n                                                  float(recording.get_sampling_frequency()) * pq.s\n                    ts_group.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz\n                    data = ts_group.require_dataset('data', data=recording.get_traces(channel_ids=[ch]))\n                    data.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz\n                    data.attrs['unit'] = pq.uV\n            else:\n                channel_groups = np.unique(recording.get_channel_groups())\n                for chan in channel_groups:\n                    ch_group = ephys.require_group('channel_group_' + str(chan))\n                    lfp_group = ch_group.require_group('LFP')\n                    ch_group.attrs['electrode_group_id'] = chan\n                    ch_group.attrs['electrode_identities'] = np.array([ch for ch in recording.get_channel_ids()\n                                                                       if recording.get_channel_groups(ch) == chan])\n                    ch_group.attrs['electrode_idx'] = np.array([i_c for i_c, ch in\n                                                                enumerate(recording.get_channel_ids())\n                                                                if recording.get_channel_groups(ch) == chan])\n                    ch_group.attrs['start_time'] = 0 * pq.s\n                    ch_group.attrs['stop_time'] = recording.get_num_frames() / \\\n                                                  float(recording.get_sampling_frequency()) * pq.s\n                    for i_c, ch in enumerate(recording.get_channel_ids()):\n                        if recording.get_channel_groups(ch) == chan:\n                            ts_group = lfp_group.require_group('LFP_timeseries_' + str(ch))\n                            ts_group.attrs['electrode_group_id'] = chan\n                            ts_group.attrs['electrode_identity'] = ch\n                            ts_group.attrs['num_samples'] = recording.get_num_frames()\n                            ts_group.attrs['electrode_idx'] = i_c\n                            ts_group.attrs['start_time'] = 0 * pq.s\n                            ts_group.attrs['stop_time'] = recording.get_num_frames() / \\\n                                                          float(recording.get_sampling_frequency()) * pq.s\n                            ts_group.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz\n                            data = ts_group.require_dataset('data', data=recording.get_traces(channel_ids=[ch]))\n                            data.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz\n                            data.attrs['unit'] = pq.uV\n            return\n        elif mua:\n            ephys = exdir_group.require_group('processing').require_group('electrophysiology')\n            ephys.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz\n            if 'group' in recording.get_shared_channel_property_names():\n                channel_groups = np.unique(recording.get_channel_groups())\n            else:\n                channel_groups = [0]\n\n            if len(channel_groups) == 1:\n                chan = 0\n                ch_group = ephys.require_group('channel_group_' + str(chan))\n                mua_group = ch_group.require_group('MUA')\n                ch_group.attrs['electrode_group_id'] = chan\n                ch_group.attrs['electrode_identities'] = np.array(recording.get_channel_ids())\n                ch_group.attrs['electrode_idx'] = np.arange(len(recording.get_channel_ids()))\n                ch_group.attrs['start_time'] = 0 * pq.s\n                ch_group.attrs['stop_time'] = recording.get_num_frames() / \\\n                                              float(recording.get_sampling_frequency()) * pq.s\n                for i_c, ch in enumerate(recording.get_channel_ids()):\n                    ts_group = mua_group.require_group('MUA_timeseries_' + str(ch))\n                    ts_group.attrs['electrode_group_id'] = chan\n                    ts_group.attrs['electrode_identity'] = ch\n                    ts_group.attrs['num_samples'] = recording.get_num_frames()\n                    ts_group.attrs['electrode_idx'] = i_c\n                    ts_group.attrs['start_time'] = 0 * pq.s\n                    ts_group.attrs['stop_time'] = recording.get_num_frames() / \\\n                                                  float(recording.get_sampling_frequency()) * pq.s\n                    ts_group.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz\n                    data = ts_group.require_dataset('data', data=recording.get_traces(channel_ids=[ch]))\n                    data.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz\n                    data.attrs['unit'] = pq.uV\n            else:\n                channel_groups = np.unique(recording.get_channel_groups())\n                for chan in channel_groups:\n                    ch_group = ephys.require_group('channel_group_' + str(chan))\n                    mua_group = ch_group.require_group('MUA')\n                    ch_group.attrs['electrode_group_id'] = chan\n                    ch_group.attrs['electrode_identities'] = np.array([ch for ch in recording.get_channel_ids()\n                                                                       if recording.get_channel_groups(ch) == chan])\n                    ch_group.attrs['electrode_idx'] = np.array([i_c for i_c, ch in\n                                                                enumerate(recording.get_channel_ids())\n                                                                if recording.get_channel_groups(ch) == chan])\n                    ch_group.attrs['start_time'] = 0 * pq.s\n                    ch_group.attrs['stop_time'] = recording.get_num_frames() / \\\n                                                  float(recording.get_sampling_frequency()) * pq.s\n                    for i_c, ch in enumerate(recording.get_channel_ids()):\n                        if recording.get_channel_groups(ch) == chan:\n                            ts_group = mua_group.require_group('MUA_timeseries_' + str(ch))\n                            ts_group.attrs['electrode_group_id'] = chan\n                            ts_group.attrs['electrode_identity'] = ch\n                            ts_group.attrs['num_samples'] = recording.get_num_frames()\n                            ts_group.attrs['electrode_idx'] = i_c\n                            ts_group.attrs['start_time'] = 0 * pq.s\n                            ts_group.attrs['stop_time'] = recording.get_num_frames() / \\\n                                                          float(recording.get_sampling_frequency()) * pq.s\n                            ts_group.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz\n                            data = ts_group.require_dataset('data', data=recording.get_traces(channel_ids=[ch]))\n                            data.attrs['sample_rate'] = recording.get_sampling_frequency() * pq.Hz\n                            data.attrs['unit'] = pq.uV\n\n\nclass ExdirSortingExtractor(SortingExtractor):\n    extractor_name = 'ExdirSorting'\n    installed = HAVE_EXDIR  # check at class level if installed or not\n    is_writable = True\n    mode = 'folder'\n    installation_mesg = \"To use the ExdirExtractors run:\\n\\n pip install exdir\\n\\n\"  # error message when not installed\n\n    def __init__(self, folder_path, sampling_frequency=None, channel_group=None, load_waveforms=False):\n        assert self.installed, self.installation_mesg\n        SortingExtractor.__init__(self)\n        self._exdir_file = folder_path\n        exdir_group = exdir.File(folder_path, plugins=exdir.plugins.quantities)\n\n        electrophysiology = None\n        sf = copy(sampling_frequency)\n        if 'processing' in exdir_group.keys():\n            if 'electrophysiology' in exdir_group['processing']:\n                electrophysiology = exdir_group['processing']['electrophysiology']\n                ephys_attrs = electrophysiology.attrs\n                if 'sample_rate' in ephys_attrs:\n                    sf = ephys_attrs['sample_rate']\n        else:\n            if sf is None:\n                raise Exception(\"Sampling rate information not found. Please provide it with the 'sampling_frequency' \"\n                                \"argument\")\n            else:\n                sf = sf * pq.Hz\n        self._sampling_frequency = float(sf.rescale('Hz').magnitude)\n\n        if electrophysiology is None:\n            raise Exception(\"'electrophysiology' group not found!\")\n\n        self._unit_ids = []\n        current_unit = 1\n        self._spike_trains = []\n        for chan_name, channel in electrophysiology.items():\n            if 'channel' in chan_name:\n                group = int(chan_name.split('_')[-1])\n                if channel_group is not None:\n                    if group != channel_group:\n                        continue\n                if load_waveforms:\n                    if 'Clustering' in channel.keys() and 'EventWaveform' in channel.keys():\n                        clustering = channel.require_group('Clustering')\n                        eventwaveform = channel.require_group('EventWaveform')\n                        nums = clustering['nums'].data\n                        waveforms = eventwaveform.require_group('waveform_timeseries')['data'].data\n                if 'UnitTimes' in channel.keys():\n                    for unit, unit_times in channel['UnitTimes'].items():\n                        self._unit_ids.append(current_unit)\n                        self._spike_trains.append((unit_times['times'].data.rescale('s') * sf).magnitude)\n                        attrs = unit_times.attrs\n                        for k, v in attrs.items():\n                            self.set_unit_property(current_unit, k, v)\n                        if load_waveforms:\n                            unit_idxs = np.where(nums == int(unit))\n                            wf = waveforms[unit_idxs]\n                            self.set_unit_spike_features(current_unit, 'waveforms', wf)\n                        current_unit += 1\n        self._kwargs = {'folder_path': str(Path(folder_path).absolute()), 'sampling_frequency': sampling_frequency,\n                        'channel_group': channel_group, 'load_waveforms': load_waveforms}\n\n    def get_unit_ids(self):\n        return self._unit_ids\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n\n        times = self._spike_trains[self._unit_ids.index(unit_id)]\n        inds = np.where((start_frame <= times) & (times < end_frame))\n        return np.rint(times[inds]).astype(int)\n\n    @staticmethod\n    def write_sorting(sorting, save_path, recording=None, sampling_frequency=None, save_waveforms=False, verbose=False):\n        assert HAVE_EXDIR, ExdirSortingExtractor.installation_mesg\n        if sampling_frequency is None and recording is None:\n            raise Exception(\"Provide 'sampling_frequency' argument (Hz)\")\n        else:\n            if recording is None:\n                sampling_frequency = sampling_frequency * pq.Hz\n            else:\n                sampling_frequency = recording.get_sampling_frequency() * pq.Hz\n\n        exdir_group = exdir.File(save_path, plugins=exdir.plugins.quantities)\n        ephys = exdir_group.require_group('processing').require_group('electrophysiology')\n        ephys.attrs['sample_rate'] = sampling_frequency\n\n        if 'group' in sorting.get_shared_unit_property_names():\n            channel_groups = np.unique([sorting.get_unit_property(unit, 'group') for unit in sorting.get_unit_ids()])\n        else:\n            channel_groups = [0]\n\n        if len(channel_groups) == 1 and channel_groups[0] == 0:\n            chan = 0\n            if verbose:\n                print(\"Single group: \", chan)\n            ch_group = ephys.require_group('channel_group_' + str(chan))\n            try:\n                del ch_group['UnitTimes']\n                del ch_group['EventWaveform']\n                del ch_group['Clustering']\n            except Exception as e:\n                pass\n            unittimes = ch_group.require_group('UnitTimes')\n            unit_stop_time = np.max(\n                [(np.max(sorting.get_unit_spike_train(u).astype(float) / sampling_frequency).rescale('s'))\n                 for u in sorting.get_unit_ids()]) * pq.s\n            recording_stop_time = None\n            if recording is not None:\n                ch_group.attrs['electrode_group_id'] = chan\n                ch_group.attrs['electrode_identities'] = np.array([])\n                ch_group.attrs['electrode_idx'] = np.arange(len(recording.get_channel_ids()))\n                ch_group.attrs['start_time'] = 0 * pq.s\n                recording_stop_time = recording.get_num_frames() / float(recording.get_sampling_frequency()) * pq.s\n\n                unittimes.attrs['electrode_group_id'] = chan\n                unittimes.attrs['electrode_identities'] = np.array([])\n                unittimes.attrs['electrode_idx'] = np.array(recording.get_channel_ids())\n                unittimes.attrs['start_time'] = 0 * pq.s\n            ch_group.attrs['sample_rate'] = sampling_frequency\n\n            if recording_stop_time is not None:\n                unittimes.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \\\n                    else unit_stop_time\n                ch_group.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \\\n                    else unit_stop_time\n\n            nums = np.array([])\n            timestamps = np.array([])\n            waveforms = np.array([])\n            for unit in sorting.get_unit_ids():\n                unit_group = unittimes.require_group(str(unit))\n                unit_group.require_dataset('times',\n                                           data=(sorting.get_unit_spike_train(unit).astype(float)\n                                                 / sampling_frequency).rescale('s'))\n                unit_group.attrs['cluster_group'] = 'unsorted'\n                unit_group.attrs['group_id'] = chan\n                unit_group.attrs['name'] = 'unit #' + str(unit)\n\n                timestamps = np.concatenate((timestamps, (sorting.get_unit_spike_train(unit).astype(float)\n                                                          / sampling_frequency).rescale('s')))\n                nums = np.concatenate((nums, [unit] * len(sorting.get_unit_spike_train(unit))))\n\n                if 'waveforms' in sorting.get_unit_spike_feature_names(unit):\n                    if len(waveforms) == 0:\n                        waveforms = sorting.get_unit_spike_features(unit, 'waveforms')\n                    else:\n                        waveforms = np.vstack((waveforms, sorting.get_unit_spike_features(unit, 'waveforms')))\n\n            if save_waveforms:\n                if verbose:\n                    print(\"Saving EventWaveforms\")\n                if 'waveforms' in sorting.get_shared_unit_spike_feature_names():\n                    eventwaveform = ch_group.require_group('EventWaveform')\n                    waveform_ts = eventwaveform.require_group('waveform_timeseries')\n                    data = waveform_ts.require_dataset('data', data=waveforms)\n                    waveform_ts.attrs['electrode_group_id'] = chan\n                    data.attrs['num_samples'] = len(waveforms)\n                    data.attrs['sample_rate'] = sampling_frequency\n                    data.attrs['unit'] = pq.dimensionless\n                    times = waveform_ts.require_dataset('timestamps', data=timestamps)\n                    times.attrs['num_samples'] = len(timestamps)\n                    times.attrs['unit'] = pq.s\n                    if recording is not None:\n                        waveform_ts.attrs['electrode_identities'] = np.array([])\n                        waveform_ts.attrs['electrode_idx'] = np.arange(len(recording.get_channel_ids()))\n                        waveform_ts.attrs['start_time'] = 0 * pq.s\n                        if recording_stop_time is not None:\n                            waveform_ts.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \\\n                                else unit_stop_time\n                        waveform_ts.attrs['sample_rate'] = sampling_frequency\n                        waveform_ts.attrs['sample_length'] = waveforms.shape[1]\n                        waveform_ts.attrs['num_samples'] = len(waveforms)\n                if verbose:\n                    print(\"Saving Clustering\")\n                clustering = ch_group.require_group('Clustering')\n                ts = clustering.require_dataset('timestamps', data=timestamps * pq.s)\n                ts.attrs['num_samples'] = len(timestamps)\n                ts.attrs['unit'] = pq.s\n                ns = clustering.require_dataset('nums', data=nums)\n                ns.attrs['num_samples'] = len(nums)\n                cn = clustering.require_dataset('cluster_nums', data=np.array(sorting.get_unit_ids()))\n                cn.attrs['num_samples'] = len(sorting.get_unit_ids())\n        else:\n            # remove preexisten spike sorting data\n            max_group = 10\n            for chan in np.arange(max_group):\n                if 'channel_group_' + str(chan) in ephys.keys():\n                    if verbose:\n                        print('Removing channel', chan, 'info')\n                    ch_group = ephys.require_group('channel_group_' + str(chan))\n                    try:\n                        del ch_group['UnitTimes']\n                        del ch_group['EventWaveform']\n                        del ch_group['Clustering']\n                    except Exception as e:\n                        pass\n            channel_groups = np.unique([sorting.get_unit_property(unit, 'group') for unit in sorting.get_unit_ids()])\n            for chan in channel_groups:\n                if verbose:\n                    print(\"Group: \", chan)\n                ch_group = ephys.require_group('channel_group_' + str(chan))\n                unittimes = ch_group.require_group('UnitTimes')\n                unit_stop_time = np.max(\n                    [(np.max(sorting.get_unit_spike_train(u).astype(float) / sampling_frequency).rescale('s'))\n                     for u in sorting.get_unit_ids()]) * pq.s\n                recording_stop_time = None\n                if recording is not None:\n                    unittimes.attrs['electrode_group_id'] = chan\n                    unittimes.attrs['electrode_identities'] = np.array([])\n                    unittimes.attrs['electrode_idx'] = np.array(\n                        [ch for i_c, ch in enumerate(recording.get_channel_ids())\n                         if recording.get_channel_groups(ch) == chan])\n                    unittimes.attrs['start_time'] = 0 * pq.s\n                    recording_stop_time = recording.get_num_frames() / float(recording.get_sampling_frequency()) * pq.s\n\n                    ch_group.attrs['electrode_group_id'] = chan\n                    ch_group.attrs['electrode_identities'] = np.array(\n                        [i_c for i_c, ch in enumerate(recording.get_channel_ids())\n                         if recording.get_channel_groups(ch) == chan])\n                    ch_group.attrs['electrode_idx'] = np.array(\n                        [i_c for i_c, ch in enumerate(recording.get_channel_ids())\n                         if recording.get_channel_groups(ch) == chan])\n                    ch_group.attrs['start_time'] = 0 * pq.s\n                ch_group.attrs['sample_rate'] = sampling_frequency\n\n                if recording_stop_time is not None:\n                    unittimes.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \\\n                        else unit_stop_time\n                    ch_group.attrs['stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \\\n                        else unit_stop_time\n                nums = np.array([])\n                timestamps = np.array([])\n                waveforms = np.array([])\n                for unit in sorting.get_unit_ids():\n                    if sorting.get_unit_property(unit, 'group') == chan:\n                        if verbose:\n                            print(\"Unit: \", unit)\n                        unit_group = unittimes.require_group(str(unit))\n                        unit_group.require_dataset('times',\n                                                   data=(sorting.get_unit_spike_train(unit).astype(float)\n                                                         / sampling_frequency).rescale('s'))\n                        unit_group.attrs['cluster_group'] = 'unsorted'\n                        unit_group.attrs['group_id'] = chan\n                        unit_group.attrs['name'] = 'unit #' + str(unit)\n\n                        timestamps = np.concatenate((timestamps, (sorting.get_unit_spike_train(unit).astype(float)\n                                                                  / sampling_frequency).rescale('s')))\n                        nums = np.concatenate((nums, [unit] * len(sorting.get_unit_spike_train(unit))))\n\n                        if 'waveforms' in sorting.get_unit_spike_feature_names(unit):\n                            if len(waveforms) == 0:\n                                waveforms = sorting.get_unit_spike_features(unit, 'waveforms')\n                            else:\n                                waveforms = np.vstack((waveforms, sorting.get_unit_spike_features(unit, 'waveforms')))\n                if save_waveforms:\n                    if verbose:\n                        print(\"Saving EventWaveforms\")\n                    if 'waveforms' in sorting.get_shared_unit_spike_feature_names():\n                        eventwaveform = ch_group.require_group('EventWaveform')\n                        waveform_ts = eventwaveform.require_group('waveform_timeseries')\n                        data = waveform_ts.require_dataset('data', data=waveforms)\n                        data.attrs['num_samples'] = len(waveforms)\n                        data.attrs['sample_rate'] = sampling_frequency\n                        data.attrs['unit'] = pq.dimensionless\n                        times = waveform_ts.require_dataset('timestamps', data=timestamps)\n                        times.attrs['num_samples'] = len(timestamps)\n                        times.attrs['unit'] = pq.s\n                        waveform_ts.attrs['electrode_group_id'] = chan\n                        if recording is not None:\n                            waveform_ts.attrs['electrode_identities'] = np.array([])\n                            waveform_ts.attrs['electrode_idx'] = np.array([ch for i_c, ch in\n                                                                           enumerate(recording.get_channel_ids())\n                                                                           if recording.get_channel_groups(ch) == chan])\n                            waveform_ts.attrs['start_time'] = 0 * pq.s\n                            if recording_stop_time is not None:\n                                waveform_ts.attrs[\n                                    'stop_time'] = recording_stop_time if recording_stop_time > unit_stop_time \\\n                                    else unit_stop_time\n                            waveform_ts.attrs['sample_rate'] = sampling_frequency\n                            waveform_ts.attrs['sample_length'] = waveforms.shape[1]\n                            waveform_ts.attrs['num_samples'] = len(waveforms)\n                if verbose:\n                    print(\"Saving Clustering\")\n                clustering = ephys.require_group('channel_group_' + str(chan)).require_group('Clustering')\n                ts = clustering.require_dataset('timestamps', data=timestamps * pq.s)\n                ts.attrs['num_samples'] = len(timestamps)\n                ts.attrs['unit'] = pq.s\n                ns = clustering.require_dataset('nums', data=nums)\n                ns.attrs['num_samples'] = len(nums)\n                cn = clustering.require_dataset('cluster_nums', data=np.array(sorting.get_unit_ids()))\n                cn.attrs['num_samples'] = len(sorting.get_unit_ids())\n"
  },
  {
    "path": "spikeextractors/extractors/hdsortsortingextractor/__init__.py",
    "content": "from .hdsortsortingextractor import HDSortSortingExtractor"
  },
  {
    "path": "spikeextractors/extractors/hdsortsortingextractor/hdsortsortingextractor.py",
    "content": "from pathlib import Path\r\nfrom typing import Union\r\nimport numpy as np\r\nimport sys\r\nimport os\r\n\r\nfrom spikeextractors.extractors.matsortingextractor.matsortingextractor import MATSortingExtractor\r\nfrom spikeextractors.extraction_tools import check_get_unit_spike_train\r\n\r\nPathType = Union[str, Path]\r\n\r\nconvert_cell_array_to_struct_code = \"\"\"\r\nhdsortOutput = load(fileName);\r\nhdsortOutput.Units = [hdsortOutput.Units{:}];\r\nUnits = hdsortOutput.Units;\r\nMultiElectrode = hdsortOutput.MultiElectrode;\r\nnoiseStd = hdsortOutput.noiseStd;\r\nsamplingRate = hdsortOutput.samplingRate;\r\nsave(fileName, 'Units', 'MultiElectrode', 'noiseStd', 'samplingRate');\r\n\"\"\"\r\n\r\n\r\nclass HDSortSortingExtractor(MATSortingExtractor):\r\n    extractor_name = \"HDSortSortingExtractor\"\r\n\r\n    def __init__(self, file_path: PathType, keep_good_only: bool = True):\r\n        super().__init__(file_path)\r\n\r\n        if not self._old_style_mat:\r\n            _units = self._data['Units']\r\n            units = _parse_units(self._data, _units)\r\n\r\n            # Extracting MutliElectrode field by field:\r\n            _ME = self._data[\"MultiElectrode\"]\r\n            multi_electrode = dict((k, _ME.get(k)[()]) for k in _ME.keys())\r\n\r\n            # Extracting sampling_frequency:\r\n            sr = self._data[\"samplingRate\"]\r\n            self._sampling_frequency = float(_squeeze_ds(sr))\r\n\r\n            # Remove noise units if necessary:\r\n            if keep_good_only:\r\n                units = [unit for unit in units if unit[\"ID\"].flatten()[0].astype(int) % 1000 != 0]\r\n\r\n            if 'sortingInfo' in self._data.keys():\r\n                info = self._data[\"sortingInfo\"]\r\n                start_frame = _squeeze_ds(info['startTimes'])\r\n                self.start_frame = int(start_frame)\r\n            else:\r\n                self.start_frame = 0\r\n        else:\r\n            _units = self._getfield('Units').squeeze()\r\n            fields = _units.dtype.fields.keys()\r\n            units = []\r\n\r\n            for unit in _units:\r\n                unit_dict = {}\r\n                for f in fields:\r\n                    unit_dict[f] = unit[f]\r\n                units.append(unit_dict)\r\n\r\n            sr = self._getfield(\"samplingRate\")\r\n            self._sampling_frequency = float(_squeeze_ds(sr))\r\n\r\n            _ME = self._data[\"MultiElectrode\"]\r\n            multi_electrode = dict((k, _ME[k][0][0].T) for k in _ME.dtype.fields.keys())\r\n\r\n            # Remove noise units if necessary:\r\n            if keep_good_only:\r\n                units = [unit for unit in units if unit[\"ID\"].flatten()[0].astype(int) % 1000 != 0]\r\n\r\n            if 'sortingInfo' in self._data.keys():\r\n                info = self._getfield(\"sortingInfo\")\r\n                start_frame = _squeeze_ds(info['startTimes'])\r\n                self.start_frame = int(start_frame)\r\n            else:\r\n                self.start_frame = 0\r\n\r\n        # Parse through 'units':\r\n        self._spike_trains = {}\r\n        self._unit_ids = np.empty(0, np.int)\r\n        for uc, unit in enumerate(units):\r\n            uid = int(_squeeze_ds(unit[\"ID\"]))\r\n\r\n            self._unit_ids = np.append(self._unit_ids, uid)\r\n            self._spike_trains[uc] = _squeeze(unit[\"spikeTrain\"]).astype(np.int) - self.start_frame\r\n\r\n            # For memory efficiency in case it's necessary:\r\n            # X = self.allocate_array( \"amplitudes_\" + uid, array= unit[\"spikeAmplitudes\"].flatten().T)\r\n            # self.set_unit_spike_features(uid, \"amplitudes\", X)\r\n            self.set_unit_spike_features(uid, \"amplitudes\", _squeeze(unit[\"spikeAmplitudes\"]))\r\n            self.set_unit_spike_features(uid, \"detection_channel\", _squeeze(unit[\"detectionChannel\"]).astype(np.int))\r\n\r\n            idx = unit[\"detectionChannel\"].astype(int) - 1\r\n            spikePositions = np.vstack((_squeeze(multi_electrode[\"electrodePositions\"][0][idx]),\r\n                                        _squeeze(multi_electrode[\"electrodePositions\"][1][idx]))).T\r\n            self.set_unit_spike_features(uid, \"positions\", spikePositions)\r\n\r\n            if self._old_style_mat:\r\n                template = unit[\"footprint\"].T\r\n            else:\r\n                template = unit[\"footprint\"]\r\n            self.set_unit_property(uid, \"template\", template)\r\n            self.set_unit_property(uid, \"template_frames_cut_before\", unit[\"cutLeft\"].flatten())\r\n\r\n        self._units = units\r\n        self._multi_electrode = multi_electrode\r\n        self._kwargs['keep_good_only'] = keep_good_only\r\n\r\n    @check_get_unit_spike_train\r\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\r\n        uidx = np.where(np.array(self.get_unit_ids()) == unit_id)[0][0]\r\n        st = self._spike_trains[uidx]\r\n        return st[(st >= start_frame) & (st < end_frame)]\r\n\r\n    def get_unit_ids(self):\r\n        return self._unit_ids.tolist()\r\n\r\n    @staticmethod\r\n    def write_sorting(sorting, save_path, locations=None, noise_std_by_channel=None, start_frame=0,\r\n                      convert_cell_to_struct=True):\r\n\r\n        # First, find out how many channels there are\r\n        if locations is not None:\r\n            # write_locations must be a 2D numpy array with n_channels in first dim., (x,y) in second dim.\r\n            n_channels = locations.shape[0]\r\n        elif 'template' in sorting.get_shared_unit_property_names() or \\\r\n                'detection_channel' in sorting.get_shared_unit_property_names():\r\n            # Without locations, check if there is a template to get the number of channels\r\n            uid = int(sorting.get_unit_ids()[0])\r\n            if \"template\" in sorting.get_unit_property_names(uid):\r\n                template = sorting.get_unit_property(uid, \"template\")\r\n                n_channels = template.shape[0]\r\n            else:\r\n                # If there is also no template, loop through all units and find max. detection_channel\r\n                max_channel = 1\r\n                for uid_ in sorting.get_unit_ids():\r\n                    uid = int(uid_)\r\n                    detection_channel = sorting.get_unit_spike_features(uid, \"detection_channel\")\r\n                    max_channel = max([max_channel], np.append(detection_channel))\r\n                n_channels = max_channel\r\n        else:\r\n            n_channels = 1\r\n\r\n        # Now loop through all units and extract the data that we want to save:\r\n        units = []\r\n        for uid_ in sorting.get_unit_ids():\r\n            uid = int(uid_)\r\n\r\n            unit = {\"ID\": uid,\r\n                    \"spikeTrain\": sorting.get_unit_spike_train(uid)}\r\n            num_spikes = len(sorting.get_unit_spike_train(uid))\r\n\r\n            if \"amplitudes\" in sorting.get_unit_spike_feature_names(uid):\r\n                unit[\"spikeAmplitudes\"] = sorting.get_unit_spike_features(uid, \"amplitudes\")\r\n            else:\r\n                # Save a spikeAmplitudes = 1\r\n                unit[\"spikeAmplitudes\"] = np.ones(num_spikes, np.double)\r\n\r\n            if \"detection_channel\" in sorting.get_unit_spike_feature_names(uid):\r\n                unit[\"detectionChannel\"] = sorting.get_unit_spike_features(uid, \"detection_channel\")\r\n            else:\r\n                # Save a detectionChannel = 1\r\n                unit[\"detectionChannel\"] = np.ones(num_spikes, np.double)\r\n\r\n            if \"template\" in sorting.get_unit_property_names(uid):\r\n                unit[\"footprint\"] = sorting.get_unit_property(uid, \"template\").T\r\n            else:\r\n                # If this unit does not have a footprint, create an empty one:\r\n                unit[\"footprint\"] = np.zeros((3, n_channels), np.double)\r\n\r\n            if \"template_cut_left\" in sorting.get_unit_property_names(uid):\r\n                unit[\"cutLeft\"] = sorting.get_unit_property(uid, \"template_cut_left\")\r\n            else:\r\n                unit[\"cutLeft\"] = 1\r\n\r\n            units.append(unit)\r\n\r\n        # Save the electrode locations:\r\n        if locations is None:\r\n            # Create artificial locations if none are provided:\r\n            x = np.zeros(n_channels, np.double)\r\n            y = np.array(np.arange(n_channels), np.double)\r\n            locations = np.vstack((x, y)).T\r\n\r\n        multi_electrode = {\"electrodePositions\": locations, \"electrodeNumbers\": np.arange(n_channels)}\r\n\r\n        if noise_std_by_channel is None:\r\n            noise_std_by_channel = np.ones((1, n_channels))\r\n\r\n        dict_to_save = {'Units': np.array(units),\r\n                        'MultiElectrode': multi_electrode,\r\n                        'noiseStd': noise_std_by_channel,\r\n                        \"samplingRate\": sorting._sampling_frequency}\r\n\r\n        # Save Units and MultiElectrode to .mat file:\r\n        MATSortingExtractor.write_dict_to_mat(save_path, dict_to_save, version='7.3')\r\n\r\n        if convert_cell_to_struct:\r\n            # read the template txt files\r\n            convert_cellarray_to_structarray = f\"fileName='{str(Path(save_path).absolute())}';\\n\" \\\r\n                                               f\"{convert_cell_array_to_struct_code}\"\r\n            convert_script = Path(save_path).parent / \"convert_cellarray_to_structarray.m\"\r\n\r\n            with convert_script.open('w') as f:\r\n                f.write(convert_cellarray_to_structarray)\r\n\r\n            if 'win' in sys.platform and sys.platform != 'darwin':\r\n                matlab_cmd = \"\"\"\r\n                             #!/bin/bash\r\n                             cd {tmpdir}\r\n                             matlab -nosplash -wait -log -r convert_cellarray_to_structarray\r\n                             \"\"\".format(tmpdir={str(convert_script.parent)})\r\n            else:\r\n                matlab_cmd = \"\"\"\r\n                             #!/bin/bash\r\n                             cd {tmpdir}\r\n                             matlab -nosplash -nodisplay -log -r convert_cellarray_to_structarray\r\n                             \"\"\".format(tmpdir={str(convert_script.parent)})\r\n\r\n            try:\r\n                os.system(matlab_cmd)\r\n            except:\r\n                print(\"Failed to convert cell array to struct array\")\r\n            convert_script.unlink()\r\n\r\n\r\n# For .mat v7.3: Function to extract all fields of a struct-array:\r\ndef _parse_units(file, _units):\r\n    import h5py\r\n\r\n    t_units = {}\r\n    if isinstance(_units, h5py.Group):\r\n        for name in _units.keys():\r\n            value = _units[name]\r\n            dict_val = []\r\n            for val in value:\r\n                if isinstance(file[val[0]], h5py.Dataset):\r\n                    dict_val.append(file[val[0]][()])\r\n                    t_units[name] = dict_val\r\n                else:\r\n                    break\r\n        out = [dict(zip(t_units, col)) for col in zip(*t_units.values())]\r\n    else:\r\n        out = []\r\n        for unit in _units:\r\n            group = file[unit[()][0]]\r\n            unit_dict = {}\r\n            for k in group.keys():\r\n                unit_dict[k] = group[k][()]\r\n            out.append(unit_dict)\r\n\r\n    return out\r\n\r\n\r\ndef _squeeze_ds(ds):\r\n    while not isinstance(ds, (int, float, np.integer, np.float)):\r\n        ds = ds[0]\r\n    return ds\r\n\r\n\r\ndef _squeeze(arr):\r\n    shape = arr.shape\r\n    if len(shape) == 2:\r\n        if shape[0] == 1:\r\n            arr = arr[0]\r\n        elif shape[1] == 1:\r\n            arr = arr[:, 0]\r\n    return arr\r\n"
  },
  {
    "path": "spikeextractors/extractors/hs2sortingextractor/__init__.py",
    "content": "from .hs2sortingextractor import HS2SortingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/hs2sortingextractor/hs2sortingextractor.py",
    "content": "from spikeextractors import SortingExtractor\nimport numpy as np\nfrom pathlib import Path\nfrom spikeextractors.extraction_tools import check_get_unit_spike_train\n\ntry:\n    import h5py\n\n    HAVE_HS2SX = True\nexcept ImportError:\n    HAVE_HS2SX = False\n\n\nclass HS2SortingExtractor(SortingExtractor):\n    extractor_name = 'HS2Sorting'\n    installed = HAVE_HS2SX  # check at class level if installed or not\n    is_writable = True\n    mode = 'file'\n    installation_mesg = \"To use the HS2SortingExtractor install h5py: \\n\\n pip install h5py\\n\\n\"  # error message when not installed\n\n    def __init__(self, file_path, load_unit_info=True):\n        assert self.installed, self.installation_mesg\n        SortingExtractor.__init__(self)\n        self._recording_file = file_path\n        self._rf = h5py.File(self._recording_file, mode='r')\n        if 'Sampling' in self._rf:\n            if self._rf['Sampling'][()] == 0:\n                self._sampling_frequency = None\n            else:\n                self._sampling_frequency = self._rf['Sampling'][()]\n\n        self._cluster_id = self._rf['cluster_id'][()]\n        self._unit_ids = set(self._cluster_id)\n        self._spike_times = self._rf['times'][()]\n\n        if load_unit_info:\n            self.load_unit_info()\n\n        self._kwargs = {'file_path': str(Path(file_path).absolute()), 'load_unit_info': load_unit_info}\n\n    def load_unit_info(self):\n        if 'centres' in self._rf.keys() and len(self._spike_times) > 0:\n            self._unit_locs = self._rf['centres'][()]  # cache for faster access\n            for u_i, unit_id in enumerate(self._unit_ids):\n                self.set_unit_property(unit_id, property_name='unit_location', value=self._unit_locs[u_i])\n        inds = []  # get these only once\n        for unit_id in self._unit_ids:\n            inds.append(np.where(self._cluster_id == unit_id)[0])\n        if 'data' in self._rf.keys() and len(self._spike_times) > 0:\n            d = self._rf['data'][()]\n            for i, unit_id in enumerate(self._unit_ids):\n                self.set_unit_spike_features(unit_id, 'spike_location', d[:, inds[i]].T)\n        if 'ch' in self._rf.keys() and len(self._spike_times) > 0:\n            d = self._rf['ch'][()]\n            for i, unit_id in enumerate(self._unit_ids):\n                self.set_unit_spike_features(unit_id, 'max_channel', d[inds[i]])\n\n    def get_unit_indices(self, x):\n        return np.where(self._cluster_id == x)[0]\n\n    def get_unit_ids(self):\n        return list(self._unit_ids)\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n        times = self._spike_times[self.get_unit_indices(unit_id)]\n        inds = np.where((start_frame <= times) & (times < end_frame))\n        return times[inds]\n\n    @staticmethod\n    def write_sorting(sorting, save_path):\n        assert HAVE_HS2SX, HS2SortingExtractor.installation_mesg\n        unit_ids = sorting.get_unit_ids()\n        times_list = []\n        labels_list = []\n        for i in range(len(unit_ids)):\n            unit = unit_ids[i]\n            times = sorting.get_unit_spike_train(unit_id=unit)\n            times_list.append(times)\n            labels_list.append(np.ones(times.shape, dtype=int) * unit)\n        all_times = np.concatenate(times_list)\n        all_labels = np.concatenate(labels_list)\n\n        rf = h5py.File(save_path, mode='w')\n        if sorting.get_sampling_frequency() is not None:\n            rf.create_dataset(\"Sampling\", data=sorting.get_sampling_frequency())\n        else:\n            rf.create_dataset(\"Sampling\", data=0)\n        if 'unit_location' in sorting.get_shared_unit_property_names():\n            spike_centres = [sorting.get_unit_property(u, 'unit_location') for u in sorting.get_unit_ids()]\n            spike_centres = np.array(spike_centres)\n            rf.create_dataset(\"centres\", data=spike_centres)\n        if 'spike_location' in sorting.get_shared_unit_spike_feature_names():\n            spike_loc_x = []\n            spike_loc_y = []\n            for u in sorting.get_unit_ids():\n                l = sorting.get_unit_spike_features(u, 'spike_location')\n                spike_loc_x.append(l[:, 0])\n                spike_loc_y.append(l[:, 1])\n            spike_loc = np.vstack((np.concatenate(spike_loc_x), np.concatenate(spike_loc_y)))\n            rf.create_dataset(\"data\", data=spike_loc)\n        if 'max_channel' in sorting.get_shared_unit_spike_feature_names():\n            spike_max_channel = np.concatenate(\n                [sorting.get_unit_spike_features(u, 'max_channel') for u in sorting.get_unit_ids()])\n            rf.create_dataset(\"ch\", data=spike_max_channel)\n\n        rf.create_dataset(\"times\", data=all_times)\n        rf.create_dataset(\"cluster_id\", data=all_labels)\n        rf.close()\n"
  },
  {
    "path": "spikeextractors/extractors/intanrecordingextractor/__init__.py",
    "content": "from .intanrecordingextractor import IntanRecordingExtractor"
  },
  {
    "path": "spikeextractors/extractors/intanrecordingextractor/intanrecordingextractor.py",
    "content": "import numpy as np\nfrom pathlib import Path\nfrom packaging.version import parse\nfrom typing import Union, Optional\n\nfrom spikeextractors import RecordingExtractor\nfrom spikeextractors.extraction_tools import check_get_traces_args, check_get_ttl_args\n\nDtypeType = Union[str, np.dtype]\nOptionalArrayType = Optional[Union[np.ndarray, list]]\n\ntry:\n    import pyintan\n    if parse(pyintan.__version__) >= parse('0.3.0'):\n        HAVE_INTAN = True\n    else:\n        print(\"pyintan version requires an update (>=0.3.0). Please upgrade with 'pip install --upgrade pyintan'\")\n        HAVE_INTAN = False\nexcept ImportError:\n    HAVE_INTAN = False\n\n\nclass IntanRecordingExtractor(RecordingExtractor):\n    \"\"\"\n    Extracts raw neural recordings from the Intan file format.\n\n    The recording extractor always returns channel IDs starting from 0.\n\n    The recording data will always be returned in the shape of (num_channels, num_frames).\n\n    Parameters\n    ----------\n    file_path : str\n        Path to the .dat file to be extracted.\n    dtype : dtype\n        The data type used in the binary file.\n    verbose : bool, optional\n        Print output during pyintan file read.\n    \"\"\"\n\n    extractor_name = 'IntanRecording'\n    has_default_locations = False\n    has_unscaled = True\n    is_writable = False\n    mode = \"file\"\n    installed = HAVE_INTAN\n    installation_mesg = \"To use the Intan extractor, install pyintan: \\n\\n pip install pyintan\\n\\n\"\n\n    def __init__(self, file_path: str, verbose: bool = False):\n        assert self.installed, self.installation_mesg\n        RecordingExtractor.__init__(self)\n        assert Path(file_path).suffix == '.rhs' or Path(file_path).suffix == '.rhd', \\\n            \"Only '.rhd' and '.rhs' files are supported\"\n        self._recording_file = file_path\n        self._recording = pyintan.File(file_path, verbose)\n        self._num_frames = len(self._recording.times)\n        self._analog_channels = np.array([\n            ch for ch in self._recording._anas_chan\n            if all([other_ch not in ch['name'] for other_ch in ['ADC', 'VDD', 'AUX']])\n        ])\n        self._num_channels = len(self._analog_channels)\n        self._channel_ids = list(range(self._num_channels))\n        self._fs = float(self._recording.sample_rate.rescale('Hz').magnitude)\n\n        for i, ch in enumerate(self._analog_channels):\n            self.set_channel_gains(channel_ids=i, gains=ch['gain'])\n            self.set_channel_offsets(channel_ids=i, offsets=ch['offset'])\n\n        self._kwargs = dict(file_path=str(Path(file_path).absolute()), verbose=verbose)\n\n    def get_channel_ids(self):\n        return self._channel_ids\n\n    def get_num_frames(self):\n        return self._num_frames\n\n    def get_sampling_frequency(self):\n        return self._fs\n\n    @check_get_traces_args\n    def get_traces(\n        self,\n        channel_ids: OptionalArrayType = None,\n        start_frame: Optional[int] = None,\n        end_frame: Optional[int] = None,\n        return_scaled: bool = True,\n    ):\n        \"\"\"\n        This function extracts and returns a trace from the recorded data from the\n        given channels ids and the given start and end frame. It will return\n        traces from within three ranges:\n\n            [start_frame, start_frame+1, ..., end_frame-1]\n            [start_frame, start_frame+1, ..., final_recording_frame - 1]\n            [0, 1, ..., end_frame-1]\n            [0, 1, ..., final_recording_frame - 1]\n\n        if both start_frame and end_frame are given, if only start_frame is\n        given, if only end_frame is given, or if neither start_frame or end_frame\n        are given, respectively. Traces are returned in a 2D array that\n        contains all of the traces from each channel with dimensions\n        (num_channels x num_frames). In this implementation, start_frame is inclusive\n        and end_frame is exclusive conforming to numpy standards.\n\n        Parameters\n        ----------\n        start_frame : int, optional\n            The starting frame of the trace to be returned (inclusive)\n        end_frame : int, optional\n            The ending frame of the trace to be returned (exclusive)\n        channel_ids : ArrayType, optional\n            A list or 1D array of channel ids (ints) from which each trace will be\n            extracted\n        return_scaled : bool, optional\n            If True, traces are returned after scaling (using gain/offset). If False, the raw traces are returned.\n            Defaults to True.\n\n        Returns\n        ----------\n        traces: numpy.ndarray\n            A 2D array that contains all of the traces from each channel.\n            Dimensions are: (num_channels x num_frames)\n        \"\"\"\n        channel_idxs = np.array([self._channel_ids.index(ch) for ch in channel_ids])\n        return self._recording._read_analog(\n            channels=self._analog_channels[channel_idxs],\n            i_start=start_frame,\n            i_stop=end_frame,\n            dtype=\"uint16\"\n        ).T\n\n    @check_get_ttl_args\n    def get_ttl_events(self, start_frame=None, end_frame=None, channel_id=0):\n        channels = [np.unique(ev.channels)[0] for ev in self._recording.digital_in_events]\n        assert channel_id in channels, f\"Specified 'channel' not found. Available channels are {channels}\"\n        ev = self._recording.events[channels.index(channel_id)]\n\n        ttl_frames = (ev.times.rescale(\"s\") * self.get_sampling_frequency()).magnitude.astype(int)\n        ttl_states = np.sign(ev.channel_states)\n        ttl_valid_idxs = np.where((ttl_frames >= start_frame) & (ttl_frames < end_frame))[0]\n        return ttl_frames[ttl_valid_idxs], ttl_states[ttl_valid_idxs]\n"
  },
  {
    "path": "spikeextractors/extractors/jrcsortingextractor/__init__.py",
    "content": "from .jrcsortingextractor import JRCSortingExtractor"
  },
  {
    "path": "spikeextractors/extractors/jrcsortingextractor/jrcsortingextractor.py",
    "content": "from pathlib import Path\r\nimport re\r\nfrom typing import Union\r\nimport numpy as np\r\n\r\nfrom spikeextractors.extractors.matsortingextractor.matsortingextractor import MATSortingExtractor, HAVE_MAT\r\nfrom spikeextractors.extraction_tools import check_get_unit_spike_train\r\n\r\nPathType = Union[str, Path]\r\n\r\n\r\nclass JRCSortingExtractor(MATSortingExtractor):\r\n    extractor_name = \"JRCSortingExtractor\"\r\n    installation_mesg = \"To use the MATSortingExtractor install h5py and scipy: \\n\\n pip install h5py scipy\\n\\n\"  # error message when not installed\r\n\r\n    def __init__(self, file_path: PathType, keep_good_only: bool = False):\r\n        super().__init__(file_path)\r\n        file_path = self._kwargs[\"file_path\"]\r\n\r\n        spike_times = self._getfield(\"spikeTimes\").ravel() - 1  # int32\r\n        spike_clusters = self._getfield(\"spikeClusters\").ravel()  # uint32\r\n        spike_amplitudes = self._getfield(\"spikeAmps\").ravel()  # int16\r\n        spike_sites = self._getfield(\"spikeSites\").ravel() - 1  # uint32\r\n        spike_positions = self._getfield(\"spikePositions\").T  # float32\r\n\r\n        unit_centroids = self._getfield(\"clusterCentroids\").astype(np.float).T\r\n        unit_sites = self._getfield(\"clusterSites\").astype(np.uint32).ravel()\r\n        mean_waveforms = self._getfield(\"meanWfGlobal\").T\r\n        mean_waveforms_raw = self._getfield(\"meanWfGlobalRaw\").T\r\n\r\n        # try to extract various parameters from the .prm file\r\n        self._bit_scaling = np.float32(0.30518)  # conversion factor for ADC units -> µV\r\n        sample_rate = 30000.\r\n        filter_type = \"ndiff\"\r\n        ndiff_order = 2\r\n\r\n        prm_file = Path(file_path.parent, file_path.name.replace(\"_res.mat\", \".prm\"))\r\n        with prm_file.open(\"r\") as fh:\r\n            lines = [line.strip() for line in fh.readlines()]\r\n\r\n        for line in lines:\r\n            try:\r\n                key, val = line.split('%', 1)[0].strip(\" ;\").split(\"=\")\r\n            except ValueError:\r\n                continue\r\n\r\n            key = key.strip()\r\n            val = val.strip()\r\n\r\n            if key == \"sampleRate\":\r\n                try:\r\n                    sample_rate = float(val)\r\n                except (IndexError, ValueError):\r\n                    pass\r\n            elif key == \"bitScaling\":\r\n                try:\r\n                    self._bit_scaling = np.float32(val)\r\n                except (IndexError, ValueError):\r\n                    pass\r\n            elif key == \"filterType\":\r\n                filter_type = val\r\n            elif key == \"nDiffOrder\":\r\n                try:\r\n                    ndiff_order = int(val)\r\n                except (IndexError, ValueError):\r\n                    pass\r\n            elif key == \"siteLoc\":\r\n                site_locs = []\r\n                str_locs = map(lambda v: v.strip(\" ][\"), val.split(\";\"))\r\n                for loc in str_locs:\r\n                    x, y = map(float, re.split(r\",?\\s+\", loc))\r\n                    site_locs.append([x, y])\r\n\r\n                site_locs = np.array(site_locs)\r\n            elif key == \"shankMap\":\r\n                val = val.strip(\"][\")\r\n                try:\r\n                    shank_map = np.array(map(float, re.split(r\"[,;]?\\s+\", val)))\r\n                except:\r\n                    shank_map = np.array([])\r\n\r\n        self.set_sampling_frequency(sample_rate)\r\n        if filter_type == \"sgdiff\":\r\n            self._bit_scaling /= (2 * (np.arange(1, ndiff_order + 1) ** 2).sum())\r\n        elif filter_type == \"ndiff\":\r\n            self._bit_scaling /= 2\r\n\r\n        # traces, features\r\n        raw_file = Path(file_path.parent, file_path.name.replace(\"_res.mat\", \"_raw.jrc\"))\r\n        raw_shape = tuple(self._getfield(\"rawShape\").ravel().astype(np.int))\r\n        self._raw_traces = np.memmap(raw_file, dtype=np.int16, mode=\"r\",\r\n                                     shape=raw_shape, order=\"F\")\r\n\r\n        filt_file = Path(file_path.parent, file_path.name.replace(\"_res.mat\", \"_filt.jrc\"))\r\n        filt_shape = tuple(self._getfield(\"filtShape\").ravel().astype(np.int))\r\n        self._filt_traces = np.memmap(filt_file, dtype=np.int16, mode=\"r\",\r\n                                      shape=filt_shape, order=\"F\")\r\n\r\n        features_file = Path(file_path.parent, file_path.name.replace(\"_res.mat\", \"_features.jrc\"))\r\n        features_shape = tuple(self._getfield(\"featuresShape\").ravel().astype(np.int))\r\n        self._cluster_features = np.memmap(features_file, dtype=np.float32, mode=\"r\",\r\n                                           shape=features_shape, order=\"F\")\r\n\r\n        neighbors = _find_site_neighbors(site_locs, raw_shape[1], shank_map)  # get nearest neighbors for each site\r\n\r\n        # nonpositive clusters are noise or deleted units\r\n        if keep_good_only:\r\n            good_mask = spike_clusters > 0\r\n        else:\r\n            good_mask = np.ones_like(spike_clusters, dtype=np.bool)\r\n\r\n        self._unit_ids = np.unique(spike_clusters[good_mask])\r\n\r\n        # load spike trains\r\n        self._spike_trains = {}\r\n        self._unit_masks = {}\r\n        for uid in self._unit_ids:\r\n            mask = (spike_clusters == uid)\r\n            self._unit_masks[uid] = mask\r\n\r\n            self._spike_trains[uid] = spike_times[mask]\r\n\r\n            self.set_unit_spike_features(uid, \"amplitudes\", spike_amplitudes[mask])\r\n            self.set_unit_spike_features(uid, \"max_channels\", spike_sites[mask])\r\n            self.set_unit_spike_features(uid, \"positions\", spike_positions[mask, :])\r\n            self.set_unit_spike_features(uid, \"site_neighbors\", neighbors[spike_sites[mask], :])\r\n\r\n            self.set_unit_property(uid, \"centroid\", unit_centroids[uid - 1, :])\r\n            self.set_unit_property(uid, \"max_channel\", unit_sites[uid - 1])\r\n            self.set_unit_property(uid, \"template\", mean_waveforms[:, :, uid - 1])\r\n            self.set_unit_property(uid, \"template_raw\", mean_waveforms_raw[:, :, uid - 1])\r\n\r\n        self._kwargs[\"keep_good_only\"] = keep_good_only\r\n\r\n\r\n    @check_get_unit_spike_train\r\n    def get_unit_spike_features(self, unit_id, feature_name, start_frame=None, end_frame=None):\r\n        if feature_name not in (\"raw_traces\", \"filtered_traces\", \"cluster_features\"):\r\n            return super().get_unit_spike_features(unit_id, feature_name, start_frame, end_frame)\r\n\r\n        mask = self._unit_masks[unit_id]\r\n        if feature_name == \"raw_traces\":\r\n            return self._raw_traces[:, :, mask] * self._bit_scaling\r\n        elif feature_name == \"filtered_traces\":\r\n            return self._filt_traces[:, :, mask] * self._bit_scaling\r\n        else:\r\n            return self._cluster_features[:, :, mask]\r\n\r\n    @check_get_unit_spike_train\r\n    def get_unit_spike_feature_names(self, unit_id):\r\n        return super().get_unit_spike_feature_names(unit_id) + [\"raw_traces\", \"filtered_traces\", \"cluster_features\"]\r\n\r\n    @check_get_unit_spike_train\r\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\r\n        start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame)\r\n\r\n        start_frame = start_frame or 0\r\n        end_frame = end_frame or np.infty\r\n\r\n        st = self._spike_trains[unit_id]\r\n        return st[(st >= start_frame) & (st < end_frame)]\r\n\r\n    def get_unit_ids(self):\r\n        return self._unit_ids.tolist()\r\n\r\n\r\ndef _find_site_neighbors(site_locs, n_neighbors, shank_map):\r\n    from scipy.spatial.distance import cdist\r\n\r\n    if np.unique(shank_map).size <= 1:\r\n        pass\r\n\r\n    n_sites = site_locs.shape[0]\r\n    n_neighbors = int(min(n_neighbors, n_sites))\r\n\r\n    neighbors = np.zeros((n_sites, n_neighbors), dtype=np.int)\r\n    for i in range(n_sites):\r\n        i_loc = site_locs[i, :][np.newaxis, :]\r\n        dists = cdist(i_loc, site_locs).ravel()\r\n        neighbors[i, :] = dists.argsort()[:n_neighbors]\r\n\r\n    return neighbors\r\n"
  },
  {
    "path": "spikeextractors/extractors/kilosortextractors/__init__.py",
    "content": "from .kilosortextractors import KiloSortSortingExtractor, KiloSortRecordingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/kilosortextractors/kilosortextractors.py",
    "content": "from spikeextractors.extractors.phyextractors import PhyRecordingExtractor, PhySortingExtractor\nfrom pathlib import Path\n\n\nclass KiloSortRecordingExtractor(PhyRecordingExtractor):\n    extractor_name = 'KiloSortRecording'\n    has_default_locations = True\n    has_unscaled = False\n    installed = True  # check at class level if installed or not\n    is_writable = False\n    mode = 'folder'\n    installation_mesg = \"\"  # error message when not installed\n\n    def __init__(self, folder_path):\n        PhyRecordingExtractor.__init__(self, folder_path)\n\n\nclass KiloSortSortingExtractor(PhySortingExtractor):\n    extractor_name = 'KiloSortSorting'\n    installed = True  # check at class level if installed or not\n    installation_mesg = \"\"  # error message when not installed\n    is_writable = False\n    mode = 'folder'\n\n    def __init__(self, folder_path, exclude_cluster_groups=None, keep_good_only=False):\n        PhySortingExtractor.__init__(self, folder_path, exclude_cluster_groups)\n        self._keep_good_only = keep_good_only\n        self._good_units = []\n\n        if keep_good_only:\n            for u in self.get_unit_ids():\n                if 'KSLabel' in self.get_unit_property_names(u):\n                    if self.get_unit_property(u, 'KSLabel') == 'good':\n                        self._good_units.append(u)\n            self._unit_ids = self._good_units\n\n        self._kwargs = {'folder_path': str(Path(folder_path).absolute()),\n                        'exclude_cluster_groups': exclude_cluster_groups, 'keep_good_only': keep_good_only}\n"
  },
  {
    "path": "spikeextractors/extractors/klustaextractors/__init__.py",
    "content": "from .klustaextractors import KlustaSortingExtractor, KlustaRecordingExtractor"
  },
  {
    "path": "spikeextractors/extractors/klustaextractors/klustaextractors.py",
    "content": "\"\"\"\nkwik structure based on:\nhttps://github.com/kwikteam/phy-doc/blob/master/docs/kwik-format.md\n\ncluster_group defaults based on:\nhttps://github.com/kwikteam/phy-doc/blob/master/docs/kwik-model.md\n\n04/08/20\n\"\"\"\n\n\nfrom spikeextractors import SortingExtractor\nfrom spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor\nfrom spikeextractors.extraction_tools import read_python, check_get_unit_spike_train\nimport numpy as np\nfrom pathlib import Path\n\n\ntry:\n    import h5py\n    HAVE_KLSX = True\nexcept ImportError:\n    HAVE_KLSX = False\n\n\n# noinspection SpellCheckingInspection\nclass KlustaRecordingExtractor(BinDatRecordingExtractor):\n    extractor_name = 'KlustaRecording'\n    has_default_locations = False\n    has_unscaled = False\n    installed = HAVE_KLSX  # check at class level if installed or not\n    is_writable = True\n    mode = 'folder'\n    installation_mesg = \"To use the KlustaSortingExtractor install h5py: \\n\\n pip install h5py\\n\\n\"  # error message when not installed\n\n    def __init__(self, folder_path):\n        assert self.installed, self.installation_mesg\n        klustafolder = Path(folder_path).absolute()\n        config_file = [f for f in klustafolder.iterdir() if f.suffix == '.prm'][0]\n        dat_file = [f for f in klustafolder.iterdir() if f.suffix == '.dat'][0]\n        assert config_file.is_file() and dat_file.is_file(), \"Not a valid klusta folder\"\n        config = read_python(str(config_file))\n        sampling_frequency = config['traces']['sample_rate']\n        n_channels = config['traces']['n_channels']\n        dtype = config['traces']['dtype']\n\n        BinDatRecordingExtractor.__init__(self, file_path=dat_file, sampling_frequency=sampling_frequency, numchan=n_channels,\n                                          dtype=dtype)\n\n        self._kwargs = {'folder_path': str(Path(folder_path).absolute())}\n\n\n# noinspection SpellCheckingInspection\nclass KlustaSortingExtractor(SortingExtractor):\n    extractor_name = 'KlustaSorting'\n    installed = HAVE_KLSX  # check at class level if installed or not\n    installation_mesg = \"To use the KlustaSortingExtractor install h5py: \\n\\n pip install h5py\\n\\n\"  # error message when not installed\n    is_writable = True\n    mode = 'file_or_folder'\n\n    default_cluster_groups = {0: 'Noise', 1: 'MUA', 2: 'Good', 3: 'Unsorted'}\n\n    def __init__(self, file_or_folder_path, exclude_cluster_groups=None):\n        assert self.installed, self.installation_mesg\n        SortingExtractor.__init__(self)\n        kwik_file_or_folder = Path(file_or_folder_path)\n        kwikfile = None\n        klustafolder = None\n        if kwik_file_or_folder.is_file():\n            assert kwik_file_or_folder.suffix == '.kwik', \"Not a '.kwik' file\"\n            kwikfile = Path(kwik_file_or_folder).absolute()\n            klustafolder = kwikfile.parent\n        elif kwik_file_or_folder.is_dir():\n            klustafolder = kwik_file_or_folder\n            kwikfiles = [f for f in kwik_file_or_folder.iterdir() if f.suffix == '.kwik']\n            if len(kwikfiles) == 1:\n                kwikfile = kwikfiles[0]\n        assert kwikfile is not None, \"Could not load '.kwik' file\"\n\n        try:\n            config_file = [f for f in klustafolder.iterdir() if f.suffix == '.prm'][0]\n            config = read_python(str(config_file))\n            sampling_frequency = config['traces']['sample_rate']\n            self._sampling_frequency = sampling_frequency\n        except Exception as e:\n            print(\"Could not load sampling frequency info\")\n\n        kf_reader = h5py.File(kwikfile, 'r')\n        self._spiketrains = []\n        self._unit_ids = []\n        unique_units = []\n        klusta_units = []\n        cluster_groups_name = []\n        groups = []\n        unit = 0\n\n        cs_to_exclude = []\n        valid_group_names = [i[1].lower() for i in self.default_cluster_groups.items()]\n        if exclude_cluster_groups is not None:\n            assert isinstance(exclude_cluster_groups, list), 'exclude_cluster_groups should be a list'\n            for ec in exclude_cluster_groups:\n                assert ec in valid_group_names, f'select exclude names out of: {valid_group_names}'\n                cs_to_exclude.append(ec.lower())\n\n        for channel_group in kf_reader.get('/channel_groups'):\n            if 'spikes' not in kf_reader.get(f'/channel_groups/{channel_group}'):\n                print('No spikes found for this channel group')\n                continue\n            else:\n                chan_cluster_id_arr = kf_reader.get(f'/channel_groups/{channel_group}/spikes/clusters/main')[()]\n                chan_cluster_times_arr = kf_reader.get(f'/channel_groups/{channel_group}/spikes/time_samples')[()]\n                chan_cluster_ids = np.unique(chan_cluster_id_arr)  # if clusters were merged in gui,\n                                                                    # the original id's are still in the kwiktree, but\n                                                                    # in this array\n\n                for cluster_id in chan_cluster_ids:\n                    cluster_frame_idx = np.nonzero(chan_cluster_id_arr == cluster_id)  # the [()] is a h5py thing\n                    st = chan_cluster_times_arr[cluster_frame_idx]\n                    assert st.shape[0] > 0, 'no spikes in cluster'\n                    cluster_group = kf_reader.get(f'/channel_groups/{channel_group}/clusters/main/{cluster_id}').attrs['cluster_group']\n\n                    assert cluster_group in self.default_cluster_groups.keys(), f'cluster_group not in \"default_dict: {cluster_group}'\n                    cluster_group_name = self.default_cluster_groups[cluster_group]\n\n                    if cluster_group_name.lower() in cs_to_exclude:\n                        continue\n\n                    self._spiketrains.append(st)\n                    klusta_units.append(int(cluster_id))\n                    unique_units.append(unit)\n                    unit += 1\n                    groups.append(int(channel_group))\n                    cluster_groups_name.append(cluster_group_name)\n\n        if len(np.unique(klusta_units)) == len(np.unique(unique_units)):\n            self._unit_ids = klusta_units\n        else:\n            print('Klusta units are not unique! Using unique unit ids')\n            self._unit_ids = unique_units\n        for i, u in enumerate(self._unit_ids):\n            self.set_unit_property(u, 'group', groups[i])\n            self.set_unit_property(u, 'quality', cluster_groups_name[i].lower())\n\n        self._kwargs = {'file_or_folder_path': str(Path(file_or_folder_path).absolute())}\n\n    def get_unit_ids(self):\n        return list(self._unit_ids)\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n\n        times = self._spiketrains[self.get_unit_ids().index(unit_id)]\n        inds = np.where((start_frame <= times) & (times < end_frame))\n        return times[inds]\n"
  },
  {
    "path": "spikeextractors/extractors/matsortingextractor/__init__.py",
    "content": "from .matsortingextractor import MATSortingExtractor"
  },
  {
    "path": "spikeextractors/extractors/matsortingextractor/matsortingextractor.py",
    "content": "from collections import deque\nfrom pathlib import Path\nfrom typing import Union\nimport numpy as np\n\ntry:\n    import h5py\n    HAVE_H5PY = True\nexcept ImportError:\n    HAVE_H5PY = False\n\ntry:\n    from scipy.io.matlab import loadmat, savemat\n\n    HAVE_LOADMAT = True\nexcept ImportError:\n    HAVE_LOADMAT = False\n\n\ntry:\n    import hdf5storage\n    HAVE_HDF5STORAGE = True\nexcept ImportError:\n    HAVE_HDF5STORAGE = False\n\nHAVE_MAT = HAVE_H5PY & HAVE_LOADMAT\n\nfrom spikeextractors import SortingExtractor\n\nPathType = Union[str, Path]\n\n\nclass MATSortingExtractor(SortingExtractor):\n    extractor_name = \"MATSortingExtractor\"\n    installed = HAVE_MAT  # check at class level if installed or not\n    is_writable = False\n    mode = \"file\"\n    installation_mesg = \"To use the MATSortingExtractor install h5py and scipy: \" \\\n                        \"\\n\\n pip install h5py scipy\\n\\n\"  # error message when not installed\n\n    def __init__(self, file_path: PathType):\n        assert self.installed, self.installation_mesg\n        super().__init__()\n\n        file_path = Path(file_path) if isinstance(file_path, str) else file_path\n        if not isinstance(file_path, Path):\n            raise TypeError(f\"Expected a str or Path file_path but got '{type(file_path).__name__}'\")\n\n        file_path = file_path.resolve()  # get absolute path to this file\n        if not file_path.is_file():\n            raise ValueError(f\"Specified file path '{file_path}' is not a file.\")\n\n        self._kwargs = {\"file_path\": str(file_path.absolute())}\n\n        try:  # load old-style (up to 7.2) .mat file\n            self._data = loadmat(file_path, matlab_compatible=True)\n            self._old_style_mat = True\n        except NameError:  # loadmat not defined\n            raise ImportError(\"Old-style .mat file given, but `loadmat` is not defined.\")\n        except NotImplementedError:  # new style .mat file\n            try:\n                self._data = h5py.File(file_path, \"r+\")\n                self._old_style_mat = False\n            except NameError:\n                raise ImportError(\"Version 7.2 .mat file given, but you don't have h5py installed.\")\n\n    def __del__(self):\n        if not self._old_style_mat:\n            self._data.close()\n                \n    def _getfield(self, fieldname: str):\n        def _drill(d: dict, keys: deque):\n            if len(keys) == 1:\n                return d[keys.popleft()]\n            else:\n                return _drill(d[keys.popleft()], keys)\n\n        if self._old_style_mat:\n            return _drill(self._data, deque(fieldname.split(\"/\")))\n        else:\n            return self._data[fieldname][()]\n\n    @staticmethod\n    def write_dict_to_mat(mat_file_path, dict_to_write, version='7.3'):  # field must be a dict\n        assert HAVE_HDF5STORAGE, \"To use the MATSortingExtractor write_dict_to_mat function install hdf5storage: \" \\\n                                 \"\\n\\n pip install hdf5storage\\n\\n\"\n        if version == '7.3':\n            hdf5storage.write(dict_to_write, '/', mat_file_path, matlab_compatible=True, options='w')\n        elif version < '7.3' and version > '4':\n            savemat(mat_file_path, dict_to_write)\n"
  },
  {
    "path": "spikeextractors/extractors/maxwellextractors/__init__.py",
    "content": "from .maxwellextractors import MaxOneRecordingExtractor, MaxOneSortingExtractor, \\\n    MaxTwoRecordingExtractor, MaxTwoSortingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/maxwellextractors/maxwellextractors.py",
    "content": "from spikeextractors import RecordingExtractor, SortingExtractor\nfrom pathlib import Path\nimport numpy as np\nfrom spikeextractors.extraction_tools import check_get_traces_args, check_get_ttl_args, check_get_unit_spike_train\n\ntry:\n    import h5py\n    HAVE_MAX = True\nexcept ImportError:\n    HAVE_MAX = False\n\ninstallation_mesg = \"To use the MaxOneRecordingExtractor install h5py: \\n\\n pip install h5py\\n\\n\"\n\n\nclass MaxOneRecordingExtractor(RecordingExtractor):\n    extractor_name = 'MaxOneRecording'\n    has_default_locations = True\n    has_unscaled = True\n    installed = HAVE_MAX  # check at class level if installed or not\n    is_writable = False\n    mode = 'file'\n    installation_mesg = installation_mesg\n\n    def __init__(self, file_path, load_spikes=True, rec_name='rec0000'):\n        assert self.installed, self.installation_mesg\n        RecordingExtractor.__init__(self)\n        self._file_path = file_path\n        self._fs = None\n        self._positions = None\n        self._recordings = None\n        self._filehandle = None\n        self._load_spikes = load_spikes\n        self._mapping = None\n        self._rec_name = rec_name\n        self._initialize()\n        self._kwargs = {'file_path': str(Path(file_path).absolute()),\n                        'load_spikes': load_spikes}\n\n    def __del__(self):\n        self._filehandle.close()\n\n    def _initialize(self):\n        self._filehandle = h5py.File(self._file_path, 'r')\n        self._version = self._filehandle['version'][0].decode()\n\n        if int(self._version) == 20160704:\n            # old format\n            self._mapping = self._filehandle['mapping']\n            if 'lsb' in self._filehandle['settings'].keys():\n                self._lsb = self._filehandle['settings']['lsb'][0] * 1e6\n            else:\n                print(\"Couldn't read lsb. Setting lsb to 1\")\n                self._lsb = 1.\n            channels = np.array(self._mapping['channel'])\n            electrodes = np.array(self._mapping['electrode'])\n            # remove unused channels\n            routed_idxs = np.where(electrodes > -1)[0]\n            self._channel_ids = list(channels[routed_idxs])\n            self._electrode_ids = list(electrodes[routed_idxs])\n            self._num_channels = len(self._channel_ids)\n            self._fs = float(20000)\n            self._signals = self._filehandle['sig']\n            self._num_frames = self._signals.shape[1]\n        elif int(self._version) > 20160704:\n            # new format\n            well_name = 'well000'\n            rec_name = self._rec_name\n            settings = self._filehandle['wells'][well_name][rec_name]['settings']\n            self._mapping = settings['mapping']\n            if 'lsb' in settings.keys():\n                self._lsb = settings['lsb'][()][0] * 1e6\n            else:\n                self._lsb = 1.\n            channels = np.array(self._mapping['channel'])\n            electrodes = np.array(self._mapping['electrode'])\n            # remove unused channels\n            routed_idxs = np.where(electrodes > -1)[0]\n            self._channel_ids = list(channels[routed_idxs])\n            self._electrode_ids = list(electrodes[routed_idxs])\n            self._num_channels = len(self._channel_ids)\n            self._fs = settings['sampling'][()][0]\n            self._signals = self._filehandle['wells'][well_name][rec_name]['groups']['routed']['raw']\n            self._num_frames = self._signals.shape[1]\n        else:\n            raise Exception(\"Could not parse the MaxOne file\")\n\n        # This happens when only spikes are recorded\n        if self._num_frames == 0:\n            find_max_frame = True\n        else:\n            find_max_frame = False\n\n        for i_ch, ch, el in zip(routed_idxs, self._channel_ids, self._electrode_ids):\n            self.set_channel_locations([self._mapping['x'][i_ch], self._mapping['y'][i_ch]], ch)\n            self.set_channel_property(ch, 'electrode', el)\n\n        # set gains\n        self.set_channel_gains(self._lsb)\n\n        if self._load_spikes:\n            if 'proc0' in self._filehandle:\n                if 'spikeTimes' in self._filehandle['proc0']:\n                    spikes = self._filehandle['proc0']['spikeTimes']\n\n                    spike_mask = [True] * len(spikes)\n                    for i, ch in enumerate(spikes['channel']):\n                        if ch not in self._channel_ids:\n                            spike_mask[i] = False\n                    spikes_channels = np.array(spikes['channel'])[spike_mask]\n\n                    if find_max_frame:\n                        self._num_frames = np.ptp(spikes['frameno'])\n\n                    # load activity as property\n                    activity_channels, counts = np.unique(spikes_channels, return_counts=True)\n                    # transform to spike rate\n                    duration = float(self._num_frames) / self._fs\n                    counts = counts.astype(float) / duration\n                    activity_channels = list(activity_channels)\n                    for ch in self.get_channel_ids():\n                        if ch in activity_channels:\n                            self.set_channel_property(ch, 'spike_rate', counts[activity_channels.index(ch)])\n                            spike_amplitudes = spikes[np.where(spikes['channel'] == ch)]['amplitude']\n                            self.set_channel_property(ch, 'spike_amplitude', np.median(spike_amplitudes))\n                        else:\n                            self.set_channel_property(ch, 'spike_rate', 0)\n                            self.set_channel_property(ch, 'spike_amplitude', 0)\n\n    def get_channel_ids(self):\n        return list(self._channel_ids)\n\n    def get_electrode_ids(self):\n        return list(self._electrode_ids)\n\n    def get_num_frames(self):\n        return self._num_frames\n\n    def get_sampling_frequency(self):\n        return self._fs\n\n    def correct_for_missing_frames(self, verbose=False):\n        \"\"\"\n        Corrects for missing frames. The correct times can be retrieved with the frame_to_time and time_to_frame\n        functions.\n        Parameters\n        ----------\n        verbose: bool\n            If True, output is verbose\n        \"\"\"\n        frame_idxs_span = self._get_frame_number(self.get_num_frames() - 1) - self._get_frame_number(0)\n        if frame_idxs_span > self.get_num_frames():\n            if verbose:\n                print(f\"Found missing frames! Correcting for it (this might take a while)\")\n\n            framenos = self._get_frame_numbers()\n            # find missing frames\n            diff_frames = np.diff(framenos)\n            missing_frames_idxs = np.where(diff_frames > 1)[0]\n\n            delays_in_frames = []\n            for mf_idx in missing_frames_idxs:\n                delays_in_frames.append(diff_frames[mf_idx])\n\n            if verbose:\n                print(f\"Found {len(delays_in_frames)} missing intervals\")\n\n            times = np.round(np.arange(self.get_num_frames()) / self.get_sampling_frequency(), 6)\n\n            for mf_idx, duration in zip(missing_frames_idxs, delays_in_frames):\n                times[mf_idx:] += np.round(duration / self.get_sampling_frequency(), 6)\n            self.set_times(times)\n        else:\n            if verbose:\n                print(\"No missing frames found\")\n\n    def _get_frame_numbers(self):\n        bitvals = self._signals[-2:, :]\n        frame_nos = np.bitwise_or(np.left_shift(bitvals[-1].astype('int64'), 16), bitvals[0])\n        return frame_nos\n\n    def _get_frame_number(self, index):\n        bitvals = self._signals[-2:, index]\n        frameno = bitvals[1] << 16 | bitvals[0]\n        return frameno\n\n    @check_get_traces_args\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        if np.array(channel_ids).size > 1:\n            if np.any(np.diff(channel_ids) < 0):\n                sorted_channel_ids = np.sort(channel_ids)\n                sorted_idx = np.array([list(sorted_channel_ids).index(ch) for ch in channel_ids])\n                traces = self._signals[sorted_channel_ids, start_frame:end_frame][sorted_idx]\n            else:\n                traces = self._signals[np.array(channel_ids), start_frame:end_frame]\n        else:\n            traces = self._signals[np.array(channel_ids), start_frame:end_frame]\n        return traces\n\n    @check_get_ttl_args\n    def get_ttl_events(self, start_frame=None, end_frame=None, channel_id=0):\n        bitvals = self._signals[-2:, 0]\n        first_frame = bitvals[1] << 16 | bitvals[0]\n        bits = self._filehandle['bits']\n        bit_frames = bits['frameno'] - first_frame\n        bit_states = bits['bits']\n        bit_idxs = np.where((bit_frames >= start_frame) & (bit_frames < end_frame))[0]\n        ttl_frames = bit_frames[bit_idxs]\n        ttl_states = bit_states[bit_idxs]\n        ttl_states[ttl_states == 0] = -1\n        return ttl_frames, ttl_states\n\n\nclass MaxOneSortingExtractor(SortingExtractor):\n    extractor_name = 'MaxOneSorting'\n    installed = HAVE_MAX  # check at class level if installed or not\n    is_writable = False\n    mode = 'file'\n    installation_mesg = installation_mesg\n\n    def __init__(self, file_path):\n        assert self.installed, self.installation_mesg\n        SortingExtractor.__init__(self)\n        self._file_path = file_path\n        self._filehandle = None\n        self._mapping = None\n        self._version = None\n        self._initialize()\n        self._kwargs = {'file_path': str(Path(file_path).absolute())}\n\n    def _initialize(self):\n        self._filehandle = h5py.File(self._file_path, 'r')\n        self._mapping = self._filehandle['mapping']\n        self._signals = self._filehandle['sig']\n\n        bitvals = self._signals[-2:, 0]\n        self._first_frame = bitvals[1] << 16 | bitvals[0]\n\n        channels = np.array(self._mapping['channel'])\n        electrodes = np.array(self._mapping['electrode'])\n        # remove unused channels\n        routed_idxs = np.where(electrodes > -1)[0]\n        self._channel_ids = list(channels[routed_idxs])\n        self._unit_ids = list(electrodes[routed_idxs])\n        self._sampling_frequency = float(20000)\n\n        self._spiketrains = []\n        self._unit_ids = []\n\n        try:\n            spikes = self._filehandle['proc0']['spikeTimes']\n            for u in self._channel_ids:\n                spiketrain_idx = np.where(spikes['channel'] == u)[0]\n                if len(spiketrain_idx) > 0:\n                    self._unit_ids.append(u)\n                    spiketrain = spikes['frameno'][spiketrain_idx] - self._first_frame\n                    idxs_greater_0 = np.where(spiketrain >= 0)[0]\n                    self._spiketrains.append(spiketrain[idxs_greater_0])\n                    self.set_unit_spike_features(u, 'amplitude', spikes['amplitude'][spiketrain_idx][idxs_greater_0])\n        except:\n            raise AttributeError(\"Spike times information are missing from the .h5 file\")\n\n    def get_unit_ids(self):\n        return self._unit_ids\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n        start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame)\n        if start_frame is None:\n            start_frame = 0\n        if end_frame is None:\n            end_frame = np.Inf\n        unit_idx = self._unit_ids.index(unit_id)\n        spiketrain = self._spiketrains[unit_idx]\n        inds = np.where((start_frame <= spiketrain) & (spiketrain < end_frame))\n        return spiketrain[inds]\n\n\nclass MaxTwoRecordingExtractor(RecordingExtractor):\n    extractor_name = 'MaxTwoRecording'\n    has_default_locations = True\n    has_unscaled = True\n    installed = HAVE_MAX  # check at class level if installed or not\n    is_writable = False\n    mode = 'file'\n    installation_mesg = installation_mesg\n\n    def __init__(self, file_path, well_name='well000', rec_name='rec0000', load_spikes=True):\n        assert self.installed, self.installation_mesg\n        RecordingExtractor.__init__(self)\n        self._file_path = file_path\n        self._well_name = well_name\n        self._rec_name = rec_name\n        self._fs = None\n        self._positions = None\n        self._recordings = None\n        self._filehandle = None\n        self._mapping = None\n        self._load_spikes = load_spikes\n        self._initialize()\n        self._kwargs = {'file_path': str(Path(file_path).absolute()), 'well_name': well_name, 'rec_name': rec_name,\n                        'load_spikes': load_spikes}\n\n    def _initialize(self):\n        self._filehandle = h5py.File(self._file_path, 'r')\n        settings = self._filehandle['wells'][self._well_name][self._rec_name]['settings']\n        self._mapping = settings['mapping']\n        if 'lsb' in settings.keys():\n            self._lsb = settings['lsb'][()][0] * 1e6\n        else:\n            self._lsb = 1.\n        channels = np.array(self._mapping['channel'])\n        electrodes = np.array(self._mapping['electrode'])\n        # remove unused channels\n        routed_idxs = np.where(electrodes > -1)[0]\n        self._channel_ids = list(channels[routed_idxs])\n        self._electrode_ids = list(electrodes[routed_idxs])\n        self._num_channels = len(self._channel_ids)\n        self._fs = settings['sampling'][()][0]\n        self._signals = self._filehandle['wells'][self._well_name][self._rec_name]['groups']['routed']['raw']\n        self._num_frames = self._signals.shape[1]\n\n        # This happens when only spikes are recorded\n        if self._num_frames == 0:\n            find_max_frame = True\n        else:\n            find_max_frame = False\n\n        for i_ch, ch, el in zip(routed_idxs, self._channel_ids, self._electrode_ids):\n            self.set_channel_locations([self._mapping['x'][i_ch], self._mapping['y'][i_ch]], ch)\n            self.set_channel_property(ch, 'electrode', el)\n        # set gains\n        self.set_channel_gains(self._lsb)\n\n        if self._load_spikes:\n            if \"spikes\" in self._filehandle[\"wells\"][self._well_name][self._rec_name].keys():\n                spikes = self._filehandle[\"wells\"][self._well_name][self._rec_name][\"spikes\"]\n\n                spike_mask = [True] * len(spikes)\n                for i, ch in enumerate(spikes['channel']):\n                    if ch not in self._channel_ids:\n                        spike_mask[i] = False\n                spikes_channels = np.array(spikes['channel'])[spike_mask]\n\n                if find_max_frame:\n                    self._num_frames = np.ptp(spikes['frameno'])\n\n                # load activity as property\n                activity_channels, counts = np.unique(spikes_channels, return_counts=True)\n                # transform to spike rate\n                duration = float(self._num_frames) / self._fs\n                counts = counts.astype(float) / duration\n                activity_channels = list(activity_channels)\n                for ch in self.get_channel_ids():\n                    if ch in activity_channels:\n                        self.set_channel_property(ch, 'spike_rate', counts[activity_channels.index(ch)])\n                        spike_amplitudes = spikes[np.where(spikes['channel'] == ch)]['amplitude']\n                        self.set_channel_property(ch, 'spike_amplitude', np.median(spike_amplitudes))\n                    else:\n                        self.set_channel_property(ch, 'spike_rate', 0)\n                        self.set_channel_property(ch, 'spike_amplitude', 0)\n\n    def get_channel_ids(self):\n        return list(self._channel_ids)\n\n    def get_num_frames(self):\n        return self._num_frames\n\n    def get_sampling_frequency(self):\n        return self._fs\n\n    @staticmethod\n    def get_well_names(file_path):\n        with h5py.File(file_path, 'r') as f:\n            wells = list(f[\"wells\"])\n        return wells\n\n    @staticmethod\n    def get_recording_names(file_path, well_name):\n        with h5py.File(file_path, 'r') as f:\n            assert well_name in f[\"wells\"], f\"Well name should be among: \" \\\n                                            f\"{MaxTwoRecordingExtractor.get_well_names(file_path)}\"\n            rec_names = list(f[\"wells\"][well_name].keys())\n        return rec_names\n\n    @check_get_traces_args\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        channel_idxs = np.array([self.get_channel_ids().index(ch) for ch in channel_ids])\n        if np.array(channel_idxs).size > 1:\n            if np.any(np.diff(channel_idxs) < 0):\n                sorted_channel_ids = np.sort(channel_idxs)\n                sorted_idx = np.array([list(sorted_channel_ids).index(ch) for ch in channel_idxs])\n                traces = self._signals[sorted_channel_ids, start_frame:end_frame][sorted_idx]\n            else:\n                traces = self._signals[np.array(channel_idxs), start_frame:end_frame]\n        else:\n            traces = self._signals[np.array(channel_idxs), start_frame:end_frame]\n        return traces\n\n\nclass MaxTwoSortingExtractor(SortingExtractor):\n    extractor_name = 'MaxTwoSorting'\n    installed = HAVE_MAX  # check at class level if installed or not\n    is_writable = False\n    mode = 'file'\n    installation_mesg = installation_mesg\n\n    def __init__(self, file_path, well_name='well000', rec_name='rec0000'):\n        assert self.installed, self.installation_mesg\n        SortingExtractor.__init__(self)\n        self._file_path = file_path\n        self._well_name = well_name\n        self._rec_name = rec_name\n        self._filehandle = None\n        self._mapping = None\n        self._version = None\n        self._initialize()\n        self._sampling_frequency = self._fs\n        self._kwargs = {'file_path': str(Path(file_path).absolute()), 'well_name': well_name, 'rec_name': rec_name}\n\n    def _initialize(self):\n        self._filehandle = h5py.File(self._file_path, 'r')\n        settings = self._filehandle['wells'][self._well_name][self._rec_name]['settings']\n        self._mapping = settings['mapping']\n        if 'lsb' in settings.keys():\n            self._lsb = settings['lsb'][()] * 1e6\n        else:\n            self._lsb = 1.\n        channels = np.array(self._mapping['channel'])\n        electrodes = np.array(self._mapping['electrode'])\n        # remove unused channels\n        routed_idxs = np.where(electrodes > -1)[0]\n        self._channel_ids = list(channels[routed_idxs])\n        self._unit_ids = list(electrodes[routed_idxs])\n        self._fs = settings['sampling'][()][0]\n        self._first_frame = self._filehandle['wells'][self._well_name][self._rec_name] \\\n            ['groups']['routed']['frame_nos'][0]\n\n        self._spiketrains = []\n        self._unit_ids = []\n        try:\n            spikes = self._filehandle[\"wells\"][self._well_name][self._rec_name][\"spikes\"]\n            for u in self._channel_ids:\n                spiketrain_idx = np.where(spikes['channel'] == u)[0]\n                if len(spiketrain_idx) > 0:\n                    self._unit_ids.append(u)\n                    spiketrain = spikes['frameno'][spiketrain_idx] - self._first_frame\n                    idxs_greater_0 = np.where(spiketrain >= 0)[0]\n                    self._spiketrains.append(spiketrain[idxs_greater_0])\n                    self.set_unit_spike_features(u, 'amplitude', spikes['amplitude'][spiketrain_idx][idxs_greater_0])\n        except:\n            raise AttributeError(\"Spike times information are missing from the .h5 file\")\n\n    def get_unit_ids(self):\n        return self._unit_ids\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n        start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame)\n        if start_frame is None:\n            start_frame = 0\n        if end_frame is None:\n            end_frame = np.Inf\n        unit_idx = self._unit_ids.index(unit_id)\n        spiketrain = self._spiketrains[unit_idx]\n        inds = np.where((start_frame <= spiketrain) & (spiketrain < end_frame))\n        return spiketrain[inds]\n\n"
  },
  {
    "path": "spikeextractors/extractors/mcsh5recordingextractor/__init__.py",
    "content": "from .mcsh5recordingextractor import MCSH5RecordingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/mcsh5recordingextractor/mcsh5recordingextractor.py",
    "content": "from spikeextractors import RecordingExtractor\nimport numpy as np\nfrom pathlib import Path\nfrom spikeextractors.extraction_tools import check_get_traces_args\n\ntry:\n    import h5py\n\n    HAVE_MCSH5 = True\nexcept ImportError:\n    HAVE_MCSH5 = False\n\n\nclass MCSH5RecordingExtractor(RecordingExtractor):\n    extractor_name = 'MCSH5Recording'\n    has_default_locations = False\n    has_unscaled = False\n    installed = HAVE_MCSH5  # check at class level if installed or not\n    is_writable = False\n    mode = 'file'\n    installation_mesg = \"To use the MCSH5RecordingExtractor install h5py: \\n\\n pip install h5py\\n\\n\"  # error message when not installed\n\n    def __init__(self, file_path, stream_id=0, verbose=False):\n        assert self.installed, self.installation_mesg\n        self._recording_file = file_path\n        self._verbose = verbose\n        self._available_stream_ids = self.get_available_stream_ids()\n        self.set_stream_id(stream_id)\n\n        RecordingExtractor.__init__(self)\n        self._kwargs = {'file_path': str(Path(file_path).absolute()), 'stream_id': stream_id,\n                        'verbose': verbose}\n\n    def __del__(self):\n        self._rf.close()\n\n    def get_channel_ids(self):\n        return list(self._channel_ids)\n\n    def get_num_frames(self):\n        return self._nFrames\n\n    def get_sampling_frequency(self):\n        return self._samplingRate\n\n    def set_stream_id(self, stream_id):\n        assert stream_id in self._available_stream_ids, \"The specified stream ID is unavailable.\"\n        self._stream_id = stream_id\n\n        if hasattr(self, '_rf'):\n            self._rf.close()\n\n        self._rf, self._nFrames, self._samplingRate, self._nRecCh, \\\n        self._channel_ids, self._electrodeLabels, self._exponent, self._convFact \\\n            = openMCSH5File(self._recording_file, stream_id, self._verbose)\n\n    def get_stream_id(self):\n        assert hasattr(self, '_stream_id'), \"Stream ID has not been set yet.\"\n        return self._stream_id\n\n    def get_available_stream_ids(self):\n        if hasattr(self, '_available_stream_ids'):\n            return self._available_stream_ids\n        else:\n            rf = h5py.File(self._recording_file, 'r')\n            analog_stream_names = list(rf.require_group('/Data/Recording_0/AnalogStream').keys())\n            return list(range(len(analog_stream_names)))\n\n    @check_get_traces_args\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        channel_idxs = []\n        for m in channel_ids:\n            assert m in self._channel_ids, 'channel_id {} not found'.format(m)\n            channel_idxs.append(np.where(np.array(self._channel_ids) == m)[0][0])\n\n        stream = self._rf.require_group('/Data/Recording_0/AnalogStream/Stream_' + str(self._stream_id))\n        conv = self._convFact.astype(float) * (10.0 ** self._exponent)\n\n        if np.array(channel_idxs).size > 1:\n            if np.any(np.diff(channel_idxs) < 0):\n                sorted_channel_ids = np.sort(channel_idxs)\n                sorted_idx = np.array([list(sorted_channel_ids).index(ch) for ch in channel_idxs])\n                signals = stream.get('ChannelData')[sorted_channel_ids, start_frame:end_frame][sorted_idx]\n            else:\n                signals =  stream.get('ChannelData')[np.sort(channel_idxs), start_frame:end_frame]\n        else:\n            signals = stream.get('ChannelData')[np.array(channel_idxs), start_frame:end_frame]\n        if return_scaled:\n            return signals * conv\n        else:\n            return signals\n\n\ndef openMCSH5File(filename, stream_id, verbose=False):\n    \"\"\"Open an MCS hdf5 file, read and return the recording info.\"\"\"\n    rf = h5py.File(filename, 'r')\n\n    stream_name = 'Stream_' + str(stream_id)\n    analog_stream_names = list(rf.require_group('/Data/Recording_0/AnalogStream').keys())\n    assert stream_name in analog_stream_names, \"Specified stream does not exist.\"\n\n    stream = rf.require_group('/Data/Recording_0/AnalogStream/' + stream_name)\n    data = np.array(stream.get('ChannelData'), dtype=np.int)\n    timestamps = np.array(stream.get('ChannelDataTimeStamps'))\n    info = np.array(stream.get('InfoChannel'))\n\n    Unit = info['Unit'][0]\n    Tick = info['Tick'][0] / 1e6\n    exponent = info['Exponent'][0]\n    convFact = info['ConversionFactor'][0]\n\n    nRecCh, nFrames = data.shape\n    channel_ids = info['ChannelID']\n    assert len(np.unique(channel_ids)) == len(channel_ids), 'Duplicate MCS channel IDs found'\n    electrodeLabels = info['Label']\n\n    assert timestamps[0][0] < timestamps[0][2], 'Please check the validity of \\'ChannelDataTimeStamps\\' in the stream.'\n    TimeVals = np.arange(timestamps[0][0], timestamps[0][2] + 1, 1) * Tick\n\n    assert Unit == b'V', 'Unexpected units found, expected volts, found {}'.format(Unit.decode('UTF-8'))\n    data_V = data * convFact.astype(float) * (10.0 ** (exponent))\n\n    timestep_avg = np.mean(TimeVals[1:] - TimeVals[0:-1])\n    timestep_std = np.std(TimeVals[1:] - TimeVals[0:-1])\n    timestep_min = np.min(TimeVals[1:] - TimeVals[0:-1])\n    timestep_max = np.min(TimeVals[1:] - TimeVals[0:-1])\n    assert all(np.abs(np.array(\n        (timestep_min, timestep_max)) - timestep_avg) / timestep_avg < 1e-6), 'Time steps vary by more than 1 ppm'\n    samplingRate = 1. / timestep_avg\n\n    if verbose:\n        print('# MCS H5 data format')\n        print('#')\n        print('# File: {}'.format(rf.filename))\n        print('# File size: {:.2f} MB'.format(rf.id.get_filesize() / 1024 ** 2))\n        print('#')\n        for key in rf.attrs.keys():\n            print('# {}: {}'.format(key, rf.attrs[key]))\n        print('#')\n        print('# Signal range: {:.2f} to {:.2f} µV'.format(np.amin(data_V) * 1e6, np.amax(data_V) * 1e6))\n        print('# Number of channels: {}'.format(nRecCh))\n        print('# Number of frames: {}'.format(nFrames))\n        print('# Time step: {:.2f} µs ± {:.5f} % (range {} to {})'.format(timestep_avg * 1e6,\n                                                                          timestep_std / timestep_avg * 100,\n                                                                          timestep_min * 1e6, timestep_max * 1e6))\n        print('# Sampling rate: {:.2f} Hz'.format(samplingRate))\n        print('#')\n        print('# MCSH5RecordingExtractor currently only reads /Data/Recording_0/AnalogStream/Stream_0')\n\n    return rf, nFrames, samplingRate, nRecCh, channel_ids, electrodeLabels, exponent, convFact\n"
  },
  {
    "path": "spikeextractors/extractors/mdaextractors/__init__.py",
    "content": "from .mdaextractors import MdaRecordingExtractor, MdaSortingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/mdaextractors/mdaextractors.py",
    "content": "from spikeextractors import RecordingExtractor\nfrom spikeextractors import SortingExtractor\nfrom spikeextractors.extraction_tools import write_to_binary_dat_format, check_get_traces_args, \\\n    check_get_unit_spike_train\n\nimport json\nimport numpy as np\nfrom pathlib import Path\nfrom .mdaio import DiskReadMda, readmda, writemda64, MdaHeader\nimport shutil\n\n\nclass MdaRecordingExtractor(RecordingExtractor):\n    extractor_name = 'MdaRecording'\n    has_default_locations = True\n    has_unscaled = False\n    installed = True  # check at class level if installed or not\n    is_writable = True\n    mode = 'folder'\n    installation_mesg = \"\"  # error message when not installed\n\n    def __init__(self, folder_path, raw_fname='raw.mda', params_fname='params.json', geom_fname='geom.csv'):\n        dataset_directory = Path(folder_path)\n        self._dataset_directory = dataset_directory\n        timeseries0 = dataset_directory / raw_fname\n        self._dataset_params = read_dataset_params(dataset_directory, params_fname)\n        self._sampling_frequency = self._dataset_params['samplerate'] * 1.0\n        self._timeseries_path = str(timeseries0.absolute())\n        geom0 = dataset_directory / geom_fname\n        self._geom_fname = geom0\n        self._geom = np.loadtxt(self._geom_fname, delimiter=',', ndmin=2)\n        X = DiskReadMda(self._timeseries_path)\n        if self._geom.shape[0] != X.N1():\n            raise Exception(\n                'Incompatible dimensions between geom.csv and timeseries file {} <> {}'.format(self._geom.shape[0],\n                                                                                               X.N1()))\n        self._num_channels = X.N1()\n        self._num_timepoints = X.N2()\n        RecordingExtractor.__init__(self)\n        self.set_channel_locations(self._geom)\n        self._kwargs = {'folder_path': str(Path(folder_path).absolute())}\n\n    def get_channel_ids(self):\n        return list(range(self._num_channels))\n\n    def get_num_frames(self):\n        return self._num_timepoints\n\n    def get_sampling_frequency(self):\n        return self._sampling_frequency\n\n    @check_get_traces_args\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        X = DiskReadMda(self._timeseries_path)\n        recordings = X.readChunk(i1=0, i2=start_frame, N1=X.N1(), N2=end_frame - start_frame)\n        recordings = recordings[channel_ids, :]\n        return recordings\n\n    def write_to_binary_dat_format(self, save_path, time_axis=0, dtype=None, chunk_size=None, chunk_mb=500,\n                                   n_jobs=1, joblib_backend='loky', verbose=False):\n        \"\"\"Saves the traces of this recording extractor into binary .dat format.\n\n        Parameters\n        ----------\n        save_path: str\n            The path to the file.\n        time_axis: 0 (default) or 1\n            If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file.\n            If 1, the traces shape (nb_channel, nb_sample) is kept in the file.\n        dtype: dtype\n            Type of the saved data. Default float32\n        chunk_size: None or int\n            Size of each chunk in number of frames.\n            If None (default) and 'chunk_mb' is given, the file is saved in chunks of 'chunk_mb' Mb (default 500Mb)\n        chunk_mb: None or int\n            Chunk size in Mb (default 500Mb)\n        n_jobs: int\n            Number of jobs to use (Default 1)\n        joblib_backend: str\n            Joblib backend for parallel processing ('loky', 'threading', 'multiprocessing')\n        verbose: bool\n            If True, output is verbose\n        \"\"\"\n        X = DiskReadMda(self._timeseries_path)\n        header_size = X._header.header_size\n        if dtype is None or dtype == self.get_dtype():\n            try:\n                with open(self._timeseries_path, 'rb') as src, open(save_path, 'wb') as dst:\n                    src.seek(header_size)\n                    shutil.copyfileobj(src, dst)\n            except Exception as e:\n                print('Error occurred while copying:', e)\n                print('Writing to binary')\n                write_to_binary_dat_format(self, save_path=save_path, time_axis=time_axis, dtype=dtype,\n                                           chunk_size=chunk_size, chunk_mb=chunk_mb, n_jobs=n_jobs,\n                                           joblib_backend=joblib_backend, verbose=verbose)\n        else:\n            write_to_binary_dat_format(self, save_path=save_path, time_axis=time_axis, dtype=dtype,\n                                       chunk_size=chunk_size, chunk_mb=chunk_mb, n_jobs=n_jobs,\n                                       joblib_backend=joblib_backend, verbose=verbose)\n\n    @staticmethod\n    def write_recording(recording, save_path, params=dict(), raw_fname='raw.mda', params_fname='params.json',\n                        geom_fname='geom.csv', dtype=None, chunk_size=None, n_jobs=None, chunk_mb=500, verbose=False):\n        \"\"\"\n        Writes recording to file in MDA format.\n\n        Parameters\n        ----------\n        recording: RecordingExtractor\n            The recording extractor to be saved\n        save_path: str or Path\n            The folder in which the Mda files are saved\n        params: dictionary\n            Dictionary with optional parameters to save metadata. Sampling frequency is appended to this dictionary.\n        raw_fname: str\n            File name of raw file (default raw.mda)\n        params_fname: str\n            File name of params file (default params.json)\n        geom_fname: str\n            File name of geom file (default geom.csv)\n        dtype: dtype\n            dtype to be used. If None dtype is same as recording traces.\n        chunk_size: None or int\n            Size of each chunk in number of frames.\n            If None (default) and 'chunk_mb' is given, the file is saved in chunks of 'chunk_mb' Mb (default 500Mb)\n        n_jobs: int\n            Number of jobs to use (Default 1)\n        chunk_mb: None or int\n            Chunk size in Mb (default 500Mb)\n        verbose: bool\n            If True, output is verbose\n        \"\"\"\n        save_path = Path(save_path)\n        save_path.mkdir(parents=True, exist_ok=True)\n        save_file_path = save_path / raw_fname\n        parent_dir = save_path\n        channel_ids = recording.get_channel_ids()\n        num_chan = recording.get_num_channels()\n        num_frames = recording.get_num_frames()\n\n        geom = recording.get_channel_locations()\n\n        if dtype is None:\n            dtype = recording.get_dtype()\n\n        if dtype == 'float':\n            dtype = 'float32'\n        if dtype == 'int':\n            dtype = 'int16'\n\n        with save_file_path.open('wb') as f:\n            header = MdaHeader(dt0=dtype, dims0=(num_chan, num_frames))\n            header.write(f)\n            # takes care of the chunking\n            write_to_binary_dat_format(recording, file_handle=f, dtype=dtype, n_jobs=n_jobs, chunk_size=chunk_size,\n                                       chunk_mb=chunk_mb, verbose=verbose)\n\n        params[\"samplerate\"] = float(recording.get_sampling_frequency())\n        with (parent_dir / params_fname).open('w') as f:\n            json.dump(params, f)\n        np.savetxt(str(parent_dir / geom_fname), geom, delimiter=',')\n\n\nclass MdaSortingExtractor(SortingExtractor):\n    extractor_name = 'MdaSorting'\n    installed = True  # check at class level if installed or not\n    is_writable = True\n    mode = 'file'\n    installation_mesg = \"\"  # error message when not installed\n\n    def __init__(self, file_path, sampling_frequency=None):\n\n        SortingExtractor.__init__(self)\n        self._firings_path = file_path\n        self._firings = readmda(self._firings_path)\n        self._max_channels = self._firings[0, :]\n        self._spike_times = self._firings[1, :]\n        self._labels = self._firings[2, :]\n        self._unit_ids = np.unique(self._labels).astype(int)\n        self._sampling_frequency = sampling_frequency\n        self._kwargs = {'file_path': str(Path(file_path).absolute()), 'sampling_frequency': sampling_frequency}\n\n    def get_unit_ids(self):\n        return list(self._unit_ids)\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n\n        inds = np.where(\n            (self._labels == unit_id) & (start_frame <= self._spike_times) & (self._spike_times < end_frame))\n        return np.rint(self._spike_times[inds]).astype(int)\n\n    @staticmethod\n    def write_sorting(sorting, save_path, write_primary_channels=False):\n        unit_ids = sorting.get_unit_ids()\n        times_list = []\n        labels_list = []\n        primary_channels_list = []\n        for unit_id in unit_ids:\n            times = sorting.get_unit_spike_train(unit_id=unit_id)\n            times_list.append(times)\n            labels_list.append(np.ones(times.shape) * unit_id)\n            if write_primary_channels:\n                if 'max_channel' in sorting.get_unit_property_names(unit_id):\n                    primary_channels_list.append([sorting.get_unit_property(unit_id, 'max_channel')] * times.shape[0])\n                else:\n                    raise ValueError(\n                        \"Unable to write primary channels because 'max_channel' spike feature not set in unit \" + str(\n                            unit_id))\n            else:\n                primary_channels_list.append(np.zeros(times.shape))\n        all_times = _concatenate(times_list)\n        all_labels = _concatenate(labels_list)\n        all_primary_channels = _concatenate(primary_channels_list)\n        sort_inds = np.argsort(all_times)\n        all_times = all_times[sort_inds]\n        all_labels = all_labels[sort_inds]\n        all_primary_channels = all_primary_channels[sort_inds]\n        L = len(all_times)\n        firings = np.zeros((3, L))\n        firings[0, :] = all_primary_channels\n        firings[1, :] = all_times\n        firings[2, :] = all_labels\n\n        writemda64(firings, save_path)\n\n\ndef _concatenate(list):\n    if len(list) == 0:\n        return np.array([])\n    return np.concatenate(list)\n\n\ndef read_dataset_params(dsdir, params_fname):\n    fname1 = dsdir / params_fname\n    if not fname1.is_file():\n        raise Exception('Dataset parameter file does not exist: ' + fname1)\n    with open(fname1) as f:\n        return json.load(f)\n"
  },
  {
    "path": "spikeextractors/extractors/mdaextractors/mdaio.py",
    "content": "import numpy as np\nimport struct\nimport os\nimport tempfile\nimport traceback\nfrom pathlib import Path\n\n\nclass MdaHeader:\n    def __init__(self, dt0, dims0):\n        uses64bitdims = (max(dims0) > 2e9)\n        self.uses64bitdims = uses64bitdims\n        self.dt_code = _dt_code_from_dt(dt0)\n        self.dt = dt0\n        self.num_bytes_per_entry = get_num_bytes_per_entry_from_dt(dt0)\n        self.num_dims = len(dims0)\n        self.dimprod = np.prod(dims0)\n        self.dims = dims0\n        if uses64bitdims:\n            self.header_size = 3 * 4 + self.num_dims * 8\n        else:\n            self.header_size = (3 + self.num_dims) * 4\n\n    def write(self, f):\n        H = self\n        _write_int32(f, H.dt_code)\n        _write_int32(f, H.num_bytes_per_entry)\n        if H.uses64bitdims:\n            _write_int32(f, -H.num_dims)\n            for j in range(0, H.num_dims):\n                _write_int64(f, H.dims[j])\n        else:\n            _write_int32(f, H.num_dims)\n            for j in range(0, H.num_dims):\n                _write_int32(f, H.dims[j])\n\n\ndef npy_dtype_to_string(dt):\n    str = dt.str[1:]\n    map = {\n        \"f2\": 'float16',\n        \"f4\": 'float32',\n        \"f8\": 'float64',\n        \"i1\": 'int8',\n        \"i2\": 'int16',\n        \"i4\": 'int32',\n        \"u2\": 'uint16',\n        \"u4\": 'uint32'\n    }\n    return map[str]\n\n\nclass DiskReadMda:\n    def __init__(self, path, header=None):\n        self._npy_mode = False\n        self._path = path\n        if file_extension(path) == '.npy':\n            raise Exception('DiskReadMda implementation has not been tested for npy files')\n            self._npy_mode = True\n            if header:\n                raise Exception('header not allowed in npy mode for DiskReadMda')\n        if header:\n            self._header = header\n            self._header.header_size = 0\n        else:\n            self._header = _read_header(self._path)\n\n    def dims(self):\n        if self._npy_mode:\n            A = np.load(self._path, mmap_mode='r')\n            return A.shape\n        return self._header.dims\n\n    def N1(self):\n        return self.dims()[0]\n\n    def N2(self):\n        return self.dims()[1]\n\n    def N3(self):\n        return self.dims()[2]\n\n    def dt(self):\n        if self._npy_mode:\n            A = np.load(self._path, mmap_mode='r')\n            return npy_dtype_to_string(A.dtype)\n        return self._header.dt\n\n    def numBytesPerEntry(self):\n        if self._npy_mode:\n            A = np.load(self._path, mmap_mode='r')\n            return A.itemsize\n        return self._header.num_bytes_per_entry\n\n    def readChunk(self, i1=-1, i2=-1, i3=-1, N1=1, N2=1, N3=1):\n        # print(\"Reading chunk {} {} {} {} {} {}\".format(i1,i2,i3,N1,N2,N3))\n        if i2 < 0:\n            if self._npy_mode:\n                A = np.load(self._path, mmap_mode='r')\n                return A[:, :, i1:i1 + N1]\n            return self._read_chunk_1d(i1, N1)\n        elif i3 < 0:\n            if N1 != self.N1():\n                print(\"Unable to support N1 {} != {}\".format(N1, self.N1()))\n                return None\n            X = self._read_chunk_1d(i1 + N1 * i2, N1 * N2)\n            if X is None:\n                print('Problem reading chunk from file: ' + self._path)\n                return None\n            if self._npy_mode:\n                A = np.load(self._path, mmap_mode='r')\n                return A[:, i2:i2 + N2]\n            return np.reshape(X, (N1, N2), order='F')\n        else:\n            if N1 != self.N1():\n                print(\"Unable to support N1 {} != {}\".format(N1, self.N1()))\n                return None\n            if N2 != self.N2():\n                print(\"Unable to support N2 {} != {}\".format(N2, self.N2()))\n                return None\n            if self._npy_mode:\n                A = np.load(self._path, mmap_mode='r')\n                return A[:, :, i3:i3 + N3]\n            X = self._read_chunk_1d(i1 + N1 * i2 + N1 * N2 * i3, N1 * N2 * N3)\n            return np.reshape(X, (N1, N2, N3), order='F')\n\n    def _read_chunk_1d(self, i, N):\n        offset = self._header.header_size + self._header.num_bytes_per_entry * i\n        if is_url(self._path):\n            tmp_fname = _download_bytes_to_tmpfile(self._path, offset, offset + self._header.num_bytes_per_entry * N)\n            try:\n                ret = self._read_chunk_1d_helper(tmp_fname, N, offset=0)\n            except:\n                ret = None\n            return ret\n        return self._read_chunk_1d_helper(self._path, N, offset=offset)\n\n    def _read_chunk_1d_helper(self, path0, N, *, offset):\n        f = open(path0, \"rb\")\n        try:\n            f.seek(offset)\n            ret = np.fromfile(f, dtype=self._header.dt, count=N)\n            f.close()\n            return ret\n        except Exception as e:  # catch *all* exceptions\n            print(e)\n            f.close()\n            return None\n\n\ndef is_url(path):\n    return path.startswith('http://') or path.startswith('https://')\n\n\ndef _download_bytes_to_tmpfile(url, start, end):\n    try:\n        import requests\n    except:\n        raise Exception('Unable to import module: requests')\n    headers = {\"Range\": \"bytes={}-{}\".format(start, end - 1)}\n    r = requests.get(url, headers=headers, stream=True)\n    fd, tmp_fname = tempfile.mkstemp()\n    with open(tmp_fname, 'wb') as f:\n        for chunk in r.iter_content(chunk_size=1024):\n            if chunk:\n                f.write(chunk)\n    return tmp_fname\n\n\ndef _read_header(path):\n    if is_url(path):\n        tmp_fname = _download_bytes_to_tmpfile(path, 0, 200)\n        if not tmp_fname:\n            raise Exception('Problem downloading bytes from ' + path)\n        try:\n            ret = _read_header(tmp_fname)\n        except:\n            ret = None\n        Path(tmp_fname).unlink()\n        return ret\n\n    f = open(path, \"rb\")\n    try:\n        dt_code = _read_int32(f)\n        num_bytes_per_entry = _read_int32(f)\n        num_dims = _read_int32(f)\n        uses64bitdims = False\n        if num_dims < 0:\n            uses64bitdims = True\n            num_dims = -num_dims\n        if num_dims < 1 or num_dims > 6:  # allow single dimension as of 12/6/17\n            print(\"Invalid number of dimensions: {}\".format(num_dims))\n            f.close()\n            return None\n        dims = []\n        dimprod = 1\n        if uses64bitdims:\n            for j in range(0, num_dims):\n                tmp0 = _read_int64(f)\n                dimprod = dimprod * tmp0\n                dims.append(tmp0)\n        else:\n            for j in range(0, num_dims):\n                tmp0 = _read_int32(f)\n                dimprod = dimprod * tmp0\n                dims.append(tmp0)\n        dt = _dt_from_dt_code(dt_code)\n        if dt is None:\n            print(\"Invalid data type code: {}\".format(dt_code))\n            f.close()\n            return None\n        H = MdaHeader(dt, dims)\n        if uses64bitdims:\n            H.uses64bitdims = True\n            H.header_size = 3 * 4 + H.num_dims * 8\n        f.close()\n        return H\n    except Exception as e:  # catch *all* exceptions\n        print(e)\n        f.close()\n        return None\n\n\ndef _dt_from_dt_code(dt_code):\n    if dt_code == -2:\n        dt = 'uint8'\n    elif dt_code == -3:\n        dt = 'float32'\n    elif dt_code == -4:\n        dt = 'int16'\n    elif dt_code == -5:\n        dt = 'int32'\n    elif dt_code == -6:\n        dt = 'uint16'\n    elif dt_code == -7:\n        dt = 'float64'\n    elif dt_code == -8:\n        dt = 'uint32'\n    else:\n        dt = None\n    return dt\n\n\ndef _dt_code_from_dt(dt):\n    if dt == 'uint8':\n        return -2\n    if dt == 'float32':\n        return -3\n    if dt == 'int16':\n        return -4\n    if dt == 'int32':\n        return -5\n    if dt == 'uint16':\n        return -6\n    if dt == 'float64':\n        return -7\n    if dt == 'uint32':\n        return -8\n    return None\n\n\ndef get_num_bytes_per_entry_from_dt(dt):\n    if dt == 'uint8':\n        return 1\n    if dt == 'float32':\n        return 4\n    if dt == 'int16':\n        return 2\n    if dt == 'int32':\n        return 4\n    if dt == 'uint16':\n        return 2\n    if dt == 'float64':\n        return 8\n    if dt == 'uint32':\n        return 4\n    return None\n\n\ndef readmda_header(path):\n    if file_extension(path) == '.npy':\n        raise Exception('Cannot read mda header for .npy file.')\n    return _read_header(path)\n\n\ndef _write_header(path, H, rewrite=False):\n    if rewrite:\n        f = open(path, \"r+b\")\n    else:\n        f = open(path, \"wb\")\n    try:\n        _write_int32(f, H.dt_code)\n        _write_int32(f, H.num_bytes_per_entry)\n        if H.uses64bitdims:\n            _write_int32(f, -H.num_dims)\n            for j in range(0, H.num_dims):\n                _write_int64(f, H.dims[j])\n        else:\n            _write_int32(f, H.num_dims)\n            for j in range(0, H.num_dims):\n                _write_int32(f, H.dims[j])\n        f.close()\n        return True\n    except Exception as e:  # catch *all* exceptions\n        print(e)\n        f.close()\n        return False\n\n\ndef readmda(path):\n    if file_extension(path) == '.npy':\n        return readnpy(path);\n    H = _read_header(path)\n    if H is None:\n        print(\"Problem reading header of: {}\".format(path))\n        return None\n    f = open(path, \"rb\")\n    try:\n        f.seek(H.header_size)\n        # This is how I do the column-major order\n        ret = np.fromfile(f, dtype=H.dt, count=H.dimprod)\n        ret = np.reshape(ret, H.dims, order='F')\n        f.close()\n        return ret\n    except Exception as e:  # catch *all* exceptions\n        print(e)\n        f.close()\n        return None\n\n\ndef writemda32(X, fname):\n    if file_extension(fname) == '.npy':\n        return writenpy32(X, fname)\n    return _writemda(X, fname, 'float32')\n\n\ndef writemda64(X, fname):\n    if file_extension(fname) == '.npy':\n        return writenpy64(X, fname)\n    return _writemda(X, fname, 'float64')\n\n\ndef writemda8(X, fname):\n    if file_extension(fname) == '.npy':\n        return writenpy8(X, fname)\n    return _writemda(X, fname, 'uint8')\n\n\ndef writemda32i(X, fname):\n    if file_extension(fname) == '.npy':\n        return writenpy32i(X, fname)\n    return _writemda(X, fname, 'int32')\n\n\ndef writemda32ui(X, fname):\n    if file_extension(fname) == '.npy':\n        return writenpy32ui(X, fname)\n    return _writemda(X, fname, 'uint32')\n\n\ndef writemda16i(X, fname):\n    if file_extension(fname) == '.npy':\n        return writenpy16i(X, fname)\n    return _writemda(X, fname, 'int16')\n\n\ndef writemda16ui(X, fname):\n    if file_extension(fname) == '.npy':\n        return writenpy16ui(X, fname)\n    return _writemda(X, fname, 'uint16')\n\n\ndef writemda(X, fname, *, dtype):\n    return _writemda(X, fname, dtype)\n\n\ndef _writemda(X, fname, dt):\n    num_bytes_per_entry = get_num_bytes_per_entry_from_dt(dt)\n    dt_code = _dt_code_from_dt(dt)\n    if dt_code is None:\n        print(\"Unexpected data type: {}\".format(dt))\n        return False\n\n    if type(fname) == str:\n        f = open(fname, 'wb')\n    else:\n        f = fname\n    try:\n        _write_int32(f, dt_code)\n        _write_int32(f, num_bytes_per_entry)\n        _write_int32(f, X.ndim)\n        for j in range(0, X.ndim):\n            _write_int32(f, X.shape[j])\n        # This is how I do column-major order\n        # A=np.reshape(X,X.size,order='F').astype(dt)\n        # A.tofile(f)\n\n        bytes0 = X.astype(dt).tobytes(order='F')\n        f.write(bytes0)\n\n        if type(fname) == str:\n            f.close()\n        return True\n    except Exception as e:  # catch *all* exceptions\n        traceback.print_exc()\n        print(e)\n        if type(fname) == str:\n            f.close()\n        return False\n\n\ndef readnpy(path):\n    return np.load(path)\n\n\ndef writenpy8(X, path):\n    return _writenpy(X, path, dtype='int8')\n\n\ndef writenpy32(X, path):\n    return _writenpy(X, path, dtype='float32')\n\n\ndef writenpy64(X, path):\n    return _writenpy(X, path, dtype='float64')\n\n\ndef writenpy16i(X, path):\n    return _writenpy(X, path, dtype='int16')\n\n\ndef writenpy16ui(X, path):\n    return _writenpy(X, path, dtype='uint16')\n\n\ndef writenpy32i(X, path):\n    return _writenpy(X, path, dtype='int32')\n\n\ndef writenpy32ui(X, path):\n    return _writenpy(X, path, dtype='uint32')\n\n\ndef writenpy(X, path, *, dtype):\n    return _writenpy(X, path, dtype=dtype)\n\n\ndef _writenpy(X, path, *, dtype):\n    np.save(path, X.astype(dtype=dtype, copy=False))  # astype will always create copy if dtype does not match\n    # apparently allowing pickling is a security issue. (according to the docs) ??\n    # np.save(path,X.astype(dtype=dtype,copy=False),allow_pickle=False) # astype will always create copy if dtype does not match\n    return True\n\n\ndef appendmda(X, path):\n    if file_extension(path) == '.npy':\n        raise Exception('appendmda not yet implemented for .npy files')\n    H = _read_header(path)\n    if H is None:\n        print(\"Problem reading header of: {}\".format(path))\n        return None\n    if len(H.dims) != len(X.shape):\n        print(\"Incompatible number of dimensions in appendmda\", H.dims, X.shape)\n        return None\n    num_entries_old = np.product(H.dims)\n    num_dims = len(H.dims)\n    for j in range(num_dims - 1):\n        if X.shape[j] != X.shape[j]:\n            print(\"Incompatible dimensions in appendmda\", H.dims, X.shape)\n            return None\n    H.dims[num_dims - 1] = H.dims[num_dims - 1] + X.shape[num_dims - 1]\n    try:\n        _write_header(path, H, rewrite=True)\n        f = open(path, \"r+b\")\n        f.seek(H.header_size + H.num_bytes_per_entry * num_entries_old)\n        A = np.reshape(X, X.size, order='F').astype(H.dt)\n        A.tofile(f)\n        f.close()\n    except Exception as e:  # catch *all* exceptions\n        print(e)\n        f.close()\n        return False\n\n\ndef file_extension(fname):\n    if type(fname) == str:\n        filename, ext = os.path.splitext(fname)\n        return ext\n    else:\n        return None\n\n\ndef _read_int32(f):\n    return struct.unpack('<i', f.read(4))[0]\n\n\ndef _read_int64(f):\n    return struct.unpack('<q', f.read(8))[0]\n\n\ndef _write_int32(f, val):\n    f.write(struct.pack('<i', val))\n\n\ndef _write_int64(f, val):\n    f.write(struct.pack('<q', val))\n\n\ndef _header_from_file(f):\n    try:\n        dt_code = _read_int32(f)\n        num_bytes_per_entry = _read_int32(f)\n        num_dims = _read_int32(f)\n        uses64bitdims = False\n        if num_dims < 0:\n            uses64bitdims = True\n            num_dims = -num_dims\n        if num_dims < 1 or num_dims > 6:  # allow single dimension as of 12/6/17\n            print(\"Invalid number of dimensions: {}\".format(num_dims))\n            return None\n        dims = []\n        dimprod = 1\n        if uses64bitdims:\n            for j in range(0, num_dims):\n                tmp0 = _read_int64(f)\n                dimprod = dimprod * tmp0\n                dims.append(tmp0)\n        else:\n            for j in range(0, num_dims):\n                tmp0 = _read_int32(f)\n                dimprod = dimprod * tmp0\n                dims.append(tmp0)\n        dt = _dt_from_dt_code(dt_code)\n        if dt is None:\n            print(\"Invalid data type code: {}\".format(dt_code))\n            return None\n        H = MdaHeader(dt, dims)\n        if uses64bitdims:\n            H.uses64bitdims = True\n            H.header_size = 3 * 4 + H.num_dims * 8\n        return H\n    except Exception as e:  # catch *all* exceptions\n        print(e)\n        return None\n"
  },
  {
    "path": "spikeextractors/extractors/mearecextractors/__init__.py",
    "content": "from .mearecextractors import MEArecRecordingExtractor, MEArecSortingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/mearecextractors/mearecextractors.py",
    "content": "from spikeextractors import RecordingExtractor\nfrom spikeextractors import SortingExtractor\nfrom spikeextractors.extraction_tools import check_get_traces_args, check_get_unit_spike_train\n\nimport numpy as np\nfrom pathlib import Path\nfrom packaging.version import parse\n\ntry:\n    import MEArec as mr\n    import neo\n    import quantities as pq\n    if parse(mr.__version__) >= parse('1.5.0'):\n        HAVE_MREX = True\n    else:\n        print(\"MEArec version requires an update (>=1.5). Please upgrade with 'pip install --upgrade MEArec'\")\n        HAVE_MREX = False\nexcept ImportError:\n    HAVE_MREX = False\n\n\nclass MEArecRecordingExtractor(RecordingExtractor):\n    extractor_name = 'MEArecRecording'\n    has_default_locations = True\n    has_unscaled = False\n    installed = HAVE_MREX  # check at class level if installed or not\n    is_writable = True\n    mode = 'file'\n    installation_mesg = \"To use the MEArec extractors, install MEArec: \\n\\n pip install MEArec\\n\\n\"  # error message when not installed\n\n    def __init__(self, file_path, locs_2d=True):\n        assert self.installed, self.installed\n        self._recording_path = file_path\n        self._fs = None\n        self._positions = None\n        self._recordings = None\n        self._recgen = None\n        self._locs_2d = locs_2d\n        self._locations = None\n        self._initialize()\n        RecordingExtractor.__init__(self)\n\n        if self._locations is not None:\n            self.set_channel_locations(self._locations)\n\n        self._kwargs = {'file_path': str(Path(file_path).absolute()), 'locs_2d': locs_2d}\n\n    def _initialize(self):\n        self._recgen = mr.load_recordings(recordings=self._recording_path, return_h5_objects=True, check_suffix=False,\n                                          load=['recordings', 'channel_positions'])\n        self._fs = self._recgen.info['recordings']['fs']\n        self._recordings = self._recgen.recordings\n        self._num_frames, self._num_channels = self._recordings.shape\n        if len(np.array(self._recgen.channel_positions)) == self._num_channels:\n            self._locations = np.array(self._recgen.channel_positions)\n            if self._locs_2d:\n                if 'electrodes' in self._recgen.info.keys():\n                    if 'plane' in self._recgen.info['electrodes'].keys():\n                        probe_plane = self._recgen.info['electrodes']['plane']\n                        if probe_plane == 'xy':\n                            self._locations = self._locations[:, :2]\n                        elif probe_plane == 'yz':\n                            self._locations = self._locations[:, 1:]\n                        elif probe_plane == 'xz':\n                            self._locations = self._locations[:, [0, 2]]\n                if self._locations.shape[1] == 3:\n                    self._locations = self._locations[:, 1:]\n\n    def get_channel_ids(self):\n        return list(range(self._num_channels))\n\n    def get_num_frames(self):\n        return self._num_frames\n\n    def get_sampling_frequency(self):\n        return self._fs\n\n    @check_get_traces_args\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        if np.any(np.diff(channel_ids) < 0):\n            sorted_channel_ids = np.sort(channel_ids)\n            sorted_idx = np.array([list(sorted_channel_ids).index(ch) for ch in channel_ids])\n            recordings = self._recordings[start_frame:end_frame, sorted_channel_ids.tolist()]\n            return np.array(recordings[:, sorted_idx]).T\n        else:\n            if sorted(channel_ids) == channel_ids and np.all(np.diff(channel_ids) == 1):\n                channel_ids = slice(channel_ids[0], channel_ids[0] + len(channel_ids))\n            return np.array(self._recordings[start_frame:end_frame, channel_ids]).T\n        \n    @staticmethod\n    def write_recording(recording, save_path, check_suffix=True):\n        \"\"\"\n        Save recording extractor to MEArec format.\n        Parameters\n        ----------\n        recording: RecordingExtractor\n            Recording extractor object to be saved\n        save_path: str\n            .h5 or .hdf5 path\n        \"\"\"\n        assert HAVE_MREX, MEArecRecordingExtractor.installation_mesg\n        save_path = Path(save_path)\n        if save_path.is_dir():\n            print(\"The file will be saved as recording.h5 in the provided folder\")\n            save_path = save_path / 'recording.h5'\n        if (save_path.suffix == '.h5' or save_path.suffix == '.hdf5') or (not check_suffix):\n            info = {'recordings': {'fs': recording.get_sampling_frequency()}}\n            rec_dict = {'recordings': recording.get_traces().transpose()}\n            if 'location' in recording.get_shared_channel_property_names():\n                positions = recording.get_channel_locations()\n                rec_dict['channel_positions'] = positions\n            recgen = mr.RecordingGenerator(rec_dict=rec_dict, info=info)\n            mr.save_recording_generator(recgen, str(save_path), verbose=False)\n        else:\n            raise Exception(\"Provide a folder or an .h5/.hdf5 as 'save_path'\")\n\n\nclass MEArecSortingExtractor(SortingExtractor):\n    extractor_name = 'MEArecSorting'\n    installed = HAVE_MREX  # check at class level if installed or not\n    is_writable = True\n    mode = 'file'\n    installation_mesg = \"To use the MEArec extractors, install MEArec: \\n\\n pip install MEArec\\n\\n\"  # error message when not installed\n\n    def __init__(self, file_path):\n        assert self.installed, self.installed\n        SortingExtractor.__init__(self)\n        self._recording_path = file_path\n        self._num_units = None\n        self._spike_trains = None\n        self._unit_ids = None\n        self._fs = None\n        self._initialize()\n        self._kwargs = {'file_path': str(Path(file_path).absolute())}\n\n    def _initialize(self):\n        recgen = mr.load_recordings(recordings=self._recording_path, return_h5_objects=True, check_suffix=False,\n                                    load=['spiketrains'])\n        self._num_units = len(recgen.spiketrains)\n        if 'unit_id' in recgen.spiketrains[0].annotations:\n            self._unit_ids = [int(st.annotations['unit_id']) for st in recgen.spiketrains]\n        else:\n            self._unit_ids = list(range(self._num_units))\n        self._spike_trains = recgen.spiketrains\n        self._fs = recgen.info['recordings']['fs'] * pq.Hz  # fs is in kHz\n        self._sampling_frequency = recgen.info['recordings']['fs']\n\n        if 'soma_position' in self._spike_trains[0].annotations:\n            for u, st in zip(self._unit_ids, self._spike_trains):\n                self.set_unit_property(u, 'soma_location', st.annotations['soma_position'])\n\n    def get_unit_ids(self):\n        if self._unit_ids is None:\n            self._initialize()\n        return self._unit_ids\n\n    def get_num_units(self):\n        if self._num_units is None:\n            self._initialize()\n        return self._num_units\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n        if self._spike_trains is None:\n            self._initialize()\n        times = (self._spike_trains[self.get_unit_ids().index(unit_id)].times.rescale('s') *\n                 self._fs.rescale('Hz')).magnitude\n        inds = np.where((start_frame <= times) & (times < end_frame))\n        return np.rint(times[inds]).astype(int)\n\n    @staticmethod\n    def write_sorting(sorting, save_path, sampling_frequency, check_suffix=True):\n        \"\"\"\n        Save sorting extractor to MEArec format.\n        Parameters\n        ----------\n        sorting: SortingExtractor\n            Sorting extractor object to be saved\n        save_path: str\n            .h5 or .hdf5 path\n        sampling_frequency: int\n            Sampling frequency in Hz\n\n        \"\"\"\n        assert HAVE_MREX, MEArecSortingExtractor.installation_mesg\n        save_path = Path(save_path)\n        if save_path.is_dir():\n            print(\"The file will be saved as sorting.h5 in the provided folder\")\n            save_path = save_path / 'sorting.h5'\n        if (save_path.suffix == '.h5' or save_path.suffix == '.hdf5') or (not check_suffix):\n            # create neo spike trains\n            spiketrains = []\n            for u in sorting.get_unit_ids():\n                st = neo.SpikeTrain(times=sorting.get_unit_spike_train(u) / float(sampling_frequency) * pq.s,\n                                    t_start=np.min(sorting.get_unit_spike_train(u) / float(sampling_frequency)) * pq.s,\n                                    t_stop=np.max(sorting.get_unit_spike_train(u) / float(sampling_frequency)) * pq.s)\n                st.annotate(unit_id=u)\n                spiketrains.append(st)\n\n            assert len(spiketrains) > 0, \"\"\"\n                The sorting for output contains no unit, please check the sorting.\n            \"\"\"\n\n            duration = np.max([st.t_stop.magnitude for st in spiketrains])\n            info = {'recordings': {'fs': sampling_frequency}, 'spiketrains': {'duration': duration}}\n            rec_dict = {'spiketrains': spiketrains}\n            recgen = mr.RecordingGenerator(rec_dict=rec_dict, info=info)\n            mr.save_recording_generator(recgen, str(save_path), verbose=False)\n        else:\n            raise Exception(\"Provide a folder or an .h5/.hdf5 as 'save_path'\")\n"
  },
  {
    "path": "spikeextractors/extractors/neoextractors/__init__.py",
    "content": "from .plexonextractor import PlexonRecordingExtractor, PlexonSortingExtractor\nfrom .neuralynxextractor import NeuralynxRecordingExtractor, NeuralynxSortingExtractor\nfrom .mcsrawrecordingextractor import MCSRawRecordingExtractor\nfrom .blackrockextractor import BlackrockRecordingExtractor, BlackrockSortingExtractor\nfrom .axonaextractor import AxonaRecordingExtractor\nfrom .spikegadgetsextractor import SpikeGadgetsRecordingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/neoextractors/axonaextractor.py",
    "content": "from .neobaseextractor import NeoBaseRecordingExtractor\n\ntry:\n    import neo\n\n    HAVE_NEO = True\nexcept ImportError:\n    HAVE_NEO = False\n\n\nclass AxonaRecordingExtractor(NeoBaseRecordingExtractor):\n    extractor_name = 'AxonaRecording'\n    mode = 'file'\n    NeoRawIOClass = 'AxonaRawIO'\n\n    def __init__(self, **kargs):\n        super().__init__(**kargs)\n\n        # Read channel groups by tetrode IDs\n        self.set_channel_groups(groups=[x - 1 for x in self.neo_reader.raw_annotations[\n            'blocks'][0]['segments'][0]['signals'][0]['__array_annotations__']['tetrode_id']])\n\n        header_channels = self.neo_reader.header['signal_channels'][slice(None)]\n\n        names = header_channels['name']\n        for i, ind in enumerate(self.get_channel_ids()):\n            self.set_channel_property(channel_id=ind, property_name='name', value=names[i])\n"
  },
  {
    "path": "spikeextractors/extractors/neoextractors/blackrockextractor.py",
    "content": "from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor\nfrom pathlib import Path\nfrom typing import Union, Optional\n\ntry:\n    import neo\n    HAVE_NEO = True\nexcept ImportError:\n    HAVE_NEO = False\n\n\nPathType = Union[str, Path]\n\n\nclass BlackrockRecordingExtractor(NeoBaseRecordingExtractor):\n    \"\"\"\n    The Blackrock extractor is wrapped from neo.rawio.BlackrockRawIO.\n    \n    Parameters\n    ----------\n    filename: str\n        The Blackrock file (.ns1, .ns2, .ns3, .ns4m .ns4, or .ns6)\n    block_index: None or int\n        If the underlying dataset have several blocks the index must be specified.\n    seg_index_index: None or int\n        If the underlying dataset have several segments the index must be specified.\n    \n    \"\"\"\n    extractor_name = 'BlackrockRecording'\n    mode = 'file'\n    installed = HAVE_NEO\n    NeoRawIOClass = 'BlackrockRawIO'\n\n    def __init__(self, filename: PathType, nsx_to_load: Optional[int] = None, block_index: Optional[int] = None, \n                 seg_index: Optional[int] = None, **kwargs):\n        super().__init__(filename=filename, nsx_to_load=nsx_to_load, \n                         block_index=block_index, seg_index=seg_index, **kwargs)\n\n\nclass BlackrockSortingExtractor(NeoBaseSortingExtractor):\n    \"\"\"\n    The Blackrock extractor is wrapped from neo.rawio.BlackrockRawIO.\n\n    Parameters\n    ----------\n    filename: str\n        The Blackrock file (.nev)\n    block_index: None or int\n        If the underlying dataset have several blocks the index must be specified.\n    seg_index_index: None or int\n        If the underlying dataset have several segments the index must be specified.\n\n    \"\"\"\n    extractor_name = 'BlackrockSorting'\n    mode = 'file'\n    installed = HAVE_NEO\n    NeoRawIOClass = 'BlackrockRawIO'\n\n    def __init__(self, filename: PathType, nsx_to_load: Optional[int] = None,\n                 block_index: Optional[int] = None, seg_index: Optional[int] = None, **kwargs):\n        super().__init__(filename=filename, nsx_to_load=nsx_to_load,\n                         block_index=block_index, seg_index=seg_index, **kwargs)\n"
  },
  {
    "path": "spikeextractors/extractors/neoextractors/mcsrawrecordingextractor.py",
    "content": "from .neobaseextractor import NeoBaseRecordingExtractor\n\ntry:\n    import neo\n    HAVE_NEO = True\nexcept ImportError:\n    HAVE_NEO = False\n \n\nclass MCSRawRecordingExtractor(NeoBaseRecordingExtractor):\n    extractor_name='mcsrawRecoding'\n    mode='file'\n    NeoRawIOClass='RawMCSRawIO'\n"
  },
  {
    "path": "spikeextractors/extractors/neoextractors/neobaseextractor.py",
    "content": "import numpy as np\nimport warnings\n\nfrom spikeextractors import RecordingExtractor\nfrom spikeextractors import SortingExtractor\nfrom spikeextractors.extraction_tools import check_get_traces_args, check_get_unit_spike_train\n\ntry:\n    import neo\n\n    HAVE_NEO = True\nexcept ImportError:\n    HAVE_NEO = False\n\n\nclass _NeoBaseExtractor:\n    NeoRawIOClass = None\n    installed = HAVE_NEO\n    is_writable = False\n    has_default_locations = False\n    has_unscaled = True\n    installation_mesg = \"To use the Neo extractors, install Neo: \\n\\n pip install neo\\n\\n\"\n\n    def __init__(self, block_index=None, seg_index=None, **kargs):\n        \"\"\"\n        if block_index is None then check if only one block\n        if seg_index is None then check if only one segment\n\n        \"\"\"\n        assert self.installed, self.installation_mesg\n        neoIOclass = eval('neo.rawio.' + self.NeoRawIOClass)\n\n        self.neo_reader = neoIOclass(**kargs)\n        self.neo_reader.parse_header()\n\n        if block_index is None:\n            # auto select first block\n            num_block = self.neo_reader.block_count()\n            assert num_block == 1, 'This file is multi block spikeextractors support only one segment, please provide block_index='\n            block_index = 0\n\n        if seg_index is None:\n            # auto select first segment\n            num_seg = self.neo_reader.segment_count(block_index)\n            assert num_seg == 1, 'This file is multi segment spikeextractors support only one segment, please provide seg_index='\n            seg_index = 0\n\n        self.block_index = block_index\n        self.seg_index = seg_index\n        self._kwargs = kargs\n        self._kwargs.update({'seg_index': seg_index, 'block_index': block_index})\n\n\nclass NeoBaseRecordingExtractor(RecordingExtractor, _NeoBaseExtractor):\n\n    def __init__(self, block_index=None, seg_index=None, **kargs):\n        RecordingExtractor.__init__(self)\n        _NeoBaseExtractor.__init__(self, block_index=block_index, seg_index=seg_index, **kargs)\n\n        if hasattr(self.neo_reader, 'get_group_signal_channel_indexes'):\n            # Neo >= 0.9.0\n            channel_indexes_list = self.neo_reader.get_group_signal_channel_indexes()\n            num_streams = len(channel_indexes_list)\n            assert num_streams == 1, 'This file have several channel groups spikeextractors support only one groups'\n            self.after_v10 = False\n        elif hasattr(self.neo_reader, 'get_group_channel_indexes'):\n            # Neo < 0.9.0\n            channel_indexes_list = self.neo_reader.get_group_channel_indexes()\n            num_streams = len(channel_indexes_list)\n            self.after_v10 = False\n        elif hasattr(self.neo_reader, 'signal_streams_count'):\n            # Neo >= 0.10.0 (not release yet in march 2021)\n            num_streams = self.neo_reader.signal_streams_count()\n            self.after_v10 = True\n        else:\n            raise ValueError('Strange neo version. Please upgrade your neo package: pip install --upgrade neo')\n\n        assert num_streams == 1, 'This file have several signal streams spikeextractors support only one streams' \\\n                                 'Maybe you can use option to select only one stream'\n\n        # spikeextractor for units to be uV implicitly\n        # check that units are V, mV or uV\n        units = self.neo_reader.header['signal_channels']['units']\n        if not np.all(np.isin(units, ['V', 'mV', 'uV'])):\n            warnings.warn('Signal units no Volt compatible, assuming scaling as uV')\n        self.additional_gain = np.ones(units.size, dtype='float')\n        self.additional_gain[units == 'V'] = 1e6\n        self.additional_gain[units == 'mV'] = 1e3\n        self.additional_gain[units == 'uV'] = 1.\n        self.additional_gain[units == ''] = 1.\n        self.additional_gain = self.additional_gain.reshape(1, -1)\n\n        # Add channels properties\n        header_channels = self.neo_reader.header['signal_channels'][slice(None)]\n        self._neo_chan_ids = self.neo_reader.header['signal_channels']['id']\n\n        # In neo there is not guarantee that channel ids are unique.\n        # for instance Blacrock can have several times the same chan_id\n        # different sampling rate\n        # so check it\n        assert np.unique(self._neo_chan_ids).size == self._neo_chan_ids.size, 'In this format channel ids are not ' \\\n                                                                              'unique! Incompatible with SpikeInterface'\n\n        try:\n            channel_ids = [int(ch) for ch in self._neo_chan_ids]\n        except Exception as e:\n            warnings.warn(\"Could not parse channel ids to int: using linear channel map\")\n            channel_ids = list(np.arange(len(self._neo_chan_ids)))\n        self._channel_ids = channel_ids\n\n        gains = header_channels['gain'] * self.additional_gain[0]\n        self.set_channel_gains(gains=gains, channel_ids=self._channel_ids)\n\n        names = header_channels['name']\n        for i, ind in enumerate(self._channel_ids):\n            self.set_channel_property(channel_id=ind, property_name='name', value=names[i])\n\n    @check_get_traces_args\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        # in neo rawio channel can acces by names/ids/indexes\n        # there is no garranty that ids/names are unique on some formats\n        channel_idxs = [self.get_channel_ids().index(ch) for ch in channel_ids]\n        neo_chan_ids = self._neo_chan_ids[channel_idxs]\n        if self.after_v10:\n            raw_traces = self.neo_reader.get_analogsignal_chunk(block_index=self.block_index, seg_index=self.seg_index,\n                                                                i_start=start_frame, i_stop=end_frame,\n                                                                channel_indexes=None, channel_names=None,\n                                                                stream_index=0, channel_ids=neo_chan_ids)\n        else:\n            raw_traces = self.neo_reader.get_analogsignal_chunk(block_index=self.block_index, seg_index=self.seg_index,\n                                                                i_start=start_frame, i_stop=end_frame,\n                                                                channel_indexes=None, channel_names=None,\n                                                                channel_ids=neo_chan_ids)\n        # neo works with (samples, channels) strides\n        # so transpose to spikeextractors wolrd\n        return raw_traces.transpose()\n\n    def get_num_frames(self):\n        # channel_indexes=None means all channels\n        if self.after_v10:\n            n = self.neo_reader.get_signal_size(self.block_index, self.seg_index, stream_index=0)\n        else:\n            n = self.neo_reader.get_signal_size(self.block_index, self.seg_index, channel_indexes=None)\n        return n\n\n    def get_sampling_frequency(self):\n        # channel_indexes=None means all channels\n        if self.after_v10:\n            sf = self.neo_reader.get_signal_sampling_rate(stream_index=0)\n        else:\n            sf = self.neo_reader.get_signal_sampling_rate(channel_indexes=None)\n        return sf\n\n    def get_channel_ids(self):\n        return self._channel_ids\n\n\nclass NeoBaseSortingExtractor(SortingExtractor, _NeoBaseExtractor):\n    def __init__(self, block_index=None, seg_index=None, **kargs):\n        SortingExtractor.__init__(self)\n        _NeoBaseExtractor.__init__(self, block_index=block_index, seg_index=seg_index, **kargs)\n\n        # the sampling frequency is quite tricky because in neo\n        # spike are handle in s or ms\n        # internally many format do have have the spike time stamps\n        # at the same speed as the signal but at a higher clocks speed.\n        # here in spikeinterface we need spike index to be at the same speed\n        # that signal it do not make sens to have spikes at 50kHz sample\n        # when the sig is 10kHz.\n        # neo handle this but not spikeextractors\n\n        self._handle_sampling_frequency()\n\n    def _handle_sampling_frequency(self):\n        # bacause neo handle spike in times (s or ms) but spikeextractors in frames related to signals.\n        # In neo spikes can have diffrents sampling rate than signals so conversion from\n        #  signals frames to times is format dependent\n\n        # here the generic case\n        #  all channels are in the same neo group so\n        if len(self.neo_reader.header['signal_channels']['sampling_rate']) > 0:\n            self._neo_sig_sampling_rate = self.neo_reader.header['signal_channels']['sampling_rate'][0]\n            self.set_sampling_frequency(self._neo_sig_sampling_rate)\n        else:\n            warnings.warn(\"Sampling frequency not found: setting it to 30 kHz\")\n            self._sampling_frequency = 30000\n            self._neo_sig_sampling_rate = self._sampling_frequency\n\n        if hasattr(self.neo_reader, 'get_group_signal_channel_indexes'):\n            # Neo >= 0.9.0\n            if len(self.neo_reader.get_group_signal_channel_indexes()) > 0:\n                self._neo_sig_time_start = self.neo_reader.get_signal_t_start(self.block_index, self.seg_index,\n                                                                              channel_indexes=[0])\n            else:\n                warnings.warn(\"Start time not found: setting it to 0 s\")\n                self._neo_sig_time_start = 0\n        elif hasattr(self.neo_reader, 'get_group_channel_indexes'):\n            # Neo < 0.9.0\n            if len(self.neo_reader.get_group_channel_indexes()) > 0:\n                self._neo_sig_time_start = self.neo_reader.get_signal_t_start(self.block_index, self.seg_index,\n                                                                              channel_indexes=[0])\n            else:\n                warnings.warn(\"Start time not found: setting it to 0 s\")\n                self._neo_sig_time_start = 0\n        elif hasattr(self.neo_reader, 'signal_streams_count'):\n            # Neo >= 0.10.0 (not release yet in march 2021)\n            num_streams = self.neo_reader.signal_streams_count()\n            if num_streams > 0:\n                self._neo_sig_time_start = self.neo_reader.get_signal_t_start(self.block_index, self.seg_index,\n                                                                              stream_index=0)\n            else:\n                warnings.warn(\"Start time not found: setting it to 0 s\")\n                self._neo_sig_time_start = 0\n        else:\n            raise ValueError('Strange neo version. Please upgrade your neo package: pip install --upgrade neo')\n\n        # For some IOs when there is no signals at inside the dataset this could not work\n        # in that case the extractor class must overwrite this method\n\n    def get_unit_ids(self):\n        # should be this but this is strings in neo\n        #  unit_ids = self.neo_reader.header['unit_channels']['id']\n\n        # in neo unit_ids are string so here we take unit_index\n        if 'unit_channels' in self.neo_reader.header:\n            unit_ids = np.arange(self.neo_reader.header['unit_channels'].size, dtype='int64')\n        elif 'spike_channels' in self.neo_reader.header:\n            unit_ids = np.arange(self.neo_reader.header['spike_channels'].size, dtype='int64')\n        else:\n            raise ValueError('Strange neo version. Please upgrade your neo package: pip install --upgrade neo')\n        return unit_ids\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n        start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame)\n        # this is a string\n        #  neo_unit_id = self.neo_reader.header['unit_channels']['id'][unit_id]\n\n        # this is an int\n        unit_index = unit_id\n\n        # in neo can be a sample, or hiher sample rate or even float\n        try:\n            # version >= 0.9.0\n            spike_timestamps = self.neo_reader.get_spike_timestamps(block_index=self.block_index,\n                                                                    seg_index=self.seg_index,\n                                                                    spike_channel_index=unit_index,\n                                                                    t_start=None, t_stop=None)\n        except TypeError as e:\n            # version < 0.9.0\n            spike_timestamps = self.neo_reader.get_spike_timestamps(block_index=self.block_index,\n                                                                    seg_index=self.seg_index,\n                                                                    unit_index=unit_index, t_start=None, t_stop=None)\n\n        if start_frame is not None:\n            spike_timestamps = spike_timestamps[spike_timestamps >= start_frame]\n\n        if end_frame is not None:\n            spike_timestamps = spike_timestamps[spike_timestamps <= end_frame]\n\n        # convert to second second\n        spike_times = self.neo_reader.rescale_spike_timestamp(spike_timestamps, dtype='float64')\n\n        # convert to sample related to recording signals\n        spike_indexes = ((spike_times - self._neo_sig_time_start) * self._neo_sig_sampling_rate).astype('int64')\n        return spike_indexes\n"
  },
  {
    "path": "spikeextractors/extractors/neoextractors/neuralynxextractor.py",
    "content": "from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor\n\ntry:\n    import neo\n    HAVE_NEO = True\nexcept ImportError:\n    HAVE_NEO = False\n\n\nclass NeuralynxRecordingExtractor(NeoBaseRecordingExtractor):\n    \"\"\"\n    The neruralynx extractor is wrapped from neo.rawio.NeuralynxRawIO.\n    \n    Parameters\n    ----------\n    dirname: str\n        The neuralynx folder that contain all neuralynx files ('nse', 'ncs', 'nev', 'ntt')\n    block_index: None or int\n        If the underlying dataset have several blocks the index must be specified.\n    seg_index_index: None or int\n        If the underlying dataset have several segments the index must be specified.\n    \n    \"\"\"\n    extractor_name = 'NeuralynxRecording'\n    mode = 'folder'\n    installed = HAVE_NEO\n    NeoRawIOClass = 'NeuralynxRawIO'\n\n\nclass NeuralynxSortingExtractor(NeoBaseSortingExtractor):\n    \"\"\"\n    The neruralynx extractor is wrapped from neo.rawio.NeuralynxRawIO.\n\n    Parameters\n    ----------\n    dirname: str\n        The neuralynx folder that contain all neuralynx files ('nse', 'ncs', 'nev', 'ntt')\n    block_index: None or int\n        If the underlying dataset have several blocks the index must be specified.\n    seg_index_index: None or int\n        If the underlying dataset have several segments the index must be specified.\n\n    \"\"\"\n    extractor_name = 'NeuralynxSorting'\n    mode = 'folder'\n    installed = HAVE_NEO\n    NeoRawIOClass = 'NeuralynxRawIO'\n"
  },
  {
    "path": "spikeextractors/extractors/neoextractors/plexonextractor.py",
    "content": "from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor\n\ntry:\n    import neo\n    HAVE_NEO = True\nexcept ImportError:\n    HAVE_NEO = False\n\nclass PlexonRecordingExtractor(NeoBaseRecordingExtractor):\n    \"\"\"\n    The plxon extractor is wrapped from neo.rawio.PlexonRawIO.\n    \n    Parameters\n    ----------\n    filename: str\n        The plexon file ('plx')\n    block_index: None or int\n        If the underlying dataset have several blocks the index must be specified.\n    seg_index_index: None or int\n        If the underlying dataset have several segments the index must be specified.\n    \n    \"\"\"    \n    extractor_name = 'PlexonRecording'\n    mode = 'file'\n    installed = HAVE_NEO\n    NeoRawIOClass = 'PlexonRawIO'\n\nclass PlexonSortingExtractor(NeoBaseSortingExtractor):\n    extractor_name = 'PlexonSorting'\n    mode = 'file'\n    installed = HAVE_NEO\n    NeoRawIOClass = 'PlexonRawIO'\n"
  },
  {
    "path": "spikeextractors/extractors/neoextractors/spikegadgetsextractor.py",
    "content": "from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor\n\ntry:\n    import neo\n    HAVE_NEO = True\nexcept ImportError:\n    HAVE_NEO = False\n\nclass SpikeGadgetsRecordingExtractor(NeoBaseRecordingExtractor):\n    \"\"\"\n    The spikegadgets extractor is wrapped from neo.rawio.SpikegadgetsRawIO.\n    \n    Parameters\n    ----------\n    filename: str\n        The spike gadgets file ('rec')\n    selected_streams: str\n        The id of the stream to load 'trodes' is ephy channels.\n        Can also be ECU, ...\n    block_index: None or int\n        If the underlying dataset have several blocks the index must be specified.\n    seg_index_index: None or int\n        If the underlying dataset have several segments the index must be specified.\n    \"\"\"    \n    extractor_name = 'SpikeGadgetsRecording'\n    mode = 'file'\n    installed = HAVE_NEO\n    NeoRawIOClass = 'SpikeGadgetsRawIO'\n    def __init__(self, filename, selected_streams='trodes',**kwargs):\n        super().__init__(filename=filename, selected_streams=selected_streams, **kwargs)\n"
  },
  {
    "path": "spikeextractors/extractors/neuropixelsdatrecordingextractor/__init__.py",
    "content": "from .neuropixelsdatrecordingextractor import NeuropixelsDatRecordingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/neuropixelsdatrecordingextractor/channel_positions_neuropixels.txt",
    "content": "4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01 4.300000000000000000e+01 1.100000000000000000e+01 5.900000000000000000e+01 2.700000000000000000e+01\n2.000000000000000000e+01 2.000000000000000000e+01 4.000000000000000000e+01 4.000000000000000000e+01 6.000000000000000000e+01 6.000000000000000000e+01 8.000000000000000000e+01 8.000000000000000000e+01 1.000000000000000000e+02 1.000000000000000000e+02 1.200000000000000000e+02 1.200000000000000000e+02 1.400000000000000000e+02 1.400000000000000000e+02 1.600000000000000000e+02 1.600000000000000000e+02 1.800000000000000000e+02 1.800000000000000000e+02 2.000000000000000000e+02 2.000000000000000000e+02 2.200000000000000000e+02 2.200000000000000000e+02 2.400000000000000000e+02 2.400000000000000000e+02 2.600000000000000000e+02 2.600000000000000000e+02 2.800000000000000000e+02 2.800000000000000000e+02 3.000000000000000000e+02 3.000000000000000000e+02 3.200000000000000000e+02 3.200000000000000000e+02 3.400000000000000000e+02 3.400000000000000000e+02 3.600000000000000000e+02 3.600000000000000000e+02 3.800000000000000000e+02 3.800000000000000000e+02 4.000000000000000000e+02 4.000000000000000000e+02 4.200000000000000000e+02 4.200000000000000000e+02 4.400000000000000000e+02 4.400000000000000000e+02 4.600000000000000000e+02 4.600000000000000000e+02 4.800000000000000000e+02 4.800000000000000000e+02 5.000000000000000000e+02 5.000000000000000000e+02 5.200000000000000000e+02 5.200000000000000000e+02 5.400000000000000000e+02 5.400000000000000000e+02 5.600000000000000000e+02 5.600000000000000000e+02 5.800000000000000000e+02 5.800000000000000000e+02 6.000000000000000000e+02 6.000000000000000000e+02 6.200000000000000000e+02 6.200000000000000000e+02 6.400000000000000000e+02 6.400000000000000000e+02 6.600000000000000000e+02 6.600000000000000000e+02 6.800000000000000000e+02 6.800000000000000000e+02 7.000000000000000000e+02 7.000000000000000000e+02 7.200000000000000000e+02 7.200000000000000000e+02 7.400000000000000000e+02 7.400000000000000000e+02 7.600000000000000000e+02 7.600000000000000000e+02 7.800000000000000000e+02 7.800000000000000000e+02 8.000000000000000000e+02 8.000000000000000000e+02 8.200000000000000000e+02 8.200000000000000000e+02 8.400000000000000000e+02 8.400000000000000000e+02 8.600000000000000000e+02 8.600000000000000000e+02 8.800000000000000000e+02 8.800000000000000000e+02 9.000000000000000000e+02 9.000000000000000000e+02 9.200000000000000000e+02 9.200000000000000000e+02 9.400000000000000000e+02 9.400000000000000000e+02 9.600000000000000000e+02 9.600000000000000000e+02 9.800000000000000000e+02 9.800000000000000000e+02 1.000000000000000000e+03 1.000000000000000000e+03 1.020000000000000000e+03 1.020000000000000000e+03 1.040000000000000000e+03 1.040000000000000000e+03 1.060000000000000000e+03 1.060000000000000000e+03 1.080000000000000000e+03 1.080000000000000000e+03 1.100000000000000000e+03 1.100000000000000000e+03 1.120000000000000000e+03 1.120000000000000000e+03 1.140000000000000000e+03 1.140000000000000000e+03 1.160000000000000000e+03 1.160000000000000000e+03 1.180000000000000000e+03 1.180000000000000000e+03 1.200000000000000000e+03 1.200000000000000000e+03 1.220000000000000000e+03 1.220000000000000000e+03 1.240000000000000000e+03 1.240000000000000000e+03 1.260000000000000000e+03 1.260000000000000000e+03 1.280000000000000000e+03 1.280000000000000000e+03 1.300000000000000000e+03 1.300000000000000000e+03 1.320000000000000000e+03 1.320000000000000000e+03 1.340000000000000000e+03 1.340000000000000000e+03 1.360000000000000000e+03 1.360000000000000000e+03 1.380000000000000000e+03 1.380000000000000000e+03 1.400000000000000000e+03 1.400000000000000000e+03 1.420000000000000000e+03 1.420000000000000000e+03 1.440000000000000000e+03 1.440000000000000000e+03 1.460000000000000000e+03 1.460000000000000000e+03 1.480000000000000000e+03 1.480000000000000000e+03 1.500000000000000000e+03 1.500000000000000000e+03 1.520000000000000000e+03 1.520000000000000000e+03 1.540000000000000000e+03 1.540000000000000000e+03 1.560000000000000000e+03 1.560000000000000000e+03 1.580000000000000000e+03 1.580000000000000000e+03 1.600000000000000000e+03 1.600000000000000000e+03 1.620000000000000000e+03 1.620000000000000000e+03 1.640000000000000000e+03 1.640000000000000000e+03 1.660000000000000000e+03 1.660000000000000000e+03 1.680000000000000000e+03 1.680000000000000000e+03 1.700000000000000000e+03 1.700000000000000000e+03 1.720000000000000000e+03 1.720000000000000000e+03 1.740000000000000000e+03 1.740000000000000000e+03 1.760000000000000000e+03 1.760000000000000000e+03 1.780000000000000000e+03 1.780000000000000000e+03 1.800000000000000000e+03 1.800000000000000000e+03 1.820000000000000000e+03 1.820000000000000000e+03 1.840000000000000000e+03 1.840000000000000000e+03 1.860000000000000000e+03 1.860000000000000000e+03 1.880000000000000000e+03 1.880000000000000000e+03 1.900000000000000000e+03 1.900000000000000000e+03 1.920000000000000000e+03 1.920000000000000000e+03 1.940000000000000000e+03 1.940000000000000000e+03 1.960000000000000000e+03 1.960000000000000000e+03 1.980000000000000000e+03 1.980000000000000000e+03 2.000000000000000000e+03 2.000000000000000000e+03 2.020000000000000000e+03 2.020000000000000000e+03 2.040000000000000000e+03 2.040000000000000000e+03 2.060000000000000000e+03 2.060000000000000000e+03 2.080000000000000000e+03 2.080000000000000000e+03 2.100000000000000000e+03 2.100000000000000000e+03 2.120000000000000000e+03 2.120000000000000000e+03 2.140000000000000000e+03 2.140000000000000000e+03 2.160000000000000000e+03 2.160000000000000000e+03 2.180000000000000000e+03 2.180000000000000000e+03 2.200000000000000000e+03 2.200000000000000000e+03 2.220000000000000000e+03 2.220000000000000000e+03 2.240000000000000000e+03 2.240000000000000000e+03 2.260000000000000000e+03 2.260000000000000000e+03 2.280000000000000000e+03 2.280000000000000000e+03 2.300000000000000000e+03 2.300000000000000000e+03 2.320000000000000000e+03 2.320000000000000000e+03 2.340000000000000000e+03 2.340000000000000000e+03 2.360000000000000000e+03 2.360000000000000000e+03 2.380000000000000000e+03 2.380000000000000000e+03 2.400000000000000000e+03 2.400000000000000000e+03 2.420000000000000000e+03 2.420000000000000000e+03 2.440000000000000000e+03 2.440000000000000000e+03 2.460000000000000000e+03 2.460000000000000000e+03 2.480000000000000000e+03 2.480000000000000000e+03 2.500000000000000000e+03 2.500000000000000000e+03 2.520000000000000000e+03 2.520000000000000000e+03 2.540000000000000000e+03 2.540000000000000000e+03 2.560000000000000000e+03 2.560000000000000000e+03 2.580000000000000000e+03 2.580000000000000000e+03 2.600000000000000000e+03 2.600000000000000000e+03 2.620000000000000000e+03 2.620000000000000000e+03 2.640000000000000000e+03 2.640000000000000000e+03 2.660000000000000000e+03 2.660000000000000000e+03 2.680000000000000000e+03 2.680000000000000000e+03 2.700000000000000000e+03 2.700000000000000000e+03 2.720000000000000000e+03 2.720000000000000000e+03 2.740000000000000000e+03 2.740000000000000000e+03 2.760000000000000000e+03 2.760000000000000000e+03 2.780000000000000000e+03 2.780000000000000000e+03 2.800000000000000000e+03 2.800000000000000000e+03 2.820000000000000000e+03 2.820000000000000000e+03 2.840000000000000000e+03 2.840000000000000000e+03 2.860000000000000000e+03 2.860000000000000000e+03 2.880000000000000000e+03 2.880000000000000000e+03 2.900000000000000000e+03 2.900000000000000000e+03 2.920000000000000000e+03 2.920000000000000000e+03 2.940000000000000000e+03 2.940000000000000000e+03 2.960000000000000000e+03 2.960000000000000000e+03 2.980000000000000000e+03 2.980000000000000000e+03 3.000000000000000000e+03 3.000000000000000000e+03 3.020000000000000000e+03 3.020000000000000000e+03 3.040000000000000000e+03 3.040000000000000000e+03 3.060000000000000000e+03 3.060000000000000000e+03 3.080000000000000000e+03 3.080000000000000000e+03 3.100000000000000000e+03 3.100000000000000000e+03 3.120000000000000000e+03 3.120000000000000000e+03 3.140000000000000000e+03 3.140000000000000000e+03 3.160000000000000000e+03 3.160000000000000000e+03 3.180000000000000000e+03 3.180000000000000000e+03 3.200000000000000000e+03 3.200000000000000000e+03 3.220000000000000000e+03 3.220000000000000000e+03 3.240000000000000000e+03 3.240000000000000000e+03 3.260000000000000000e+03 3.260000000000000000e+03 3.280000000000000000e+03 3.280000000000000000e+03 3.300000000000000000e+03 3.300000000000000000e+03 3.320000000000000000e+03 3.320000000000000000e+03 3.340000000000000000e+03 3.340000000000000000e+03 3.360000000000000000e+03 3.360000000000000000e+03 3.380000000000000000e+03 3.380000000000000000e+03 3.400000000000000000e+03 3.400000000000000000e+03 3.420000000000000000e+03 3.420000000000000000e+03 3.440000000000000000e+03 3.440000000000000000e+03 3.460000000000000000e+03 3.460000000000000000e+03 3.480000000000000000e+03 3.480000000000000000e+03 3.500000000000000000e+03 3.500000000000000000e+03 3.520000000000000000e+03 3.520000000000000000e+03 3.540000000000000000e+03 3.540000000000000000e+03 3.560000000000000000e+03 3.560000000000000000e+03 3.580000000000000000e+03 3.580000000000000000e+03 3.600000000000000000e+03 3.600000000000000000e+03 3.620000000000000000e+03 3.620000000000000000e+03 3.640000000000000000e+03 3.640000000000000000e+03 3.660000000000000000e+03 3.660000000000000000e+03 3.680000000000000000e+03 3.680000000000000000e+03 3.700000000000000000e+03 3.700000000000000000e+03 3.720000000000000000e+03 3.720000000000000000e+03 3.740000000000000000e+03 3.740000000000000000e+03 3.760000000000000000e+03 3.760000000000000000e+03 3.780000000000000000e+03 3.780000000000000000e+03 3.800000000000000000e+03 3.800000000000000000e+03 3.820000000000000000e+03 3.820000000000000000e+03 3.840000000000000000e+03 3.840000000000000000e+03 3.860000000000000000e+03 3.860000000000000000e+03 3.880000000000000000e+03 3.880000000000000000e+03 3.900000000000000000e+03 3.900000000000000000e+03 3.920000000000000000e+03 3.920000000000000000e+03 3.940000000000000000e+03 3.940000000000000000e+03 3.960000000000000000e+03 3.960000000000000000e+03 3.980000000000000000e+03 3.980000000000000000e+03 4.000000000000000000e+03 4.000000000000000000e+03 4.020000000000000000e+03 4.020000000000000000e+03 4.040000000000000000e+03 4.040000000000000000e+03 4.060000000000000000e+03 4.060000000000000000e+03 4.080000000000000000e+03 4.080000000000000000e+03 4.100000000000000000e+03 4.100000000000000000e+03 4.120000000000000000e+03 4.120000000000000000e+03 4.140000000000000000e+03 4.140000000000000000e+03 4.160000000000000000e+03 4.160000000000000000e+03 4.180000000000000000e+03 4.180000000000000000e+03 4.200000000000000000e+03 4.200000000000000000e+03 4.220000000000000000e+03 4.220000000000000000e+03 4.240000000000000000e+03 4.240000000000000000e+03 4.260000000000000000e+03 4.260000000000000000e+03 4.280000000000000000e+03 4.280000000000000000e+03 4.300000000000000000e+03 4.300000000000000000e+03 4.320000000000000000e+03 4.320000000000000000e+03 4.340000000000000000e+03 4.340000000000000000e+03 4.360000000000000000e+03 4.360000000000000000e+03 4.380000000000000000e+03 4.380000000000000000e+03 4.400000000000000000e+03 4.400000000000000000e+03 4.420000000000000000e+03 4.420000000000000000e+03 4.440000000000000000e+03 4.440000000000000000e+03 4.460000000000000000e+03 4.460000000000000000e+03 4.480000000000000000e+03 4.480000000000000000e+03 4.500000000000000000e+03 4.500000000000000000e+03 4.520000000000000000e+03 4.520000000000000000e+03 4.540000000000000000e+03 4.540000000000000000e+03 4.560000000000000000e+03 4.560000000000000000e+03 4.580000000000000000e+03 4.580000000000000000e+03 4.600000000000000000e+03 4.600000000000000000e+03 4.620000000000000000e+03 4.620000000000000000e+03 4.640000000000000000e+03 4.640000000000000000e+03 4.660000000000000000e+03 4.660000000000000000e+03 4.680000000000000000e+03 4.680000000000000000e+03 4.700000000000000000e+03 4.700000000000000000e+03 4.720000000000000000e+03 4.720000000000000000e+03 4.740000000000000000e+03 4.740000000000000000e+03 4.760000000000000000e+03 4.760000000000000000e+03 4.780000000000000000e+03 4.780000000000000000e+03 4.800000000000000000e+03 4.800000000000000000e+03 4.820000000000000000e+03 4.820000000000000000e+03 4.840000000000000000e+03 4.840000000000000000e+03 4.860000000000000000e+03 4.860000000000000000e+03 4.880000000000000000e+03 4.880000000000000000e+03 4.900000000000000000e+03 4.900000000000000000e+03 4.920000000000000000e+03 4.920000000000000000e+03 4.940000000000000000e+03 4.940000000000000000e+03 4.960000000000000000e+03 4.960000000000000000e+03 4.980000000000000000e+03 4.980000000000000000e+03 5.000000000000000000e+03 5.000000000000000000e+03 5.020000000000000000e+03 5.020000000000000000e+03 5.040000000000000000e+03 5.040000000000000000e+03 5.060000000000000000e+03 5.060000000000000000e+03 5.080000000000000000e+03 5.080000000000000000e+03 5.100000000000000000e+03 5.100000000000000000e+03 5.120000000000000000e+03 5.120000000000000000e+03 5.140000000000000000e+03 5.140000000000000000e+03 5.160000000000000000e+03 5.160000000000000000e+03 5.180000000000000000e+03 5.180000000000000000e+03 5.200000000000000000e+03 5.200000000000000000e+03 5.220000000000000000e+03 5.220000000000000000e+03 5.240000000000000000e+03 5.240000000000000000e+03 5.260000000000000000e+03 5.260000000000000000e+03 5.280000000000000000e+03 5.280000000000000000e+03 5.300000000000000000e+03 5.300000000000000000e+03 5.320000000000000000e+03 5.320000000000000000e+03 5.340000000000000000e+03 5.340000000000000000e+03 5.360000000000000000e+03 5.360000000000000000e+03 5.380000000000000000e+03 5.380000000000000000e+03 5.400000000000000000e+03 5.400000000000000000e+03 5.420000000000000000e+03 5.420000000000000000e+03 5.440000000000000000e+03 5.440000000000000000e+03 5.460000000000000000e+03 5.460000000000000000e+03 5.480000000000000000e+03 5.480000000000000000e+03 5.500000000000000000e+03 5.500000000000000000e+03 5.520000000000000000e+03 5.520000000000000000e+03 5.540000000000000000e+03 5.540000000000000000e+03 5.560000000000000000e+03 5.560000000000000000e+03 5.580000000000000000e+03 5.580000000000000000e+03 5.600000000000000000e+03 5.600000000000000000e+03 5.620000000000000000e+03 5.620000000000000000e+03 5.640000000000000000e+03 5.640000000000000000e+03 5.660000000000000000e+03 5.660000000000000000e+03 5.680000000000000000e+03 5.680000000000000000e+03 5.700000000000000000e+03 5.700000000000000000e+03 5.720000000000000000e+03 5.720000000000000000e+03 5.740000000000000000e+03 5.740000000000000000e+03 5.760000000000000000e+03 5.760000000000000000e+03 5.780000000000000000e+03 5.780000000000000000e+03 5.800000000000000000e+03 5.800000000000000000e+03 5.820000000000000000e+03 5.820000000000000000e+03 5.840000000000000000e+03 5.840000000000000000e+03 5.860000000000000000e+03 5.860000000000000000e+03 5.880000000000000000e+03 5.880000000000000000e+03 5.900000000000000000e+03 5.900000000000000000e+03 5.920000000000000000e+03 5.920000000000000000e+03 5.940000000000000000e+03 5.940000000000000000e+03 5.960000000000000000e+03 5.960000000000000000e+03 5.980000000000000000e+03 5.980000000000000000e+03 6.000000000000000000e+03 6.000000000000000000e+03 6.020000000000000000e+03 6.020000000000000000e+03 6.040000000000000000e+03 6.040000000000000000e+03 6.060000000000000000e+03 6.060000000000000000e+03 6.080000000000000000e+03 6.080000000000000000e+03 6.100000000000000000e+03 6.100000000000000000e+03 6.120000000000000000e+03 6.120000000000000000e+03 6.140000000000000000e+03 6.140000000000000000e+03 6.160000000000000000e+03 6.160000000000000000e+03 6.180000000000000000e+03 6.180000000000000000e+03 6.200000000000000000e+03 6.200000000000000000e+03 6.220000000000000000e+03 6.220000000000000000e+03 6.240000000000000000e+03 6.240000000000000000e+03 6.260000000000000000e+03 6.260000000000000000e+03 6.280000000000000000e+03 6.280000000000000000e+03 6.300000000000000000e+03 6.300000000000000000e+03 6.320000000000000000e+03 6.320000000000000000e+03 6.340000000000000000e+03 6.340000000000000000e+03 6.360000000000000000e+03 6.360000000000000000e+03 6.380000000000000000e+03 6.380000000000000000e+03 6.400000000000000000e+03 6.400000000000000000e+03 6.420000000000000000e+03 6.420000000000000000e+03 6.440000000000000000e+03 6.440000000000000000e+03 6.460000000000000000e+03 6.460000000000000000e+03 6.480000000000000000e+03 6.480000000000000000e+03 6.500000000000000000e+03 6.500000000000000000e+03 6.520000000000000000e+03 6.520000000000000000e+03 6.540000000000000000e+03 6.540000000000000000e+03 6.560000000000000000e+03 6.560000000000000000e+03 6.580000000000000000e+03 6.580000000000000000e+03 6.600000000000000000e+03 6.600000000000000000e+03 6.620000000000000000e+03 6.620000000000000000e+03 6.640000000000000000e+03 6.640000000000000000e+03 6.660000000000000000e+03 6.660000000000000000e+03 6.680000000000000000e+03 6.680000000000000000e+03 6.700000000000000000e+03 6.700000000000000000e+03 6.720000000000000000e+03 6.720000000000000000e+03 6.740000000000000000e+03 6.740000000000000000e+03 6.760000000000000000e+03 6.760000000000000000e+03 6.780000000000000000e+03 6.780000000000000000e+03 6.800000000000000000e+03 6.800000000000000000e+03 6.820000000000000000e+03 6.820000000000000000e+03 6.840000000000000000e+03 6.840000000000000000e+03 6.860000000000000000e+03 6.860000000000000000e+03 6.880000000000000000e+03 6.880000000000000000e+03 6.900000000000000000e+03 6.900000000000000000e+03 6.920000000000000000e+03 6.920000000000000000e+03 6.940000000000000000e+03 6.940000000000000000e+03 6.960000000000000000e+03 6.960000000000000000e+03 6.980000000000000000e+03 6.980000000000000000e+03 7.000000000000000000e+03 7.000000000000000000e+03 7.020000000000000000e+03 7.020000000000000000e+03 7.040000000000000000e+03 7.040000000000000000e+03 7.060000000000000000e+03 7.060000000000000000e+03 7.080000000000000000e+03 7.080000000000000000e+03 7.100000000000000000e+03 7.100000000000000000e+03 7.120000000000000000e+03 7.120000000000000000e+03 7.140000000000000000e+03 7.140000000000000000e+03 7.160000000000000000e+03 7.160000000000000000e+03 7.180000000000000000e+03 7.180000000000000000e+03 7.200000000000000000e+03 7.200000000000000000e+03 7.220000000000000000e+03 7.220000000000000000e+03 7.240000000000000000e+03 7.240000000000000000e+03 7.260000000000000000e+03 7.260000000000000000e+03 7.280000000000000000e+03 7.280000000000000000e+03 7.300000000000000000e+03 7.300000000000000000e+03 7.320000000000000000e+03 7.320000000000000000e+03 7.340000000000000000e+03 7.340000000000000000e+03 7.360000000000000000e+03 7.360000000000000000e+03 7.380000000000000000e+03 7.380000000000000000e+03 7.400000000000000000e+03 7.400000000000000000e+03 7.420000000000000000e+03 7.420000000000000000e+03 7.440000000000000000e+03 7.440000000000000000e+03 7.460000000000000000e+03 7.460000000000000000e+03 7.480000000000000000e+03 7.480000000000000000e+03 7.500000000000000000e+03 7.500000000000000000e+03 7.520000000000000000e+03 7.520000000000000000e+03 7.540000000000000000e+03 7.540000000000000000e+03 7.560000000000000000e+03 7.560000000000000000e+03 7.580000000000000000e+03 7.580000000000000000e+03 7.600000000000000000e+03 7.600000000000000000e+03 7.620000000000000000e+03 7.620000000000000000e+03 7.640000000000000000e+03 7.640000000000000000e+03 7.660000000000000000e+03 7.660000000000000000e+03 7.680000000000000000e+03 7.680000000000000000e+03 7.700000000000000000e+03 7.700000000000000000e+03 7.720000000000000000e+03 7.720000000000000000e+03 7.740000000000000000e+03 7.740000000000000000e+03 7.760000000000000000e+03 7.760000000000000000e+03 7.780000000000000000e+03 7.780000000000000000e+03 7.800000000000000000e+03 7.800000000000000000e+03 7.820000000000000000e+03 7.820000000000000000e+03 7.840000000000000000e+03 7.840000000000000000e+03 7.860000000000000000e+03 7.860000000000000000e+03 7.880000000000000000e+03 7.880000000000000000e+03 7.900000000000000000e+03 7.900000000000000000e+03 7.920000000000000000e+03 7.920000000000000000e+03 7.940000000000000000e+03 7.940000000000000000e+03 7.960000000000000000e+03 7.960000000000000000e+03 7.980000000000000000e+03 7.980000000000000000e+03 8.000000000000000000e+03 8.000000000000000000e+03 8.020000000000000000e+03 8.020000000000000000e+03 8.040000000000000000e+03 8.040000000000000000e+03 8.060000000000000000e+03 8.060000000000000000e+03 8.080000000000000000e+03 8.080000000000000000e+03 8.100000000000000000e+03 8.100000000000000000e+03 8.120000000000000000e+03 8.120000000000000000e+03 8.140000000000000000e+03 8.140000000000000000e+03 8.160000000000000000e+03 8.160000000000000000e+03 8.180000000000000000e+03 8.180000000000000000e+03 8.200000000000000000e+03 8.200000000000000000e+03 8.220000000000000000e+03 8.220000000000000000e+03 8.240000000000000000e+03 8.240000000000000000e+03 8.260000000000000000e+03 8.260000000000000000e+03 8.280000000000000000e+03 8.280000000000000000e+03 8.300000000000000000e+03 8.300000000000000000e+03 8.320000000000000000e+03 8.320000000000000000e+03 8.340000000000000000e+03 8.340000000000000000e+03 8.360000000000000000e+03 8.360000000000000000e+03 8.380000000000000000e+03 8.380000000000000000e+03 8.400000000000000000e+03 8.400000000000000000e+03 8.420000000000000000e+03 8.420000000000000000e+03 8.440000000000000000e+03 8.440000000000000000e+03 8.460000000000000000e+03 8.460000000000000000e+03 8.480000000000000000e+03 8.480000000000000000e+03 8.500000000000000000e+03 8.500000000000000000e+03 8.520000000000000000e+03 8.520000000000000000e+03 8.540000000000000000e+03 8.540000000000000000e+03 8.560000000000000000e+03 8.560000000000000000e+03 8.580000000000000000e+03 8.580000000000000000e+03 8.600000000000000000e+03 8.600000000000000000e+03 8.620000000000000000e+03 8.620000000000000000e+03 8.640000000000000000e+03 8.640000000000000000e+03 8.660000000000000000e+03 8.660000000000000000e+03 8.680000000000000000e+03 8.680000000000000000e+03 8.700000000000000000e+03 8.700000000000000000e+03 8.720000000000000000e+03 8.720000000000000000e+03 8.740000000000000000e+03 8.740000000000000000e+03 8.760000000000000000e+03 8.760000000000000000e+03 8.780000000000000000e+03 8.780000000000000000e+03 8.800000000000000000e+03 8.800000000000000000e+03 8.820000000000000000e+03 8.820000000000000000e+03 8.840000000000000000e+03 8.840000000000000000e+03 8.860000000000000000e+03 8.860000000000000000e+03 8.880000000000000000e+03 8.880000000000000000e+03 8.900000000000000000e+03 8.900000000000000000e+03 8.920000000000000000e+03 8.920000000000000000e+03 8.940000000000000000e+03 8.940000000000000000e+03 8.960000000000000000e+03 8.960000000000000000e+03 8.980000000000000000e+03 8.980000000000000000e+03 9.000000000000000000e+03 9.000000000000000000e+03 9.020000000000000000e+03 9.020000000000000000e+03 9.040000000000000000e+03 9.040000000000000000e+03 9.060000000000000000e+03 9.060000000000000000e+03 9.080000000000000000e+03 9.080000000000000000e+03 9.100000000000000000e+03 9.100000000000000000e+03 9.120000000000000000e+03 9.120000000000000000e+03 9.140000000000000000e+03 9.140000000000000000e+03 9.160000000000000000e+03 9.160000000000000000e+03 9.180000000000000000e+03 9.180000000000000000e+03 9.200000000000000000e+03 9.200000000000000000e+03 9.220000000000000000e+03 9.220000000000000000e+03 9.240000000000000000e+03 9.240000000000000000e+03 9.260000000000000000e+03 9.260000000000000000e+03 9.280000000000000000e+03 9.280000000000000000e+03 9.300000000000000000e+03 9.300000000000000000e+03 9.320000000000000000e+03 9.320000000000000000e+03 9.340000000000000000e+03 9.340000000000000000e+03 9.360000000000000000e+03 9.360000000000000000e+03 9.380000000000000000e+03 9.380000000000000000e+03 9.400000000000000000e+03 9.400000000000000000e+03 9.420000000000000000e+03 9.420000000000000000e+03 9.440000000000000000e+03 9.440000000000000000e+03 9.460000000000000000e+03 9.460000000000000000e+03 9.480000000000000000e+03 9.480000000000000000e+03 9.500000000000000000e+03 9.500000000000000000e+03 9.520000000000000000e+03 9.520000000000000000e+03 9.540000000000000000e+03 9.540000000000000000e+03 9.560000000000000000e+03 9.560000000000000000e+03 9.580000000000000000e+03 9.580000000000000000e+03 9.600000000000000000e+03 9.600000000000000000e+03\n"
  },
  {
    "path": "spikeextractors/extractors/neuropixelsdatrecordingextractor/neuropixelsdatrecordingextractor.py",
    "content": "from spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor\nimport numpy as np\nfrom pathlib import Path\nimport warnings\n\ntry:\n    import xmltodict\n    HAVE_XMLTODICT = True\nexcept ImportError:\n    HAVE_XMLTODICT = False\n\n\nclass NeuropixelsDatRecordingExtractor(BinDatRecordingExtractor):\n    \"\"\"\n    Read raw Neurpoixels recordings from Open Ephys dat file and settings.xml\n    \n    This extractor is currently compatible with the 960 channel Neuropixels probes,\n    where a maximum of 384 channels are recorded simulatenously. The array \n    configuration can be specified by passing the settings.xml file created by \n    OpenEphys (it can be found in the directory tree with teh recordings). If this \n    is not provided, the default configuration using 384 channels at the probe tip \n    will be used (a is warning printed).\n   \n    Parameters\n    ----------\n    file_path: str\n        The raw data file (usually continuous.dat)\n    settings_file: None or str\n        The file settings.xml generated by OpenEphys containing the array \n        configuration. If not provided the default configuration using 384 \n        channels at the probe tip will be used.\n    verbose: bool\n        Print probe configuration\n    \n    \"\"\"\n    extractor_name = 'NeuropixelsDatRecording'\n    has_default_locations = True\n    has_unscaled = True\n    installed = HAVE_XMLTODICT\n    is_writable = False\n    mode = 'file'\n    installation_mesg = \"To use the NeuropixelsDat extractor, install xmltodict: \\n\\n pip install xmltodict\\n\\n\"\n\n    def __init__(self, file_path, settings_file=None, is_filtered=None, verbose=False):\n        assert self.installed, self.installation_mesg\n        source_dir = Path(__file__).parent\n        self._settings_file = settings_file\n        datfile = Path(file_path)\n        time_axis = 0\n        dtype = 'int16'\n        sampling_frequency = float(30000)\n\n        channel_locations = np.loadtxt(source_dir / 'channel_positions_neuropixels.txt')\n        if self._settings_file is not None:\n            with open(self._settings_file) as f:\n                xmldata = f.read()\n                settings = xmltodict.parse(xmldata)['SETTINGS']\n            channel_info = settings['SIGNALCHAIN']['PROCESSOR'][0]['CHANNEL_INFO']\n            channels = settings['SIGNALCHAIN']['PROCESSOR'][0]['CHANNEL']\n            recorded_channels = []\n            for c in channels:\n                if c['SELECTIONSTATE']['@record'] == '1':\n                    recorded_channels.append(int(c['@number']))\n            used_channels = []\n            used_channel_gains = []\n            for c in channel_info['CHANNEL']:\n                if 'AP' in c['@name'] and int(c['@number']) in recorded_channels:\n                    used_channels.append(int(c['@number']))\n                    used_channel_gains.append(float(c['@gain']))\n            if verbose:\n                print(f'{len(recorded_channels)} total channels found, with {len(used_channels)} recording AP')\n                print(f'Channels used:\\n{used_channels}')\n            numchan = len(used_channels)\n            geom = channel_locations[:, np.array(used_channels)].T\n            gain = used_channel_gains[0]\n            channels = used_channels\n        else:\n            warnings.warn(\"No information about this recording available,\"\n                          \"using a default of 384 channels at the probe tip.\"\n                          \"If the recording differs, use settings_file=settings.xml\")\n            numchan = 384\n            geom = channel_locations[:, :384].T\n            gain = None\n            channels = range(384)\n\n        BinDatRecordingExtractor.__init__(self, file_path=datfile, numchan=numchan, dtype=dtype,\n                                          sampling_frequency=sampling_frequency, gain=gain, geom=geom,\n                                          recording_channels=channels, time_axis=time_axis, is_filtered=is_filtered)\n\n        self._kwargs = {'filename': str(Path(file_path).absolute()), 'settings_file': settings_file,\n                        'is_filtered': is_filtered}\n"
  },
  {
    "path": "spikeextractors/extractors/neuroscopeextractors/__init__.py",
    "content": "from .neuroscopeextractors import NeuroscopeRecordingExtractor, NeuroscopeMultiRecordingTimeExtractor\nfrom .neuroscopeextractors import NeuroscopeSortingExtractor, NeuroscopeMultiSortingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/neuroscopeextractors/neuroscopeextractors.py",
    "content": "from spikeextractors import RecordingExtractor, MultiRecordingTimeExtractor, SortingExtractor, MultiSortingExtractor\nfrom spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor\nimport numpy as np\nfrom pathlib import Path\nfrom spikeextractors.extraction_tools import check_get_unit_spike_train, get_sub_extractors_by_property\nfrom typing import Union, Optional\nimport re\nimport warnings\n\ntry:\n    from lxml import etree as et\n\n    HAVE_LXML = True\nexcept ImportError:\n    HAVE_LXML = False\n\nPathType = Union[str, Path]\nOptionalPathType = Optional[PathType]\nDtypeType = Union[str, np.dtype, None]\n\n\ndef get_single_files(folder_path: Path, suffix: str):\n    return [\n        f for f in folder_path.iterdir() if f.is_file() and suffix in f.suffixes and not f.name.endswith(\"~\")\n        and len(f.suffixes) == 1\n    ]\n\n\ndef get_shank_files(folder_path: Path, suffix: str):\n    return [\n        f for f in folder_path.iterdir() if f.is_file() and suffix in f.suffixes\n        and re.search(r\"\\d+$\", f.name) is not None and len(f.suffixes) == 2\n    ]\n\ndef find_xml_file_path(folder_path: PathType):\n    xml_files = [f for f in folder_path.iterdir() if f.is_file() if f.suffix == \".xml\"]\n    assert any(xml_files), \"No .xml files found in the folder_path.\"\n    assert len(xml_files) == 1, \"More than one .xml file found in the folder_path! Specify xml_file_path.\"\n    xml_file_path = xml_files[0]\n    return xml_file_path\n\ndef handle_xml_file_path(folder_path: PathType, initial_xml_file_path: PathType):\n    if initial_xml_file_path is None:\n        xml_file_path = find_xml_file_path(folder_path=folder_path)\n    else:\n        assert Path(initial_xml_file_path).is_file(), f\".xml file ({initial_xml_file_path}) not found!\"\n        xml_file_path = initial_xml_file_path\n    return xml_file_path\n\n\nclass NeuroscopeRecordingExtractor(BinDatRecordingExtractor):\n    \"\"\"\n    Extracts raw neural recordings from binary .dat files in the neuroscope format.\n\n    The recording extractor always returns channel IDs starting from 0.\n\n    The recording data will always be returned in the shape of (num_channels,num_frames).\n\n    Parameters\n    ----------\n    file_path : str\n        Path to the .dat file to be extracted.\n    gain : float, optional\n        Numerical value that converts the native int dtype to microvolts. Defaults to 1.\n    xml_file_path : PathType, optional\n        Path to the .xml file referenced by this recording.\n    \"\"\"\n\n    extractor_name = \"NeuroscopeRecordingExtractor\"\n    installed = HAVE_LXML\n    has_default_locations = False\n    has_unscaled = False\n    is_writable = True\n    mode = \"file\"\n    installation_mesg = \"Please install lxml to use this extractor!\"\n\n    def __init__(self, file_path: PathType, gain: Optional[float] = None, xml_file_path: OptionalPathType = None):\n        assert self.installed, self.installation_mesg\n        file_path = Path(file_path)\n        assert file_path.is_file() and file_path.suffix in [\".dat\", \".eeg\", \".lfp\"], \\\n            \"file_path must lead to a .dat or .eeg file!\"\n\n        RecordingExtractor.__init__(self)\n        self._recording_file = file_path\n        xml_file_path = handle_xml_file_path(folder_path=Path(file_path).parent, initial_xml_file_path=xml_file_path)\n        xml_root = et.parse(str(xml_file_path)).getroot()\n        n_bits = int(xml_root.find('acquisitionSystem').find('nBits').text)\n        dtype = f\"int{n_bits}\"\n        numchan_from_file = int(xml_root.find('acquisitionSystem').find('nChannels').text)\n\n        if file_path.suffix == \".dat\":\n            sampling_frequency = float(xml_root.find('acquisitionSystem').find('samplingRate').text)\n        else:\n            sampling_frequency = float(xml_root.find('fieldPotentials').find('lfpSamplingRate').text)\n\n        BinDatRecordingExtractor.__init__(self, file_path, sampling_frequency=sampling_frequency,\n                                          dtype=dtype, numchan=numchan_from_file, gain=gain)\n        self._kwargs = dict(file_path=str(Path(file_path).absolute()), gain=gain)\n\n    @staticmethod\n    def write_recording(\n        recording: RecordingExtractor,\n        save_path: PathType,\n        dtype: DtypeType = None,\n        **write_binary_kwargs\n    ):\n        \"\"\"\n        Convert and save the recording extractor to Neuroscope format.\n\n        Parameters\n        ----------\n        recording: RecordingExtractor\n            The recording extractor to be converted and saved.\n        save_path: str\n            Path to desired target folder. The name of the files will be the same as the final directory.\n        dtype: dtype\n            Optional. Data type to be used in writing; must be int16 or int32 (default).\n                      Will throw a warning if stored recording type from get_traces() does not match.\n        **write_binary_kwargs: keyword arguments for write_to_binary_dat_format function\n            - chunk_size\n            - chunk_mb\n        \"\"\"\n        save_path = Path(save_path)\n        save_path.mkdir(parents=True, exist_ok=True)\n\n        if save_path.suffix == \"\":\n            recording_name = save_path.name\n        else:\n            recording_name = save_path.stem\n        xml_name = recording_name\n\n        save_xml_filepath = save_path / f\"{xml_name}.xml\"\n        recording_filepath = save_path / recording_name\n\n        # create parameters file if none exists\n        if save_xml_filepath.is_file():\n            raise FileExistsError(f\"{save_xml_filepath} already exists!\")\n\n        xml_root = et.Element('xml')\n        et.SubElement(xml_root, 'acquisitionSystem')\n        et.SubElement(xml_root.find('acquisitionSystem'), 'nBits')\n        et.SubElement(xml_root.find('acquisitionSystem'), 'nChannels')\n        et.SubElement(xml_root.find('acquisitionSystem'), 'samplingRate')\n\n        recording_dtype = str(recording.get_dtype())\n        int_loc = recording_dtype.find('int')\n        recording_n_bits = recording_dtype[(int_loc + 3):(int_loc + 5)]\n\n        valid_dtype = [\"16\", \"32\"]\n        if dtype is None:\n            if int_loc != -1 and recording_n_bits in valid_dtype:\n                n_bits = recording_n_bits\n            else:\n                print(\"Warning: Recording data type must be int16 or int32! Defaulting to int32.\")\n                n_bits = \"32\"\n            dtype = f\"int{n_bits}\"  # update dtype in pass to BinDatRecordingExtractor.write_recording\n        else:\n            dtype = str(dtype)  # if user passed numpy data type\n            int_loc = dtype.find('int')\n            assert int_loc != -1, \"Data type must be int16 or int32! Non-integer received.\"\n            n_bits = dtype[(int_loc + 3):(int_loc + 5)]\n            assert n_bits in valid_dtype, \"Data type must be int16 or int32!\"\n\n        xml_root.find('acquisitionSystem').find('nBits').text = n_bits\n        xml_root.find('acquisitionSystem').find('nChannels').text = str(recording.get_num_channels())\n        xml_root.find('acquisitionSystem').find('samplingRate').text = str(recording.get_sampling_frequency())\n\n        et.ElementTree(xml_root).write(str(save_xml_filepath), pretty_print=True)\n\n        recording.write_to_binary_dat_format(recording_filepath, dtype=dtype, **write_binary_kwargs)\n\n\nclass NeuroscopeMultiRecordingTimeExtractor(MultiRecordingTimeExtractor):\n    \"\"\"\n    Extracts raw neural recordings from several binary .dat files in the neuroscope format.\n\n    The recording extractor always returns channel IDs starting from 0.\n\n    The recording data will always be returned in the shape of (num_channels,num_frames).\n\n    Parameters\n    ----------\n    folder_path : PathType\n        Path to the .dat files to be extracted.\n    gain : float, optional\n        Numerical value that converts the native int dtype to microvolts. Defaults to 1.\n    xml_file_path : PathType, optional\n        Path to the .xml file referenced by this recording.\n    \"\"\"\n\n    extractor_name = \"NeuroscopeMultiRecordingTimeExtractor\"\n    installed = HAVE_LXML\n    is_writable = True\n    mode = \"folder\"\n    installation_mesg = \"Please install lxml to use this extractor!\"\n\n    def __init__(self, folder_path: PathType, gain: Optional[float] = None, xml_file_path: OptionalPathType = None):\n        assert self.installed, self.installation_mesg\n\n        folder_path = Path(folder_path)\n        recording_files = [x for x in folder_path.iterdir() if x.is_file() and x.suffix == \".dat\"]\n        assert any(recording_files), \"The folder_path must lead to at least one .dat file!\"\n\n        recordings = [NeuroscopeRecordingExtractor(file_path=x, gain=gain, xml_file_path=xml_file_path) for x in recording_files]\n        MultiRecordingTimeExtractor.__init__(self, recordings=recordings)\n\n        self._kwargs = dict(folder_path=str(folder_path.absolute()), gain=gain)\n\n    @staticmethod\n    def write_recording(\n        recording: Union[MultiRecordingTimeExtractor, RecordingExtractor],\n        save_path: PathType,\n        dtype: DtypeType = None,\n        **write_binary_kwargs\n    ):\n        \"\"\"\n        Convert and save the recording extractor to Neuroscope format.\n\n        Parameters\n        ----------\n        recording: MultiRecordingTimeExtractor or RecordingExtractor\n            The recording extractor to be converted and saved.\n        save_path: str\n            Path to desired target folder. The name of the files will be the same as the final directory.\n        dtype: dtype\n            Optional. Data type to be used in writing; must be int16 or int32 (default).\n                      Will throw a warning if stored recording type from get_traces() does not match.\n        **write_binary_kwargs: keyword arguments for write_to_binary_dat_format function\n            - chunk_size\n            - chunk_mb\n        \"\"\"\n        save_path = Path(save_path)\n        save_path.mkdir(parents=True, exist_ok=True)\n\n        if save_path.suffix == \"\":\n            recording_name = save_path.name\n        else:\n            recording_name = save_path.stem\n\n        xml_name = recording_name\n        save_xml_filepath = save_path / f\"{xml_name}.xml\"\n        if save_xml_filepath.is_file():\n            raise FileExistsError(f\"{save_xml_filepath} already exists!\")\n\n        recording_dtype = str(recording.get_dtype())\n        int_loc = recording_dtype.find(\"int\")\n        recording_n_bits = recording_dtype[(int_loc + 3):(int_loc + 5)]\n\n        valid_int_types = [\"16\", \"32\"]\n        if dtype is None:\n            if int_loc != -1 and recording_n_bits in valid_int_types:\n                n_bits = recording_n_bits\n            else:\n                warnings.warn(\"Recording data type must be int16 or int32! Defaulting to int32.\")\n                n_bits = \"32\"\n            dtype = f\"int{n_bits}\"\n        else:\n            dtype = str(dtype)\n            int_loc = dtype.find('int')\n            assert int_loc != -1, \"Data type must be int16 or int32! Non-integer received.\"\n            n_bits = dtype[(int_loc + 3):(int_loc + 5)]\n            assert n_bits in valid_int_types, \"Data type must be int16 or int32!\"\n\n        xml_root = et.Element('xml')\n        et.SubElement(xml_root, 'acquisitionSystem')\n        et.SubElement(xml_root.find('acquisitionSystem'), 'nBits')\n        et.SubElement(xml_root.find('acquisitionSystem'), 'nChannels')\n        et.SubElement(xml_root.find('acquisitionSystem'), 'samplingRate')\n        xml_root.find('acquisitionSystem').find('nBits').text = n_bits\n        xml_root.find('acquisitionSystem').find('nChannels').text = str(recording.get_num_channels())\n        xml_root.find('acquisitionSystem').find('samplingRate').text = str(recording.get_sampling_frequency())\n        et.ElementTree(xml_root).write(str(save_xml_filepath.absolute()), pretty_print=True)\n\n        if isinstance(recording, MultiRecordingTimeExtractor):\n            for n, record in enumerate(recording.recordings):\n                epoch_id = str(n).zfill(2)  # Neuroscope seems to zero-pad length 2\n                record.write_to_binary_dat_format(\n                    save_path=save_path / f\"{recording_name}-{epoch_id}.dat\",\n                    dtype=dtype,\n                    **write_binary_kwargs\n                )\n\n        elif isinstance(recording, RecordingExtractor):\n            recordings = [recording.get_epoch(epoch_name=epoch_name) for epoch_name in recording.get_epoch_names()]\n\n            if len(recordings) == 0:\n                recording.write_to_binary_dat_format(\n                    save_path=save_path / f\"{recording_name}.dat\",\n                    dtype=dtype,\n                    **write_binary_kwargs\n                )\n            else:\n                for n, subrecording in enumerate(recordings):\n                    epoch_id = str(n).zfill(2)  # Neuroscope seems to zero-pad length 2\n                    subrecording.write_to_binary_dat_format(\n                        save_path=save_path / f\"{recording_name}-{epoch_id}.dat\",\n                        dtype=dtype,\n                        **write_binary_kwargs\n                    )\n\n\nclass NeuroscopeSortingExtractor(SortingExtractor):\n    \"\"\"\n    Extracts spiking information from pair of .res and .clu files.\n\n    The .res is a text file with a sorted list of spiketimes from all units displayed in sample (integer '%i') units.\n    The .clu file is a file with one more row than the .res with the first row corresponding to\n    the total number of unique ids in the file (and may exclude 0 & 1 from this count)\n    with the rest of the rows indicating which unit id the corresponding entry in the\n    .res file refers to.\n\n    In the original Neuroscope format:\n        Unit ID 0 is the cluster of unsorted spikes (noise).\n        Unit ID 1 is a cluster of multi-unit spikes.\n\n    The function defaults to returning multi-unit activity as the first index, and ignoring unsorted noise.\n    To return only the fully sorted units, set keep_mua_units=False.\n\n    The sorting extractor always returns unit IDs from 1, ..., number of chosen clusters.\n\n    Parameters\n    ----------\n    resfile_path : PathType\n        Optional. Path to a particular .res text file.\n    clufile_path : PathType\n        Optional. Path to a particular .clu text file.\n    folder_path : PathType\n        Optional. Path to the collection of .res and .clu text files. Will auto-detect format.\n    keep_mua_units : bool\n        Optional. Whether or not to return sorted spikes from multi-unit activity. Defaults to True.\n    spkfile_path : PathType\n        Optional. Path to a particular .spk binary file containing waveform snippets added to the extractor as features.\n    gain : float\n        Optional. If passing a spkfile_path, this value converts the data type of the waveforms to units of microvolts.\n    xml_file_path : PathType, optional\n        Path to the .xml file referenced by this sorting.\n    \"\"\"\n\n    extractor_name = \"NeuroscopeSortingExtractor\"\n    installed = HAVE_LXML\n    is_writable = True\n    mode = \"custom\"\n    installation_mesg = \"Please install lxml to use this extractor!\"\n\n    def __init__(\n        self,\n        resfile_path: OptionalPathType = None,\n        clufile_path: OptionalPathType = None,\n        folder_path: OptionalPathType = None,\n        keep_mua_units: bool = True,\n        spkfile_path: OptionalPathType = None,\n        gain: Optional[float] = None,\n        xml_file_path: OptionalPathType = None,\n    ):\n        assert self.installed, self.installation_mesg\n        assert not (folder_path is None and resfile_path is None and clufile_path is None), \\\n            \"Either pass a single folder_path location, or a pair of resfile_path and clufile_path! None received.\"\n\n        if resfile_path is not None:\n            assert clufile_path is not None, \"If passing resfile_path or clufile_path, both are required!\"\n            resfile_path = Path(resfile_path)\n            clufile_path = Path(clufile_path)\n            assert resfile_path.is_file() and clufile_path.is_file(), \\\n                f\"The resfile_path ({resfile_path}) and clufile_path ({clufile_path}) must be .res and .clu files!\"\n\n            assert folder_path is None, \"Pass either a single folder_path location, \" \\\n                                        \"or a pair of resfile_path and clufile_path! All received.\"\n            folder_path_passed = False\n            folder_path = resfile_path.parent\n        else:\n            assert folder_path is not None, \"Either pass resfile_path and clufile_path, or folder_path!\"\n            folder_path = Path(folder_path)\n            assert folder_path.is_dir(), \"The folder_path must be a directory!\"\n\n            res_files = get_single_files(folder_path=folder_path, suffix=\".res\")\n            clu_files = get_single_files(folder_path=folder_path, suffix=\".clu\")\n\n            assert len(res_files) > 0 or len(clu_files) > 0, \\\n                \"No .res or .clu files found in the folder_path!\"\n            assert len(res_files) == 1 and len(clu_files) == 1, \\\n                \"NeuroscopeSortingExtractor expects a single pair of .res and .clu files in the folder_path. \" \\\n                \"For multiple .res and .clu files, use the NeuroscopeMultiSortingExtractor instead.\"\n\n            folder_path_passed = True  # flag for setting kwargs for proper dumping\n            resfile_path = res_files[0]\n            clufile_path = clu_files[0]\n\n        SortingExtractor.__init__(self)\n\n        res_sorting_name = resfile_path.name[:resfile_path.name.find('.res')]\n        clu_sorting_name = clufile_path.name[:clufile_path.name.find('.clu')]\n\n        assert res_sorting_name == clu_sorting_name, \"The .res and .clu files do not share the same name! \" \\\n                                                     f\"{res_sorting_name}  -- {clu_sorting_name}\"\n\n        xml_file_path = handle_xml_file_path(folder_path=folder_path, initial_xml_file_path=xml_file_path)\n        xml_root = et.parse(str(xml_file_path)).getroot()\n        self._sampling_frequency = float(xml_root.find('acquisitionSystem').find('samplingRate').text)\n\n        with open(resfile_path) as f:\n            res = np.array([int(line) for line in f], np.int64)\n        with open(clufile_path) as f:\n            clu = np.array([int(line) for line in f], np.int64)\n\n        n_spikes = len(res)\n        if n_spikes > 0:\n            # Extract the number of unique IDs from the first line of the clufile then remove it from the list\n            n_clu = clu[0]\n            clu = np.delete(clu, 0)\n            unique_ids = np.unique(clu)\n            assert len(unique_ids) == n_clu, (\n                \"First value of .clu file ({clufile_path}) does not match number of unique IDs!\"\n            )\n            unit_map = dict(zip(unique_ids, list(range(1, n_clu + 1))))\n\n            if 0 in unique_ids:\n                unit_map.pop(0)\n            if not keep_mua_units and 1 in unique_ids:\n                unit_map.pop(1)\n            self._unit_ids = unit_map.values()\n            self._spiketrains = []\n            for s_id in unit_map:\n                self._spiketrains.append(res[(clu == s_id).nonzero()])\n\n        if spkfile_path is not None and Path(spkfile_path).is_file():\n            n_bits = int(xml_root.find('acquisitionSystem').find('nBits').text)\n            dtype = f\"int{n_bits}\"\n            n_samples = int(xml_root.find('neuroscope').find('spikes').find('nSamples').text)\n            wf = np.moveaxis(np.memmap(spkfile_path, dtype=dtype).reshape(n_spikes, n_samples, -1), 1, -1)\n\n            for unit_id in self.get_unit_ids():\n                if gain is not None:\n                    self.set_unit_property(unit_id=unit_id, property_name='gain', value=gain)\n                self.set_unit_spike_features(\n                    unit_id=unit_id,\n                    feature_name='waveforms',\n                    value=wf[clu == unit_id + 1 - int(keep_mua_units), :, :]\n                )\n\n        if folder_path_passed:\n            self._kwargs = dict(\n                resfile_path=None,\n                clufile_path=None,\n                folder_path=str(folder_path.absolute()),\n                keep_mua_units=keep_mua_units,\n                gain=gain\n            )\n        else:\n            self._kwargs = dict(\n                resfile_path=str(resfile_path.absolute()),\n                clufile_path=str(clufile_path.absolute()),\n                folder_path=None,\n                keep_mua_units=keep_mua_units,\n                gain=gain\n            )\n        if spkfile_path is not None:\n            self._kwargs.update(spkfile_path=str(spkfile_path.absolute()))\n        else:\n            self._kwargs.update(spkfile_path=spkfile_path)\n\n    def get_unit_ids(self):\n        return list(self._unit_ids)\n\n    def get_sampling_frequency(self):\n        return self._sampling_frequency\n\n    def shift_unit_ids(self, shift):\n        self._unit_ids = [x + shift for x in self._unit_ids]\n\n    def add_unit(self, unit_id, spike_times):\n        \"\"\"This function adds a new unit with the given spike times.\n\n        Parameters\n        ----------\n        unit_id: int\n            The unit_id of the unit to be added.\n        \"\"\"\n        self._unit_ids.append(unit_id)\n        self._spiketrains.append(spike_times)\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n\n        times = self._spiketrains[self.get_unit_ids().index(unit_id)]\n        inds = np.where((start_frame <= times) & (times < end_frame))\n        return times[inds]\n\n    @staticmethod\n    def write_sorting(sorting: SortingExtractor, save_path: PathType):\n        # if multiple groups, use the NeuroscopeMultiSortingExtactor write function\n        if 'group' in sorting.get_shared_unit_property_names():\n            NeuroscopeMultiSortingExtractor.write_sorting(sorting, save_path)\n        else:\n            save_path.mkdir(parents=True, exist_ok=True)\n\n            if save_path.suffix == '':\n                sorting_name = save_path.name\n            else:\n                sorting_name = save_path.stem\n            xml_name = sorting_name\n            save_xml_filepath = save_path / (str(xml_name) + '.xml')\n\n            # create parameters file if none exists\n            if save_xml_filepath.is_file():\n                raise FileExistsError(f'{save_xml_filepath} already exists!')\n\n            xml_root = et.Element('xml')\n            et.SubElement(xml_root, 'acquisitionSystem')\n            et.SubElement(xml_root.find('acquisitionSystem'), 'samplingRate')\n            xml_root.find('acquisitionSystem').find('samplingRate').text = str(sorting.get_sampling_frequency())\n            et.ElementTree(xml_root).write(str(save_xml_filepath.absolute()), pretty_print=True)\n\n            # Create and save .res and .clu files from the current sorting object\n            save_res = save_path / f'{sorting_name}.res'\n            save_clu = save_path / f'{sorting_name}.clu'\n\n            res, clu = _extract_res_clu_arrays(sorting)\n\n            np.savetxt(save_res, res, fmt='%i')\n            np.savetxt(save_clu, clu, fmt='%i')\n\n\nclass NeuroscopeMultiSortingExtractor(MultiSortingExtractor):\n    \"\"\"\n    Extracts spiking information from an arbitrary number of .res.%i and .clu.%i files in the general folder path.\n\n    The .res is a text file with a sorted list of spiketimes from all units displayed in sample (integer '%i') units.\n    The .clu file is a file with one more row than the .res with the first row corresponding to the total number of\n    unique ids in the file (and may exclude 0 & 1 from this count)\n    with the rest of the rows indicating which unit id the corresponding entry in the .res file refers to.\n    The group id is loaded as unit property 'group'.\n\n    In the original Neuroscope format:\n        Unit ID 0 is the cluster of unsorted spikes (noise).\n        Unit ID 1 is a cluster of multi-unit spikes.\n\n    The function defaults to returning multi-unit activity as the first index, and ignoring unsorted noise.\n    To return only the fully sorted units, set keep_mua_units=False.\n\n    The sorting extractor always returns unit IDs from 1, ..., number of chosen clusters.\n\n    Parameters\n    ----------\n    folder_path : str\n        Optional. Path to the collection of .res and .clu text files. Will auto-detect format.\n    keep_mua_units : bool\n        Optional. Whether or not to return sorted spikes from multi-unit activity. Defaults to True.\n    exclude_shanks : list\n        Optional. List of indices to ignore. The set of all possible indices is chosen by default, extracted as the\n        final integer of all the .res.%i and .clu.%i pairs.\n    load_waveforms : bool\n        Optional. If True, extracts waveform data from .spk.%i files in the path corresponding to\n        the .res.%i and .clue.%i files and sets these as unit spike features. Defaults to False.\n    gain : float\n        Optional. If passing a spkfile_path, this value converts the data type of the waveforms to units of microvolts.\n    xml_file_path : PathType, optional\n        Path to the .xml file referenced by this sorting.\n    \"\"\"\n\n    extractor_name = \"NeuroscopeMultiSortingExtractor\"\n    installed = HAVE_LXML\n    is_writable = True\n    mode = \"folder\"\n    installation_mesg = \"Please install lxml to use this extractor!\"\n\n    def __init__(\n        self,\n        folder_path: PathType,\n        keep_mua_units: bool = True,\n        exclude_shanks: Optional[list] = None,\n        load_waveforms: bool = False,\n        gain: Optional[float] = None,\n        xml_file_path: OptionalPathType = None,\n    ):\n        assert self.installed, self.installation_mesg\n\n        folder_path = Path(folder_path)\n\n        if exclude_shanks is not None:  # dumping checks do not like having an empty list as default\n            assert all([isinstance(x, (int, np.integer)) and x >= 0 for x in\n                        exclude_shanks]), \"Optional argument 'exclude_shanks' must contain positive integers only!\"\n            exclude_shanks_passed = True\n        else:\n            exclude_shanks = []\n            exclude_shanks_passed = False\n        xml_file_path = handle_xml_file_path(folder_path=folder_path, initial_xml_file_path=xml_file_path)\n        xml_root = et.parse(str(xml_file_path)).getroot()\n        self._sampling_frequency = float(xml_root.find('acquisitionSystem').find('samplingRate').text)\n\n        res_files = get_shank_files(folder_path=folder_path, suffix=\".res\")\n        clu_files = get_shank_files(folder_path=folder_path, suffix=\".clu\")\n\n        assert len(res_files) > 0 or len(clu_files) > 0, \"No .res or .clu files found in the folder_path!\"\n        assert len(res_files) == len(clu_files)\n\n        res_ids = [int(x.suffix[1:]) for x in res_files]\n        clu_ids = [int(x.suffix[1:]) for x in clu_files]\n        assert sorted(res_ids) == sorted(clu_ids), \"Unmatched .clu.%i and .res.%i files detected!\"\n        if any([x not in res_ids for x in exclude_shanks]):\n            warnings.warn(\"Detected indices in exclude_shanks that are not in the directory! These will be ignored.\")\n\n        resfile_names = [x.name[:x.name.find('.res')] for x in res_files]\n        clufile_names = [x.name[:x.name.find('.clu')] for x in clu_files]\n        assert np.all(r == c for (r, c) in zip(resfile_names, clufile_names)), \\\n            \"Some of the .res.%i and .clu.%i files do not share the same name!\"\n        sorting_name = resfile_names[0]\n\n        all_shanks_list_se = []\n        for shank_id in list(set(res_ids) - set(exclude_shanks)):\n            nse_args = dict(\n                resfile_path=folder_path / f\"{sorting_name}.res.{shank_id}\",\n                clufile_path=folder_path / f\"{sorting_name}.clu.{shank_id}\",\n                keep_mua_units=keep_mua_units,\n                xml_file_path=xml_file_path,\n            )\n\n            if load_waveforms:\n                spk_files = get_shank_files(folder_path=folder_path, suffix=\".spk\")\n                assert len(spk_files) > 0, \"No .spk files found in the folder_path, but 'write_waveforms' is True!\"\n                assert len(spk_files) == len(res_files), \"Mismatched number of .spk and .res files!\"\n\n                spk_ids = [int(x.suffix[1:]) for x in spk_files]\n                assert sorted(spk_ids) == sorted(res_ids), \"Unmatched .spk.%i and .res.%i files detected!\"\n\n                spkfile_names = [x.name[:x.name.find('.spk')] for x in spk_files]\n                assert np.all(s == r for (s, r) in zip(spkfile_names, resfile_names)), \\\n                    \"Some of the .spk.%i and .res.%i files do not share the same name!\"\n\n                nse_args.update(spkfile_path=folder_path / f\"{sorting_name}.spk.{shank_id}\", gain=gain)\n\n            all_shanks_list_se.append(NeuroscopeSortingExtractor(**nse_args))\n\n        MultiSortingExtractor.__init__(self, sortings=all_shanks_list_se)\n\n        if exclude_shanks_passed:\n            self._kwargs = dict(\n                folder_path=str(folder_path.absolute()),\n                keep_mua_units=keep_mua_units,\n                exclude_shanks=exclude_shanks,\n                load_waveforms=load_waveforms,\n                gain=gain\n            )\n        else:\n            self._kwargs = dict(\n                folder_path=str(folder_path.absolute()),\n                keep_mua_units=keep_mua_units,\n                exclude_shanks=None,\n                load_waveforms=load_waveforms,\n                gain=gain\n            )\n\n    @staticmethod\n    def write_sorting(sorting: Union[MultiSortingExtractor, SortingExtractor], save_path: PathType):\n        save_path = Path(save_path)\n        if save_path.suffix == '':\n            sorting_name = save_path.name\n        else:\n            sorting_name = save_path.stem\n        xml_name = sorting_name\n        save_xml_filepath = save_path / (str(xml_name) + '.xml')\n\n        assert not save_path.is_file(), \"Argument 'save_path' should be a folder!\"\n        save_path.mkdir(parents=True, exist_ok=True)\n\n        if save_xml_filepath.is_file():\n            raise FileExistsError(f\"{save_xml_filepath} already exists!\")\n\n        xml_root = et.Element('xml')\n        et.SubElement(xml_root, 'acquisitionSystem')\n        et.SubElement(xml_root.find('acquisitionSystem'), 'samplingRate')\n        xml_root.find('acquisitionSystem').find('samplingRate').text = str(sorting.get_sampling_frequency())\n        et.ElementTree(xml_root).write(str(save_xml_filepath.absolute()), pretty_print=True)\n\n        if isinstance(sorting, MultiSortingExtractor):\n            counter = 1\n            for sort in sorting.sortings:\n                # Create and save .res.%i and .clu.%i files from the current sorting object\n                save_res = save_path / f\"{sorting_name}.res.{counter}\"\n                save_clu = save_path / f\"{sorting_name}.clu.{counter}\"\n                counter += 1\n\n                res, clu = _extract_res_clu_arrays(sort)\n\n                np.savetxt(save_res, res, fmt=\"%i\")\n                np.savetxt(save_clu, clu, fmt=\"%i\")\n\n        elif isinstance(sorting, SortingExtractor):\n            # assert units have group property\n            assert 'group' in sorting.get_shared_unit_property_names()\n            sortings, groups = get_sub_extractors_by_property(sorting, 'group', return_property_list=True)\n\n            for (sort, group) in zip(sortings, groups):\n                # Create and save .res.%i and .clu.%i files from the current sorting object\n                save_res = save_path / f\"{sorting_name}.res.{group}\"\n                save_clu = save_path / f\"{sorting_name}.clu.{group}\"\n\n                res, clu = _extract_res_clu_arrays(sort)\n\n                np.savetxt(save_res, res, fmt=\"%i\")\n                np.savetxt(save_clu, clu, fmt=\"%i\")\n\n\ndef _extract_res_clu_arrays(sorting):\n    unit_ids = sorting.get_unit_ids()\n    if len(unit_ids) > 0:\n        spiketrains = [sorting.get_unit_spike_train(u) for u in unit_ids]\n        res = np.concatenate(spiketrains).ravel()\n        clu = np.concatenate(\n            [np.repeat(i + 1, len(st)) for i, st in enumerate(spiketrains)]).ravel()  # i here counts from 0\n        res_sort = np.argsort(res)\n        res = res[res_sort]\n        clu = clu[res_sort]\n\n        unique_ids = np.unique(clu)\n        n_clu = len(unique_ids)\n        clu = np.insert(clu, 0, n_clu)  # The +1 is necessary becuase the base sorting object is from 1,...,nUnits\n    else:\n        res, clu = [], []\n\n    return res, clu\n"
  },
  {
    "path": "spikeextractors/extractors/nixioextractors/__init__.py",
    "content": "from .nixioextractors import NIXIORecordingExtractor, NIXIOSortingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/nixioextractors/nixioextractors.py",
    "content": "import os\nimport numpy as np\nfrom collections.abc import Iterable\nfrom pathlib import Path\ntry:\n    import nixio as nix\n    HAVE_NIXIO = True\nexcept ImportError:\n    HAVE_NIXIO = False\n\nfrom spikeextractors import RecordingExtractor\nfrom spikeextractors import SortingExtractor\nfrom spikeextractors.extraction_tools import check_get_traces_args, check_get_unit_spike_train\n\n\nclass NIXIORecordingExtractor(RecordingExtractor):\n    extractor_name = 'NIXIORecording'\n    has_default_locations = False\n    has_unscaled = False\n    installed = HAVE_NIXIO\n    is_writable = True\n    installation_mesg = \"To use the NIXIORecordingExtractor install nixio: \\n\\n pip install nixio\\n\\n\"\n    mode = 'file'\n\n    def __init__(self, file_path):\n        assert self.installed, self.installation_mesg\n        file_path = str(file_path)\n        RecordingExtractor.__init__(self)\n        self._file = nix.File.open(file_path, nix.FileMode.ReadOnly)\n        self._load_properties()\n        self._kwargs = {'file_path': str(Path(file_path).absolute())}\n\n    def __del__(self):\n        self._file.close()\n\n    @property\n    def _traces(self):\n        blk = self._file.blocks[0]\n        da = blk.data_arrays[\"traces\"]\n        return da\n\n    def get_channel_ids(self):\n        da = self._traces\n        channel_dim = da.dimensions[0]\n        channel_ids = [int(chid) for chid in channel_dim.labels]\n        return channel_ids\n\n    def get_num_frames(self):\n        da = self._traces\n        return da.shape[1]\n\n    def get_sampling_frequency(self):\n        da = self._traces\n        timedim = da.dimensions[1]\n        sampling_frequency = 1./timedim.sampling_interval\n        return sampling_frequency\n\n    @check_get_traces_args\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        channels = np.array([self._traces[cid] for cid in channel_ids])\n        return channels[:, start_frame:end_frame]\n\n    def _load_properties(self):\n        traces_md = self._traces.metadata\n        if traces_md is None:\n            # no metadata stored\n            return\n\n        for chan_md in traces_md.sections:\n            chan_id = int(chan_md.name)\n            for prop in chan_md.props:\n                values = prop.values\n                if self._file.version <= (1, 1, 0):\n                    values = [v.value for v in prop.values]\n                if len(values) == 1:\n                    values = values[0]\n                self.set_channel_property(chan_id, prop.name, values)\n\n    @staticmethod\n    def write_recording(recording, save_path, overwrite=False):\n        assert HAVE_NIXIO, NIXIORecordingExtractor.installation_mesg\n        if os.path.exists(save_path) and not overwrite:\n            raise FileExistsError(\"File exists: {}\".format(save_path))\n\n        nf = nix.File.open(save_path, nix.FileMode.Overwrite)\n        # use the file name to name the top-level block\n        fname = os.path.basename(save_path)\n        block = nf.create_block(fname, \"spikeinterface.recording\")\n        da = block.create_data_array(\"traces\", \"spikeinterface.traces\",\n                                     data=recording.get_traces())\n        da.unit = \"uV\"\n        da.label = \"voltage\"\n        labels = recording.get_channel_ids()\n        if not labels:  # channel IDs not specified; just number them\n            labels = list(range(recording.get_num_channels()))\n        chandim = da.append_set_dimension()\n        chandim.labels = labels\n        sfreq = recording.get_sampling_frequency()\n        timedim = da.append_sampled_dimension(sampling_interval=1./sfreq)\n        timedim.unit = \"s\"\n\n        # In NIX, channel properties are stored as follows\n        # Traces metadata (nix.Section)\n        #     |\n        #     |--- Channel 0 (nix.Section)\n        #     |       |\n        #     |       |---- Location (nix.Property)\n        #     |       |\n        #     |       |---- Other property a (nix.Property)\n        #     |       |\n        #     |       `---- Other property b (nix.Property)\n        #     |\n        #     `--- Channel 1 (nix.Section)\n        #             |\n        #             |---- Location (nix.Property)\n        #             |\n        #             |---- Other property a (nix.Property)\n        #             |\n        #             `---- Other property b (nix.Property)\n        traces_md = nf.create_section(\"traces.metadata\",\n                                      \"spikeinterface.properties\")\n        da.metadata = traces_md\n        channels = recording.get_channel_ids()\n        for chan_id in channels:\n            chan_md = traces_md.create_section(str(chan_id),\n                                               \"spikeinterface.properties\")\n            for propname in recording.get_channel_property_names(chan_id):\n                propvalue = recording.get_channel_property(chan_id, propname)\n                if nf.version <= (1, 1, 0):\n                    if isinstance(propvalue, Iterable):\n                        values = list(map(nix.Value, propvalue))\n                    else:\n                        values = nix.Value(propvalue)\n                else:\n                    values = propvalue\n                chan_md.create_property(propname, values)\n\n        nf.close()\n\n\nclass NIXIOSortingExtractor(SortingExtractor):\n    extractor_name = 'NIXIOSorting'\n    installed = HAVE_NIXIO\n    is_writable = True\n    installation_mesg = \"To use the NIXIORecordingExtractor install nixio: \\n\\n pip install nixio\\n\\n\"\n    mode = 'file'\n\n    def __init__(self, file_path):\n        assert self.installed, self.installation_mesg\n        file_path = str(file_path)\n        SortingExtractor.__init__(self)\n        self._file = nix.File.open(file_path, nix.FileMode.ReadOnly)\n        md = self._file.sections\n        if \"sampling_frequency\" in md:\n            sfreq = md[\"sampling_frequency\"]\n            self._sampling_frequency = sfreq\n        self._load_properties()\n        self._kwargs = {'file_path': str(Path(file_path).absolute())}\n\n    def __del__(self):\n        self._file.close()\n\n    @property\n    def _spike_das(self):\n        blk = self._file.blocks[0]\n        return blk.data_arrays\n\n    def get_unit_ids(self):\n        return [int(da.label) for da in self._spike_das]\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n        name = \"spikes-{}\".format(unit_id)\n        da = self._spike_das[name]\n        if np.isfinite(end_frame):\n            return da[start_frame:end_frame]\n        else:\n            return da[start_frame:]\n\n    def _load_properties(self):\n        spikes_md = self._spike_das[0].metadata\n        if spikes_md is None:\n            # no metadata stored\n            return\n\n        for unit_md in spikes_md.sections:\n            unit_id = int(unit_md.name)\n            for prop in unit_md.props:\n                values = prop.values\n                if self._file.version <= (1, 1, 0):\n                    values = [v.value for v in prop.values]\n                if len(values) == 1:\n                    values = values[0]\n                self.set_unit_property(unit_id, prop.name, values)\n\n    @staticmethod\n    def write_sorting(sorting, save_path, overwrite=False):\n        assert HAVE_NIXIO, NIXIOSortingExtractor.installation_mesg\n\n        if os.path.exists(save_path) and not overwrite:\n            raise FileExistsError(\"File exists: {}\".format(save_path))\n\n        sfreq = sorting.get_sampling_frequency()\n        if sfreq is None:\n            unit = None\n        elif sfreq == 1:\n            unit = \"s\"\n        else:\n            unit = \"{} s\".format(1./sfreq)\n\n        nf = nix.File.open(save_path, nix.FileMode.Overwrite)\n        # use the file name to name the top-level block\n        fname = os.path.basename(save_path)\n        block = nf.create_block(fname, \"spikeinterface.sorting\")\n        commonmd = nf.create_section(fname, \"spikeinterface.sorting.metadata\")\n        if sfreq is not None:\n            commonmd[\"sampling_frequency\"] = sfreq\n\n        spikes_das = list()\n        for unit_id in sorting.get_unit_ids():\n            spikes = sorting.get_unit_spike_train(unit_id)\n            name = \"spikes-{}\".format(unit_id)\n            da = block.create_data_array(name, \"spikeinterface.spikes\",\n                                         data=spikes)\n            da.unit = unit\n            da.label = str(unit_id)\n            spikes_das.append(da)\n\n        spikes_md = nf.create_section(\"spikes.metadata\",\n                                      \"spikeinterface.properties\")\n        for da in spikes_das:\n            da.metadata = spikes_md\n\n        units = sorting.get_unit_ids()\n        for unit_id in units:\n            unit_md = spikes_md.create_section(str(unit_id),\n                                               \"spikeinterface.properties\")\n            for propname in sorting.get_unit_property_names(unit_id):\n                propvalue = sorting.get_unit_property(unit_id, propname)\n                if nf.version <= (1, 1, 0):\n                    if isinstance(propvalue, Iterable):\n                        values = list(map(nix.Value, propvalue))\n                    else:\n                        values = nix.Value(propvalue)\n                else:\n                    values = propvalue\n                unit_md.create_property(propname, values)\n\n        nf.close()\n"
  },
  {
    "path": "spikeextractors/extractors/npzsortingextractor/__init__.py",
    "content": "from .npzsortingextractor import NpzSortingExtractor"
  },
  {
    "path": "spikeextractors/extractors/npzsortingextractor/npzsortingextractor.py",
    "content": "from spikeextractors import SortingExtractor\nfrom pathlib import Path\nfrom spikeextractors.extraction_tools import check_get_unit_spike_train\nimport numpy as np\n\n\nclass NpzSortingExtractor(SortingExtractor):\n    \"\"\"\n    Dead simple and super light format based on the NPZ numpy format.\n    https://docs.scipy.org/doc/numpy/reference/generated/numpy.savez.html#numpy.savez\n\n    It is in fact an arichive of several .npy format.\n    All spike are store in two columns maner index+labels\n\n\n    \"\"\"\n    extractor_name = 'NpzSorting'\n    installed = True # depend only on numpy\n    installation_mesg = \"Always installed\"\n    is_writable = True\n    mode = 'file'\n\n    def __init__(self, file_path):\n        SortingExtractor.__init__(self)\n        self.npz_filename = file_path\n\n        npz = np.load(file_path)\n\n        self.unit_ids = npz['unit_ids']\n        self.spike_indexes = npz['spike_indexes']\n        self.spike_labels = npz['spike_labels']\n\n        if 'sampling_frequency' in npz:\n            self._sampling_frequency = float(npz['sampling_frequency'][0])\n        else:\n            self._sampling_frequency = None\n        self._kwargs = {'file_path': str(Path(file_path).absolute())}\n\n    def get_unit_ids(self):\n        return list(self.unit_ids)\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n        spike_times = self.spike_indexes[self.spike_labels == unit_id]\n        if start_frame is not None:\n            spike_times = spike_times[spike_times >= start_frame]\n        if end_frame is not None:\n            spike_times = spike_times[spike_times < end_frame]\n        return spike_times.astype('int64')\n\n    @staticmethod\n    def write_sorting(sorting, save_path):\n        d = {}\n        units_ids = np.array(sorting.get_unit_ids())\n        d['unit_ids'] = units_ids\n        spike_indexes = []\n        spike_labels = []\n        for unit_id in units_ids:\n            sp_ind = sorting.get_unit_spike_train(unit_id)\n            spike_indexes.append(sp_ind)\n            spike_labels.append(np.ones(sp_ind.size, dtype='int64')*unit_id)\n\n        # order times\n        if len(spike_indexes) > 0:\n            spike_indexes = np.concatenate(spike_indexes)\n            spike_labels = np.concatenate(spike_labels)\n            order = np.argsort(spike_indexes)\n            spike_indexes = spike_indexes[order]\n            spike_labels = spike_labels[order]\n        else:\n            spike_indexes = np.array([], dtype='int64')\n            spike_labels = np.array([], dtype='int64')\n\n        d['spike_indexes'] = spike_indexes\n        d['spike_labels'] = spike_labels\n\n        if sorting.get_sampling_frequency() is not None:\n            d['sampling_frequency'] = np.array([sorting.get_sampling_frequency()], dtype='float64')\n\n        np.savez(save_path, **d)\n"
  },
  {
    "path": "spikeextractors/extractors/numpyextractors/__init__.py",
    "content": "from .numpyextractors import NumpyRecordingExtractor, NumpySortingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/numpyextractors/numpyextractors.py",
    "content": "from spikeextractors import RecordingExtractor\nfrom spikeextractors import SortingExtractor\nfrom pathlib import Path\nimport numpy as np\nfrom spikeextractors.extraction_tools import check_get_traces_args, check_get_unit_spike_train, check_get_ttl_args\n\n\"\"\"\nThe NumpyExtractors can be constructed and used to encapsulate custom file formats and data structures which\ncontain information about recordings or sorting results. NumpyExtractors are instantiated in-memory and function\nlike any other Recording/SortingExtractor.\n\"\"\"\n\nclass NumpyRecordingExtractor(RecordingExtractor):\n    extractor_name = 'NumpyRecording'\n    is_writable = True\n    has_default_locations = False\n    has_unscaled = False\n\n    def __init__(self, timeseries, sampling_frequency, geom=None):\n        RecordingExtractor.__init__(self)\n        if isinstance(timeseries, str):\n            if Path(timeseries).is_file():\n                assert Path(timeseries).suffix == '.npy', \"'timeseries' file is not a numpy file (.npy)\"\n                self.is_dumpable = True\n                self._timeseries = np.load(timeseries)\n                self._kwargs = {'timeseries': str(Path(timeseries).absolute()),\n                                'sampling_frequency': sampling_frequency, 'geom': geom}\n            else:\n                raise ValueError(\"'timeeseries' is does not exist\")\n        elif isinstance(timeseries, np.ndarray):\n            self.is_dumpable = False\n            self._timeseries = timeseries\n            self._kwargs = {'timeseries': timeseries,\n                            'sampling_frequency': sampling_frequency, 'geom': geom}\n        else:\n            raise TypeError(\"'timeseries' can be a str or a numpy array\")\n        self._sampling_frequency = float(sampling_frequency)\n        self._geom = geom\n        if geom is not None:\n            self.set_channel_locations(self._geom)\n\n        self._ttl_frames = None\n        self._ttl_states = None\n\n    def set_ttls(self, ttl_frames, ttl_states=None):\n        self._ttl_frames = ttl_frames.astype('int64')\n        if ttl_states is not None:\n            self._ttl_states = ttl_states.astype('int64')\n        else:\n            self._ttl_states = np.ones_like(ttl_frames, dtype='int64')\n\n    def get_channel_ids(self):\n        return list(range(self._timeseries.shape[0]))\n\n    def get_num_frames(self):\n        return self._timeseries.shape[1]\n\n    def get_sampling_frequency(self):\n        return self._sampling_frequency\n\n    @check_get_traces_args\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        recordings = self._timeseries[:, start_frame:end_frame][channel_ids, :]\n        return recordings\n\n    @check_get_ttl_args\n    def get_ttl_events(self, start_frame=None, end_frame=None, channel_id=0):\n        if self._ttl_frames is not None and self._ttl_states is not None:\n            ttl_idxs = np.where((self._ttl_frames >= start_frame) & (self._ttl_frames < end_frame))[0]\n            return self._ttl_frames[ttl_idxs], self._ttl_states[ttl_idxs]\n        else:\n            print(\"TTL frames have not been added to the extractor. You can add them with the `set_ttls()1 function\")\n            return None, None\n\n    @staticmethod\n    def write_recording(recording, save_path):\n        save_path = Path(save_path)\n        np.save(save_path, recording.get_traces())\n\n\nclass NumpySortingExtractor(SortingExtractor):\n    extractor_name = 'NumpySorting'\n    is_writable = False\n\n    def __init__(self):\n        SortingExtractor.__init__(self)\n        self._units = {}\n        self.is_dumpable = False\n\n    def load_from_extractor(self, sorting, copy_unit_properties=False, copy_unit_spike_features=False):\n        \"\"\"This function loads the information from a SortingExtractor into this extractor.\n\n        Parameters\n        ----------\n        sorting: SortingExtractor\n            The SortingExtractor from which this extractor will copy information.\n        copy_unit_properties: bool\n            If True, the unit_properties will be copied from the given SortingExtractor to this extractor.\n        copy_unit_spike_features: bool\n            If True, the unit_spike_features will be copied from the given SortingExtractor to this extractor.\n        \"\"\"\n        ids = sorting.get_unit_ids()\n        for id in ids:\n            self.add_unit(id, sorting.get_unit_spike_train(id))\n        if sorting.get_sampling_frequency() is not None:\n            self.set_sampling_frequency(sorting.get_sampling_frequency())\n        if copy_unit_properties:\n            self.copy_unit_properties(sorting)\n        if copy_unit_spike_features:\n            self.copy_unit_spike_features(sorting)\n\n    def set_sampling_frequency(self, sampling_frequency):\n        self._sampling_frequency = sampling_frequency\n\n    def set_times_labels(self, times, labels):\n        \"\"\"This function takes in an array of spike times (in frames) and an array of spike labels and adds all the\n        unit information in these lists into the extractor.\n\n        Parameters\n        ----------\n        times: np.array\n            An array of spike times (in frames).\n        labels: np.array\n            An array of spike labels corresponding to the given times.\n        \"\"\"\n        units = np.sort(np.unique(labels))\n        for unit in units:\n            times0 = times[np.where(labels == unit)[0]]\n            self.add_unit(unit_id=int(unit), times=times0)\n\n    def add_unit(self, unit_id, times):\n        \"\"\"This function adds a new unit with the given spike times.\n\n        Parameters\n        ----------\n        unit_id: int\n            The unit_id of the unit to be added.\n        times: np.array\n            An array of spike times (in frames).\n        \"\"\"\n        self._units[unit_id] = dict(times=times)\n\n    def get_unit_ids(self):\n        return list(self._units.keys())\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n        times = self._units[unit_id]['times']\n        inds = np.where((start_frame <= times) & (times < end_frame))[0]\n        return np.rint(times[inds]).astype(int)\n"
  },
  {
    "path": "spikeextractors/extractors/nwbextractors/__init__.py",
    "content": "from .nwbextractors import NwbRecordingExtractor, NwbSortingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/nwbextractors/nwbextractors.py",
    "content": "import uuid\nfrom datetime import datetime\nfrom collections import abc\nfrom pathlib import Path\nimport numpy as np\nfrom packaging.version import parse\nfrom typing import Union, List, Optional\nimport warnings\n\nimport spikeextractors as se\nfrom spikeextractors.extraction_tools import check_get_traces_args, check_get_unit_spike_train\n\ntry:\n    import pandas as pd\n    import pynwb\n    from pynwb import NWBHDF5IO\n    from pynwb import NWBFile\n    from pynwb.ecephys import ElectricalSeries, FilteredEphys, LFP\n    from pynwb.ecephys import ElectrodeGroup\n    from hdmf.data_utils import DataChunkIterator\n    from hdmf.backends.hdf5.h5_utils import H5DataIO\n\n    HAVE_NWB = True\nexcept ModuleNotFoundError:\n    HAVE_NWB = False\n\nPathType = Union[str, Path, None]\nArrayType = Union[list, np.ndarray]\n\n\ndef check_nwb_install():\n    assert HAVE_NWB, NwbRecordingExtractor.installation_mesg\n\n\ndef set_dynamic_table_property(dynamic_table, row_ids, property_name, values, index=False,\n                               default_value=np.nan, table=False, description='no description'):\n    check_nwb_install()\n    if not isinstance(row_ids, list) or not all(isinstance(x, (int, np.integer)) for x in row_ids):\n        raise TypeError(\"'ids' must be a list of integers\")\n    ids = list(dynamic_table.id[:])\n    if any([i not in ids for i in row_ids]):\n        raise ValueError(\"'ids' contains values outside the range of existing ids\")\n    if not isinstance(property_name, str):\n        raise TypeError(\"'property_name' must be a string\")\n    if len(row_ids) != len(values) and index is False:\n        raise ValueError(\"'ids' and 'values' should be lists of same size\")\n\n    if index is False:\n        if property_name in dynamic_table:\n            for (row_id, value) in zip(row_ids, values):\n                dynamic_table[property_name].data[ids.index(row_id)] = value\n        else:\n            col_data = [default_value] * len(ids)  # init with default val\n            for (row_id, value) in zip(row_ids, values):\n                col_data[ids.index(row_id)] = value\n            dynamic_table.add_column(\n                name=property_name,\n                description=description,\n                data=col_data,\n                index=index,\n                table=table\n            )\n    else:\n        if property_name in dynamic_table:\n            # TODO\n            raise NotImplementedError\n        else:\n            dynamic_table.add_column(\n                name=property_name,\n                description=description,\n                data=values,\n                index=index,\n                table=table\n            )\n\n\ndef get_dynamic_table_property(dynamic_table, *, row_ids=None, property_name):\n    all_row_ids = list(dynamic_table.id[:])\n    if row_ids is None:\n        row_ids = all_row_ids\n    return [dynamic_table[property_name][all_row_ids.index(x)] for x in row_ids]\n\n\ndef get_nspikes(units_table, unit_id):\n    \"\"\"Return the number of spikes for chosen unit.\"\"\"\n    check_nwb_install()\n    ids = np.array(units_table.id[:])\n    indexes = np.where(ids == unit_id)[0]\n    if not len(indexes):\n        raise ValueError(f\"{unit_id} is an invalid unit_id. Valid ids: {ids}.\")\n    index = indexes[0]\n    if index == 0:\n        return units_table['spike_times_index'].data[index]\n    else:\n        return units_table['spike_times_index'].data[index] - units_table['spike_times_index'].data[index - 1]\n\n\ndef most_relevant_ch(traces: ArrayType):\n    \"\"\"\n    Calculate the most relevant channel for a given Unit.\n\n    Estimates the channel where the max-min difference of the average traces is greatest.\n\n    Parameters\n    ----------\n    traces : ndarray\n        ndarray of shape (nSpikes, nChannels, nSamples)\n    \"\"\"\n    n_channels = traces.shape[1]\n    avg = np.mean(traces, axis=0)\n\n    max_min = np.zeros(n_channels)\n    for ch in range(n_channels):\n        max_min[ch] = avg[ch, :].max() - avg[ch, :].min()\n\n    relevant_ch = np.argmax(max_min)\n    return relevant_ch\n\n\ndef update_dict(d: dict, u: dict):\n    \"\"\"Smart dictionary updates.\"\"\"\n    if u is not None:\n        for k, v in u.items():\n            if isinstance(v, abc.Mapping):\n                d[k] = update_dict(d.get(k, {}), v)\n            else:\n                d[k] = v\n    return d\n\n\ndef list_get(li: list, idx: int, default):\n    \"\"\"Safe index retrieval from list.\"\"\"\n    try:\n        return li[idx]\n    except IndexError:\n        return default\n\n\ndef check_module(nwbfile, name: str, description: str = None):\n    \"\"\"\n    Check if processing module exists. If not, create it. Then return module.\n\n    Parameters\n    ----------\n    nwbfile: pynwb.NWBFile\n    name: str\n    description: str | None (optional)\n\n    Returns\n    -------\n    pynwb.module\n    \"\"\"\n    assert isinstance(nwbfile, pynwb.NWBFile), \"'nwbfile' should be of type pynwb.NWBFile\"\n    if name in nwbfile.modules:\n        return nwbfile.modules[name]\n    else:\n        if description is None:\n            description = name\n        return nwbfile.create_processing_module(name, description)\n\n\nclass NwbRecordingExtractor(se.RecordingExtractor):\n    \"\"\"Primary class for interfacing between NWBFiles and RecordingExtractors.\"\"\"\n\n    extractor_name = 'NwbRecording'\n    has_default_locations = True\n    has_unscaled = False\n    installed = HAVE_NWB  # check at class level if installed or not\n    is_writable = True\n    mode = 'file'\n    installation_mesg = \"To use the Nwb extractors, install pynwb: \\n\\n pip install pynwb\\n\\n\"\n\n    def __init__(self, file_path: PathType, electrical_series_name: str = None):\n        \"\"\"\n        Load an NWBFile as a RecordingExtractor.\n\n        Parameters\n        ----------\n        file_path: path to NWB file\n        electrical_series_name: str, optional\n        \"\"\"\n        assert self.installed, self.installation_mesg\n        se.RecordingExtractor.__init__(self)\n        self._path = str(file_path)\n        with NWBHDF5IO(self._path, 'r') as io:\n            nwbfile = io.read()\n            if electrical_series_name is not None:\n                self._electrical_series_name = electrical_series_name\n            else:\n                a_names = list(nwbfile.acquisition)\n                if len(a_names) > 1:\n                    raise ValueError(\"More than one acquisition found! You must specify 'electrical_series_name'.\")\n                if len(a_names) == 0:\n                    raise ValueError(\"No acquisitions found in the .nwb file.\")\n                self._electrical_series_name = a_names[0]\n            es = nwbfile.acquisition[self._electrical_series_name]\n            if hasattr(es, 'timestamps') and es.timestamps:\n                self.sampling_frequency = 1. / np.median(np.diff(es.timestamps))\n                self.recording_start_time = es.timestamps[0]\n            else:\n                self.sampling_frequency = es.rate\n                if hasattr(es, 'starting_time'):\n                    self.recording_start_time = es.starting_time\n                else:\n                    self.recording_start_time = 0.\n\n            self.num_frames = int(es.data.shape[0])\n            num_channels = len(es.electrodes.data)\n\n            # Channels gains - for RecordingExtractor, these are values to cast traces to uV\n            if es.channel_conversion is not None:\n                gains = es.conversion * es.channel_conversion[:] * 1e6\n            else:\n                gains = es.conversion * np.ones(num_channels) * 1e6\n            # Extractors channel groups must be integers, but Nwb electrodes group_name can be strings\n            if 'group_name' in nwbfile.electrodes.colnames:\n                unique_grp_names = list(np.unique(nwbfile.electrodes['group_name'][:]))\n\n            # Fill channel properties dictionary from electrodes table\n            self.channel_ids = [es.electrodes.table.id[x] for x in es.electrodes.data]\n\n            # If gains are not 1, set has_scaled to True\n            if np.any(gains != 1):\n                self.set_channel_gains(gains)\n                self.has_unscaled = True\n\n            for es_ind, (channel_id, electrode_table_index) in enumerate(zip(self.channel_ids, es.electrodes.data)):\n                this_loc = []\n                if 'rel_x' in nwbfile.electrodes:\n                    this_loc.append(nwbfile.electrodes['rel_x'][electrode_table_index])\n                    if 'rel_y' in nwbfile.electrodes:\n                        this_loc.append(nwbfile.electrodes['rel_y'][electrode_table_index])\n                    else:\n                        this_loc.append(0)\n                    self.set_channel_locations(this_loc, channel_id)\n\n                for col in nwbfile.electrodes.colnames:\n                    if isinstance(nwbfile.electrodes[col][electrode_table_index], ElectrodeGroup):\n                        continue\n                    elif col == 'group_name':\n                        self.set_channel_groups(\n                            int(unique_grp_names.index(nwbfile.electrodes[col][electrode_table_index])), channel_id)\n                    elif col == 'location':\n                        self.set_channel_property(channel_id, 'brain_area',\n                                                  nwbfile.electrodes[col][electrode_table_index])\n                    elif col == 'offset':\n                        self.set_channel_offsets(channel_ids=channel_id,\n                                                 offsets=nwbfile.electrodes[col][electrode_table_index])\n                    elif col in ['x', 'y', 'z', 'rel_x', 'rel_y']:\n                        continue\n                    else:\n                        self.set_channel_property(channel_id, col, nwbfile.electrodes[col][electrode_table_index])\n\n            # Fill epochs dictionary\n            self._epochs = {}\n            if nwbfile.epochs is not None:\n                df_epochs = nwbfile.epochs.to_dataframe()\n\n                if 'tags' in df_epochs:\n                    tags_or_label = 'tags'  # older nwb schema version\n                else:\n                    tags_or_label = 'label'\n\n                self._epochs = {\n                    row[tags_or_label][0]: {\n                        'start_frame': self.time_to_frame(row['start_time']),\n                        'end_frame': self.time_to_frame(row['stop_time'])\n                    }\n                    for _, row in df_epochs.iterrows()\n                }\n\n            self._kwargs = {'file_path': str(Path(file_path).absolute()),\n                            'electrical_series_name': electrical_series_name}\n            self.make_nwb_metadata(nwbfile=nwbfile, es=es)\n\n    def make_nwb_metadata(self, nwbfile, es):\n        # Metadata dictionary - useful for constructing a nwb file\n        self.nwb_metadata = dict()\n        self.nwb_metadata['NWBFile'] = {\n            'session_description': nwbfile.session_description,\n            'identifier': nwbfile.identifier,\n            'session_start_time': nwbfile.session_start_time,\n            'institution': nwbfile.institution,\n            'lab': nwbfile.lab\n        }\n        self.nwb_metadata['Ecephys'] = dict()\n        # Update metadata with Device info\n        self.nwb_metadata['Ecephys']['Device'] = []\n        for dev in nwbfile.devices:\n            self.nwb_metadata['Ecephys']['Device'].append({'name': dev})\n        # Update metadata with ElectrodeGroup info\n        self.nwb_metadata['Ecephys']['ElectrodeGroup'] = []\n        for k, v in nwbfile.electrode_groups.items():\n            self.nwb_metadata['Ecephys']['ElectrodeGroup'].append({\n                'name': v.name,\n                'description': v.description,\n                'location': v.location,\n                'device': v.device.name\n            })\n        # Update metadata with ElectricalSeries info\n        self.nwb_metadata['Ecephys']['ElectricalSeries'] = dict(\n            name=es.name,\n            description=es.description\n        )\n\n    @check_get_traces_args\n    def get_traces(\n        self,\n        channel_ids: ArrayType = None,\n        start_frame: int = None,\n        end_frame: int = None,\n        return_scaled: bool = True\n    ):\n        with NWBHDF5IO(self._path, 'r') as io:\n            nwbfile = io.read()\n            es = nwbfile.acquisition[self._electrical_series_name]\n            es_channel_ids = np.array(es.electrodes.table.id[:])[es.electrodes.data[:]].tolist()\n            channel_inds = [es_channel_ids.index(id) for id in channel_ids]\n            if np.array(channel_inds).size > 1 and np.any(np.diff(channel_inds) < 0):\n                # h5py constraint does not allow datasets to be indexed out of order\n                ind_sort_order = np.argsort(channel_inds)\n                sorted_channel_inds = np.array(channel_inds)[ind_sort_order]\n                recordings = es.data[start_frame:end_frame, sorted_channel_inds]\n                traces = recordings[:, ind_sort_order].T\n            else:\n                traces = es.data[start_frame:end_frame, channel_inds].T\n        return traces\n\n    def get_sampling_frequency(self):\n        return self.sampling_frequency\n\n    def get_num_frames(self):\n        return self.num_frames\n\n    def get_channel_ids(self):\n        return self.channel_ids\n\n    @staticmethod\n    def add_devices(recording: se.RecordingExtractor, nwbfile=None, metadata: dict = None):\n        \"\"\"\n        Auxiliary static method for nwbextractor.\n\n        Adds device information to nwbfile object.\n        Will always ensure nwbfile has at least one device, but multiple\n        devices within the metadata list will also be created.\n\n        Parameters\n        ----------\n        recording: RecordingExtractor\n        nwbfile: NWBFile\n            nwb file to which the recording information is to be added\n        metadata: dict\n            metadata info for constructing the nwb file (optional).\n            Should be of the format\n                metadata['Ecephys']['Device'] = [{'name': my_name,\n                                                  'description': my_description}, ...]\n\n        Missing keys in an element of metadata['Ecephys']['Device'] will be auto-populated with defaults.\n        \"\"\"\n        if nwbfile is not None:\n            assert isinstance(nwbfile, NWBFile), \"'nwbfile' should be of type pynwb.NWBFile\"\n        \n        defaults = dict(\n            name=\"Device\",\n            description=\"Ecephys probe.\"\n        )\n\n        if metadata is None:\n            metadata = dict() \n\n        if 'Ecephys' not in metadata:\n            metadata['Ecephys'] = dict()\n\n        if 'Device' not in metadata['Ecephys']:\n            metadata['Ecephys']['Device'] = [defaults]\n\n        assert all([isinstance(x, dict) for x in metadata['Ecephys']['Device']]), \\\n            \"Expected metadata['Ecephys']['Device'] to be a list of dictionaries!\"\n\n        for dev in metadata['Ecephys']['Device']:\n            if dev.get('name', defaults['name']) not in nwbfile.devices:\n                nwbfile.create_device(**dict(defaults, **dev))\n\n    @staticmethod\n    def add_electrode_groups(recording: se.RecordingExtractor, nwbfile=None, metadata: dict = None):\n        \"\"\"\n        Auxiliary static method for nwbextractor.\n\n        Adds electrode group information to nwbfile object.\n        Will always ensure nwbfile has at least one electrode group.\n        Will auto-generate a linked device if the specified name does not exist in the nwbfile.\n\n        Parameters\n        ----------\n        recording: RecordingExtractor\n        nwbfile: NWBFile\n            nwb file to which the recording information is to be added\n        metadata: dict\n            metadata info for constructing the nwb file (optional).\n            Should be of the format\n                metadata['Ecephys']['ElectrodeGroup'] = [{'name': my_name,\n                                                          'description': my_description,\n                                                          'location': electrode_location,\n                                                          'device': my_device_name}, ...]\n\n        Missing keys in an element of metadata['Ecephys']['ElectrodeGroup'] will be auto-populated with defaults.\n\n        Group names set by RecordingExtractor channel properties will also be included with passed metadata,\n        but will only use default description and location.\n        \"\"\"\n        if nwbfile is not None:\n            assert isinstance(nwbfile, NWBFile), \"'nwbfile' should be of type pynwb.NWBFile\"\n        if len(nwbfile.devices) == 0:\n            se.NwbRecordingExtractor.add_devices(recording=recording, nwbfile=nwbfile, metadata=metadata)\n\n        if metadata is None:\n            metadata = dict()\n\n        if 'Ecephys' not in metadata:\n            metadata['Ecephys'] = dict()\n\n        defaults = [\n            dict(\n                name=str(group_id),\n                description=\"no description\",\n                location=\"unknown\",\n                device=[i.name for i in nwbfile.devices.values()][0]\n            )\n            for group_id in np.unique(recording.get_channel_groups())\n        ]\n\n        if 'ElectrodeGroup' not in metadata['Ecephys']:\n            metadata['Ecephys']['ElectrodeGroup'] = defaults\n\n        assert all([isinstance(x, dict) for x in metadata['Ecephys']['ElectrodeGroup']]), \\\n            \"Expected metadata['Ecephys']['ElectrodeGroup'] to be a list of dictionaries!\"\n\n        for grp in metadata['Ecephys']['ElectrodeGroup']:\n            if grp.get('name', defaults[0]['name']) not in nwbfile.electrode_groups:\n                device_name = grp.get('device', defaults[0]['device'])\n                if device_name not in nwbfile.devices:\n                    new_device = dict(\n                        Ecephys=dict(\n                            Device=[dict(\n                                name=device_name\n                            )]\n                        )\n                    )\n                    se.NwbRecordingExtractor.add_devices(recording, nwbfile, metadata=new_device)\n                    warnings.warn(f\"Device \\'{device_name}\\' not detected in \"\n                                  \"attempted link to electrode group! Automatically generating.\")\n                electrode_group_kwargs = dict(defaults[0], **grp)\n                # electrode_group_kwargs.pop('device')\n                electrode_group_kwargs.update(device=nwbfile.devices[device_name])\n                nwbfile.create_electrode_group(**electrode_group_kwargs)\n\n        if not nwbfile.electrode_groups:\n            device_name = list(nwbfile.devices.keys())[0]\n            device = nwbfile.devices[device_name]\n            if len(nwbfile.devices) > 1:\n                warnings.warn(\"More than one device found when adding electrode group \"\n                              f\"via channel properties: using device \\'{device_name}\\'. To use a \"\n                              \"different device, indicate it the metadata argument.\")\n\n            electrode_group_kwargs = dict(defaults[0])\n            electrode_group_kwargs.update(device=device)\n            for grp_name in np.unique(recording.get_channel_groups()).tolist():\n                electrode_group_kwargs.update(name=str(grp_name))\n                nwbfile.create_electrode_group(**electrode_group_kwargs)\n\n    @staticmethod\n    def add_electrodes(recording: se.RecordingExtractor, nwbfile=None, metadata: dict = None,\n                       write_scaled: bool = True):\n        \"\"\"\n        Auxiliary static method for nwbextractor.\n\n        Adds channels from recording object as electrodes to nwbfile object.\n\n        Parameters\n        ----------\n        recording: RecordingExtractor\n        nwbfile: NWBFile\n            nwb file to which the recording information is to be added\n        write_scaled: bool (optional, defaults to True)\n            If True, writes the scaled traces (return_scaled=True)\n        metadata: dict\n            metadata info for constructing the nwb file (optional).\n            Should be of the format\n                metadata['Ecephys']['Electrodes'] = [{'name': my_name,\n                                                      'description': my_description,\n                                                      'data': [my_electrode_data]}, ...]\n            where each dictionary corresponds to a column in the Electrodes table and [my_electrode_data] is a list in\n            one-to-one correspondence with the nwbfile electrode ids and RecordingExtractor channel ids.\n\n        Missing keys in an element of metadata['Ecephys']['ElectrodeGroup'] will be auto-populated with defaults\n        whenever possible.\n\n        If 'my_name' is set to one of the required fields for nwbfile\n        electrodes (id, x, y, z, imp, loccation, filtering, group_name),\n        then the metadata will override their default values.\n\n        Setting 'my_name' to metadata field 'group' is not supported as the linking to\n        nwbfile.electrode_groups is handled automatically; please specify the string 'group_name' in this case.\n\n        If no group information is passed via metadata, automatic linking to existing electrode groups,\n        possibly including the default, will occur.\n        \"\"\"\n        if nwbfile is not None:\n            assert isinstance(nwbfile, NWBFile), \"'nwbfile' should be of type pynwb.NWBFile\"\n        if nwbfile.electrode_groups is None:\n            se.NwbRecordingExtractor.add_electrode_groups(recording, nwbfile, metadata)\n        # For older versions of pynwb, we need to manually add these columns\n        if parse(pynwb.__version__) < parse('1.3.0'):\n            if nwbfile.electrodes is None or 'rel_x' not in nwbfile.electrodes.colnames:\n                nwbfile.add_electrode_column('rel_x', 'x position of electrode in electrode group')\n            if nwbfile.electrodes is None or 'rel_y' not in nwbfile.electrodes.colnames:\n                nwbfile.add_electrode_column('rel_y', 'y position of electrode in electrode group')\n\n        defaults = dict(\n            x=np.nan,\n            y=np.nan,\n            z=np.nan,\n            # There doesn't seem to be a canonical default for impedence, if missing.\n            # The NwbRecordingExtractor follows the -1.0 convention, other scripts sometimes use np.nan\n            imp=-1.0,\n            location=\"unknown\",\n            filtering=\"none\",\n            group_name=\"ElectrodeGroup\"\n        )\n        if metadata is None:\n            metadata = dict(Ecephys=dict())\n\n        if 'Electrodes' not in metadata['Ecephys']:\n            metadata['Ecephys']['Electrodes'] = []\n\n        assert all([isinstance(x, dict) and set(x.keys()) == set(['name', 'description', 'data'])\n                    and isinstance(x['data'], list) for x in metadata['Ecephys']['Electrodes']]), \\\n            \"Expected metadata['Ecephys']['Electrodes'] to be a list of dictionaries!\"\n        assert all([x['name'] != 'group' for x in metadata['Ecephys']['Electrodes']]), \\\n            \"Passing metadata field 'group' is depricated; pass group_name instead!\"\n\n        if nwbfile.electrodes is None:\n            nwb_elec_ids = []\n        else:\n            nwb_elec_ids = nwbfile.electrodes.id.data[:]\n\n        for metadata_column in metadata['Ecephys']['Electrodes']:\n            if (nwbfile.electrodes is None or metadata_column['name'] not in nwbfile.electrodes.colnames) \\\n                    and metadata_column['name'] != 'group_name':\n                nwbfile.add_electrode_column(\n                    name=str(metadata_column['name']),\n                    description=str(metadata_column['description'])\n                )\n\n        for j, channel_id in enumerate(recording.get_channel_ids()):\n            if channel_id not in nwb_elec_ids:\n                electrode_kwargs = dict(defaults)\n                electrode_kwargs.update(id=channel_id)\n\n                # recording.get_channel_locations defaults to np.nan if there are none\n                location = recording.get_channel_locations(channel_ids=channel_id)[0]\n                if all([not np.isnan(loc) for loc in location]):\n                    # property 'location' of RX channels corresponds to rel_x and rel_ y of NWB electrodes\n                    electrode_kwargs.update(\n                        dict(\n                            rel_x=float(location[0]),\n                            rel_y=float(location[1])\n                        )\n                    )\n\n                for metadata_column in metadata['Ecephys']['Electrodes']:\n                    if metadata_column['name'] == 'group_name':\n                        group_name = list_get(metadata_column['data'], j, defaults['group_name'])\n                        if group_name not in nwbfile.electrode_groups:\n                            warnings.warn(f\"Electrode group for electrode {channel_id} was not \"\n                                          \"found in the nwbfile! Automatically adding.\")\n                            missing_group_metadata = dict(\n                                Ecephys=dict(\n                                    ElectrodeGroup=[dict(\n                                        name=group_name,\n                                        description=\"no description\",\n                                        location=\"unknown\",\n                                        device=\"Device\"\n                                    )]\n                                )\n                            )\n                            se.NwbRecordingExtractor.add_electrode_groups(recording, nwbfile, missing_group_metadata)\n                        electrode_kwargs.update(\n                            dict(\n                                group=nwbfile.electrode_groups[group_name],\n                                group_name=group_name\n                            )\n                        )\n                    else:\n                        if metadata_column['name'] in defaults:\n                            electrode_kwargs.update({\n                                metadata_column['name']: list_get(metadata_column['data'], j,\n                                                                  defaults[metadata_column['name']])\n                            })\n                        else:\n                            if j < len(metadata_column['data']):\n                                electrode_kwargs.update({\n                                    metadata_column['name']: metadata_column['data'][j]\n                                })\n                            else:\n                                metadata_column_name = metadata_column['name']\n                                warnings.warn(f\"Custom column {metadata_column_name} \"\n                                              f\"has incomplete data for channel id [{j}] and no \"\n                                              \"set default! Electrode will not be added.\")\n                                continue\n\n                if not any([x.get('name', '') == 'group_name' for x in metadata['Ecephys']['Electrodes']]):\n                    group_id = recording.get_channel_groups(channel_ids=channel_id)[0]\n                    if str(group_id) in nwbfile.electrode_groups:\n                        electrode_kwargs.update(\n                            dict(\n                                group=nwbfile.electrode_groups[str(group_id)],\n                                group_name=str(group_id)\n                            )\n                        )\n                    else:\n                        warnings.warn(\"No metadata was passed specifying the electrode group for \"\n                                      f\"electrode {channel_id}, and the internal recording channel group was \"\n                                      f\"assigned a value (str({group_id})) not present as electrode \"\n                                      \"groups in the NWBFile! Electrode will not be added.\")\n                        continue\n\n                nwbfile.add_electrode(**electrode_kwargs)\n        assert nwbfile.electrodes is not None, \\\n            \"Unable to form electrode table! Check device, electrode group, and electrode metadata.\"\n\n        # property 'gain' should not be in the NWB electrodes_table\n        # property 'brain_area' of RX channels corresponds to 'location' of NWB electrodes\n        # property 'offset' should not be in the NWB electrodes_table as not officially supported by schema v2.2.5\n        channel_prop_names = set(recording.get_shared_channel_property_names()) - set(nwbfile.electrodes.colnames) \\\n                             - {'gain', 'location', 'offset'}\n        for channel_prop_name in channel_prop_names:\n            for channel_id in recording.get_channel_ids():\n                val = recording.get_channel_property(channel_id, channel_prop_name)\n                descr = 'no description'\n                if channel_prop_name == 'brain_area':\n                    channel_prop_name = 'location'\n                    descr = 'brain area location'\n                set_dynamic_table_property(\n                    dynamic_table=nwbfile.electrodes,\n                    row_ids=[int(channel_id)],\n                    property_name=channel_prop_name,\n                    values=[val],\n                    default_value=np.nan,\n                    description=descr\n                )\n\n    @staticmethod\n    def add_electrical_series(\n        recording: se.RecordingExtractor,\n        nwbfile=None,\n        metadata: dict = None,\n        buffer_mb: int = 500,\n        use_times: bool = False,\n        write_as: str = 'raw',\n        es_key: str = None,\n        write_scaled: bool = False\n    ):\n        \"\"\"\n        Auxiliary static method for nwbextractor.\n\n        Adds traces from recording object as ElectricalSeries to nwbfile object.\n\n        Parameters\n        ----------\n        recording: RecordingExtractor\n        nwbfile: NWBFile\n            nwb file to which the recording information is to be added\n        metadata: dict\n            metadata info for constructing the nwb file (optional).\n            Should be of the format\n                metadata['Ecephys']['ElectricalSeries'] = {'name': my_name,\n                                                           'description': my_description}\n        buffer_mb: int (optional, defaults to 500MB)\n            maximum amount of memory (in MB) to use per iteration of the\n            DataChunkIterator (requires traces to be memmap objects)\n        use_times: bool (optional, defaults to False)\n            If True, the times are saved to the nwb file using recording.frame_to_time(). If False (defualut),\n            the sampling rate is used.\n        write_as: str (optional, defaults to 'raw')\n            How to save the traces data in the nwb file. Options: \n            - 'raw' will save it in acquisition\n            - 'processed' will save it as FilteredEphys, in a processing module\n            - 'lfp' will save it as LFP, in a processing module\n        es_key: str (optional)\n            Key in metadata dictionary containing metadata info for the specific electrical series\n        write_scaled: bool (optional, defaults to True)\n            If True, writes the scaled traces (return_scaled=True)\n\n        Missing keys in an element of metadata['Ecephys']['ElectrodeGroup'] will be auto-populated with defaults\n        whenever possible.\n        \"\"\"\n        if nwbfile is not None:\n            assert isinstance(nwbfile, NWBFile), \"'nwbfile' should be of type pynwb.NWBFile!\"\n\n        assert buffer_mb > 10, \"'buffer_mb' should be at least 10MB to ensure data can be chunked!\"\n\n        if not nwbfile.electrodes:\n            se.NwbRecordingExtractor.add_electrodes(recording, nwbfile, metadata)\n\n        assert write_as in ['raw', 'processed', 'lfp'], \\\n            f\"'write_as' should be 'raw', 'processed' or 'lfp', but intead received value {write_as}\"\n\n        if write_as == 'raw':\n            eseries_kwargs = dict(\n                name=\"ElectricalSeries_raw\",\n                description=\"Raw acquired data\",\n                comments=\"Generated from SpikeInterface::NwbRecordingExtractor\"\n            )\n        elif write_as == 'processed':\n            eseries_kwargs = dict(\n                name=\"ElectricalSeries_processed\",\n                description=\"Processed data\",\n                comments=\"Generated from SpikeInterface::NwbRecordingExtractor\"\n            )\n            # Check for existing processing module and data interface\n            ecephys_mod = check_module(\n                nwbfile=nwbfile,\n                name='ecephys',\n                description=\"Intermediate data from extracellular electrophysiology recordings, e.g., LFP.\"\n            )\n            if 'Processed' not in ecephys_mod.data_interfaces:\n                ecephys_mod.add(FilteredEphys(name='Processed'))\n        elif write_as == 'lfp':\n            eseries_kwargs = dict(\n                name=\"ElectricalSeries_lfp\",\n                description=\"Processed data - LFP\",\n                comments=\"Generated from SpikeInterface::NwbRecordingExtractor\"\n            )\n            # Check for existing processing module and data interface\n            ecephys_mod = check_module(\n                nwbfile=nwbfile,\n                name='ecephys',\n                description=\"Intermediate data from extracellular electrophysiology recordings, e.g., LFP.\"\n            )\n            if 'LFP' not in ecephys_mod.data_interfaces:\n                ecephys_mod.add(LFP(name='LFP'))\n\n        # If user passed metadata info, overwrite defaults\n        if metadata is not None and 'Ecephys' in metadata and es_key is not None:\n            assert es_key in metadata['Ecephys'], f\"metadata['Ecephys'] dictionary does not contain key '{es_key}'\"\n            eseries_kwargs.update(metadata['Ecephys'][es_key])\n\n        # Check for existing names in nwbfile\n        if write_as == 'raw':\n            assert eseries_kwargs['name'] not in nwbfile.acquisition, \\\n                f\"Raw ElectricalSeries '{eseries_kwargs['name']}' is already written in the NWBFile!\"\n        elif write_as == 'processed':\n            assert eseries_kwargs['name'] not in nwbfile.processing['ecephys'].data_interfaces['Processed'].electrical_series, \\\n                f\"Processed ElectricalSeries '{eseries_kwargs['name']}' is already written in the NWBFile!\"\n        elif write_as == 'lfp':\n            assert eseries_kwargs['name'] not in nwbfile.processing['ecephys'].data_interfaces['LFP'].electrical_series, \\\n                f\"LFP ElectricalSeries '{eseries_kwargs['name']}' is already written in the NWBFile!\"\n\n        # Electrodes table region\n        channel_ids = recording.get_channel_ids()\n        table_ids = [list(nwbfile.electrodes.id[:]).index(id) for id in channel_ids]\n        electrode_table_region = nwbfile.create_electrode_table_region(\n            region=table_ids,\n            description=\"electrode_table_region\"\n        )\n        eseries_kwargs.update(electrodes=electrode_table_region)\n\n        # channels gains - for RecordingExtractor, these are values to cast traces to uV.\n        # For nwb, the conversions (gains) cast the data to Volts.\n        # To get traces in Volts we take data*channel_conversion*conversion.\n        channel_conversion = recording.get_channel_gains()\n        channel_offset = recording.get_channel_offsets()\n        unsigned_coercion = channel_offset / channel_conversion\n        if not np.all([x.is_integer() for x in unsigned_coercion]):\n            raise NotImplementedError(\n                \"Unable to coerce underlying unsigned data type to signed type, which is currently required for NWB \"\n                \"Schema v2.2.5! Please specify 'write_scaled=True'.\"\n            )\n        elif np.any(unsigned_coercion != 0):\n            warnings.warn(\n                \"NWB Schema v2.2.5 does not officially support channel offsets. The data will be converted to a signed \"\n                \"type that does not use offsets.\"\n            )\n            unsigned_coercion = unsigned_coercion.astype(int)\n        if write_scaled:\n            eseries_kwargs.update(conversion=1e-6)\n        else:\n            if len(np.unique(channel_conversion)) == 1:  # if all gains are equal\n                eseries_kwargs.update(conversion=channel_conversion[0] * 1e-6)\n            else:\n                eseries_kwargs.update(conversion=1e-6)\n                eseries_kwargs.update(channel_conversion=channel_conversion)\n\n        if isinstance(recording.get_traces(end_frame=5, return_scaled=write_scaled), np.memmap) \\\n                and np.all(channel_offset == 0):\n            n_bytes = np.dtype(recording.get_dtype()).itemsize\n            buffer_size = int(buffer_mb * 1e6) // (recording.get_num_channels() * n_bytes)\n            ephys_data = DataChunkIterator(\n                data=recording.get_traces(return_scaled=write_scaled).T,  # nwb standard is time as zero axis\n                buffer_size=buffer_size\n            )\n        else:\n            def data_generator(recording, channels_ids, unsigned_coercion, write_scaled):\n                for i, ch in enumerate(channels_ids):\n                    data = recording.get_traces(channel_ids=[ch], return_scaled=write_scaled)\n                    if not write_scaled:\n                        data_dtype_name = data.dtype.name\n                        if data_dtype_name.startswith(\"uint\"):\n                            data_dtype_name = data_dtype_name[1:]  # Retain memory of signed data type\n                        data = data + unsigned_coercion[i]\n                        data = data.astype(data_dtype_name)\n                    yield data.flatten()\n            ephys_data = DataChunkIterator(\n                data=data_generator(\n                    recording=recording,\n                    channels_ids=channel_ids,\n                    unsigned_coercion=unsigned_coercion,\n                    write_scaled=write_scaled\n                ),\n                iter_axis=1,  # nwb standard is time as zero axis\n                maxshape=(recording.get_num_frames(), recording.get_num_channels())\n            )\n\n        eseries_kwargs.update(data=H5DataIO(ephys_data, compression=\"gzip\"))\n        if not use_times:\n            eseries_kwargs.update(\n                starting_time=recording.frame_to_time(0),\n                rate=float(recording.get_sampling_frequency())\n            )\n        else:\n            eseries_kwargs.update(\n                timestamps=H5DataIO(\n                    recording.frame_to_time(np.arange(recording.get_num_frames())),\n                    compression=\"gzip\"\n                )\n            )\n\n        # Add ElectricalSeries to nwbfile object\n        if write_as == 'raw':\n            nwbfile.add_acquisition(ElectricalSeries(**eseries_kwargs))\n        elif write_as == 'processed':\n            ecephys_mod.data_interfaces['Processed'].add_electrical_series(ElectricalSeries(**eseries_kwargs))\n        elif write_as == 'lfp':\n            ecephys_mod.data_interfaces['LFP'].add_electrical_series(ElectricalSeries(**eseries_kwargs))\n            \n\n    @staticmethod\n    def add_epochs(\n        recording: se.RecordingExtractor, \n        nwbfile=None,\n        metadata: dict = None\n    ):\n        \"\"\"\n        Auxiliary static method for nwbextractor.\n        Adds epochs from recording object to nwbfile object.\n\n        Parameters\n        ----------\n        recording: RecordingExtractor\n        nwbfile: NWBFile\n            nwb file to which the recording information is to be added\n        metadata: dict\n            metadata info for constructing the nwb file (optional).\n        \"\"\"\n        if nwbfile is not None:\n            assert isinstance(nwbfile, NWBFile), \"'nwbfile' should be of type pynwb.NWBFile\"\n\n        # add/update epochs\n        for epoch_name in recording.get_epoch_names():\n            epoch = recording.get_epoch_info(epoch_name)\n            if nwbfile.epochs is None:\n                nwbfile.add_epoch(\n                    start_time=recording.frame_to_time(epoch['start_frame']),\n                    stop_time=recording.frame_to_time(epoch['end_frame'] - 1),\n                    tags=epoch_name\n                )\n            else:\n                if [epoch_name] in nwbfile.epochs['tags'][:]:\n                    ind = nwbfile.epochs['tags'][:].index([epoch_name])\n                    nwbfile.epochs['start_time'].data[ind] = recording.frame_to_time(epoch['start_frame'])\n                    nwbfile.epochs['stop_time'].data[ind] = recording.frame_to_time(epoch['end_frame'])\n                else:\n                    nwbfile.add_epoch(\n                        start_time=recording.frame_to_time(epoch['start_frame']),\n                        stop_time=recording.frame_to_time(epoch['end_frame']),\n                        tags=epoch_name\n                    )\n\n    @staticmethod\n    def add_all_to_nwbfile(\n        recording: se.RecordingExtractor,\n        nwbfile=None,\n        buffer_mb: int = 500,\n        use_times: bool = False,\n        metadata: dict = None,\n        write_as: str = 'raw',\n        es_key: str = None,\n        write_scaled: bool = False\n    ):\n        \"\"\"\n        Auxiliary static method for nwbextractor.\n\n        Adds all recording related information from recording object and metadata to the nwbfile object.\n\n        Parameters\n        ----------\n        recording: RecordingExtractor\n        nwbfile: NWBFile\n            nwb file to which the recording information is to be added\n        buffer_mb: int (optional, defaults to 500MB)\n            maximum amount of memory (in MB) to use per iteration of the\n            DataChunkIterator (requires traces to be memmap objects)\n        use_times: bool\n            If True, the times are saved to the nwb file using recording.frame_to_time(). If False (defualut),\n            the sampling rate is used.\n        metadata: dict\n            metadata info for constructing the nwb file (optional).\n            Check the auxiliary function docstrings for more information\n            about metadata format.\n        write_as: str (optional, defaults to 'raw')\n            How to save the traces data in the nwb file. Options: \n            - 'raw' will save it in acquisition\n            - 'processed' will save it as FilteredEphys, in a processing module\n            - 'lfp' will save it as LFP, in a processing module\n        es_key: str (optional)\n            Key in metadata dictionary containing metadata info for the specific electrical series\n        write_scaled: bool (optional, defaults to True)\n            If True, writes the scaled traces (return_scaled=True)\n        \"\"\"\n        if nwbfile is not None:\n            assert isinstance(nwbfile, NWBFile), \"'nwbfile' should be of type pynwb.NWBFile\"\n\n        se.NwbRecordingExtractor.add_devices(\n            recording=recording,\n            nwbfile=nwbfile,\n            metadata=metadata\n        )\n\n        se.NwbRecordingExtractor.add_electrode_groups(\n            recording=recording,\n            nwbfile=nwbfile,\n            metadata=metadata\n        )\n        se.NwbRecordingExtractor.add_electrodes(\n            recording=recording,\n            nwbfile=nwbfile,\n            metadata=metadata,\n            write_scaled=write_scaled\n        )\n        se.NwbRecordingExtractor.add_electrical_series(\n            recording=recording,\n            nwbfile=nwbfile,\n            buffer_mb=buffer_mb,\n            use_times=use_times,\n            metadata=metadata,\n            write_as=write_as,\n            es_key=es_key,\n            write_scaled=write_scaled\n        )\n        se.NwbRecordingExtractor.add_epochs(\n            recording=recording,\n            nwbfile=nwbfile,\n            metadata=metadata\n        )\n\n    @staticmethod\n    def write_recording(\n        recording: se.RecordingExtractor,\n        save_path: PathType = None,\n        overwrite: bool = False,\n        nwbfile=None,\n        buffer_mb: int = 500,\n        use_times: bool = False,\n        metadata: dict = None,\n        write_as: str = 'raw',\n        es_key: str = None,\n        write_scaled: bool = False\n    ):\n        \"\"\"\n        Primary method for writing a RecordingExtractor object to an NWBFile.\n\n        Parameters\n        ----------\n        recording: RecordingExtractor\n        save_path: PathType\n            Required if an nwbfile is not passed. Must be the path to the nwbfile\n            being appended, otherwise one is created and written.\n        overwrite: bool\n            If using save_path, whether or not to overwrite the NWBFile if it already exists.\n        nwbfile: NWBFile\n            Required if a save_path is not specified. If passed, this function\n            will fill the relevant fields within the nwbfile. E.g., calling\n            spikeextractors.NwbRecordingExtractor.write_recording(\n                my_recording_extractor, my_nwbfile\n            )\n            will result in the appropriate changes to the my_nwbfile object.\n        buffer_mb: int (optional, defaults to 500MB)\n            maximum amount of memory (in MB) to use per iteration of the\n            DataChunkIterator (requires traces to be memmap objects)\n        use_times: bool\n            If True, the times are saved to the nwb file using recording.frame_to_time(). If False (defualut),\n            the sampling rate is used.\n        metadata: dict\n            metadata info for constructing the nwb file (optional). Should be\n            of the format\n                metadata['Ecephys'] = {}\n            with keys of the forms\n                metadata['Ecephys']['Device'] = [{'name': my_name,\n                                                  'description': my_description}, ...]\n                metadata['Ecephys']['ElectrodeGroup'] = [{'name': my_name,\n                                                          'description': my_description,\n                                                          'location': electrode_location,\n                                                          'device': my_device_name}, ...]\n                metadata['Ecephys']['Electrodes'] = [{'name': my_name,\n                                                      'description': my_description,\n                                                      'data': [my_electrode_data]}, ...]\n                metadata['Ecephys']['ElectricalSeries'] = {'name': my_name,\n                                                           'description': my_description}\n        write_as: str (optional, defaults to 'raw')\n            How to save the traces data in the nwb file. Options: \n            - 'raw' will save it in acquisition\n            - 'processed' will save it as FilteredEphys, in a processing module\n            - 'lfp' will save it as LFP, in a processing module\n        es_key: str (optional)\n            Key in metadata dictionary containing metadata info for the specific electrical series\n        write_scaled: bool (optional, defaults to True)\n            If True, writes the scaled traces (return_scaled=True)\n        \"\"\"\n        assert HAVE_NWB, NwbRecordingExtractor.installation_mesg\n\n        if nwbfile is not None:\n            assert isinstance(nwbfile, NWBFile), \"'nwbfile' should be of type pynwb.NWBFile\"\n\n        assert parse(pynwb.__version__) >= parse('1.3.3'), \\\n            \"'write_recording' not supported for version < 1.3.3. Run pip install --upgrade pynwb\"\n\n        assert save_path is None or nwbfile is None, \\\n            \"Either pass a save_path location, or nwbfile object, but not both!\"\n\n        # Update any previous metadata with user passed dictionary\n        if hasattr(recording, 'nwb_metadata'):\n            metadata = update_dict(recording.nwb_metadata, metadata)\n        elif metadata is None:\n            # If not NWBRecording, make metadata from information available on Recording\n            metadata_0 = se.NwbRecordingExtractor.get_nwb_metadata(recording=recording)\n            metadata = update_dict(metadata_0, metadata)\n\n        if nwbfile is None:\n            if Path(save_path).is_file() and not overwrite:\n                read_mode = 'r+'\n            else:\n                read_mode = 'w'\n\n            with NWBHDF5IO(str(save_path), mode=read_mode) as io:\n                if read_mode == 'r+':\n                    nwbfile = io.read()\n                else:\n                    # Default arguments will be over-written if contained in metadata\n                    nwbfile_kwargs = dict(\n                        session_description=\"Auto-generated by NwbRecordingExtractor without description.\",\n                        identifier=str(uuid.uuid4()),\n                        session_start_time=datetime(1970, 1, 1)\n                    )\n                    if metadata is not None and 'NWBFile' in metadata:\n                        nwbfile_kwargs.update(metadata['NWBFile'])\n                    nwbfile = NWBFile(**nwbfile_kwargs)\n\n                se.NwbRecordingExtractor.add_all_to_nwbfile(\n                    recording=recording,\n                    nwbfile=nwbfile,\n                    buffer_mb=buffer_mb,\n                    metadata=metadata,\n                    use_times=use_times,\n                    write_as=write_as,\n                    es_key=es_key,\n                    write_scaled=write_scaled\n                )\n\n                # Write to file\n                io.write(nwbfile)\n        else:\n            se.NwbRecordingExtractor.add_all_to_nwbfile(\n                recording=recording,\n                nwbfile=nwbfile,\n                buffer_mb=buffer_mb,\n                use_times=use_times,\n                metadata=metadata,\n                write_as=write_as,\n                es_key=es_key,\n                write_scaled=write_scaled\n            )\n\n    @staticmethod\n    def get_nwb_metadata(recording: se.RecordingExtractor, metadata: dict = None):\n        \"\"\"\n        Parameters\n        ----------\n        recording: RecordingExtractor\n        metadata: dict\n            metadata info for constructing the nwb file (optional).\n        \"\"\"\n        metadata = dict(\n            NWBFile=dict(\n                session_description=\"Auto-generated by NwbRecordingExtractor without description.\",\n                identifier=str(uuid.uuid4()),\n                session_start_time=datetime(1970, 1, 1)\n            ),\n            Ecephys=dict(\n                Device=[dict(\n                    name=\"Device\",\n                    description=\"no description\"\n                )],\n                ElectrodeGroup=[\n                    dict(\n                        name=str(gn),\n                        description=\"no description\",\n                        location=\"unknown\",\n                        device=\"Device\"\n                    ) for gn in np.unique(recording.get_channel_groups())\n                ]\n            )\n        )\n        return metadata\n\n\nclass NwbSortingExtractor(se.SortingExtractor):\n    extractor_name = 'NwbSorting'\n    installed = HAVE_NWB  # check at class level if installed or not\n    is_writable = True\n    mode = 'file'\n    installation_mesg = \"To use the Nwb extractors, install pynwb: \\n\\n pip install pynwb\\n\\n\"\n\n    def __init__(self, file_path, electrical_series=None, sampling_frequency=None):\n        \"\"\"\n        Parameters\n        ----------\n        path: path to NWB file\n        electrical_series: pynwb.ecephys.ElectricalSeries object\n        \"\"\"\n        assert self.installed, self.installation_mesg\n        se.SortingExtractor.__init__(self)\n        self._path = str(file_path)\n        with NWBHDF5IO(self._path, 'r') as io:\n            nwbfile = io.read()\n            if sampling_frequency is None:\n                # defines the electrical series from where the sorting came from\n                # important to know the sampling_frequency\n                if electrical_series is None:\n                    if len(nwbfile.acquisition) > 1:\n                        raise Exception('More than one acquisition found. You must specify electrical_series.')\n                    if len(nwbfile.acquisition) == 0:\n                        raise Exception(\"No acquisitions found in the .nwb file from which to read sampling frequency. \\\n                                         Please, specify 'sampling_frequency' parameter.\")\n                    es = list(nwbfile.acquisition.values())[0]\n                else:\n                    es = electrical_series\n                # get rate\n                if es.rate is not None:\n                    self._sampling_frequency = es.rate\n                else:\n                    self._sampling_frequency = 1 / (es.timestamps[1] - es.timestamps[0])\n            else:\n                self._sampling_frequency = sampling_frequency\n\n            # get all units ids\n            units_ids = nwbfile.units.id[:]\n\n            # store units properties and spike features to dictionaries\n            all_pr_ft = list(nwbfile.units.colnames)\n            all_names = [i.name for i in nwbfile.units.columns]\n            for item in all_pr_ft:\n                if item == 'spike_times':\n                    continue\n                # test if item is a unit_property or a spike_feature\n                if item + '_index' in all_names:  # if it has index, it is a spike_feature\n                    for u_id in units_ids:\n                        ind = list(units_ids).index(u_id)\n                        self.set_unit_spike_features(u_id, item, nwbfile.units[item][ind])\n                else:  # if it is unit_property\n                    for u_id in units_ids:\n                        ind = list(units_ids).index(u_id)\n                        if isinstance(nwbfile.units[item][ind], pd.DataFrame):\n                            prop_value = nwbfile.units[item][ind].index[0]\n                        else:\n                            prop_value = nwbfile.units[item][ind]\n\n                        if isinstance(prop_value, (list, np.ndarray)):\n                            self.set_unit_property(u_id, item, prop_value)\n                        else:\n                            if prop_value == prop_value:  # not nan\n                                self.set_unit_property(u_id, item, prop_value)\n\n            # Fill epochs dictionary\n            self._epochs = {}\n            if nwbfile.epochs is not None:\n                df_epochs = nwbfile.epochs.to_dataframe()\n                self._epochs = {row['tags'][0]: {\n                    'start_frame': self.time_to_frame(row['start_time']),\n                    'end_frame': self.time_to_frame(row['stop_time'])}\n                    for _, row in df_epochs.iterrows()}\n        self._kwargs = {'file_path': str(Path(file_path).absolute()), 'electrical_series': electrical_series,\n                        'sampling_frequency': sampling_frequency}\n\n    def get_unit_ids(self):\n        \"\"\"This function returns a list of ids (ints) for each unit in the sorted result.\n        Returns\n        ----------\n        unit_ids: array_like\n            A list of the unit ids in the sorted result (ints).\n        \"\"\"\n        check_nwb_install()\n        with NWBHDF5IO(self._path, 'r') as io:\n            nwbfile = io.read()\n            unit_ids = [int(i) for i in nwbfile.units.id[:]]\n        return unit_ids\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n\n        check_nwb_install()\n        with NWBHDF5IO(self._path, 'r') as io:\n            nwbfile = io.read()\n            # chosen unit and interval\n            times = nwbfile.units['spike_times'][list(nwbfile.units.id[:]).index(unit_id)][:]\n            # spike times are measured in samples\n            frames = self.time_to_frame(times)\n        return frames[(frames > start_frame) & (frames < end_frame)]\n\n    @staticmethod\n    def write_units(\n            sorting: se.SortingExtractor,\n            nwbfile,\n            property_descriptions: Optional[dict] = None,\n            skip_properties: Optional[List[str]] = None,\n            skip_features: Optional[List[str]] = None,\n            use_times: bool = True\n    ):\n        \"\"\"Auxilliary function for write_sorting.\"\"\"\n        unit_ids = sorting.get_unit_ids()\n        fs = sorting.get_sampling_frequency()\n        if fs is None:\n            raise ValueError(\"Writing a SortingExtractor to an NWBFile requires a known sampling frequency!\")\n\n        all_properties = set()\n        all_features = set()\n        for unit_id in unit_ids:\n            all_properties.update(sorting.get_unit_property_names(unit_id))\n            all_features.update(sorting.get_unit_spike_feature_names(unit_id))\n\n        default_descriptions = dict(\n            isi_violation=\"Quality metric that measures the ISI violation ratio as a proxy for the purity of the unit.\",\n            firing_rate=\"Number of spikes per unit of time.\",\n            template=\"The extracellular average waveform.\",\n            max_channel=\"The recording channel id with the largest amplitude.\",\n            halfwidth=\"The full-width half maximum of the negative peak computed on the maximum channel.\",\n            peak_to_valley=\"The duration between the negative and the positive peaks computed on the maximum channel.\",\n            snr=\"The signal-to-noise ratio of the unit.\",\n            quality=\"Quality of the unit as defined by phy (good, mua, noise).\",\n            spike_amplitude=\"Average amplitude of peaks detected on the channel.\",\n            spike_rate=\"Average rate of peaks detected on the channel.\"\n        )\n        if property_descriptions is None:\n            property_descriptions = dict(default_descriptions)\n        else:\n            property_descriptions = dict(default_descriptions, **property_descriptions)\n        if skip_properties is None:\n            skip_properties = []\n        if skip_features is None:\n            skip_features = []\n\n        if nwbfile.units is None:\n            # Check that array properties have the same shape across units\n            property_shapes = dict()\n            for pr in all_properties:\n                shapes = []\n                for unit_id in unit_ids:\n                    if pr in sorting.get_unit_property_names(unit_id):\n                        prop_value = sorting.get_unit_property(unit_id, pr)\n                        if isinstance(prop_value, (int, np.integer, float, str, bool)):\n                            shapes.append(1)\n                        elif isinstance(prop_value, (list, np.ndarray)):\n                            if np.array(prop_value).ndim == 1:\n                                shapes.append(len(prop_value))\n                            else:\n                                shapes.append(np.array(prop_value).shape)\n                        elif isinstance(prop_value, dict):\n                            print(f\"Skipping property '{pr}' because dictionaries are not supported.\")\n                            skip_properties.append(pr)\n                            break\n                    else:\n                        shapes.append(np.nan)\n                property_shapes[pr] = shapes\n\n            for pr in property_shapes.keys():\n                elems = [elem for elem in property_shapes[pr] if not np.any(np.isnan(elem))]\n                if not np.all([elem == elems[0] for elem in elems]):\n                    print(f\"Skipping property '{pr}' because it has variable size across units.\")\n                    skip_properties.append(pr)\n\n            write_properties = set(all_properties) - set(skip_properties)\n            for pr in write_properties:\n                if pr not in property_descriptions:\n                    warnings.warn(\n                        f\"Description for property {pr} not found in property_descriptions. \"\n                        \"Setting description to 'no description'\"\n                    )\n            for pr in write_properties:\n                unit_col_args = dict(name=pr, description=property_descriptions.get(pr, \"No description.\"))\n                if pr in ['max_channel', 'max_electrode'] and nwbfile.electrodes is not None:\n                    unit_col_args.update(table=nwbfile.electrodes)\n                nwbfile.add_unit_column(**unit_col_args)\n\n            for unit_id in unit_ids:\n                unit_kwargs = dict()\n                if use_times:\n                    spkt = sorting.frame_to_time(sorting.get_unit_spike_train(unit_id=unit_id))\n                else:\n                    spkt = sorting.get_unit_spike_train(unit_id=unit_id) / sorting.get_sampling_frequency()\n                for pr in write_properties:\n                    if pr in sorting.get_unit_property_names(unit_id):\n                        prop_value = sorting.get_unit_property(unit_id, pr)\n                        unit_kwargs.update({pr: prop_value})\n                    else:  # Case of missing data for this unit and this property\n                        unit_kwargs.update({pr: np.nan})\n                nwbfile.add_unit(id=int(unit_id), spike_times=spkt, **unit_kwargs)\n\n            # TODO\n            # # Stores average and std of spike traces\n            # This will soon be updated to the current NWB standard\n            # if 'waveforms' in sorting.get_unit_spike_feature_names(unit_id=id):\n            #     wf = sorting.get_unit_spike_features(unit_id=id,\n            #                                          feature_name='waveforms')\n            #     relevant_ch = most_relevant_ch(wf)\n            #     # Spike traces on the most relevant channel\n            #     traces = wf[:, relevant_ch, :]\n            #     traces_avg = np.mean(traces, axis=0)\n            #     traces_std = np.std(traces, axis=0)\n            #     nwbfile.add_unit(\n            #         id=id,\n            #         spike_times=spkt,\n            #         waveform_mean=traces_avg,\n            #         waveform_sd=traces_std\n            #     )\n\n            # Check that multidimensional features have the same shape across units\n            feature_shapes = dict()\n            for ft in all_features:\n                shapes = []\n                for unit_id in unit_ids:\n                    if ft in sorting.get_unit_spike_feature_names(unit_id):\n                        feat_value = sorting.get_unit_spike_features(unit_id, ft)\n                        if isinstance(feat_value[0], (int, np.integer, float, str, bool)):\n                            break\n                        elif isinstance(feat_value[0], (list, np.ndarray)):  # multidimensional features\n                            if np.array(feat_value).ndim > 1:\n                                shapes.append(np.array(feat_value).shape)\n                                feature_shapes[ft] = shapes\n                        elif isinstance(feat_value[0], dict):\n                            print(f\"Skipping feature '{ft}' because dictionaries are not supported.\")\n                            skip_features.append(ft)\n                            break\n                    else:\n                        print(f\"Skipping feature '{ft}' because not share across all units.\")\n                        skip_features.append(ft)\n                        break\n\n            nspikes = {k: get_nspikes(nwbfile.units, int(k)) for k in unit_ids}\n\n            for ft in feature_shapes.keys():\n                # skip first dimension (num_spikes) when comparing feature shape\n                if not np.all([elem[1:] == feature_shapes[ft][0][1:] for elem in feature_shapes[ft]]):\n                    print(f\"Skipping feature '{ft}' because it has variable size across units.\")\n                    skip_features.append(ft)\n\n            for ft in set(all_features) - set(skip_features):\n                values = []\n                if not ft.endswith('_idxs'):\n                    for unit_id in sorting.get_unit_ids():\n                        feat_vals = sorting.get_unit_spike_features(unit_id, ft)\n\n                        if len(feat_vals) < nspikes[unit_id]:\n                            skip_features.append(ft)\n                            print(f\"Skipping feature '{ft}' because it is not defined for all spikes.\")\n                            break\n                            # this means features are available for a subset of spikes\n                            # all_feat_vals = np.array([np.nan] * nspikes[unit_id])\n                            # feature_idxs = sorting.get_unit_spike_features(unit_id, feat_name + '_idxs')\n                            # all_feat_vals[feature_idxs] = feat_vals\n                        else:\n                            all_feat_vals = feat_vals\n                        values.append(all_feat_vals)\n\n                    flatten_vals = [item for sublist in values for item in sublist]\n                    nspks_list = [sp for sp in nspikes.values()]\n                    spikes_index = np.cumsum(nspks_list).astype('int64')\n                    if ft in nwbfile.units:  # If property already exists, skip it\n                        warnings.warn(f'Feature {ft} already present in units table, skipping it')\n                        continue\n                    set_dynamic_table_property(\n                        dynamic_table=nwbfile.units,\n                        row_ids=[int(k) for k in unit_ids],\n                        property_name=ft,\n                        values=flatten_vals,\n                        index=spikes_index,\n                    )\n        else:\n            warnings.warn(\"The nwbfile already contains units. These units will not be over-written.\")\n\n    @staticmethod\n    def write_sorting(\n            sorting: se.SortingExtractor,\n            save_path: PathType = None,\n            overwrite: bool = False,\n            nwbfile=None,\n            property_descriptions: Optional[dict] = None,\n            skip_properties: Optional[List[str]] = None,\n            skip_features: Optional[List[str]] = None,\n            use_times: bool = True,\n            **nwbfile_kwargs\n    ):\n        \"\"\"\n        Primary method for writing a SortingExtractor object to an NWBFile.\n\n        Parameters\n        ----------\n        sorting: SortingExtractor\n        save_path: PathType\n            Required if an nwbfile is not passed. The location where the NWBFile either exists, or will be written.\n        overwrite: bool\n            If using save_path, whether or not to overwrite the NWBFile if it already exists.\n        nwbfile: NWBFile\n            Required if a save_path is not specified. If passed, this function\n            will fill the relevant fields within the nwbfile. E.g., calling\n            spikeextractors.NwbRecordingExtractor.write_recording(\n                my_recording_extractor, my_nwbfile\n            )\n            will result in the appropriate changes to the my_nwbfile object.\n        property_descriptions: dict\n            For each key in this dictionary which matches the name of a unit\n            property in sorting, adds the value as a description to that\n            custom unit column.\n        skip_properties: list of str\n            Each string in this list that matches a unit property will not be written to the NWBFile.\n        skip_features: list of str\n            Each string in this list that matches a spike feature will not be written to the NWBFile.\n        use_times: bool (optional, defaults to False)\n            If True, the times are saved to the nwb file using sorting.frame_to_time(). If False (defualut),\n            the sampling rate is used.\n        nwbfile_kwargs: dict\n            Information for constructing the nwb file (optional).\n            Only used if no nwbfile exists at the save_path, and no nwbfile\n            was directly passed.\n        \"\"\"\n        assert HAVE_NWB, NwbSortingExtractor.installation_mesg\n        assert save_path is None or nwbfile is None, \\\n            \"Either pass a save_path location, or nwbfile object, but not both!\"\n        if nwbfile is not None:\n            assert isinstance(nwbfile, NWBFile), \"'nwbfile' should be a pynwb.NWBFile object!\"\n\n        if nwbfile is None:\n            if Path(save_path).is_file() and not overwrite:\n                read_mode = 'r+'\n            else:\n                read_mode = 'w'\n\n            with NWBHDF5IO(str(save_path), mode=read_mode) as io:\n                if read_mode == 'r+':\n                    nwbfile = io.read()\n                else:\n                    default_nwbfile_kwargs = dict(\n                        session_description=\"Auto-generated by NwbRecordingExtractor without description.\",\n                        identifier=str(uuid.uuid4()),\n                        session_start_time=datetime(1970, 1, 1)\n                    )\n                    default_nwbfile_kwargs.update(**nwbfile_kwargs)\n                    nwbfile = NWBFile(**default_nwbfile_kwargs)\n                se.NwbSortingExtractor.write_units(\n                    sorting=sorting,\n                    nwbfile=nwbfile,\n                    property_descriptions=property_descriptions,\n                    skip_properties=skip_properties,\n                    skip_features=skip_features,\n                    use_times=use_times\n                )\n                io.write(nwbfile)\n        else:\n            se.NwbSortingExtractor.write_units(\n                sorting=sorting,\n                nwbfile=nwbfile,\n                property_descriptions=property_descriptions,\n                skip_properties=skip_properties,\n                skip_features=skip_features,\n                use_times=use_times\n            )\n"
  },
  {
    "path": "spikeextractors/extractors/openephysextractors/__init__.py",
    "content": "from .openephysextractors import OpenEphysRecordingExtractor, OpenEphysSortingExtractor"
  },
  {
    "path": "spikeextractors/extractors/openephysextractors/openephysextractors.py",
    "content": "from spikeextractors import RecordingExtractor, SortingExtractor\nfrom pathlib import Path\nimport numpy as np\nfrom spikeextractors.extraction_tools import check_get_traces_args, check_get_unit_spike_train, check_get_ttl_args\nfrom packaging.version import parse\nimport warnings\n\n\ntry:\n    import pyopenephys\n    HAVE_OE = True\n\n    if parse(pyopenephys.__version__) >= parse(\"1.1.2\"):\n        HAVE_OE_11 = True\n    else:\n        warnings.warn(\"pyopenephys>=1.1.2 should be installed. Support for older versions will be removed in \"\n                      \"future releases. Install with:\\n\\n pip install --upgrade pyopenephys\\n\\n\")\n        HAVE_OE_11 = False\nexcept ImportError:\n    HAVE_OE = False\n    HAVE_OE_11 = False\n\nextractors_dir = Path(__file__).parent.parent\n\n\nclass OpenEphysRecordingExtractor(RecordingExtractor):\n    extractor_name = 'OpenEphysRecording'\n    has_default_locations = False\n    has_unscaled = True\n    installed = HAVE_OE  # check at class level if installed or not\n    is_writable = False\n    mode = 'folder'\n    installation_mesg = \"To use the OpenEphys extractor, install pyopenephys: \\n\\n pip install pyopenephys\\n\\n\"\n\n    def __init__(self, folder_path, experiment_id=0, recording_id=0):\n        assert self.installed, self.installation_mesg\n        RecordingExtractor.__init__(self)\n        self._recording_file = folder_path\n\n        self._fileobj = pyopenephys.File(folder_path)\n        self._recording = self._fileobj.experiments[experiment_id].recordings[recording_id]\n        self._set_analogsignal(self._recording.analog_signals[0])\n        self._kwargs = {'folder_path': str(Path(folder_path).absolute()), 'experiment_id': experiment_id,\n                        'recording_id': recording_id}\n\n    def _set_analogsignal(self, analogsignals):\n        self._analogsignals = analogsignals\n        # Set gains: int16 to uV\n        if HAVE_OE_11:\n            self.set_channel_gains(gains=self._analogsignals.gains)\n        else:\n            self.set_channel_gains(gains=self._analogsignals.gain)\n\n    def get_channel_ids(self):\n        if HAVE_OE_11:\n            return list(self._analogsignals.channel_ids)\n        else:\n            return list(range(self._analogsignals.signal.shape[0]))\n\n    def get_num_frames(self):\n        return self._analogsignals.signal.shape[1]\n\n    def get_sampling_frequency(self):\n        if HAVE_OE_11:\n            return self._analogsignals.sample_rate\n        else:\n            return float(self._recording.sample_rate.rescale('Hz').magnitude)\n\n    @check_get_traces_args\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        return self._analogsignals.signal[channel_ids, start_frame:end_frame]\n\n    @check_get_ttl_args\n    def get_ttl_events(self, start_frame=None, end_frame=None, channel_id=0):\n        channels = [np.unique(ev.channels)[0] for ev in self._recording.events]\n        assert channel_id in channels, f\"Specified 'channel' not found. Available channels are {channels}\"\n        ev = self._recording.events[channels.index(channel_id)]\n\n        ttl_frames = (ev.times.rescale(\"s\") * self.get_sampling_frequency()).magnitude.astype(int)\n        ttl_states = np.sign(ev.channel_states)\n        ttl_valid_idxs = np.where((ttl_frames >= start_frame) & (ttl_frames < end_frame))[0]\n        return ttl_frames[ttl_valid_idxs], ttl_states[ttl_valid_idxs]\n\n\nclass OpenEphysNPIXRecordingExtractor(OpenEphysRecordingExtractor):\n    extractor_name = 'OpenEphysNPIXRecording'\n    has_default_locations = False\n    installed = HAVE_OE_11  # check at class level if installed or not\n    is_writable = False\n    mode = 'folder'\n    installation_mesg = \"To use the OpenEphys extractor, \" \\\n                        \"install pyopenephys >= 1.1: \\n\\n pip install pyopenephys>=1.1\\n\\n\"\n\n    def __init__(self, folder_path, experiment_id=0, recording_id=0, stream=\"AP\"):\n        assert self.installed, self.installation_mesg\n        assert stream.upper() in [\"AP\", \"LFP\"]\n        OpenEphysRecordingExtractor.__init__(self, folder_path, experiment_id, recording_id)\n\n        analogsignals = self._recording.analog_signals\n        for analog in analogsignals:\n            channel_names = analog.channel_names\n            if np.all([stream.upper() in chan for chan in channel_names]):\n                self._set_analogsignal(analog)\n                # load neuropixels locations\n                channel_locations = np.loadtxt(extractors_dir / 'neuropixelsdatrecordingextractor' /\n                                               'channel_positions_neuropixels.txt').T\n                # get correct channel ID from channel name (e.g. AP32 --> 32)\n                channel_ids = [int(chan_name[chan_name.find(stream.upper())+len(stream):]) - 1\n                               for chan_name in channel_names]\n                locations = channel_locations[channel_ids]\n                self.set_channel_locations(locations)\n                for i, ch in enumerate(self.get_channel_ids()):\n                    self.set_channel_property(ch, \"channel_name\", channel_names[i])\n                break\n\n        self._kwargs = {'folder_path': str(Path(folder_path).absolute()), 'experiment_id': experiment_id,\n                        'recording_id': recording_id, 'stream': stream}\n\n\nclass OpenEphysSortingExtractor(SortingExtractor):\n    extractor_name = 'OpenEphysSorting'\n    installed = HAVE_OE  # check at class level if installed or not\n    is_writable = False\n    mode = 'file'\n    installation_mesg = \"To use the OpenEphys extractor, install pyopenephys: \\n\\n pip install pyopenephys\\n\\n\"  # error message when not installed\n\n    def __init__(self, folder_path, experiment_id=0, recording_id=0):\n        assert self.installed, self.installation_mesg\n        SortingExtractor.__init__(self)\n        self._recording_file = folder_path\n        self._recording = pyopenephys.File(folder_path).experiments[experiment_id].recordings[recording_id]\n        self._spiketrains = self._recording.spiketrains\n        self._unit_ids = list([np.unique(st.clusters)[0] for st in self._spiketrains])\n        self._sampling_frequency = float(self._recording.sample_rate.rescale('Hz').magnitude)\n\n        self._kwargs = {'folder_path': str(Path(folder_path).absolute()), 'experiment_id': experiment_id,\n                        'recording_id': recording_id}\n\n    def get_unit_ids(self):\n        return self._unit_ids\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n\n        st = self._spiketrains[unit_id]\n        inds = np.where((start_frame <= (st.times * self._recording.sample_rate)) &\n                        ((st.times * self._recording.sample_rate) < end_frame))\n        return (st.times[inds] * self._recording.sample_rate).magnitude\n"
  },
  {
    "path": "spikeextractors/extractors/phyextractors/__init__.py",
    "content": "from .phyextractors import PhyRecordingExtractor, PhySortingExtractor"
  },
  {
    "path": "spikeextractors/extractors/phyextractors/phyextractors.py",
    "content": "import numpy as np\nfrom pathlib import Path\nimport csv\nfrom typing import Union, Optional\n\nfrom spikeextractors import SortingExtractor, RecordingExtractor\nfrom spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor\nfrom spikeextractors.extraction_tools import read_python, check_get_unit_spike_train\n\nPathType = Union[str, Path]\n\n\nclass PhyRecordingExtractor(BinDatRecordingExtractor):\n    \"\"\"\n    RecordingExtractor for a Phy output folder\n\n    Parameters\n    ----------\n    folder_path: str or Path\n        Path to the output Phy folder (containing the params.py)\n    \"\"\"\n    extractor_name = 'PhyRecording'\n    has_default_locations = True\n    has_unscaled = False\n    installed = True  # check at class level if installed or not\n    is_writable = False\n    mode = 'folder'\n    installation_mesg = \"\"  # error message when not installed\n\n    def __init__(self, folder_path: PathType):\n        RecordingExtractor.__init__(self)\n        phy_folder = Path(folder_path)\n\n        self.params = read_python(str(phy_folder / 'params.py'))\n        datfile = [x for x in phy_folder.iterdir() if x.suffix == '.dat' or x.suffix == '.bin']\n\n        if (phy_folder / 'channel_map_si.npy').is_file():\n            channel_map = list(np.squeeze(np.load(phy_folder / 'channel_map_si.npy')))\n            assert max(channel_map) < self.params['n_channels_dat'], \"Channel map inconsistent with dat file.\"\n        elif (phy_folder / 'channel_map.npy').is_file():\n            channel_map = list(np.squeeze(np.load(phy_folder / 'channel_map.npy')))\n            assert max(channel_map) < self.params['n_channels_dat'], \"Channel map inconsistent with dat file.\"\n        else:\n            channel_map = list(range(self.params['n_channels_dat']))\n\n        BinDatRecordingExtractor.__init__(self, datfile[0], sampling_frequency=float(self.params['sample_rate']),\n                                          dtype=self.params['dtype'], numchan=self.params['n_channels_dat'],\n                                          recording_channels=list(channel_map))\n\n        if (phy_folder / 'channel_groups.npy').is_file():\n            channel_groups = np.load(phy_folder / 'channel_groups.npy')\n            assert len(channel_groups) == self.get_num_channels()\n            self.set_channel_groups(channel_groups)\n\n        if (phy_folder / 'channel_positions.npy').is_file():\n            channel_locations = np.load(phy_folder / 'channel_positions.npy')\n            assert len(channel_locations) == self.get_num_channels()\n            self.set_channel_locations(channel_locations)\n\n        self._kwargs = {'folder_path': str(Path(folder_path).absolute())}\n\n\nclass PhySortingExtractor(SortingExtractor):\n    \"\"\"\n    SortingExtractor for a Phy output folder\n\n    Parameters\n    ----------\n    folder_path: str or Path\n        Path to the output Phy folder (containing the params.py)\n    exclude_cluster_groups: list (optional)\n        List of cluster groups to exclude (e.g. [\"noise\", \"mua\"]\n    \"\"\"\n    extractor_name = 'PhySorting'\n    installed = True  # check at class level if installed or not\n    is_writable = False\n    mode = 'folder'\n    installation_mesg = \"\"  # error message when not installed\n\n    def __init__(self, folder_path: PathType, exclude_cluster_groups: Optional[list] = None):\n        SortingExtractor.__init__(self)\n        phy_folder = Path(folder_path)\n\n        spike_times = np.load(phy_folder / 'spike_times.npy')\n        spike_templates = np.load(phy_folder / 'spike_templates.npy')\n\n        if (phy_folder / 'spike_clusters.npy').is_file():\n            spike_clusters = np.load(phy_folder / 'spike_clusters.npy')\n        else:\n            spike_clusters = spike_templates\n\n        if (phy_folder / 'amplitudes.npy').is_file():\n            amplitudes = np.squeeze(np.load(phy_folder / 'amplitudes.npy'))\n        else:\n            amplitudes = np.ones(len(spike_times))\n\n        if (phy_folder / 'pc_features.npy').is_file():\n            pc_features = np.squeeze(np.load(phy_folder / 'pc_features.npy'))\n        else:\n            pc_features = None\n\n        clust_id = np.unique(spike_clusters)\n        self._unit_ids = list(clust_id)\n        spike_times.astype(int)\n        self.params = read_python(str(phy_folder / 'params.py'))\n        self._sampling_frequency = self.params['sample_rate']\n\n        # set unit quality properties\n        csv_tsv_files = [x for x in phy_folder.iterdir() if x.suffix == '.csv' or x.suffix == '.tsv']\n        for f in csv_tsv_files:\n            if f.suffix == '.csv':\n                with f.open() as csv_file:\n                    csv_reader = csv.reader(csv_file, delimiter=',')\n                    line_count = 0\n                    for row in csv_reader:\n                        if line_count == 0:\n                            tokens = row[0].split(\"\\t\")\n                            property_name = tokens[1]\n                        else:\n                            tokens = row[0].split(\"\\t\")\n                            if int(tokens[0]) in self.get_unit_ids():\n                                if 'cluster_group' in str(f):\n                                    self.set_unit_property(int(tokens[0]), 'quality', tokens[1])\n                                elif property_name == 'chan_grp' or property_name == 'ch_group':\n                                    self.set_unit_property(int(tokens[0]), 'group', int(tokens[1]))\n                                else:\n                                    if isinstance(tokens[1], (int, np.int, float, str)):\n                                        self.set_unit_property(int(tokens[0]), property_name, tokens[1])\n                            line_count += 1\n            elif f.suffix == '.tsv':\n                with f.open() as csv_file:\n                    csv_reader = csv.reader(csv_file, delimiter='\\t')\n                    line_count = 0\n                    for row in csv_reader:\n                        if line_count == 0:\n                            property_name = row[1]\n                        else:\n                            if len(row) == 2:\n                                if int(row[0]) in self.get_unit_ids():\n                                    if 'cluster_group' in str(f):\n                                        self.set_unit_property(int(row[0]), 'quality', row[1])\n                                    elif property_name == 'chan_grp' or property_name == 'ch_group':\n                                        self.set_unit_property(int(row[0]), 'group', int(row[1]))\n                                    else:\n                                        if isinstance(row[1], (int, float, str)) and len(row) == 2:\n                                            self.set_unit_property(int(row[0]), property_name, row[1])\n                        line_count += 1\n\n        for unit in self.get_unit_ids():\n            if 'quality' not in self.get_unit_property_names(unit):\n                self.set_unit_property(unit, 'quality', 'unsorted')\n\n        if exclude_cluster_groups is not None:\n            if len(exclude_cluster_groups) > 0:\n                included_units = []\n                for u in self.get_unit_ids():\n                    if self.get_unit_property(u, 'quality') not in exclude_cluster_groups:\n                        included_units.append(u)\n            else:\n                included_units = self._unit_ids\n        else:\n            included_units = self._unit_ids\n\n        original_units = self._unit_ids\n        self._unit_ids = included_units\n        # set features\n        self._spiketrains = []\n        for clust in self._unit_ids:\n            idx = np.where(spike_clusters == clust)[0]\n            self._spiketrains.append(spike_times[idx])\n            self.set_unit_spike_features(clust, 'amplitudes', amplitudes[idx])\n            if pc_features is not None:\n                self.set_unit_spike_features(clust, 'pc_features', pc_features[idx])\n\n        self._kwargs = {'folder_path': str(Path(folder_path).absolute()),\n                        'exclude_cluster_groups': exclude_cluster_groups}\n\n    def get_unit_ids(self):\n        return list(self._unit_ids)\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n        times = self._spiketrains[self.get_unit_ids().index(unit_id)]\n        inds = np.where((start_frame <= times) & (times < end_frame))\n        return times[inds]\n"
  },
  {
    "path": "spikeextractors/extractors/shybridextractors/__init__.py",
    "content": "from .shybridextractors import SHYBRIDRecordingExtractor, SHYBRIDSortingExtractor"
  },
  {
    "path": "spikeextractors/extractors/shybridextractors/shybridextractors.py",
    "content": "import os\nfrom pathlib import Path\nimport numpy as np\nfrom spikeextractors import RecordingExtractor, SortingExtractor\nfrom spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor\nfrom spikeextractors.extraction_tools import save_to_probe_file, load_probe_file, check_get_unit_spike_train\n\ntry:\n    import hybridizer.io as sbio\n    import hybridizer.probes as sbprb\n    import yaml\n\n    HAVE_SBEX = True\nexcept ImportError:\n    HAVE_SBEX = False\n\n\nclass SHYBRIDRecordingExtractor(RecordingExtractor):\n    extractor_name = 'SHYBRIDRecording'\n    installed = HAVE_SBEX\n    has_default_locations = True\n    has_unscaled = False\n    is_writable = True\n    mode = 'file'\n    installation_mesg = \"To use the SHYBRID extractors, install SHYBRID and pyyaml: \" \\\n                        \"\\n\\n pip install shybrid pyyaml\\n\\n\"\n\n    def __init__(self, file_path):\n        # load params file related to the given shybrid recording\n        assert self.installed, self.installation_mesg\n        RecordingExtractor.__init__(self)\n        params = sbio.get_params(file_path)['data']\n\n        # create a shybrid probe object\n        probe = sbprb.Probe(params['probe'])\n        nb_channels = probe.total_nb_channels\n\n        # translate the byte ordering\n        # TODO still ambiguous, shybrid should assume time_axis=1, since spike interface makes an assumption on the byte ordering\n        byte_order = params['order']\n        if byte_order == 'C':\n            time_axis = 1\n        elif byte_order == 'F':\n            time_axis = 0\n\n        # piggyback on binary data recording extractor\n        recording = BinDatRecordingExtractor(\n            file_path,\n            params['fs'],\n            nb_channels,\n            params['dtype'],\n            time_axis=time_axis)\n        \n        # load probe file\n        self._recording = load_probe_file(recording, params['probe'])\n        self._kwargs = {'file_path': str(Path(file_path).absolute())}\n\n    def get_channel_ids(self):\n        return self._recording.get_channel_ids()\n\n    def get_num_frames(self):\n        return self._recording.get_num_frames()\n\n    def get_sampling_frequency(self):\n        return self._recording.get_sampling_frequency()\n\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        return self._recording.get_traces(channel_ids=channel_ids, start_frame=start_frame, end_frame=end_frame,\n                                          return_scaled=return_scaled)\n\n    @staticmethod\n    def write_recording(recording, save_path, initial_sorting_fn, dtype='float32', **write_binary_kwargs):\n        \"\"\" Convert and save the recording extractor to SHYBRID format\n\n        Parameters\n        ----------\n        recording: RecordingExtractor\n            The recording extractor to be converted and saved\n        save_path: str\n            Full path to desired target folder\n        initial_sorting_fn: str\n            Full path to the initial sorting csv file (can also be generated\n            using write_sorting static method from the SHYBRIDSortingExtractor)\n        dtype: dtype\n            Type of the saved data. Default float32.\n        **write_binary_kwargs: keyword arguments for write_to_binary_dat_format() function\n        \"\"\"\n        assert HAVE_SBEX, SHYBRIDRecordingExtractor.installation_mesg\n        RECORDING_NAME = 'recording.bin'\n        PROBE_NAME = 'probe.prb'\n        PARAMETERS_NAME = 'recording.yml'\n\n        # location information has to be present in order for shybrid to\n        # be able to operate on the recording\n        if 'location' not in recording.get_shared_channel_property_names():\n            raise GeometryNotLoadedError(\"Channel locations were not found\")\n\n        # write recording\n        recording_fn = os.path.join(save_path, RECORDING_NAME)\n        recording.write_to_binary_dat_format(save_path=recording_fn,\n                                             time_axis=0, dtype=dtype,\n                                             **write_binary_kwargs)\n\n        # write probe file\n        probe_fn = os.path.join(save_path, PROBE_NAME)\n        save_to_probe_file(recording, probe_fn)\n\n        # create parameters file\n        parameters = dict(clusters=initial_sorting_fn,\n                          data=dict(dtype=dtype,\n                                    fs=str(recording.get_sampling_frequency()),\n                                    order='F',\n                                    probe=probe_fn))\n\n        # write parameters file\n        parameters_fn = os.path.join(save_path, PARAMETERS_NAME)\n        with open(parameters_fn, 'w') as fp:\n            yaml.dump(parameters, fp)\n\n\nclass SHYBRIDSortingExtractor(SortingExtractor):\n    extractor_name = 'SHYBRIDSorting'\n    installed = HAVE_SBEX\n    is_writable = True\n    installation_mesg = \"To use the SHYBRID extractors, install SHYBRID: \\n\\n pip install shybrid\\n\\n\"\n\n    def __init__(self, file_path, delimiter=','):\n        assert self.installed, self.installation_mesg\n        SortingExtractor.__init__(self)\n\n        if os.path.isfile(file_path):\n            self._spike_clusters = sbio.SpikeClusters()\n            self._spike_clusters.fromCSV(file_path, None, delimiter=delimiter)\n        else:\n            raise FileNotFoundError('the ground truth file \"{}\" could not be found'.format(file_path))\n        self._kwargs = {'file_path': str(Path(file_path).absolute()), 'delimiter': delimiter}\n\n    def get_unit_ids(self):\n        return self._spike_clusters.keys()\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n        train = self._spike_clusters[unit_id].get_actual_spike_train().spikes\n        idxs = np.where((start_frame <= train) & (train < end_frame))\n        return train[idxs]\n\n    @staticmethod\n    def write_sorting(sorting, save_path):\n        \"\"\" Convert and save the sorting extractor to SHYBRID CSV format\n\n        parameters\n        ----------\n        sorting : SortingExtractor\n            The sorting extractor to be converted and saved\n        save_path : str\n            Full path to the desired target folder\n        \"\"\"\n        assert HAVE_SBEX, SHYBRIDSortingExtractor.installation_mesg\n        dump = np.empty((0, 2))\n\n        for unit_id in sorting.get_unit_ids():\n            spikes = sorting.get_unit_spike_train(unit_id)[:, np.newaxis]\n            expanded_id = (np.ones(spikes.size) * unit_id)[:, np.newaxis]\n            tmp_concat = np.concatenate((expanded_id, spikes), axis=1)\n\n            dump = np.concatenate((dump, tmp_concat), axis=0)\n\n        sorting_fn = os.path.join(save_path, 'initial_sorting.csv')\n        np.savetxt(sorting_fn, dump, delimiter=',', fmt='%i')\n\n\nclass GeometryNotLoadedError(Exception):\n    \"\"\" Raised when the recording extractor has no associated channel locations\n    \"\"\"\n    pass\n\n\nparams_template = \\\n    \"\"\"clusters:\n      csv: {initial_sorting_fn}\n    data:\n      dtype: {data_type}\n      fs: {sampling_frequency}\n      order: {byte_ordering}\n      probe: {probe_fn}\n    \"\"\"\n"
  },
  {
    "path": "spikeextractors/extractors/spikeglxrecordingextractor/__init__.py",
    "content": "from .spikeglxrecordingextractor import SpikeGLXRecordingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/spikeglxrecordingextractor/readSGLX.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\n----------------------------------------------------------------\nThis is an adapted version of auxiliary functions to read from SpikeGLX data files.\nThe original code can be found at:\n    https://billkarsh.github.io/SpikeGLX/#offline-analysis-tools\n----------------------------------------------------------------\nRequires python 3\n\nThe main() function at the bottom of this file can run from an\ninterpreter, or, the helper functions can be imported into a\nnew module or Jupyter notebook (an example is included).\n\nSimple helper functions and python dictionary demonstrating\nhow to read and manipulate SpikeGLX meta and binary files.\n\nThe most important part of the demo is readMeta().\nPlease read the comments for that function. Use of\nthe 'meta' dictionary will make your data handling\nmuch easier!\n\n\"\"\"\nimport numpy as np\n# import matplotlib.pyplot as plt\nfrom pathlib import Path\n# from tkinter import Tk\n# from tkinter import filedialog\n\n\n# Parse ini file returning a dictionary whose keys are the metadata\n# left-hand-side-tags, and values are string versions of the right-hand-side\n# metadata values. We remove any leading '~' characters in the tags to match\n# the MATLAB version of readMeta.\n#\n# The string values are converted to numbers using the \"int\" and \"float\"\n# fucntions. Note that python 3 has no size limit for integers.\n#\ndef readMeta(binFullPath):\n    metaName = binFullPath.stem + \".meta\"\n    metaPath = Path(binFullPath.parent / metaName)\n    metaDict = {}\n    if metaPath.exists():\n        # print(\"meta file present\")\n        with metaPath.open() as f:\n            mdatList = f.read().splitlines()\n            # convert the list entries into key value pairs\n            for m in mdatList:\n                csList = m.split(sep='=')\n                if csList[0][0] == '~':\n                    currKey = csList[0][1:len(csList[0])]\n                else:\n                    currKey = csList[0]\n                metaDict.update({currKey: csList[1]})\n    else:\n        print(\"no meta file\")\n    return(metaDict)\n\n\n# Return sample rate as python float.\n# On most systems, this will be implemented as C++ double.\n# Use python command sys.float_info to get properties of float on your system.\n#\ndef SampRate(meta):\n    if meta['typeThis'] == 'imec':\n        srate = float(meta['imSampRate'])\n    else:\n        srate = float(meta['niSampRate'])\n    return(srate)\n\n\n# Return a multiplicative factor for converting 16-bit file data\n# to volatge. This does not take gain into account. The full\n# conversion with gain is:\n#         dataVolts = dataInt * fI2V / gain\n# Note that each channel may have its own gain.\n#\ndef Int2Volts(meta):\n    if meta['typeThis'] == 'imec':\n        fI2V = float(meta['imAiRangeMax'])/512\n    else:\n        fI2V = float(meta['niAiRangeMax'])/32768\n    return(fI2V)\n\n\n# Return array of original channel IDs. As an example, suppose we want the\n# imec gain for the ith channel stored in the binary data. A gain array\n# can be obtained using ChanGainsIM(), but we need an original channel\n# index to do the lookup. Because you can selectively save channels, the\n# ith channel in the file isn't necessarily the ith acquired channel.\n# Use this function to convert from ith stored to original index.\n# Note that the SpikeGLX channels are 0 based.\n#\ndef OriginalChans(meta):\n    if meta['snsSaveChanSubset'] == 'all':\n        # output = int32, 0 to nSavedChans - 1\n        chans = np.arange(0, int(meta['nSavedChans']))\n    else:\n        # parse the snsSaveChanSubset string\n        # split at commas\n        chStrList = meta['snsSaveChanSubset'].split(sep=',')\n        chans = np.arange(0, 0)  # creates an empty array of int32\n        for sL in chStrList:\n            currList = sL.split(sep=':')\n            if len(currList) > 1:\n                # each set of contiguous channels specified by\n                # chan1:chan2 inclusive\n                newChans = np.arange(int(currList[0]), int(currList[1])+1)\n            else:\n                newChans = np.arange(int(currList[0]), int(currList[0])+1)\n            chans = np.append(chans, newChans)\n    return(chans)\n\n\n# Return counts of each nidq channel type that composes the timepoints\n# stored in the binary file.\n#\ndef ChannelCountsNI(meta):\n    chanCountList = meta['snsMnMaXaDw'].split(sep=',')\n    MN = int(chanCountList[0])\n    MA = int(chanCountList[1])\n    XA = int(chanCountList[2])\n    DW = int(chanCountList[3])\n    return(MN, MA, XA, DW)\n\n\n# Return counts of each imec channel type that composes the timepoints\n# stored in the binary files.\n#\ndef ChannelCountsIM(meta):\n    chanCountList = meta['snsApLfSy'].split(sep=',')\n    AP = int(chanCountList[0])\n    LF = int(chanCountList[1])\n    SY = int(chanCountList[2])\n    return(AP, LF, SY)\n\n\n# Return gain for ith channel stored in nidq file.\n# ichan is a saved channel index, rather than the original (acquired) index.\n#\ndef ChanGainNI(ichan, savedMN, savedMA, meta):\n    if ichan < savedMN:\n        gain = float(meta['niMNGain'])\n    elif ichan < (savedMN + savedMA):\n        gain = float(meta['niMAGain'])\n    else:\n        gain = 1    # non multiplexed channels have no extra gain\n    return(gain)\n\n\n# Return gain for imec channels.\n# Index into these with the original (acquired) channel IDs.\n#\ndef ChanGainsIM(meta):\n    imroList = meta['imroTbl'].split(sep=')')\n    # One entry for each channel plus header entry,\n    # plus a final empty entry following the last ')'\n    nChan = len(imroList) - 2\n    APgain = np.zeros(nChan)        # default type = float\n    LFgain = np.zeros(nChan)\n    for i in range(0, nChan):\n        currList = imroList[i+1].split(sep=' ')\n        APgain[i] = currList[3]\n        LFgain[i] = currList[4]\n    return(APgain, LFgain)\n\n\n# Having accessed a block of raw nidq data using makeMemMapRaw, convert\n# values to gain-corrected voltage. The conversion is only applied to the\n# saved-channel indicies in chanList. Remember, saved-channel indicies are\n# in the range [0:nSavedChans-1]. The dimensions of dataArray remain\n# unchanged. ChanList examples:\n# [0:MN-1]    all MN channels (MN from ChannelCountsNI)\n# [2,6,20]  just these three channels (zero based, as they appear in SGLX).\n#\ndef GainCorrectNI(dataArray, chanList, meta):\n    MN, MA, XA, DW = ChannelCountsNI(meta)\n    fI2V = Int2Volts(meta)\n    # print statements used for testing...\n    # print(\"NI fI2V: %.3e\" % (fI2V))\n    # print(\"NI ChanGainNI: %.3f\" % (ChanGainNI(0, MN, MA, meta)))\n\n    # make array of floats to return. dataArray contains only the channels\n    # in chanList, so output matches that shape\n    # convArray = np.zeros(dataArray.shape, dtype=float)\n    conv = np.zeros(len(chanList), dtype=float)\n    for i in range(0, len(chanList)):\n        j = chanList[i]             # index into timepoint\n        conv[i] = fI2V/ChanGainNI(j, MN, MA, meta)\n        # dataArray contains only the channels in chanList\n        #convArray[i, :] = dataArray[i, :] * conv[i]\n    return conv\n\n\n# Having accessed a block of raw imec data using makeMemMapRaw, convert\n# values to gain corrected voltages. The conversion is only applied to\n# the saved-channel indicies in chanList. Remember saved-channel indicies\n# are in the range [0:nSavedChans-1]. The dimensions of the dataArray\n# remain unchanged. ChanList examples:\n# [0:AP-1]    all AP channels\n# [2,6,20]    just these three channels (zero based)\n# Remember that for an lf file, the saved channel indicies (fetched by\n# OriginalChans) will be in the range 384-767 for a standard 3A or 3B probe.\n#\ndef GainCorrectIM(dataArray, chanList, meta):\n    # Look up gain with acquired channel ID\n    chans = OriginalChans(meta)\n    APgain, LFgain = ChanGainsIM(meta)\n    nAP = len(APgain)\n    nNu = nAP * 2\n\n    # Common converstion factor\n    fI2V = Int2Volts(meta)\n\n    # make array of floats to return. dataArray contains only the channels\n    # in chanList, so output matches that shape\n    # convArray = np.zeros(dataArray.shape, dtype='float')\n    conv = np.zeros(len(chanList), dtype=float)\n    for i in range(0, len(chanList)):\n        j = chanList[i]     # index into timepoint\n        k = chans[j]        # acquisition index\n        if k < nAP:\n            conv[i] = fI2V / APgain[k]\n        elif k < nNu:\n            conv[i] = fI2V / LFgain[k - nAP]\n        else:\n            conv[i] = 1\n        # The dataArray contains only the channels in chList\n        #convArray[i, :] = dataArray[i, :]*conv[i]\n    return conv\n\n\ndef makeMemMapRaw(binFullPath, meta):\n    nChan = int(meta['nSavedChans'])\n    nFileSamp = int(int(meta['fileSizeBytes'])/(2*nChan))\n    # print(\"nChan: %d, nFileSamp: %d\" % (nChan, nFileSamp))\n    rawData = np.memmap(binFullPath, dtype='int16', mode='r',\n                        shape=(nChan, nFileSamp), offset=0, order='F')\n    return(rawData)\n\n\n# Return an array [lines X timepoints] of uint8 values for a\n# specified set of digital lines.\n#\n# - dwReq is the zero-based index into the saved file of the\n#    16-bit word that contains the digital lines of interest.\n# - dLineList is a zero-based list of one or more lines/bits\n#    to scan from word dwReq.\n#\ndef ExtractDigital(rawData, firstSamp, lastSamp, dwReq, dLineList, meta):\n    # Get channel index of requested digial word dwReq\n    if meta['typeThis'] == 'imec':\n        AP, LF, SY = ChannelCountsIM(meta)\n        if SY == 0:\n            print(\"No imec sync channel saved.\")\n            digArray = np.zeros((0), 'uint8')\n            return(digArray)\n        else:\n            digCh = AP + LF + dwReq\n    else:\n        MN, MA, XA, DW = ChannelCountsNI(meta)\n        if dwReq > DW-1:\n            print(\"Maximum digital word in file = %d\" % (DW-1))\n            digArray = np.zeros((0), 'uint8')\n            return(digArray)\n        else:\n            digCh = MN + MA + XA + dwReq\n\n    selectData = np.ascontiguousarray(rawData[digCh, firstSamp:lastSamp], 'int16')\n    nSamp = lastSamp-firstSamp\n\n    # unpack bits of selectData; unpack bits works with uint8\n    # origintal data is int16\n    bitWiseData = np.unpackbits(selectData.view(dtype='uint8'))\n    # output is 1-D array, nSamp*16. Reshape and transpose\n    bitWiseData = np.transpose(np.reshape(bitWiseData, (nSamp, 16)))\n\n    nLine = len(dLineList)\n    digArray = np.zeros((nLine, nSamp), 'uint8')\n    for i in range(0, nLine):\n        byteN, bitN = np.divmod(dLineList[i], 8)\n        targI = byteN*8 + (7 - bitN)\n        digArray[i, :] = bitWiseData[targI, :]\n    return (digArray)\n\n\n# Sample calling program to get a file from the user,\n# read metadata fetch sample rate, voltage conversion\n# values for this file and channel, and plot a small range\n# of voltages from a single channel.\n# Note that this code merely demonstrates indexing into the\n# data file, without any optimization for efficiency.\n#\n# def main():\n# \n#     # Get file from user\n#     root = Tk()         # create the Tkinter widget\n#     root.withdraw()     # hide the Tkinter root window\n# \n#     # Windows specific; forces the window to appear in front\n#     root.attributes(\"-topmost\", True)\n# \n#     binFullPath = Path(filedialog.askopenfilename(title=\"Select binary file\"))\n#     root.destroy()      # destroy the Tkinter widget\n# \n#     # Other parameters about what data to read\n#     tStart = 0\n#     tEnd = 1\n#     dataType = 'D'    # 'A' for analog, 'D' for digital data\n# \n#     # For analog channels: zero-based index of a channel to extract,\n#     # gain correct and plot (plots first channel only)\n#     chanList = [0]\n# \n#     # For a digital channel: zero based index of the digital word in\n#     # the saved file. For imec data there is never more than one digital word.\n#     dw = 0\n# \n#     # Zero-based Line indicies to read from the digital word and plot.\n#     # For 3B2 imec data: the sync pulse is stored in line 6.\n#     dLineList = [0, 1, 6]\n# \n#     # Read in metadata; returns a dictionary with string for values\n#     meta = readMeta(binFullPath)\n# \n#     # parameters common to NI and imec data\n#     sRate = SampRate(meta)\n#     firstSamp = int(sRate*tStart)\n#     lastSamp = int(sRate*tEnd)\n#     # array of times for plot\n#     tDat = np.arange(firstSamp, lastSamp+1)\n#     tDat = 1000*tDat/sRate      # plot time axis in msec\n# \n#     rawData = makeMemMapRaw(binFullPath, meta)\n# \n#     if dataType == 'A':\n#         selectData = rawData[chanList, firstSamp:lastSamp+1]\n#         if meta['typeThis'] == 'imec':\n#             # apply gain correction and convert to uV\n#             convData = 1e6*GainCorrectIM(selectData, chanList, meta)\n#         else:\n#             MN, MA, XA, DW = ChannelCountsNI(meta)\n#             # print(\"NI channel counts: %d, %d, %d, %d\" % (MN, MA, XA, DW))\n#             # apply gain coorection and conver to mV\n#             convData = 1e3*GainCorrectNI(selectData, chanList, meta)\n# \n#         # # Plot the first of the extracted channels\n#         # fig, ax = plt.subplots()\n#         # ax.plot(tDat, convData[0, :])\n#         # plt.show()\n# \n#     else:\n#         digArray = ExtractDigital(rawData, firstSamp, lastSamp, dw,\n#                                   dLineList, meta)\n# \n#         # # Plot the first of the extracted channels\n#         # fig, ax = plt.subplots()\n#         # \n#         # for i in range(0, len(dLineList)):\n#         #    ax.plot(tDat, digArray[i, :])\n#         # plt.show()\n# \n# \n# if __name__ == \"__main__\":\n#     main()\n"
  },
  {
    "path": "spikeextractors/extractors/spikeglxrecordingextractor/spikeglxrecordingextractor.py",
    "content": "from .readSGLX import readMeta, SampRate, makeMemMapRaw, GainCorrectIM, GainCorrectNI, ExtractDigital\nimport numpy as np\nfrom pathlib import Path\n\nfrom spikeextractors import RecordingExtractor\nfrom spikeextractors.extraction_tools import check_get_traces_args, check_get_ttl_args\nimport re\n\nclass SpikeGLXRecordingExtractor(RecordingExtractor):\n    \"\"\"\n    RecordingExtractor from a SpikeGLX Neuropixels file\n\n    Parameters\n    ----------\n    file_path: str or Path\n        Path to the ap.bin, lf.bin, or nidq.bin file\n    dtype: str\n        'int16' or 'float'. If 'float' is selected, the returned traces are converted to uV\n    x_pitch: int\n        The x pitch of the probe (default 16)\n    y_pitch: int\n        The y pitch of the probe (default 20)\n    \"\"\"\n    extractor_name = 'SpikeGLXRecording'\n    has_default_locations = True\n    has_unscaled = True\n    installed = True  # check at class level if installed or not\n    is_writable = False\n    mode = 'file'\n    installation_mesg = \"To use the SpikeGLXRecordingExtractor run:\\n\\n pip install mtscomp\\n\\n\"  # error message when not installed\n\n    def __init__(self, file_path: str, x_pitch: int = 32, y_pitch: int = 20):\n        RecordingExtractor.__init__(self)\n        self._npxfile = Path(file_path)\n        self._basepath = self._npxfile.parents[0]\n        \n        # Gets file type: 'imec0.ap', 'imec0.lf' or 'nidq'\n        assert re.search(r'imec[0-9]*.(ap|lf){1}.bin$', self._npxfile.name) or  'nidq' in self._npxfile.name, \\\n               \"'file_path' can be an imec.ap, imec.lf, imec0.ap, imec0.lf, or nidq file\"\n           \n        if 'ap.bin' in str(self._npxfile):\n            rec_type = \"ap\"\n            self.is_filtered = True\n        elif 'lf.bin' in str(self._npxfile):\n            rec_type = \"lf\"\n        else:\n            rec_type = \"nidq\"\n        aux = self._npxfile.stem.split('.')[-1]\n        if aux == 'nidq':\n            self._ftype = aux\n        else:\n            self._ftype = self._npxfile.stem.split('.')[-2] + '.' + aux\n\n        # Metafile\n        self._metafile = self._basepath.joinpath(self._npxfile.stem+'.meta')\n        if not self._metafile.exists():\n            raise Exception(\"'meta' file for '\"+self._ftype+\"' traces should be in the same folder.\")\n        # Read in metadata, returns a dictionary\n        meta = readMeta(self._npxfile)\n        self._meta = meta\n\n        # Traces in 16-bit format\n        self._raw = makeMemMapRaw(self._npxfile, meta)  # [chanList, firstSamp:lastSamp+1]\n\n        # sampling rate and ap channels\n        self._sampling_frequency = SampRate(meta)\n\n        tot_chan, ap_chan, lfp_chan, locations, channel_ids, channel_names \\\n            = _parse_spikeglx_metafile(self._metafile,\n                                       x_pitch=x_pitch,\n                                       y_pitch=y_pitch,\n                                       rec_type=rec_type)\n        if rec_type in (\"ap\", \"lf\"):\n            self._channels = channel_ids\n            # locations\n            if len(locations) > 0:\n                self.set_channel_locations(locations)\n            if len(channel_names) > 0:\n                if len(channel_names) == len(self._channels):\n                    for i, ch in enumerate(self._channels):\n                        self.set_channel_property(ch, \"channel_name\", channel_names[i])\n\n            if rec_type == \"ap\":\n                if ap_chan < tot_chan:\n                    self._timeseries = self._raw[0:ap_chan, :]\n            elif rec_type == \"lf\":\n                if lfp_chan < tot_chan:\n                    self._timeseries = self._raw[0:lfp_chan, :]\n        else:\n            # nidq\n            self._channels = list(range(int(tot_chan)))\n            self._timeseries = self._raw\n\n        # get gains\n        if meta['typeThis'] == 'imec':\n            gains = GainCorrectIM(self._timeseries, self._channels, meta)\n        elif meta['typeThis'] == 'nidq':\n            gains = GainCorrectNI(self._timeseries, self._channels, meta)\n\n        # set gains - convert from int16 to uVolt\n        self.set_channel_gains(gains=gains*1e6, channel_ids=self._channels)\n        self._kwargs = {'file_path': str(Path(file_path).absolute()),\n                        'x_pitch': x_pitch, 'y_pitch': y_pitch}\n\n    def get_channel_ids(self):\n        return self._channels\n\n    def get_num_frames(self):\n        return self._timeseries.shape[1]\n\n    def get_sampling_frequency(self):\n        return self._sampling_frequency\n\n    @check_get_traces_args\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        channel_idxs = np.array([self.get_channel_ids().index(ch) for ch in channel_ids])\n        if np.array_equal(channel_ids, self.get_channel_ids()):\n            traces = self._timeseries[:, start_frame:end_frame]\n        else:\n            if np.all(np.diff(channel_idxs) == 1):\n                traces = self._timeseries[channel_idxs[0]:channel_idxs[0]+len(channel_idxs), start_frame:end_frame]\n            else:\n                # This block of the execution will return the data as an array, not a memmap\n                traces = self._timeseries[channel_idxs, start_frame:end_frame]\n\n        return traces\n\n    @check_get_ttl_args\n    def get_ttl_events(self, start_frame=None, end_frame=None, channel_id=0):\n        channel = [channel_id]\n        dw = 0\n        dig = ExtractDigital(self._raw, firstSamp=start_frame, lastSamp=end_frame, dwReq=dw, dLineList=channel,\n                             meta=self._meta)\n        dig = np.squeeze(dig)\n        diff_dig = np.diff(dig.astype(int))\n\n        rising = np.where(diff_dig > 0)[0] + start_frame\n        falling = np.where(diff_dig < 0)[0] + start_frame\n\n        ttl_frames = np.concatenate((rising, falling))\n        ttl_states = np.array([1] * len(rising) + [-1] * len(falling))\n        sort_idxs = np.argsort(ttl_frames)\n        return ttl_frames[sort_idxs], ttl_states[sort_idxs]\n\n\ndef _parse_spikeglx_metafile(metafile, x_pitch, y_pitch, rec_type):\n    tot_channels = None\n    ap_channels = None\n    lfp_channels = None\n\n    y_offset = 20\n    x_offset = 11\n\n    locations = []\n    channel_names = []\n    channel_ids = []\n    with Path(metafile).open() as f:\n        for line in f.readlines():\n            if 'nSavedChans' in line:\n                tot_channels = int(line.split('=')[-1])\n            if 'snsApLfSy' in line:\n                ap_channels = int(line.split('=')[-1].split(',')[0].strip())\n                lfp_channels = int(line.split(',')[-2].strip())\n            if 'imSampRate' in line:\n                fs = float(line.split('=')[-1])\n            if rec_type in (\"ap\", \"lf\"):\n                if 'snsChanMap' in line:\n                    map = line.split('=')[-1]\n                    chans = map.split(')')[1:]\n                    for chan in chans:\n                        chan_name = chan[1:].split(';')[0]\n                        if rec_type == \"ap\":\n                            if \"AP\" in chan_name:\n                                channel_names.append(chan_name)\n                                chan_id = int(chan_name[2:])\n                                channel_ids.append(chan_id)\n                        elif rec_type == \"lf\":\n                            if \"LF\" in chan_name:\n                                channel_names.append(chan_name)\n                                chan_id = int(chan_name[2:])\n                                channel_ids.append(chan_id)\n                if 'snsShankMap' in line:\n                    map = line.split('=')[-1]\n                    chans = map.split(')')[1:]\n                    for chan in chans:\n                        chan = chan[1:]\n                        if len(chan) > 0:\n                            x_idx = int(chan.split(':')[1])\n                            y_idx = int(chan.split(':')[2])\n                            stagger = np.mod(y_idx + 0, 2) * x_pitch / 2\n                            x_pos = (1 - x_idx) * x_pitch + stagger + x_offset\n                            y_pos = y_idx * y_pitch + y_offset\n                            locations.append([x_pos, y_pos])\n    return tot_channels, ap_channels, lfp_channels, locations, channel_ids, channel_names\n"
  },
  {
    "path": "spikeextractors/extractors/spykingcircusextractors/__init__.py",
    "content": "from .spykingcircusextractors import SpykingCircusSortingExtractor, SpykingCircusRecordingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/spykingcircusextractors/spykingcircusextractors.py",
    "content": "from spikeextractors import RecordingExtractor, SortingExtractor\nfrom spikeextractors.extractors.numpyextractors import NumpyRecordingExtractor\nfrom spikeextractors.extractors.bindatrecordingextractor import BinDatRecordingExtractor\nimport numpy as np\nfrom pathlib import Path\nfrom spikeextractors.extraction_tools import check_get_unit_spike_train\n\ntry:\n    import h5py\n    HAVE_SCSX = True\nexcept ImportError:\n    HAVE_SCSX = False\n\n\nclass SpykingCircusRecordingExtractor(RecordingExtractor):\n    \"\"\"\n    RecordingExtractor for a SpykingCircus output folder\n\n    Parameters\n    ----------\n    folder_path: str or Path\n        Path to the output Spyking Circus folder or result folder\n    \"\"\"\n    extractor_name = 'SpykingCircusRecording'\n    has_default_locations = False\n    has_unscaled = False\n    installed = True  # check at class level if installed or not\n    is_writable = False\n    mode = 'folder'\n    installation_mesg = \"\"  # error message when not installed\n\n    def __init__(self, folder_path):\n        RecordingExtractor.__init__(self)\n        spykingcircus_folder = Path(folder_path)\n        listfiles = spykingcircus_folder.iterdir()\n\n        parent_folder = None\n        result_folder = None\n        for f in listfiles:\n            if f.is_dir():\n                if any([f_.suffix == '.hdf5' for f_ in f.iterdir()]):\n                    parent_folder = spykingcircus_folder\n                    result_folder = f\n\n        if parent_folder is None:\n            parent_folder = spykingcircus_folder.parent\n            for f in parent_folder.iterdir():\n                if f.is_dir():\n                    if any([f_.suffix == '.hdf5' for f_ in f.iterdir()]):\n                        result_folder = spykingcircus_folder\n\n        assert isinstance(parent_folder, Path) and isinstance(result_folder, Path), \"Not a valid spyking circus folder\"\n\n        params = None\n        params_file = None\n        for f in parent_folder.iterdir():\n            if f.suffix == '.params':\n                params = _load_params(f)\n                params_file = f\n                break\n        assert params is not None, \"Could not find the .params file\"\n        recording_name = params_file.stem\n\n        file_format = params[\"file_format\"].lower()\n        if file_format == \"numpy\":\n            recording_file = parent_folder / f\"{recording_name}.npy\"\n            self._recording = NumpyRecordingExtractor(recording_file, params[\"sampling_frequency\"])\n        elif file_format == \"raw_binary\":\n            recording_file = parent_folder / f\"{recording_name}.dat\"\n            self._recording = BinDatRecordingExtractor(recording_file, sampling_frequency=params[\"sampling_frequency\"],\n                                                       numchan=params[\"nb_channels\"], dtype=params[\"dtype\"],\n                                                       time_axis=0)\n        else:\n            raise Exception(f\"'file_format' {params['file_format']} is not supported by the \"\n                            f\"SpykingCircusRecordingExtractor\")\n\n        if params[\"mapping\"].is_file():\n            self._recording = self.load_probe_file(params[\"mapping\"])\n\n        self.params = params\n        self._kwargs = {'folder_path': str(Path(folder_path).absolute())}\n\n    def get_channel_ids(self):\n        return self._recording.get_channel_ids()\n\n    def get_num_frames(self):\n        return self._recording.get_num_frames()\n\n    def get_sampling_frequency(self):\n        return self._recording.get_sampling_frequency()\n\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        return self._recording.get_traces(channel_ids=channel_ids, start_frame=start_frame, end_frame=end_frame,\n                                          return_scaled=return_scaled)\n\n\nclass SpykingCircusSortingExtractor(SortingExtractor):\n    \"\"\"\n    SortingExtractor for SpykingCircus output folder or file\n\n    Parameters\n    ----------\n    file_or_folder_path: str or Path\n        Path to the output Spyking Circus folder, the result folder, or a specific hdf5 file in the result folder\n    load_templates: bool\n        If True, templates are loaded from Spyking Circus output\n    \"\"\"\n    extractor_name = 'SpykingCircusSorting'\n    installed = HAVE_SCSX  # check at class level if installed or not\n    is_writable = True\n    mode = 'folder'\n    installation_mesg = \"To use the SpykingCircusSortingExtractor install h5py: \\n\\n pip install h5py\\n\\n\"\n\n    def __init__(self, file_or_folder_path, load_templates=False):\n        assert self.installed, self.installation_mesg\n        SortingExtractor.__init__(self)\n        file_or_folder_path = Path(file_or_folder_path)\n\n        if file_or_folder_path.is_dir():\n            listfiles = file_or_folder_path.iterdir()\n            results = None\n            parent_folder = None\n            result_folder = None\n            for f in listfiles:\n                if f.is_dir():\n                    if any([f_.suffix == '.hdf5' for f_ in f.iterdir()]):\n                        parent_folder = file_or_folder_path\n                        result_folder = f\n\n            if parent_folder is None:\n                parent_folder = file_or_folder_path.parent\n                for f in parent_folder.iterdir():\n                    if f.is_dir():\n                        if any([f_.suffix == '.hdf5' for f_ in f.iterdir()]):\n                            result_folder = file_or_folder_path\n            # load files\n            for f in result_folder.iterdir():\n                if 'result.hdf5' in str(f):\n                    results = f\n                    result_extension = ''\n                    base_name = f.name[:f.name.find(\"result\")-1]\n                if 'result-merged.hdf5' in str(f):\n                    results = f\n                    result_extension = '-merged'\n                    base_name = f.name[:f.name.find(\"result\")-1]\n                    break\n        else:\n            assert file_or_folder_path.suffix in ['.h5', '.hdf5']\n            result_folder = file_or_folder_path.parent\n            parent_folder = result_folder.parent\n            results = file_or_folder_path\n            result_extension = results.stem[results.stem.find(\"result\") + 6:]\n            base_name = file_or_folder_path.name[:file_or_folder_path.name.find(\"result\") - 1]\n\n        assert isinstance(parent_folder, Path) and isinstance(result_folder, Path), \"Not a valid spyking circus folder\"\n\n        # load params\n        params = {}\n        for f in parent_folder.iterdir():\n            if f.suffix == '.params':\n                params = _load_params(f)\n\n        if \"sampling_frequency\" in params.keys():\n            self._sampling_frequency = params[\"sampling_frequency\"]\n\n        if results is None:\n            raise Exception(f\"{file_or_folder_path} is not a spyking circus folder\")\n\n        f_results = h5py.File(results, 'r')\n        self._spiketrains = []\n        self._unit_ids = []\n        for temp in f_results['spiketimes'].keys():\n            self._spiketrains.append(np.array(f_results['spiketimes'][temp]).astype('int64'))\n            self._unit_ids.append(int(temp.split('_')[-1]))\n\n        if load_templates:\n            try:\n                import scipy\n            except:\n                raise ImportError(\"'scipy' is needed to load templates from Spyking Circus\")\n\n            filename = result_folder / f\"{base_name}.templates{result_extension}.hdf5\"\n            with h5py.File(filename, 'r', libver='earliest') as f:\n                temp_x = f.get('temp_x')[:].ravel()\n                temp_y = f.get('temp_y')[:].ravel()\n                temp_data = f.get('temp_data')[:].ravel()\n                N_e, N_t, nb_templates = f.get('temp_shape')[:].ravel().astype(np.int32)\n            templates = scipy.sparse.csc_matrix((temp_data, (temp_x, temp_y)), shape=(N_e * N_t, nb_templates))\n            templates = np.array([templates[:, i].toarray().reshape(N_e, N_t) for i in range(templates.shape[1])])\n\n            templates = templates[:len(templates)//2]\n            for u_i, unit in enumerate(self.get_unit_ids()):\n                self.set_unit_property(unit, 'template', templates[u_i])\n\n        self._kwargs = {'file_or_folder_path': str(Path(file_or_folder_path).absolute()),\n                        'load_templates': load_templates}\n\n    def get_unit_ids(self):\n        return list(self._unit_ids)\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n\n        times = self._spiketrains[self.get_unit_ids().index(unit_id)]\n        inds = np.where((start_frame <= times) & (times < end_frame))\n        return times[inds]\n\n    @staticmethod\n    def write_sorting(sorting, save_path):\n        assert HAVE_SCSX, SpykingCircusSortingExtractor.installation_mesg\n        save_path = Path(save_path)\n        if save_path.is_dir():\n            save_path = save_path / 'data.result.hdf5'\n        elif save_path.suffix == '.hdf5':\n            if not str(save_path).endswith('result.hdf5') or not str(save_path).endswith('result-merged.hdf5'):\n                raise AttributeError(\"'save_path' is either a folder or an hdf5 file \"\n                                     \"ending with 'result.hdf5' or 'result-merged.hdf5\")\n        else:\n            save_path.mkdir()\n            save_path = save_path / 'data.result.hdf5'\n        F = h5py.File(save_path, 'w')\n        spiketimes = F.create_group('spiketimes')\n\n        for id in sorting.get_unit_ids():\n            spiketimes.create_dataset('tmp_' + str(id), data=sorting.get_unit_spike_train(id))\n\n\ndef _load_params(params_file):\n    params = {}\n    with params_file.open('r') as f:\n        for r in f.readlines():\n            if 'sampling_rate' in r:\n                sampling_frequency = r.split('=')[-1]\n                if '#' in sampling_frequency:\n                    sampling_frequency = sampling_frequency[:sampling_frequency.find('#')]\n                sampling_frequency = sampling_frequency.strip(\" \").strip(\"\\n\")\n                sampling_frequency = float(sampling_frequency)\n                params[\"sampling_frequency\"] = sampling_frequency\n            if 'file_format' in r:\n                file_format = r.split('=')[-1]\n                if '#' in file_format:\n                    file_format = file_format[:file_format.find('#')]\n                file_format = file_format.strip(\" \").strip(\"\\n\")\n                params[\"file_format\"] = file_format\n            if 'nb_channels' in r:\n                nb_channels = r.split('=')[-1]\n                if '#' in nb_channels:\n                    nb_channels = nb_channels[:nb_channels.find('#')]\n                nb_channels = nb_channels.strip(\" \").strip(\"\\n\")\n                params[\"nb_channels\"] = int(nb_channels)\n            if 'data_dtype' in r:\n                dtype = r.split('=')[-1]\n                if '#' in dtype:\n                    dtype = dtype[:dtype.find('#')]\n                dtype = dtype.strip(\" \").strip(\"\\n\")\n                params[\"dtype\"] = dtype\n            if 'mapping' in r:\n                mapping = r.split('=')[-1]\n                if '#' in mapping:\n                    mapping = mapping[:mapping.find('#')]\n                mapping = mapping.strip(\" \").strip(\"\\n\")\n                params[\"mapping\"] = Path(mapping)\n    return params\n"
  },
  {
    "path": "spikeextractors/extractors/tridescloussortingextractor/__init__.py",
    "content": "from .tridescloussortingextractor import TridesclousSortingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/tridescloussortingextractor/tridescloussortingextractor.py",
    "content": "from spikeextractors import SortingExtractor\nfrom pathlib import Path\nfrom spikeextractors.extraction_tools import check_get_unit_spike_train\n\ntry:\n    import tridesclous as tdc\n\n    HAVE_TDC = True\nexcept ImportError:\n    HAVE_TDC = False\n\n\nclass TridesclousSortingExtractor(SortingExtractor):\n    extractor_name = 'TridesclousSorting'\n    installed = HAVE_TDC  # check at class level if installed or not\n    is_writable = False\n    mode = 'folder'\n    installation_mesg = \"To use the TridesclousSortingExtractor install tridesclous: \\n\\n pip install tridesclous\\n\\n\"  # error message when not installed\n\n    def __init__(self, folder_path, chan_grp=None):\n        assert self.installed, self.installation_mesg\n        tdc_folder = Path(folder_path)\n        SortingExtractor.__init__(self)\n        \n        dataio = tdc.DataIO(str(tdc_folder))\n        if chan_grp is None:\n            # if chan_grp is not provided, take the first one if unique\n            chan_grps = list(dataio.channel_groups.keys())\n            assert len(chan_grps) == 1, 'There are several groups in the folder, specify chan_grp=...'\n            chan_grp = chan_grps[0]\n\n        self.chan_grp = chan_grp\n        \n        catalogue = dataio.load_catalogue(name='initial', chan_grp=chan_grp)\n        \n        labels = catalogue['clusters']['cluster_label']\n        labels = labels[labels >= 0]\n        self._unit_ids = list(labels)\n        # load all spike in memory (this avoid to lock the folder with memmap throug dataio\n        self._all_spikes = dataio.get_spikes(seg_num=0, chan_grp=self.chan_grp, i_start=None, i_stop=None).copy()\n    \n        self._sampling_frequency = dataio.sample_rate\n        self._kwargs = {'folder_path': str(Path(folder_path).absolute()), 'chan_grp': chan_grp}\n\n    def get_unit_ids(self):\n        return self._unit_ids\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n        start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame)\n        spikes = self._all_spikes\n        spikes = spikes[spikes['cluster_label'] == unit_id]\n        spike_times = spikes['index']\n        if start_frame is not None:\n            spike_times = spike_times[spike_times >= start_frame]\n        if end_frame is not None:\n            spike_times = spike_times[spike_times < end_frame]\n        return spike_times.copy()\n"
  },
  {
    "path": "spikeextractors/extractors/waveclussortingextractor/__init__.py",
    "content": "from .waveclussortingextractor import WaveClusSortingExtractor"
  },
  {
    "path": "spikeextractors/extractors/waveclussortingextractor/waveclussortingextractor.py",
    "content": "from pathlib import Path\nfrom typing import Union\n\nimport numpy as np\n\nfrom spikeextractors.extractors.matsortingextractor.matsortingextractor import MATSortingExtractor\nfrom spikeextractors.extraction_tools import check_get_unit_spike_train\n\nPathType = Union[str, Path]\n\n\nclass WaveClusSortingExtractor(MATSortingExtractor):\n    extractor_name = \"WaveClusSortingExtractor\"\n    installation_mesg = \"\"  # error message when not installed\n\n    def __init__(self, file_path: PathType):\n        super().__init__(file_path)\n        cluster_classes = self._getfield(\"cluster_class\")\n        classes = cluster_classes[:, 0]\n        spike_times = cluster_classes[:, 1]\n        par = self._getfield(\"par\")\n        sample_rate = par[0, 0][np.where(np.array(par.dtype.names) == 'sr')[0][0]][0][0]\n\n        self.set_sampling_frequency(sample_rate)\n        self._unit_ids = np.unique(classes[classes > 0]).astype('int')\n\n        self._spike_trains = {}\n        for uid in self._unit_ids:\n            mask = (classes == uid)\n            self._spike_trains[uid] = np.rint(spike_times[mask]*(sample_rate/1000))\n        self._unsorted_train = np.rint(spike_times[classes == 0] * (sample_rate / 1000))\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n        start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame)\n\n        start_frame = start_frame or 0\n        end_frame = end_frame or np.infty\n        st = self._spike_trains[unit_id]\n        return st[(st >= start_frame) & (st < end_frame)]\n\n    def get_unit_ids(self):\n        return self._unit_ids.tolist()\n\n    def get_unsorted_spike_train(self, start_frame=None, end_frame=None):\n        start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame)\n\n        start_frame = start_frame or 0\n        end_frame = end_frame or np.infty\n        u = self._unsorted_train\n        return u[(u >= start_frame) & (u < end_frame)]\n"
  },
  {
    "path": "spikeextractors/extractors/yassextractors/__init__.py",
    "content": "from .yassextractors import YassSortingExtractor\n"
  },
  {
    "path": "spikeextractors/extractors/yassextractors/yassextractors.py",
    "content": "import numpy as np\nfrom pathlib import Path\n\nfrom spikeextractors import SortingExtractor\nfrom spikeextractors.extractors.numpyextractors import NumpyRecordingExtractor\nfrom spikeextractors.extraction_tools import check_get_unit_spike_train\n\ntry:\n    import yaml\n    HAVE_YASS = True\nexcept:\n    HAVE_YASS = False\n\n\nclass YassSortingExtractor(SortingExtractor):\n\n    extractor_name = 'YassSorting'\n    mode = 'folder'\n    installed = HAVE_YASS  # check at class level if installed or not\n\n    has_default_locations = False\n    is_writable = False\n    installation_mesg = \"To use the Yass extractor, install pyyaml: \\n\\n pip install pyyaml\\n\\n\"  # error message when not installed\n    \n    \n    def __init__(self, folder_path):\n        assert self.installed, self.installation_mesg\n        SortingExtractor.__init__(self)\n\n        self.root_dir = folder_path\n        r = Path(self.root_dir)\n\n        self.fname_spike_train = r / 'tmp' / 'output' / 'spike_train.npy'\n        self.fname_templates = r /'tmp' / 'output' / 'templates' / 'templates_0sec.npy'\n        self.fname_config = r / 'config.yaml'\n        \n        \n        # set defaults to None so they are only loaded if user requires them\n        \n        self.spike_train = None\n        self.temps = None\n\n        # Read CONFIG File\n        with open(self.fname_config, 'r') as stream:\n            self.config = yaml.safe_load(stream)\n        \n        self._sampling_frequency = self.config['recordings']['sampling_rate']\n\n    def get_unit_ids(self):\n\n        if self.spike_train is None:\n            self.spike_train = np.load(self.fname_spike_train)\n        \n        unit_ids = np.unique(self.spike_train[:,1])\n        \n        return unit_ids\n    \n    def get_temps(self):\n\n        # Electrical images/templates.\n        \n        if self.temps is None:\n            self.temps = np.load(self.fname_templates)\n                    \n        return self.temps\n\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n\n        \"\"\"Code to extract spike frames from the specified unit.\n        \"\"\"\n\n        if self.spike_train is None:\n            self.spike_train = np.load(self.fname_spike_train)\n            \n        # find unit id spike times\n        idx = np.where(self.spike_train[:,1]==unit_id)\n        spike_times = self.spike_train[idx,0].squeeze()\n\n        # find spike times\n        if start_frame is None:\n            start_frame = 0\n        if end_frame is None:\n            end_frame = 1E50 # use large time\n            \n        idx2 = np.where(np.logical_and(spike_times>=start_frame, spike_times<end_frame))[0]\n        spike_times = spike_times[idx2]\n        \n        return spike_times\n    \n"
  },
  {
    "path": "spikeextractors/multirecordingchannelextractor.py",
    "content": "from .recordingextractor import RecordingExtractor\nfrom .extraction_tools import check_get_traces_args\nimport numpy as np\nimport warnings\n\n\n# Concatenates the given recordings by channel\nclass MultiRecordingChannelExtractor(RecordingExtractor):\n    def __init__(self, recordings, groups=None):\n        self._recordings = recordings\n        self._all_channel_ids = []\n        self._channel_map = {}\n\n        # Sampling frequency based off the initial extractor\n        self._first_recording = recordings[0]\n        self._sampling_frequency = self._first_recording.get_sampling_frequency()\n        self._num_frames = self._first_recording.get_num_frames()\n\n        use_times = True\n        if np.all([rec._times is not None for rec in self._recordings]):\n            times_0 = self._recordings[0]._times\n            for rec in self._recordings[1:]:\n                times_i = rec._times\n                if not np.allclose(times_0, times_i):\n                    use_times = False\n                    warnings.warn(\"The recordings have different times! Reset times with the \"\n                                  \"'set_times() function\")\n        elif np.all([rec._times is not None for rec in self._recordings]):\n            warnings.warn(\"Not all the recordings have times! Reset times with the \"\n                          \"'set_times() function\")\n        else:\n            use_times = False\n\n        # Test if all recording extractors have same sampling frequency\n        for i, recording in enumerate(recordings[1:]):\n            sampling_frequency = recording.get_sampling_frequency()\n            if self._sampling_frequency != sampling_frequency:\n                raise ValueError(\"Inconsistent sampling frequency between extractor 0 and extractor \" + str(i + 1))\n\n        # set channel map for new channel ids to old channel ids\n        new_channel_id = 0\n        for r_i, recording in enumerate(self._recordings):\n            channel_ids = recording.get_channel_ids()\n            for channel_id in channel_ids:\n                self._all_channel_ids.append(new_channel_id)\n                self._channel_map[new_channel_id] = {'recording': r_i, 'channel_id': channel_id}\n                new_channel_id += 1\n\n        RecordingExtractor.__init__(self)\n\n        if use_times:\n            self.copy_times(self._recordings[0])\n\n        # set group information for channels if available\n        if groups is not None:\n            if len(groups) == len(recordings):\n                group_values = []\n                for i, group in enumerate(groups):\n                    recording = recordings[i]\n                    channel_ids = recording.get_channel_ids()\n                    recording_groups = [group] * len(channel_ids)\n                    group_values += recording_groups\n\n                self.set_channel_groups(groups=group_values)\n            else:\n                raise ValueError(\"recordings and groups must have same length\")\n\n        # set channel locations\n        locations = np.empty([0, 2])\n        for i, recording in enumerate(recordings):\n            locations = np.vstack((locations, recording.get_channel_locations()))\n        self.set_channel_locations(locations)\n\n        #set all normal properties\n        for channel_id in self.get_channel_ids():\n            recording = self._recordings[self._channel_map[channel_id]['recording']]\n            channel_id_recording = self._channel_map[channel_id]['channel_id']\n            for property_name in recording.get_channel_property_names(channel_id_recording):\n                if property_name not in (\"group\", \"location\"):\n                    value = recording.get_channel_property(channel_id_recording, property_name)\n                    self.set_channel_property(channel_id=channel_id, property_name=property_name, value=value)\n\n        # avoid rescaling twice\n        self.clear_channel_gains()\n        self.clear_channel_offsets()\n\n        self.is_filtered = self._first_recording.is_filtered\n        self.has_unscaled = self._first_recording.has_unscaled\n\n        self._kwargs = {'recordings': [rec.make_serialized_dict() for rec in recordings], 'groups': groups}\n\n    @property\n    def recordings(self):\n        return self._recordings\n\n    @check_get_traces_args\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        traces = []\n        if channel_ids is not None:\n            for channel_id in channel_ids:\n                recording = self._recordings[self._channel_map[channel_id]['recording']]\n                channel_id_recording = self._channel_map[channel_id]['channel_id']\n                traces_recording = recording.get_traces(channel_ids=[channel_id_recording], start_frame=start_frame,\n                                                        end_frame=end_frame, return_scaled=return_scaled)\n                traces.append(traces_recording)\n        else:\n            for recording in self._recordings:\n                traces_all_recording = recording.get_traces(channel_ids=channel_ids, start_frame=start_frame,\n                                                            end_frame=end_frame, return_scaled=return_scaled)\n                traces.append(traces_all_recording)\n        return np.concatenate(traces, axis=0)\n\n    def get_channel_ids(self):\n        return self._all_channel_ids\n\n    def get_num_frames(self):\n        return self._num_frames\n\n    def get_sampling_frequency(self):\n        return self._sampling_frequency\n\n\ndef concatenate_recordings_by_channel(recordings, groups=None):\n    \"\"\"\n    Concatenates recordings together by channel. The order of the recordings\n    determines the order of the channels in the concatenated recording.\n\n    Parameters\n    ----------\n    recordings: list\n        The list of RecordingExtractors to be concatenated by channel.\n    groups: list\n        A list of ints corresponding to the group identity of each recording's\n        channel ids.\n\n    Returns\n    -------\n    recording: MultiRecordingChannelExtractor\n        The concatenated recording extractors enscapsulated in the\n        MultiRecordingChannelExtractor object (which is also a recording extractor)\n    \"\"\"\n    return MultiRecordingChannelExtractor(\n        recordings=recordings,\n        groups=groups,\n    )\n"
  },
  {
    "path": "spikeextractors/multirecordingtimeextractor.py",
    "content": "from .recordingextractor import RecordingExtractor\nfrom .extraction_tools import check_get_traces_args, check_get_ttl_args\nimport numpy as np\n\n\n# Concatenates the given recordings by time\nclass MultiRecordingTimeExtractor(RecordingExtractor):\n    def __init__(self, recordings, epoch_names=None):\n        self._recordings = recordings\n\n        # Num channels and sampling frequency based off the initial extractor\n        self._first_recording = recordings[0]\n        self._num_channels = self._first_recording.get_num_channels()\n        self._channel_ids = self._first_recording.get_channel_ids()\n        self._sampling_frequency = self._first_recording.get_sampling_frequency()\n\n        if epoch_names is None:\n            epoch_names = [str(i) for i in range(len(recordings))]\n\n        RecordingExtractor.__init__(self)\n\n        # Add all epochs to the epochs data structure\n        start_frames = 0\n        for i, epoch_name in enumerate(epoch_names):\n            num_frames = recordings[i].get_num_frames()\n            self.add_epoch(epoch_name, start_frames, start_frames + num_frames)\n            start_frames += num_frames\n\n        # Test if all recording extractors have same num channels and sampling frequency\n        for i, recording in enumerate(recordings[1:]):\n            channel_ids = recording.get_channel_ids()\n            sampling_frequency = recording.get_sampling_frequency()\n\n            if self._channel_ids != channel_ids:\n                raise ValueError(\"Inconsistent channel ids between extractor 0 and extractor \" + str(i + 1))\n            if self._sampling_frequency != sampling_frequency:\n                raise ValueError(\"Inconsistent sampling frequency between extractor 0 and extractor \" + str(i + 1))\n\n        self._start_frames = []\n        self._end_frames = []\n        self._start_times = []\n        self._end_times = []\n        ff = 0\n        tt = 0.\n        for recording in self._recordings:\n            self._start_frames.append(ff)\n            self._start_times.append(tt)\n            ff = ff + recording.get_num_frames()\n            tt = tt + recording.frame_to_time(recording.get_num_frames() - 1) - recording.frame_to_time(0)\n            self._end_frames.append(ff)\n            self._end_times.append(tt)\n        self._num_frames = ff\n\n        # Set the channel properties based on the first recording extractor\n        self.copy_channel_properties(self._first_recording)\n\n        # avoid rescaling twice\n        self.clear_channel_gains()\n        self.clear_channel_offsets()\n\n        self.is_filtered = self._first_recording.is_filtered\n        self.has_unscaled = self._first_recording.has_unscaled\n\n        self._kwargs = {'recordings': [rec.make_serialized_dict() for rec in recordings], 'epoch_names': epoch_names}\n\n    @property\n    def recordings(self):\n        return self._recordings\n\n    def _find_section_for_frame(self, frame):\n        start_frames = np.array(self._start_frames)\n        end_frames = np.array(self._end_frames)\n        inds = np.where((frame >= start_frames) & (frame < end_frames))[0]\n        if len(inds) == 0:\n            # can only happen if frame == end_frame\n            ind = len(self._start_frames) - 1\n        else:\n            ind = inds[0]\n        return self._recordings[ind], ind, frame - self._start_frames[ind]\n\n    def _find_section_for_time(self, time):\n        start_times = np.array(self._start_times)\n        end_times = np.array(self._end_times)\n        inds = np.where((time >= start_times) & (time < end_times))[0]\n        if len(inds) == 0:\n            # can only happen if frame == end_frame\n            ind = len(self._start_times) - 1\n        else:\n            ind = inds[0]\n        return self._recordings[ind], ind, time - self._start_times[ind]\n\n    @check_get_traces_args\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        recording1, i_sec1, i_start_frame = self._find_section_for_frame(start_frame)\n        _, i_sec2, i_end_frame = self._find_section_for_frame(end_frame)\n        if i_sec1 == i_sec2:\n            return recording1.get_traces(channel_ids=channel_ids, start_frame=i_start_frame, end_frame=i_end_frame,\n                                         return_scaled=return_scaled)\n        traces = []\n        traces.append(\n            self._recordings[i_sec1].get_traces(channel_ids=channel_ids, start_frame=i_start_frame,\n                                                end_frame=self._recordings[i_sec1].get_num_frames(),\n                                                return_scaled=return_scaled)\n        )\n        for i_sec in range(i_sec1 + 1, i_sec2):\n            traces.append(\n                self._recordings[i_sec].get_traces(channel_ids=channel_ids, return_scaled=return_scaled)\n            )\n        if i_end_frame != 0:\n            traces.append(\n                self._recordings[i_sec2].get_traces(channel_ids=channel_ids, start_frame=0, end_frame=i_end_frame,\n                                                    return_scaled=return_scaled)\n            )\n        return np.concatenate(traces, axis=1)\n\n    @check_get_ttl_args\n    def get_ttl_events(self, start_frame=None, end_frame=None, channel_id=0):\n        recording1, i_sec1, i_start_frame = self._find_section_for_frame(start_frame)\n        _, i_sec2, i_end_frame = self._find_section_for_frame(end_frame)\n\n        if i_sec1 == i_sec2:\n            ttl_frames, ttl_states = recording1.get_ttl_events(start_frame=i_start_frame,\n                                                               end_frame=i_end_frame,\n                                                               channel_id=channel_id)\n            ttl_frames += self._start_frames[i_sec1]\n        else:\n            ttl_frames, ttl_states = [], []\n\n            ttl_frames_1, ttl_states_1 = self._recordings[i_sec1].get_ttl_events(\n                start_frame=i_start_frame,\n                end_frame=self._recordings[i_sec1].get_num_frames(),\n                channel_id=channel_id)\n            ttl_frames_1 = (ttl_frames_1 + self._start_frames[i_sec1]).astype('int64')\n            ttl_frames.append(ttl_frames_1)\n            ttl_states.append(ttl_states_1)\n\n            for i_sec in range(i_sec1 + 1, i_sec2):\n                ttl_frames_i, ttl_states_i = self._recordings[i_sec].get_ttl_events(channel_id=channel_id)\n                ttl_frames_i = (ttl_frames_i + self._start_frames[i_sec]).astype('int64')\n                ttl_frames.append(ttl_frames_i)\n                ttl_states.append(ttl_states_i)\n\n            ttl_frames_2, ttl_states_2 = self._recordings[i_sec2].get_ttl_events(start_frame=0,\n                                                                                 end_frame=i_end_frame,\n                                                                                 channel_id=channel_id)\n            ttl_frames_2 = (ttl_frames_2 + self._start_frames[i_sec2]).astype('int64')\n            ttl_frames.append(ttl_frames_2)\n            ttl_states.append(ttl_states_2)\n\n            ttl_frames = np.concatenate(np.array(ttl_frames))\n            ttl_states = np.concatenate(np.array(ttl_states))\n\n        return ttl_frames, ttl_states\n\n    def get_channel_ids(self):\n        return self._channel_ids\n\n    def get_num_frames(self):\n        return self._num_frames\n\n    def get_sampling_frequency(self):\n        return self._sampling_frequency\n\n    def frame_to_time(self, frame):\n        recording, i_epoch, rel_frame = self._find_section_for_frame(frame)\n        return np.round(recording.frame_to_time(rel_frame) + self._start_times[i_epoch], 6)\n\n    def time_to_frame(self, time):\n        recording, i_epoch, rel_time = self._find_section_for_time(time)\n        return (recording.time_to_frame(rel_time) + self._start_frames[i_epoch]).astype('int64')\n\n\ndef concatenate_recordings_by_time(recordings, epoch_names=None):\n    \"\"\"\n    Concatenates recordings together by time. The order of the recordings\n    determines the order of the time series in the concatenated recording.\n\n    Parameters\n    ----------\n    recordings: list\n        The list of RecordingExtractors to be concatenated by time\n    epoch_names: list\n        The list of strings corresponding to the names of recording time period.\n\n    Returns\n    -------\n    recording: MultiRecordingTimeExtractor\n        The concatenated recording extractors enscapsulated in the\n        MultiRecordingTimeExtractor object (which is also a recording extractor)\n    \"\"\"\n    return MultiRecordingTimeExtractor(\n        recordings=recordings,\n        epoch_names=epoch_names,\n    )\n"
  },
  {
    "path": "spikeextractors/multisortingextractor.py",
    "content": "from .sortingextractor import SortingExtractor\nimport numpy as np\nfrom .extraction_tools import check_get_unit_spike_train\n\n\n# Encapsulates a grouping of non-continuous sorting extractors\nclass MultiSortingExtractor(SortingExtractor):\n    def __init__(self, sortings):\n        SortingExtractor.__init__(self)\n        self._sortings = sortings\n        self._all_unit_ids = []\n        self._unit_map = {}\n\n        u_id = 0\n        for s_i, sorting in enumerate(self._sortings):\n            unit_ids = sorting.get_unit_ids()\n            for unit_id in unit_ids:\n                self._all_unit_ids.append(u_id)\n                self._unit_map[u_id] = {'sorting_id': s_i, 'unit_id': unit_id}\n                u_id += 1\n        self._kwargs = {'sortings': [sort.make_serialized_dict() for sort in sortings]}\n\n    @property\n    def sortings(self):\n        return self._sortings\n\n    def get_unit_ids(self):\n        return list(self._all_unit_ids)\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n        sorting_id = self._unit_map[unit_id]['sorting_id']\n        unit_id_sorting = self._unit_map[unit_id]['unit_id']\n        return self._sortings[sorting_id].get_unit_spike_train(unit_id_sorting, start_frame, end_frame)\n\n    def set_sampling_frequency(self, sampling_frequency):\n        for sorting in self._sortings:\n            sorting.set_sampling_frequency(sampling_frequency)\n\n    def get_sampling_frequency(self):\n        return self._sortings[0].get_sampling_frequency()\n\n    def set_unit_property(self, unit_id, property_name, value):\n        if unit_id not in self._unit_map.keys():\n            raise ValueError(\"Non-valid unit_id\")\n        sorting_id = self._unit_map[unit_id]['sorting_id']\n        unit_id_sorting = self._unit_map[unit_id]['unit_id']\n        self._sortings[sorting_id].set_unit_property(unit_id_sorting, property_name, value)\n\n    def get_unit_property(self, unit_id, property_name):\n        if unit_id not in self._unit_map.keys():\n            raise ValueError(\"Non-valid unit_id\")\n        sorting_id = self._unit_map[unit_id]['sorting_id']\n        unit_id_sorting = self._unit_map[unit_id]['unit_id']\n        return self._sortings[sorting_id].get_unit_property(unit_id_sorting, property_name)\n\n    def get_unit_property_names(self, unit_id):\n        sorting_id = self._unit_map[unit_id]['sorting_id']\n        unit_id_sorting = self._unit_map[unit_id]['unit_id']\n        property_names = self._sortings[sorting_id].get_unit_property_names(unit_id_sorting)\n        return property_names\n\n    def clear_unit_property(self, unit_id, property_name):\n        if unit_id not in self._unit_map.keys():\n            raise ValueError(\"Non-valid unit_id\")\n        sorting_id = self._unit_map[unit_id]['sorting_id']\n        unit_id_sorting = self._unit_map[unit_id]['unit_id']\n        self._sortings[sorting_id].clear_unit_property(unit_id_sorting, property_name)\n\n    def get_unit_spike_features(self, unit_id, feature_name, start_frame=None, end_frame=None):\n        start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame)\n        if unit_id not in self._unit_map.keys():\n            raise ValueError(\"Non-valid unit_id\")\n        sorting_id = self._unit_map[unit_id]['sorting_id']\n        unit_id_sorting = self._unit_map[unit_id]['unit_id']\n        return self._sortings[sorting_id].get_unit_spike_features(unit_id_sorting, feature_name, start_frame=start_frame, end_frame=end_frame)\n\n    def get_unit_spike_feature_names(self, unit_id):\n        if isinstance(unit_id, (int, np.integer)):\n            if unit_id in self.get_unit_ids():\n                if unit_id not in self._unit_map.keys():\n                    raise ValueError(\"Non-valid unit_id\")\n                sorting_id = self._unit_map[unit_id]['sorting_id']\n                unit_id_sorting = self._unit_map[unit_id]['unit_id']\n                feature_names = sorted(self._sortings[sorting_id].get_unit_spike_feature_names(unit_id_sorting))\n                return feature_names\n            else:\n                raise ValueError(\"Non-valid unit_id\")\n        else:\n            raise ValueError(\"unit_id must be an int\")\n\n    def set_unit_spike_features(self, unit_id, feature_name, value, indexes=None):\n        if unit_id not in self._unit_map.keys():\n            raise ValueError(\"Non-valid unit_id\")\n        sorting_id = self._unit_map[unit_id]['sorting_id']\n        unit_id_sorting = self._unit_map[unit_id]['unit_id']\n        self._sortings[sorting_id].set_unit_spike_features(unit_id_sorting, feature_name, value, indexes)\n\n    def clear_unit_spike_features(self, unit_id, feature_name):\n        if unit_id not in self._unit_map.keys():\n            raise ValueError(\"Non-valid unit_id\")\n        sorting_id = self._unit_map[unit_id]['sorting_id']\n        unit_id_sorting = self._unit_map[unit_id]['unit_id']\n        self._sortings[sorting_id].clear_unit_spike_features(unit_id_sorting, feature_name)\n\n\ndef concatenate_sortings(sortings):\n    \"\"\"\n    Concatenates sortings together. The sortings should be non-continuous\n\n    Parameters\n    ----------\n    sortings: list\n        The list of SortingExtractors to be concatenated\n\n    Returns\n    -------\n    recording: MultiSortingExtractor\n        The concatenated sorting extractors enscapsulated in the\n        MultiSortingExtractor object (which is also a sorting extractor)\n    \"\"\"\n    return MultiSortingExtractor(\n        sortings=sortings,\n    )\n"
  },
  {
    "path": "spikeextractors/recordingextractor.py",
    "content": "from abc import ABC, abstractmethod\nimport numpy as np\nfrom copy import deepcopy\n\nfrom .extraction_tools import load_probe_file, save_to_probe_file, write_to_binary_dat_format, \\\n    write_to_h5_dataset_format, get_sub_extractors_by_property, cast_start_end_frame\nfrom .baseextractor import BaseExtractor\n\n\nclass RecordingExtractor(ABC, BaseExtractor):\n    \"\"\"A class that contains functions for extracting important information\n    from recorded extracellular data. It is an abstract class so all\n    functions with the @abstractmethod tag must be implemented for the\n    initialization to work.\n    \"\"\"\n\n    _default_filename = \"spikeinterface_recording\"\n\n    def __init__(self):\n        BaseExtractor.__init__(self)\n        self._key_properties = {'group': None, 'location': None, 'gain': None, 'offset': None}\n        self.is_filtered = False\n\n    @abstractmethod\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        \"\"\"This function extracts and returns a trace from the recorded data from the\n        given channels ids and the given start and end frame. It will return\n        traces from within three ranges:\n\n            [start_frame, start_frame+1, ..., end_frame-1]\n            [start_frame, start_frame+1, ..., final_recording_frame - 1]\n            [0, 1, ..., end_frame-1]\n            [0, 1, ..., final_recording_frame - 1]\n\n        if both start_frame and end_frame are given, if only start_frame is\n        given, if only end_frame is given, or if neither start_frame or end_frame\n        are given, respectively. Traces are returned in a 2D array that\n        contains all of the traces from each channel with dimensions\n        (num_channels x num_frames). In this implementation, start_frame is inclusive\n        and end_frame is exclusive conforming to numpy standards.\n\n        Parameters\n        ----------\n        channel_ids: array_like\n            A list or 1D array of channel ids (ints) from which each trace will be extracted.\n        start_frame: int\n            The starting frame of the trace to be returned (inclusive).\n        end_frame: int\n            The ending frame of the trace to be returned (exclusive).\n        return_scaled: bool\n            If True, traces are returned after scaling (using gain/offset).\n            If False, the raw traces are returned.\n\n        Returns\n        -------\n        traces: numpy.ndarray\n            A 2D array that contains all of the traces from each channel.\n            Dimensions are: (num_channels x num_frames)\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def get_num_frames(self):\n        \"\"\"This function returns the number of frames in the recording\n\n        Returns\n        -------\n        num_frames: int\n            Number of frames in the recording (duration of recording)\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def get_sampling_frequency(self):\n        \"\"\"This function returns the sampling frequency in units of Hz.\n\n        Returns\n        -------\n        fs: float\n            Sampling frequency of the recordings in Hz\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def get_channel_ids(self):\n        \"\"\"Returns the list of channel ids. If not specified, the range from 0 to num_channels - 1 is returned.\n\n        Returns\n        -------\n        channel_ids: list\n            Channel list\n        \"\"\"\n        pass\n\n    def get_num_channels(self):\n        \"\"\"This function returns the number of channels in the recording.\n\n        Returns\n        -------\n        num_channels: int\n            Number of channels in the recording\n        \"\"\"\n        return len(self.get_channel_ids())\n\n    def get_dtype(self, return_scaled=True):\n        \"\"\"This function returns the traces dtype\n\n        Parameters\n        ----------\n        return_scaled: bool\n            If False and the recording extractor has unscaled traces, it returns the dtype of unscaled traces.\n            If True (default) it returns the dtype of the scaled traces\n\n        Returns\n        -------\n        dtype: np.dtype\n            The dtype of the traces\n        \"\"\"\n        return self.get_traces(channel_ids=[self.get_channel_ids()[0]], start_frame=0, end_frame=1,\n                               return_scaled=return_scaled).dtype\n\n    def set_times(self, times):\n        \"\"\"This function sets the recording times (in seconds) for each frame\n\n        Parameters\n        ----------\n        times: array-like\n            The times in seconds for each frame\n        \"\"\"\n        assert len(times) == self.get_num_frames(), \"'times' should have the same length of the \" \\\n                                                    \"number of frames\"\n        self._times = times.astype('float64')\n\n    def copy_times(self, extractor):\n        \"\"\"This function copies times from another extractor.\n\n        Parameters\n        ----------\n        extractor: BaseExtractor\n            The extractor from which the epochs will be copied\n        \"\"\"\n        if extractor._times is not None:\n            self.set_times(deepcopy(extractor._times))\n\n    def frame_to_time(self, frames):\n        \"\"\"This function converts user-inputted frame indexes to times with units of seconds.\n\n        Parameters\n        ----------\n        frames: float or array-like\n            The frame or frames to be converted to times\n\n        Returns\n        -------\n        times: float or array-like\n            The corresponding times in seconds\n        \"\"\"\n        # Default implementation\n        if self._times is None:\n            return np.round(frames / self.get_sampling_frequency(), 6)\n        else:\n            return self._times[frames]\n\n    def time_to_frame(self, times):\n        \"\"\"This function converts a user-inputted times (in seconds) to a frame indexes.\n\n        Parameters\n        -------\n        times: float or array-like\n            The times (in seconds) to be converted to frame indexes\n\n        Returns\n        -------\n        frames: float or array-like\n            The corresponding frame indexes\n        \"\"\"\n        # Default implementation\n        if self._times is None:\n            return np.round(times * self.get_sampling_frequency()).astype('int64')\n        else:\n            return np.searchsorted(self._times, times).astype('int64')\n\n    def get_snippets(self, reference_frames, snippet_len, channel_ids=None, return_scaled=True):\n        \"\"\"This function returns data snippets from the given channels that\n        are starting on the given frames and are the length of the given snippet\n        lengths before and after.\n\n        Parameters\n        ----------\n        reference_frames: array_like\n            A list or array of frames that will be used as the reference frame of each snippet.\n        snippet_len: int or tuple\n            If int, the snippet will be centered at the reference frame and\n            and return half before and half after of the length. If tuple,\n            it will return the first value of before frames and the second value\n            of after frames around the reference frame (allows for asymmetry).\n        channel_ids: array_like\n            A list or array of channel ids (ints) from which each trace will be\n            extracted\n        return_scaled: bool\n            If True, snippets are returned after scaling (using gain/offset).\n            If False, the raw traces are returned.\n\n        Returns\n        -------\n        snippets: numpy.ndarray\n            Returns a list of the snippets as numpy arrays.\n            The length of the list is len(reference_frames)\n            Each array has dimensions: (num_channels x snippet_len)\n            Out-of-bounds cases should be handled by filling in zeros in the snippet\n        \"\"\"\n        # Default implementation\n        if isinstance(snippet_len, (tuple, list, np.ndarray)):\n            snippet_len_before = int(snippet_len[0])\n            snippet_len_after = int(snippet_len[1])\n        else:\n            snippet_len_before = int((snippet_len + 1) / 2)\n            snippet_len_after = int(snippet_len - snippet_len_before)\n\n        if channel_ids is None:\n            channel_ids = self.get_channel_ids()\n\n        num_snippets = len(reference_frames)\n        num_channels = len(channel_ids)\n        num_frames = self.get_num_frames()\n        snippet_len_total = int(snippet_len_before + snippet_len_after)\n        snippets = np.zeros((num_snippets, num_channels, snippet_len_total), dtype=self.get_dtype(return_scaled))\n\n        for i in range(num_snippets):\n            snippet_chunk = np.zeros((num_channels, snippet_len_total), dtype=self.get_dtype(return_scaled))\n            if 0 <= reference_frames[i] < num_frames:\n                snippet_range = np.array([int(reference_frames[i]) - snippet_len_before,\n                                          int(reference_frames[i]) + snippet_len_after])\n                snippet_buffer = np.array([0, snippet_len_total], dtype='int')\n                # The following handles the out-of-bounds cases\n                if snippet_range[0] < 0:\n                    snippet_buffer[0] -= snippet_range[0]\n                    snippet_range[0] -= snippet_range[0]\n                if snippet_range[1] >= num_frames:\n                    snippet_buffer[1] -= snippet_range[1] - num_frames\n                    snippet_range[1] -= snippet_range[1] - num_frames\n                snippet_chunk[:, snippet_buffer[0]:snippet_buffer[1]] = self.get_traces(channel_ids=channel_ids,\n                                                                                        start_frame=snippet_range[0],\n                                                                                        end_frame=snippet_range[1],\n                                                                                        return_scaled=return_scaled)\n            snippets[i] = snippet_chunk\n        return snippets\n\n    def set_channel_locations(self, locations, channel_ids=None):\n        \"\"\"This function sets the location key properties of each specified channel\n        id with the corresponding locations of the passed in locations list.\n\n        Parameters\n        ----------\n        locations: array_like\n            A list of corresponding locations (array_like) for the given channel_ids\n        channel_ids: array-like or int\n            The channel ids (ints) for which the locations will be specified. If None, all channel ids are assumed.\n        \"\"\"\n        if channel_ids is None:\n            channel_ids = list(self.get_channel_ids())\n        if isinstance(channel_ids, (int, np.integer)):\n            channel_ids = [channel_ids]\n            locations = [locations]\n        # Only None upon initialization\n        if self._key_properties['location'] is None:\n            default_locations = np.empty((self.get_num_channels(), 3), dtype='float')\n            default_locations[:] = np.nan\n            self._key_properties['location'] = default_locations\n        if len(channel_ids) == len(locations):\n            for i in range(len(channel_ids)):\n                if isinstance(locations[i], (list, np.ndarray, tuple)):\n                    location = np.asarray(locations[i])\n                    channel_idx = list(self.get_channel_ids()).index(channel_ids[i])\n                    if len(location) == 2:\n                        self._key_properties['location'][channel_idx, :2] = location\n                    elif len(location) == 3:\n                        self._key_properties['location'][channel_idx] = location\n                    else:\n                        raise TypeError(\"'location' must be 2d ior 3d\")\n                else:\n                    raise TypeError(\"'location' must be an array like object\")\n        else:\n            raise ValueError(\"channel_ids and locations must have same length\")\n\n    def get_channel_locations(self, channel_ids=None, locations_2d=True):\n        \"\"\"This function returns the location of each channel specified by channel_ids\n\n        Parameters\n        ----------\n        channel_ids: array-like or int\n            The channel ids (ints) for which the locations will be returned. If None, all channel ids are assumed.\n        locations_2d: bool\n            If True (default), first two dimensions are returned\n\n        Returns\n        -------\n        locations: array_like\n            Returns a list of corresponding locations (floats) for the given\n            channel_ids\n        \"\"\"\n        if channel_ids is None:\n            channel_ids = list(self.get_channel_ids())\n        if isinstance(channel_ids, (int, np.integer)):\n            channel_ids = [channel_ids]\n        locations = self._key_properties['location']\n        # Only None upon initialization\n        if locations is None:\n            locations = np.empty((self.get_num_channels(), 3), dtype='float')\n            locations[:] = np.nan\n            self._key_properties['location'] = locations\n        locations = np.array(locations)\n        channel_idxs = np.array([list(self.get_channel_ids()).index(ch) for ch in channel_ids])\n        if locations_2d:\n            locations = np.array(locations)[:, :2]\n        return locations[channel_idxs]\n\n    def clear_channel_locations(self, channel_ids=None):\n        \"\"\"This function clears the location of each channel specified by channel_ids.\n\n        Parameters\n        ----------\n        channel_ids: array-like or int\n            The channel ids (ints) for which the locations will be cleared. If None, all channel ids are assumed.\n        \"\"\"\n        if channel_ids is None:\n            channel_ids = list(self.get_channel_ids())\n        if isinstance(channel_ids, (int, np.integer)):\n            channel_ids = [channel_ids]\n        # Reset to default locations (NaN)\n        default_locations =  np.array([[np.nan, np.nan, np.nan] for i in range(len(channel_ids))])\n        self.set_channel_locations(default_locations, channel_ids)\n\n    def set_channel_groups(self, groups, channel_ids=None):\n        \"\"\"This function sets the group key property of each specified channel\n        id with the corresponding group of the passed in groups list.\n\n        Parameters\n        ----------\n        groups: array-like or int\n            A list of groups (ints) for the channel_ids\n        channel_ids: array_like or None\n            The channel ids (ints) for which the groups will be specified. If None, all channel ids are assumed.\n        \"\"\"\n        if channel_ids is None:\n            channel_ids = list(self.get_channel_ids())\n        if isinstance(channel_ids, (int, np.integer)):\n            channel_ids = [channel_ids]\n        if isinstance(groups, (int, np.integer)):\n            groups = [groups]\n        # Only None upon initialization\n        if self._key_properties['group'] is None:\n            self._key_properties['group'] = np.zeros(self.get_num_channels(), dtype='int')\n        if len(channel_ids) == len(groups):\n            for i in range(len(channel_ids)):\n                if isinstance(groups[i], (int, np.integer)):\n                    channel_idx = list(self.get_channel_ids()).index(channel_ids[i])\n                    self._key_properties['group'][channel_idx] = int(groups[i])\n                else:\n                    raise TypeError(\"'group' must be an int\")\n        else:\n            raise ValueError(\"channel_ids and groups must have same length\")\n\n    def get_channel_groups(self, channel_ids=None):\n        \"\"\"This function returns the group of each channel specified by\n        channel_ids\n\n        Parameters\n        ----------\n        channel_ids: array-like or int\n            The channel ids (ints) for which the groups will be returned\n\n        Returns\n        -------\n        groups: array_like\n            Returns a list of corresponding groups (ints) for the given channel_ids\n        \"\"\"\n        if channel_ids is None:\n            channel_ids = list(self.get_channel_ids())\n        if isinstance(channel_ids, (int, np.integer)):\n            channel_ids = [channel_ids]\n        groups = self._key_properties['group']\n        # Only None upon initialization\n        if groups is None:\n            groups = np.zeros(self.get_num_channels(), dtype='int')\n            self._key_properties['group'] = groups\n        groups = np.array(groups)\n        channel_idxs = np.array([list(self.get_channel_ids()).index(ch) for ch in channel_ids])\n        return groups[channel_idxs]\n\n    def clear_channel_groups(self, channel_ids=None):\n        \"\"\"This function clears the group of each channel specified by channel_ids\n\n        Parameters\n        ----------\n        channel_ids: array-like or int\n            The channel ids (ints) for which the groups will be cleared. If None, all channel ids are assumed.\n        \"\"\"\n        if channel_ids is None:\n            channel_ids = list(self.get_channel_ids())\n        if isinstance(channel_ids, (int, np.integer)):\n            channel_ids = [channel_ids]\n        # Reset to default groups (0)\n        default_groups = [0] * len(channel_ids)\n        self.set_channel_groups(default_groups, channel_ids)\n\n    def set_channel_gains(self, gains, channel_ids=None):\n        \"\"\"This function sets the gain key property of each specified channel\n        id with the corresponding group of the passed in gains float/list.\n\n        Parameters\n        ----------\n        gains: float/array_like\n            If a float, each channel will be assigned the corresponding gain.\n            If a list, each channel will be given a gain from the list\n        channel_ids: array_like or None\n            The channel ids (ints) for which the groups will be specified. If None, all channel ids are assumed.\n        \"\"\"\n        if channel_ids is None:\n            channel_ids = list(self.get_channel_ids())\n        if isinstance(channel_ids, (int, np.integer)):\n            channel_ids = [channel_ids]\n        if isinstance(gains, (int, np.integer, float)):\n            gains = [gains] * len(channel_ids)\n        # Only None upon initialization\n        if self._key_properties['gain'] is None:\n            self._key_properties['gain'] = np.ones(self.get_num_channels(), dtype='float')\n        if len(channel_ids) == len(gains):\n            for i in range(len(channel_ids)):\n                if isinstance(gains[i], (int, np.integer, float)):\n                    channel_idx = list(self.get_channel_ids()).index(channel_ids[i])\n                    self._key_properties['gain'][channel_idx] = float(gains[i])\n                else:\n                    raise TypeError(\"'gain' must be an int or float\")\n        else:\n            raise ValueError(\"channel_ids and gains must have same length\")\n\n    def get_channel_gains(self, channel_ids=None):\n        \"\"\"This function returns the gain of each channel specified by channel_ids.\n\n        Parameters\n        ----------\n        channel_ids: array_like\n            The channel ids (ints) for which the gains will be returned\n\n        Returns\n        -------\n        gains: array_like\n            Returns a list of corresponding gains (floats) for the given channel_ids\n        \"\"\"\n        if channel_ids is None:\n            channel_ids = list(self.get_channel_ids())\n        if isinstance(channel_ids, (int, np.integer)):\n            channel_ids = [channel_ids]\n        gains = self._key_properties['gain']\n        # Only None upon initialization\n        if gains is None:\n            gains = np.ones(self.get_num_channels(), dtype='float')\n            self._key_properties['gain'] = gains\n        gains = np.array(gains)\n        channel_idxs = np.array([list(self.get_channel_ids()).index(ch) for ch in channel_ids])\n        return gains[channel_idxs]\n\n    def clear_channel_gains(self, channel_ids=None):\n        \"\"\"This function clears the gains of each channel specified by channel_ids\n\n        Parameters\n        ----------\n        channel_ids: array-like or int\n            The channel ids (ints) for which the groups will be cleared. If None, all channel ids are assumed.\n        \"\"\"\n        if channel_ids is None:\n            channel_ids = list(self.get_channel_ids())\n        if isinstance(channel_ids, (int, np.integer)):\n            channel_ids = [channel_ids]\n        # Reset to default gains (1)\n        default_gains = [1.] * len(channel_ids)\n        self.set_channel_gains(default_gains, channel_ids)\n\n    def set_channel_offsets(self, offsets, channel_ids=None):\n        \"\"\"This function sets the offset key property of each specified channel\n        id with the corresponding group of the passed in gains float/list.\n\n        Parameters\n        ----------\n        offsets: float/array_like\n            If a float, each channel will be assigned the corresponding offset.\n            If a list, each channel will be given an offset from the list\n        channel_ids: array_like or None\n            The channel ids (ints) for which the groups will be specified. If None, all channel ids are assumed.\n        \"\"\"\n        if channel_ids is None:\n            channel_ids = list(self.get_channel_ids())\n        if isinstance(channel_ids, (int, np.integer)):\n            channel_ids = [channel_ids]\n        if isinstance(offsets, (int, np.integer, float)):\n            offsets = [offsets] * len(channel_ids)\n        # Only None upon initialization\n        if self._key_properties['offset'] is None:\n            self._key_properties['offset'] = np.zeros(self.get_num_channels(), dtype='float')\n        if len(channel_ids) == len(offsets):\n            for i in range(len(channel_ids)):\n                if isinstance(offsets[i], (int, np.integer, float)):\n                    channel_idx = list(self.get_channel_ids()).index(channel_ids[i])\n                    self._key_properties['offset'][channel_idx] = float(offsets[i])\n                else:\n                    raise TypeError(\"'offset' must be an int or float\")\n        else:\n            raise ValueError(\"channel_ids and offsets must have same length\")\n\n    def get_channel_offsets(self, channel_ids=None):\n        \"\"\"This function returns the offset of each channel specified by channel_ids.\n\n        Parameters\n        ----------\n        channel_ids: array_like\n            The channel ids (ints) for which the gains will be returned\n\n        Returns\n        -------\n        offsets: array_like\n            Returns a list of corresponding offsets for the given channel_ids\n        \"\"\"\n        if channel_ids is None:\n            channel_ids = list(self.get_channel_ids())\n        if isinstance(channel_ids, (int, np.integer)):\n            channel_ids = [channel_ids]\n        offsets = self._key_properties['offset']\n        # Only None upon initialization\n        if offsets is None:\n            offsets = np.zeros(self.get_num_channels(), dtype='float')\n            self._key_properties['offset'] = offsets\n        offsets = np.array(offsets)\n        channel_idxs = np.array([list(self.get_channel_ids()).index(ch) for ch in channel_ids])\n        return offsets[channel_idxs]\n\n    def clear_channel_offsets(self, channel_ids=None):\n        \"\"\"This function clears the gains of each channel specified by channel_ids.\n\n        Parameters\n        ----------\n        channel_ids: array-like or int\n            The channel ids (ints) for which the groups will be cleared. If None, all channel ids are assumed.\n        \"\"\"\n        if channel_ids is None:\n            channel_ids = list(self.get_channel_ids())\n        if isinstance(channel_ids, (int, np.integer)):\n            channel_ids = [channel_ids]\n        # Reset to default offets (0)\n        default_offsets = [0.] * len(channel_ids)\n        self.set_channel_offsets(default_offsets, channel_ids)\n\n    def set_channel_property(self, channel_id, property_name, value):\n        \"\"\"This function adds a property dataset to the given channel under the property name.\n\n        Parameters\n        ----------\n        channel_id: int\n            The channel id for which the property will be added\n        property_name: str\n            A property stored by the RecordingExtractor (location, etc.)\n        value:\n            The data associated with the given property name. Could be many\n            formats as specified by the user\n        \"\"\"\n        if isinstance(channel_id, (int, np.integer)):\n            if channel_id in self.get_channel_ids():\n                if isinstance(property_name, str):\n                    if property_name == 'location':\n                        self.set_channel_locations(value, channel_id)\n                    elif property_name == 'group':\n                        self.set_channel_groups(value, channel_id)\n                    else:\n                        if channel_id not in self._properties.keys():\n                            self._properties[channel_id] = {}\n                        self._properties[channel_id][property_name] = value\n                else:\n                    raise TypeError(str(property_name) + \" must be a string\")\n            else:\n                raise ValueError(str(channel_id) + \" is not a valid channel_id\")\n        else:\n            raise TypeError(str(channel_id) + \" must be an int\")\n\n    def get_channel_property(self, channel_id, property_name):\n        \"\"\"This function returns the data stored under the property name from\n        the given channel.\n\n        Parameters\n        ----------\n        channel_id: int\n            The channel id for which the property will be returned\n        property_name: str\n            A property stored by the RecordingExtractor (location, etc.)\n\n        Returns\n        -------\n        property_data\n            The data associated with the given property name. Could be many\n            formats as specified by the user\n        \"\"\"\n        if not isinstance(channel_id, (int, np.integer)):\n            raise TypeError(str(channel_id) + \" must be an int\")\n        if channel_id not in self.get_channel_ids():\n            raise ValueError(str(channel_id) + \" is not a valid channel_id\")\n        if property_name == 'location':\n            return self.get_channel_locations(channel_id)[0]\n        if property_name == 'group':\n            return self.get_channel_groups(channel_id)[0]\n        if property_name == 'gain':\n            return self.get_channel_gains(channel_id)[0]\n        if property_name == 'offset':\n            return self.get_channel_offsets(channel_id)[0]\n        if channel_id not in self._properties.keys():\n            raise ValueError('no properties found for channel ' + str(channel_id))\n        if property_name not in self._properties[channel_id]:\n            raise RuntimeError(str(property_name) + \" has not been added to channel \" + str(channel_id))\n        if not isinstance(property_name, str):\n            raise TypeError(str(property_name) + \" must be a string\")\n        return self._properties[channel_id][property_name]\n\n    def get_channel_property_names(self, channel_id):\n        \"\"\"Get a list of property names for a given channel.\n\n        Parameters\n        ----------\n        channel_id: int\n            The channel id for which the property names will be returned\n            If None (default), will return property names for all channels\n\n        Returns\n        -------\n        property_names\n            The list of property names\n        \"\"\"\n        if isinstance(channel_id, (int, np.integer)):\n            if channel_id in self.get_channel_ids():\n                if channel_id not in self._properties.keys():\n                    self._properties[channel_id] = {}\n                property_names = list(self._properties[channel_id].keys())\n                if np.all(np.logical_not(np.isnan(self.get_channel_locations(channel_id)))):\n                    property_names.extend(['location'])\n                property_names.extend(['group'])\n                property_names.extend(['gain'])\n                property_names.extend(['offset'])\n                return sorted(property_names)\n            else:\n                raise ValueError(str(channel_id) + \" is not a valid channel_id\")\n        else:\n            raise TypeError(str(channel_id) + \" must be an int\")\n\n    def get_shared_channel_property_names(self, channel_ids=None):\n        \"\"\"Get the intersection of channel property names for a given set of channels or for all channels\n        if channel_ids is None.\n\n        Parameters\n        ----------\n        channel_ids: array_like\n            The channel ids for which the shared property names will be returned.\n            If None (default), will return shared property names for all channels\n\n        Returns\n        -------\n        property_names\n            The list of shared property names\n        \"\"\"\n        if channel_ids is None:\n            channel_ids = self.get_channel_ids()\n        curr_property_name_set = set(self.get_channel_property_names(channel_id=channel_ids[0]))\n        for channel_id in channel_ids[1:]:\n            curr_channel_property_name_set = set(self.get_channel_property_names(channel_id=channel_id))\n            curr_property_name_set = curr_property_name_set.intersection(curr_channel_property_name_set)\n        property_names = list(curr_property_name_set)\n        return sorted(property_names)\n\n    def copy_channel_properties(self, recording, channel_ids=None):\n        \"\"\"Copy channel properties from another recording extractor to the current\n        recording extractor.\n\n        Parameters\n        ----------\n        recording: RecordingExtractor\n            The recording extractor from which the properties will be copied\n        channel_ids: (array_like, (int, np.integer))\n            The list (or single value) of channel_ids for which the properties will be copied\n        \"\"\"\n        if channel_ids is None:\n            self._key_properties = deepcopy(recording._key_properties)\n            self._properties = deepcopy(recording._properties)\n        else:\n            if isinstance(channel_ids, (int, np.integer)):\n                channel_ids = [channel_ids]\n            # copy key properties\n            groups = recording.get_channel_groups(channel_ids=channel_ids)\n            locations = recording.get_channel_locations(channel_ids=channel_ids)\n            gains = recording.get_channel_gains(channel_ids=channel_ids)\n            offsets = recording.get_channel_offsets(channel_ids=channel_ids)\n            self.set_channel_groups(groups)\n            self.set_channel_locations(locations)\n            self.set_channel_gains(gains)\n            self.set_channel_offsets(offsets)\n\n            # copy normal properties\n            for channel_id in channel_ids:\n                curr_property_names = recording.get_channel_property_names(channel_id=channel_id)\n                for curr_property_name in curr_property_names:\n                    if curr_property_name not in self._key_properties.keys():  # key property\n                        value = recording.get_channel_property(channel_id=channel_id, property_name=curr_property_name)\n                        self.set_channel_property(channel_id=channel_id, property_name=curr_property_name, value=value)\n\n    def clear_channel_property(self, channel_id, property_name):\n        \"\"\"This function clears the channel property for the given property.\n\n        Parameters\n        ----------\n        channel_id: int\n            The id that specifies a channel in the recording\n        property_name: string\n            The name of the property to be cleared\n        \"\"\"\n        if property_name == 'location':\n            self.clear_channel_locations(channel_id)\n        elif property_name == 'group':\n            self.clear_channel_groups(channel_id)\n        elif channel_id in self._properties.keys():\n            if property_name in self._properties[channel_id]:\n                del self._properties[channel_id][property_name]\n\n    def clear_channels_property(self, property_name, channel_ids=None):\n        \"\"\"This function clears the channels' properties for the given property.\n\n        Parameters\n        ----------\n        property_name: string\n            The name of the property to be cleared\n        channel_ids: list\n            A list of ids that specifies a set of channels in the recording. If None all channels are cleared\n        \"\"\"\n        if channel_ids is None:\n            channel_ids = self.get_channel_ids()\n        for channel_id in channel_ids:\n            self.clear_channel_property(channel_id, property_name)\n\n    def get_epoch(self, epoch_name):\n        \"\"\"This function returns a SubRecordingExtractor which is a view to the\n        given epoch\n\n        Parameters\n        ----------\n        epoch_name: str\n            The name of the epoch to be returned\n\n        Returns\n        -------\n        epoch_extractor: SubRecordingExtractor\n            A SubRecordingExtractor which is a view to the given epoch\n        \"\"\"\n        from .subrecordingextractor import SubRecordingExtractor\n\n        epoch_info = self.get_epoch_info(epoch_name)\n        start_frame = epoch_info['start_frame']\n        end_frame = epoch_info['end_frame']\n        return SubRecordingExtractor(parent_recording=self, start_frame=start_frame,\n                                     end_frame=end_frame)\n\n    def load_probe_file(self, probe_file, channel_map=None, channel_groups=None, verbose=False):\n        \"\"\"This function returns a SubRecordingExtractor that contains information from the given\n        probe file (channel locations, groups, etc.) If a .prb file is given, then 'location' and 'group'\n        information for each channel is added to the SubRecordingExtractor. If a .csv file is given, then\n        it will only add 'location' to the SubRecordingExtractor.\n\n        Parameters\n        ----------\n        probe_file: str\n            Path to probe file. Either .prb or .csv\n        channel_map : array-like\n            A list of channel IDs to set in the loaded file.\n            Only used if the loaded file is a .csv.\n        channel_groups : array-like\n            A list of groups (ints) for the channel_ids to set in the loaded file.\n            Only used if the loaded file is a .csv.\n        verbose: bool\n            If True, output is verbose\n\n        Returns\n        -------\n        subrecording = SubRecordingExtractor\n            The extractor containing all of the probe information.\n        \"\"\"\n        subrecording = load_probe_file(self, probe_file, channel_map=channel_map,\n                                       channel_groups=channel_groups, verbose=verbose)\n        return subrecording\n\n    def save_to_probe_file(self, probe_file, grouping_property=None, radius=None,\n                           graph=True, geometry=True, verbose=False):\n        \"\"\"Saves probe file from the channel information of this recording extractor.\n\n        Parameters\n        ----------\n        probe_file: str\n            file name of .prb or .csv file to save probe information to\n        grouping_property: str (default None)\n            If grouping_property is a shared_channel_property, different groups are saved based on the property.\n        radius: float (default None)\n            Adjacency radius (used by some sorters). If None it is not saved to the probe file.\n        graph: bool\n            If True, the adjacency graph is saved (default=True)\n        geometry: bool\n            If True, the geometry is saved (default=True)\n        verbose: bool\n            If True, output is verbose\n        \"\"\"\n        save_to_probe_file(self, probe_file, grouping_property=grouping_property, radius=radius,\n                           graph=graph, geometry=geometry, verbose=verbose)\n\n    def write_to_binary_dat_format(self, save_path, time_axis=0, dtype=None, chunk_size=None, chunk_mb=500,\n                                   n_jobs=1, joblib_backend='loky', return_scaled=True, verbose=False):\n        \"\"\"Saves the traces of this recording extractor into binary .dat format.\n\n        Parameters\n        ----------\n        save_path: str\n            The path to the file.\n        time_axis: 0 (default) or 1\n            If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file.\n            If 1, the traces shape (nb_channel, nb_sample) is kept in the file.\n        dtype: dtype\n            Type of the saved data. Default float32\n        chunk_size: None or int\n            Size of each chunk in number of frames.\n            If None (default) and 'chunk_mb' is given, the file is saved in chunks of 'chunk_mb' Mb (default 500Mb)\n        chunk_mb: None or int\n            Chunk size in Mb (default 500Mb)\n        n_jobs: int\n            Number of jobs to use (Default 1)\n        joblib_backend: str\n            Joblib backend for parallel processing ('loky', 'threading', 'multiprocessing')\n        return_scaled: bool\n            If True, traces are returned after scaling (using gain/offset). If False, the raw traces are returned\n        verbose: bool\n            If True, output is verbose (when chunks are used)\n        \"\"\"\n        write_to_binary_dat_format(self, save_path=save_path, time_axis=time_axis, dtype=dtype, chunk_size=chunk_size,\n                                   chunk_mb=chunk_mb, n_jobs=n_jobs, joblib_backend=joblib_backend,\n                                   return_scaled=return_scaled, verbose=verbose)\n\n    def write_to_h5_dataset_format(self, dataset_path, save_path=None, file_handle=None,\n                                   time_axis=0, dtype=None, chunk_size=None, chunk_mb=500, verbose=False):\n        \"\"\"Saves the traces of a recording extractor in an h5 dataset.\n\n        Parameters\n        ----------\n        dataset_path: str\n            Path to dataset in h5 file (e.g. '/dataset')\n        save_path: str\n            The path to the file.\n        file_handle: file handle\n            The file handle to dump data. This can be used to append data to an header. In case file_handle is given,\n            the file is NOT closed after writing the binary data.\n        time_axis: 0 (default) or 1\n            If 0 then traces are transposed to ensure (nb_sample, nb_channel) in the file.\n            If 1, the traces shape (nb_channel, nb_sample) is kept in the file.\n        dtype: dtype\n            Type of the saved data. Default float32.\n        chunk_size: None or int\n            Size of each chunk in number of frames.\n            If None (default) and 'chunk_mb' is given, the file is saved in chunks of 'chunk_mb' Mb (default 500Mb)\n        chunk_mb: None or int\n            Chunk size in Mb (default 500Mb)\n        verbose: bool\n            If True, output is verbose (when chunks are used)\n        \"\"\"\n        write_to_h5_dataset_format(self, dataset_path, save_path, file_handle, time_axis, dtype, chunk_size, chunk_mb,\n                                   verbose)\n\n    def get_sub_extractors_by_property(self, property_name, return_property_list=False):\n        \"\"\"Returns a list of SubRecordingExtractors from this RecordingExtractor based on the given\n        property_name (e.g. group)\n\n        Parameters\n        ----------\n        property_name: str\n            The property used to subdivide the extractor\n        return_property_list: bool\n            If True the property list is returned\n\n        Returns\n        -------\n        sub_list: list\n            The list of subextractors to be returned\n        OR\n        sub_list, prop_list\n            If return_property_list is True, the property list will be returned as well\n\n        \"\"\"\n        if return_property_list:\n            sub_list, prop_list = get_sub_extractors_by_property(self, property_name=property_name,\n                                                                 return_property_list=return_property_list)\n            return sub_list, prop_list\n        else:\n            sub_list = get_sub_extractors_by_property(self, property_name=property_name,\n                                                      return_property_list=return_property_list)\n            return sub_list\n\n    def get_ttl_events(self, start_frame=None, end_frame=None, channel_id=0):\n        \"\"\"\n        Returns an array with frames of TTL signals. To be implemented in sub-classes\n\n        Parameters\n        ----------\n        start_frame: int\n            The starting frame of the ttl to be returned (inclusive)\n        end_frame: int\n            The ending frame of the ttl to be returned (exclusive)\n        channel_id: int\n            The TTL channel id\n\n        Returns\n        -------\n        ttl_frames: array-like\n            Frames of TTL signal for the specified channel\n        ttl_state: array-like\n            State of the transition: 1 - rising, -1 - falling\n        \"\"\"\n        raise NotImplementedError\n\n    @staticmethod\n    def write_recording(recording, save_path):\n        \"\"\"This function writes out the recorded file of a given recording\n        extractor to the file format of this current recording extractor. Allows\n        for easy conversion between recording file formats. It is a static\n        method so it can be used without instantiating this recording extractor.\n\n        Parameters\n        ----------\n        recording: RecordingExtractor\n            An RecordingExtractor that can extract information from the recording\n            file to be converted to the new format.\n\n        save_path: string\n            A path to where the converted recorded data will be saved, which may\n            either be a file or a folder, depending on the format.\n        \"\"\"\n        raise NotImplementedError(\"The write_recording function is not \\\n                                  implemented for this extractor\")\n"
  },
  {
    "path": "spikeextractors/save_tools.py",
    "content": "from pathlib import Path\n\nfrom .cacheextractors import CacheRecordingExtractor, CacheSortingExtractor\nfrom .recordingextractor import RecordingExtractor\nfrom .sortingextractor import SortingExtractor\n\n\ndef save_si_object(object_name: str, si_object, output_folder,\n                   cache_raw=False, include_properties=True, include_features=False):\n    \"\"\"\n    Save an arbitrary SI object to a temprary location.\n\n    Parameters\n    ----------\n    object_name: str\n        The unique name of the SpikeInterface object.\n    si_object: RecordingExtractor or SortingExtractor\n        The extractor to be saved.\n    output_folder: str or Path\n        The folder where the object is saved.\n    cache_raw: bool\n        If True, the Extractor is cached to a binary file (not recommended for RecordingExtractor objects)\n        (default False).\n    include_properties: bool\n        If True, properties (channel or unit) are saved (default True).\n    include_features: bool\n        If True, spike features are saved (default False)\n    \"\"\"\n    Path(output_folder).mkdir(parents=True, exist_ok=True)\n\n    if isinstance(si_object, RecordingExtractor):\n        if not si_object.is_dumpable:\n            cache = CacheRecordingExtractor(si_object, save_path=output_folder / \"raw.dat\")\n        elif cache_raw:\n            # save to json before caching to keep history (in case it's needed)\n            json_file = output_folder / f\"{object_name}.json\"\n            si_object.dump_to_json(output_folder / json_file)\n            cache = CacheRecordingExtractor(si_object, save_path=output_folder / \"raw.dat\")\n        else:\n            cache = si_object\n\n    elif isinstance(si_object, SortingExtractor):\n        if not si_object.is_dumpable:\n            cache = CacheSortingExtractor(si_object, save_path=output_folder / \"sorting.npz\")\n        elif cache_raw:\n            # save to json before caching to keep history (in case it's needed)\n            json_file = output_folder / f\"{object_name}.json\"\n            si_object.dump_to_json(output_folder / json_file)\n            cache = CacheSortingExtractor(si_object, save_path=output_folder / \"sorting.npz\")\n        else:\n            cache = si_object\n    else:\n        raise ValueError(\"The 'si_object' argument shoulde be a SpikeInterface Extractor!\")\n\n    pkl_file = output_folder / f\"{object_name}.pkl\"\n    cache.dump_to_pickle(\n        output_folder / pkl_file,\n        include_properties=include_properties,\n        include_features=include_features\n    )\n"
  },
  {
    "path": "spikeextractors/sortingextractor.py",
    "content": "from abc import ABC, abstractmethod\nimport numpy as np\nfrom copy import deepcopy\n\nfrom .extraction_tools import get_sub_extractors_by_property\nfrom .baseextractor import BaseExtractor\n\n\nclass SortingExtractor(ABC, BaseExtractor):\n    \"\"\"A class that contains functions for extracting important information\n    from spiked sorted data given a spike sorting software. It is an abstract\n    class so all functions with the @abstractmethod tag must be implemented for\n    the initialization to work.\n    \"\"\"\n\n    _default_filename = \"spikeinterface_sorting\"\n\n    def __init__(self):\n        BaseExtractor.__init__(self)\n        self._sampling_frequency = None\n\n    @abstractmethod\n    def get_unit_ids(self):\n        \"\"\"This function returns a list of ids (ints) for each unit in the sorsted result.\n\n        Returns\n        -------\n        unit_ids: array_like\n            A list of the unit ids in the sorted result (ints).\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n        \"\"\"This function extracts spike frames from the specified unit.\n        It will return spike frames from within three ranges:\n\n            [start_frame, t_start+1, ..., end_frame-1]\n            [start_frame, start_frame+1, ..., final_unit_spike_frame - 1]\n            [0, 1, ..., end_frame-1]\n            [0, 1, ..., final_unit_spike_frame - 1]\n\n        if both start_frame and end_frame are given, if only start_frame is\n        given, if only end_frame is given, or if neither start_frame or end_frame\n        are given, respectively. Spike frames are returned in the form of an\n        array_like of spike frames. In this implementation, start_frame is inclusive\n        and end_frame is exclusive conforming to numpy standards.\n\n        Parameters\n        ----------\n        unit_id: int\n            The id that specifies a unit in the recording\n        start_frame: int\n            The frame above which a spike frame is returned  (inclusive)\n        end_frame: int\n            The frame below which a spike frame is returned  (exclusive)\n\n        Returns\n        -------\n        spike_train: numpy.ndarray\n            An 1D array containing all the frames for each spike in the\n            specified unit given the range of start and end frames\n        \"\"\"\n        pass\n\n    def get_units_spike_train(self, unit_ids=None, start_frame=None, end_frame=None):\n        \"\"\"This function extracts spike frames from the specified units.\n\n        Parameters\n        ----------\n        unit_ids: array_like\n            The unit ids from which to return spike trains. If None, all unit\n            spike trains will be returned\n        start_frame: int\n            The frame above which a spike frame is returned  (inclusive)\n        end_frame: int\n            The frame below which a spike frame is returned  (exclusive)\n\n        Returns\n        -------\n        spike_train: numpy.ndarray\n            An 2D array containing all the frames for each spike in the\n            specified units given the range of start and end frames\n        \"\"\"\n        if unit_ids is None:\n            unit_ids = self.get_unit_ids()\n        spike_trains = [self.get_unit_spike_train(uid, start_frame, end_frame) for uid in unit_ids]\n        return spike_trains\n\n    def get_sampling_frequency(self):\n        \"\"\"\n        It returns the sampling frequency.\n\n        Returns\n        -------\n        sampling_frequency: float\n            The sampling frequency\n        \"\"\"\n        return self._sampling_frequency\n\n    def set_sampling_frequency(self, sampling_frequency):\n        \"\"\"\n        It sets the sorting extractor sampling frequency.\n\n        Parameters\n        ----------\n        sampling_frequency: float\n            The sampling frequency\n        \"\"\"\n        self._sampling_frequency = sampling_frequency\n\n    def set_unit_spike_features(self, unit_id, feature_name, value, indexes=None):\n        \"\"\"This function adds a unit features data set under the given features\n        name to the given unit.\n\n        Parameters\n        ----------\n        unit_id: int\n            The unit id for which the features will be set\n        feature_name: str\n            The name of the feature to be stored\n        value: array_like\n            The data associated with the given feature name. Could be many\n            formats as specified by the user.\n        indexes: array_like\n            The indices of the specified spikes (if the number of spike features\n            is less than the length of the unit's spike train). If None, it is\n            assumed that value has the same length as the spike train.\n        \"\"\"\n        if isinstance(unit_id, (int, np.integer)):\n            if unit_id in self.get_unit_ids():\n                if unit_id not in self._features.keys():\n                    self._features[unit_id] = {}\n                if indexes is None:\n                    if isinstance(feature_name, str) and len(value) == len(self.get_unit_spike_train(unit_id)):\n                        self._features[unit_id][feature_name] = value\n                    else:\n                        if not isinstance(feature_name, str):\n                            raise ValueError(\"feature_name must be a string\")\n                        else:\n                            raise ValueError(\"feature values should have the same length as the spike train\")\n                else:\n                    if isinstance(feature_name, str) and len(value) == len(indexes):\n                        indexes = np.array(indexes)\n                        self._features[unit_id][feature_name] = value\n                        self._features[unit_id][feature_name + '_idxs'] = indexes\n                    else:\n                        if not isinstance(feature_name, str):\n                            raise ValueError(\"feature_name must be a string\")\n                        else:\n                            raise ValueError(\"feature values should have the same length as indexes\")\n            else:\n                raise ValueError(str(unit_id) + \" is not a valid unit_id\")\n        else:\n            raise ValueError(str(unit_id) + \" must be an int\")\n\n    def get_unit_spike_features(self, unit_id, feature_name, start_frame=None, end_frame=None):\n        \"\"\"This function extracts the specified spike features from the specified unit.\n        It will return spike features from within three ranges:\n\n            [start_frame, t_start+1, ..., end_frame-1]\n            [start_frame, start_frame+1, ..., final_unit_spike_frame - 1]\n            [0, 1, ..., end_frame-1]\n            [0, 1, ..., final_unit_spike_frame - 1]\n\n        if both start_frame and end_frame are given, if only start_frame is\n        given, if only end_frame is given, or if neither start_frame or end_frame\n        are given, respectively. Spike features are returned in the form of an\n        array_like of spike features. In this implementation, start_frame is inclusive\n        and end_frame is exclusive conforming to numpy standards.\n\n        Parameters\n        ----------\n        unit_id: int\n            The id that specifies a unit in the recording\n        feature_name: string\n            The name of the feature to be returned\n        start_frame: int\n            The frame above which a spike frame is returned  (inclusive)\n        end_frame: int\n            The frame below which a spike frame is returned  (exclusive)\n\n        Returns\n        -------\n        spike_features: numpy.ndarray\n            An array containing all the features for each spike in the\n            specified unit given the range of start and end frames\n        \"\"\"\n        start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame)\n        if isinstance(unit_id, (int, np.integer)):\n            if unit_id in self.get_unit_ids():\n                if unit_id not in self._features.keys():\n                    self._features[unit_id] = {}\n                if isinstance(feature_name, str):\n                    if feature_name in self._features[unit_id].keys():\n                        spike_train = self.get_unit_spike_train(unit_id)\n                        if start_frame is None:\n                            start_frame = 0\n                        if end_frame is None:\n                            end_frame = np.inf\n                        if start_frame == 0 and end_frame == np.inf:\n                            # keep memmap objects\n                            return self._features[unit_id][feature_name]\n                        else:\n                            if len(self._features[unit_id][feature_name]) == len(spike_train):\n                                spike_indices = np.where(np.logical_and(spike_train >= start_frame,\n                                                                        spike_train < end_frame))\n                            elif len(self._features[unit_id][feature_name]) < len(spike_train):\n                                if not feature_name.endswith('idxs'):\n                                    # retrieve features on the correct idxs\n                                    assert feature_name + '_idxs' in self.get_unit_spike_feature_names(unit_id=unit_id)\n                                    feature_name_idxs = feature_name + '_idxs'\n                                    value_idxs = np.array(self.get_unit_spike_features(unit_id=unit_id,\n                                                                                       feature_name=feature_name_idxs))\n                                    spike_train = spike_train[value_idxs]\n                                    spike_indices = np.where(np.logical_and(spike_train >= start_frame,\n                                                                            spike_train < end_frame))\n                                else:\n                                    # retrieve idxs features\n                                    value_idxs = np.array(self.get_unit_spike_features(unit_id=unit_id,\n                                                                                       feature_name=feature_name))\n                                    spike_train = spike_train[value_idxs]\n                                    spike_indices = np.where(np.logical_and(spike_train >= start_frame,\n                                                                            spike_train < end_frame))\n                            else:\n                                raise ValueError(str(feature_name) + \" dimensions are inconsistent for unit \"\n                                                 + str(unit_id))\n                            if isinstance(self._features[unit_id][feature_name], list):\n                                return list(np.array(self._features[unit_id][feature_name])[spike_indices])\n                            else:\n                                return np.array(self._features[unit_id][feature_name])[spike_indices]\n                    else:\n                        raise ValueError(str(feature_name) + \" has not been added to unit \" + str(unit_id))\n                else:\n                    raise ValueError(str(feature_name) + \" must be a string\")\n            else:\n                raise ValueError(str(unit_id) + \" is not a valid unit_id\")\n        else:\n            raise ValueError(str(unit_id) + \" must be an int\")\n\n    def set_times(self, times):\n        \"\"\"This function sets the sorting times to convert spike trains to seconds\n\n        Parameters\n        ----------\n        times: array-like\n            The times in seconds for each frame\n        \"\"\"\n        max_frames = np.array([np.max(self.get_unit_spike_train(u)) for u in self.get_unit_ids()])\n        assert np.all(max_frames < len(times)), \"The length of 'times' should be greater than the maximum \" \\\n                                                     \"spike frame index\"\n        self._times = times.astype('float64')\n\n    def copy_times(self, extractor):\n        \"\"\"This function copies times from another extractor.\n\n        Parameters\n        ----------\n        extractor: BaseExtractor\n            The extractor from which the epochs will be copied\n        \"\"\"\n        if extractor._times is not None:\n            self.set_times(deepcopy(extractor._times))\n\n    def frame_to_time(self, frames):\n        \"\"\"This function converts user-inputted frame indexes to times with units of seconds.\n\n        Parameters\n        ----------\n        frames: float or array-like\n            The frame or frames to be converted to times\n\n        Returns\n        -------\n        times: float or array-like\n            The corresponding times in seconds\n        \"\"\"\n        # Default implementation\n        if self._times is None:\n            return np.round(frames / self.get_sampling_frequency(), 6)\n        else:\n            return self._times[frames]\n\n    def time_to_frame(self, times):\n        \"\"\"This function converts a user-inputted times (in seconds) to a frame indexes.\n\n        Parameters\n        ----------\n        times: float or array-like\n            The times (in seconds) to be converted to frame indexes\n\n        Returns\n        -------\n        frames: float or array-like\n            The corresponding frame indexes\n        \"\"\"\n        # Default implementation\n        if self._times is None:\n            return np.round(times * self.get_sampling_frequency()).astype('int64')\n        else:\n            return np.searchsorted(self._times, times).astype('int64')\n\n    def clear_unit_spike_features(self, unit_id, feature_name):\n        \"\"\"This function clears the unit spikes features for the given feature.\n\n        Parameters\n        ----------\n        unit_id: int\n            The id that specifies a unit in the sorting\n        feature_name: string\n            The name of the feature to be cleared\n        \"\"\"\n        if unit_id in self._features.keys():\n            if feature_name in self._features[unit_id]:\n                del self._features[unit_id][feature_name]\n\n    def clear_units_spike_features(self, feature_name, unit_ids=None):\n        \"\"\"This function clears the units' spikes features for the given feature.\n\n        Parameters\n        ----------\n        feature_name: string\n            The name of the feature to be cleared\n        unit_ids: list\n            A list of ids that specifies a set of units in the sorting. If None, all units are cleared\n        \"\"\"\n        if unit_ids is None:\n            unit_ids = self.get_unit_ids()\n        for unit_id in unit_ids:\n            self.clear_unit_spike_features(unit_id, feature_name)\n\n    def get_unit_spike_feature_names(self, unit_id):\n        \"\"\"This function returns the list of feature names for the given unit\n\n        Parameters\n        ----------\n        unit_id: int\n            The unit id for which the feature names will be returned\n\n        Returns\n        -------\n        property_names\n            The list of feature names.\n        \"\"\"\n        if isinstance(unit_id, (int, np.integer)):\n            if unit_id in self.get_unit_ids():\n                if unit_id not in self._features.keys():\n                    self._features[unit_id] = {}\n                feature_names = sorted(self._features[unit_id].keys())\n                return feature_names\n            else:\n                raise ValueError(str(unit_id) + \" is not a valid unit_id\")\n        else:\n            raise ValueError(str(unit_id) + \" must be an int\")\n\n    def get_shared_unit_spike_feature_names(self, unit_ids=None):\n        \"\"\"Get the intersection of unit feature names for a given set of units or for all units if unit_ids is None.\n\n        Parameters\n        ----------\n        unit_ids: array_like\n            The unit ids for which the shared feature names will be returned.\n            If None (default), will return shared feature names for all units\n\n        Returns\n        -------\n        property_names\n            The list of shared feature names\n        \"\"\"\n        if unit_ids is None:\n            unit_ids = self.get_unit_ids()\n        if len(unit_ids) > 0:\n            curr_feature_name_set = set(self.get_unit_spike_feature_names(unit_id=unit_ids[0]))\n            for unit_id in unit_ids[1:]:\n                curr_unit_feature_name_set = set(self.get_unit_spike_feature_names(unit_id=unit_id))\n                curr_feature_name_set = curr_feature_name_set.intersection(curr_unit_feature_name_set)\n            feature_names = sorted(list(curr_feature_name_set))\n        else:\n            feature_names = []\n        return feature_names\n\n    def set_unit_property(self, unit_id, property_name, value):\n        \"\"\"This function adds a unit property data set under the given property\n        name to the given unit.\n\n        Parameters\n        ----------\n        unit_id: int\n            The unit id for which the property will be set\n        property_name: str\n            The name of the property to be stored\n        value\n            The data associated with the given property name. Could be many\n            formats as specified by the user\n        \"\"\"\n        if isinstance(unit_id, (int, np.integer)):\n            if unit_id in self.get_unit_ids():\n                if unit_id not in self._properties.keys():\n                    self._properties[unit_id] = {}\n                if isinstance(property_name, str):\n                    self._properties[unit_id][property_name] = value\n                else:\n                    raise ValueError(str(property_name) + \" must be a string\")\n            else:\n                raise ValueError(str(unit_id) + \" is not a valid unit_id\")\n        else:\n            raise ValueError(str(unit_id) + \" must be an int\")\n\n    def set_units_property(self, *, unit_ids=None, property_name, values):\n        \"\"\"Sets unit property data for a list of units\n\n        Parameters\n        ----------\n        unit_ids: list\n            The list of unit ids for which the property will be set\n            Defaults to get_unit_ids()\n        property_name: str\n            The name of the property\n        value: list\n            The list of values to be set\n        \"\"\"\n        if unit_ids is None:\n            unit_ids = self.get_unit_ids()\n        for i, unit in enumerate(unit_ids):\n            self.set_unit_property(unit_id=unit, property_name=property_name, value=values[i])\n\n    def get_unit_property(self, unit_id, property_name):\n        \"\"\"This function returns the data stored under the property name given\n        from the given unit.\n\n        Parameters\n        ----------\n        unit_id: int\n            The unit id for which the property will be returned\n        property_name: str\n            The name of the property\n\n        Returns\n        -------\n        value\n            The data associated with the given property name. Could be many\n            formats as specified by the user\n        \"\"\"\n        if isinstance(unit_id, (int, np.integer)):\n            if unit_id in self.get_unit_ids():\n                if unit_id not in self._properties.keys():\n                    self._properties[unit_id] = {}\n                if isinstance(property_name, str):\n                    if property_name in list(self._properties[unit_id].keys()):\n                        return self._properties[unit_id][property_name]\n                    else:\n                        raise ValueError(str(property_name) + \" has not been added to unit \" + str(unit_id))\n                else:\n                    raise ValueError(str(property_name) + \" must be a string\")\n            else:\n                raise ValueError(str(unit_id) + \" is not a valid unit_id\")\n        else:\n            raise ValueError(str(unit_id) + \" must be an int\")\n\n    def get_units_property(self, *, unit_ids=None, property_name):\n        \"\"\"Returns a list of values stored under the property name corresponding\n        to a list of units\n\n        Parameters\n        ----------\n        unit_ids: list\n            The unit ids for which the property will be returned\n            Defaults to get_unit_ids()\n        property_name: str\n            The name of the property\n\n        Returns\n        -------\n        values\n            The list of values\n        \"\"\"\n        if unit_ids is None:\n            unit_ids = self.get_unit_ids()\n        values = [self.get_unit_property(unit_id=unit, property_name=property_name) for unit in unit_ids]\n        return values\n\n    def get_unit_property_names(self, unit_id):\n        \"\"\"Get a list of property names for a given unit.\n\n        Parameters\n        ----------\n        unit_id: int\n            The unit id for which the property names will be returned\n\n        Returns\n        -------\n        property_names\n            The list of property names\n        \"\"\"\n        if isinstance(unit_id, (int, np.integer)):\n            if unit_id in self.get_unit_ids():\n                if unit_id not in self._properties.keys():\n                    self._properties[unit_id] = {}\n                property_names = sorted(self._properties[unit_id].keys())\n                return property_names\n            else:\n                raise ValueError(str(unit_id) + \" is not a valid unit id\")\n        else:\n            raise TypeError(str(unit_id) + \" must be an int\")\n\n    def get_shared_unit_property_names(self, unit_ids=None):\n        \"\"\"Get the intersection of unit property names for a given set of units or for all units if unit_ids is None.\n\n        Parameters\n        ----------\n        unit_ids: array_like\n            The unit ids for which the shared property names will be returned.\n            If None (default), will return shared property names for all units\n\n        Returns\n        -------\n        property_names\n            The list of shared property names\n        \"\"\"\n        if unit_ids is None:\n            unit_ids = self.get_unit_ids()\n        if len(unit_ids) > 0:\n            curr_property_name_set = set(self.get_unit_property_names(unit_id=unit_ids[0]))\n            for unit_id in unit_ids[1:]:\n                curr_unit_property_name_set = set(self.get_unit_property_names(unit_id=unit_id))\n                curr_property_name_set = curr_property_name_set.intersection(curr_unit_property_name_set)\n            property_names = sorted(list(curr_property_name_set))\n        else:\n            property_names = []\n        return property_names\n\n    def copy_unit_properties(self, sorting, unit_ids=None):\n        \"\"\"Copy unit properties from another sorting extractor to the current\n        sorting extractor.\n\n        Parameters\n        ----------\n        sorting: SortingExtractor\n            The sorting extractor from which the properties will be copied\n        unit_ids: (array_like, (int, np.integer))\n            The list (or single value) of unit_ids for which the properties will be copied\n        \"\"\"\n        # Second condition: Ensure dictionary is not empty\n        if unit_ids is None and len(self._properties.keys()) > 0:\n            self._properties = deepcopy(sorting._properties)\n        else:\n            if unit_ids is None:\n                unit_ids = sorting.get_unit_ids()\n            if isinstance(unit_ids, (int, np.integer)):\n                curr_property_names = sorting.get_unit_property_names(unit_id=unit_ids)\n                for curr_property_name in curr_property_names:\n                    value = sorting.get_unit_property(unit_id=unit_ids, property_name=curr_property_name)\n                    self.set_unit_property(unit_id=unit_ids, property_name=curr_property_name, value=value)\n            else:\n                for unit_id in unit_ids:\n                    curr_property_names = sorting.get_unit_property_names(unit_id=unit_id)\n                    for curr_property_name in curr_property_names:\n                        value = sorting.get_unit_property(unit_id=unit_id, property_name=curr_property_name)\n                        self.set_unit_property(unit_id=unit_id, property_name=curr_property_name, value=value)\n\n    def clear_unit_property(self, unit_id, property_name):\n        \"\"\"This function clears the unit property for the given property.\n\n        Parameters\n        ----------\n        unit_id: int\n            The id that specifies a unit in the sorting\n        property_name: string\n            The name of the property to be cleared\n        \"\"\"\n        if unit_id in self._properties.keys():\n            if property_name in self._properties[unit_id]:\n                del self._properties[unit_id][property_name]\n\n    def clear_units_property(self, property_name, unit_ids=None):\n        \"\"\"This function clears the units' properties for the given property.\n\n        Parameters\n        ----------\n        property_name: string\n            The name of the property to be cleared\n        unit_ids: list\n            A list of ids that specifies a set of units in the sorting. If None, all units are cleared\n        \"\"\"\n        if unit_ids is None:\n            unit_ids = self.get_unit_ids()\n        for unit_id in unit_ids:\n            self.clear_unit_property(unit_id, property_name)\n\n    def copy_unit_spike_features(self, sorting, unit_ids=None):\n        \"\"\"Copy unit spike features from another sorting extractor to the current\n        sorting extractor.\n\n        Parameters\n        ----------\n        sorting: SortingExtractor\n            The sorting extractor from which the spike features will be copied\n        unit_ids: (array_like, (int, np.integer))\n            The list (or single value) of unit_ids for which the spike features will be copied\n        \"\"\"\n        if unit_ids is None:\n            self._features = deepcopy(sorting._features)\n        else:\n            if isinstance(unit_ids, (int, np.integer)):\n                unit_ids = [unit_ids]\n            for unit_id in unit_ids:\n                curr_feature_names = sorting.get_unit_spike_feature_names(unit_id=unit_id)\n                for curr_feature_name in curr_feature_names:\n                    value = sorting.get_unit_spike_features(unit_id=unit_id, feature_name=curr_feature_name)\n                    if len(value) < len(sorting.get_unit_spike_train(unit_id)):\n                        if not curr_feature_name.endswith('idxs'):\n                            assert curr_feature_name + '_idxs' in \\\n                                   sorting.get_unit_spike_feature_names(unit_id=unit_id)\n                            curr_feature_name_idxs = curr_feature_name + '_idxs'\n                            value_idxs = np.array(sorting.get_unit_spike_features(unit_id=unit_id,\n                                                                                  feature_name=curr_feature_name_idxs))\n                            # find index of first spike\n                            self.set_unit_spike_features(unit_id=unit_id, feature_name=curr_feature_name,\n                                                         value=value, indexes=value_idxs)\n                    else:\n                        self.set_unit_spike_features(unit_id=unit_id, feature_name=curr_feature_name, value=value)\n\n    def get_epoch(self, epoch_name):\n        \"\"\"This function returns a SubSortingExtractor which is a view to the given epoch.\n\n        Parameters\n        ----------\n        epoch_name: str\n            The name of the epoch to be returned\n\n        Returns\n        -------\n        epoch_extractor: SubRecordingExtractor\n            A SubRecordingExtractor which is a view to the given epoch\n        \"\"\"\n        epoch_info = self.get_epoch_info(epoch_name)\n        start_frame = epoch_info['start_frame']\n        end_frame = epoch_info['end_frame']\n        from .subsortingextractor import SubSortingExtractor\n        return SubSortingExtractor(parent_sorting=self, start_frame=start_frame,\n                                   end_frame=end_frame)\n\n    def get_sub_extractors_by_property(self, property_name, return_property_list=False):\n        \"\"\"Returns a list of SubSortingExtractors from this SortingExtractor based on the given\n        property_name (e.g. group)\n\n        Parameters\n        ----------\n        property_name: str\n            The property used to subdivide the extractor\n        return_property_list: bool\n            If True the property list is returned\n\n        Returns\n        -------\n        sub_list: list\n            The list of subextractors to be returned\n        \"\"\"\n        if return_property_list:\n            sub_list, prop_list = get_sub_extractors_by_property(self, property_name=property_name,\n                                                                 return_property_list=return_property_list)\n            return sub_list, prop_list\n        else:\n            sub_list = get_sub_extractors_by_property(self, property_name=property_name,\n                                                      return_property_list=return_property_list)\n            return sub_list\n\n    @staticmethod\n    def write_sorting(sorting, save_path):\n        \"\"\"This function writes out the spike sorted data file of a given sorting\n        extractor to the file format of this current sorting extractor. Allows\n        for easy conversion between spike sorting file formats. It is a static\n        method so it can be used without instantiating this sorting extractor.\n\n        Parameters\n        ----------\n        sorting: SortingExtractor\n            A SortingExtractor that can extract information from the sorted data\n            file to be converted to the new format\n        save_path: string\n            A path to where the converted sorted data will be saved, which may\n            either be a file or a folder, depending on the format\n        \"\"\"\n        raise NotImplementedError(\"The write_sorting function is not \\\n                                  implemented for this extractor\")\n\n    def get_unsorted_spike_train(self, start_frame=None, end_frame=None):\n        \"\"\"This function extracts spike frames from the unsorted events.\n        It will return spike frames from within three ranges:\n\n            [start_frame, t_start+1, ..., end_frame-1]\n            [start_frame, start_frame+1, ..., final_unit_spike_frame - 1]\n            [0, 1, ..., end_frame-1]\n            [0, 1, ..., final_unit_spike_frame - 1]\n\n        if both start_frame and end_frame are given, if only start_frame is\n        given, if only end_frame is given, or if neither start_frame or end_frame\n        are given, respectively. Spike frames are returned in the form of an\n        array_like of spike frames. In this implementation, start_frame is inclusive\n        and end_frame is exclusive conforming to numpy standards.\n\n        Parameters\n        ----------\n        start_frame: int\n            The frame above which a spike frame is returned  (inclusive)\n        end_frame: int\n            The frame below which a spike frame is returned  (exclusive)\n        Returns\n        ----------\n        spike_train: numpy.ndarray\n            An 1D array containing all the frames for each spike in the\n            specified unit given the range of start and end frames\n        \"\"\"\n\n        raise NotImplementedError\n"
  },
  {
    "path": "spikeextractors/subrecordingextractor.py",
    "content": "from .recordingextractor import RecordingExtractor\nfrom .extraction_tools import check_get_traces_args, cast_start_end_frame, check_get_ttl_args\nimport numpy as np\n\n\n# Encapsulates a sub-dataset\nclass SubRecordingExtractor(RecordingExtractor):\n    def __init__(self, parent_recording, *, channel_ids=None, renamed_channel_ids=None, start_frame=None,\n                 end_frame=None):\n        start_frame, end_frame = cast_start_end_frame(start_frame, end_frame)\n        self._parent_recording = parent_recording\n        self._channel_ids = channel_ids\n        self._renamed_channel_ids = renamed_channel_ids\n        self._start_frame = start_frame\n        self._end_frame = end_frame\n\n        if self._channel_ids is None:\n            self._channel_ids = self._parent_recording.get_channel_ids()\n        if self._renamed_channel_ids is None:\n            self._renamed_channel_ids = self._channel_ids\n        if self._start_frame is None:\n            self._start_frame = 0\n        if self._end_frame is None:\n            self._end_frame = self._parent_recording.get_num_frames()\n        if self._end_frame > self._parent_recording.get_num_frames():\n            self._end_frame = self._parent_recording.get_num_frames()\n        self._original_channel_id_lookup = {}\n\n        for i in range(len(self._channel_ids)):\n            self._original_channel_id_lookup[self._renamed_channel_ids[i]] = self._channel_ids[i]\n        RecordingExtractor.__init__(self)\n        self.copy_channel_properties(parent_recording, channel_ids=self._renamed_channel_ids)\n\n        # avoid rescaling twice\n        self.clear_channel_gains()\n        self.clear_channel_offsets()\n\n        self.is_filtered = self._parent_recording.is_filtered\n        self.has_unscaled = self._parent_recording.has_unscaled\n\n        # update dump dict\n        self._kwargs = {'parent_recording': parent_recording.make_serialized_dict(), 'channel_ids': channel_ids,\n                        'renamed_channel_ids': renamed_channel_ids, 'start_frame': start_frame, 'end_frame': end_frame}\n\n    @check_get_traces_args\n    def get_traces(self, channel_ids=None, start_frame=None, end_frame=None, return_scaled=True):\n        sf = self._start_frame + start_frame\n        ef = self._start_frame + end_frame\n        original_ch_ids = self.get_original_channel_ids(channel_ids)\n        return self._parent_recording.get_traces(channel_ids=original_ch_ids, start_frame=sf, end_frame=ef,\n                                                 return_scaled=return_scaled)\n\n    @check_get_ttl_args\n    def get_ttl_events(self, start_frame=None, end_frame=None, channel_id=0):\n        sf = self._start_frame + start_frame\n        ef = self._start_frame + end_frame\n        sf, ef = cast_start_end_frame(sf, ef)\n        try:\n            ttl_frames, ttl_states = self._parent_recording.get_ttl_events(start_frame=sf, end_frame=ef,\n                                                                           channel_id=channel_id)\n            ttl_frames -= self._start_frame\n            return ttl_frames, ttl_states\n        except NotImplementedError:\n            raise NotImplementedError(\"The parent recording does implement the 'get_ttl_events method'\")\n\n    def get_channel_ids(self):\n        return list(self._renamed_channel_ids)\n\n    def get_num_frames(self):\n        return self._end_frame - self._start_frame\n\n    def get_sampling_frequency(self):\n        return self._parent_recording.get_sampling_frequency()\n\n    def frame_to_time(self, frame):\n        frame2 = frame + self._start_frame\n        time1 = self._parent_recording.frame_to_time(frame2)\n        start_time = self._parent_recording.frame_to_time(self._start_frame)\n        return np.round(time1 - start_time, 6)\n\n    def time_to_frame(self, time):\n        time2 = time + self._parent_recording.frame_to_time(self._start_frame)\n        frame1 = self._parent_recording.time_to_frame(time2)\n        frame2 = frame1 - self._start_frame\n        return frame2.astype('int64')\n\n    def get_snippets(self, reference_frames, snippet_len, channel_ids=None, return_scaled=True):\n        if channel_ids is None:\n            channel_ids = self.get_channel_ids()\n        reference_frames_shift = self._start_frame + np.array(reference_frames)\n        original_ch_ids = self.get_original_channel_ids(channel_ids)\n        return self._parent_recording.get_snippets(reference_frames=reference_frames_shift, snippet_len=snippet_len,\n                                                   channel_ids=original_ch_ids, return_scaled=return_scaled)\n\n    def copy_channel_properties(self, recording, channel_ids=None):\n        if channel_ids is None:\n            channel_ids = self.get_channel_ids()\n        if isinstance(channel_ids, (int, np.integer)):\n            recording_ch_id = channel_ids\n            if recording is self._parent_recording:\n                recording_ch_id = self.get_original_channel_ids(channel_ids)\n            curr_property_names = recording.get_channel_property_names(channel_id=recording_ch_id)\n            for curr_property_name in curr_property_names:\n                if curr_property_name not in self._key_properties.keys():  # key property\n                    value = recording.get_channel_property(channel_id=recording_ch_id, property_name=curr_property_name)\n                    self.set_channel_property(channel_id=channel_ids, property_name=curr_property_name, value=value)\n                else:\n                    if curr_property_name == 'group':\n                        group = recording.get_channel_groups(channel_ids=recording_ch_id)\n                        self.set_channel_groups(groups=group, channel_ids=channel_ids)\n                    elif curr_property_name == 'location':\n                        location = recording.get_channel_locations(channel_ids=recording_ch_id)\n                        self.set_channel_locations(locations=location, channel_ids=channel_ids)\n        else:\n            # copy key properties\n            original_channel_ids = self.get_original_channel_ids(channel_ids)\n            groups = recording.get_channel_groups(channel_ids=original_channel_ids)\n            locations = recording.get_channel_locations(channel_ids=original_channel_ids)\n            gains = recording.get_channel_gains(channel_ids=original_channel_ids)\n            offsets = recording.get_channel_offsets(channel_ids=original_channel_ids)\n            self.set_channel_groups(groups=groups, channel_ids=channel_ids)\n            self.set_channel_locations(locations=locations, channel_ids=channel_ids)\n            self.set_channel_gains(gains=gains, channel_ids=channel_ids)\n            self.set_channel_offsets(offsets=offsets, channel_ids=channel_ids)\n\n            # copy normal properties\n            for channel_id in channel_ids:\n                recording_ch_id = channel_id\n                if recording is self._parent_recording:\n                    recording_ch_id = self.get_original_channel_ids(channel_id)\n                curr_property_names = recording.get_channel_property_names(channel_id=recording_ch_id)\n                for curr_property_name in curr_property_names:\n                    if curr_property_name not in self._key_properties.keys():  # key property\n                        value = recording.get_channel_property(channel_id=recording_ch_id,\n                                                               property_name=curr_property_name)\n                        self.set_channel_property(channel_id=channel_id, property_name=curr_property_name, value=value)\n\n    def get_original_channel_ids(self, channel_ids):\n        if isinstance(channel_ids, (int, np.integer)):\n            if channel_ids in self.get_channel_ids():\n                original_ch_ids = self._original_channel_id_lookup[channel_ids]\n            else:\n                raise ValueError(\"Non-valid channel_id\")\n        else:\n            original_ch_ids = []\n            for channel_id in channel_ids:\n                if isinstance(channel_id, (int, np.integer)):\n                    if channel_id in self.get_channel_ids():\n                        original_ch_id = self._original_channel_id_lookup[channel_id]\n                        original_ch_ids.append(original_ch_id)\n                    else:\n                        raise ValueError(\"Non-valid channel_id\")\n                else:\n                    raise ValueError(\"channel_id must be an int\")\n        return original_ch_ids\n"
  },
  {
    "path": "spikeextractors/subsortingextractor.py",
    "content": "from .sortingextractor import SortingExtractor\nimport numpy as np\nfrom .extraction_tools import check_get_unit_spike_train\n\n\n# Encapsulates a subset of a spike sorted data file\nclass SubSortingExtractor(SortingExtractor):\n    def __init__(self, parent_sorting, *, unit_ids=None, renamed_unit_ids=None, start_frame=None, end_frame=None):\n        SortingExtractor.__init__(self)\n        start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame)\n        self._parent_sorting = parent_sorting\n        self._unit_ids = unit_ids\n        self._renamed_unit_ids = renamed_unit_ids\n        self._start_frame = start_frame\n        self._end_frame = end_frame\n        if self._unit_ids is None:\n            self._unit_ids = self._parent_sorting.get_unit_ids()\n        if self._renamed_unit_ids is None:\n            self._renamed_unit_ids = self._unit_ids\n        if self._start_frame is None:\n            self._start_frame = 0\n        if self._end_frame is None:\n            self._end_frame = np.Inf\n        self._original_unit_id_lookup = {}\n        for i in range(len(self._unit_ids)):\n            self._original_unit_id_lookup[self._renamed_unit_ids[i]] = self._unit_ids[i]\n        self.copy_unit_properties(parent_sorting, unit_ids=self._renamed_unit_ids)\n        self.copy_unit_spike_features(parent_sorting, unit_ids=self._renamed_unit_ids, start_frame=start_frame,\n                                      end_frame=end_frame)\n        self._kwargs = {'parent_sorting': parent_sorting.make_serialized_dict(), 'unit_ids': unit_ids,\n                        'renamed_unit_ids': renamed_unit_ids, 'start_frame': start_frame, 'end_frame': end_frame}\n\n    def get_unit_ids(self):\n        return list(self._renamed_unit_ids)\n\n    @check_get_unit_spike_train\n    def get_unit_spike_train(self, unit_id, start_frame=None, end_frame=None):\n        original_unit_id = self._original_unit_id_lookup[unit_id]\n        sf = self._start_frame + start_frame\n        ef = self._start_frame + end_frame\n        if sf < self._start_frame:\n            sf = self._start_frame\n        if ef > self._end_frame:\n            ef = self._end_frame\n        if ef == np.Inf:\n            ef = None\n        return self._parent_sorting.get_unit_spike_train(unit_id=original_unit_id, start_frame=sf,\n                                                         end_frame=ef) - self._start_frame\n\n    def get_sampling_frequency(self):\n        return self._parent_sorting.get_sampling_frequency()\n\n    def frame_to_time(self, frame):\n        frame2 = frame + self._start_frame\n        time1 = self._parent_sorting.frame_to_time(frame2)\n        start_time = self._parent_sorting.frame_to_time(self._start_frame)\n        return np.round(time1 - start_time, 6)\n\n    def time_to_frame(self, time):\n        time2 = time + self._parent_sorting.frame_to_time(self._start_frame)\n        frame1 = self._parent_sorting.time_to_frame(time2)\n        frame2 = frame1 - self._start_frame\n        return frame2.astype('int64')\n\n    def copy_unit_properties(self, sorting, unit_ids=None):\n        if unit_ids is None:\n            unit_ids = self.get_unit_ids()\n        if isinstance(unit_ids, (int, np.integer)):\n            sorting_unit_id = unit_ids\n            if sorting is self._parent_sorting:\n                sorting_unit_id = self.get_original_unit_ids(unit_ids)\n            curr_property_names = sorting.get_unit_property_names(unit_id=sorting_unit_id)\n            for curr_property_name in curr_property_names:\n                value = sorting.get_unit_property(unit_id=sorting_unit_id, property_name=curr_property_name)\n                self.set_unit_property(unit_id=unit_ids, property_name=curr_property_name, value=value)\n        else:\n            for unit_id in unit_ids:\n                sorting_unit_id = unit_id\n                if sorting is self._parent_sorting:\n                    sorting_unit_id = self.get_original_unit_ids(unit_id)\n                curr_property_names = sorting.get_unit_property_names(unit_id=sorting_unit_id)\n                for curr_property_name in curr_property_names:\n                    value = sorting.get_unit_property(unit_id=sorting_unit_id, property_name=curr_property_name)\n                    self.set_unit_property(unit_id=unit_id, property_name=curr_property_name, value=value)\n\n    def copy_unit_spike_features(self, sorting, unit_ids=None, start_frame=None, end_frame=None):\n        start_frame, end_frame = self._cast_start_end_frame(start_frame, end_frame)\n        if unit_ids is None:\n            unit_ids = self.get_unit_ids()\n        if isinstance(unit_ids, (int, np.integer)):\n            unit_ids = [unit_ids]\n        for unit_id in unit_ids:\n            sorting_unit_id = unit_id\n            if sorting is self._parent_sorting:\n                sorting_unit_id = self.get_original_unit_ids(unit_id)\n            curr_feature_names = sorting.get_unit_spike_feature_names(unit_id=sorting_unit_id)\n            for curr_feature_name in curr_feature_names:\n                value = sorting.get_unit_spike_features(unit_id=sorting_unit_id, feature_name=curr_feature_name,\n                                                        start_frame=start_frame, end_frame=end_frame)\n                if len(value) < len(sorting.get_unit_spike_train(sorting_unit_id, start_frame=start_frame,\n                                                                 end_frame=end_frame)):\n                    if not curr_feature_name.endswith('idxs'):\n                        assert curr_feature_name + '_idxs' in \\\n                               sorting.get_unit_spike_feature_names(unit_id=sorting_unit_id)\n                        curr_feature_name_idxs = curr_feature_name + '_idxs'\n                        value_idxs = np.array(sorting.get_unit_spike_features(unit_id=sorting_unit_id,\n                                                                              feature_name=curr_feature_name_idxs,\n                                                                              start_frame=start_frame,\n                                                                              end_frame=end_frame))\n                        # find index of first spike\n                        if start_frame is not None:\n                            discarded_spikes_idxs = np.where(sorting.get_unit_spike_train(sorting_unit_id) <\n                                                             start_frame)\n                            if len(discarded_spikes_idxs) > 0:\n                                n_discarded = len(discarded_spikes_idxs[0])\n                                value_idxs = value_idxs - n_discarded\n                        self.set_unit_spike_features(unit_id=unit_id, feature_name=curr_feature_name,\n                                                     value=value,\n                                                     indexes=value_idxs)\n                else:\n                    self.set_unit_spike_features(unit_id=unit_id, feature_name=curr_feature_name, value=value)\n\n    def get_original_unit_ids(self, unit_ids):\n        if isinstance(unit_ids, (int, np.integer)):\n            if unit_ids in self.get_unit_ids():\n                original_unit_ids = self._original_unit_id_lookup[unit_ids]\n            else:\n                raise ValueError(\"Non-valid unit_id\")\n        else:\n            original_unit_ids = []\n            for unit_id in unit_ids:\n                if isinstance(unit_id, (int, np.integer)):\n                    if unit_id in self.get_unit_ids():\n                        original_unit_id = self._original_unit_id_lookup[unit_id]\n                        original_unit_ids.append(original_unit_id)\n                    else:\n                        raise ValueError(\"Non-valid unit_id\")\n                else:\n                    raise ValueError(\"unit_id must be an int\")\n        return original_unit_ids\n"
  },
  {
    "path": "spikeextractors/testing.py",
    "content": "import os\nimport shutil\nfrom pathlib import Path\nimport uuid\nfrom datetime import datetime\nimport numpy as np\n\nfrom .extraction_tools import load_extractor_from_pickle, load_extractor_from_dict, \\\n    load_extractor_from_json\n\n\ndef check_recordings_equal(RX1, RX2, return_scaled=True, force_dtype=None, check_times=True):\n    N = RX1.get_num_frames()\n    # get_channel_ids\n    assert np.allclose(RX1.get_channel_ids(), RX2.get_channel_ids())\n    # get_num_channels\n    assert np.allclose(RX1.get_num_channels(), RX2.get_num_channels())\n    # get_num_frames\n    assert np.allclose(RX1.get_num_frames(), RX2.get_num_frames())\n    # get_sampling_frequency\n    assert np.allclose(RX1.get_sampling_frequency(), RX2.get_sampling_frequency())\n    # get_traces\n    if force_dtype is None:\n        assert np.allclose(RX1.get_traces(return_scaled=return_scaled), RX2.get_traces(return_scaled=return_scaled))\n    else:\n        assert np.allclose(RX1.get_traces(return_scaled=return_scaled).astype(force_dtype),\n                           RX2.get_traces(return_scaled=return_scaled).astype(force_dtype))\n    sf = 0\n    ef = N\n    if RX1.get_num_channels() > 1:\n        ch = [RX1.get_channel_ids()[0], RX1.get_channel_ids()[-1]]\n    else:\n        ch = RX1.get_channel_ids()\n    if force_dtype is None:\n        assert np.allclose(RX1.get_traces(channel_ids=ch, start_frame=sf, end_frame=ef, return_scaled=return_scaled),\n                           RX2.get_traces(channel_ids=ch, start_frame=sf, end_frame=ef, return_scaled=return_scaled))\n    else:\n        assert np.allclose(RX1.get_traces(channel_ids=ch, start_frame=sf, end_frame=ef,\n                                          return_scaled=return_scaled).astype(force_dtype),\n                           RX2.get_traces(channel_ids=ch, start_frame=sf, end_frame=ef,\n                                          return_scaled=return_scaled).astype(force_dtype))\n    if check_times:\n        for f in range(0, RX1.get_num_frames(), 10):\n            assert np.isclose(RX1.frame_to_time(f), RX2.frame_to_time(f))\n            assert np.isclose(RX1.time_to_frame(RX1.frame_to_time(f)), RX2.time_to_frame(RX2.frame_to_time(f)))\n    # get_snippets\n    frames = [30, 50, 80]\n    snippets1 = RX1.get_snippets(reference_frames=frames, snippet_len=20, return_scaled=return_scaled)\n    snippets2 = RX2.get_snippets(reference_frames=frames, snippet_len=(10, 10), return_scaled=return_scaled)\n    if force_dtype is None:\n        for ii in range(len(frames)):\n            assert np.allclose(snippets1[ii], snippets2[ii])\n    else:\n        for ii in range(len(frames)):\n            assert np.allclose(snippets1[ii].astype(force_dtype), snippets2[ii].astype(force_dtype))\n\n\ndef check_recording_properties(RX1, RX2):\n    # check properties\n    assert sorted(RX1.get_shared_channel_property_names()) == sorted(RX2.get_shared_channel_property_names())\n    for prop in RX1.get_shared_channel_property_names():\n        for ch in RX1.get_channel_ids():\n            if not isinstance(RX1.get_channel_property(ch, prop), str):\n                assert np.allclose(np.array(RX1.get_channel_property(ch, prop)),\n                                   np.array(RX2.get_channel_property(ch, prop)))\n            else:\n                assert RX1.get_channel_property(ch, prop) == RX2.get_channel_property(ch, prop)\n\n\ndef check_recording_return_types(RX):\n    channel_ids = RX.get_channel_ids()\n    assert isinstance(RX.get_num_channels(), (int, np.integer))\n    assert isinstance(RX.get_num_frames(), (int, np.integer))\n    assert isinstance(RX.get_sampling_frequency(), float)\n    assert isinstance(RX.get_traces(start_frame=0, end_frame=10), (np.ndarray, np.memmap))\n\n    for channel_id in channel_ids:\n        assert isinstance(channel_id, (int, np.integer))\n\n\ndef check_sorting_return_types(SX):\n    unit_ids = SX.get_unit_ids()\n    assert (all(isinstance(id, (int, np.integer)) or isinstance(id, np.integer) for id in unit_ids))\n    for id in unit_ids:\n        train = SX.get_unit_spike_train(id)\n        # print(train)\n        assert (all(isinstance(x, (int, np.integer)) or isinstance(x, np.integer) for x in train))\n\n\ndef check_sortings_equal(SX1, SX2):\n    # get_unit_ids\n    ids1 = np.sort(np.array(SX1.get_unit_ids()))\n    ids2 = np.sort(np.array(SX2.get_unit_ids()))\n    assert (np.allclose(ids1, ids2))\n    for id in ids1:\n        train1 = np.sort(SX1.get_unit_spike_train(id))\n        train2 = np.sort(SX2.get_unit_spike_train(id))\n        assert np.array_equal(train1, train2)\n\n\ndef check_sorting_properties_features(SX1, SX2):\n    # check properties\n    print(SX1.__class__)\n    print('Properties', sorted(SX1.get_shared_unit_property_names()), sorted(SX2.get_shared_unit_property_names()))\n    assert sorted(SX1.get_shared_unit_property_names()) == sorted(SX2.get_shared_unit_property_names())\n    for prop in SX1.get_shared_unit_property_names():\n        for u in SX1.get_unit_ids():\n            if not isinstance(SX1.get_unit_property(u, prop), str):\n                assert np.allclose(np.array(SX1.get_unit_property(u, prop)),\n                                   np.array(SX2.get_unit_property(u, prop)))\n            else:\n                assert SX1.get_unit_property(u, prop) == SX2.get_unit_property(u, prop)\n    # check features\n    print('Features', sorted(SX1.get_shared_unit_spike_feature_names()),\n          sorted(SX2.get_shared_unit_spike_feature_names()))\n    assert sorted(SX1.get_shared_unit_spike_feature_names()) == sorted(SX2.get_shared_unit_spike_feature_names())\n    for feat in SX1.get_shared_unit_spike_feature_names():\n        for u in SX1.get_unit_ids():\n            assert np.allclose(np.array(SX1.get_unit_spike_features(u, feat)),\n                               np.array(SX2.get_unit_spike_features(u, feat)))\n\n\ndef check_dumping(extractor, test_relative=False):\n    # dump to dict\n    d = extractor.dump_to_dict()\n    extractor_loaded = load_extractor_from_dict(d)\n\n    if 'Recording' in str(type(extractor)):\n        check_recordings_equal(extractor, extractor_loaded, check_times=False)\n    elif 'Sorting' in str(type(extractor)):\n        check_sortings_equal(extractor, extractor_loaded)\n\n    # dump to json\n    # without file_name\n    extractor.dump_to_json()\n\n    if 'Recording' in str(type(extractor)):\n        extractor_loaded = load_extractor_from_json('spikeinterface_recording.json')\n        check_recordings_equal(extractor, extractor_loaded, check_times=False)\n    elif 'Sorting' in str(type(extractor)):\n        extractor_loaded = load_extractor_from_json('spikeinterface_sorting.json')\n        check_sortings_equal(extractor, extractor_loaded)\n\n    # with file_name\n    extractor.dump_to_json(file_path='test_dumping/test.json')\n    extractor_loaded = load_extractor_from_json('test_dumping/test.json')\n\n    if 'Recording' in str(type(extractor)):\n        check_recordings_equal(extractor, extractor_loaded, check_times=False)\n    elif 'Sorting' in str(type(extractor)):\n        check_sortings_equal(extractor, extractor_loaded)\n\n    # dump to pickle\n    # without file_name\n    extractor.dump_to_pickle()\n\n    if 'Recording' in str(type(extractor)):\n        extractor_loaded = load_extractor_from_pickle('spikeinterface_recording.pkl')\n        check_recordings_equal(extractor, extractor_loaded, check_times=True)\n        check_recording_properties(extractor, extractor_loaded)\n    elif 'Sorting' in str(type(extractor)):\n        extractor_loaded = load_extractor_from_pickle('spikeinterface_sorting.pkl')\n        check_sortings_equal(extractor, extractor_loaded)\n        check_sorting_properties_features(extractor, extractor_loaded)\n\n    # with file_name\n    extractor.dump_to_pickle(file_path='test_dumping/test.pkl')\n    extractor_loaded = load_extractor_from_pickle('test_dumping/test.pkl')\n\n    if 'Recording' in str(type(extractor)):\n        check_recordings_equal(extractor, extractor_loaded, check_times=True)\n        check_recording_properties(extractor, extractor_loaded)\n    elif 'Sorting' in str(type(extractor)):\n        check_sortings_equal(extractor, extractor_loaded)\n        check_sorting_properties_features(extractor, extractor_loaded)\n\n    if test_relative:\n        # dump to dict with relative path\n        d = extractor.dump_to_dict(relative_to=\".\")\n        extractor_loaded = load_extractor_from_dict(d)\n\n        if 'Recording' in str(type(extractor)):\n            check_recordings_equal(extractor, extractor_loaded, check_times=False)\n        elif 'Sorting' in str(type(extractor)):\n            check_sortings_equal(extractor, extractor_loaded)\n\n        # dump to json with relative path\n        extractor.dump_to_json(file_path='test_dumping/test_rel.json', relative_to=\".\")\n        extractor_loaded = load_extractor_from_json('test_dumping/test_rel.json')\n\n        if 'Recording' in str(type(extractor)):\n            check_recordings_equal(extractor, extractor_loaded, check_times=False)\n        elif 'Sorting' in str(type(extractor)):\n            check_sortings_equal(extractor, extractor_loaded)\n\n        # dump to pickle with relative path\n        extractor.dump_to_pickle(file_path='test_dumping/test_rel.pkl', relative_to=\".\")\n        extractor_loaded = load_extractor_from_pickle('test_dumping/test_rel.pkl')\n\n        if 'Recording' in str(type(extractor)):\n            check_recordings_equal(extractor, extractor_loaded, check_times=True)\n        elif 'Sorting' in str(type(extractor)):\n            check_sortings_equal(extractor, extractor_loaded)\n\n    shutil.rmtree('test_dumping')\n    if Path('spikeinterface_recording.json').is_file():\n        os.remove('spikeinterface_recording.json')\n    if Path('spikeinterface_sorting.json').is_file():\n        os.remove('spikeinterface_sorting.json')\n    if Path('spikeinterface_recording.pkl').is_file():\n        os.remove('spikeinterface_recording.pkl')\n    if Path('spikeinterface_sorting.pkl').is_file():\n        os.remove('spikeinterface_sorting.pkl')\n\n\ndef get_default_nwbfile_metadata():\n    \"\"\"\n    Returns structure with defaulted metadata values required for a NWBFile.\n    \"\"\"\n    metadata = dict(\n        NWBFile=dict(\n            session_description=\"no description\",\n            session_start_time=datetime(1970, 1, 1),\n            identifier=str(uuid.uuid4())\n        ),\n        Ecephys=dict(\n            Device=[dict(\n                name='Device_ecephys',\n                description='no description'\n            )],\n            ElectrodeGroup=[],\n            ElectricalSeries_raw=dict(\n                name='raw_traces',\n                description='those are the raw traces'\n            ),\n            ElectricalSeries_processed=dict(\n                name='processed_traces',\n                description='those are the processed traces'\n            ),\n            ElectricalSeries_lfp=dict(\n                name='lfp_traces',\n                description='those are the lfp traces'\n            )\n        )\n    )\n    return metadata\n"
  },
  {
    "path": "spikeextractors/version.py",
    "content": "version = '0.9.11'\n"
  },
  {
    "path": "tests/__init__.py",
    "content": ""
  },
  {
    "path": "tests/probe_test.prb",
    "content": "channel_groups = {\n    1: {\n        'channels': list(range(16)),\n        'graph' : [],\n        'geometry': {\n            0:  [  0.0 ,   0.0],\n            1:  [  0.0 ,  50.0],\n            2:  [+21.65, 262.5],\n            3:  [+21.65, 237.5],\n            4:  [+21.65, 187.5],\n            5:  [+21.65, 137.5],\n            6:  [+21.65,  87.5],\n            7:  [+21.65,  37.5],\n            8:  [  0.0 , 200.0],\n            9:  [  0.0 , 250.0],\n            10: [+21.65,  62.5],\n            11: [+21.65, 112.5],\n            12: [+21.65, 162.5],\n            13: [+21.65, 212.5],\n            14: [  0.0 , 150.0],\n            15: [  0.0 , 100.0],\n\t}\n    },\n\n    2: {\n\t'channels': list(range(16,32)),\n\t'graph' : [],\n\t'geometry': {\n\t    16: [  0.0 , 125.0],\n\t    17: [  0.0 , 175.0],\n\t    18: [-21.65, 212.5],\n\t    19: [-21.65, 162.5],\n\t    20: [-21.65, 112.5],\n\t    21: [-21.65,  62.5],\n\t    22: [  0.0 , 275.0],\n\t    23: [  0.0 , 225.0],\n\t    24: [-21.65,  37.5],\n\t    25: [-21.65,  87.5],\n\t    26: [-21.65, 137.5],\n\t    27: [-21.65, 187.5],\n\t    28: [-21.65, 237.5],\n\t    29: [-21.65, 262.5],\n\t    30: [  0.0 ,  75.0],\n\t    31: [  0.0 ,  25.0],\n\t}\n    }\n}\n"
  },
  {
    "path": "tests/test_extractors.py",
    "content": "import os\nimport shutil\nimport tempfile\nimport unittest\nfrom pathlib import Path\n\nimport numpy as np\n\nimport spikeextractors as se\nfrom spikeextractors.exceptions import NotDumpableExtractorError\nfrom spikeextractors.testing import (check_sortings_equal, check_recordings_equal, check_dumping,\n    check_recording_return_types, check_sorting_return_types, get_default_nwbfile_metadata)\n\n\nclass TestExtractors(unittest.TestCase):\n    def setUp(self):\n        self.RX, self.RX2, self.RX3, self.SX, self.SX2, self.SX3, self.example_info = self._create_example(seed=0)\n        self.test_dir = tempfile.mkdtemp()\n        # self.test_dir = '.'\n\n    def tearDown(self):\n        # Remove the directory after the test\n        del self.RX, self.RX2, self.RX3, self.SX, self.SX2, self.SX3\n        shutil.rmtree(self.test_dir)\n        # pass\n\n    def _create_example(self, seed):\n        channel_ids = [0, 1, 2, 3]\n        num_channels = 4\n        num_frames = 10000\n        num_ttls = 30\n        sampling_frequency = 30000\n        X = np.random.RandomState(seed=seed).normal(0, 1, (num_channels, num_frames))\n        geom = np.random.RandomState(seed=seed).normal(0, 1, (num_channels, 2))\n        X = (X * 100).astype(int)\n        ttls = np.sort(np.random.permutation(num_frames)[:num_ttls])\n\n        RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom)\n        RX.set_ttls(ttls)\n        RX.set_channel_locations([0, 0], channel_ids=0)\n        RX.add_epoch(\"epoch1\", 0, 10)\n        RX.add_epoch(\"epoch2\", 10, 20)\n        for i, channel_id in enumerate(RX.get_channel_ids()):\n            RX.set_channel_property(channel_id=channel_id, property_name='shared_channel_prop', value=i)\n\n        RX2 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom)\n        RX2.copy_epochs(RX)\n        times = np.arange(RX.get_num_frames()) / RX.get_sampling_frequency() + 5\n        RX2.set_times(times)\n\n        RX3 = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom)\n\n        SX = se.NumpySortingExtractor()\n        SX.set_sampling_frequency(sampling_frequency)\n        spike_times = [200, 300, 400]\n        train1 = np.sort(np.rint(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times[0])).astype(int))\n        SX.add_unit(unit_id=1, times=train1)\n        SX.add_unit(unit_id=2, times=np.sort(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times[1])))\n        SX.add_unit(unit_id=3, times=np.sort(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times[2])))\n        SX.set_unit_property(unit_id=1, property_name='stability', value=80)\n        SX.add_epoch(\"epoch1\", 0, 10)\n        SX.add_epoch(\"epoch2\", 10, 20)\n\n        SX2 = se.NumpySortingExtractor()\n        SX2.set_sampling_frequency(sampling_frequency)\n        spike_times2 = [100, 150, 450]\n        train2 = np.rint(np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times2[0])).astype(int)\n        SX2.add_unit(unit_id=3, times=train2)\n        SX2.add_unit(unit_id=4, times=np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times2[1]))\n        SX2.add_unit(unit_id=5, times=np.random.RandomState(seed=seed).uniform(0, num_frames, spike_times2[2]))\n        SX2.set_unit_property(unit_id=4, property_name='stability', value=80)\n        SX2.set_unit_spike_features(unit_id=3, feature_name='widths', value=np.asarray([3] * spike_times2[0]))\n        SX2.copy_epochs(SX)\n        SX2.copy_times(RX2)\n        for i, unit_id in enumerate(SX2.get_unit_ids()):\n            SX2.set_unit_property(unit_id=unit_id, property_name='shared_unit_prop', value=i)\n            SX2.set_unit_spike_features(\n                unit_id=unit_id,\n                feature_name='shared_unit_feature',\n                value=np.asarray([i] * spike_times2[i])\n            )\n\n        SX3 = se.NumpySortingExtractor()\n        train3 = np.asarray([1, 20, 21, 35, 38, 45, 46, 47])\n        SX3.add_unit(unit_id=0, times=train3)\n        features3 = np.asarray([0, 5, 10, 15, 20, 25, 30, 35])\n        features4 = np.asarray([0, 10, 20, 30])\n        feature4_idx = np.asarray([0, 2, 4, 6])\n        SX3.set_unit_spike_features(unit_id=0, feature_name='dummy', value=features3)\n        SX3.set_unit_spike_features(unit_id=0, feature_name='dummy2', value=features4, indexes=feature4_idx)\n\n        example_info = dict(\n            channel_ids=channel_ids,\n            num_channels=num_channels,\n            num_frames=num_frames,\n            sampling_frequency=sampling_frequency,\n            unit_ids=[1, 2, 3],\n            train1=train1,\n            train2=train2,\n            train3=train3,\n            features3=features3,\n            unit_prop=80,\n            channel_prop=(0, 0),\n            ttls=ttls,\n            epochs_info=((0, 10), (10, 20)),\n            geom=geom,\n            times=times\n        )\n\n        return (RX, RX2, RX3, SX, SX2, SX3, example_info)\n\n    def test_example(self):\n        self.assertEqual(self.RX.get_channel_ids(), self.example_info['channel_ids'])\n        self.assertEqual(self.RX.get_num_channels(), self.example_info['num_channels'])\n        self.assertEqual(self.RX.get_num_frames(), self.example_info['num_frames'])\n        self.assertEqual(self.RX.get_sampling_frequency(), self.example_info['sampling_frequency'])\n        self.assertEqual(self.SX.get_unit_ids(), self.example_info['unit_ids'])\n        self.assertEqual(self.RX.get_channel_locations(0)[0][0], self.example_info['channel_prop'][0])\n        self.assertEqual(self.RX.get_channel_locations(0)[0][1], self.example_info['channel_prop'][1])\n        self.assertTrue(np.array_equal(self.RX.get_ttl_events()[0], self.example_info['ttls']))\n        self.assertEqual(self.SX.get_unit_property(unit_id=1, property_name='stability'),\n                         self.example_info['unit_prop'])\n        self.assertTrue(np.array_equal(self.SX.get_unit_spike_train(1), self.example_info['train1']))\n\n        self.assertTrue(issubclass(self.SX.get_unit_spike_train(1).dtype.type, np.integer))\n        self.assertTrue(self.RX.get_shared_channel_property_names(), ['group', 'location', 'shared_channel_prop'])\n        self.assertTrue(self.RX.get_channel_property_names(0), ['group', 'location', 'shared_channel_prop'])\n        self.assertTrue(self.SX2.get_shared_unit_property_names(), ['shared_unit_prop'])\n        self.assertTrue(self.SX2.get_unit_property_names(4), ['shared_unit_prop', 'stability'])\n        self.assertTrue(self.SX2.get_shared_unit_spike_feature_names(), ['shared_unit_feature'])\n        self.assertTrue(self.SX2.get_unit_spike_feature_names(3), ['shared_channel_prop', 'widths'])\n\n        self.assertTrue(np.array_equal(self.SX3.get_unit_spike_features(0, 'dummy'), self.example_info['features3']))\n        self.assertTrue(np.array_equal(self.SX3.get_unit_spike_features(0, 'dummy', start_frame=4),\n                                       self.example_info['features3'][1:]))\n        self.assertTrue(np.array_equal(self.SX3.get_unit_spike_features(0, 'dummy', end_frame=4),\n                                       self.example_info['features3'][:1]))\n        self.assertTrue(np.array_equal(self.SX3.get_unit_spike_features(0, 'dummy', start_frame=20, end_frame=46),\n                                       self.example_info['features3'][1:6]))\n        self.assertTrue('dummy2' in self.SX3.get_unit_spike_feature_names(0))\n        self.assertTrue('dummy2_idxs' in self.SX3.get_unit_spike_feature_names(0))\n\n        sub_extractor_full = se.SubSortingExtractor(self.SX3)\n        sub_extractor_partial = se.SubSortingExtractor(self.SX3, start_frame=20, end_frame=46)\n\n        self.assertTrue(np.array_equal(sub_extractor_full.get_unit_spike_features(0, 'dummy'),\n                                       self.SX3.get_unit_spike_features(0, 'dummy')))\n        self.assertTrue(np.array_equal(sub_extractor_partial.get_unit_spike_features(0, 'dummy'),\n                                       self.SX3.get_unit_spike_features(0, 'dummy', start_frame=20, end_frame=46)))\n\n        self.assertEqual(tuple(self.RX.get_epoch_info(\"epoch1\").values()), self.example_info['epochs_info'][0])\n        self.assertEqual(tuple(self.RX.get_epoch_info(\"epoch2\").values()), self.example_info['epochs_info'][1])\n        self.assertEqual(tuple(self.SX.get_epoch_info(\"epoch1\").values()), self.example_info['epochs_info'][0])\n        self.assertEqual(tuple(self.SX.get_epoch_info(\"epoch2\").values()), self.example_info['epochs_info'][1])\n\n        self.assertEqual(tuple(self.RX.get_epoch_info(\"epoch1\").values()),\n                         tuple(self.RX2.get_epoch_info(\"epoch1\").values()))\n        self.assertEqual(tuple(self.RX.get_epoch_info(\"epoch2\").values()),\n                         tuple(self.RX2.get_epoch_info(\"epoch2\").values()))\n        self.assertEqual(tuple(self.SX.get_epoch_info(\"epoch1\").values()),\n                         tuple(self.SX2.get_epoch_info(\"epoch1\").values()))\n        self.assertEqual(tuple(self.SX.get_epoch_info(\"epoch2\").values()),\n                         tuple(self.SX2.get_epoch_info(\"epoch2\").values()))\n\n        self.assertTrue(np.array_equal(self.RX2.frame_to_time(np.arange(self.RX2.get_num_frames())),\n                                       self.example_info['times']))\n        self.assertTrue(np.array_equal(self.SX2.get_unit_spike_train(3) / self.SX2.get_sampling_frequency() + 5,\n                                       self.SX2.frame_to_time(self.SX2.get_unit_spike_train(3))))\n\n        self.RX3.clear_channel_locations()\n        self.assertTrue('location' not in self.RX3.get_shared_channel_property_names())\n        self.RX3.set_channel_locations(self.example_info['geom'])\n        self.assertTrue(np.array_equal(self.RX3.get_channel_locations(),\n                                       self.RX2.get_channel_locations()))\n        self.RX3.set_channel_groups(groups=[1], channel_ids=[1])\n        self.assertEqual(self.RX3.get_channel_groups(channel_ids=[1]), 1)\n        self.RX3.clear_channel_groups()\n        self.assertEqual(self.RX3.get_channel_groups(channel_ids=[1]), 0)\n        self.RX3.set_channel_locations(locations=[[np.nan, np.nan, np.nan]], channel_ids=[1])\n        self.assertTrue('location' not in self.RX3.get_shared_channel_property_names())\n        self.RX3.set_channel_locations(locations=[[0, 0, 0]], channel_ids=[1])\n        self.assertTrue('location' in self.RX3.get_shared_channel_property_names())\n        check_recording_return_types(self.RX)\n\n    def test_allocate_arrays(self):\n        shape = (30, 1000)\n        dtype = 'int16'\n\n        arr_in_memory = self.RX.allocate_array(shape=shape, dtype=dtype, memmap=False)\n        arr_memmap = self.RX.allocate_array(shape=shape, dtype=dtype, memmap=True)\n\n        assert isinstance(arr_in_memory, np.ndarray)\n        assert isinstance(arr_memmap, np.memmap)\n        assert arr_in_memory.shape == shape\n        assert arr_memmap.shape == shape\n        assert arr_in_memory.dtype == dtype\n        assert arr_memmap.dtype == dtype\n\n        arr_in_memory = self.SX.allocate_array(shape=shape, dtype=dtype, memmap=False)\n        arr_memmap = self.SX.allocate_array(shape=shape, dtype=dtype, memmap=True)\n\n        assert isinstance(arr_in_memory, np.ndarray)\n        assert isinstance(arr_memmap, np.memmap)\n        assert arr_in_memory.shape == shape\n        assert arr_memmap.shape == shape\n        assert arr_in_memory.dtype == dtype\n        assert arr_memmap.dtype == dtype\n\n    def test_cache_extractor(self):\n        cache_rec = se.CacheRecordingExtractor(self.RX)\n        check_recording_return_types(cache_rec)\n        check_recordings_equal(self.RX, cache_rec)\n        cache_rec.move_to('cache_rec')\n\n        assert cache_rec.filename == 'cache_rec.dat'\n        check_dumping(cache_rec, test_relative=True)\n\n        cache_rec = se.CacheRecordingExtractor(self.RX, save_path='cache_rec2')\n        check_recording_return_types(cache_rec)\n        check_recordings_equal(self.RX, cache_rec)\n\n        assert cache_rec.filename == 'cache_rec2.dat'\n        check_dumping(cache_rec, test_relative=True)\n\n        # test saving to file\n        del cache_rec\n        assert Path('cache_rec2.dat').is_file()\n\n        # test tmp\n        cache_rec = se.CacheRecordingExtractor(self.RX)\n        tmp_file = cache_rec.filename\n        del cache_rec\n        assert not Path(tmp_file).is_file()\n\n        cache_sort = se.CacheSortingExtractor(self.SX)\n        check_sorting_return_types(cache_sort)\n        check_sortings_equal(self.SX, cache_sort)\n        cache_sort.move_to('cache_sort')\n\n        assert cache_sort.filename == 'cache_sort.npz'\n        check_dumping(cache_sort, test_relative=True)\n\n        # test saving to file\n        del cache_sort\n        assert Path('cache_sort.npz').is_file()\n\n        cache_sort = se.CacheSortingExtractor(self.SX, save_path='cache_sort2')\n        check_sorting_return_types(cache_sort)\n        check_sortings_equal(self.SX, cache_sort)\n\n        assert cache_sort.filename == 'cache_sort2.npz'\n        check_dumping(cache_sort, test_relative=True)\n\n        # test saving to file\n        del cache_sort\n        assert Path('cache_sort2.npz').is_file()\n\n        # test tmp\n        cache_sort = se.CacheSortingExtractor(self.SX)\n        tmp_file = cache_sort.filename\n        del cache_sort\n        assert not Path(tmp_file).is_file()\n\n        # cleanup\n        os.remove('cache_rec.dat')\n        os.remove('cache_rec2.dat')\n        os.remove('cache_sort.npz')\n        os.remove('cache_sort2.npz')\n\n    def test_not_dumpable_exception(self):\n        try:\n            self.RX.dump_to_json()\n        except Exception as e:\n            assert isinstance(e, NotDumpableExtractorError)\n\n        try:\n            self.RX.dump_to_pickle()\n        except Exception as e:\n            assert isinstance(e, NotDumpableExtractorError)\n\n    def test_mda_extractor(self):\n        path1 = self.test_dir + '/mda'\n        path2 = path1 + '/firings_true.mda'\n        se.MdaRecordingExtractor.write_recording(self.RX, path1)\n        se.MdaSortingExtractor.write_sorting(self.SX, path2)\n        RX_mda = se.MdaRecordingExtractor(path1)\n        SX_mda = se.MdaSortingExtractor(path2)\n        check_recording_return_types(RX_mda)\n        check_recordings_equal(self.RX, RX_mda)\n        check_sorting_return_types(SX_mda)\n        check_sortings_equal(self.SX, SX_mda)\n        check_dumping(RX_mda)\n        check_dumping(SX_mda)\n\n    def test_hdsort_extractor(self):\n        path = self.test_dir + '/results_test_hdsort_extractor.mat'\n        locations = np.ones((10, 2))\n        se.HDSortSortingExtractor.write_sorting(self.SX, path, locations=locations, noise_std_by_channel=None)\n        SX_hd = se.HDSortSortingExtractor(path)\n        check_sorting_return_types(SX_hd)\n        check_sortings_equal(self.SX, SX_hd)\n        check_dumping(SX_hd)\n\n    def test_npz_extractor(self):\n        path = self.test_dir + '/sorting.npz'\n        se.NpzSortingExtractor.write_sorting(self.SX, path)\n        SX_npz = se.NpzSortingExtractor(path)\n\n        # empty write\n        sorting_empty = se.NumpySortingExtractor()\n        path_empty = self.test_dir + '/sorting_empty.npz'\n        se.NpzSortingExtractor.write_sorting(sorting_empty, path_empty)\n\n        check_sorting_return_types(SX_npz)\n        check_sortings_equal(self.SX, SX_npz)\n        check_dumping(SX_npz)\n\n    def test_biocam_extractor(self):\n        path1 = self.test_dir + '/raw.brw'\n        se.BiocamRecordingExtractor.write_recording(self.RX, path1)\n        RX_biocam = se.BiocamRecordingExtractor(path1)\n        check_recording_return_types(RX_biocam)\n        check_recordings_equal(self.RX, RX_biocam)\n        check_dumping(RX_biocam)\n\n    def test_mearec_extractors(self):\n        path1 = self.test_dir + '/raw.h5'\n        se.MEArecRecordingExtractor.write_recording(self.RX, path1)\n        RX_mearec = se.MEArecRecordingExtractor(path1)\n        tr = RX_mearec.get_traces(channel_ids=[0, 1], end_frame=1000)\n        check_recording_return_types(RX_mearec)\n        check_recordings_equal(self.RX, RX_mearec)\n        check_dumping(RX_mearec)\n\n        path2 = self.test_dir + '/firings_true.h5'\n        se.MEArecSortingExtractor.write_sorting(self.SX, path2, self.RX.get_sampling_frequency())\n        SX_mearec = se.MEArecSortingExtractor(path2)\n        check_sorting_return_types(SX_mearec)\n        check_sortings_equal(self.SX, SX_mearec)\n        check_dumping(SX_mearec)\n\n    def test_hs2_extractor(self):\n        path1 = self.test_dir + '/firings_true.hdf5'\n        se.HS2SortingExtractor.write_sorting(self.SX, path1)\n        SX_hs2 = se.HS2SortingExtractor(path1)\n        check_sorting_return_types(SX_hs2)\n        check_sortings_equal(self.SX, SX_hs2)\n        self.assertEqual(SX_hs2.get_sampling_frequency(), self.SX.get_sampling_frequency())\n        check_dumping(SX_hs2)\n\n    def test_exdir_extractors(self):\n        path1 = self.test_dir + '/raw.exdir'\n        se.ExdirRecordingExtractor.write_recording(self.RX, path1)\n        RX_exdir = se.ExdirRecordingExtractor(path1)\n        check_recording_return_types(RX_exdir)\n        check_recordings_equal(self.RX, RX_exdir)\n        check_dumping(RX_exdir)\n\n        path2 = self.test_dir + '/firings.exdir'\n        se.ExdirSortingExtractor.write_sorting(self.SX, path2, self.RX)\n        SX_exdir = se.ExdirSortingExtractor(path2)\n        check_sorting_return_types(SX_exdir)\n        check_sortings_equal(self.SX, SX_exdir)\n        check_dumping(SX_exdir)\n\n    def test_spykingcircus_extractor(self):\n        path1 = self.test_dir + '/sc'\n        se.SpykingCircusSortingExtractor.write_sorting(self.SX, path1)\n        SX_spy = se.SpykingCircusSortingExtractor(path1)\n        check_sorting_return_types(SX_spy)\n        check_sortings_equal(self.SX, SX_spy)\n        check_dumping(SX_spy)\n\n    def test_multi_sub_recording_extractor(self):\n        RX_multi = se.MultiRecordingTimeExtractor(\n            recordings=[self.RX, self.RX, self.RX],\n            epoch_names=['A', 'B', 'C']\n        )\n        RX_sub = RX_multi.get_epoch('C')\n        check_recordings_equal(self.RX, RX_sub)\n        check_recordings_equal(self.RX, RX_multi.recordings[0])\n        check_recordings_equal(self.RX, RX_multi.recordings[1])\n        check_recordings_equal(self.RX, RX_multi.recordings[2])\n        self.assertEqual(4, len(RX_sub.get_channel_ids()))\n\n        RX_multi = se.MultiRecordingChannelExtractor(\n            recordings=[self.RX, self.RX2, self.RX3],\n            groups=[1, 2, 3]\n        )\n        RX_sub = se.SubRecordingExtractor(RX_multi, channel_ids=[4, 5, 6, 7], renamed_channel_ids=[0, 1, 2, 3])\n        # RX2 has times\n        check_recordings_equal(self.RX2, RX_sub, check_times=False)\n        check_recordings_equal(self.RX, RX_multi.recordings[0])\n        check_recordings_equal(self.RX2, RX_multi.recordings[1], check_times=False)\n        check_recordings_equal(self.RX3, RX_multi.recordings[2])\n        self.assertEqual([2, 2, 2, 2], list(RX_sub.get_channel_groups()))\n        self.assertEqual(12, len(RX_multi.get_channel_ids()))\n\n        RX_multi = se.MultiRecordingChannelExtractor(\n            recordings=[self.RX2, self.RX2, self.RX2],\n            groups=[1, 2, 3]\n        )\n        RX_sub = se.SubRecordingExtractor(RX_multi, channel_ids=[4, 5, 6, 7], renamed_channel_ids=[0, 1, 2, 3])\n        check_recordings_equal(self.RX2, RX_sub, check_times=False)\n        check_recordings_equal(self.RX2, RX_multi.recordings[0])\n        check_recordings_equal(self.RX2, RX_multi.recordings[1], check_times=False)\n        check_recordings_equal(self.RX2, RX_multi.recordings[2])\n        self.assertTrue(np.array_equal([2, 2, 2, 2], list(RX_sub.get_channel_groups())))\n        self.assertTrue(12 == len(RX_multi.get_channel_ids()))\n        self.assertTrue(np.array_equal(RX_multi.frame_to_time(np.arange(RX_multi.get_num_frames())),\n                        np.arange(RX_multi.get_num_frames()) / RX_multi.get_sampling_frequency() + 5))\n\n        rx1 = self.RX\n        rx2 = self.RX2\n        rx3 = self.RX3\n        rx2.set_channel_property(0, \"foo\", 100)\n        rx3.set_channel_locations([11, 11], channel_ids=0)\n        RX_multi_c = se.MultiRecordingChannelExtractor(\n            recordings=[rx1, rx2, rx3],\n            groups=[0, 0, 1]\n        )\n        self.assertTrue(np.array_equal([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], RX_multi_c.get_channel_ids()))\n        self.assertTrue(np.array_equal([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1], RX_multi_c.get_channel_groups()))\n        self.assertEqual(rx2.get_channel_property(0, \"foo\"), RX_multi_c.get_channel_property(4, \"foo\"))\n        self.assertTrue(np.array_equal(rx3.get_channel_locations([0])[0], RX_multi_c.get_channel_locations([8])[0]))\n\n    def test_ttl_frames_in_sub_multi(self):\n        # sub recording\n        start_frame = self.example_info['num_frames'] // 3\n        end_frame = 2 * self.example_info['num_frames'] // 3\n        RX_sub = se.SubRecordingExtractor(self.RX, start_frame=start_frame, end_frame=end_frame)\n        original_ttls = self.RX.get_ttl_events()[0]\n        ttls_in_sub = original_ttls[np.where((original_ttls >= start_frame) & (original_ttls < end_frame))[0]]\n        self.assertTrue(np.array_equal(RX_sub.get_ttl_events()[0], ttls_in_sub - start_frame))\n\n        # multirecording\n        RX_multi = se.MultiRecordingTimeExtractor(recordings=[self.RX, self.RX, self.RX])\n        ttls_originals = self.RX.get_ttl_events()[0]\n        num_ttls = len(ttls_originals)\n        self.assertEqual(len(RX_multi.get_ttl_events()[0]), 3 * num_ttls)\n        self.assertTrue(np.array_equal(RX_multi.get_ttl_events()[0][:num_ttls], ttls_originals))\n        self.assertTrue(np.array_equal(RX_multi.get_ttl_events()[0][num_ttls:2 * num_ttls],\n                                       ttls_originals + self.RX.get_num_frames()))\n        self.assertTrue(np.array_equal(RX_multi.get_ttl_events()[0][2 * num_ttls:],\n                                       ttls_originals + 2 * self.RX.get_num_frames()))\n\n    def test_multi_sub_sorting_extractor(self):\n        N = self.RX.get_num_frames()\n        SX_multi = se.MultiSortingExtractor(\n            sortings=[self.SX, self.SX, self.SX],\n        )\n        SX_multi.set_unit_property(unit_id=1, property_name='dummy', value=5)\n        SX_sub = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=0)\n        check_sortings_equal(SX_multi, SX_sub)\n        self.assertEqual(SX_multi.get_unit_property(1, 'dummy'), SX_sub.get_unit_property(1, 'dummy'))\n\n        N = self.RX.get_num_frames()\n        SX_multi = se.MultiSortingExtractor(\n            sortings=[self.SX, self.SX2],\n        )\n        SX_sub1 = se.SubSortingExtractor(parent_sorting=SX_multi, start_frame=0, end_frame=N)\n        check_sortings_equal(SX_multi, SX_sub1)\n        check_sortings_equal(self.SX, SX_multi.sortings[0])\n        check_sortings_equal(self.SX2, SX_multi.sortings[1])\n\n    def test_dump_load_multi_sub_extractor(self):\n        # generate dumpable formats\n        path1 = self.test_dir + '/mda'\n        path2 = path1 + '/firings_true.mda'\n        se.MdaRecordingExtractor.write_recording(self.RX, path1)\n        se.MdaSortingExtractor.write_sorting(self.SX, path2)\n        RX_mda = se.MdaRecordingExtractor(path1)\n        SX_mda = se.MdaSortingExtractor(path2)\n\n        RX_multi_chan = se.MultiRecordingChannelExtractor(recordings=[RX_mda, RX_mda, RX_mda])\n        check_dumping(RX_multi_chan)\n        RX_multi_time = se.MultiRecordingTimeExtractor(recordings=[RX_mda, RX_mda, RX_mda], )\n        check_dumping(RX_multi_time)\n        RX_multi_chan = se.SubRecordingExtractor(RX_mda, channel_ids=[0, 1])\n        check_dumping(RX_multi_chan)\n\n        SX_sub = se.SubSortingExtractor(SX_mda, unit_ids=[1, 2])\n        check_dumping(SX_sub)\n        SX_multi = se.MultiSortingExtractor(sortings=[SX_mda, SX_mda, SX_mda])\n        check_dumping(SX_multi)\n\n    def test_nwb_extractor(self):\n        path1 = self.test_dir + '/test.nwb'\n        se.NwbRecordingExtractor.write_recording(self.RX, path1)\n        RX_nwb = se.NwbRecordingExtractor(path1)\n        check_recording_return_types(RX_nwb)\n        check_recordings_equal(self.RX, RX_nwb)\n        check_dumping(RX_nwb)\n\n        del RX_nwb\n        se.NwbRecordingExtractor.write_recording(recording=self.RX, save_path=path1, overwrite=True)\n        RX_nwb = se.NwbRecordingExtractor(path1)\n        check_recording_return_types(RX_nwb)\n        check_recordings_equal(self.RX, RX_nwb)\n        check_dumping(RX_nwb)\n\n        # append sorting to existing file\n        se.NwbSortingExtractor.write_sorting(sorting=self.SX, save_path=path1, overwrite=False)\n\n        path2 = self.test_dir + \"/firings_true.nwb\"\n        se.NwbRecordingExtractor.write_recording(recording=self.RX, save_path=path2)\n        se.NwbSortingExtractor.write_sorting(sorting=self.SX, save_path=path2)\n        SX_nwb = se.NwbSortingExtractor(path2)\n        check_sortings_equal(self.SX, SX_nwb)\n        check_dumping(SX_nwb)\n\n        # Test for handling unit property descriptions argument\n        property_descriptions = dict(stability=\"This is a description of stability.\")\n        se.NwbRecordingExtractor.write_recording(recording=self.RX, save_path=path1, overwrite=True)\n        se.NwbSortingExtractor.write_sorting(\n            sorting=self.SX,\n            save_path=path1,\n            property_descriptions=property_descriptions\n        )\n        SX_nwb = se.NwbSortingExtractor(path1)\n        check_sortings_equal(self.SX, SX_nwb)\n        check_dumping(SX_nwb)\n\n        # Test for handling skip_properties argument\n        se.NwbRecordingExtractor.write_recording(recording=self.RX, save_path=path1, overwrite=True)\n        se.NwbSortingExtractor.write_sorting(\n            sorting=self.SX,\n            save_path=path1,\n            skip_properties=['stability']\n        )\n        SX_nwb = se.NwbSortingExtractor(path1)\n        assert 'stability' not in SX_nwb.get_shared_unit_property_names()\n        check_sortings_equal(self.SX, SX_nwb)\n        check_dumping(SX_nwb)\n\n        # Test for handling skip_features argument\n        se.NwbRecordingExtractor.write_recording(recording=self.RX, save_path=path1, overwrite=True)\n        # SX2 has timestamps, so loading it back from Nwb will not recover the same spike frames. USe use_times=False\n        se.NwbSortingExtractor.write_sorting(\n            sorting=self.SX2,\n            save_path=path1,\n            skip_features=['widths'],\n            use_times=False\n        )\n        SX_nwb = se.NwbSortingExtractor(path1)\n        assert 'widths' not in SX_nwb.get_shared_unit_spike_feature_names()\n        check_sortings_equal(self.SX2, SX_nwb)\n        check_dumping(SX_nwb)\n\n        # Test writting multiple recordings using metadata\n        metadata = get_default_nwbfile_metadata()\n        path_nwb = self.test_dir + '/test_multiple.nwb'\n        se.NwbRecordingExtractor.write_recording(\n            recording=self.RX, \n            save_path=path_nwb,\n            metadata=metadata,\n            write_as='raw',\n            es_key='ElectricalSeries_raw',\n        )\n        se.NwbRecordingExtractor.write_recording(\n            recording=self.RX2, \n            save_path=path_nwb,\n            metadata=metadata,\n            write_as='processed',\n            es_key='ElectricalSeries_processed',\n        )\n        se.NwbRecordingExtractor.write_recording(\n            recording=self.RX3, \n            save_path=path_nwb,\n            metadata=metadata,\n            write_as='lfp',\n            es_key='ElectricalSeries_lfp',\n        )\n\n        RX_nwb = se.NwbRecordingExtractor(\n            file_path=path_nwb,\n            electrical_series_name='raw_traces'\n        )\n        check_recording_return_types(RX_nwb)\n        check_recordings_equal(self.RX, RX_nwb)\n        check_dumping(RX_nwb)\n        del RX_nwb\n\n    def test_nixio_extractor(self):\n        path1 = os.path.join(self.test_dir, 'raw.nix')\n        se.NIXIORecordingExtractor.write_recording(self.RX, path1)\n        RX_nixio = se.NIXIORecordingExtractor(path1)\n        check_recording_return_types(RX_nixio)\n        check_recordings_equal(self.RX, RX_nixio)\n        check_dumping(RX_nixio)\n        del RX_nixio\n        # test force overwrite\n        se.NIXIORecordingExtractor.write_recording(self.RX, path1,\n                                                   overwrite=True)\n\n        path2 = self.test_dir + '/firings_true.nix'\n        se.NIXIOSortingExtractor.write_sorting(self.SX, path2)\n        SX_nixio = se.NIXIOSortingExtractor(path2)\n        check_sorting_return_types(SX_nixio)\n        check_sortings_equal(self.SX, SX_nixio)\n        check_dumping(SX_nixio)\n\n    def test_shybrid_extractors(self):\n        # test sorting extractor\n        se.SHYBRIDSortingExtractor.write_sorting(self.SX, self.test_dir)\n        initial_sorting_file = os.path.join(self.test_dir, 'initial_sorting.csv')\n        SX_shybrid = se.SHYBRIDSortingExtractor(initial_sorting_file)\n        check_sorting_return_types(SX_shybrid)\n        check_sortings_equal(self.SX, SX_shybrid)\n        check_dumping(SX_shybrid)\n\n        # test recording extractor\n        se.SHYBRIDRecordingExtractor.write_recording(self.RX,\n                                                     self.test_dir,\n                                                     initial_sorting_file)\n        RX_shybrid = se.SHYBRIDRecordingExtractor(os.path.join(self.test_dir,\n                                                               'recording.bin'))\n        check_recording_return_types(RX_shybrid)\n        check_recordings_equal(self.RX, RX_shybrid)\n        check_dumping(RX_shybrid)\n\n    def test_neuroscope_extractors(self):\n        # NeuroscopeRecordingExtractor tests\n        nscope_dir = Path(self.test_dir) / 'neuroscope_rec0'\n        dat_file = nscope_dir / 'neuroscope_rec0.dat'\n        se.NeuroscopeRecordingExtractor.write_recording(self.RX, nscope_dir)\n        RX_ns = se.NeuroscopeRecordingExtractor(dat_file)\n\n        check_recording_return_types(RX_ns)\n        check_recordings_equal(self.RX, RX_ns, force_dtype='int32')\n        check_dumping(RX_ns)\n\n        check_recording_return_types(RX_ns)\n        check_recordings_equal(self.RX, RX_ns, force_dtype='int32')\n        check_dumping(RX_ns)\n\n        del RX_ns\n        # overwrite\n        nscope_dir = Path(self.test_dir) / 'neuroscope_rec1'\n        dat_file = nscope_dir / 'neuroscope_rec1.dat'\n        se.NeuroscopeRecordingExtractor.write_recording(recording=self.RX, save_path=nscope_dir)\n        RX_ns = se.NeuroscopeRecordingExtractor(dat_file)\n        check_recording_return_types(RX_ns)\n        check_recordings_equal(self.RX, RX_ns)\n        check_dumping(RX_ns)\n\n        # NeuroscopeMultiRecordingTimeExtractor tests\n        nscope_dir = Path(self.test_dir) / \"neuroscope_rec2\"\n        dat_file = nscope_dir / \"neuroscope_rec2.dat\"\n        RX_multirecording = se.MultiRecordingTimeExtractor(recordings=[self.RX, self.RX])\n        se.NeuroscopeMultiRecordingTimeExtractor.write_recording(recording=RX_multirecording, save_path=nscope_dir)\n        RX_mre = se.NeuroscopeMultiRecordingTimeExtractor(folder_path=nscope_dir)\n        check_recording_return_types(RX_mre)\n        check_recordings_equal(RX_multirecording, RX_mre)\n        check_dumping(RX_mre)\n\n        # NeuroscopeSortingExtractor tests\n        nscope_dir = Path(self.test_dir) / 'neuroscope_sort0'\n        sort_name = 'neuroscope_sort0'\n        initial_sorting_resfile = Path(self.test_dir) / sort_name / f'{sort_name}.res'\n        initial_sorting_clufile = Path(self.test_dir) / sort_name / f'{sort_name}.clu'\n        se.NeuroscopeSortingExtractor.write_sorting(self.SX, nscope_dir)\n        SX_neuroscope = se.NeuroscopeSortingExtractor(resfile_path=initial_sorting_resfile,\n                                                      clufile_path=initial_sorting_clufile)\n        check_sorting_return_types(SX_neuroscope)\n        check_sortings_equal(self.SX, SX_neuroscope)\n        check_dumping(SX_neuroscope)\n        SX_neuroscope_no_mua = se.NeuroscopeSortingExtractor(resfile_path=initial_sorting_resfile,\n                                                             clufile_path=initial_sorting_clufile,\n                                                             keep_mua_units=False)\n        check_sorting_return_types(SX_neuroscope_no_mua)\n        check_dumping(SX_neuroscope_no_mua)\n\n        # Test for extra argument 'keep_mua_units' resulted in the right output\n        SX_neuroscope_no_mua = se.NeuroscopeSortingExtractor(resfile_path=initial_sorting_resfile,\n                                                             clufile_path=initial_sorting_clufile,\n                                                             keep_mua_units=False)\n        check_sorting_return_types(SX_neuroscope_no_mua)\n        check_dumping(SX_neuroscope_no_mua)\n\n        num_original_units = len(SX_neuroscope.get_unit_ids())\n        self.assertEqual(list(SX_neuroscope.get_unit_ids()), list(range(1, num_original_units + 1)))\n        self.assertEqual(list(SX_neuroscope_no_mua.get_unit_ids()), list(range(2, num_original_units + 1)))\n\n        # Tests for the auto-detection of format for NeuroscopeSortingExtractor\n        SX_neuroscope_from_fp = se.NeuroscopeSortingExtractor(folder_path=nscope_dir)\n        check_sorting_return_types(SX_neuroscope_from_fp)\n        check_sortings_equal(self.SX, SX_neuroscope_from_fp)\n        check_dumping(SX_neuroscope_from_fp)\n\n        # Tests for the NeuroscopeMultiSortingExtractor\n        nscope_dir = Path(self.test_dir) / 'neuroscope_sort1'\n        SX_multisorting = se.MultiSortingExtractor(sortings=[self.SX, self.SX])\n        se.NeuroscopeMultiSortingExtractor.write_sorting(SX_multisorting, nscope_dir)\n        SX_neuroscope_mse = se.NeuroscopeMultiSortingExtractor(nscope_dir)\n        check_sorting_return_types(SX_neuroscope_mse)\n        check_sortings_equal(SX_multisorting, SX_neuroscope_mse)\n        check_dumping(SX_neuroscope_mse)\n\n    def test_cell_explorer_extractor(self):\n        sorter_id = \"cell_explorer_sorter\"\n        cell_explorer_dir = Path(self.test_dir) / sorter_id\n        spikes_matfile_path = cell_explorer_dir / f\"{sorter_id}.spikes.cellinfo.mat\"\n        se.CellExplorerSortingExtractor.write_sorting(sorting=self.SX, save_path=spikes_matfile_path)\n        SX_cell_explorer = se.CellExplorerSortingExtractor(spikes_matfile_path=spikes_matfile_path)\n        check_sorting_return_types(SX_cell_explorer)\n        check_sortings_equal(self.SX, SX_cell_explorer)\n        check_dumping(SX_cell_explorer)\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_gin_repo.py",
    "content": "import tempfile\nimport unittest\nfrom pathlib import Path\nimport numpy as np\nimport sys\n\nfrom datalad.api import install, Dataset\nfrom parameterized import parameterized\n\nimport spikeextractors as se\nfrom spikeextractors.testing import check_recordings_equal, check_sortings_equal\n\nrun_local = False\ntest_nwb = True\ntest_caching = True\n\nif sys.platform == \"linux\" or run_local:\n    class TestNwbConversions(unittest.TestCase):\n\n        def setUp(self):\n            pt = Path.cwd() / 'ephy_testing_data'\n            if pt.exists():\n                self.dataset = Dataset(pt)\n            else:\n                self.dataset = install('https://gin.g-node.org/NeuralEnsemble/ephy_testing_data')\n            # Must pin to previous dataset version\n            # See https://github.com/SpikeInterface/spikeextractors/pull/675\n            self.dataset.repo.call_git(['checkout', '17e8f37674d70af84cdba6acd83df964a8e09f0c'])\n            self.savedir = Path(tempfile.mkdtemp())\n\n        @parameterized.expand([\n            (\n                se.AxonaRecordingExtractor,\n                \"axona\",\n                dict(filename=str(Path.cwd() / \"ephy_testing_data\" / \"axona\" / \"axona_raw.set\"))\n            ),\n            (\n                se.BlackrockRecordingExtractor,\n                \"blackrock/blackrock_2_1\",\n                dict(\n                    filename=str(Path.cwd() / \"ephy_testing_data\" / \"blackrock\" / \"blackrock_2_1\" / \"l101210-001\"),\n                    seg_index=0,\n                    nsx_to_load=5\n                )\n            ),\n            (\n                se.IntanRecordingExtractor,\n                \"intan\",\n                dict(file_path=Path.cwd() / \"ephy_testing_data\" / \"intan\" / \"intan_rhd_test_1.rhd\")\n            ),\n            (\n                se.IntanRecordingExtractor,\n                \"intan\",\n                dict(file_path=Path.cwd() / \"ephy_testing_data\" / \"intan\" / \"intan_rhs_test_1.rhs\")\n            ),\n            # Klusta - no .prm config file in ephy_testing\n            # (\n            #     se.KlustaRecordingExtractor,\n            #     \"kwik\",\n            #     dict(folder_path=Path.cwd() / \"ephy_testing_data\" / \"kwik\")\n            # ),\n            (\n                se.MEArecRecordingExtractor,\n                \"mearec/mearec_test_10s.h5\",\n                dict(file_path=Path.cwd() / \"ephy_testing_data\" / \"mearec\" / \"mearec_test_10s.h5\")\n            ),\n            (\n                se.NeuralynxRecordingExtractor,\n                \"neuralynx/Cheetah_v5.7.4/original_data\",\n                dict(\n                    dirname=Path.cwd() / \"ephy_testing_data\" / \"neuralynx\" / \"Cheetah_v5.7.4\" / \"original_data\",\n                    seg_index=0\n                )\n            ),\n            (\n                se.NeuroscopeRecordingExtractor,\n                \"neuroscope/test1\",\n                dict(file_path=Path.cwd() / \"ephy_testing_data\" / \"neuroscope\" / \"test1\" / \"test1.dat\")\n            ),\n            # Nixio - RuntimeError: Cannot open non-existent file in ReadOnly mode!\n            # (\n            #     se.NIXIORecordingExtractor,\n            #     \"nix\",\n            #     dict(file_path=str(Path.cwd() / \"ephy_testing_data\" / \"neoraw.nix\"))\n            # ),\n            (\n                se.OpenEphysRecordingExtractor,\n                \"openephys/OpenEphys_SampleData_1\",\n                dict(folder_path=Path.cwd() / \"ephy_testing_data\" / \"openephys\" / \"OpenEphys_SampleData_1\")\n            ),\n            (\n                se.OpenEphysRecordingExtractor,\n                \"openephysbinary/v0.4.4.1_with_video_tracking\",\n                dict(folder_path=Path.cwd() / \"ephy_testing_data\" / \"openephysbinary\" / \"v0.4.4.1_with_video_tracking\")\n            ),\n            (\n                se.OpenEphysNPIXRecordingExtractor,\n                \"openephysbinary/v0.5.3_two_neuropixels_stream\",\n                dict(\n                    folder_path=Path.cwd() / \"ephy_testing_data\" / \"openephysbinary\" / \"v0.5.3_two_neuropixels_stream\"\n                                / \"Record_Node_107\")\n            ),\n            (\n                se.NeuropixelsDatRecordingExtractor,\n                \"openephysbinary/v0.5.3_two_neuropixels_stream\",\n                dict(\n                    file_path=Path.cwd() / \"ephy_testing_data\" / \"openephysbinary\" / \"v0.5.3_two_neuropixels_stream\" /\n                              \"Record_Node_107\" / \"experiment1\" / \"recording1\" / \"continuous\" /\n                              \"Neuropix-PXI-116.0\" / \"continuous.dat\",\n                    settings_file=Path.cwd() / \"ephy_testing_data\" / \"openephysbinary\" /\n                                  \"v0.5.3_two_neuropixels_stream\" / \"Record_Node_107\" / \"settings.xml\")\n            ),\n            (\n                se.PhyRecordingExtractor,\n                \"phy/phy_example_0\",\n                dict(folder_path=Path.cwd() / \"ephy_testing_data\" / \"phy\" / \"phy_example_0\")\n            ),\n            # Plexon - AssertionError: This file have several channel groups spikeextractors support only one groups\n            # (\n            #     se.PlexonRecordingExtractor,\n            #     \"plexon\",\n            #     dict(filename=Path.cwd() / \"ephy_testing_data\" / \"plexon\" / \"File_plexon_2.plx\")\n            # ),\n            (\n                    se.CEDRecordingExtractor,\n                    \"spike2/m365_1sec.smrx\",\n                    dict(\n                        file_path=Path.cwd() / \"ephy_testing_data\" / \"spike2\" / \"m365_1sec.smrx\",\n                        smrx_channel_ids=range(10)\n                    )\n            ),\n            (\n                se.SpikeGLXRecordingExtractor,\n                \"spikeglx/Noise4Sam_g0\",\n                dict(\n                    file_path=Path.cwd() / \"ephy_testing_data\" / \"spikeglx\" / \"Noise4Sam_g0\" / \"Noise4Sam_g0_imec0\" /\n                    \"Noise4Sam_g0_t0.imec0.ap.bin\"\n                )\n            )\n        ])\n        def test_convert_recording_extractor_to_nwb(self, se_class, dataset_path, se_kwargs):\n            print(f\"\\n\\n\\n TESTING {se_class.extractor_name}...\")\n            dataset_stem = Path(dataset_path).stem\n            self.dataset.get(dataset_path)\n            recording = se_class(**se_kwargs)\n\n            # # test writing to NWB\n            if test_nwb:\n                nwb_save_path = self.savedir / f\"{se_class.__name__}_test_{dataset_stem}.nwb\"\n                se.NwbRecordingExtractor.write_recording(recording, nwb_save_path, write_scaled=True)\n                nwb_recording = se.NwbRecordingExtractor(nwb_save_path)\n                check_recordings_equal(recording, nwb_recording, check_times=False)\n\n                if recording.has_unscaled:\n                    nwb_save_path_unscaled = self.savedir / f\"{se_class.__name__}_test_{dataset_stem}_unscaled.nwb\"\n                    if np.all(recording.get_channel_offsets() == 0):\n                        se.NwbRecordingExtractor.write_recording(recording, nwb_save_path_unscaled, write_scaled=False)\n                        nwb_recording = se.NwbRecordingExtractor(nwb_save_path_unscaled)\n                        check_recordings_equal(recording, nwb_recording, return_scaled=False, check_times=False)\n                        # Skip check when NWB converts uint to int\n                        if recording.get_dtype(return_scaled=False) == nwb_recording.get_dtype(return_scaled=False):\n                            check_recordings_equal(recording, nwb_recording, return_scaled=True, check_times=False)\n\n            # test caching\n            if test_caching:\n                rec_cache = se.CacheRecordingExtractor(recording)\n                check_recordings_equal(recording, rec_cache)\n                if recording.has_unscaled:\n                    rec_cache_unscaled = se.CacheRecordingExtractor(recording, return_scaled=False)\n                    check_recordings_equal(recording, rec_cache_unscaled, return_scaled=False)\n                    check_recordings_equal(recording, rec_cache_unscaled, return_scaled=True)\n\n        @parameterized.expand([\n            (\n                se.BlackrockSortingExtractor,\n                \"blackrock/blackrock_2_1\",\n                dict(\n                    filename=str(Path.cwd() / \"ephy_testing_data\" / \"blackrock\" / \"blackrock_2_1\" / \"l101210-001\"),\n                    seg_index=0,\n                    nsx_to_load=5\n                 )\n            ),\n            (\n                se.KlustaSortingExtractor,\n                \"kwik\",\n                dict(file_or_folder_path=Path.cwd() / \"ephy_testing_data\" / \"kwik\" / \"neo.kwik\")\n            ),\n            # Neuralynx - units_ids = nwbfile.units.id[:] - AttributeError: 'NoneType' object has no attribute 'id'\n            # Is the GIN data OK? Or are there no units?\n            # (\n            #     se.NeuralynxSortingExtractor,\n            #     \"neuralynx/Cheetah_v5.7.4/original_data\",\n            #     dict(\n            #         dirname=Path.cwd() / \"ephy_testing_data\" / \"neuralynx\" / \"Cheetah_v5.7.4\" / \"original_data\",\n            #         seg_index=0\n            #     )\n            # ),\n            # NIXIO - return [int(da.label) for da in self._spike_das]\n            # TypeError: int() argument must be a string, a bytes-like object or a number, not 'NoneType'\n            # (\n            #     se.NIXIOSortingExtractor,\n            #     \"nix/nixio_fr.nix\",\n            #     dict(file_path=str(Path.cwd() / \"ephy_testing_data\" / \"nix\" / \"nixio_fr.nix\"))\n            # ),\n            (\n                se.MEArecSortingExtractor,\n                \"mearec/mearec_test_10s.h5\",\n                dict(file_path=Path.cwd() / \"ephy_testing_data\" / \"mearec\" / \"mearec_test_10s.h5\")\n            ),\n            (\n                se.PhySortingExtractor,\n                \"phy/phy_example_0\",\n                dict(folder_path=Path.cwd() / \"ephy_testing_data\" / \"phy\" / \"phy_example_0\")\n            ),\n            (\n                se.PlexonSortingExtractor,\n                \"plexon\",\n                dict(filename=Path.cwd() / \"ephy_testing_data\" / \"plexon\" / \"File_plexon_2.plx\")\n            ),\n            (\n                se.SpykingCircusSortingExtractor,\n                \"spykingcircus/spykingcircus_example0\",\n                dict(\n                    file_or_folder_path=Path.cwd() / \"ephy_testing_data\" / \"spykingcircus\" / \"spykingcircus_example0\" /\n                                        \"recording\"\n                )\n            ),\n            # # Tridesclous - dataio error, GIN data is not correct?\n            # (\n            #     se.TridesclousSortingExtractor,\n            #     \"tridesclous/tdc_example0\",\n            #     dict(folder_path=Path.cwd() / \"ephy_testing_data\" / \"tridesclous\" / \"tdc_example0\")\n            # )\n        ])\n        def test_convert_sorting_extractor_to_nwb(self, se_class, dataset_path, se_kwargs):\n            print(f\"\\n\\n\\n TESTING {se_class.extractor_name}...\")\n            dataset_stem = Path(dataset_path).stem\n            self.dataset.get(dataset_path)\n\n            sorting = se_class(**se_kwargs)\n            sf = sorting.get_sampling_frequency()\n            if sf is None:  # need to set dummy sampling frequency since no associated acquisition in file\n                sf = 30000\n                sorting.set_sampling_frequency(sf)\n\n            if test_nwb:\n                nwb_save_path = self.savedir / f\"{se_class.__name__}_test_{dataset_stem}.nwb\"\n                se.NwbSortingExtractor.write_sorting(sorting, nwb_save_path)\n                nwb_sorting = se.NwbSortingExtractor(nwb_save_path, sampling_frequency=sf)\n                check_sortings_equal(sorting, nwb_sorting)\n\n            if test_caching:\n                sort_cache = se.CacheSortingExtractor(sorting)\n                check_sortings_equal(sorting, sort_cache)\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_numpy_extractors.py",
    "content": "import numpy as np\nimport unittest\nimport spikeextractors as se\n\n\nclass TestNumpyExtractors(unittest.TestCase):\n    def setUp(self):\n        M = 4\n        N = 10000\n        N_ttl = 50\n        seed = 0\n        sampling_frequency = 30000\n        X = np.random.RandomState(seed=seed).normal(0, 1, (M, N))\n        geom = np.random.RandomState(seed=seed).normal(0, 1, (M, 2))\n        self._X = X\n        self._geom = geom\n        self._sampling_frequency = sampling_frequency\n        self.RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency, geom=geom)\n        self._ttl_frames = np.sort(np.random.permutation(N)[:N_ttl])\n        self.RX.set_ttls(self._ttl_frames)\n        self.SX = se.NumpySortingExtractor()\n        L = 200\n        self._train1 = np.rint(np.random.RandomState(seed=seed).uniform(0, N, L)).astype(int)\n        self.SX.add_unit(unit_id=1, times=self._train1)\n        self.SX.add_unit(unit_id=2, times=np.random.RandomState(seed=seed).uniform(0, N, L))\n        self.SX.add_unit(unit_id=3, times=np.random.RandomState(seed=seed).uniform(0, N, L))\n\n    def tearDown(self):\n        pass\n\n    def test_recording_extractor(self):\n        # get_channel_ids\n        self.assertEqual(self.RX.get_channel_ids(), [i for i in range(self._X.shape[0])])\n        # get_num_channels\n        self.assertEqual(self.RX.get_num_channels(), self._X.shape[0])\n        # get_num_frames\n        self.assertEqual(self.RX.get_num_frames(), self._X.shape[1])\n        # get_sampling_frequency\n        self.assertEqual(self.RX.get_sampling_frequency(), self._sampling_frequency)\n        # get_traces\n        self.assertTrue(np.allclose(self.RX.get_traces(), self._X))\n        self.assertTrue(\n            np.allclose(self.RX.get_traces(channel_ids=[0, 3], start_frame=0, end_frame=12), self._X[[0, 3], 0:12]))\n        # get_channel_property - location\n        self.assertTrue(np.allclose(np.array(self.RX.get_channel_locations(1)), self._geom[1, :]))\n        # time_to_frame / frame_to_time\n        self.assertEqual(self.RX.time_to_frame(12), 12 * self.RX.get_sampling_frequency())\n        self.assertEqual(self.RX.frame_to_time(12), 12 / self.RX.get_sampling_frequency())\n        # get_snippets\n        snippets = self.RX.get_snippets(reference_frames=[0, 30, 50], snippet_len=20)\n        self.assertTrue(np.allclose(snippets[1], self._X[:, 20:40]))\n        # get_ttl_events\n        self.assertTrue(np.array_equal(self.RX.get_ttl_events()[0], self._ttl_frames))\n\n    def test_sorting_extractor(self):\n        unit_ids = [1, 2, 3]\n        # get_unit_ids\n        self.assertEqual(self.SX.get_unit_ids(), unit_ids)\n        # get_unit_spike_train\n        st = self.SX.get_unit_spike_train(unit_id=1)\n        self.assertTrue(np.allclose(st, self._train1))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_tools.py",
    "content": "import numpy as np\nimport unittest\nimport tempfile\nimport shutil\nimport spikeextractors as se\nimport os\nfrom copy import copy\nfrom pathlib import Path\n\nthis_file = Path(__file__).parent\n\n\nclass TestTools(unittest.TestCase):\n    def setUp(self):\n        M = 32\n        N = 10000\n        seed = 0\n        sampling_frequency = 30000\n        X = np.random.RandomState(seed=seed).normal(0, 1, (M, N))\n        self._X = X\n        self._sampling_frequency = sampling_frequency\n        self.RX = se.NumpyRecordingExtractor(timeseries=X, sampling_frequency=sampling_frequency)\n        self.RX.set_channel_locations(np.random.randn(32, 3))\n        self.test_dir = Path(tempfile.mkdtemp())\n\n    def tearDown(self):\n        shutil.rmtree(self.test_dir)\n\n    def test_load_save_probes(self):\n        sub_RX = se.load_probe_file(self.RX, this_file / 'probe_test.prb')\n        # print(SX.get_channel_property_names())\n        assert 'location' in sub_RX.get_shared_channel_property_names()\n        assert 'group' in sub_RX.get_shared_channel_property_names()\n        positions = [sub_RX.get_channel_locations(chan)[0] for chan in range(self.RX.get_num_channels())]\n        # save in csv\n        sub_RX.save_to_probe_file(self.test_dir / 'geom.csv')\n        # load csv locations\n        sub_RX_load = sub_RX.load_probe_file(self.test_dir / 'geom.csv')\n        position_loaded = [sub_RX_load.get_channel_locations(chan)[0] for\n                           chan in range(sub_RX_load.get_num_channels())]\n        self.assertTrue(np.allclose(positions[10], position_loaded[10]))\n\n        # prb file\n        RX = copy(self.RX)\n        channel_groups = []\n        n_group = 4\n        for i in RX.get_channel_ids():\n            channel_groups.append(i // n_group)\n        RX.set_channel_groups(channel_groups)\n        RX.save_to_probe_file(this_file / 'probe_test_no_groups.prb')\n        RX.save_to_probe_file(this_file / 'probe_test_groups.prb', grouping_property='group')\n\n        # load\n        RX_loaded_no_groups = se.load_probe_file(RX, this_file / 'probe_test_no_groups.prb')\n        RX_loaded_groups = se.load_probe_file(RX, this_file / 'probe_test_groups.prb')\n\n        assert len(np.unique(RX_loaded_no_groups.get_channel_groups())) == 1\n        assert len(np.unique(RX_loaded_groups.get_channel_groups())) == RX.get_num_channels() // n_group\n\n        # cleanup\n        (this_file / 'probe_test_no_groups.prb').unlink()\n        (this_file / 'probe_test_groups.prb').unlink()\n\n    def test_write_dat_file(self):\n        nb_sample = self.RX.get_num_frames()\n        nb_chan = self.RX.get_num_channels()\n\n        # time_axis=0 chunk_size=None\n        self.RX.write_to_binary_dat_format(self.test_dir / 'rec.dat', time_axis=0, dtype='float32', chunk_size=None)\n        data = np.memmap(self.test_dir / 'rec.dat', dtype='float32', mode='r', shape=(nb_sample, nb_chan)).T\n        assert np.allclose(data, self.RX.get_traces())\n        del data  # this close the file\n\n        # time_axis=1 chunk_size=None\n        self.RX.write_to_binary_dat_format(self.test_dir / 'rec.dat', time_axis=1, dtype='float32', chunk_size=None)\n        data = np.memmap(self.test_dir / 'rec.dat', dtype='float32', mode='r', shape=(nb_chan, nb_sample))\n        assert np.allclose(data, self.RX.get_traces())\n        del data  # this close the file\n\n        # time_axis=0 chunk_size=99\n        self.RX.write_to_binary_dat_format(self.test_dir / 'rec.dat', time_axis=0, dtype='float32', chunk_size=99)\n        data = np.memmap(self.test_dir / 'rec.dat', dtype='float32', mode='r', shape=(nb_sample, nb_chan)).T\n        assert np.allclose(data, self.RX.get_traces())\n        del(data) # this close the file\n\n        # time_axis=0 chunk_mb=2\n        self.RX.write_to_binary_dat_format(self.test_dir /  'rec.dat', time_axis=0, dtype='float32', chunk_mb=2)\n        data = np.memmap(self.test_dir / 'rec.dat', dtype='float32', mode='r', shape=(nb_sample, nb_chan)).T\n        assert np.allclose(data, self.RX.get_traces())\n        del data  # this close the file\n\n        # time_axis=1 chunk_mb=2\n        self.RX.write_to_binary_dat_format(self.test_dir / 'rec.dat', time_axis=1, dtype='float32', chunk_mb=2)\n        data = np.memmap(self.test_dir / 'rec.dat', dtype='float32', mode='r', shape=(nb_chan, nb_sample))\n        assert np.allclose(data, self.RX.get_traces())\n        del data  # this close the file\n\n        # time_axis=0 chunk_mb=10, n_jobs=2\n        self.RX.write_to_binary_dat_format(self.test_dir / 'rec.dat', time_axis=0,\n                                           dtype='float32', chunk_mb=10, n_jobs=2)\n        data = np.memmap(self.test_dir / 'rec.dat', dtype='float32', mode='r', shape=(nb_sample, nb_chan)).T\n        assert np.allclose(data, self.RX.get_traces())\n        del data  # this close the file\n\n        # time_axis=1 chunk_mb=10 n_jobs=2\n        self.RX.write_to_binary_dat_format(self.test_dir / 'rec.dat', time_axis=1,\n                                           dtype='float32', chunk_mb=2, n_jobs=2)\n        data = np.memmap(self.test_dir / 'rec.dat', dtype='float32', mode='r', shape=(nb_chan, nb_sample))\n        assert np.allclose(data, self.RX.get_traces())\n        del data  # this close the file\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  }
]