Repository: google/weather-tools Branch: main Commit: cc9d5b2f2bae Files: 238 Total size: 183.3 MB Directory structure: gitextract_haxxml5e/ ├── .github/ │ └── workflows/ │ ├── ci.yml │ └── publish.yml ├── .gitignore ├── .read-the-docs.yaml ├── CONTRIBUTING.md ├── Configuration.md ├── Dockerfile ├── Efficient-Requests.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── Runners.md ├── Runtime-Container.md ├── bin/ │ ├── install-branch │ └── post-push ├── ci3.8.yml ├── ci3.9.yml ├── configs/ │ ├── era5_example_config.cfg │ ├── era5_example_config_local_run.cfg │ ├── era5_example_config_local_run.json │ ├── era5_example_config_preproc.cfg │ ├── era5_example_config_using_date.cfg │ ├── era5_example_monthly_soil.cfg │ ├── mars_example_config.cfg │ ├── mars_example_config.json │ ├── multiple_multiple_licenses/ │ │ ├── era5_pressure500.cfg │ │ ├── era5_pressure600.cfg │ │ └── era5_pressure700.cfg │ ├── multiple_single_license/ │ │ ├── era5_pressure500.cfg │ │ ├── era5_pressure600.cfg │ │ └── era5_pressure700.cfg │ ├── s2s_operational_forecast_example.cfg │ ├── seasonal_forecast_example_config.cfg │ ├── tigge_example_config.cfg │ └── yesterdays_surface_example.cfg ├── docs/ │ ├── Makefile │ ├── Private-IP-Configuration.md │ ├── _static/ │ │ └── custom.css │ ├── conf.py │ ├── download_pipeline.md │ ├── index.md │ ├── loader_pipeline.md │ ├── make.bat │ ├── modules.md │ ├── requirements.txt │ └── splitter_pipeline.md ├── environment.yml ├── pyproject.toml ├── setup.cfg ├── setup.py ├── tox.ini ├── weather_dl/ │ ├── MANIFEST.in │ ├── README.md │ ├── __init__.py │ ├── download_pipeline/ │ │ ├── __init__.py │ │ ├── clients.py │ │ ├── clients_test.py │ │ ├── config.py │ │ ├── config_test.py │ │ ├── fetcher.py │ │ ├── fetcher_test.py │ │ ├── manifest.py │ │ ├── manifest_test.py │ │ ├── parsers.py │ │ ├── parsers_test.py │ │ ├── partition.py │ │ ├── partition_test.py │ │ ├── pipeline.py │ │ ├── pipeline_test.py │ │ ├── stores.py │ │ ├── stores_test.py │ │ ├── util.py │ │ └── util_test.py │ ├── setup.py │ └── weather-dl ├── weather_dl_v2/ │ ├── README.md │ ├── __init__.py │ ├── cli/ │ │ ├── CLI-Documentation.md │ │ ├── Dockerfile │ │ ├── README.md │ │ ├── VERSION.txt │ │ ├── app/ │ │ │ ├── __init__.py │ │ │ ├── cli_config.py │ │ │ ├── data/ │ │ │ │ └── cli_config.json │ │ │ ├── main.py │ │ │ ├── services/ │ │ │ │ ├── __init__.py │ │ │ │ ├── download_service.py │ │ │ │ ├── license_service.py │ │ │ │ ├── network_service.py │ │ │ │ └── queue_service.py │ │ │ ├── subcommands/ │ │ │ │ ├── __init__.py │ │ │ │ ├── config.py │ │ │ │ ├── download.py │ │ │ │ ├── license.py │ │ │ │ └── queue.py │ │ │ └── utils.py │ │ ├── environment.yml │ │ ├── setup.py │ │ └── vm-startup.sh │ ├── cloudbuild.yml │ ├── config.json │ ├── downloader_kubernetes/ │ │ ├── Dockerfile │ │ ├── README.md │ │ ├── VERSION.txt │ │ ├── downloader.py │ │ ├── downloader_config.py │ │ ├── environment.yml │ │ ├── manifest.py │ │ └── util.py │ ├── fastapi-server/ │ │ ├── API-Interactions.md │ │ ├── Dockerfile │ │ ├── README.md │ │ ├── VERSION.txt │ │ ├── __init__.py │ │ ├── config_processing/ │ │ │ ├── config.py │ │ │ ├── manifest.py │ │ │ ├── parsers.py │ │ │ ├── partition.py │ │ │ ├── pipeline.py │ │ │ ├── stores.py │ │ │ └── util.py │ │ ├── database/ │ │ │ ├── __init__.py │ │ │ ├── download_handler.py │ │ │ ├── license_handler.py │ │ │ ├── manifest_handler.py │ │ │ ├── queue_handler.py │ │ │ ├── session.py │ │ │ └── storage_handler.py │ │ ├── environment.yml │ │ ├── example.cfg │ │ ├── license_dep/ │ │ │ ├── deployment_creator.py │ │ │ └── license_deployment.yaml │ │ ├── logging.conf │ │ ├── main.py │ │ ├── routers/ │ │ │ ├── download.py │ │ │ ├── license.py │ │ │ └── queues.py │ │ ├── server.yaml │ │ ├── server_config.py │ │ └── tests/ │ │ ├── __init__.py │ │ ├── integration/ │ │ │ ├── __init__.py │ │ │ ├── test_download.py │ │ │ ├── test_license.py │ │ │ └── test_queues.py │ │ └── test_data/ │ │ ├── example.cfg │ │ └── not_exist.cfg │ └── license_deployment/ │ ├── Dockerfile │ ├── README.md │ ├── VERSION.txt │ ├── __init__.py │ ├── clients.py │ ├── config.py │ ├── database.py │ ├── deployment_config.py │ ├── downloader.yaml │ ├── environment.yml │ ├── fetch.py │ ├── job_creator.py │ ├── manifest.py │ └── util.py ├── weather_mv/ │ ├── MANIFEST.in │ ├── README.md │ ├── __init__.py │ ├── loader_pipeline/ │ │ ├── __init__.py │ │ ├── bq.py │ │ ├── bq_test.py │ │ ├── ee.py │ │ ├── ee_test.py │ │ ├── execution_test.py │ │ ├── metrics.py │ │ ├── pipeline.py │ │ ├── pipeline_test.py │ │ ├── regrid.py │ │ ├── regrid_test.py │ │ ├── sinks.py │ │ ├── sinks_test.py │ │ ├── streaming.py │ │ ├── streaming_test.py │ │ ├── util.py │ │ └── util_test.py │ ├── setup.py │ ├── test_data/ │ │ ├── test_data.zarr/ │ │ │ ├── .zattrs │ │ │ ├── .zgroup │ │ │ ├── .zmetadata │ │ │ ├── cape/ │ │ │ │ ├── .zarray │ │ │ │ ├── .zattrs │ │ │ │ └── 0.0.0 │ │ │ ├── d2m/ │ │ │ │ ├── .zarray │ │ │ │ ├── .zattrs │ │ │ │ └── 0.0.0 │ │ │ ├── latitude/ │ │ │ │ ├── .zarray │ │ │ │ ├── .zattrs │ │ │ │ └── 0 │ │ │ ├── longitude/ │ │ │ │ ├── .zarray │ │ │ │ ├── .zattrs │ │ │ │ └── 0 │ │ │ └── time/ │ │ │ ├── .zarray │ │ │ ├── .zattrs │ │ │ └── 0 │ │ ├── test_data_20180101.nc │ │ ├── test_data_corrupt_grib │ │ ├── test_data_grib_multiple_edition_single_timestep.bz2 │ │ ├── test_data_grib_single_timestep │ │ ├── test_data_has_nan.nc │ │ ├── test_data_single_point.nc │ │ └── test_data_tif_time.tif │ └── weather-mv ├── weather_sp/ │ ├── MANIFEST.in │ ├── README.md │ ├── __init__.py │ ├── setup.py │ ├── splitter_pipeline/ │ │ ├── __init__.py │ │ ├── file_name_utils.py │ │ ├── file_name_utils_test.py │ │ ├── file_splitters.py │ │ ├── file_splitters_test.py │ │ ├── pipeline.py │ │ ├── pipeline_test.py │ │ └── streaming.py │ ├── test_data/ │ │ ├── era5_sample.grib │ │ ├── era5_sample.nc │ │ └── era5_sample_grib │ └── weather-sp └── xql/ ├── README.md ├── main.py ├── setup.py └── src/ ├── __init__.py ├── weather_lm/ │ ├── __init__.py │ ├── constant.py │ ├── gemini.py │ ├── template.py │ └── utils.py └── xql/ ├── __init__.py ├── apply.py ├── constant.py ├── open.py ├── utils.py └── where.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/ci.yml ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. name: CI on: # Triggers the workflow on push or pull request events but only for the main branch push: branches: [ main ] pull_request: branches: [ main ] # Allows you to run this workflow manually from the Actions tab workflow_dispatch: env: CDSAPI_URL: https://cds.climate.copernicus.eu/api CDSAPI_KEY: 1234567-ab12-34cd-9876-4o4fake90909 # A fake key for testing jobs: build: runs-on: ubuntu-latest strategy: fail-fast: false matrix: python-version: ["3.8", "3.9"] steps: - name: Cancel previous uses: styfle/cancel-workflow-action@0.7.0 with: access_token: ${{ github.token }} if: ${{github.ref != 'refs/head/main'}} - uses: actions/checkout@v2 - name: conda cache uses: actions/cache@v3 env: # Increase this value to reset cache if etc/example-environment.yml has not changed CACHE_NUMBER: 0 with: path: ~/conda_pkgs_dir key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ matrix.python-version }}-${{ hashFiles('ci3.8.yml') }} - name: Setup conda environment uses: conda-incubator/setup-miniconda@v3 with: python-version: ${{ matrix.python-version }} channels: conda-forge environment-file: ci${{ matrix.python-version}}.yml activate-environment: weather-tools miniforge-variant: Miniforge3 miniforge-version: latest use-mamba: true - name: Check MetView's installation shell: bash -l {0} run: python -m metview selfcheck - name: Run unit tests shell: bash -l {0} run: pytest --memray --ignore=weather_dl_v2 # Ignoring dl-v2 as it only supports py3.10 lint: runs-on: ubuntu-latest strategy: fail-fast: false steps: - name: Cancel previous uses: styfle/cancel-workflow-action@0.7.0 with: access_token: ${{ github.token }} if: ${{github.ref != 'refs/head/main'}} - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir id: pip-cache run: | python -m pip install --upgrade pip wheel echo "dir=$(pip cache dir)" >> "$GITHUB_OUTPUT" - name: Install linter run: | pip install ruff==0.1.2 - name: Lint project run: ruff check . type-check: runs-on: ubuntu-latest strategy: fail-fast: false matrix: python-version: ["3.8"] steps: - name: Cancel previous uses: styfle/cancel-workflow-action@0.7.0 with: access_token: ${{ github.token }} if: ${{github.ref != 'refs/head/main'}} - uses: actions/checkout@v2 - name: conda cache uses: actions/cache@v3 env: # Increase this value to reset cache if etc/example-environment.yml has not changed CACHE_NUMBER: 0 with: path: ~/conda_pkgs_dir key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ matrix.python-version }}-${{ hashFiles('ci3.8.yml') }} - name: Setup conda environment uses: conda-incubator/setup-miniconda@v3 with: python-version: ${{ matrix.python-version }} channels: conda-forge environment-file: ci${{ matrix.python-version}}.yml activate-environment: weather-tools miniforge-variant: Miniforge3 miniforge-version: latest use-mamba: true - name: Install weather-tools[test] run: | conda run -n weather-tools pip install -e .[test] --use-deprecated=legacy-resolver - name: Run type checker run: conda run -n weather-tools pytype ================================================ FILE: .github/workflows/publish.yml ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. name: Build and Upload weather-tools to PyPI on: release: types: [published] jobs: build-artifacts: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v2.3.1 with: python-version: 3.8 - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install setuptools setuptools-scm wheel twine check-manifest - name: Build tarball and wheels run: | git clean -xdf git restore -SW . python -m build --sdist --wheel . - name: Check built artifacts run: | python -m twine check dist/* pwd if [ -f dist/weather_tools-0.0.0.tar.gz ]; then echo "❌ INVALID VERSION NUMBER" exit 1 else echo "✅ Looks good" fi - uses: actions/upload-artifact@v2 with: name: releases path: dist test-built-dist: needs: build-artifacts runs-on: ubuntu-latest steps: - uses: actions/setup-python@v2.3.1 name: Install Python with: python-version: 3.8 - uses: actions/download-artifact@v4.1.7 with: name: releases path: dist - name: List contents of built dist run: | ls -ltrh ls -ltrh dist - name: Publish package to TestPyPI if: github.event_name == 'push' uses: pypa/gh-action-pypi-publish@v1.4.2 with: user: __token__ password: ${{ secrets.TESTPYPI_TOKEN }} repository_url: https://test.pypi.org/legacy/ verbose: true - name: Check uploaded package if: github.event_name == 'push' run: | sleep 3 python -m pip install --upgrade pip python -m pip install --extra-index-url https://test.pypi.org/simple --upgrade weather-tools weather-dl --help && weather-mv --help && weather-sp --help upload-to-pypi: needs: test-built-dist if: github.event_name == 'release' runs-on: ubuntu-latest steps: - uses: actions/download-artifact@v4.1.7 with: name: releases path: dist - name: Publish package to PyPI uses: pypa/gh-action-pypi-publish@v1.4.2 with: user: __token__ password: ${{ secrets.PYPI_TOKEN }} verbose: true ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # JetBrains / PyCharm .idea/ # VSCode .vscode/ launch.json ================================================ FILE: .read-the-docs.yaml ================================================ # .readthedocs.yaml # Read the Docs configuration file # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details # Required version: 2 build: image: latest # Build documentation in the docs/ directory with Sphinx sphinx: configuration: docs/conf.py # Optionally set the version of Python and requirements required to build your docs python: version: "3.8" install: - requirements: docs/requirements.txt - method: pip path: . system_packages: false ================================================ FILE: CONTRIBUTING.md ================================================ # How to Contribute We'd love to accept your patches and contributions to this project. There are just a few small guidelines you need to follow. ## Contributor License Agreement Contributions to this project must be accompanied by a Contributor License Agreement (CLA). You (or your employer) retain the copyright to your contribution; this simply gives us permission to use and redistribute your contributions as part of the project. Head over to to see your current agreements on file or to sign a new one. You generally only need to submit a CLA once, so if you've already submitted one (even if it was for a different project), you probably don't need to do it again. ## Code Reviews All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests. ## Community Guidelines This project follows [Google's Open Source Community Guidelines](https://opensource.google/conduct/). ## Pull Request Etiquette We're thrilled to have your pull request! However, we ask that you file an issue (or check if one exist) before submitting a change-list (CL). In general, we're happier reviewing many small PRs over fewer, larger changes. Please try to submit [small CLs](https://google.github.io/eng-practices/review/developer/small-cls.html). ## Project Structure ``` bin/ # Tools for development configs/ # Save example and general purpose configs. docs/ # Sphinx documentation, deployed to Github pages (or readthedocs). weather_dl/ # Main python package, other pipelines can come later. __init__.py download_pipeline/ # sources for pipeline __init__.py ... .py _test.py # We'll follow this naming pattern for tests. weather-dl # script to run pipeline setup.py # packages sources for execution in beam worker # pipeline-specific requirements managed here setup.py # Project is pip-installable, project requirements managed here. ``` ## Developer Installation 1. Set up a Python development environment with *Anaconda* or [*Miniconda*](https://docs.conda.io/en/latest/miniconda.html). * On a Mac, you can run `brew install miniconda`. * Use Python version 3.8.5+ for development. We have not yet upgraded the project to use Python 3.9 (see [#166](https://github.com/google/weather-tools/issues/166)). 3. Clone the repo and install dependencies with anaconda. ```shell # clone with HTTPS git clone http://github.com/google/weather-tools.git # clone with SSH git clone git@github.com:google/weather-tools.git cd weather-tools conda env create -f environment.yml conda activate weather-tools pip install -e ".[dev]" ``` 4. Install `gcloud`, the Google Cloud CLI, following these [instructions](https://cloud.google.com/sdk/docs/install). 5. Acquire adequate permissions for the project. * Run `gcloud auth application-default login`. * Make sure your account has *write* permissions to the storage bucket as well as the permissions to create a Dataflow job. * Make sure that both Dataflow and Cloud Storage are enabled on your Google Cloud Platform project. ### Windows Developer Instructions Windows support for each CLI is currently under development ( See [#64](https://github.com/google/weather-tools/issues/64)). However, there are workarounds available for running the nweather tools outside of installation with `pip`. First, the would-be pip-installed script can be run directly with python like so: ```shell python weather_dl/weather-dl --help ``` ## Testing For testing at development time, we make use of three tools: * `pytype` for checking types: ```shell # check everything pytype # check a specific tool pytype weather_dl ``` * `ruff` for linting: ```shell # lint everything ruff check . # lint a specific tool ruff check weather_mv ``` * `pytest` for running tests: ```shell # test everything pytest # test a specific tool pytest weather_sp ``` If you'd like to automate running these checks, we provide a post-push git hook: ```shell cp bin/post-push .git/hooks/ ``` This script can be run manually, too. If you'd like to locally run these checks for all versions of python this project supports, you can use `tox`. Tox will make use of the python versions installed on your machine and create virtual test environments on your behalf. ```shell tox ``` In addition, we provide a simple script to install _other_ branches locally. Run `bin/install-branch ` to pip install that branches working copy of weather-tools. Hopefully, this script facilitates testing of work-in-progress contributions. Please review the [Beam testing docs](https://beam.apache.org/documentation/pipelines/test-your-pipeline/) for guidance in how to write tests for the pipeline. ## Documentation Documents are generated with Sphinx and the myst-parser. To generate the documents locally, simply invoke `make`: ```shell cd docs rm -r _build make html ``` > Note: Due to the idiosyncrasies of how Sphinx detects updates and our use of symbolic links, we recommend deleting the > `_build` folder. Or, you can run the following subshell command to re-generate everything without having to leave the project root: ```shell (cd docs && rm -r _build && make html) ``` After the docs are re-generated, you can view them by starting a local file server, for example: ```shell python -m http.server -d docs/_build/html ``` ## Versions & Releasing We aim to represent the version of each tool using [semver.org](https://semver.org/) semantic versions. To that end, we will abide by the following pattern: - When making a change to a particular tool, please remember to update its semantic version in the `setup.py` file. - The version for _all the tools_ should be incremented on update to _any_ tool. If one tool changes, the whole `google-weather-tools` package should have its version incremented. - The representation for the version of all of `google-weather-tools` is via [git tags](https://git-scm.com/book/en/v2/Git-Basics-Tagging). To update the version of the package, please make an annotated tag with a short description of the change. - When it's time for release, choose the latest tag and fill out the release description. Check out previous release notes to get an idea of how to structure the next one. These release notes are the primary changelog for the project. ================================================ FILE: Configuration.md ================================================ # Configuration Files Config files describe both _what_ to download and _how_ it should be downloaded. To this end, configs have two sections: `selection` that describes the data desired from data-sources and `parameters` that define the details of the download. By convention, the `parameters` section comes first. Configuration files can be written in `*.cfg` or `*.json`, but typically, they're written in the former format (i.e. Python's native config language, which is similar to INI format). Before jumping into the details of each section, let's look at a few example configs. ## Examples The following demonstrate how to download weather data from ECMWF's Copernicus (CDS) and Meteorological Archival and Retrieval System (MARS) catalogues. ### Download Era5 Pressure Level Reanalysis from Copernicus ``` [parameters] client=cds ; choose a data source client dataset=reanalysis-era5-pressure-levels ; specify a dataset to download from the data source (CDS-specific) target_path=gs://ecmwf-output-test/era5/{}/{}/{}-pressure-{}.nc ; create a template for the output file path partition_keys= ; define how we should partition the download by the "keys" in the `selection` section. year ; See docs below for more explanation month day pressure_level [selection] product_type=ensemble_mean format=netcdf variable= divergence fraction_of_cloud_cover geopotential pressure_level= 500 year= ; we can specify a list of values using multiple lines 2015 2016 2017 month= 01 day= 01 15 time= 00:00 06:00 12:00 18:00 ``` ### Download Yesterday's Surface Temperatures from MARS. ``` [parameters] client=mars ; download from MARS (this is the default data source). target_path=all.an ; Download all the data from the selection section into one file. [selection] class = od type = analysis levtype = surface date = -1 ; Yesterday -- see link to MARS request syntax in docs linked below. time = 00/06/12/18 ; We can specify multiple values using `/` delimiters -- see MARS request syntax docs param = z/sp ``` ## `parameters` Section _Parameters for the pipeline_ These describe which data source to download, where the data should live, and how the download should be partitioned. * `client`: (required) Select the weather API client. Supported values are `cds` for Copernicus, and `mars` for MARS. * `dataset`: (optional) Name of the target dataset. Allowed options are dictated by the client. * `target_path`: (required) Download artifact filename template. Can use Python string format symbols. Must have the same number of format symbols as the number of partition keys. * `partition_keys`: (optional) This determines how download jobs will be divided. * Value can be a single item or a list. * Each value must appear as a key in the `selection` section. * Each downloader will receive a config file with every parameter listed in the `selection`, _except_ for the fields specified by the `partition_keys`. * The downloader config will contain one instance of the cross-product of every key in `partition_keys`. * E.g. `['year', 'month']` will lead to a config set like `[(2015, 01), (2015, 02), (2015, 03), ...]`. * The list of keys will be used to format the `target_path`. > **NOTE**: `target_path` template is totally compatible with Python's standard string formatting. > This includes being able to use named arguments (e.g. 'gs://bucket/{year}/{month}/{day}.nc') as well as specifying formats for strings > (e.g. 'gs://bucket/{year:04d}/{month:02d}/{day:02d}.nc'). ### Creating a date-based directory hierarchy The date-based directory hierarchy can be created using Python's standard string formatting. Below are some examples of how to use `target_path` with Python's standard string formatting.
Examples Note that any parameters that are not relevant to the target path have been omitted. ``` [parameters] target_path=gs://ecmwf-output-test/era5/{date:%%Y/%%m/%%d}.nc partition_keys= date [selection] date=2017-01-01/to/2017-01-02 ``` will create `gs://ecmwf-output-test/era5/2017/01/01.nc` and `gs://ecmwf-output-test/era5/2017/01/02.nc`. ``` [parameters] target_path=gs://ecmwf-output-test/era5/{date:%%Y/%%m/%%d}-pressure-{pressure_level}.nc partition_keys= date pressure_level [selection] pressure_level= 500 date=2017-01-01/to/2017-01-02 ``` will create `gs://ecmwf-output-test/era5/2017/01/01-pressure-500.nc` and `gs://ecmwf-output-test/era5/2017/01/02-pressure-500.nc`. ``` [parameters] target_path=gs://ecmwf-output-test/pressure-{pressure_level}/era5/{date:%%Y/%%m/%%d}.nc partition_keys= date pressure_level [selection] pressure_level= 500 date=2017-01-01/to/2017-01-02 ``` will create `gs://ecmwf-output-test/pressure-500/era5/2017/01/01.nc` and `gs://ecmwf-output-test/pressure-500/era5/2017/01/02.nc`. ``` [parameters] target_path=gs://ecmwf-output-test/era5/{year:04d}/{month:02d}/{day:02d}-pressure-{pressure_level}.nc partition_keys= year month day pressure_level [selection] pressure_level= 500 year= 2017 month= 01 day= 01 02 ``` will create `gs://ecmwf-output-test/era5/2017/01/01-pressure-500.nc` and `gs://ecmwf-output-test/era5/2017/01/02-pressure-500.nc`. > **Note**: Replacing the `target_path` of the above example with this `target_path=gs://ecmwf-output-test/era5/{year}/{month}/{day}-pressure- >{pressure_level}.nc` > > will create > > `gs://ecmwf-output-test/era5/2017/1/1-pressure-500.nc` and > `gs://ecmwf-output-test/era5/2017/1/2-pressure-500.nc`. In addition to the above, the table below presents further partitioning examples based on recent enhancements to the weather-dl tool, particularly around date and related keywords.
Keyword
Description
Sample Config
Output Partitions
date Specifies the calendar date(s) for which data is retrieved in an ECMWF MARS request. [parameters]
target_path={date}.nc
partition_keys=date
[selection]
date=2017-01-01/to/2017-01-02
2017-01-01.nc
2017-01-02.nc
date [step] [time] [var] Along with specifying date in the ECMWF MARS request, users can also partition data using one or more of [forecast-step, initialization-time, or varirable].
Case-1
[parameters]
target_path={date}_{time}.nc
partition_keys=
date
time
[selection]
date=2025-01-01/to/2025-01-08/by/5
time=00/12
Case-2
[parameters]
target_path={date}_{step}.nc
partition_keys=
date
step
[selection]
date=2024-12-31/to/2024-01-26/by/-3
step=0
Case-1
2025-01-01_00:00:00.nc
2025-01-01_12:00:00.nc
2025-01-06_00:00:00.nc
2025-01-06_12:00:00.nc
Case-2
2024-12-31_0.nc
2024-12-28_0.nc
year, month, day Specifies dates in a decomposed form (separate fields) for ECMWF MARS requests, enabling flexible selection across ranges and combinations. Supports multiple year and month inputs, allowing users to define broad time spans without enumerating each period.
Case-1
[parameters]
target_path={year}/{year}-{month:02d}.grb2
partition_keys=
year
month
[selection]
year=2021
month=1/to/2
day=all
Case-2: Full year
[parameters]
target_path=full_{year}.nc
partition_keys=year
[selection]
year=2001/to/2002
month=1/to/12
day=all
Case-3: Odd months
[parameters]
target_path=odd_{year}.nc
partition_keys=year
[selection]
year=2001/to/2002
month=1/to/12/by/2
day=all
Case-4: Specific days
[parameters]
target_path=misc_{year}.nc
partition_keys=year
[selection]
year=2001/to/2002
month=1/to/12
day=1/5/10/15
Case-1
2021/2021-01.grb2
2021/2021-02.grb2
Case-2
full_2001.nc
full_2002.nc
Case-3
odd_2001.nc
odd_2002.nc
Case-4
misc_2001.nc
misc_2002.nc
year-month Added support for a year‑month key, allowing users to specify downloads using month granularity instead of ranged formats. [parameters]
target_path={year-month}.gb
partition_keys=year-month
[selection]
year-month=2024-11/to/2025-02
2024-11.gb
2024-12.gb
2025-01.gb
2025-02.gb
date_range Added support for specifying one or more date-range values, enabling users to download data across multiple date intervals in a single run.

Note: date_range must be specified in partition_keys.
[parameters]
target_path={date_range}.nc
partition_keys=date_range
[selection]
date_range=
2017-01-01/to/2017-01-10
2017-01-21/to/2017-01-31
2017-01-01_to_2017-01-10.nc
2017-01-21_to_2017-01-31.nc
hdate This parameter allows weather‑dl to explicitly specify historical target dates for downloads, giving users precise control over which past dates are retrieved. [parameters]
target_path={date}.gb
[selection]
date=2020-01-02
hdate=1/to/3
2019-01-02
2018-01-02
2017-01-02
(no partition keys specified) Introduced support for creating a single output-file when no partition_keys are specified, ensuring that all data is written into one consolidated file. [parameters]
target_path=data.nc
[selection]
date=2017-01-01/to/2017-01-05
time=00/06/12/18
data.nc
### Subsections Sometimes, we'd like to alternate passing certain parameters to each client. For example, certain data sources have limits on the number of API requests that can be made, enforcing a maximum per license. In these cases, the user can specify a parameters subsection. The downloader will overwrite the base parameters with the key-value pairs in each subsection, evenly alternating between each parameter set across the partitions. To specify a subsection, create a new section with the following naming pattern: `[parameters.]`. The `` can be any string, but it's recommended to chose a name that describes the grouping of values in the section. Here's an example of this type of configuration: ``` [parameters] dataset=ecmwf-mars-output target_template=gs://ecmwf-downloads/hres-single-level/{}.nc partition_keys= date [parameters.deepmind] api_key=KKKKK1 api_url=UUUUU1 [parameters.research] api_key=KKKKK2 api_url=UUUUU2 [parameters.cloud] api_key=KKKKK3 api_url=UUUUU3 ``` ## `selection` Section _Parameters used to select desired data_ These will be passed as request parameters to the specified API client. Selections are dependent on how each data source's catalog is structured. ### Copernicus / CDS **License**: By using Copernicus / CDS Dataset, users agree to the terms and conditions specified in dataset. i.e. [License](https://cds.climate.copernicus.eu/datasets/reanalysis-era5-single-levels?tab=download#manage-licences). **Catalog**: [https://cds.climate.copernicus.eu/datasets](https://cds.climate.copernicus.eu/datasets) Visit the follow to register / acquire API credentials: _[Install the CDS API key](https://cds.climate.copernicus.eu/how-to-api)_. After, please set the `api_url` and `api_key` arguments in the `parameters` section of your configuration. Alternatively, one can set these values as environment variables: ```shell export CDSAPI_URL=$api_url export CDSAPI_KEY=$api_key ``` For CDS parameter options, check out the [Copernicus documentation](https://cds.climate.copernicus.eu/datasets). See [this example](https://cds.climate.copernicus.eu/datasets/reanalysis-era5-pressure-levels?tab=overview) for what kind of requests one can make. ### MARS **License**: By using MARS Dataset, users agree to the terms and conditions specified in [License](https://www.ecmwf.int/en/forecasts/accessing-forecasts/licences-available) document. **Catalog**: [https://apps.ecmwf.int/archive-catalogue/](https://apps.ecmwf.int/archive-catalogue/) Visit the following to register / acquire API credentials: _[Install ECWMF Key](https://confluence.ecmwf.int/display/WEBAPI/Access+MARS#AccessMARS-key)_. After, please set the `api_url`, `api_key`, and `api_email` arguments in the `parameters` section of your configuration. Alternatively, one can set these values as environment variables: ```shell export MARSAPI_URL=$api_url export MARSAPI_EMAIL=$api_email export MARSAPI_KEY=$api_key ``` For MARS parameter options, first read up on [MARS request syntax](https://confluence.ecmwf.int/display/WEBAPI/Brief+MARS+request+syntax). For a full range of what data can be requested, please consult the [MARS catalog](https://apps.ecmwf.int/archive-catalogue/). See [these examples](https://confluence.ecmwf.int/display/UDOC/MARS+example+requests) to discover the kinds of requests that can be made. > **NOTE**: MARS data is stored on tape drives. It takes longer for multiple workers to request data than a single > worker. Thus, it's recommended _not_ to set a partition key when writing MARS data configurations. ================================================ FILE: Dockerfile ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== ARG py_version=3.8 FROM apache/beam_python${py_version}_sdk:2.40.0 as beam_sdk FROM continuumio/miniconda3:v25.11.1 # Add the mamba solver for faster builds RUN conda install -n base conda-libmamba-solver RUN conda config --set solver libmamba # Create conda env using environment.yml ARG weather_tools_git_rev=main RUN git clone https://github.com/google/weather-tools.git /weather WORKDIR /weather RUN git checkout "${weather_tools_git_rev}" RUN rm -r /weather/weather_*/test_data/ RUN conda env create -f environment.yml --debug && \ conda clean --all -f --yes # Activate the conda env and update the PATH ARG CONDA_ENV_NAME=weather-tools RUN echo "source activate ${CONDA_ENV_NAME}" >> ~/.bashrc ENV PATH /opt/conda/envs/${CONDA_ENV_NAME}/bin:$PATH # Install gcloud alpha RUN apt-get update -y RUN gcloud components install alpha --quiet # Copy files from official SDK image, including script/dependencies. COPY --from=beam_sdk /opt/apache/beam /opt/apache/beam # Set the entrypoint to Apache Beam SDK launcher. ENTRYPOINT ["/opt/apache/beam/boot"] ================================================ FILE: Efficient-Requests.md ================================================ # Writing Efficient Data Requests TODO([#26](https://github.com/googlestaging/weather-tools/issues/26)). In the mean-time, please consult this ECMWF documentation: * [Web API Retrieval Efficiency](https://confluence.ecmwf.int/display/WEBAPI/Retrieval+efficiency) * [Era 5 daily data retrieval efficiency](https://confluence.ecmwf.int/display/WEBAPI/ERA5+daily+retrieval+efficiency) ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: MANIFEST.in ================================================ global-exclude *.nc global-exclude *.gb *.grib global-exclude *.bz2 global-exclude *.yaml *.yml global-exclude *_test.py prune .github prune bin prune docs prune weather_*/test_data exclude tox.ini ================================================ FILE: README.md ================================================ # weather-tools Apache Beam pipelines to make weather data accessible and useful. [![CI](https://github.com/googlestaging/weather-tools/actions/workflows/ci.yml/badge.svg)](https://github.com/googlestaging/weather-tools/actions/workflows/ci.yml) [![Documentation Status](https://readthedocs.org/projects/weather-tools/badge/?version=latest)](https://weather-tools.readthedocs.io/en/latest/?badge=latest) ## Introduction This project contributes a series of command-line tools to make common data engineering tasks easier for researchers in climate and weather. These solutions were born out of the need to improve repeated work performed by research teams across Alphabet. The first tool created was the weather downloader (`weather-dl`). This makes it easier to ingest data from the European Center for Medium Range Forecasts (ECMWF). `weather-dl` enables users to describe very specifically what data they'd like to ingest from ECMWF's catalogs. It also offers them control over how to parallelize requests, empowering users to [retrieve data efficiently](Efficient-Requests.md). Downloads are driven from a [configuration file](Configuration.md), which can be reviewed (and version-controlled) independently of pipeline or analysis code. We also provide two additional tools to aid climate and weather researchers: the weather mover (`weather-mv`) and the weather splitter (`weather-sp`). These CLIs are still in their alpha stages of development. Yet, they have been used for production workflows for several partner teams. We created the weather mover (`weather-mv`) to load geospatial data from cloud buckets into [Google BigQuery](https://cloud.google.com/bigquery). This enables rapid exploratory analysis and visualization of weather data: From BigQuery, scientists can load arbitrary climate data fields into a Pandas or XArray dataframe via a simple SQL query. The weather splitter (`weather-sp`) helps normalize how archival weather data is stored in cloud buckets: Whether you're trying to merge two datasets with overlapping variables — or, you simply need to [open Grib data from XArray](https://github.com/ecmwf/cfgrib/issues/2), it's really useful to split datasets into their component variables. ## Installing It is currently recommended that you create a local python environment (with [Anaconda](https://www.anaconda.com/products/individual)) and install the sources as follows: ```shell conda env create --name weather-tools --file=environment.yml conda activate weather-tools ``` > Note: Due to its use of 3rd-party binary dependencies such as GDAL and MetView, `weather-tools` > is transitioning from PyPi to Conda for its main release channel. The instructions above > are a temporary workaround before our Conda-forge release. From here, you can use the `weather-*` tools from your python environment. Currently, the following tools are available: - [⛈ `weather-dl`](weather_dl/README.md) (_beta_) – Download weather data (namely, from ECMWF's API). - [⛅️ `weather-mv`](weather_mv/README.md) (_alpha_) – Load weather data into analytics engines, like BigQuery. - [🌪 `weather-sp`](weather_sp/README.md) (_alpha_) – Split weather data by arbitrary dimensions. ## Quickstart In this tutorial, we will download the [Era 5 pressure level dataset](https://cds.climate.copernicus.eu/datasets/reanalysis-era5-pressure-levels?tab=overview) and ingest it into Google BigQuery using `weather-dl` and `weather-mv`, respectively. ### Prerequisites 1. [Register here](https://www.ecmwf.int/) and [here](https://cds.climate.copernicus.eu/) for a license from ECMWF's [Copernicus (CDS) API](https://cds.climate.copernicus.eu/api-how-to). 2. User must agree to the Terms of Use of a dataset before downloading any data out of dataset.(E.g.: accept [terms & condition](https://cds.climate.copernicus.eu/datasets/reanalysis-era5-single-levels?tab=download#manage-licences) from here.) 3. Install your license by copying your API url & key from [this page](https://cds.climate.copernicus.eu/how-to-api) to a new file `$HOME/.cdsapirc`.[^1] The file should look like this: ``` url: https://cds.climate.copernicus.eu/api key: : ``` 4. If you do not already have a Google Cloud project, create one by following [these steps](https://cloud.google.com/docs/get-started). If you are working on an existing project, make sure your user has the [BigQuery Admin role](https://cloud.google.com/bigquery/docs/access-control#bigquery.admin). To learn more about granting IAM roles to users in Google Cloud, visit the [official docs](https://cloud.google.com/iam/docs/granting-changing-revoking-access#grant-single-role). 5. Create an empty BigQuery Dataset. This can be done using the [Google Cloud Console](https://cloud.google.com/bigquery/docs/quickstarts/quickstart-cloud-console#create_a_dataset) or via the [`bq` CLI tool](https://cloud.google.com/bigquery/docs/quickstarts/quickstart-command-line). For example: ```shell bq mk --project_id=$PROJECT_ID $DATASET_ID ``` 6. Follow [these steps](https://cloud.google.com/storage/docs/creating-buckets) to create a bucket for staging temporary files in [Google Cloud Storage](https://cloud.google.com/storage). ### Steps For the purpose of this tutorial, we will use your local machine to run the data pipelines. Note that all `weather-tools` can also be run in [Cloud Dataflow](https://cloud.google.com/dataflow) which is easier to scale and fully managed. 1. Use `weather-dl` to download the *Era 5 pressure level* dataset. ```bash weather-dl configs/era5_example_config_local_run.cfg \ --local-run # Use the local machine ``` > Recommendation: Pass the `-d, --dry-run` flag to any of these commands to preview the effects. **NOTE:** By default, local downloads are saved to the `./local_run` directory unless another file system is specified. The recommended output location for `weather-dl` is [Cloud Storage](https://cloud.google.com/storage). The source and destination of the download are configured using the `.cfg` configuration file which is passed to the command. To learn more about this configuration file's format and features, see [this reference](Configuration.md). To learn more about the `weather-dl` command, visit [here](weather_dl/README.md). 2. *(optional)* Split your downloaded dataset up with `weather-sp`: ```shell weather-sp --input-pattern "./local_run/era5-*.nc" \ --output-dir "split_data" ``` Visit the `weather-sp` [docs](weather_sp/README.md) for more information. 3. Use `weather-mv` to ingest the downloaded data into BigQuery, in a structured format. ```bash weather-mv bigquery --uris "./local_run/**.nc" \ # or "./split_data/**.nc" if weather-sp is used --output_table "$PROJECT.$DATASET_ID.$TABLE_ID" \ # The path to the destination BigQuery table --temp_location "gs://$BUCKET/tmp" \ # Needed for stage temporary files before writing to BigQuery --direct_num_workers 2 ``` See [these docs](weather_mv/README.md) for more about the `weather-mv` command. That's it! After the pipeline is completed, you should be able to query the ingested dataset in [BigQuery SQL workspace](https://cloud.google.com/bigquery/docs/bigquery-web-ui) and analyze it using [BigQuery ML](https://cloud.google.com/bigquery-ml/docs/introduction). ## Contributing The weather tools are under active development, and contributions are welcome! Please check out our [guide](CONTRIBUTING.md) to get started. ## License This is not an official Google product. ``` Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ``` [^1]: Note that you need to be logged in for the [CDS API page](https://cds.climate.copernicus.eu/api-how-to#install-the-cds-api-key) to actually show your user ID and API key. Otherwise, it will display a placeholder, which is confusing to some users. ================================================ FILE: Runners.md ================================================ # Choosing a Beam Runner All tools use Apache Beam pipelines. By default, pipelines run locally using the `DirectRunner`. You can optionally choose to run the pipelines on [Google Cloud Dataflow](https://cloud.google.com/dataflow) by selection the `DataflowRunner`. When working with GCP, it's recommended you set the project ID up front with the command: ```shell gcloud config set project ``` ## _Direct Runner options_: * `--direct_num_workers`: The number of workers to use. We recommend 2 for local development. Example run: ```shell weather-mv -i gs://netcdf_file.nc \ -o $PROJECT.$DATASET_ID.$TABLE_ID \ -t gs://$BUCKET/tmp \ --direct_num_workers 2 ``` For a full list of how to configure the direct runner, please review [this page](https://beam.apache.org/documentation/runners/direct/). ## _Dataflow options_: * `--runner`: The `PipelineRunner` to use. This field can be either `DirectRunner` or `DataflowRunner`. Default: `DirectRunner` (local mode) * `--project`: The project ID for your Google Cloud Project. This is required if you want to run your pipeline using the Dataflow managed service (i.e. `DataflowRunner`). * `--temp_location`: Cloud Storage path for temporary files. Must be a valid Cloud Storage URL, beginning with `gs://`. * `--region`: Specifies a regional endpoint for deploying your Dataflow jobs. Default: `us-central1`. * `--job_name`: The name of the Dataflow job being executed as it appears in Dataflow's jobs list and job details. Example run: ```shell weather-dl configs/seasonal_forecast_example_config.cfg \ --runner DataflowRunner \ --project $PROJECT \ --region $REGION \ --temp_location gs://$BUCKET/tmp/ ``` For a full list of how to configure the Dataflow pipeline, please review [this table](https://cloud.google.com/dataflow/docs/reference/pipeline-options). ## Monitoring When running Dataflow, you can [monitor jobs through UI](https://cloud.google.com/dataflow/docs/guides/using-monitoring-intf), or [via Dataflow's CLI commands](https://cloud.google.com/dataflow/docs/guides/using-command-line-intf): For example, to see all outstanding Dataflow jobs, simply run: ```shell gcloud dataflow jobs list ``` To describe stats about a particular Dataflow job, run: ```shell gcloud dataflow jobs describe $JOBID ``` In addition, Dataflow provides a series of [Beta CLI commands](https://cloud.google.com/sdk/gcloud/reference/beta/dataflow). These can be used to keep track of job metrics, like so: ```shell JOBID= gcloud beta dataflow metrics list $JOBID --source=user ``` You can even [view logs via the beta commands](https://cloud.google.com/sdk/gcloud/reference/beta/dataflow/logs/list): ```shell gcloud beta dataflow logs list $JOBID ``` ================================================ FILE: Runtime-Container.md ================================================ # Runtime Containers _How to build & public custom Dataflow images in GCS_ *Pre-requisites*: Install the gcloud CLI ([instructions are here](https://cloud.google.com/sdk/docs/install)). Then, log in to your cloud account: ```shell gcloud auth login ``` > Please follow all the instructions from the CLI. This will involve running an > auth script on your local machine, which will open a browser window to log you > in. Last, make sure you have adequate permissions to use Google Cloud Build (see [IAM options here](https://cloud.google.com/build/docs/iam-roles-permissions)). [This documentation](https://cloud.google.com/build/docs/securing-builds/configure-access-to-resources) will help you configure your project to use Cloud Build. *Updating the image*: Please modify the `Dockerfile` in the root directory. Then, build and upload the image with Google Cloud Build (updating the tag, as is appropriate): ```shell export PROJECT= export REPO=weather-tools export IMAGE_URI=gcr.io/$PROJECT/$REPO export TAG="0.0.0" # Please increment on every update. # from the project root... # dev release gcloud builds submit . --tag "$IMAGE_URI:dev" # release: gcloud builds submit . --tag "$IMAGE_URI:$TAG" && gcloud builds submit weather_mv/ --tag "$IMAGE_URI:latest" ``` ================================================ FILE: bin/install-branch ================================================ #!/usr/bin/env sh # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Set the first argument as the name of the branch. If unset, default to "main". pip install -e "git+http://github.com/google/weather-tools.git@${1:-main}#egg=google-weather-tools" ================================================ FILE: bin/post-push ================================================ #!/usr/bin/env sh # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. MAG="\033[95m" BLD="\033[1m" END="\033[0m" # End script early on error set -e # In project root... cd "$(dirname $0)/.." flake8 pytype pytest echo "$MAG$BLD--- Presubmit: Success ---$END" exit 0 ================================================ FILE: ci3.8.yml ================================================ name: weather-tools channels: - conda-forge - defaults dependencies: - python=3.8.13 - apache-beam=2.40.0 - pytest=7.2.0 - pytest-subtests=0.8.0 - cfgrib=0.9.10.2 - dask=2022.10.0 - dataclasses=0.8 - distributed=2022.10.0 - eccodes=2.27.0 - requests=2.28.1 - netcdf4=1.6.1 - rioxarray=0.13.4 - xarray-beam=0.6.2 - ecmwf-api-client=1.6.3 - fsspec=2022.11.0 - gcsfs=2022.11.0 - gdal=3.5.1 - pyproj=3.4.0 - geojson=2.5.0=py_0 - simplejson=3.17.6 - metview-batch=5.17.0 - numpy=1.22.4 - pandas=1.5.1 - pip=24.3.1 - pygrib=2.1.4 - xarray==2023.1.0 - ruff==0.1.2 - google-cloud-sdk=410.0.0 - aria2=1.36.0 - zarr=2.15.0 - google-cloud-monitoring=2.22.2 - pillow=10.4.0 - cdsapi=0.7.5 - pip: - cython==0.29.34 - earthengine-api==0.1.329 - pyparsing==3.1.4 - .[test] ================================================ FILE: ci3.9.yml ================================================ name: weather-tools channels: - conda-forge - defaults dependencies: - python=3.9.13 - apache-beam=2.40.0 - pytest=7.2.0 - pytest-subtests=0.8.0 - cfgrib=0.9.10.2 - dask=2022.10.0 - dataclasses=0.8 - distributed=2022.10.0 - eccodes=2.27.0 - requests=2.28.1 - netcdf4=1.6.1 - rioxarray=0.13.4 - xarray-beam=0.6.2 - ecmwf-api-client=1.6.3 - fsspec=2022.11.0 - gcsfs=2022.11.0 - gdal=3.5.1 - pyproj=3.4.0 - geojson=2.5.0=py_0 - simplejson=3.17.6 - metview-batch=5.17.0 - numpy=1.22.4 - pandas=1.5.1 - pip=24.3.1 - pygrib=2.1.4 - google-cloud-sdk=410.0.0 - aria2=1.36.0 - xarray==2023.1.0 - ruff==0.1.2 - zarr=2.15.0 - google-cloud-monitoring=2.22.2 - pillow=10.4.0 - cdsapi=0.7.5 - pip: - cython==0.29.34 - earthengine-api==0.1.329 - pyparsing==3.1.4 - .[test] ================================================ FILE: configs/era5_example_config.cfg ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [parameters] client=cds dataset=reanalysis-era5-pressure-levels target_path=gs://ecmwf-output-test/era5/{year:04d}/{month:02d}/{day:02d}-pressure-{pressure_level}.nc partition_keys= year month day pressure_level api_url=https://cds.climate.copernicus.eu/api # a fake key for tests api_key=12345:1234567-ab12-34cd-9876-4o4fake90909 [selection] product_type=reanalysis format=netcdf variable= divergence fraction_of_cloud_cover geopotential pressure_level= 500 year= 2015 2016 2017 month= 01 day= 01 15 time= 00:00 06:00 12:00 18:00 ================================================ FILE: configs/era5_example_config_local_run.cfg ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ## Run with # $ weather-dl configs/era5_example_config_local_run.cfg --local-run # This will create a folder '$CWD/local_run' with a manifest.json file and four data files # era5-20160101-pressure-500.nc # era5-20160115-pressure-500.nc # era5-20170101-pressure-500.nc # era5-20170115-pressure-500.nc [parameters] client=cds dataset=reanalysis-era5-pressure-levels target_path=era5-{year:04d}{month:02d}{day:02d}-pressure-{pressure_level}.nc partition_keys= year month day pressure_level [selection] product_type=reanalysis format=netcdf variable= divergence fraction_of_cloud_cover geopotential pressure_level= 500 year= 2016 2017 month= 01 day= 01 15 time= 00:00 12:00 ================================================ FILE: configs/era5_example_config_local_run.json ================================================ { "parameters": { "client": "cds", "dataset": "reanalysis-era5-pressure-levels", "target_path": "era5-{}{}{}-pressure-{}.nc", "partition_keys": ["year","month","day","pressure_level"] }, "selection": { "product_type": "ensemble_mean", "format": "netcdf", "variable": ["divergence","fraction_of_cloud_cover","geopotential"], "pressure_level": [500], "year": [2016, 2017], "month": [1], "day": [1, 15], "time": ["00:00", "12:00"] } } ================================================ FILE: configs/era5_example_config_preproc.cfg ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [parameters] client=cds dataset=reanalysis-era5-pressure-levels target_path=gs://ecmwf-downloads/test/o1280-{year:04d}-{month:02d}-{day:02d}.grib partition_keys= year month day [selection] # These are the averages of an ensemble of analyses. product_type=ensemble_mean format=grib grid=o1280 variable= relative_humidity temperature fraction_of_cloud_cover geopotential u_component_of_wind v_component_of_wind pressure_level= 500 year= 2020 month= 01 day= 01 02 time= 00:00 12:00 ================================================ FILE: configs/era5_example_config_using_date.cfg ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [parameters] client=cds dataset=reanalysis-era5-pressure-levels # This config creates a date-based directory hierarchy. # In this case, the two files that will be created are # gs://ecmwf-output-test/era5/2017/01/01-pressure-500.nc # gs://ecmwf-output-test/era5/2017/01/02-pressure-500.nc # gs://ecmwf-output-test/era5/2017/01/01-pressure-1000.nc # gs://ecmwf-output-test/era5/2017/01/02-pressure-1000.nc target_path=gs://ecmwf-output-test/era5/{date:%%Y/%%m/%%d}-pressure-{pressure_level}.nc partition_keys= date pressure_level [selection] product_type=reanalysis format=netcdf variable= divergence fraction_of_cloud_cover geopotential pressure_level= 500 1000 date=2017-01-01/to/2017-01-02 time= 00:00 06:00 12:00 18:00 ================================================ FILE: configs/era5_example_monthly_soil.cfg ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [parameters] client=cds dataset=reanalysis-era5-complete target_path=gs://ecmwf-output-test/era5/ERA5GRIB/HRES/Month/{year}/{year}{month:02d}_hres_soil.grb2 partition_keys= year month [selection] class=ea stream=oper expver=1 levtype=sfc type=an year=1979/to/2021 month=01/to/12 day=all time=00/to/23/by/1 param=139.128/170.128/183.128/236.128/238.128/39.128/40.128/41.128/42.128/35.128/36.128/37.128/38.128 ================================================ FILE: configs/mars_example_config.cfg ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [parameters] client=mars dataset=ecmwf-mars-output target_path=gs://ecmwf-downloads/hres-single-level/{date:%%Y/%%m/%%d}.nc partition_keys= date [selection] # example requests # https://confluence.ecmwf.int/display/UDOC/MARS+example+requests # hres fields # https://www.ecmwf.int/en/forecasts/datasets/set-i stream=oper levtype=sfc param=10fg6/10u/10v/100u/100vcrr/2t/2d/200u/200v/cp/dsrp/hcc/i10fg/lcc/lsp/lspf/lsrr/msl/ptype/sf/sp/ssr/tcrw/tclw/tcsw/tcw/tcwv/tp padding=0 step=0/1/2/3/4/5/6/7/8/9/10/11/12/24/48/72/96/120/144/168/192/216/240 grid=0.125/0.125 expver=1 time=0000/1800 date=2017-01-01/to/2017-01-07 # these are weather forecasts type=fc class=od expect=anymars format=netcdf ================================================ FILE: configs/mars_example_config.json ================================================ { "parameters": { "client": "mars", "dataset": "ecmwf-mars-output", "target_path": "gs://ecmwf-downloads/hres-single-level/{:%Y/%m/%d}.nc", "partition_keys": "date" }, "selection": { "stream": "oper", "levtype": "sfc", "param": ["10fg6","10u","10v","100u","100vcrr","2t","2d","200u","200v","cp","dsrp","hcc","i10fg","lcc","lsp","lspf","lsrr","msl","ptype","sf","sp","ssr","tcrw","tclw","tcsw","tcw","tcwv","tp"], "padding": 0, "step": [0,1,2,3,4,5,6,7,8,9,10,11,12,24,48,72,96,120,144,168,192,216,240], "grid": [0.125,0.125], "expver": 1, "time": ["0000","1800"], "date": ["2017-01-01","2017-01-07"], "type": "fc", "class": "od", "expect": "anymars", "format": "netcdf" } } ================================================ FILE: configs/multiple_multiple_licenses/era5_pressure500.cfg ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [parameters] client=cds dataset=reanalysis-era5-pressure-levels target_path=gs://ecmwf-output-test/era5/{year:04d}/{month:02d}/{day:02d}-pressure-{pressure_level}.nc partition_keys= year month day pressure_level [parameters.a] api_url=https://cds.climate.copernicus.eu/api # a fake key for tests api_key=12345:1234567-ab12-34cd-9876-4o4fake90909 [parameters.b] api_url=https://cds.climate.copernicus.eu/api # a fake key for tests api_key=12345:1234567-ab12-34cd-9876-4o4fake90909 [parameters.c] api_url=https://cds.climate.copernicus.eu/api # a fake key for tests api_key=12345:1234567-ab12-34cd-9876-4o4fake90909 [selection] product_type=reanalysis format=netcdf variable= divergence fraction_of_cloud_cover geopotential pressure_level= 500 year= 2015 2016 2017 month= 01 day= 01 15 time= 00:00 06:00 12:00 18:00 ================================================ FILE: configs/multiple_multiple_licenses/era5_pressure600.cfg ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [parameters] client=cds dataset=reanalysis-era5-pressure-levels target_path=gs://ecmwf-output-test/era5/{year:04d}/{month:02d}/{day:02d}-pressure-{pressure_level}.nc partition_keys= year month day pressure_level [parameters.a] api_url=https://cds.climate.copernicus.eu/api # a fake key for tests api_key=12345:1234567-ab12-34cd-9876-4o4fake90909 [parameters.b] api_url=https://cds.climate.copernicus.eu/api # a fake key for tests api_key=12345:1234567-ab12-34cd-9876-4o4fake90909 [parameters.c] api_url=https://cds.climate.copernicus.eu/api # a fake key for tests api_key=12345:1234567-ab12-34cd-9876-4o4fake90909 [selection] product_type=reanalysis format=netcdf variable= divergence fraction_of_cloud_cover geopotential pressure_level= 600 year= 2015 2016 2017 month= 01 day= 01 15 time= 00:00 06:00 12:00 18:00 ================================================ FILE: configs/multiple_multiple_licenses/era5_pressure700.cfg ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [parameters] client=cds dataset=reanalysis-era5-pressure-levels target_path=gs://ecmwf-output-test/era5/{year:04d}/{month:02d}/{day:02d}-pressure-{pressure_level}.nc partition_keys= year month day pressure_level [parameters.a] api_url=https://cds.climate.copernicus.eu/api # a fake key for tests api_key=12345:1234567-ab12-34cd-9876-4o4fake90909 [parameters.b] api_url=https://cds.climate.copernicus.eu/api # a fake key for tests api_key=12345:1234567-ab12-34cd-9876-4o4fake90909 [parameters.c] api_url=https://cds.climate.copernicus.eu/api # a fake key for tests api_key=12345:1234567-ab12-34cd-9876-4o4fake90909 [selection] product_type=reanalysis format=netcdf variable= divergence fraction_of_cloud_cover geopotential pressure_level= 700 year= 2015 2016 2017 month= 01 day= 01 15 time= 00:00 06:00 12:00 18:00 ================================================ FILE: configs/multiple_single_license/era5_pressure500.cfg ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [parameters] client=cds dataset=reanalysis-era5-pressure-levels target_path=gs://ecmwf-output-test/era5/{year:04d}/{month:02d}/{day:02d}-pressure-{pressure_level}.nc partition_keys= year month day pressure_level api_url=https://cds.climate.copernicus.eu/api # a fake key for tests api_key=12345:1234567-ab12-34cd-9876-4o4fake90909 [selection] product_type=reanalysis format=netcdf variable= divergence fraction_of_cloud_cover geopotential pressure_level= 500 year= 2015 2016 2017 month= 01 day= 01 15 time= 00:00 06:00 12:00 18:00 ================================================ FILE: configs/multiple_single_license/era5_pressure600.cfg ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [parameters] client=cds dataset=reanalysis-era5-pressure-levels target_path=gs://ecmwf-output-test/era5/{year:04d}/{month:02d}/{day:02d}-pressure-{pressure_level}.nc partition_keys= year month day pressure_level api_url=https://cds.climate.copernicus.eu/api # a fake key for tests api_key=12345:1234567-ab12-34cd-9876-4o4fake90909 [selection] product_type=reanalysis format=netcdf variable= divergence fraction_of_cloud_cover geopotential pressure_level= 600 year= 2015 2016 2017 month= 01 day= 01 15 time= 00:00 06:00 12:00 18:00 ================================================ FILE: configs/multiple_single_license/era5_pressure700.cfg ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [parameters] client=cds dataset=reanalysis-era5-pressure-levels target_path=gs://ecmwf-output-test/era5/{year:04d}/{month:02d}/{day:02d}-pressure-{pressure_level}.nc partition_keys= year month day pressure_level api_url=https://cds.climate.copernicus.eu/api # a fake key for tests api_key=12345:1234567-ab12-34cd-9876-4o4fake90909 [selection] product_type=reanalysis format=netcdf variable= divergence fraction_of_cloud_cover geopotential pressure_level= 700 year= 2015 2016 2017 month= 01 day= 01 15 time= 00:00 06:00 12:00 18:00 ================================================ FILE: configs/s2s_operational_forecast_example.cfg ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [parameters] client=ecpublic dataset=s2s target_path=gs://ecmwf-downloads/s2s/pressure-levels/{date}-all-levs.gb # target_path=gs://ecmwf-downloads/s2s/pressure-levels/{date}-hc-{hdate}-all-levs.gb partition_keys= date # hdate [selection] class=s2 date=2016-01-04/2016-01-07/2016-01-11/2016-01-14/2016-01-18/2016-01-21/2016-01-25 # hdate = , The number of years to subtract. # Eg input: # date = 2020-01-02 # hdate = 1/to/6 or 1/2/3/4/5/6 # Code will do: # hdate = 2019-01-02/2018-01-02/2017-01-02/2016-01-02/2015-01-02/2014-01-02 # Note: If 'hdate' is specified in the 'selection' section, then 'date' is required as a partition keys. hdate=1/to/20 expver=prod levelist=10/50/100/200/300/500/700/850/925/1000 levtype=pl model=glob number=1/to/50 origin=ecmf # All except q param=130/131/132/135/156 step=0/to/1104/by/24 stream=enfo time=00:00:00 type=pf ================================================ FILE: configs/seasonal_forecast_example_config.cfg ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [parameters] client=cds dataset=seasonal-original-single-levels target_path=gs://ecmwf-output-test/seasonal-forecast/seasonal-forecast-{year:04d}-{month:02d}.nc partition_keys= year month [selection] format=netcdf originating_centre=ecmwf variable= 10m_u_component_of_wind 10m_v_component_of_wind year= 2016 2017 month= 01 day= 01 leadtime_hour= 6 area= 10 -10 -10 10 ================================================ FILE: configs/tigge_example_config.cfg ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [parameters] client=ecpublic dataset=tigge target_path=gs://ecmwf-output-test/tigge/{}.gb partition_keys= date [selection] class=ti date=2020-01-01/to/2020-01-31 expver=prod grid=0.25/0.25 levtype=sfc number=1/to/50 origin=ecmf param=167 step=0/to/360/by/6 time=00/12 type=pf ================================================ FILE: configs/yesterdays_surface_example.cfg ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [parameters] client=mars partition_keys= date target_path=all-{}.an [selection] class = od type = analysis levtype = surface date = -1 time = 00/06/12/18 param = z/sp ================================================ FILE: docs/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/Private-IP-Configuration.md ================================================ # Private IP Configuration _A Guide for Dataflow Pipeline Execution_ ## Goals In this document, we’ll describe how to use Private IP for the execution of a dataflow pipeline. ## Background When we are running the dataflow pipeline, GCP decides to spawn one or more new VM-instances. By default, each VM-instance will have an External IP address. ![VM Instance with External IP Address](_static/vm_instance_with_external_ip_address.png "VM Instance with External IP Address") Considering a billing account has a limited number of External IP addresses, we can skip this overhead by providing VPC-parameters as CLI-input of dataflow. Following table* summarizes the required input parameters. | Field | Description | |-------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | network | The Compute Engine network for launching Compute Engine instances to run your pipeline. If not set, Google Cloud assumes that you intend to use a network named default. | | subnetwork | The Compute Engine subnetwork for launching Compute Engine instances to run your pipeline. | | no_use_public_ips | Command-line flag that sets use_public_ips to False. If the option is not explicitly enabled or disabled, the Dataflow workers use public IP addresses. | (* excerpt from [GCP pipeline-options](https://cloud.google.com/dataflow/docs/reference/pipeline-options)) ## Steps to Configure VPC 1. Configure VPC-Network & Subnetwork 2. Configure Firewall-Rule 3. Configure NAT & Router 4. Sample commands to trigger dataflow pipeline execution using above options ## Configure VPC-Network & Subnetwork 1. Open GCP’s Create VPC-Network page. 2. Provide name, description. 3. Select “Subnet creation mode” as Custom. 4. Provide name, description, region, IP address range & other details for the new subnet. 5. Select “Private Google Access” as On. 6. Complete VPC-Network creation by providing other required parameters. Refer to GCP’s [Create-VPC-Network](https://cloud.google.com/vpc/docs/create-modify-vpc-networks) documentation for more details. ![VPC Network Details](_static/vpc_network_details.png "VPC Network Details") ## Configure Firewall-Rule 1. Open GCP’s Create a firewall rule page. 2. Provide name, description. You may set “Logs” as Off. 3. For the “Network” drop-down, select the network that we created in the previous step. 4. Select “Direction of traffic” as Ingress & “Action on match” as Allow. 5. For the "Targets" drop-down, select "All instances in the network". 6. For the "Source IPv4 ranges", pass in the following: "0.0.0.0/0". 7. Select "Protocols and ports" as Allow all. 8. Complete Firewall-Rule creation by providing other necessary information. Refer to GCP’s [Configuring-Firewall](https://cloud.google.com/filestore/docs/configuring-firewall) documentation for more details. ![Firewall Rule Details](_static/firewall_rule_details.png "Firewall Rule Details") ## Configure NAT & Router 1. Open GCP’s Create a NAT gateway page. 2. Provide name, region. 3. For the “Network” drop-down, select the network that we created earlier. 4. For the “Router” drop-down, EITHER select pre-created router OR click on “create new router”.
a. Complete router creation by providing name, description & region. Refer to GCP’s [Create-Router](https://cloud.google.com/network-connectivity/docs/router/how-to/create-router-vpc-on-premises-network) documentation for more details. 5. Complete NAT gateway creation by providing required details. Refer to GCP’s [Create-NAT-Gateway](https://cloud.google.com/nat/docs/set-up-manage-network-address-translation) documentation for more details. ![Router Details](_static/router_details.png "Router Details") ![NAT Gateways](_static/nat_gateways.png "NAT Gateways") ## Sample commands to trigger dataflow pipeline execution using above options Following section showcases how VPC-parameters can be given as CLI inputs to weather-mv dataflow pipeline. ```bash weather-mv --uris "gs://$STORAGE_BUCKET/*.nc" --output_table "$HOST_PROJECT_ID.$DATASET_ID.$TABLE_ID" --temp_location "gs://$STORAGE_BUCKET/tmp" --runner DataflowRunner --project $HOST_PROJECT_ID --region $REGION_NAME --no_use_public_ips --network=$NETWORK_NAME --subnetwork=regions/$REGION_NAME/subnetworks/$SUBNETWORK_NAME ``` Replace the following: - STORAGE_BUCKET: the storage bucket, e.g. bucket_58231 - HOST_PROJECT_ID: the host project ID, e.g. weather_tools - DATASET_ID: the name of dataset, e.g. weather_mv_ds - TABLE_ID: the name of table, e.g. tbl_2017_01 - REGION_NAME: the regional endpoint of your Dataflow job, e.g. us-central1 - NETWORK_NAME: the name of your Compute Engine network, e.g. dataflow - Provide network_name same as what we created in Step-1. - SUBNETWORK_NAME: the name of your Compute Engine subnetwork, e.g. private - Provide a subnetwork_name same as what we created in Step-1. Alternatively, you may also execute following command, ```bash weather-mv --uris "gs://$STORAGE_BUCKET/*.nc" --output_table "$HOST_PROJECT_ID.$DATASET_ID.$TABLE_ID" --temp_location "gs://$STORAGE_BUCKET/tmp" --runner DataflowRunner --project $HOST_PROJECT_ID --region $REGION_NAME --no_use_public_ips –-subnetwork=https://www.googleapis.com/compute/v1/projects/$HOST_PROJECT_ID/regions/$REGION_NAME/subnetworks/$SUBNETWORK_NAME ``` ================================================ FILE: docs/_static/custom.css ================================================ p { text-align: justify; } body { min-width: 250px; } div.body { min-width: 250px; } ================================================ FILE: docs/conf.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. import os import sys sys.path.insert(0, os.path.abspath('../weather_dl')) sys.path.insert(1, os.path.abspath('../weather_mv')) sys.path.insert(2, os.path.abspath('../weather_sp')) # -- Project information ----------------------------------------------------- project = 'weather-tools' copyright = '2021 Google' author = 'Anthromets' # The full version, including alpha/beta/rc tags release = '0.0' # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'myst_parser', 'sphinx.ext.todo', 'sphinx.ext.viewcode', 'sphinx.ext.autodoc', ] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = 'alabaster' # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] # https://stackoverflow.com/a/66295922/809705 autodoc_typehints = "description" ================================================ FILE: docs/download_pipeline.md ================================================ # download_pipeline package ========================= ## Submodules ### download_pipeline.clients module ```{eval-rst} .. automodule:: weather_dl.download_pipeline.clients :members: :undoc-members: :show-inheritance: ``` ### download_pipeline.manifest module ```{eval-rst} .. automodule:: weather_dl.download_pipeline.manifest :members: :undoc-members: :show-inheritance: ``` ### download_pipeline.parsers module ```{eval-rst} .. automodule:: weather_dl.download_pipeline.parsers :members: :undoc-members: :show-inheritance: ``` ### download_pipeline.pipeline module ```{eval-rst} .. automodule:: weather_dl.download_pipeline.pipeline :members: :undoc-members: :show-inheritance: ``` ### download_pipeline.store module ```{eval-rst} .. automodule:: weather_dl.download_pipeline.stores :members: :undoc-members: :show-inheritance: ``` ## Module contents ```{eval-rst} .. automodule:: download_pipeline :members: :undoc-members: :show-inheritance: ``` ================================================ FILE: docs/index.md ================================================ ```{include} README.md ``` ## Contents ```{eval-rst} .. toctree:: :maxdepth: 2 :titlesonly: :glob: :includehidden: weather_dl/README weather_mv/README weather_sp/README Configuration Efficient-Requests Runners Private-IP-Configuration Runtime-Container CONTRIBUTING modules ``` # Indices and tables ```{eval-rst} * :ref:`genindex` * :ref:`modindex` * :ref:`search` ``` ================================================ FILE: docs/loader_pipeline.md ================================================ # loader_pipeline package ## Submodules ### loader_pipeline.netcdf_loader module ```{eval-rst} .. automodule:: weather_mv.loader_pipeline :members: :undoc-members: :show-inheritance: ``` ## Module contents ```{eval-rst} .. automodule:: loader_pipeline :members: :undoc-members: :show-inheritance: ``` ================================================ FILE: docs/make.bat ================================================ REM Copyright 2021 Google LLC REM REM Licensed under the Apache License, Version 2.0 (the "License"); REM you may not use this file except in compliance with the License. REM You may obtain a copy of the License at REM REM https://www.apache.org/licenses/LICENSE-2.0 REM REM Unless required by applicable law or agreed to in writing, software REM distributed under the License is distributed on an "AS IS" BASIS, REM WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. REM See the License for the specific language governing permissions and REM limitations under the License. @ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=. set BUILDDIR=_build if "%1" == "" goto help %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.http://sphinx-doc.org/ exit /b 1 ) %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end popd ================================================ FILE: docs/modules.md ================================================ # Modules ```{toctree} :maxdepth: 4 download_pipeline loader_pipeline splitter_pipeline ``` ================================================ FILE: docs/requirements.txt ================================================ # doc requirements myst-parser==0.13.7 sphinx>=2.1 Jinja2<3.1 ================================================ FILE: docs/splitter_pipeline.md ================================================ # splitter_pipeline package ========================= ## Submodules ### splitter_pipeline.file_splitters module ```{eval-rst} .. automodule:: weather_sp.splitter_pipeline.file_splitters :members: :undoc-members: :show-inheritance: ``` ### splitter_pipeline.pipeline module ```{eval-rst} .. automodule:: weather_sp.splitter_pipeline.pipeline :members: :undoc-members: :show-inheritance: ``` ## Module contents ```{eval-rst} .. automodule:: splitter_pipeline :members: :undoc-members: :show-inheritance: ``` ================================================ FILE: environment.yml ================================================ name: weather-tools channels: - conda-forge dependencies: - python=3.8.13 - apache-beam=2.40.0 - xarray-beam=0.6.2 - xarray=2023.1.0 - fsspec=2022.11.0 - gcsfs=2022.11.0 - rioxarray=0.13.4 - gdal=3.5.1 - pyproj=3.4.0 - ecmwf-api-client=1.6.3 - eccodes=2.27.0 - cfgrib=0.9.10.2 - pygrib=2.1.4 - metview-batch=5.17.0 - netcdf4=1.6.1 - geojson=2.5.0=py_0 - simplejson=3.17.6 - numpy=1.22.4 - pandas=1.5.1 - google-cloud-sdk=410.0.0 - aria2=1.36.0 - pip=24.3.1 - zarr=2.15.0 - google-cloud-monitoring=2.22.2 - pillow=10.4.0 - cdsapi=0.7.5 - pip: - cython==0.29.34 - earthengine-api==0.1.329 - firebase-admin==6.0.1 - contourpy==1.1.1 - google-crc32c==1.1.2 - MarkupSafe==2.1.5 - setuptools==70.3.0 - pyparsing==3.1.4 - . - ./weather_dl - ./weather_mv - ./weather_sp ================================================ FILE: pyproject.toml ================================================ [tool.ruff] # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. select = ["E", "F", "W"] ignore = [] # Allow autofix for all enabled rules (when `--fix`) is provided. fixable = ["A", "B", "C", "D", "E", "F"] unfixable = [] # Exclude a variety of commonly ignored directories. exclude = [ ".bzr", ".direnv", ".eggs", ".git", ".hg", ".mypy_cache", ".nox", ".pants.d", ".pytype", ".ruff_cache", ".svn", ".tox", ".venv", "__pypackages__", "_build", "buck-out", "build", "dist", "node_modules", "venv", ] # Same as Black. line-length = 120 # Allow unused variables when underscore-prefixed. dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" # Assume Python 3.10. target-version = "py310" [tool.ruff.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 10 ================================================ FILE: setup.cfg ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [pytype] inputs = weather_dl weather_mv weather_sp ================================================ FILE: setup.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from setuptools import find_packages, setup beam_gcp_requirements = [ "google-cloud-bigquery==2.34.4", "google-cloud-bigquery-storage==2.14.1", "google-cloud-bigtable==1.7.2", "google-cloud-core==1.7.3", "google-cloud-datastore==1.15.5", "google-cloud-dlp==3.8.0", "google-cloud-language==1.3.2", "google-cloud-pubsub==2.13.4", "google-cloud-pubsublite==1.4.2", "google-cloud-recommendations-ai==0.2.0", "google-cloud-spanner==1.19.3", "google-cloud-videointelligence==1.16.3", "google-cloud-vision==1.0.2", "apache-beam[gcp]==2.40.0", ] weather_dl_requirements = [ "cdsapi==0.7.5", "ecmwf-api-client", "numpy>=1.19.1", "pandas", "xarray", "requests>=2.24.0", "google-cloud-firestore", "firebase-admin", "urllib3==1.26.5", ] weather_mv_requirements = [ "dataclasses", "numpy", "pandas", "xarray", "cfgrib", "netcdf4", "geojson", "simplejson", "rioxarray", "rasterio", "earthengine-api>=0.1.263", "pyproj", # requires separate binary installation! "gdal", # requires separate binary installation! "xarray-beam==0.6.2", "gcsfs==2022.11.0", "zarr==2.15.0", ] weather_sp_requirements = [ "numpy>=1.20.3", "pygrib", "xarray", "scipy", ] test_requirements = [ "pytype==2021.11.29", "ruff==0.1.2", "pytest", "pytest-subtests", "netcdf4", "numpy", "xarray", "xarray-beam", "absl-py", "metview", "memray", "pytest-memray", "h5py", "pooch", ] all_test_requirements = beam_gcp_requirements + weather_dl_requirements + \ weather_mv_requirements + weather_sp_requirements + \ test_requirements setup( name='google-weather-tools', packages=find_packages(), author='Anthromets', author_email='anthromets-ecmwf@google.com', url='https://weather-tools.readthedocs.io/', description='Apache Beam pipelines to make weather data accessible and useful.', long_description=open('README.md', 'r', encoding='utf-8').read(), long_description_content_type='text/markdown', platforms=['darwin', 'linux'], license='License :: OSI Approved :: Apache Software License', classifiers=[ 'Development Status :: 4 - Beta', 'Environment :: Console', 'Intended Audience :: Science/Research', 'Intended Audience :: Developers', 'Intended Audience :: Information Technology', 'License :: OSI Approved :: Apache Software License', 'Operating System :: MacOS :: MacOS X', # 'Operating System :: Microsoft :: Windows', # TODO(#64): Fully support Windows. 'Operating System :: POSIX', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Topic :: Scientific/Engineering :: Atmospheric Science', ], python_requires='>=3.8, <3.10', install_requires=['apache-beam[gcp]==2.40.0', 'gcsfs==2022.11.0'], use_scm_version=True, setup_requires=['setuptools_scm'], scripts=['weather_dl/weather-dl', 'weather_mv/weather-mv', 'weather_sp/weather-sp'], tests_require=test_requirements, extras_require={ 'docs': ['tox', 'sphinx>=2.1', 'myst-parser', 'Jinja2<3.1'], 'test': all_test_requirements, 'dev': ['google-weather-tools[docs,test]'], 'regrid': ['metview'] }, project_urls={ 'Issue Tracking': 'http://github.com/google/weather-tools/issues', }, ) ================================================ FILE: tox.ini ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. [tox] envlist = py37,py38,lint,type [testenv] deps = .[test] commands = pytest [testenv:type] deps = .[test] commands = pytype [testenv:lint] deps = .[test] commands = flake8 weather_dl flake8 weather_mv flake8 weather_sp [flake8] max-line-length = 120 [gh-actions] python = 3.8: py38, lint, type 3.9: py39 ================================================ FILE: weather_dl/MANIFEST.in ================================================ global-exclude *_test.py include README.md exclude download_status*.py ================================================ FILE: weather_dl/README.md ================================================ # ⛈ `weather-dl` – Weather Downloader Weather Downloader ingests weather data to cloud buckets, such as [Google Cloud Storage](https://cloud.google.com/storage) (_beta_). ## Features * **Flexible Pipelines**: `weather-dl` offers a high degree of control over what is downloaded via configuration files. Separate scripts need not be written to get new data or add parameters. For more, see the [configuration docs](../Configuration.md). * **Efficient Parallelization**: The tool gives you full control over how downloads are sharded and parallelized (with good defaults). This lets you focus on the data and not the plumbing. * **Hassle-Free Dev-Ops**. `weather-dl` and Dataflow make it easy to spin up VMs on your behalf with one command. No need to keep your local machine online all night to acquire data. * **Robust Downloads**. If an error occurs when fetching a shard, Dataflow will automatically retry the download for you. Previously downloaded shards will be skipped by default, so you can re-run the tool without having to worry about duplication of work. > Note: Currently, only ECMWF's MARS and CDS clients are supported. If you'd like to use `weather-dl` to work with other > data sources, please [file an issue](https://github.com/googlestaging/weather-tools/issues) (or consider > [making a contribution](../CONTRIBUTING.md)). ## Usage ``` usage: weather-dl [-h] [-f] [-d] [-l] [-m MANIFEST_LOCATION] [-n NUM_REQUESTS_PER_KEY] [-p PARTITION_CHUNKS] [-s {in-order,fair}] config [config ...] Weather Downloader ingests weather data to cloud storage. positional arguments: config path/to/configs.cfg, containing client and data information. Can take multiple configs.Accepts *.cfg and *.json files. ``` _Common options_: * `-f, --force-download`: Force redownload of partitions that were previously downloaded. * `-d, --dry-run`: Run pipeline steps without _actually_ downloading or writing to cloud storage. * `-l, --local-run`: Run locally and download to local hard drive. The data and manifest directory is set by default to '<$CWD>/local_run'. The runner will be set to `DirectRunner`. The only other relevant option is the config and `--direct_num_workers` * `-m, --manifest-location MANIFEST_LOCATION`: Location of the manifest. By default, it will use Cloud Logging (stdout for direct runner). You can set the name of the manifest as the hostname of a URL with the 'cli' protocol. For example, `cli://manifest` will prefix all the manifest logs as '[manifest]'. In addition, users can specify either a Firestore collection URI (`fs://?projectId=`), or BigQuery table (`bq://..`), or `noop://` for an in-memory location. * `-n, --num-requests-per-key`: Number of concurrent requests to make per API key. Default: make an educated guess per client & config. Please see the client documentation for more details. * `-p, --partition-chunks`: Group shards into chunks of this size when computing the partitions. Specifically, this controls how we chunk elements in a cartesian product, which affects parallelization of that step. Default: chunks of 1000 elements for 'in-order' scheduling. Chunks of 1 element for 'fair' scheduling. * `-s, --schedule {in-order,fair}`: When using multiple configs, decide how partitions are scheduled: 'in-order' implies that partitions will be processed in sequential order of each config; 'fair' means that partitions from each config will be interspersed evenly. Note: When using 'fair' scheduling, we recommend you set the '--partition-chunks' to a much smaller number. Default: 'in-order'. * `--log-level`: An integer to configure log level. Default: 2(INFO). * `--use-local-code`: Supply local code to the Runner. Default: False. > Note: > * In case of BigQuery manifest tool will create the BQ table itself, if not already present. > Or it will use the existing table but can report errors in case of schema mismatch. > * To run complex queries on the Firestore manifest, users may find it helpful to replicate Firestore to BigQuery > using the automated process described in > [this article](https://medium.com/@ammppp/automated-firestore-replication-to-bigquery-15915d518e38). > By following the step-by-step instructions, users can easily set up the automated replication and then use BigQuery > to perform advanced analysis on the data. Invoke with `-h` or `--help` to see the full range of options. For further information on how to write config files, please consult [this documentation](../Configuration.md). _Usage Examples_: ```bash weather-dl configs/era5_example_config_local_run.cfg --local-run ``` Preview download with a dry run: ```bash weather-dl configs/mars_example_config.cfg --dry-run ``` Using DataflowRunner ```bash weather-dl configs/mars_example_config.cfg \ --runner DataflowRunner \ --project $PROJECT \ --temp_location gs://$BUCKET/tmp \ --job_name $JOB_NAME ``` Using DataflowRunner and using local code for pipeline ```bash weather-dl configs/mars_example_config.cfg \ --runner DataflowRunner \ --project $PROJECT \ --temp_location gs://$BUCKET/tmp \ --job_name $JOB_NAME \ --use-local-code ``` Using the DataflowRunner and specifying 3 requests per license ```bash weather-dl configs/mars_example_config.cfg \ -n 3 \ --runner DataflowRunner \ --project $PROJECT \ --temp_location gs://$BUCKET/tmp \ --job_name $JOB_NAME ``` For a full list of how to configure the Dataflow pipeline, please review [this table](https://cloud.google.com/dataflow/docs/reference/pipeline-options). ## Monitoring You can view how your ECMWF API jobs are by visitng the client-specific job queue: * [MARS](https://apps.ecmwf.int/mars-activity/) * [Copernicus](https://cds.climate.copernicus.eu/requests?tab=all) If you use Google Cloud Storage, we recommend using [`gsutil` (link)](https://cloud.google.com/storage/docs/gsutil) to inspect the progress of your downloads. For example: ```shell # Check that the file-sizes of your downloads look alright gcloud storage du --readable-sizes gs://your-cloud-bucket/mars-data/*T00z.nc # See how many downloads have finished gcloud storage du --readable-sizes gs://your-cloud-bucket/mars-data/*T00z.nc | wc -l ``` ================================================ FILE: weather_dl/__init__.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: weather_dl/download_pipeline/__init__.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .pipeline import run, pipeline def cli(extra=[]): import sys pipeline(run(sys.argv + extra)) ================================================ FILE: weather_dl/download_pipeline/clients.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ECMWF Downloader Clients.""" import abc import collections import contextlib import datetime import io import json import logging import os import time import typing as t import warnings from urllib.parse import urljoin from cdsapi import api as cds_api import urllib3 from ecmwfapi import api from .config import Config, optimize_selection_partition from .manifest import Manifest, Stage from .util import download_with_aria2, retry_with_exponential_backoff warnings.simplefilter( "ignore", category=urllib3.connectionpool.InsecureRequestWarning) class Client(abc.ABC): """Weather data provider client interface. Defines methods and properties required to efficiently interact with weather data providers. Attributes: config: A config that contains pipeline parameters, such as API keys. level: Default log level for the client. """ def __init__(self, config: Config, level: int = logging.INFO) -> None: """Clients are initialized with the general CLI configuration.""" self.config = config self.logger = logging.getLogger(f'{__name__}.{type(self).__name__}') self.logger.setLevel(level) @abc.abstractmethod def retrieve(self, dataset: str, selection: t.Dict, output: str, manifest: Manifest) -> None: """Download from data source.""" pass @classmethod @abc.abstractmethod def num_requests_per_key(cls, dataset: str) -> int: """Specifies the number of workers to be used per api key for the dataset.""" pass @property @abc.abstractmethod def license_url(self): """Specifies the License URL.""" pass class SplitCDSRequest(): """Extended CDS class that separates fetch and download stage.""" def __init__(self, *args, **kwargs): self.__cds_client = cds_api.Client(*args, **kwargs) @retry_with_exponential_backoff def _download(self, url, path: str, size: int) -> None: self.__cds_client.info("Downloading %s to %s (%s)", url, path, cds_api.bytes_to_string(size)) start = time.time() download_with_aria2(url, path) elapsed = time.time() - start if elapsed: self.__cds_client.info("Download rate %s/s", cds_api.bytes_to_string(size / elapsed)) def fetch(self, request: t.Dict, dataset: str) -> t.Dict: result = self.__cds_client.retrieve(dataset, request) return {'href': result.location, 'size': result.content_length} def download(self, result: cds_api.Result, target: t.Optional[str] = None) -> None: if target: if os.path.exists(target): # Empty the target file, if it already exists, otherwise the # transfer below might be fooled into thinking we're resuming # an interrupted download. open(target, "w").close() self._download(result["href"], target, result["size"]) class CdsClient(Client): """A client to access weather data from the Cloud Data Store (CDS). Datasets on CDS can be found at: https://cds.climate.copernicus.eu/cdsapp#!/search?type=dataset The parameters section of the input `config` requires two values: `api_url` and `api_key`. Or, these values can be set as the environment variables: `CDSAPI_URL` and `CDSAPI_KEY`. These can be acquired from the following URL, which requires creating a free account: https://cds.climate.copernicus.eu/api-how-to The CDS global queues for data access has dynamic rate limits. These can be viewed live here: https://cds.climate.copernicus.eu/live/limits. Attributes: config: A config that contains pipeline parameters, such as API keys. level: Default log level for the client. """ """Name patterns of datasets that are hosted internally on CDS servers.""" cds_hosted_datasets = {'reanalysis-era'} def retrieve(self, dataset: str, selection: t.Dict, output: str, manifest: Manifest) -> None: c = CDSClientExtended( url=self.config.kwargs.get('api_url', os.environ.get('CDSAPI_URL')), key=self.config.kwargs.get('api_key', os.environ.get('CDSAPI_KEY')), debug_callback=self.logger.debug, info_callback=self.logger.info, warning_callback=self.logger.warning, error_callback=self.logger.error, ) selection_ = optimize_selection_partition(selection) with StdoutLogger(self.logger, level=logging.DEBUG): manifest.set_stage(Stage.FETCH) precise_fetch_start_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec='seconds') ) manifest.prev_stage_precise_start_time = precise_fetch_start_time result = c.fetch(selection_, dataset) manifest.set_stage(Stage.DOWNLOAD) precise_download_start_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec='seconds') ) manifest.prev_stage_precise_start_time = precise_download_start_time c.download(result, target=output) @property def license_url(self): return 'https://cds.climate.copernicus.eu/api/v2/terms/static/licence-to-use-copernicus-products.pdf' @classmethod def num_requests_per_key(cls, dataset: str) -> int: """Number of requests per key from the CDS API. CDS has dynamic, data-specific limits, defined here: https://cds.climate.copernicus.eu/live/limits Typically, the reanalysis dataset allows for 3-5 simultaneous requets. For all standard CDS data (backed on disk drives), it's common that 2 requests are allowed, though this is dynamically set, too. If the Beam pipeline encounters a user request limit error, please cancel all outstanding requests (per each user account) at the following link: https://cds.climate.copernicus.eu/cdsapp#!/yourrequests """ # TODO(#15): Parse live CDS limits API to set data-specific limits. for internal_set in cls.cds_hosted_datasets: if dataset.startswith(internal_set): return 5 return 2 class StdoutLogger(io.StringIO): """Special logger to redirect stdout to logs.""" def __init__(self, logger_: logging.Logger, level: int = logging.INFO): super().__init__() self.logger = logger_ self.level = level self._redirector = contextlib.redirect_stdout(self) def log(self, msg) -> None: self.logger.log(self.level, msg) def write(self, msg): if msg and not msg.isspace(): self.logger.log(self.level, msg) def __enter__(self): self._redirector.__enter__() return self def __exit__(self, exc_type, exc_value, traceback): # let contextlib do any exception handling here self._redirector.__exit__(exc_type, exc_value, traceback) class SplitMARSRequest(api.APIRequest): """Extended MARS APIRequest class that separates fetch and download stage.""" @retry_with_exponential_backoff def _download(self, url, path: str, size: int) -> None: self.log( "Transferring %s into %s" % (self._bytename(size), path) ) self.log("From %s" % (url,)) download_with_aria2(url, path) def fetch(self, request: t.Dict, dataset: str) -> t.Dict: status = None self.connection.submit("%s/%s/requests" % (self.url, self.service), request) self.log("Request submitted") self.log("Request id: " + self.connection.last.get("name")) if self.connection.status != status: status = self.connection.status self.log("Request is %s" % (status,)) while not self.connection.ready(): if self.connection.status != status: status = self.connection.status self.log("Request is %s" % (status,)) self.connection.wait() if self.connection.status != status: status = self.connection.status self.log("Request is %s" % (status,)) result = self.connection.result() return result def download(self, result: t.Dict, target: t.Optional[str] = None) -> None: if target: if os.path.exists(target): # Empty the target file, if it already exists, otherwise the # transfer below might be fooled into thinking we're resuming # an interrupted download. open(target, "w").close() self._download(urljoin(self.url, result["href"]), target, result["size"]) self.connection.cleanup() class SplitRequestMixin: c = None def fetch(self, req: t.Dict, dataset: t.Optional[str] = None) -> t.Dict: return self.c.fetch(req, dataset) def download(self, res: t.Dict, target: str) -> None: self.c.download(res, target) class CDSClientExtended(SplitRequestMixin): """Extended CDS Client class that separates fetch and download stage.""" def __init__(self, *args, **kwargs): self.c = SplitCDSRequest(*args, **kwargs) class MARSECMWFServiceExtended(api.ECMWFService, SplitRequestMixin): """Extended MARS ECMFService class that separates fetch and download stage.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.c = SplitMARSRequest( self.url, "services/%s" % (self.service,), email=self.email, key=self.key, log=self.log, verbose=self.verbose, quiet=self.quiet, ) class PublicECMWFServerExtended(api.ECMWFDataServer, SplitRequestMixin): def __init__(self, *args, dataset='', **kwargs): super().__init__(*args, **kwargs) self.c = SplitMARSRequest( self.url, "datasets/%s" % (dataset,), email=self.email, key=self.key, log=self.log, verbose=self.verbose, ) class MarsClient(Client): """A client to access data from the Meteorological Archival and Retrieval System (MARS). See https://www.ecmwf.int/en/forecasts/datasets for a summary of datasets available on MARS. Most notable, MARS provides access to ECMWF's Operational Archive https://www.ecmwf.int/en/forecasts/dataset/operational-archive. The client config must contain three parameters to autheticate access to the MARS archive: `api_key`, `api_url`, and `api_email`. These can also be configued by setting the commensurate environment variables: `MARSAPI_KEY`, `MARSAPI_URL`, and `MARSAPI_EMAIL`. These credentials can be looked up by after registering for an ECMWF account (https://apps.ecmwf.int/registration/) and visitng: https://api.ecmwf.int/v1/key/. MARS server activity can be observed at https://apps.ecmwf.int/mars-activity/. Attributes: config: A config that contains pipeline parameters, such as API keys. level: Default log level for the client. """ def retrieve(self, dataset: str, selection: t.Dict, output: str, manifest: Manifest) -> None: c = MARSECMWFServiceExtended( "mars", key=self.config.kwargs.get('api_key', os.environ.get("MARSAPI_KEY")), url=self.config.kwargs.get('api_url', os.environ.get("MARSAPI_URL")), email=self.config.kwargs.get('api_email', os.environ.get("MARSAPI_EMAIL")), log=self.logger.debug, verbose=True, ) selection_ = optimize_selection_partition(selection) with StdoutLogger(self.logger, level=logging.DEBUG): manifest.set_stage(Stage.FETCH) precise_fetch_start_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec='seconds') ) manifest.prev_stage_precise_start_time = precise_fetch_start_time result = c.fetch(req=selection_) manifest.set_stage(Stage.DOWNLOAD) precise_download_start_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec='seconds') ) manifest.prev_stage_precise_start_time = precise_download_start_time c.download(result, target=output) @property def license_url(self): return 'https://apps.ecmwf.int/datasets/licences/general/' @classmethod def num_requests_per_key(cls, dataset: str) -> int: """Number of requests per key (or user) for the Mars API. Mars allows 2 active requests per user and 20 queued requests per user, as of Sept 27, 2021. To ensure we never hit a rate limit error during download, we only make use of the active requests. See: https://confluence.ecmwf.int/display/UDOC/Total+number+of+requests+a+user+can+submit+-+Web+API+FAQ Queued requests can _only_ be canceled manually from a web dashboard. If the `ERROR 101 (USER_QUEUED_LIMIT_EXCEEDED)` error occurs in the Beam pipeline, then go to http://apps.ecmwf.int/webmars/joblist/ and cancel queued jobs. """ return 2 class ECMWFPublicClient(Client): """A client for ECMWF's public datasets, like TIGGE.""" def retrieve(self, dataset: str, selection: t.Dict, output: str, manifest: Manifest) -> None: c = PublicECMWFServerExtended( url=self.config.kwargs.get('api_url', os.environ.get("MARSAPI_URL")), key=self.config.kwargs.get('api_key', os.environ.get("MARSAPI_KEY")), email=self.config.kwargs.get('api_email', os.environ.get("MARSAPI_EMAIL")), log=self.logger.debug, verbose=True, dataset=dataset, ) selection_ = optimize_selection_partition(selection) with StdoutLogger(self.logger, level=logging.DEBUG): manifest.set_stage(Stage.FETCH) precise_fetch_start_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec='seconds') ) manifest.prev_stage_precise_start_time = precise_fetch_start_time result = c.fetch(req=selection_) manifest.set_stage(Stage.DOWNLOAD) precise_download_start_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec='seconds') ) manifest.prev_stage_precise_start_time = precise_download_start_time c.download(result, target=output) @classmethod def num_requests_per_key(cls, dataset: str) -> int: # Experimentally validated request limit. return 5 @property def license_url(self): if not self.config.dataset: raise ValueError('must specify a dataset for this client!') return f'https://apps.ecmwf.int/datasets/data/{self.config.dataset.lower()}/licence/' class FakeClient(Client): """A client that writes the selection arguments to the output file.""" def retrieve(self, dataset: str, selection: t.Dict, output: str, manifest: Manifest) -> None: manifest.set_stage(Stage.RETRIEVE) precise_retrieve_start_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec='seconds') ) manifest.prev_stage_precise_start_time = precise_retrieve_start_time self.logger.debug(f'Downloading {dataset} to {output}') with open(output, 'w') as f: json.dump({dataset: selection}, f) @property def license_url(self): return 'lorem ipsum' @classmethod def num_requests_per_key(cls, dataset: str) -> int: return 1 CLIENTS = collections.OrderedDict( cds=CdsClient, mars=MarsClient, ecpublic=ECMWFPublicClient, fake=FakeClient, ) ================================================ FILE: weather_dl/download_pipeline/clients_test.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from .clients import FakeClient, CdsClient, MarsClient class MaxWorkersTest(unittest.TestCase): def test_cdsclient_internal(self): self.assertEqual(CdsClient.num_requests_per_key("reanalysis-era5-some-data"), 5) def test_cdsclient_mars_hosted(self): self.assertEqual(CdsClient.num_requests_per_key("reanalysis-carra-height-levels"), 2) def test_marsclient(self): self.assertEqual(MarsClient.num_requests_per_key("reanalysis-era5-some-data"), 2) def test_fakeclient(self): self.assertEqual(FakeClient.num_requests_per_key("reanalysis-era5-some-data"), 1) if __name__ == '__main__': unittest.main() ================================================ FILE: weather_dl/download_pipeline/config.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import calendar import copy import dataclasses import itertools import typing as t Values = t.Union[t.List['Values'], t.Dict[str, 'Values'], bool, int, float, str] # pytype: disable=not-supported-yet @dataclasses.dataclass class Config: """Contains pipeline parameters. Attributes: config_name: Name of the config file. client: Name of the Weather-API-client. Supported clients are mentioned in the 'CLIENTS' variable. dataset (optional): Name of the target dataset. Allowed options are dictated by the client. partition_keys (optional): Choose the keys from the selection section to partition the data request. This will compute a cartesian cross product of the selected keys and assign each as their own download. target_path: Download artifact filename template. Can make use of Python's standard string formatting. It can contain format symbols to be replaced by partition keys; if this is used, the total number of format symbols must match the number of partition keys. subsection_name: Name of the particular subsection. 'default' if there is no subsection. force_download: Force redownload of partitions that were previously downloaded. user_id: Username from the environment variables. kwargs (optional): For representing subsections or any other parameters. selection: Contains parameters used to select desired data. """ config_name: str = "" client: str = "" dataset: t.Optional[str] = "" target_path: str = "" partition_keys: t.Optional[t.List[str]] = dataclasses.field(default_factory=list) subsection_name: str = "default" force_download: bool = False user_id: str = "unknown" kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict) selection: t.Dict[str, Values] = dataclasses.field(default_factory=dict) @classmethod def from_dict(cls, config: t.Dict) -> 'Config': config_instance = cls() for section_key, section_value in config.items(): if section_key == "parameters": for key, value in section_value.items(): if hasattr(config_instance, key): setattr(config_instance, key, value) else: config_instance.kwargs[key] = value if section_key == "selection": config_instance.selection = section_value return config_instance def optimize_selection_partition(selection: t.Dict) -> t.Dict: """Compute right-hand-side values for the selection section of a single partition. Used to support custom syntax and optimizations, such as 'all'. """ selection_ = copy.deepcopy(selection) if 'date_range' in selection_.keys(): selection_['date'] = selection_['date_range'][0] del selection_['date_range'] if 'day' in selection_.keys() and selection_['day'] == 'all': years, months = selection_['year'], selection_['month'] multiples_error = "When using day='all' in selection, '/' is not allowed in {type}." if isinstance(years, str): years = [years] if isinstance(months, str): months = [months] date_ranges = [] # Generating dates for every year-month. for year, month in itertools.product(years, months): if isinstance(year, str): assert '/' not in year, multiples_error.format(type='year') if isinstance(month, str): assert '/' not in month, multiples_error.format(type='month') year, month = int(year), int(month) _, n_days_in_month = calendar.monthrange(year, month) date_range = [f'{year:04d}-{month:02d}-{day:02d}' for day in range(1, n_days_in_month + 1)] date_ranges.extend(date_range) selection_['date'] = date_ranges del selection_['day'] del selection_['month'] del selection_['year'] return selection_ ================================================ FILE: weather_dl/download_pipeline/config_test.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import calendar import itertools import os import unittest import weather_dl from .config import optimize_selection_partition from .pipeline import run class ConfigTest(unittest.TestCase): def setUp(self): self.data_dir = f'{next(iter(weather_dl.__path__))}/../configs' def test_process_config_files(self): for filename in os.listdir(self.data_dir): if filename.startswith('.'): continue # only files, no directories if os.path.isdir(os.path.join(self.data_dir, filename)): continue with self.subTest(filename=filename): config = os.path.join(self.data_dir, filename) try: run(["weather-dl", config, "--dry-run"]) except: # noqa: E722 self.fail(f'Config {filename!r} is incorrect.') def test_process_multi_config_files(self): for filename in os.listdir(self.data_dir): if filename.startswith('.'): continue # only directories, no files if not os.path.isdir(os.path.join(self.data_dir, filename)): continue with self.subTest(dirname=filename): configs_dir = os.path.join(self.data_dir, filename) configs = [os.path.join(configs_dir, c) for c in os.listdir(configs_dir)] try: run(["weather-dl"] + configs + ["--dry-run"]) except: # noqa: E722 self.fail(f'Configs {filename!r} is incorrect.') if __name__ == '__main__': unittest.main() class SelectionSyntaxTest(unittest.TestCase): def test_all_days__invalid_year(self): selection_with_multiple_years = {'year': '2020/2021', 'month': '2', 'day': 'all'} with self.assertRaisesRegex(AssertionError, "When using day='all' in selection, '/' is not allowed in year."): optimize_selection_partition(selection_with_multiple_years) selection_with_multiple_years = {'year': ['2020/2021'], 'month': '2', 'day': 'all'} with self.assertRaisesRegex(AssertionError, "When using day='all' in selection, '/' is not allowed in year."): optimize_selection_partition(selection_with_multiple_years) def test_all_days__invalid_month(self): selection_with_multiple_years = {'year': '2020', 'month': '1/2/3', 'day': 'all'} with self.assertRaisesRegex(AssertionError, "When using day='all' in selection, '/' is not allowed in month."): optimize_selection_partition(selection_with_multiple_years) selection_with_multiple_years = {'year': '2020', 'month': ['1/2/3'], 'day': 'all'} with self.assertRaisesRegex(AssertionError, "When using day='all' in selection, '/' is not allowed in month."): optimize_selection_partition(selection_with_multiple_years) def test_date_range(self): selection_with_date_range = {'date_range': ['2017-01-01/to/2017-01-10']} actual = optimize_selection_partition(selection_with_date_range) self.assertEqual(actual['date'], selection_with_date_range['date_range'][0]) def test_with_year_month_as_string(self): selection_with_multiple_months = {'year': '2017', 'month': '12', 'day':'all'} actual = optimize_selection_partition(selection_with_multiple_months) expected = [] _, n_days_in_month = calendar.monthrange(2017, 12) expected = [f'2017-12-{day:02d}' for day in range(1, n_days_in_month + 1)] self.assertEqual(actual['date'], expected) self.assertNotIn('day', actual) self.assertNotIn('month', actual) self.assertNotIn('year', actual) def test_with_multiple_months(self): selection_with_multiple_months = {'year': ['2017'], 'month': ['2', '4', '6', '8'], 'day':'all'} actual = optimize_selection_partition(selection_with_multiple_months) expected = [] for y, m in itertools.product(selection_with_multiple_months['year'], selection_with_multiple_months['month']): y, m = int(y), int(m) _, n_days_in_month = calendar.monthrange(y, m) date_range = [f'{y:04d}-{m:02d}-{day:02d}' for day in range(1, n_days_in_month + 1)] expected.extend(date_range) self.assertEqual(actual['date'], expected) self.assertNotIn('day', actual) self.assertNotIn('month', actual) self.assertNotIn('year', actual) def test_with_multiple_years(self): selection_with_multiple_years = {'year': ['2017', '2018'], 'month': ['2', '4', '6', '8'], 'day':'all'} actual = optimize_selection_partition(selection_with_multiple_years) expected = [] for y, m in itertools.product(selection_with_multiple_years['year'], selection_with_multiple_years['month']): y, m = int(y), int(m) _, n_days_in_month = calendar.monthrange(y, m) date_range = [f'{y:04d}-{m:02d}-{day:02d}' for day in range(1, n_days_in_month + 1)] expected.extend(date_range) self.assertEqual(actual['date'], expected) self.assertNotIn('day', actual) self.assertNotIn('month', actual) self.assertNotIn('year', actual) ================================================ FILE: weather_dl/download_pipeline/fetcher.py ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import datetime import logging import tempfile import typing as t import apache_beam as beam from .clients import CLIENTS, Client from .config import Config from .manifest import Manifest, NoOpManifest, Location, Stage from .parsers import prepare_target_name from .partition import skip_partition from .stores import Store, FSStore from .util import copy, retry_with_exponential_backoff logger = logging.getLogger(__name__) @dataclasses.dataclass class Fetcher(beam.DoFn): """Executes client download requests. Given a sequence of configs (keyed by subsection parameters and number of allowed requests), this will execute retrievals via each client. The keyed strucutre, the result of a `beam.GroupBy` operation, will ensure that all licenses and requests are utilized without any conflict. Attributes: client_name: The name of the download client to construct per each request. manifest: A manifest to keep track of the status of requests store: To manage where downloads are persisted. """ client_name: str manifest: Manifest = NoOpManifest(Location('noop://in-memory')) store: t.Optional[Store] = None log_level: t.Optional[int] = logging.INFO def __post_init__(self): if self.store is None: self.store = FSStore() @retry_with_exponential_backoff def retrieve(self, client: Client, dataset: str, selection: t.Dict, dest: str) -> None: """Retrieve from download client, with retries.""" client.retrieve(dataset, selection, dest, self.manifest) def fetch_data(self, config: Config, *, worker_name: str = 'default') -> None: """Download data from a client to a temp file, then upload to Cloud Storage.""" if not config: return if skip_partition(config, self.store, self.manifest): return client = CLIENTS[self.client_name](config, self.log_level) target = prepare_target_name(config) with tempfile.NamedTemporaryFile() as temp: logger.info(f'[{worker_name}] Fetching data for {target!r}.') with self.manifest.transact(config.config_name, config.dataset, config.selection, target, config.user_id): self.retrieve(client, config.dataset, config.selection, temp.name) self.manifest.set_stage(Stage.UPLOAD) precise_upload_start_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec='seconds') ) self.manifest.prev_stage_precise_start_time = precise_upload_start_time logger.info(f'[{worker_name}] Uploading to store for {target!r}.') # In dry-run mode we actually aren't required to upload a file. if not self.client_name == "fake": copy(temp.name, target) logger.info(f'[{worker_name}] Upload to store complete for {target!r}.') def process(self, element) -> None: # element: Tuple[Tuple[str, int], Iterator[Config]] """Execute download requests one-by-one.""" (subsection, request_idx), partitions = element worker_name = f'{subsection}.{request_idx}' logger.info(f'[{worker_name}] Starting requests...') logger.debug(f'[{worker_name}] Partitions: {partitions!r}.') for partition in partitions: beam.metrics.Metrics.counter('Fetcher', subsection).inc() self.fetch_data(partition, worker_name=worker_name) ================================================ FILE: weather_dl/download_pipeline/fetcher_test.py ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import io import json import os import tempfile import unittest from unittest.mock import patch from .config import Config from .fetcher import Fetcher from .manifest import MockManifest, Location from .stores import InMemoryStore class FetchDataTest(unittest.TestCase): def setUp(self) -> None: self.dummy_manifest = MockManifest(Location('dummy-manifest')) @patch('weather_dl.download_pipeline.clients.CDSClientExtended.download') @patch('weather_dl.download_pipeline.clients.CDSClientExtended.fetch') def test_fetch_data(self, mock_fetch, mock_download): with tempfile.TemporaryDirectory() as tmpdir: config = Config.from_dict({ 'parameters': { 'dataset': 'reanalysis-era5-pressure-levels', 'partition_keys': ['year', 'month'], 'target_path': os.path.join(tmpdir, 'download-{:02d}-{:02d}.nc'), 'api_url': 'https//api-url.com/v1/', 'api_key': '12345', }, 'selection': { 'features': ['pressure'], 'month': ['12'], 'year': ['01'] } }) fetcher = Fetcher('cds', self.dummy_manifest, InMemoryStore()) fetcher.fetch_data(config) self.assertTrue(os.path.exists(os.path.join(tmpdir, 'download-01-12.nc'))) mock_fetch.assert_called_with( config.selection, 'reanalysis-era5-pressure-levels', ) @patch('weather_dl.download_pipeline.clients.CDSClientExtended.download') @patch('weather_dl.download_pipeline.clients.CDSClientExtended.fetch') def test_fetch_data__manifest__returns_success(self, mock_fetch, mock_download): with tempfile.TemporaryDirectory() as tmpdir: config = Config.from_dict({ 'parameters': { 'dataset': 'reanalysis-era5-pressure-levels', 'partition_keys': ['year', 'month'], 'target_path': os.path.join(tmpdir, 'download-{:02d}-{:02d}.nc'), 'api_url': 'https//api-url.com/v1/', 'api_key': '12345', }, 'selection': { 'features': ['pressure'], 'month': ['12'], 'year': ['01'] } }) fetcher = Fetcher('cds', self.dummy_manifest, InMemoryStore()) fetcher.fetch_data(config) self.assertDictContainsSubset(dict( selection=json.dumps(config.selection), location=os.path.join(tmpdir, 'download-01-12.nc'), stage='upload', status='success', error=None, username='unknown', ), list(self.dummy_manifest.records.values())[0]) @patch('weather_dl.download_pipeline.clients.CDSClientExtended.fetch') def test_fetch_data__manifest__records_retrieve_failure(self, mock_fetch): with tempfile.TemporaryDirectory() as tmpdir: config = Config.from_dict({ 'parameters': { 'dataset': 'reanalysis-era5-pressure-levels', 'partition_keys': ['year', 'month'], 'target_path': os.path.join(tmpdir, 'download-{:02d}-{:02d}.nc'), 'api_url': 'https//api-url.com/v1/', 'api_key': '12345', }, 'selection': { 'features': ['pressure'], 'month': ['12'], 'year': ['01'] } }) error = IOError("We don't have enough permissions to download this.") mock_fetch.side_effect = error with self.assertRaises(IOError) as e: fetcher = Fetcher('cds', self.dummy_manifest, InMemoryStore()) fetcher.fetch_data(config) actual = list(self.dummy_manifest.records.values())[0] self.assertDictContainsSubset(dict( selection=json.dumps(config.selection), location=os.path.join(tmpdir, 'download-01-12.nc'), stage='fetch', status='failure', username='unknown', ), actual) self.assertIn(error.args[0], actual['error']) self.assertIn(error.args[0], e.exception.args[0]) @patch('weather_dl.download_pipeline.clients.CDSClientExtended.fetch') def test_fetch_data__manifest__records_gcs_failure(self, mock_fetch): with tempfile.TemporaryDirectory() as tmpdir: config = Config.from_dict({ 'parameters': { 'dataset': 'reanalysis-era5-pressure-levels', 'partition_keys': ['year', 'month'], 'target_path': os.path.join(tmpdir, 'download-{:02d}-{:02d}.nc'), 'api_url': 'https//api-url.com/v1/', 'api_key': '12345', }, 'selection': { 'features': ['pressure'], 'month': ['12'], 'year': ['01'] } }) error = IOError("Can't open gcs file.") mock_fetch.side_effect = error with self.assertRaises(IOError) as e: fetcher = Fetcher('cds', self.dummy_manifest, InMemoryStore()) fetcher.fetch_data(config) actual = list(self.dummy_manifest.records.values())[0] self.assertDictContainsSubset(dict( selection=json.dumps(config.selection), location=os.path.join(tmpdir, 'download-01-12.nc'), stage='fetch', status='failure', username='unknown', ), actual) self.assertIn(error.args[0], actual['error']) self.assertIn(error.args[0], e.exception.args[0]) @patch('weather_dl.download_pipeline.stores.InMemoryStore.open', return_value=io.StringIO()) @patch('weather_dl.download_pipeline.clients.CDSClientExtended.fetch') def test_fetch_data__skips_existing_download(self, mock_fetch, mock_gcs_file): with tempfile.TemporaryDirectory() as tmpdir: config = Config.from_dict({ 'parameters': { 'dataset': 'reanalysis-era5-pressure-levels', 'partition_keys': ['year', 'month'], 'target_path': os.path.join(tmpdir, 'download-{:02d}-{:02d}.nc'), 'api_url': 'https//api-url.com/v1/', 'api_key': '12345', }, 'selection': { 'features': ['pressure'], 'month': ['12'], 'year': ['01'] } }) # target file already exists in store... store = InMemoryStore() store.store[os.path.join(tmpdir, 'download-01-12.nc')] = '' fetcher = Fetcher('cds', self.dummy_manifest, store) fetcher.fetch_data(config) self.assertFalse(mock_gcs_file.called) self.assertFalse(mock_fetch.called) ================================================ FILE: weather_dl/download_pipeline/manifest.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Client interface for connecting to a manifest.""" import abc import collections import dataclasses import datetime import enum import json import logging import os import pandas as pd import threading import time import traceback import typing as t from urllib.parse import urlparse, parse_qsl from .util import ( to_json_serializable_type, fetch_geo_polygon, get_file_size, get_wait_interval, generate_md5_hash, retry_with_exponential_backoff, GLOBAL_COVERAGE_AREA ) import firebase_admin from firebase_admin import firestore from google.cloud import bigquery from google.cloud.firestore_v1 import DocumentReference from google.cloud.firestore_v1.types import WriteResult """An implementation-dependent Manifest URI.""" Location = t.NewType('Location', str) logger = logging.getLogger(__name__) class ManifestException(Exception): """Errors that occur in Manifest Clients.""" pass class Stage(enum.Enum): """A request can be either in one of the following stages at a time: fetch : This represents request is currently in fetch stage i.e. request placed on the client's server & waiting for some result before starting download (eg. MARS client). download : This represents request is currently in download stage i.e. data is being downloading from client's server to the worker's local file system. upload : This represents request is currently in upload stage i.e. data is getting uploaded from worker's local file system to target location (GCS path). retrieve : In case of clients where there is no proper separation of fetch & download stages (eg. CDS client), request will be in the retrieve stage i.e. fetch + download. """ RETRIEVE = 'retrieve' FETCH = 'fetch' DOWNLOAD = 'download' UPLOAD = 'upload' class Status(enum.Enum): """Depicts the request's state status: scheduled : A request partition is created & scheduled for processing. Note: Its corresponding state can be None only. in-progress : This represents the request state is currently in-progress (i.e. running). The next status would be "success" or "failure". success : This represents the request state execution completed successfully without any error. failure : This represents the request state execution failed. """ SCHEDULED = 'scheduled' IN_PROGRESS = 'in-progress' SUCCESS = 'success' FAILURE = 'failure' @dataclasses.dataclass class DownloadStatus(): """Data recorded in `Manifest`s reflecting the status of a download.""" """The name of the config file associated with the request.""" config_name: str = "" """Represents the dataset field of the configuration.""" dataset: t.Optional[str] = "" """Copy of selection section of the configuration.""" selection: t.Dict = dataclasses.field(default_factory=dict) """Location of the downloaded data.""" location: str = "" """Represents area covered by the shard.""" area: str = "" """Current stage of request : 'fetch', 'download', 'retrieve', 'upload' or None.""" stage: t.Optional[Stage] = None """Download status: 'scheduled', 'in-progress', 'success', or 'failure'.""" status: t.Optional[Status] = None """Cause of error, if any.""" error: t.Optional[str] = "" """Identifier for the user running the download.""" username: str = "" """Shard size in GB.""" size: t.Optional[float] = 0 """A UTC datetime when download was scheduled.""" scheduled_time: t.Optional[str] = "" """A UTC datetime when the retrieve stage starts.""" retrieve_start_time: t.Optional[str] = "" """A UTC datetime when the retrieve state ends.""" retrieve_end_time: t.Optional[str] = "" """A UTC datetime when the fetch state starts.""" fetch_start_time: t.Optional[str] = "" """A UTC datetime when the fetch state ends.""" fetch_end_time: t.Optional[str] = "" """A UTC datetime when the download state starts.""" download_start_time: t.Optional[str] = "" """A UTC datetime when the download state ends.""" download_end_time: t.Optional[str] = "" """A UTC datetime when the upload state starts.""" upload_start_time: t.Optional[str] = "" """A UTC datetime when the upload state ends.""" upload_end_time: t.Optional[str] = "" @classmethod def from_dict(cls, download_status: t.Dict) -> 'DownloadStatus': """Instantiate DownloadStatus dataclass from dict.""" download_status_instance = cls() for key, value in download_status.items(): if key == 'status': setattr(download_status_instance, key, Status(value)) elif key == 'stage' and value is not None: setattr(download_status_instance, key, Stage(value)) else: setattr(download_status_instance, key, value) return download_status_instance @classmethod def to_dict(cls, instance) -> t.Dict: """Return the fields of a dataclass instance as a manifest ingestible dictionary mapping of field names to field values.""" download_status_dict = {} for field in dataclasses.fields(instance): key = field.name value = getattr(instance, field.name) if isinstance(value, Status) or isinstance(value, Stage): download_status_dict[key] = value.value elif isinstance(value, pd.Timestamp): download_status_dict[key] = value.isoformat() elif key == 'selection' and value is not None: download_status_dict[key] = json.dumps(value) else: download_status_dict[key] = value return download_status_dict @dataclasses.dataclass class Manifest(abc.ABC): """Abstract manifest of download statuses. Update download statuses to some storage medium. This class lets one indicate that a download is `scheduled` or in a transaction process. In the event of a transaction, a download will be updated with an `in-progress`, `success` or `failure` status (with accompanying metadata). Example: ``` my_manifest = parse_manifest_location(Location('fs://some-firestore-collection')) # Schedule data for download my_manifest.schedule({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username') # ... # Initiate a transaction – it will record that the download is `in-progess` with my_manifest.transact({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username') as tx: # download logic here pass # ... # on error, will record the download as a `failure` before propagating the error. By default, it will # record download as a `success`. ``` Attributes: location: An implementation-specific manifest URI. status: The current `DownloadStatus` of the Manifest. """ location: Location # To reduce the impact of _read() and _update() calls # on the start time of the stage. prev_stage_precise_start_time: t.Optional[str] = None status: t.Optional[DownloadStatus] = None # This is overridden in subclass. def __post_init__(self): """Initialize the manifest.""" pass def schedule(self, config_name: str, dataset: str, selection: t.Dict, location: str, user: str) -> None: """Indicate that a job has been scheduled for download. 'scheduled' jobs occur before 'in-progress', 'success' or 'finished'. """ scheduled_time = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).isoformat(timespec='seconds') self.status = DownloadStatus( config_name=config_name, dataset=dataset if dataset else None, selection=selection, location=location, area=fetch_geo_polygon(selection.get('area', GLOBAL_COVERAGE_AREA)), username=user, stage=None, status=Status.SCHEDULED, error=None, size=None, scheduled_time=scheduled_time, retrieve_start_time=None, retrieve_end_time=None, fetch_start_time=None, fetch_end_time=None, download_start_time=None, download_end_time=None, upload_start_time=None, upload_end_time=None, ) self._update(self.status) def skip(self, config_name: str, dataset: str, selection: t.Dict, location: str, user: str) -> None: """Updates the manifest to mark the shards that were skipped in the current job as 'upload' stage and 'success' status, indicating that they have already been downloaded. """ old_status = self._read(location) # The manifest needs to be updated for a skipped shard if its entry is not present, or # if the stage is not 'upload', or if the stage is 'upload' but the status is not 'success'. if old_status.location != location or old_status.stage != Stage.UPLOAD or old_status.status != Status.SUCCESS: current_utc_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec='seconds') ) size = get_file_size(location) status = DownloadStatus( config_name=config_name, dataset=dataset if dataset else None, selection=selection, location=location, area=fetch_geo_polygon(selection.get('area', GLOBAL_COVERAGE_AREA)), username=user, stage=Stage.UPLOAD, status=Status.SUCCESS, error=None, size=size, scheduled_time=None, retrieve_start_time=None, retrieve_end_time=None, fetch_start_time=None, fetch_end_time=None, download_start_time=None, download_end_time=None, upload_start_time=current_utc_time, upload_end_time=current_utc_time, ) self._update(status) logger.info(f'Manifest updated for skipped shard: {location!r} -- {DownloadStatus.to_dict(status)!r}.') def _set_for_transaction(self, config_name: str, dataset: str, selection: t.Dict, location: str, user: str) -> None: """Reset Manifest state in preparation for a new transaction.""" self.status = dataclasses.replace(self._read(location)) self.status.config_name = config_name self.status.dataset = dataset if dataset else None self.status.selection = selection self.status.location = location self.status.username = user def __enter__(self) -> None: pass def __exit__(self, exc_type, exc_inst, exc_tb) -> None: """Record end status of a transaction as either 'success' or 'failure'.""" if exc_type is None: status = Status.SUCCESS error = None else: status = Status.FAILURE # For explanation, see https://docs.python.org/3/library/traceback.html#traceback.format_exception error = '\n'.join(traceback.format_exception(exc_type, exc_inst, exc_tb)) new_status = dataclasses.replace(self.status) new_status.error = error new_status.status = status current_utc_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec='seconds') ) # This is necessary for setting the precise start time of the previous stage # and end time of the final stage, as well as handling the case of Status.FAILURE. if new_status.stage == Stage.FETCH: new_status.fetch_start_time = self.prev_stage_precise_start_time new_status.fetch_end_time = current_utc_time elif new_status.stage == Stage.RETRIEVE: new_status.retrieve_start_time = self.prev_stage_precise_start_time new_status.retrieve_end_time = current_utc_time elif new_status.stage == Stage.DOWNLOAD: new_status.download_start_time = self.prev_stage_precise_start_time new_status.download_end_time = current_utc_time else: new_status.upload_start_time = self.prev_stage_precise_start_time new_status.upload_end_time = current_utc_time new_status.size = get_file_size(new_status.location) self.status = new_status self._update(self.status) def transact(self, config_name: str, dataset: str, selection: t.Dict, location: str, user: str) -> 'Manifest': """Create a download transaction.""" self._set_for_transaction(config_name, dataset, selection, location, user) return self def set_stage(self, stage: Stage) -> None: """Sets the current stage in manifest.""" prev_stage = self.status.stage new_status = dataclasses.replace(self.status) new_status.stage = stage new_status.status = Status.IN_PROGRESS current_utc_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec='seconds') ) if stage == Stage.FETCH: new_status.fetch_start_time = current_utc_time new_status.fetch_end_time = None new_status.download_start_time = None new_status.download_end_time = None elif stage == Stage.RETRIEVE: new_status.retrieve_start_time = current_utc_time elif stage == Stage.DOWNLOAD: new_status.fetch_start_time = self.prev_stage_precise_start_time new_status.fetch_end_time = current_utc_time new_status.download_start_time = current_utc_time else: if prev_stage == Stage.DOWNLOAD: new_status.download_start_time = self.prev_stage_precise_start_time new_status.download_end_time = current_utc_time else: new_status.retrieve_start_time = self.prev_stage_precise_start_time new_status.retrieve_end_time = current_utc_time new_status.upload_start_time = current_utc_time self.status = new_status self._update(self.status) @abc.abstractmethod def _read(self, location: str) -> DownloadStatus: pass @abc.abstractmethod def _update(self, download_status: DownloadStatus) -> None: pass class ConsoleManifest(Manifest): def __post_init__(self): self.name = urlparse(self.location).hostname def _read(self, location: str) -> DownloadStatus: return DownloadStatus() def _update(self, download_status: DownloadStatus) -> None: logger.info(f'[{self.name}] {DownloadStatus.to_dict(download_status)!r}') class LocalManifest(Manifest): """Writes a JSON representation of the manifest to local file.""" _lock = threading.Lock() def __init__(self, location: Location) -> None: super().__init__(Location(os.path.join(location, 'manifest.json'))) if location and not os.path.exists(location): os.makedirs(location) # If the file is empty, it should start out as an empty JSON object. if not os.path.exists(self.location) or os.path.getsize(self.location) == 0: with open(self.location, 'w') as file: json.dump({}, file) def _read(self, location: str) -> DownloadStatus: """Reads the JSON data from a manifest.""" assert os.path.exists(self.location), f'{self.location} must exist!' with LocalManifest._lock: with open(self.location, 'r') as file: manifest = json.load(file) return DownloadStatus.from_dict(manifest.get(location, {})) def _update(self, download_status: DownloadStatus) -> None: """Writes the JSON data to a manifest.""" assert os.path.exists(self.location), f'{self.location} must exist!' with LocalManifest._lock: with open(self.location, 'r') as file: manifest = json.load(file) status = DownloadStatus.to_dict(download_status) manifest[status['location']] = status with open(self.location, 'w') as file: json.dump(manifest, file) logger.debug('Manifest written to.') logger.debug(download_status) class BQManifest(Manifest): """Writes a JSON representation of the manifest to BQ file. This is an append-only implementation, the latest value in the manifest represents the current state of a download. """ def __init__(self, location: Location) -> None: super().__init__(Location(location[5:])) TABLE_SCHEMA = [ bigquery.SchemaField('config_name', 'STRING', mode='REQUIRED', description="The name of the config file associated with the request."), bigquery.SchemaField('dataset', 'STRING', mode='NULLABLE', description="Represents the dataset field of the configuration."), bigquery.SchemaField('selection', 'JSON', mode='REQUIRED', description="Copy of selection section of the configuration."), bigquery.SchemaField('location', 'STRING', mode='REQUIRED', description="Location of the downloaded data."), bigquery.SchemaField('area', 'STRING', mode='NULLABLE', description="Represents area covered by the shard. " "ST_GeogFromGeoJson(area): To convert a GeoJSON geometry object into a " "GEOGRAPHY value. " "ST_COVERS(geography_expression, ST_GEOGPOINT(longitude, latitude)): To check " "if a point lies in the given area or not."), bigquery.SchemaField('stage', 'STRING', mode='NULLABLE', description="Current stage of request : 'fetch', 'download', 'retrieve', 'upload' " "or None."), bigquery.SchemaField('status', 'STRING', mode='REQUIRED', description="Download status: 'scheduled', 'in-progress', 'success', or 'failure'."), bigquery.SchemaField('error', 'STRING', mode='NULLABLE', description="Cause of error, if any."), bigquery.SchemaField('username', 'STRING', mode='REQUIRED', description="Identifier for the user running the download."), bigquery.SchemaField('size', 'FLOAT', mode='NULLABLE', description="Shard size in GB."), bigquery.SchemaField('scheduled_time', 'TIMESTAMP', mode='NULLABLE', description="A UTC datetime when download was scheduled."), bigquery.SchemaField('retrieve_start_time', 'TIMESTAMP', mode='NULLABLE', description="A UTC datetime when the retrieve stage starts."), bigquery.SchemaField('retrieve_end_time', 'TIMESTAMP', mode='NULLABLE', description="A UTC datetime when the retrieve state ends."), bigquery.SchemaField('fetch_start_time', 'TIMESTAMP', mode='NULLABLE', description="A UTC datetime when the fetch state starts."), bigquery.SchemaField('fetch_end_time', 'TIMESTAMP', mode='NULLABLE', description="A UTC datetime when the fetch state ends."), bigquery.SchemaField('download_start_time', 'TIMESTAMP', mode='NULLABLE', description="A UTC datetime when the download state starts."), bigquery.SchemaField('download_end_time', 'TIMESTAMP', mode='NULLABLE', description="A UTC datetime when the download state ends."), bigquery.SchemaField('upload_start_time', 'TIMESTAMP', mode='NULLABLE', description="A UTC datetime when the upload state starts."), bigquery.SchemaField('upload_end_time', 'TIMESTAMP', mode='NULLABLE', description="A UTC datetime when the upload state ends."), ] table = bigquery.Table(self.location, schema=TABLE_SCHEMA) with bigquery.Client() as client: client.create_table(table, exists_ok=True) def _read(self, location: str) -> DownloadStatus: """Reads the JSON data from a manifest.""" with bigquery.Client() as client: select_statement = f"SELECT * FROM {self.location} WHERE location = @location" # Build the QueryJobConfig object with the parameters. job_config = bigquery.QueryJobConfig() job_config.query_parameters = [bigquery.ScalarQueryParameter('location', 'STRING', location)] # Execute the merge statement with the parameters. query_job = client.query(select_statement, job_config=job_config) # Wait for the query to execute. result = query_job.result() row = {} if result.total_rows > 0: records = result.to_dataframe().to_dict('records') row = {n: to_json_serializable_type(v) for n, v in records[0].items()} return DownloadStatus.from_dict(row) # Added retry here to handle the concurrency issue in BigQuery. # Eg: 400 Resources exceeded during query execution: Too many DML statements outstanding # against table , limit is 20 @retry_with_exponential_backoff def _update(self, download_status: DownloadStatus) -> None: """Writes the JSON data to a manifest.""" with bigquery.Client() as client: status = DownloadStatus.to_dict(download_status) table = client.get_table(self.location) columns = [field.name for field in table.schema] parameter_type_mapping = {field.name: field.field_type for field in table.schema} update_dml = [f"{col} = @{col}" for col in columns] insert_dml = [f"@{col}" for col in columns] params = {col: status[col] for col in columns} # Build the merge statement as a string with parameter placeholders. merge_statement = f""" MERGE {self.location} T USING ( SELECT @location as location ) S ON T.location = S.location WHEN MATCHED THEN UPDATE SET {', '.join(update_dml)} WHEN NOT MATCHED THEN INSERT ({", ".join(columns)}) VALUES ({', '.join(insert_dml)}) """ logger.debug(merge_statement) # Build the QueryJobConfig object with the parameters. job_config = bigquery.QueryJobConfig() job_config.query_parameters = [bigquery.ScalarQueryParameter(col, parameter_type_mapping[col], value) for col, value in params.items()] # Execute the merge statement with the parameters. query_job = client.query(merge_statement, job_config=job_config) # Wait for the query to execute. query_job.result() logger.debug('Manifest written to.') logger.debug(download_status) class FirestoreManifest(Manifest): """A Firestore Manifest. This Manifest implementation stores DownloadStatuses in a Firebase document store. The document hierarchy for the manifest is as follows: [manifest ] ├── doc_id (md5 hash of the path) { 'selection': {...}, 'location': ..., 'username': ... } └── etc... Where `[]` indicates a collection and ` {...}` indicates a document. """ def _get_db(self) -> firestore.firestore.Client: """Acquire a firestore client, initializing the firebase app if necessary. Will attempt to get the db client five times. If it's still unsuccessful, a `ManifestException` will be raised. """ db = None attempts = 0 while db is None: try: db = firestore.client() except ValueError as e: # The above call will fail with a value error when the firebase app is not initialized. # Initialize the app here, and try again. firebase_admin.initialize_app(options=self.get_firestore_config()) logger.info('Initialized Firebase App.') if attempts > 4: raise ManifestException('Exceeded number of retries to get firestore client.') from e time.sleep(get_wait_interval(attempts)) attempts += 1 return db def _read(self, location: str) -> DownloadStatus: """Reads the JSON data from a manifest.""" doc_id = generate_md5_hash(location) # Update document with download status download_doc_ref = ( self.root_document_for_store(doc_id) ) result = download_doc_ref.get() row = {} if result.exists: records = result.to_dict() row = {n: to_json_serializable_type(v) for n, v in records.items()} return DownloadStatus.from_dict(row) def _update(self, download_status: DownloadStatus) -> None: """Update or create a download status record.""" logger.debug('Updating Firestore Manifest.') status = DownloadStatus.to_dict(download_status) doc_id = generate_md5_hash(status['location']) # Update document with download status download_doc_ref = ( self.root_document_for_store(doc_id) ) result: WriteResult = download_doc_ref.set(status) logger.debug(f'Firestore manifest updated. ' f'update_time={result.update_time}, ' f'filename={download_status.location}.') def root_document_for_store(self, store_scheme: str) -> DocumentReference: """Get the root manifest document given the user's config and current document's storage location.""" # Get user-defined collection for manifest. root_collection = self.get_firestore_config().get('collection', 'manifest') return self._get_db().collection(root_collection).document(store_scheme) def get_firestore_config(self) -> t.Dict: """Parse firestore Location format: 'fs://?projectId=' Users must specify a 'projectId' query parameter in the firestore location. If this argument isn't passed in, users must set the `GOOGLE_CLOUD_PROJECT` environment variable. Users may specify options to `firebase_admin.initialize_app()` via query arguments in the URL. For more information about what options are available, consult this documentation: https://firebase.google.com/docs/reference/admin/python/firebase_admin#initialize_app Note: each query key-value pair may only appear once. If there are duplicates, the last pair will be used. Optionally, users may configure these options via the `FIREBASE_CONFIG` environment variable, which is typically a path/to/a/file.json. Examples: >>> location = Location("fs://my-collection?projectId=my-project-id&storageBucket=foo") >>> FirestoreManifest(location).get_firestore_config() {'collection': 'my-collection', 'projectId': 'my-project-id', 'storageBucket': 'foo'} Raises: ValueError: If query parameters are malformed. AssertionError: If the 'projectId' query parameter is not set. """ parsed = urlparse(self.location) query_params = {} if parsed.query: query_params = dict(parse_qsl(parsed.query, strict_parsing=True)) return {'collection': parsed.netloc, **query_params} class MockManifest(Manifest): """In-memory mock manifest.""" def __init__(self, location: Location) -> None: super().__init__(location) self.records = {} def _read(self, location: str) -> DownloadStatus: manifest = self.records return DownloadStatus.from_dict(manifest.get(location, {})) def _update(self, download_status: DownloadStatus) -> None: status = DownloadStatus.to_dict(download_status) self.records.update({status.get('location'): status}) logger.debug('Manifest updated.') logger.debug(download_status) class NoOpManifest(Manifest): """A manifest that performs no operations.""" def _read(self, location: str) -> DownloadStatus: return DownloadStatus() def _update(self, download_status: DownloadStatus) -> None: pass """Exposed manifest implementations. Users can choose their preferred manifest implementation by via the protocol of the Manifest Location. The protocol corresponds to the keys of this ordered dictionary. If no protocol is specified, we assume the user wants to write to the local file system. If no key is found, the `NoOpManifest` option will be chosen. See `parsers:parse_manifest_location`. """ MANIFESTS = collections.OrderedDict({ 'cli': ConsoleManifest, 'fs': FirestoreManifest, 'bq': BQManifest, '': LocalManifest, }) if __name__ == '__main__': # Execute doc tests import doctest doctest.testmod() ================================================ FILE: weather_dl/download_pipeline/manifest_test.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import random import string import tempfile import typing as t import unittest from .manifest import LocalManifest, Location, DownloadStatus, Status, Stage def rand_str(max_len=32): return ''.join([random.choice(string.printable) for _ in range(random.randint(0, max_len))]) def make_download_status(location: t.Optional[str] = None) -> DownloadStatus: return DownloadStatus( selection={}, location=rand_str() if location is None else location, status=random.choice([Status.SCHEDULED, Status.IN_PROGRESS, Status.SUCCESS, Status.FAILURE]), error=random.choice([None] + [rand_str(100) for _ in range(4)]), username=random.choice(['user', 'alice', 'bob', 'root']), stage=random.choice([Stage.FETCH, Stage.DOWNLOAD, Stage.UPLOAD, Stage.RETRIEVE, None]), size=0.039322572, scheduled_time="2023-02-07T17:15:26+00:00", retrieve_start_time=None, retrieve_end_time=None, fetch_start_time="2023-02-07T17:15:29+00:00", fetch_end_time="2023-02-07T17:21:37+00:00", download_start_time="2023-02-07T17:21:39+00:00", download_end_time="2023-02-07T17:21:56+00:00", upload_start_time="2023-02-07T17:21:59+00:00", upload_end_time="2023-02-07T17:22:03+00:00" ) class LocalManifestTest(unittest.TestCase): NUM_RUNS = 128 def test_empty_manifest_is_valid_json(self): with tempfile.TemporaryDirectory() as dir_: manifest = LocalManifest(Location(dir_)) with open(manifest.location) as file: self.is_valid_json(file) def does_not_overwrite_existing_manifest(self): with tempfile.TemporaryDirectory() as dir_: with open(f'{dir}/manifest.json', 'w') as file: json.dump({'foo': 'bar'}, file) manifest = LocalManifest(Location(dir_)) with open(manifest.location, 'r') as file: manifest = json.load(file) self.assertIn('foo', manifest) self.assertEqual(manifest['foo'], 'bar') def test_writes_valid_json(self): with tempfile.TemporaryDirectory() as dir_: manifest = LocalManifest(Location(dir_)) for _ in range(self.NUM_RUNS): status = make_download_status() manifest._update(status) with open(manifest.location) as file: self.is_valid_json(file) def test_overwrites_existing_statuses(self): locations = ['a', 'b', 'c'] with tempfile.TemporaryDirectory() as dir_: manifest = LocalManifest(Location(dir_)) for i in range(self.NUM_RUNS): status = make_download_status(location=locations[i % 3]) manifest._update(status) with open(manifest.location) as file: self.is_valid_json(file) with open(manifest.location) as file: manifest = json.load(file) self.assertEqual(set(locations), set(manifest.keys())) def is_valid_json(self, file: t.IO) -> None: """Fails test on error decoding JSON.""" try: json.dumps(json.load(file)) except json.JSONDecodeError: self.fail('JSON is invalid.') ================================================ FILE: weather_dl/download_pipeline/parsers.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Parsers for ECMWF download configuration.""" import ast import configparser import copy as cp import datetime import json import string import textwrap import typing as t import numpy as np from collections import OrderedDict from dateutil.relativedelta import relativedelta from urllib.parse import urlparse from .clients import CLIENTS from .config import Config from .manifest import MANIFESTS, Manifest, Location, NoOpManifest def date(candidate: str) -> datetime.date: """Converts ECMWF-format date strings into a `datetime.date`. Accepted absolute date formats: - YYYY-MM-DD - YYYYMMDD - YYYY-DDD, where DDD refers to the day of the year For example: - 2021-10-31 - 19700101 - 1950-007 See https://confluence.ecmwf.int/pages/viewpage.action?pageId=118817289 for date format spec. Note: Name of month is not supported. """ converted = None # Parse relative day value. if candidate.startswith('-'): return datetime.date.today() + datetime.timedelta(days=int(candidate)) accepted_formats = ["%Y-%m-%d", "%Y%m%d", "%Y-%j"] for fmt in accepted_formats: try: converted = datetime.datetime.strptime(candidate, fmt).date() break except ValueError: pass if converted is None: raise ValueError( f"Not a valid date: '{candidate}'. Please use valid relative or absolute format." ) return converted def time(candidate: str) -> datetime.time: """Converts ECMWF-format time strings into a `datetime.time`. Accepted time formats: - HH:MM - HHMM - HH For example: - 18:00 - 1820 - 18 Note: If MM is omitted it defaults to 00. """ converted = None accepted_formats = ["%H", "%H:%M", "%H%M"] for fmt in accepted_formats: try: converted = datetime.datetime.strptime(candidate, fmt).time() break except ValueError: pass if converted is None: raise ValueError( f"Not a valid time: '{candidate}'. Please use valid format." ) return converted def day_month_year(candidate: t.Any) -> int: """Converts day, month and year strings into 'int'.""" try: if isinstance(candidate, str) or isinstance(candidate, int): return int(candidate) raise ValueError('must be a str or int.') except ValueError as e: raise ValueError( f"Not a valid day, month, or year value: {candidate}. Please use valid value." ) from e def date_range_converter(candidate: str) -> str: """Replace / with _ to avoid directory creation.""" return candidate.replace('/', '_') def parse_literal(candidate: t.Any) -> t.Any: try: # Support parsing ints with leading zeros, e.g. '01' if isinstance(candidate, str) and candidate.isdigit(): return int(candidate) return ast.literal_eval(candidate) except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): return candidate def validate(key: str, value: int) -> None: """Validates value based on the key.""" if key == "day": assert 1 <= value <= 31, "Day value must be between 1 to 31." if key == "month": assert 1 <= value <= 12, "Month value must be between 1 to 12." def typecast(key: str, value: t.Any) -> t.Any: """Type the value to its appropriate datatype.""" SWITCHER = { 'date': date, 'time': time, 'day': day_month_year, 'month': day_month_year, 'year': day_month_year, 'date_range': date_range_converter, } converted = SWITCHER.get(key, parse_literal)(value) validate(key, converted) return converted def _read_config_file(file: t.IO) -> t.Dict: """Reads `*.json` or `*.cfg` files.""" try: return json.load(file) except json.JSONDecodeError: pass file.seek(0) try: config = configparser.ConfigParser() config.read_file(file) config = {s: dict(config.items(s)) for s in config.sections()} return config except configparser.ParsingError: return {} def parse_config(file: t.IO) -> t.Dict: """Parses a `*.json` or `*.cfg` file into a configuration dictionary.""" config = _read_config_file(file) config_by_section = {s: _parse_lists(v, s) for s, v in config.items()} config_with_nesting = parse_subsections(config_by_section) return config_with_nesting def parse_manifest(location: Location, pipeline_opts: t.Dict) -> Manifest: """Constructs a manifest object by parsing the location.""" project_id__exists = 'project' in pipeline_opts project_id__not_set = 'projectId' not in location # If the firestore location doesn't specify which project (and, the pipeline # knows which project)... if location.startswith('fs://') and project_id__not_set and project_id__exists: # ...Set the project query param in the Firestore URI. start_char = '&' if '?' in location else '?' project = pipeline_opts.get('project') location += f'{start_char}projectId={project}' parsed = urlparse(location) return MANIFESTS.get(parsed.scheme, NoOpManifest)(location) def _splitlines(block: str) -> t.List[str]: """Converts a multi-line block into a list of strings.""" return [line.strip() for line in block.strip().splitlines()] def mars_range_value(token: str, key: str) -> t.Union[datetime.date, int, float]: """Converts a range token into either a date, int, or float.""" # TODO(b/175432034): Recognize time values try: if key == 'year-month': return datetime.datetime.strptime(token, "%Y-%m").date() else: return date(token) except ValueError: pass if token.isdecimal(): return int(token) try: return float(token) except ValueError: raise ValueError("Token string must be an 'int', 'float', or 'datetime.date()'.") def mars_increment_value(token: str) -> t.Union[int, float]: """Converts an increment token into either an int or a float.""" try: return int(token) except ValueError: pass try: return float(token) except ValueError: raise ValueError("Token string must be an 'int' or a 'float'.") def parse_mars_syntax(block: str, key: str) -> t.List[str]: """Parses MARS list or range into a list of arguments; ranges are inclusive. Types for the range and value are inferred. Examples: >>> parse_mars_syntax("10/to/12") ['10', '11', '12'] >>> parse_mars_syntax("12/to/10/by/-1") ['12', '11', '10'] >>> parse_mars_syntax("0.0/to/0.5/by/0.1") ['0.0', '0.1', '0.2', '0.30000000000000004', '0.4', '0.5'] >>> parse_mars_syntax("2020-01-07/to/2020-01-14/by/2") ['2020-01-07', '2020-01-09', '2020-01-11', '2020-01-13'] >>> parse_mars_syntax("2020-01-14/to/2020-01-07/by/-2") ['2020-01-14', '2020-01-12', '2020-01-10', '2020-01-08'] Returns: A list of strings representing a range from start to finish, based on the type of the values in the range. If all range values are integers, it will return a list of strings of integers. If range values are floats, it will return a list of strings of floats. If the range values are dates, it will return a list of strings of dates in YYYY-MM-DD format. (Note: here, the increment value should be an integer). """ # Split into tokens, omitting empty strings. tokens = [b.strip() for b in block.split('/') if b != ''] # Return list if no range operators are present. if 'to' not in tokens and 'by' not in tokens: return tokens # Parse range values, honoring 'to' and 'by' operators. try: to_idx = tokens.index('to') assert to_idx != 0, "There must be a start token." start_token, end_token = tokens[to_idx - 1], tokens[to_idx + 1] start, end = mars_range_value(start_token, key), mars_range_value(end_token, key) # Parse increment token, or choose default increment. increment_token = '1' increment = 1 if 'by' in tokens: increment_token = tokens[tokens.index('by') + 1] increment = mars_increment_value(increment_token) except (AssertionError, IndexError, ValueError): raise SyntaxError(f"Improper range syntax in '{block}'.") # Return a range of values with appropriate data type. if (key == 'year-month' and isinstance(start, datetime.date) and isinstance(end, datetime.date) and isinstance(increment, int)): result = [] offset = 1 if start <= end else -1 if increment >= 0: increment *= offset # ensure increment has correct direction current = start while current <= end if offset > 0 else current >= end: result.append(current.strftime("%Y-%m")) current += relativedelta(months=increment) return result elif isinstance(start, datetime.date) and isinstance(end, datetime.date) and key != 'year-month': increment *= -1 if start > end and increment > 0 else 1 if not isinstance(increment, int): raise ValueError( f"Increments on a date range must be integer number of days, '{increment_token}' is invalid." ) return [d.strftime("%Y-%m-%d") for d in date_range(start, end, increment)] elif (isinstance(start, float) or isinstance(end, float)) and not isinstance(increment, datetime.date): # Increment can be either an int or a float. _round_places = 4 return [str(round(x, _round_places)).zfill(len(start_token)) for x in np.arange(start, end + increment, increment)] elif isinstance(start, int) and isinstance(end, int) and isinstance(increment, int): # Honor leading zeros. offset = 1 if start <= end else -1 return [str(x).zfill(len(start_token)) for x in range(start, end + offset, increment)] else: raise ValueError( f"Range tokens (start='{start_token}', end='{end_token}', increment='{increment_token}')" f" are inconsistent types." ) def date_range(start: datetime.date, end: datetime.date, increment: int = 1) -> t.Iterable[datetime.date]: """Gets a range of dates, inclusive.""" offset = 1 if start <= end else -1 return (start + datetime.timedelta(days=x) for x in range(0, (end - start).days + offset, increment)) def _parse_lists(config: dict, section: str = '') -> t.Dict: """Parses multiline blocks in *.cfg and *.json files as lists.""" for key, val in config.items(): # Checks str type for backward compatibility since it also support "padding": 0 in json config if not isinstance(val, str): continue if '/' in val and 'parameters' not in section and key != 'date_range': config[key] = parse_mars_syntax(val, key) elif '\n' in val: config[key] = _splitlines(val) return config def _number_of_replacements(s: t.Text): format_names = [v[1] for v in string.Formatter().parse(s) if v[1] is not None] num_empty_names = len([empty for empty in format_names if empty == '']) if num_empty_names != 0: num_empty_names -= 1 return len(set(format_names)) + num_empty_names def parse_subsections(config: t.Dict) -> t.Dict: """Interprets [section.subsection] as nested dictionaries in `.cfg` files.""" copy = cp.deepcopy(config) for key, val in copy.items(): path = key.split('.') runner = copy parent = {} p = None for p in path: if p not in runner: runner[p] = {} parent = runner runner = runner[p] parent[p] = val for_cleanup = [key for key, _ in copy.items() if '.' in key] for target in for_cleanup: del copy[target] return copy def require(condition: bool, message: str, error_type: t.Type[Exception] = ValueError) -> None: """A assert-like helper that wraps text and throws an error.""" if not condition: raise error_type(textwrap.dedent(message)) def process_config(file: t.IO, config_name: str) -> Config: """Read the config file and prompt the user if it is improperly structured.""" config = parse_config(file) require(bool(config), "Unable to parse configuration file.") require('parameters' in config, """ 'parameters' section required in configuration file. The 'parameters' section specifies the 'client', 'dataset', 'target_path', and 'partition_key' for the API client. Please consult the documentation for more information.""") params = config.get('parameters', {}) require('target_template' not in params, """ 'target_template' is deprecated, use 'target_path' instead. Please consult the documentation for more information.""") require('target_path' in params, """ 'parameters' section requires a 'target_path' key. The 'target_path' is used to format the name of the output files. It accepts Python 3.5+ string format symbols (e.g. '{}'). The number of symbols should match the length of the 'partition_keys', as the 'partition_keys' args are used to create the templates.""") require('client' in params, """ 'parameters' section requires a 'client' key. Supported clients are {} """.format(str(list(CLIENTS.keys())))) require(params.get('client') in CLIENTS.keys(), """ Invalid 'client' parameter. Supported clients are {} """.format(str(list(CLIENTS.keys())))) require('append_date_dirs' not in params, """ The current version of 'google-weather-tools' no longer supports 'append_date_dirs'! Please refer to documentation for creating date-based directory hierarchy : https://weather-tools.readthedocs.io/en/latest/Configuration.html#""" """creating-a-date-based-directory-hierarchy.""", NotImplementedError) require('target_filename' not in params, """ The current version of 'google-weather-tools' no longer supports 'target_filename'! Please refer to documentation : https://weather-tools.readthedocs.io/en/latest/Configuration.html#parameters-section.""", NotImplementedError) partition_keys = params.get('partition_keys', list()) if isinstance(partition_keys, str): partition_keys = [partition_keys.strip()] selection = config.get('selection', dict()) require(all((key in selection for key in partition_keys)), """ All 'partition_keys' must appear in the 'selection' section. 'partition_keys' specify how to split data for workers. Please consult documentation for more information.""") num_template_replacements = _number_of_replacements(params['target_path']) num_partition_keys = len(partition_keys) require(num_template_replacements == num_partition_keys, """ 'target_path' has {0} replacements. Expected {1}, since there are {1} partition keys. """.format(num_template_replacements, num_partition_keys)) if 'day' in partition_keys: require(selection['day'] != 'all', """If 'all' is used for a selection value, it cannot appear as a partition key.""") if 'hdate' in selection: require('date' in partition_keys, """"If 'hdate' is specified in the 'selection' section, then 'date' is required as a partition keys.""") if 'date_range' in selection: require('date_range' in partition_keys, """"If 'date_range' is specified in the 'selection' section, then it is also required as a partition keys.""") # Ensure consistent lookup. config['parameters']['partition_keys'] = partition_keys # Add config file name. config['parameters']['config_name'] = config_name # Ensure the cartesian-cross can be taken on singleton values for the partition. for key in partition_keys: if not isinstance(selection[key], list): selection[key] = [selection[key]] return Config.from_dict(config) def prepare_target_name(config: Config) -> str: """Returns name of target location.""" partition_dict = OrderedDict((key, typecast(key, config.selection[key][0])) for key in config.partition_keys) target = config.target_path.format(*partition_dict.values(), **partition_dict) return target def get_subsections(config: Config) -> t.List[t.Tuple[str, t.Dict]]: """Collect parameter subsections from main configuration. If the `parameters` section contains subsections (e.g. '[parameters.1]', '[parameters.2]'), collect the subsection key-value pairs. Otherwise, return an empty dictionary (i.e. there are no subsections). This is useful for specifying multiple API keys for your configuration. For example: ``` [parameters.alice] api_key=KKKKK1 api_url=UUUUU1 [parameters.bob] api_key=KKKKK2 api_url=UUUUU2 [parameters.eve] api_key=KKKKK3 api_url=UUUUU3 ``` """ return [(name, params) for name, params in config.kwargs.items() if isinstance(params, dict)] or [('default', {})] def all_equal(iterator): iterator = iter(iterator) try: first = next(iterator) except StopIteration: return True return all(first == x for x in iterator) def validate_all_configs(configs: t.List[Config]) -> None: clients = [conf.client for conf in configs] require(all_equal(clients), f'All configs must request data from the same client, {clients[0]!r}.') kwargs = [conf.kwargs for conf in configs] require(all_equal(kwargs), 'Discrepancy in config parameters! Please check for consistency across all configs.') ================================================ FILE: weather_dl/download_pipeline/parsers_test.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import io import unittest from .manifest import MockManifest, Location from .parsers import ( date, parse_config, process_config, _number_of_replacements, parse_subsections, prepare_target_name, ) from .config import Config class DateTest(unittest.TestCase): def test_parses_relative_date(self): self.assertEqual(date('-0'), datetime.date.today()) self.assertEqual(date('-2'), datetime.date.today() + datetime.timedelta(days=-2)) def test_parses_kebob_date(self): self.assertEqual(date('2020-08-22'), datetime.date(year=2020, month=8, day=22)) self.assertEqual(date('1900-11-01'), datetime.date(year=1900, month=11, day=1)) def test_parses_smooshed_date(self): self.assertEqual(date('20200822'), datetime.date(year=2020, month=8, day=22)) self.assertEqual(date('19001101'), datetime.date(year=1900, month=11, day=1)) def test_parses_year_and_day_of_year(self): self.assertEqual(date('2020-235'), datetime.date(year=2020, month=8, day=22)) self.assertEqual(date('2021-007'), datetime.date(year=2021, month=1, day=7)) self.assertEqual(date('1900-305'), datetime.date(year=1900, month=11, day=1)) def test_throws_error(self): with self.assertRaises(ValueError): date('2020-08-22-12') with self.assertRaises(ValueError): date('2020-0822') with self.assertRaises(ValueError): date('20-08-22') with self.assertRaises(ValueError): date('') class ParseConfigTest(unittest.TestCase): def _assert_no_newlines_in_section(self, dictionary) -> None: for val in dictionary['section'].values(): self.assertNotIn('\n', val) def test_json(self): with io.StringIO('{"section": {"key": "value"}}') as f: actual = parse_config(f) self.assertDictEqual(actual, {'section': {'key': 'value'}}) def test_bad_json(self): with io.StringIO('{"section": {"key": "value", "brokenKey": }}') as f: actual = parse_config(f) self.assertDictEqual(actual, {}) def test_json_produces_lists(self): with io.StringIO('{"section": {"key": "value", "list": [0, 10, 20, 30, 40]}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], [0, 10, 20, 30, 40]) def test_json_parses_mars_list(self): with io.StringIO('{"section": {"key": "value", "list": "1/2/3"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['1', '2', '3']) def test_json_parses_mars_int_range(self): with io.StringIO('{"section": {"key": "value", "list": "1/to/5"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['1', '2', '3', '4', '5']) def test_json_parses_mars_int_range_padded(self): with io.StringIO('{"section": {"key": "value", "list": "00/to/05"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['00', '01', '02', '03', '04', '05']) def test_json_parses_mars_int_range_incremented(self): with io.StringIO('{"section": {"key": "value", "list": "1/to/5/by/2"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['1', '3', '5']) def test_json_parses_mars_float_range(self): with io.StringIO('{"section": {"key": "value", "list": "1.0/to/5.0"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['1.0', '2.0', '3.0', '4.0', '5.0']) def test_json_parses_mars_float_range_incremented(self): with io.StringIO('{"section": {"key": "value", "list": "1.0/to/5.0/by/2.0"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['1.0', '3.0', '5.0']) def test_json_parses_mars_float_range_incremented_by_float(self): with io.StringIO('{"section": {"key": "value", "list": "0.0/to/0.5/by/0.1"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertEqual(actual['section']['list'], ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5']) def test_json_parses_mars_date_range(self): with io.StringIO('{"section": {"key": "value", "list": "2020-01-07/to/2020-01-09"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['2020-01-07', '2020-01-08', '2020-01-09']) def test_json_parses_mars_relative_date_range(self): with io.StringIO('{"section": {"key": "value", "list": "-3/to/-1"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) dates = [ datetime.date.today() + datetime.timedelta(-3), datetime.date.today() + datetime.timedelta(-2), datetime.date.today() + datetime.timedelta(-1), ] self.assertListEqual(actual['section']['list'], [d.strftime("%Y-%m-%d") for d in dates]) def test_json_parses_mars_date_range_incremented(self): with io.StringIO('{"section": {"key": "value", "list": "2020-01-07/to/2020-01-12/by/2"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['2020-01-07', '2020-01-09', '2020-01-11']) def test_json_raises_syntax_error_missing_right(self): with self.assertRaisesRegex(SyntaxError, "Improper range syntax in '2020-01-07/to/'."): with io.StringIO('{"section": {"key": "value", "list": "2020-01-07/to/"}}') as f: parse_config(f) def test_json_raises_syntax_error_missing_left(self): with self.assertRaisesRegex(SyntaxError, "Improper range syntax in '/to/2020-01-07'."): with io.StringIO('{"section": {"key": "value", "list": "/to/2020-01-07"}}') as f: parse_config(f) def test_json_raises_syntax_error_missing_increment(self): with self.assertRaisesRegex(SyntaxError, "Improper range syntax in '2020-01-07/to/2020-01-11/by/'."): with io.StringIO('{"section": {"key": "value", "list": "2020-01-07/to/2020-01-11/by/"}}') as f: parse_config(f) def test_json_raises_syntax_error_no_range(self): with self.assertRaisesRegex(SyntaxError, "Improper range syntax in '2020-01-07/by/2020-01-11'."): with io.StringIO('{"section": {"key": "value", "list": "2020-01-07/by/2020-01-11"}}') as f: parse_config(f) def test_json_raises_value_error_date_types(self): with self.assertRaisesRegex( ValueError, "Increments on a date range must be integer number of days, '2.0' is invalid." ): with io.StringIO('{"section": {"key": "value", "list": "2020-01-07/to/2020-01-11/by/2.0"}}') as f: parse_config(f) def test_json_raises_value_error_float_types(self): with self.assertRaisesRegex(SyntaxError, "Improper range syntax in '1.0/to/10.0/by/2020-01-07'."): with io.StringIO('{"section": {"key": "value", "list": "1.0/to/10.0/by/2020-01-07"}}') as f: parse_config(f) def test_json_raises_value_error_int_types(self): with self.assertRaisesRegex(ValueError, "inconsistent types."): with io.StringIO('{"section": {"key": "value", "list": "1/to/10/by/2.0"}}') as f: parse_config(f) def test_json_parses_accidental_extra_whitespace(self): with io.StringIO('{"section": {"key": "value", "list": "1/to/5"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['1', '2', '3', '4', '5']) def test_json_parses_parameter_subsections(self): with io.StringIO('{"parameters": {"api_url": "https://google.com/", \ "alice": {"api_key": "123"}, \ "bob": {"api_key": "456"}}}') as f: actual = parse_config(f) self.assertEqual(actual, { 'parameters': { 'api_url': 'https://google.com/', 'alice': {'api_key': '123'}, 'bob': {'api_key': '456'}, }, }) def test_cfg(self): with io.StringIO( """ [section] key=value """ ) as f: actual = parse_config(f) self.assertDictEqual(actual, {'section': {'key': 'value'}}) def test_bad_cfg(self): with io.StringIO( """ key=value """ ) as f: actual = parse_config(f) self.assertDictEqual(actual, {}) def test_cfg_produces_lists(self): with io.StringIO( """ [section] key=value list=00 10 20 30 40 """ ) as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['00', '10', '20', '30', '40']) def test_cfg_parses_mars_list(self): with io.StringIO( """ [section] key=value list=1/2/3 """ ) as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['1', '2', '3']) def test_cfg_parses_mars_int_range(self): with io.StringIO( """ [section] key=value list=1/to/5 """ ) as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['1', '2', '3', '4', '5']) def test_cfg_parses_mars_int_range_padded(self): with io.StringIO( """ [section] key=value list=00/to/05 """ ) as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['00', '01', '02', '03', '04', '05']) def test_cfg_parses_mars_int_range_incremented(self): with io.StringIO( """ [section] key=value list=1/to/5/by/2 """ ) as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['1', '3', '5']) def test_cfg_parses_mars_int_reverse_range_incremented(self): with io.StringIO( """ [section] key=value list=5/to/1/by/-2 """ ) as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['5', '3', '1']) def test_cfg_parses_mars_float_range(self): with io.StringIO( """ [section] key=value list=1.0/to/5.0 """ ) as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['1.0', '2.0', '3.0', '4.0', '5.0']) def test_cfg_parses_mars_float_range_incremented(self): with io.StringIO( """ [section] key=value list=1.0/to/5.0/by/2.0 """ ) as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['1.0', '3.0', '5.0']) def test_cfg_parses_mars_float_range_incremented_by_float(self): with io.StringIO( """ [section] key=value list=0.0/to/0.5/by/0.1 """ ) as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['0.0', '0.1', '0.2', '0.3', '0.4', '0.5']) def test_cfg_parses_mars_float_reverse_range_incremented_by_float(self): with io.StringIO( """ [section] key=value list=0.5/to/0.0/by/-0.1 """ ) as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['0.5', '0.4', '0.3', '0.2', '0.1', '0.0']) def test_cfg_parses_mars_date_range(self): with io.StringIO( """ [section] key=value list=2020-01-07/to/2020-01-09 """ ) as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['2020-01-07', '2020-01-08', '2020-01-09']) def test_cfg_parses_mars_relative_date_range(self): with io.StringIO( """ [section] key=value list=-3/to/-1 """ ) as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) dates = [ datetime.date.today() + datetime.timedelta(-3), datetime.date.today() + datetime.timedelta(-2), datetime.date.today() + datetime.timedelta(-1), ] self.assertListEqual(actual['section']['list'], [d.strftime("%Y-%m-%d") for d in dates]) def test_cfg_parses_mars_relative_date_reverse_range(self): with io.StringIO( """ [section] key=value list=-1/to/-3/by/-1 """ ) as f: actual = parse_config(f) for key, val in actual['section'].items(): self.assertNotIn('\n', val) dates = [ datetime.date.today() + datetime.timedelta(-1), datetime.date.today() + datetime.timedelta(-2), datetime.date.today() + datetime.timedelta(-3), ] self.assertListEqual(actual['section']['list'], [d.strftime("%Y-%m-%d") for d in dates]) def test_cfg_parses_mars_date_range_incremented(self): with io.StringIO( """ [section] key=value list=2020-01-07/to/2020-01-12/by/2 """ ) as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['2020-01-07', '2020-01-09', '2020-01-11']) def test_cfg_parses_mars_date_reverse_range_incremented(self): with io.StringIO( """ [section] key=value list=2020-01-12/to/2020-01-07/by/-2 """ ) as f: actual = parse_config(f) for key, val in actual['section'].items(): self.assertNotIn('\n', val) self.assertListEqual(actual['section']['list'], ['2020-01-12', '2020-01-10', '2020-01-08']) def test_cfg_raises_syntax_error_missing_right(self): with self.assertRaisesRegex(SyntaxError, "Improper range syntax in '2020-01-07/to/'."): with io.StringIO( """ [section] key=value list=2020-01-07/to/ """ ) as f: parse_config(f) def test_cfg_raises_syntax_error_missing_left(self): with self.assertRaisesRegex(SyntaxError, "Improper range syntax in '/to/2020-01-07'."): with io.StringIO( """ [section] key=value list=/to/2020-01-07 """ ) as f: parse_config(f) def test_cfg_raises_syntax_error_missing_increment(self): with self.assertRaisesRegex(SyntaxError, "Improper range syntax in '2020-01-07/to/2020-01-11/by/'."): with io.StringIO( """ [section] key=value list=2020-01-07/to/2020-01-11/by/ """ ) as f: parse_config(f) def test_cfg_raises_syntax_error_no_range(self): with self.assertRaisesRegex(SyntaxError, "Improper range syntax in '2020-01-07/by/2020-01-11'."): with io.StringIO( """ [section] key=value list=2020-01-07/by/2020-01-11 """ ) as f: parse_config(f) def test_cfg_raises_value_error_date_types(self): with self.assertRaisesRegex( ValueError, "Increments on a date range must be integer number of days, '2.0' is invalid." ): with io.StringIO( """ [section] key=value list=2020-01-07/to/2020-01-11/by/2.0 """ ) as f: parse_config(f) def test_cfg_raises_value_error_float_types(self): with self.assertRaisesRegex(SyntaxError, "Improper range syntax in '1.0/to/10.0/by/2020-01-07'."): with io.StringIO( """ [section] key=value list=1.0/to/10.0/by/2020-01-07 """ ) as f: parse_config(f) def test_cfg_raises_value_error_int_types(self): with self.assertRaisesRegex(ValueError, "inconsistent types."): with io.StringIO( """ [section] key=value list=1/to/10/by/2.0 """ ) as f: parse_config(f) def test_cfg_parses_accidental_extra_whitespace(self): with io.StringIO( """ [section] key=value list= 1/to/5 """ ) as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual(actual['section']['list'], ['1', '2', '3', '4', '5']) def test_json_parse_year_mon_one(self): with io.StringIO('{"section": {"year-month": "2024-10/to/2024-10"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual( actual['section']['year-month'], ['2024-10'] ) def test_json_parse_year_mon_one_by_one(self): with io.StringIO('{"section": {"year-month": "2024-10/to/2024-10/by/1"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual( actual['section']['year-month'], ['2024-10'] ) def test_json_parse_year_mon_one_by_two(self): with io.StringIO('{"section": {"year-month": "2024-10/to/2024-10/by/2"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual( actual['section']['year-month'], ['2024-10'] ) def test_json_parse_year_mon_six(self): with io.StringIO('{"section": {"year-month": "2024-10/to/2025-3"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual( actual['section']['year-month'], ['2024-10', '2024-11', '2024-12', '2025-01', '2025-02', '2025-03'] ) def test_json_parse_year_mon_six_by_one(self): with io.StringIO('{"section": {"year-month": "2024-10/to/2025-3/by/1"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual( actual['section']['year-month'], ['2024-10', '2024-11', '2024-12', '2025-01', '2025-02', '2025-03'] ) def test_json_parse_year_mon_six_by_three(self): with io.StringIO('{"section": {"year-month": "2024-10/to/2025-3/by/3"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual( actual['section']['year-month'], ['2024-10', '2025-01'] ) def test_json_parse_year_mon_six_rev(self): with io.StringIO('{"section": {"year-month": "2025-3/to/2024-10"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual( actual['section']['year-month'], ['2025-03', '2025-02', '2025-01', '2024-12', '2024-11', '2024-10'] ) def test_json_parse_year_mon_six_by_one_rev(self): with io.StringIO('{"section": {"year-month": "2025-3/to/2024-10/by/-1"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual( actual['section']['year-month'], ['2025-03', '2025-02', '2025-01', '2024-12', '2024-11', '2024-10'] ) def test_json_parse_year_mon_six_by_three_rev(self): with io.StringIO('{"section": {"year-month": "2025-3/to/2024-10/by/-3"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual( actual['section']['year-month'], ['2025-03', '2024-12'] ) def test_json_parse_year_mon_eighteen(self): with io.StringIO('{"section": {"year-month": "2023-10/to/2025-3"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual( actual['section']['year-month'], ['2023-10', '2023-11', '2023-12', '2024-01', '2024-02', '2024-03', '2024-04', '2024-05', '2024-06', '2024-07', '2024-08', '2024-09', '2024-10', '2024-11', '2024-12', '2025-01', '2025-02', '2025-03'] ) def test_json_parse_year_mon_eighteen_by_two(self): with io.StringIO('{"section": {"year-month": "2023-10/to/2025-3/by/2"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual( actual['section']['year-month'], ['2023-10', '2023-12', '2024-02', '2024-04', '2024-06', '2024-08', '2024-10', '2024-12', '2025-02'] ) def test_json_parse_year_mon_eighteen_by_six(self): with io.StringIO('{"section": {"year-month": "2023-10/to/2025-3/by/6"}}') as f: actual = parse_config(f) self._assert_no_newlines_in_section(actual) self.assertListEqual( actual['section']['year-month'], ['2023-10', '2024-04', '2024-10'] ) def test_cfg_parses_parameter_subsections(self): with io.StringIO( """ [parameters] api_url=https://google.com/ [parameters.alice] api_key=123 [parameters.bob] api_key=456 """ ) as f: actual = parse_config(f) self.assertEqual(actual, { 'parameters': { 'api_url': 'https://google.com/', 'alice': {'api_key': '123'}, 'bob': {'api_key': '456'}, }, }) class HelpersTest(unittest.TestCase): CASES = [('', 0), ('{} blah', 1), ('{} {}', 2), ('{0}, {1}', 2), ('%s hello', 0), ('hello {.2f}', 1), ('ear5-{year}{year}-{month}', 2), ('era5-{year}/{year}-{}-{}', 3), ('{year}{year}{year}{year}', 1)] def test_number_of_replacements(self): for s, want in self.CASES: with self.subTest(s=s, want=want): actual = _number_of_replacements(s) self.assertEqual(actual, want) class SubsectionsTest(unittest.TestCase): def test_parses_config_subsections(self): config = {"parsers": {'a': 1, 'b': 2}, "parsers.1": {'b': 3}} actual = parse_subsections(config) self.assertEqual(actual, {'parsers': {'a': 1, 'b': 2, '1': {'b': 3}}}) class ApiKeyCountingTest(unittest.TestCase): def test_no_keys(self): config = {"parameters": {'a': 1, 'b': 2}, "parameters.1": {'b': 3}} actual = parse_subsections(config) self.assertEqual(actual, {'parameters': {'a': 1, 'b': 2, '1': {'b': 3}}}) def test_api_keys(self): config = {"parameters": {'a': 1, 'b': 2}, "parameters.param1": {'api_key': 'key1'}, "parameters.param2": {'api_key': 'key2'}} actual = parse_subsections(config) self.assertEqual(actual, {'parameters': {'a': 1, 'b': 2, 'param1': {'api_key': 'key1'}, 'param2': {'api_key': 'key2'}}}) class ProcessConfigTest(unittest.TestCase): def test_parse_config(self): with self.assertRaises(ValueError) as ctx: with io.StringIO( """ key=value """ ) as f: process_config(f, 'test.cfg') self.assertEqual("Unable to parse configuration file.", ctx.exception.args[0]) def test_require_params_section(self): with self.assertRaises(ValueError) as ctx: with io.StringIO( """ [selection] key=value """ ) as f: process_config(f, 'test.cfg') self.assertIn( "'parameters' section required in configuration file.", ctx.exception.args[0]) def test_accepts_parameters_section(self): with self.assertRaises(ValueError) as ctx: with io.StringIO( """ [parameters] key=value """ ) as f: process_config(f, 'test.cfg') self.assertNotIn( "'parameters' section required in configuration file.", ctx.exception.args[0]) def test_requires_target_path_param(self): with self.assertRaises(ValueError) as ctx: with io.StringIO( """ [parameters] dataset=foo client=cds target=bar """ ) as f: process_config(f, 'test.cfg') self.assertIn( "'parameters' section requires a 'target_path' key.", ctx.exception.args[0]) def test_requires_target_template_param_not_present(self): with self.assertRaises(ValueError) as ctx: with io.StringIO( """ [parameters] dataset=foo client=cds target_template=bar """ ) as f: process_config(f, 'test.cfg') self.assertIn( "'target_template' is deprecated, use 'target_path' instead.", ctx.exception.args[0]) def test_accepts_target_path_param(self): with self.assertRaises(ValueError) as ctx: with io.StringIO( """ [parameters] dataset=foo client=cds target_path """ ) as f: process_config(f, 'test.cfg') self.assertNotIn( "'parameters' section requires a 'target_path' key.", ctx.exception.args[0]) def test_requires_partition_keys_to_match_sections(self): with self.assertRaises(ValueError) as ctx: with io.StringIO( """ [parameters] dataset=foo client=cds target_path=bar-{}-{} partition_keys= year month [selection] day= 01 02 03 decade= 1950 1960 1970 1980 1990 2000 2010 2020 """ ) as f: process_config(f, 'test.cfg') self.assertIn( "All 'partition_keys' must appear in the 'selection' section.", ctx.exception.args[0]) def test_accepts_partition_keys_matching_sections(self): with io.StringIO( """ [parameters] dataset=foo client=cds target_path=bar-{}-{} partition_keys= year month [selection] month= 01 02 03 year= 1950 1960 1970 1980 1990 2000 2010 2020 """ ) as f: config = process_config(f, 'test.cfg') self.assertTrue(bool(config)) def test_accepts_partition_keys_not_present(self): with io.StringIO( """ [parameters] dataset=foo client=cds target_path=bar [selection] month= 01 02 03 year= 1950 1960 1970 1980 1990 2000 2010 2020 """ ) as f: config = process_config(f, 'test.cfg') self.assertTrue(bool(config)) def test_treats_partition_keys_as_list(self): with io.StringIO( """ [parameters] dataset=foo client=cds target_path=bar-{} partition_keys=month [selection] month= 01 02 03 """ ) as f: config = process_config(f, 'test.cfg') self.assertIsInstance(config.partition_keys, list) def test_mismatched_template_partition_keys(self): with self.assertRaises(ValueError) as ctx: with io.StringIO( """ [parameters] dataset=foo client=cds target_path=bar-{} partition_keys= year month [selection] month= 01 02 03 year= 1950 1960 1970 1980 1990 2000 2010 2020 """ ) as f: process_config(f, 'test.cfg') self.assertIn( "'target_path' has 1 replacements. Expected 2", ctx.exception.args[0]) def test_append_date_dirs_raise_error(self): with self.assertRaises(NotImplementedError) as ctx: with io.StringIO( """ [parameters] dataset=foo client=cds target_path=somewhere/bar-{} append_date_dirs=true partition_keys= date [selection] date=2017-01-01/to/2017-01-01 """ ) as f: process_config(f, 'test.cfg') self.assertIn( "The current version of 'google-weather-tools' no longer supports 'append_date_dirs'!" "\n\nPlease refer to documentation for creating date-based directory hierarchy :\n" "https://weather-tools.readthedocs.io/en/latest/Configuration.html" "#creating-a-date-based-directory-hierarchy.", ctx.exception.args[0]) def test_target_filename_raise_error(self): with self.assertRaises(NotImplementedError) as ctx: with io.StringIO( """ [parameters] dataset=foo client=cds target_path=somewhere/ target_filename=bar-{} partition_keys= date [selection] date=2017-01-01/to/2017-01-01 """ ) as f: process_config(f, 'test.cfg') self.assertIn( "The current version of 'google-weather-tools' no longer supports 'target_filename'!" "\n\nPlease refer to documentation :\n" "https://weather-tools.readthedocs.io/en/latest/Configuration.html#parameters-section.", ctx.exception.args[0]) def test_client_not_set(self): with self.assertRaises(ValueError) as ctx: with io.StringIO( """ [parameters] dataset=foo target_path=bar-{} partition_keys= year [selection] year= 1969 """ ) as f: process_config(f, 'test.cfg') self.assertIn( "'parameters' section requires a 'client' key.", ctx.exception.args[0]) def test_client_invalid(self): with self.assertRaises(ValueError) as ctx: with io.StringIO( """ [parameters] dataset=foo client=nope target_path=bar-{} partition_keys= year [selection] year= 1969 """ ) as f: process_config(f, 'test.cfg') self.assertIn( "Invalid 'client' parameter.", ctx.exception.args[0]) def test_partition_cannot_include_all(self): with self.assertRaisesRegex(ValueError, 'cannot appear as a partition key.'): with io.StringIO( """ [parameters] dataset=foo client=cds target_path=bar-{}-{} partition_keys= month day [selection] year=2012 month= 01 02 03 day=all """ ) as f: process_config(f, 'test.cfg') def test_partition_key_contains_date_range(self): with self.assertRaisesRegex(ValueError, """If 'date_range' is specified in the 'selection' section, then it is also required as a partition keys."""): with io.StringIO( """ [parameters] dataset=foo client=ecpublic target_path=bar-{} partition_keys= time [selection] date_range=2017-01-01/to/2017-01-10 time=00/12 """ ) as f: process_config(f, 'test.cfg') def test_partition_key_contains_date__in_case_of_hdate(self): with self.assertRaisesRegex(ValueError, "'date' is required as a partition keys."): with io.StringIO( """ [parameters] dataset=foo client=ecpublic target_path=bar-{} partition_keys= step [selection] date=2020-01-02 step=1/2/3/4 hdate=1/to/6 """ ) as f: process_config(f, 'test.cfg') def test_singleton_partitions_are_converted_to_lists(self): with io.StringIO( """ [parameters] dataset=foo client=cds target_path=bar-{}-{} partition_keys= month year [selection] month=01 year=2018 """ ) as f: config = process_config(f, 'test.cfg') self.assertEqual(config.selection['month'], ['01']) self.assertEqual(config.selection['year'], ['2018']) class PrepareTargetNameTest(unittest.TestCase): TEST_CASES = [ dict(case='No date.', config={ 'parameters': { 'partition_keys': ['year', 'month'], 'target_path': 'download-{}-{}.nc', 'force_download': False }, 'selection': { 'features': ['pressure'], 'month': ['12'], 'year': ['02'] } }, expected='download-2-12.nc'), dict(case='Has date but no target directory.', config={ 'parameters': { 'partition_keys': ['date'], 'target_path': 'download-{}.nc', 'force_download': False }, 'selection': { 'features': ['pressure'], 'date': ['2017-01-15'], } }, expected='download-2017-01-15.nc'), dict(case='Has Directory, but no date', config={ 'parameters': { 'target_path': 'somewhere/download/{:02d}/{:02d}.nc', 'partition_keys': ['year', 'month'], 'force_download': False }, 'selection': { 'features': ['pressure'], 'month': ['02'], 'year': ['02'] } }, expected='somewhere/download/02/02.nc'), dict(case='Had date and target directory', config={ 'parameters': { 'partition_keys': ['date'], 'target_path': 'somewhere/{date:%Y/%m/%d}-download.nc', 'force_download': False }, 'selection': { 'date': ['2017-01-15'], } }, expected='somewhere/2017/01/15-download.nc'), dict(case='Had date, target directory, and additional params.', config={ 'parameters': { 'partition_keys': ['date', 'pressure_level'], 'target_path': 'somewhere/{date:%Y/%m/%d}-pressure-{pressure_level}.nc', 'force_download': False }, 'selection': { 'features': ['pressure'], 'pressure_level': ['500'], 'date': ['2017-01-15'], } }, expected='somewhere/2017/01/15-pressure-500.nc'), dict(case='Has date and target directory, including parameters in path.', config={ 'parameters': { 'partition_keys': ['date', 'expver', 'pressure_level'], 'target_path': 'somewhere/expver-{expver}/{date:%Y/%m/%d}-pressure-{pressure_level}.nc', 'force_download': False }, 'selection': { 'features': ['pressure'], 'pressure_level': ['500'], 'date': ['2017-01-15'], 'expver': ['1'], } }, expected='somewhere/expver-1/2017/01/15-pressure-500.nc'), dict(case='Has date_range', config={ 'parameters': { 'partition_keys': ['date_range', 'time'], 'target_path': 'somewhere/{date_range}-{time}.nc', 'force_download': False }, 'selection': { 'features': ['pressure'], 'date_range': ['2017-01-01/to/2017-01-10'], 'time': ['00'] } }, expected='somewhere/2017-01-01_to_2017-01-10-00:00:00.nc'), ] def setUp(self) -> None: self.dummy_manifest = MockManifest(Location('dummy-manifest')) def test_target_name(self): for it in self.TEST_CASES: with self.subTest(msg=it['case'], **it): actual = prepare_target_name(Config.from_dict(it['config'])) self.assertEqual(actual, it['expected']) if __name__ == '__main__': unittest.main() ================================================ FILE: weather_dl/download_pipeline/partition.py ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy as cp import dataclasses import itertools import logging import math import typing as t import apache_beam as beam from .config import Config from .manifest import Manifest, NoOpManifest, Location from .parsers import prepare_target_name from .stores import Store, FSStore from .util import ichunked, generate_hdate Partition = t.Tuple[str, t.Dict, Config] Index = t.Tuple[int] logger = logging.getLogger(__name__) @dataclasses.dataclass class PartitionConfig(beam.PTransform): """Partition a config into multiple data requests. Partitioning involves four main operations: First, we fan-out shards based on partition keys (a cross product of the values). Second, we filter out existing downloads (unless we want to force downloads). Next, we add subsections to the configs in a cycle (to ensure an even distribution of extra parameters). Last, We assemble each partition into a single Config. Attributes: store: A cloud storage system, used for checking the existence of downloads. subsections: A cycle of (name, parameter) tuples. manifest: A download manifest to register preparation state. scheduling: How to sort partitions from multiple configs: in order of each config (default), or in "fair" order, where partitions from each config are evenly rotated. partition_chunks: The size of chunks of partition shards to use during the fan-out stage. By default, this operation will aim to divide all the partitions into groups of about 1000 sized-chunks. num_groups: The the number of groups (for fair scheduling). Default: 1. """ store: Store subsections: itertools.cycle manifest: Manifest scheduling: str partition_chunks: t.Optional[int] = None update_manifest: bool = False num_groups: int = 1 def expand(self, configs): def loop_through_subsections(it: Config) -> Partition: """Assign a subsection to each config in a loop. If the `parameters` section contains subsections (e.g. '[parameters.1]', '[parameters.2]'), collect a repeating cycle of the subsection key-value pairs. Otherwise, assign a default section to each config. This is useful for specifying multiple API keys for your configuration. For example: ``` [parameters.alice] api_key=KKKKK1 api_url=UUUUU1 [parameters.bob] api_key=KKKKK2 api_url=UUUUU2 [parameters.eve] api_key=KKKKK3 api_url=UUUUU3 ``` """ name, params = next(self.subsections) return name, params, it if self.scheduling == 'fair': config_idxs = ( configs | beam.combiners.ToList() | 'Fair Fan-out' >> beam.FlatMap(prepare_fair_partition_index, chunk_size=self.partition_chunks, groups=self.num_groups) ) else: config_idxs = ( configs | 'Fan-out' >> beam.FlatMap(prepare_partition_index, chunk_size=self.partition_chunks) ) return ( config_idxs | beam.Reshuffle() | 'To configs' >> beam.FlatMapTuple(prepare_partitions_from_index) | 'Skip existing' >> beam.Filter(new_downloads_only, store=self.store, manifest=self.manifest) | 'Cycle subsections' >> beam.Map(loop_through_subsections) | 'Assemble' >> beam.Map(assemble_config, manifest=self.manifest) ) def _create_partition_config(option: t.Tuple, config: Config) -> Config: """Create a config for a single partition option. Output a config dictionary, overriding the range of values for each key with the partition instance in 'selection'. Continuing the example from prepare_partitions, the selection section would be: { 'foo': ..., 'year': ['2020'], 'month': ['01'], ... } { 'foo': ..., 'year': ['2020'], 'month': ['02'], ... } { 'foo': ..., 'year': ['2020'], 'month': ['03'], ... } Args: option: A single item in the range of partition_keys. config: The download config, including the parameters and selection sections. Returns: A configuration with that selects a single download partition. """ copy = cp.deepcopy(config.selection) out = cp.deepcopy(config) for idx, key in enumerate(config.partition_keys): copy[key] = [option[idx]] # Replace hdate with actual value. if 'hdate' in copy: copy['hdate'] = [generate_hdate(copy['date'][0], v) for v in copy['hdate']] out.selection = copy return out def skip_partition(config: Config, store: Store, manifest: Manifest) -> bool: """Return true if partition should be skipped.""" if config.force_download: return False target = prepare_target_name(config) if store.exists(target): logger.info(f'file {target} found, skipping.') manifest.skip(config.config_name, config.dataset, config.selection, target, config.user_id) return True return False def prepare_partition_index(config: Config, chunk_size: t.Optional[int] = None) -> t.Iterator[t.Tuple[Config, t.List[Index]]]: """Produce indexes over client parameters, partitioning over `partition_keys` This produces a Cartesian-Cross over the range of keys. For example, if the keys were 'year' and 'month', it would produce an iterable like: ( (0, 0), (0, 1), (0, 2), ...) After the indexes were converted back to keys, it would produce values like: ( ('2020', '01'), ('2020', '02'), ('2020', '03'), ...) Returns: An iterator of index tuples. """ dims = [range(len(config.selection[key])) for key in config.partition_keys] n_partitions = math.prod([len(d) for d in dims]) logger.info(f'Creating {n_partitions} partitions.') if chunk_size is None: chunk_size = 1000 if not dims: yield config, [] else: for option_idx in ichunked(itertools.product(*dims), chunk_size): yield config, list(option_idx) def prepare_partitions_from_index(config: Config, indexes: t.List[Index]) -> t.Iterator[Config]: """Convert a partition index into a config. Returns: from an option index. A partition `Config` from an option index. """ if not indexes: yield _create_partition_config((), config) else: for index in indexes: option = tuple( config.selection[config.partition_keys[key_idx]][val_idx] for key_idx, val_idx in enumerate(index) ) yield _create_partition_config(option, config) def new_downloads_only(candidate: Config, store: t.Optional[Store] = None, manifest: Manifest = NoOpManifest(Location('noop://in-memory'))) -> bool: """Predicate function to skip already downloaded partitions.""" if store is None: store = FSStore() should_skip = skip_partition(candidate, store, manifest) if should_skip: beam.metrics.Metrics.counter('Prepare', 'skipped').inc() return not should_skip def assemble_config(partition: Partition, manifest: Manifest) -> Config: """Assemble the configuration for a single partition. For each cross product of the 'selection' sections, the output dictionary will overwrite parameters from the extra param subsections, evenly cycling through each subsection. For example: { 'parameters': {... 'api_key': KKKKK1, ... }, ... } { 'parameters': {... 'api_key': KKKKK2, ... }, ... } { 'parameters': {... 'api_key': KKKKK3, ... }, ... } { 'parameters': {... 'api_key': KKKKK1, ... }, ... } { 'parameters': {... 'api_key': KKKKK2, ... }, ... } { 'parameters': {... 'api_key': KKKKK3, ... }, ... } ... Returns: An `Config` assembled out of subsection parameters and config shards. """ name, params, out = partition out.kwargs.update(params) out.subsection_name = name location = prepare_target_name(out) user = out.user_id manifest.schedule(out.config_name, out.dataset, out.selection, location, user) logger.info(f'[{name}] Created partition {location!r}.') beam.metrics.Metrics.counter('Subsection', name).inc() return out def cycle_iters(iters: t.List[t.Iterator], take: int = 1) -> t.Iterator: """Evenly cycle through a list of iterators. Args: iters: A list of iterators to evely cycle through. take: Yield N items at a time. When not set to 1, this will yield multiple items from the same collection. Returns: An iteration across several iterators in a round-robin order. """ while iters: for i, it in enumerate(iters): try: for j in range(take): logger.debug(f'yielding item {j!r} from iterable {i!r}.') yield next(it) except StopIteration: iters.remove(it) def prepare_fair_partition_index(configs: t.List[Config], chunk_size: t.Optional[int], groups: int) -> t.Iterator[t.Tuple[Config, t.List[Index]]]: """Given a list of all configs, evenly cycle through each partition chunked by the 'chunk_size'.""" if chunk_size is None: chunk_size = 1 iters = [prepare_partition_index(config, chunk_size) for config in configs] yield from cycle_iters(iters, take=groups) ================================================ FILE: weather_dl/download_pipeline/partition_test.py ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools import json import tempfile import typing as t import unittest from unittest.mock import MagicMock import apache_beam as beam from xarray_beam._src.test_util import EagerPipeline from .manifest import MockManifest, Location, DownloadStatus, LocalManifest, Status, Stage from .parsers import get_subsections from .partition import skip_partition, PartitionConfig from .stores import InMemoryStore, Store from .config import Config class OddFilesDoNotExistStore(InMemoryStore): def __init__(self): super().__init__() self.count = 0 def exists(self, filename: str) -> bool: ret = self.count % 2 == 0 self.count += 1 return ret class PreparePartitionTest(unittest.TestCase): def setUp(self) -> None: self.dummy_manifest = MockManifest(Location('mock://dummy')) def create_partition_configs(self, configs, store: t.Optional[Store] = None, schedule='in-order', n_requests_per: int = 1) -> t.List[Config]: subsections = get_subsections(configs[0]) params_cycle = itertools.cycle(subsections) return (EagerPipeline() | beam.Create(configs) | PartitionConfig(store, params_cycle, self.dummy_manifest, schedule, len(subsections) * n_requests_per)) def test_partition_single_key(self): config = { 'parameters': { 'partition_keys': ['year'], 'target_path': 'download-{}.nc', }, 'selection': { 'features': ['pressure', 'temperature', 'wind_speed_U', 'wind_speed_V'], 'month': [str(i) for i in range(1, 13)], 'year': [str(i) for i in range(2015, 2021)] } } config_obj = Config.from_dict(config) actual = self.create_partition_configs([config_obj]) self.assertListEqual([d.selection for d in actual], [ {**config['selection'], **{'year': [str(i)]}} for i in range(2015, 2021) ]) def test_partition_multi_key(self): config = { 'parameters': { 'partition_keys': ['year', 'month'], 'target_path': 'download-{}-{}.nc', }, 'selection': { 'features': ['pressure', 'temperature', 'wind_speed_U', 'wind_speed_V'], 'month': [str(i) for i in range(1, 3)], 'year': [str(i) for i in range(2015, 2017)] } } config_obj = Config.from_dict(config) actual = self.create_partition_configs([config_obj]) self.assertListEqual([d.selection for d in actual], [ {**config['selection'], **{'year': ['2015'], 'month': ['1']}}, {**config['selection'], **{'year': ['2015'], 'month': ['2']}}, {**config['selection'], **{'year': ['2016'], 'month': ['1']}}, {**config['selection'], **{'year': ['2016'], 'month': ['2']}}, ]) def test_partition_multi_key_single_values(self): config = { 'parameters': { 'partition_keys': ['year', 'month'], 'target_path': 'download-{}-{}.nc', }, 'selection': { 'features': ['pressure', 'temperature', 'wind_speed_U', 'wind_speed_V'], 'month': ['1'], 'year': ['2015'], } } config_obj = Config.from_dict(config) actual = self.create_partition_configs([config_obj]) self.assertListEqual([d.selection for d in actual], [ {**config['selection'], **{'year': ['2015'], 'month': ['1']}}, ]) def test_partition_multi_params_multi_key(self): config = { 'parameters': dict( partition_keys=['year', 'month'], target_path='download-{}-{}.nc', research={ 'api_key': 'KKKK1', 'api_url': 'UUUU1' }, cloud={ 'api_key': 'KKKK2', 'api_url': 'UUUU2' }, deepmind={ 'api_key': 'KKKK3', 'api_url': 'UUUU3' } ), 'selection': { 'features': ['pressure', 'temperature', 'wind_speed_U', 'wind_speed_V'], 'month': [str(i) for i in range(1, 3)], 'year': [str(i) for i in range(2015, 2017)] } } config_obj = Config.from_dict(config) actual = self.create_partition_configs([config_obj]) expected = [Config.from_dict(it) for it in [ {'parameters': dict(config['parameters'], api_key='KKKK1', api_url='UUUU1', subsection_name='research'), 'selection': {**config['selection'], **{'year': ['2015'], 'month': ['1']}}}, {'parameters': dict(config['parameters'], api_key='KKKK2', api_url='UUUU2', subsection_name='cloud'), 'selection': {**config['selection'], **{'year': ['2015'], 'month': ['2']}}}, {'parameters': dict(config['parameters'], api_key='KKKK3', api_url='UUUU3', subsection_name='deepmind'), 'selection': {**config['selection'], **{'year': ['2016'], 'month': ['1']}}}, {'parameters': dict(config['parameters'], api_key='KKKK1', api_url='UUUU1', subsection_name='research'), 'selection': {**config['selection'], **{'year': ['2016'], 'month': ['2']}}}, ]] self.assertListEqual(actual, expected) def test_prepare_partition_records_download_status_to_manifest(self): config = { 'parameters': { 'partition_keys': ['year'], 'target_path': 'download-{}.nc', }, 'selection': { 'features': ['pressure', 'temperature', 'wind_speed_U', 'wind_speed_V'], 'month': [str(i) for i in range(1, 13)], 'year': [str(i) for i in range(2015, 2021)] } } config_obj = Config.from_dict(config) with tempfile.TemporaryDirectory() as tmpdir: self.dummy_manifest = LocalManifest(Location(tmpdir)) self.create_partition_configs([config_obj]) with open(self.dummy_manifest.location, 'r') as f: actual = json.load(f) self.assertListEqual( [d['selection'] for d in actual.values()], [ json.dumps({**config['selection'], **{'year': [str(i)]}}) for i in range(2015, 2021) ]) self.assertTrue( all([d['status'] == 'scheduled' for d in actual.values()]) ) def test_prepare_partition_records_download_status_to_manifest_for_already_downloaded_shard(self): config = { 'parameters': { 'partition_keys': ['year'], 'target_path': 'download-{}.nc', }, 'selection': { 'features': ['pressure', 'temperature', 'wind_speed_U', 'wind_speed_V'], 'month': [str(i) for i in range(1, 13)], 'year': [str(i) for i in range(2015, 2021)] } } config_obj = Config.from_dict(config) with tempfile.TemporaryDirectory() as tmpdir: self.mock_store = InMemoryStore() self.mock_store.open('download-2015.nc') self.dummy_manifest = LocalManifest(Location(tmpdir)) self.create_partition_configs(configs=[config_obj], store=self.mock_store) with open(self.dummy_manifest.location, 'r') as f: actual = json.load(f) self.assertListEqual( [d['selection'] for d in actual.values()], [ json.dumps({**config['selection'], **{'year': [str(i)]}}) for i in range(2015, 2021) ]) self.assertTrue( all([d['status'] == 'scheduled' if d['location'] != 'download-2015.nc' else d['stage'] == 'upload' and d['status'] == 'success' for d in actual.values()]) ) def test_prepare_partition_update_download_status_for_downloaded_shard_missing_upload_entry(self): config = { 'parameters': { 'partition_keys': ['year'], 'target_path': 'download-{}.nc', }, 'selection': { 'features': ['pressure', 'temperature', 'wind_speed_U', 'wind_speed_V'], 'month': [str(i) for i in range(1, 13)], 'year': [str(i) for i in range(2015, 2021)] } } config_obj = Config.from_dict(config) with tempfile.TemporaryDirectory() as tmpdir: self.mock_store = InMemoryStore() self.mock_store.open('download-2015.nc') self.dummy_manifest = LocalManifest(Location(tmpdir)) download_status = DownloadStatus(selection={**config['selection'], **{'year': [2015]}}, location='download-2015.nc', status=Status.SUCCESS, stage=Stage.DOWNLOAD) self.dummy_manifest._update(download_status) self.create_partition_configs(configs=[config_obj], store=self.mock_store) with open(self.dummy_manifest.location, 'r') as f: actual = json.load(f) self.assertListEqual( [d['selection'] for d in actual.values()], [ json.dumps({**config['selection'], **{'year': [str(i)]}}) for i in range(2015, 2021) ]) self.assertTrue( all([d['status'] == 'scheduled' if d['location'] != 'download-2015.nc' else d['stage'] == 'upload' and d['status'] == 'success' for d in actual.values()]) ) def test_prepare_partition_update_manifest_for_failed_upload_status_of_downloaded_shard(self): config = { 'parameters': { 'partition_keys': ['year'], 'target_path': 'download-{}.nc', }, 'selection': { 'features': ['pressure', 'temperature', 'wind_speed_U', 'wind_speed_V'], 'month': [str(i) for i in range(1, 13)], 'year': [str(i) for i in range(2015, 2021)] } } config_obj = Config.from_dict(config) with tempfile.TemporaryDirectory() as tmpdir: self.mock_store = InMemoryStore() self.mock_store.open('download-2015.nc') self.dummy_manifest = LocalManifest(Location(tmpdir)) download_status = DownloadStatus(selection={**config['selection'], **{'year': [2015]}}, location='download-2015.nc', status=Status.FAILURE, error='error', stage=Stage.UPLOAD) self.dummy_manifest._update(download_status) self.create_partition_configs(configs=[config_obj], store=self.mock_store) with open(self.dummy_manifest.location, 'r') as f: actual = json.load(f) self.assertListEqual( [d['selection'] for d in actual.values()], [ json.dumps({**config['selection'], **{'year': [str(i)]}}) for i in range(2015, 2021) ]) self.assertTrue( all([d['status'] == 'scheduled' if d['location'] != 'download-2015.nc' else d['stage'] == 'upload' and d['status'] == 'success' for d in actual.values()]) ) def test_skip_partitions__never_unbalances_licenses(self): skip_odd_files = OddFilesDoNotExistStore() config = { 'parameters': dict( partition_keys=['year', 'month'], target_path='download-{}-{}.nc', research={ 'api_key': 'KKKK1', 'api_url': 'UUUU1' }, cloud={ 'api_key': 'KKKK2', 'api_url': 'UUUU2' } ), 'selection': { 'features': ['pressure', 'temperature', 'wind_speed_U', 'wind_speed_V'], 'month': [str(i) for i in range(1, 3)], 'year': [str(i) for i in range(2016, 2020)] } } config_obj = Config.from_dict(config) actual = self.create_partition_configs([config_obj], store=skip_odd_files) research_configs = [cfg for cfg in actual if cfg and t.cast('str', cfg.kwargs.get('api_url', "")).endswith('1')] cloud_configs = [cfg for cfg in actual if cfg and t.cast('str', cfg.kwargs.get('api_url', "")).endswith('2')] self.assertEqual(len(research_configs), len(cloud_configs)) def test_multi_config_partition_single_key(self): configs = [ Config.from_dict({ 'parameters': { 'partition_keys': ['year'], 'target_path': 'download-{}-%d.nc' % level, }, 'selection': { 'features': ['pressure', 'temperature', 'wind_speed_U', 'wind_speed_V'], 'month': [str(i) for i in range(1, 13)], 'year': [str(i) for i in range(2015, 2021)], 'level': [str(level)] } }) for level in range(500, 901, 100) ] actual = self.create_partition_configs(configs) self.assertListEqual([d.selection for d in actual], [ {**configs[0].selection, **{'year': [str(i)], 'level': [str(level)]}} for level in range(500, 901, 100) for i in range(2015, 2021) ]) def test_multi_config_partition_single_key_fair_schedule(self): configs = [ Config.from_dict({ 'parameters': { 'partition_keys': ['year'], 'target_path': 'download-{}-%d.nc' % level, }, 'selection': { 'features': ['pressure', 'temperature', 'wind_speed_U', 'wind_speed_V'], 'month': [str(i) for i in range(1, 13)], 'year': [str(i) for i in range(2015, 2021)], 'level': [str(level)] } }) for level in range(500, 901, 100) ] actual = self.create_partition_configs(configs, schedule='fair') self.assertListEqual([d.selection for d in actual], [ {**configs[0].selection, **{'year': [str(i)], 'level': [str(level)]}} for i in range(2015, 2021) for level in range(500, 901, 100) ]) def test_multi_config_partition_multi_params_multi_key(self): config_dicts = [ { 'parameters': dict( partition_keys=['year', 'month'], target_path='download-{}-{}-%d.nc' % level, research={ 'api_key': 'KKKK1', 'api_url': 'UUUU1' }, cloud={ 'api_key': 'KKKK2', 'api_url': 'UUUU2' }, deepmind={ 'api_key': 'KKKK3', 'api_url': 'UUUU3' } ), 'selection': { 'features': ['pressure', 'temperature', 'wind_speed_U', 'wind_speed_V'], 'month': [str(i) for i in range(1, 3)], 'year': [str(i) for i in range(2015, 2017)], 'level': [str(level)] } } for level in range(500, 901, 100) ] actual = self.create_partition_configs([Config.from_dict(c) for c in config_dicts]) combinations = [ { 'parameters': config['parameters'], 'selection': {**config['selection'], **{'year': [str(year)], 'month': [str(month)]}} } for config in config_dicts for year in range(2015, 2017) for month in range(1, 3) ] subsections = get_subsections(Config.from_dict(config_dicts[0])) expected = [] for config, (name, params) in zip(combinations, itertools.cycle(subsections)): config['parameters'].update(params) config['parameters']['subsection_name'] = name expected.append(Config.from_dict(config)) self.assertListEqual(actual, expected) def test_multi_config_partition_multi_params_fair_schedule(self): config_dicts = [ { 'parameters': dict( partition_keys=['year'], target_path='download-{}-%d.nc' % level, research={ 'api_key': 'KKKK1', 'api_url': 'UUUU1' }, cloud={ 'api_key': 'KKKK2', 'api_url': 'UUUU2' }, deepmind={ 'api_key': 'KKKK3', 'api_url': 'UUUU3' } ), 'selection': { 'features': ['pressure', 'temperature', 'wind_speed_U', 'wind_speed_V'], 'month': [str(1)], 'year': [str(i) for i in range(2015, 2017)], 'level': [str(level)] } } for level in range(500, 901, 100) ] # Three licenses, with one request per license (see groups) actual = self.create_partition_configs([Config.from_dict(c) for c in config_dicts], schedule='fair', n_requests_per=3) combinations = [ { 'parameters': config['parameters'], 'selection': {**config['selection'], **{'year': [str(year)]}} } for config in config_dicts for year in range(2015, 2017) ] subsections = get_subsections(Config.from_dict(config_dicts[0])) expected = [] for config, (name, params) in zip(combinations, itertools.cycle(subsections)): config['parameters'].update(params) config['parameters']['subsection_name'] = name expected.append(Config.from_dict(config)) self.assertListEqual(actual, expected) def test_multi_config_partition_multi_params_keys_and_requests_with_fair_schedule(self): self.maxDiff = None config_dicts = [ { 'parameters': dict( partition_keys=['year', 'month'], target_path='download-{}-{}-%d.nc' % level, research={ 'api_key': 'KKKK1', 'api_url': 'UUUU1' }, cloud={ 'api_key': 'KKKK2', 'api_url': 'UUUU2' }, deepmind={ 'api_key': 'KKKK3', 'api_url': 'UUUU3' } ), 'selection': { 'features': ['pressure', 'temperature', 'wind_speed_U', 'wind_speed_V'], 'month': [str(i) for i in range(1, 3)], 'year': [str(i) for i in range(2015, 2017)], 'level': [str(level)] } } for level in range(500, 901, 100) ] # Three licenses, with two requests per license (see groups) actual = self.create_partition_configs([Config.from_dict(c) for c in config_dicts], schedule='fair', n_requests_per=3 * 2) combinations = [ { 'parameters': config['parameters'], 'selection': {**config['selection'], **{'year': [str(year)], 'month': [str(month)]}} } for config in config_dicts for year in range(2015, 2017) for month in range(1, 3) ] subsections = get_subsections(Config.from_dict(config_dicts[0])) expected = [] for config, (name, params) in zip(combinations, itertools.cycle(subsections)): config['parameters'].update(params) config['parameters']['subsection_name'] = name expected.append(Config.from_dict(config)) self.assertListEqual(actual, expected) def test_hdate_partition_single_key(self): config = Config.from_dict({ 'parameters': { 'partition_keys': ['date'], 'target_path': 'download-{}.nc', }, 'selection': { 'features': ['pressure', 'temperature', 'wind_speed_U', 'wind_speed_V'], 'date': ['2016-01-04', '2016-01-07', '2016-01-11'], 'hdate': ['1', '2', '3'], } }) actual = self.create_partition_configs([config]) expected = [{'date': ['2016-01-04'], 'hdate': ['2015-01-04', '2014-01-04', '2013-01-04']}, {'date': ['2016-01-07'], 'hdate': ['2015-01-07', '2014-01-07', '2013-01-07']}, {'date': ['2016-01-11'], 'hdate': ['2015-01-11', '2014-01-11', '2013-01-11']}, ] self.assertListEqual([d.selection for d in actual], [{**config.selection, **e} for e in expected]) def test_partition_with_no_keys(self): config = Config.from_dict({ 'parameters': { 'target_path': 'download.nc', }, 'selection': { 'features': ['2m_temperature'], 'date': ['2017-01-01', '2016-01-02', '2016-01-03'], } }) actual = self.create_partition_configs([config]) expected = [config] self.assertListEqual(actual, expected) class SkipPartitionsTest(unittest.TestCase): def setUp(self) -> None: self.mock_store = InMemoryStore() self.dummy_manifest = MockManifest(Location('mock://dummy')) def test_skip_partition_missing_force_download(self): config = { 'parameters': { 'partition_keys': ['year', 'month'], 'target_path': 'download-{}-{}.nc', }, 'selection': { 'features': ['pressure', 'temperature', 'wind_speed_U', 'wind_speed_V'], 'month': [str(i) for i in range(1, 13)], 'year': [str(i) for i in range(2015, 2021)] } } config_obj = Config.from_dict(config) actual = skip_partition(config_obj, self.mock_store, self.dummy_manifest) self.assertEqual(actual, False) def test_skip_partition_force_download_true(self): config = { 'parameters': { 'partition_keys': ['year', 'month'], 'target_path': 'download-{}-{}.nc', 'force_download': True }, 'selection': { 'features': ['pressure', 'temperature', 'wind_speed_U', 'wind_speed_V'], 'month': [str(i) for i in range(1, 13)], 'year': [str(i) for i in range(2015, 2021)] } } config_obj = Config.from_dict(config) actual = skip_partition(config_obj, self.mock_store, self.dummy_manifest) self.assertEqual(actual, False) def test_skip_partition_force_download_false(self): config = { 'parameters': { 'partition_keys': ['year', 'month'], 'target_path': 'download-{}-{}.nc', 'force_download': False }, 'selection': { 'features': ['pressure'], 'month': ['12'], 'year': ['02'] } } config_obj = Config.from_dict(config) self.mock_store.exists = MagicMock(return_value=True) actual = skip_partition(config_obj, self.mock_store, self.dummy_manifest) self.assertEqual(actual, True) if __name__ == '__main__': unittest.main() ================================================ FILE: weather_dl/download_pipeline/pipeline.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Primary ECMWF Downloader Workflow.""" import argparse import dataclasses import getpass import itertools import logging import os import typing as t import apache_beam as beam from apache_beam.options.pipeline_options import ( PipelineOptions, StandardOptions, WorkerOptions, ) from .clients import CLIENTS from .config import Config from .fetcher import Fetcher from .manifest import ( Location, LocalManifest, Manifest, NoOpManifest, ) from .parsers import ( parse_manifest, process_config, get_subsections, validate_all_configs, ) from .partition import PartitionConfig from .stores import TempFileStore, LocalFileStore logger = logging.getLogger(__name__) def configure_logger(verbosity: int) -> None: """Configures logging from verbosity. Default verbosity will show errors.""" level = 40 - verbosity * 10 logger = logging.getLogger(__package__) fmt = '%(levelname)s %(asctime)s %(name)s: %(message)s' datefmt = '%y-%m-%d %H:%M:%S' formatter = logging.Formatter(fmt=fmt, datefmt=datefmt) handler = logging.StreamHandler() handler.setFormatter(formatter) logger.root.addHandler(handler) logger.setLevel(level) @dataclasses.dataclass class PipelineArgs: """Options for download pipeline. Attributes: known_args: Parsed arguments. Includes user-defined args and defaults. pipeline_options: The apache_beam pipeline options. configs: The download configs / data requests. client_name: The type of download client (e.g. Copernicus, Mars, or a fake). store: A Store, which is responsible for where downloads end up. manifest: A Manifest, which records download progress. num_requesters_per_key: Number of requests per subsection (license). """ known_args: argparse.Namespace pipeline_options: PipelineOptions configs: t.List[Config] client_name: str store: None manifest: Manifest num_requesters_per_key: int def pipeline(args: PipelineArgs) -> None: """Main pipeline entrypoint.""" import builtins import typing as t logger.info(f"Using '{args.num_requesters_per_key}' requests per subsection (license).") subsections = get_subsections(args.configs[0]) # Capping the max number of workers to N i.e. possible simultaneous requests + fudge factor if args.pipeline_options.view_as(WorkerOptions).max_num_workers is None: max_num_workers = len(subsections) * args.num_requesters_per_key + 10 args.pipeline_options.view_as(WorkerOptions).max_num_workers = max_num_workers logger.info(f"Capped the max number of workers to '{max_num_workers}'.") request_idxs = {name: itertools.cycle(range(args.num_requesters_per_key)) for name, _ in subsections} def subsection_and_request(it: Config) -> t.Tuple[str, int]: subsection = it.subsection_name return subsection, builtins.next(request_idxs[subsection]) subsections_cycle = itertools.cycle(subsections) partition = PartitionConfig(args.store, subsections_cycle, args.manifest, args.known_args.schedule, args.known_args.partition_chunks, args.known_args.update_manifest, len(subsections) * args.num_requesters_per_key) with beam.Pipeline(options=args.pipeline_options) as p: partitions = ( p | 'Create Configs' >> beam.Create(args.configs) | 'Prepare Partitions' >> partition ) # When the --update_manifest flag is passed, the tool will only update the manifest # for already downloaded shards and then exit. if not args.known_args.update_manifest: ( partitions | 'GroupBy Request Limits' >> beam.GroupBy(subsection_and_request) | 'Fetch Data' >> beam.ParDo(Fetcher(args.client_name, args.manifest, args.store, args.known_args.log_level)) ) def run(argv: t.List[str], save_main_session: bool = True) -> PipelineArgs: """Parse user arguments and configure the pipeline.""" parser = argparse.ArgumentParser( prog='weather-dl', description='Weather Downloader ingests weather data to cloud storage.' ) parser.add_argument('config', type=str, nargs='+', help="path/to/configs.cfg, containing client and data information. Can take multiple configs." "Accepts *.cfg and *.json files.") parser.add_argument('-f', '--force-download', action="store_true", default=False, help="Force redownload of partitions that were previously downloaded.") parser.add_argument('-d', '--dry-run', action='store_true', default=False, help='Run pipeline steps without _actually_ downloading or writing to cloud storage.') parser.add_argument('-l', '--local-run', action='store_true', default=False, help="Run pipeline locally, downloads to local hard drive.") parser.add_argument('-m', '--manifest-location', type=Location, default='cli://manifest', help="Location of the manifest. By default, it will use Cloud Logging (stdout for direct " "runner). You can set the name of the manifest as the hostname of a URL with the 'cli' " "protocol. For example, 'cli://manifest' will prefix all the manifest logs as " "'[manifest]'. In addition, users can specify either a Firestore collection URI " "('fs://?projectId='), or BigQuery table " "('bq://..') [Note: Tool will create the BQ table " "itself, if not already present. Or it will use the existing table but can report errors " "in case of schema mismatch.], or 'noop://' for an in-memory location.") parser.add_argument('-n', '--num-requests-per-key', type=int, default=-1, help='Number of concurrent requests to make per API key. ' 'Default: make an educated guess per client & config. ' 'Please see the client documentation for more details.') parser.add_argument('-p', '--partition-chunks', type=int, default=None, help='Group shards into chunks of this size when computing the partitions. Specifically, ' 'this controls how we chunk elements in a cartesian product, which affects ' "parallelization of that step. Default: chunks of 1000 elements for 'in-order' scheduling." " Chunks of 1 element for 'fair' scheduling.") parser.add_argument('-s', '--schedule', choices=['in-order', 'fair'], default='in-order', help="When using multiple configs, decide how partitions are scheduled: 'in-order' implies " "that partitions will be processed in sequential order of each config; 'fair' means that " "partitions from each config will be interspersed evenly. " "Note: When using 'fair' scheduling, we recommend you set the '--partition-chunks' to a " "much smaller number. Default: 'in-order'.") parser.add_argument('--check-skip-in-dry-run', action='store_true', default=False, help="To enable file skipping logic in dry-run mode. Default: 'false'.") parser.add_argument('-u', '--update-manifest', action='store_true', default=False, help="Update the manifest for the already downloaded shards and exit. Default: 'false'.") parser.add_argument('--log-level', type=int, default=2, help='An integer to configure log level. Default: 2(INFO)') parser.add_argument('--use-local-code', action='store_true', default=False, help='Supply local code to the Runner.') known_args, pipeline_args = parser.parse_known_args(argv[1:]) configure_logger(known_args.log_level) # 0 = error, 1 = warn, 2 = info, 3 = debug configs = [] for cfg in known_args.config: with open(cfg, 'r', encoding='utf-8') as f: # configs/example.cfg -> example.cfg config_name = os.path.split(cfg)[1] config = process_config(f, config_name) config.force_download = known_args.force_download config.user_id = getpass.getuser() configs.append(config) # This enables support for updating just the manifest for multiple config (*.cfg) # when running the tool with the '-u' or '--update-manifest' flag. if not known_args.update_manifest: validate_all_configs(configs) if known_args.check_skip_in_dry_run and not known_args.dry_run: raise RuntimeError('--check-skip-in-dry-run can only be used along with --dry-run flag.') # We use the save_main_session option because one or more DoFn's in this # workflow rely on global context (e.g., a module imported at module level). save_main_session_args = ['--save_main_session'] + ['True' if save_main_session else 'False'] pipeline_options = PipelineOptions(pipeline_args + save_main_session_args) client_name = config.client store = None # will default to using FileSystems() manifest = parse_manifest(known_args.manifest_location, pipeline_options.get_all_options()) if known_args.dry_run: client_name = 'fake' if not known_args.check_skip_in_dry_run: store = TempFileStore('dry_run') logger.warning('File skipping logic is disabled by default in dry-run mode.' 'To enable please pass the flag --check-skip-in-dry-run along with the' 'dry run flag.') for config in configs: config.force_download = True manifest = NoOpManifest(Location('noop://dry-run')) if known_args.local_run: local_dir = '{}/local_run'.format(os.getcwd()) store = LocalFileStore(local_dir) pipeline_options.view_as(StandardOptions).runner = 'DirectRunner' manifest = LocalManifest(Location(local_dir)) num_requesters_per_key = known_args.num_requests_per_key known_args.log_level = 40 - known_args.log_level * 10 client = CLIENTS[client_name](configs[0], known_args.log_level) if num_requesters_per_key == -1: num_requesters_per_key = client.num_requests_per_key(config.dataset) logger.warning(f'By using {client_name} datasets, ' f'users agree to the terms and conditions specified in {client.license_url!r}') return PipelineArgs( known_args, pipeline_options, configs, client_name, store, manifest, num_requesters_per_key ) ================================================ FILE: weather_dl/download_pipeline/pipeline_test.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import copy import dataclasses import getpass import os import typing as t import unittest from apache_beam.options.pipeline_options import PipelineOptions import weather_dl from .config import Config from .manifest import Location, NoOpManifest, LocalManifest, ConsoleManifest from .pipeline import run, PipelineArgs from .stores import TempFileStore, LocalFileStore PATH_TO_CONFIG = os.path.join(os.path.dirname(list(weather_dl.__path__)[0]), 'configs', 'era5_example_config.cfg') CONFIG = { 'parameters': {'client': 'cds', 'dataset': 'reanalysis-era5-pressure-levels', 'target_path': 'gs://ecmwf-output-test/era5/{year:04d}/{month:02d}/{day:02d}' '-pressure-{pressure_level}.nc', 'partition_keys': ['year', 'month', 'day', 'pressure_level'], 'force_download': False, 'user_id': getpass.getuser(), 'api_url': 'https://cds.climate.copernicus.eu/api', 'api_key': '12345:1234567-ab12-34cd-9876-4o4fake90909', # fake key for testing. }, 'selection': {'product_type': 'reanalysis', 'format': 'netcdf', 'variable': ['divergence', 'fraction_of_cloud_cover', 'geopotential'], 'pressure_level': ['500'], 'year': ['2015', '2016', '2017'], 'month': ['01'], 'day': ['01', '15'], 'time': ['00:00', '06:00', '12:00', '18:00']} } DEFAULT_ARGS = PipelineArgs( known_args=argparse.Namespace(config=[PATH_TO_CONFIG], force_download=False, dry_run=False, local_run=False, manifest_location='cli://manifest', num_requests_per_key=-1, partition_chunks=None, schedule='in-order', check_skip_in_dry_run=False, update_manifest=False, log_level=20, use_local_code=False), pipeline_options=PipelineOptions('--save_main_session True'.split()), configs=[Config.from_dict(CONFIG)], client_name='cds', store=None, manifest=ConsoleManifest(Location('cli://manifest')), num_requesters_per_key=5, ) def default_args(parameters: t.Optional[t.Dict] = None, selection: t.Optional[t.Dict] = None, known_args: t.Optional[t.Dict] = None, **kwargs) -> PipelineArgs: if parameters is None: parameters = {} if selection is None: selection = {} if known_args is None: known_args = {} args = dataclasses.replace(DEFAULT_ARGS, **kwargs) temp_config = copy.deepcopy(CONFIG) temp_config['parameters'].update(parameters) temp_config['selection'].update(selection) args.configs = [Config.from_dict(temp_config)] args.configs[0].config_name = 'era5_example_config.cfg' args.configs[0].user_id = getpass.getuser() args.configs[0].force_download = parameters.get('force_download', False) args.known_args = copy.deepcopy(args.known_args) for k, v in known_args.items(): setattr(args.known_args, k, v) return args class ParsePipelineArgs(unittest.TestCase): DEFAULT_CMD = f'weather-dl {PATH_TO_CONFIG}' def assert_pipeline(self, args, expected): actual = run(args.split()) self.assertEqual(vars(actual.known_args), vars(expected.known_args)) self.assertEqual( actual.pipeline_options.get_all_options(drop_default=True), expected.pipeline_options.get_all_options(drop_default=True) ) self.assertEqual(actual.configs, expected.configs) self.assertEqual(actual.client_name, expected.client_name) self.assertEqual(type(actual.store), type(expected.store)) self.assertEqual(actual.manifest, expected.manifest) self.assertEqual(type(actual.manifest), type(expected.manifest)) self.assertEqual(actual.num_requesters_per_key, expected.num_requesters_per_key) def test_happy_path(self): self.assert_pipeline(self.DEFAULT_CMD, default_args()) def test_force_download(self): self.assert_pipeline( f'{self.DEFAULT_CMD} -f', default_args(dict(force_download=True), known_args=dict(force_download=True)) ) def test_dry_run(self): self.assert_pipeline( f'{self.DEFAULT_CMD} -d', default_args( dict(force_download=True), known_args=dict(dry_run=True), client_name='fake', store=TempFileStore('dry_run'), manifest=NoOpManifest(Location('noop://dry-run')), num_requesters_per_key=1 ) ) def test_local_run(self): self.assert_pipeline( f'{self.DEFAULT_CMD} -l', default_args( known_args=dict(local_run=True), store=LocalFileStore(f'{os.getcwd()}/local_run'), manifest=LocalManifest(Location(f'{os.getcwd()}/local_run')), pipeline_options=PipelineOptions('--runner DirectRunner --save_main_session True'.split()) ) ) def test_update_manifest(self): self.assert_pipeline( f'{self.DEFAULT_CMD} -u', default_args(known_args=dict(update_manifest=True)) ) def test_user_specified_num_requests_per_key(self): self.assert_pipeline( f'{self.DEFAULT_CMD} -n 7', default_args( known_args=dict(num_requests_per_key=7), num_requesters_per_key=7 ) ) def test_check_skip_in_dry_run_raise_error_if_dry_run_flag_is_absent(self): with self.assertRaisesRegex(RuntimeError, 'can only be used along with --dry-run flag.'): run(f'{self.DEFAULT_CMD} --check-skip-in-dry-run'.split()) if __name__ == '__main__': unittest.main() ================================================ FILE: weather_dl/download_pipeline/stores.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Download destinations, or `Store`s.""" import abc import io import os import tempfile import typing as t from apache_beam.io.filesystems import FileSystems from .util import retry_with_exponential_backoff class Store(abc.ABC): """A interface to represent where downloads are stored. Default implementation uses Apache Beam's Filesystems. """ @abc.abstractmethod def open(self, filename: str, mode: str = 'r') -> t.IO: pass @abc.abstractmethod def exists(self, filename: str) -> bool: pass class InMemoryStore(Store): """Store file data in memory.""" def __init__(self): self.store = {} def open(self, filename: str, mode: str = 'r') -> t.IO: """Create or read in-memory data.""" if 'b' in mode: file = io.BytesIO() else: file = io.StringIO() self.store[filename] = file return file def exists(self, filename: str) -> bool: """Return true if the 'file' exists in memory.""" return filename in self.store class TempFileStore(Store): """Store data into temporary files.""" def __init__(self, directory: t.Optional[str] = None) -> None: """Optionally specify the directory that contains all temporary files.""" self.dir = directory if self.dir and not os.path.exists(self.dir): os.makedirs(self.dir) def open(self, filename: str, mode: str = 'r') -> t.IO: """Create a temporary file in the store directory.""" return tempfile.TemporaryFile(mode, dir=self.dir) def exists(self, filename: str) -> bool: """Return true if file exists.""" return os.path.exists(filename) class LocalFileStore(Store): """Store data into local files.""" def __init__(self, directory: t.Optional[str] = None) -> None: """Optionally specify the directory that contains all downloaded files.""" self.dir = directory if self.dir and not os.path.exists(self.dir): os.makedirs(self.dir) def open(self, filename: str, mode: str = 'r') -> t.IO: """Open a local file from the store directory.""" return open(os.sep.join([self.dir, filename]), mode) def exists(self, filename: str) -> bool: """Returns true if local file exists.""" return os.path.exists(os.sep.join([self.dir, filename])) class FSStore(Store): """Store data into any store supported by Apache Beam's FileSystems.""" @retry_with_exponential_backoff def open(self, filename: str, mode: str = 'r') -> t.IO: """Open object in cloud bucket (or local file system) as a read or write channel. To work with cloud storage systems, only a read or write channel can be openend at one time. Data will be treated as bytes, not text (equivalent to `rb` or `wb`). Further, append operations, or writes on existing objects, are dissallowed (the error thrown will depend on the implementation of the underlying cloud provider). """ if 'r' in mode and 'w' not in mode: return FileSystems().open(filename) if 'w' in mode and 'r' not in mode: return FileSystems().create(filename) raise ValueError( f"invalid mode {mode!r}: mode must have either 'r' or 'w', but not both." ) def exists(self, filename: str) -> bool: """Returns true if object exists.""" return FileSystems().exists(filename) ================================================ FILE: weather_dl/download_pipeline/stores_test.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import tempfile import unittest from .stores import FSStore class FSStoreTest(unittest.TestCase): def test_writes(self): with tempfile.TemporaryDirectory() as tmpdir: target = f'{tmpdir}/my-file' with FSStore().open(target, 'wb') as f: f.write(b'{"key": 1}') self.assertTrue(os.path.exists(target)) def test_reads(self): with tempfile.TemporaryDirectory() as tmpdir: target = f'{tmpdir}/my-file' with open(target, 'w') as f: f.write('data') with FSStore().open(target, 'rb') as f: self.assertEqual(f.readlines(), [b'data']) def test_reads__default_argument(self): with tempfile.TemporaryDirectory() as tmpdir: target = f'{tmpdir}/my-file' with open(target, 'w') as f: f.write('data') with FSStore().open(target) as f: self.assertEqual(f.readlines(), [b'data']) def test_asserts_bad_mode__both(self): with self.assertRaisesRegex(ValueError, "invalid mode 'rw':"): with tempfile.TemporaryDirectory() as tmpdir: target = f'{tmpdir}/my-file' with FSStore().open(target, 'rw') as f: f.read() def test_asserts_bad_mode__neither(self): with self.assertRaisesRegex(ValueError, "invalid mode '':"): with tempfile.TemporaryDirectory() as tmpdir: target = f'{tmpdir}/my-file' with FSStore().open(target, '') as f: f.read() ================================================ FILE: weather_dl/download_pipeline/util.py ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime from dateutil.relativedelta import relativedelta import geojson import hashlib import itertools import logging import os import shutil import socket import subprocess import sys import typing as t import numpy as np import pandas as pd from apache_beam.io.gcp import gcsio from apache_beam.utils import retry from xarray.core.utils import ensure_us_time_resolution from urllib.parse import urlparse from google.api_core.exceptions import BadRequest logger = logging.getLogger(__name__) LATITUDE_RANGE = (-90, 90) LONGITUDE_RANGE = (-180, 180) GLOBAL_COVERAGE_AREA = [90, -180, -90, 180] def _retry_if_valid_input_but_server_or_socket_error_and_timeout_filter(exception) -> bool: if isinstance(exception, socket.timeout): return True if isinstance(exception, TimeoutError): return True # To handle the concurrency issue in BigQuery. if isinstance(exception, BadRequest): return True return retry.retry_if_valid_input_but_server_error_and_timeout_filter(exception) class _FakeClock: def sleep(self, value): pass def retry_with_exponential_backoff(fun): """A retry decorator that doesn't apply during test time.""" clock = retry.Clock() # Use a fake clock only during test time... if 'unittest' in sys.modules.keys(): clock = _FakeClock() return retry.with_exponential_backoff( retry_filter=_retry_if_valid_input_but_server_or_socket_error_and_timeout_filter, clock=clock, )(fun) # TODO(#245): Group with common utilities (duplicated) def ichunked(iterable: t.Iterable, n: int) -> t.Iterator[t.Iterable]: """Yield evenly-sized chunks from an iterable.""" input_ = iter(iterable) try: while True: it = itertools.islice(input_, n) # peek to check if 'it' has next item. first = next(it) yield itertools.chain([first], it) except StopIteration: pass # TODO(#245): Group with common utilities (duplicated) def copy(src: str, dst: str) -> None: """Copy data via `gsutil` or local filesystem.""" is_gs = src.startswith("gs://") or dst.startswith("gs://") try: if is_gs: subprocess.run(['gcloud', 'storage', 'cp', src, dst], check=True, capture_output=True, text=True, input="n/n") else: os.makedirs(os.path.dirname(dst) or '.', exist_ok=True) shutil.copy(src, dst) except Exception as e: error_detail = getattr(e, "stderr", str(e)).strip() msg = f"Failed to copy {src!r} to {dst!r} due to {error_detail}" logger.error(msg) raise EnvironmentError(msg) from e # TODO(#245): Group with common utilities (duplicated) def to_json_serializable_type(value: t.Any) -> t.Any: """Returns the value with a type serializable to JSON""" # Note: The order of processing is significant. logger.debug('Serializing to JSON') if pd.isna(value) or value is None: return None elif np.issubdtype(type(value), np.floating): return float(value) elif isinstance(value, np.ndarray): # Will return a scaler if array is of size 1, else will return a list. return value.tolist() elif isinstance(value, datetime.datetime) or isinstance(value, str) or isinstance(value, np.datetime64): # Assume strings are ISO format timestamps... try: value = datetime.datetime.fromisoformat(value) except ValueError: # ... if they are not, assume serialization is already correct. return value except TypeError: # ... maybe value is a numpy datetime ... try: value = ensure_us_time_resolution(value).astype(datetime.datetime) except AttributeError: # ... value is a datetime object, continue. pass # We use a string timestamp representation. if value.tzname(): return value.isoformat() # We assume here that naive timestamps are in UTC timezone. return value.replace(tzinfo=datetime.timezone.utc).isoformat() elif isinstance(value, np.timedelta64): # Return time delta in seconds. return float(value / np.timedelta64(1, 's')) # This check must happen after processing np.timedelta64 and np.datetime64. elif np.issubdtype(type(value), np.integer): return int(value) return value def fetch_geo_polygon(area: t.Union[list, str]) -> str: """Calculates a geography polygon from an input area.""" # Ref: https://confluence.ecmwf.int/pages/viewpage.action?pageId=151520973 if isinstance(area, str): # European area if area == 'E': area = [73.5, -27, 33, 45] # Global area elif area == 'G': area = GLOBAL_COVERAGE_AREA else: raise RuntimeError(f'Not a valid value for area in config: {area}.') n, w, s, e = [float(x) for x in area] if s < LATITUDE_RANGE[0]: raise ValueError(f"Invalid latitude value for south: '{s}'") if n > LATITUDE_RANGE[1]: raise ValueError(f"Invalid latitude value for north: '{n}'") if w < LONGITUDE_RANGE[0]: raise ValueError(f"Invalid longitude value for west: '{w}'") if e > LONGITUDE_RANGE[1]: raise ValueError(f"Invalid longitude value for east: '{e}'") # Define the coordinates of the bounding box. coords = [[w, n], [w, s], [e, s], [e, n], [w, n]] # Create the GeoJSON polygon object. polygon = geojson.dumps(geojson.Polygon([coords])) return polygon def get_file_size(path: str) -> float: parsed_gcs_path = urlparse(path) if parsed_gcs_path.scheme != 'gs' or parsed_gcs_path.netloc == '': return os.stat(path).st_size / (1024 ** 3) if os.path.exists(path) else 0 else: return gcsio.GcsIO().size(path) / (1024 ** 3) if gcsio.GcsIO().exists(path) else 0 def get_wait_interval(num_retries: int = 0) -> float: """Returns next wait interval in seconds, using an exponential backoff algorithm.""" if 0 == num_retries: return 0 return 2 ** num_retries def generate_md5_hash(input: str) -> str: """Generates md5 hash for the input string.""" return hashlib.md5(input.encode('utf-8')).hexdigest() def download_with_aria2(url: str, path: str) -> None: """Downloads a file from the given URL using the `aria2c` command-line utility, with options set to improve download speed and reliability.""" dir_path, file_name = os.path.split(path) try: subprocess.run( ['aria2c', '-x', '16', '-s', '16', url, '-d', dir_path, '-o', file_name, '--allow-overwrite'], check=True, capture_output=True) except subprocess.CalledProcessError as e: logger.error(f'Failed download from server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}') raise def generate_hdate(date: str, subtract_year: str) -> str: """Generate a historical date by subtracting a specified number of years from the given date. If input date is leap day (Feb 29), return Feb 28 even if target hdate is also a leap year. This is expected in ECMWF API. Args: date (str): The input date in the format 'YYYY-MM-DD'. subtract_year (str): The number of years to subtract. Returns: str: The historical date in the format 'YYYY-MM-DD'. """ try: input_date = datetime.datetime.strptime(date, "%Y-%m-%d") # Check for leap day if input_date.month == 2 and input_date.day == 29: input_date = input_date - datetime.timedelta(days=1) subtract_year = int(subtract_year) except (ValueError, TypeError): logger.error("Invalid input.") raise hdate = input_date - relativedelta(years=subtract_year) return hdate.strftime("%Y-%m-%d") ================================================ FILE: weather_dl/download_pipeline/util_test.py ================================================ import unittest from .util import fetch_geo_polygon, generate_hdate, ichunked # TODO(#245): Duplicate tests; remove. class IChunksTests(unittest.TestCase): def setUp(self) -> None: self.items = range(20) def test_even_chunks(self): actual = [] for chunk in ichunked(self.items, 4): actual.append(list(chunk)) self.assertEqual(actual, [ [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19], ]) def test_odd_chunks(self): actual = [] for chunk in ichunked(self.items, 7): actual.append(list(chunk)) self.assertEqual(actual, [ [0, 1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12, 13], [14, 15, 16, 17, 18, 19] ]) class TestFetchGeoPolygon(unittest.TestCase): def test_valid_area(self): # Test with valid area values. area = ['40', '-75', '39', '-74'] expected_result = ( '{"type": "Polygon", "coordinates": ' '[[[-75.0, 40.0], [-75.0, 39.0], [-74.0, 39.0], [-74.0, 40.0], [-75.0, 40.0]]]}' ) self.assertEqual(fetch_geo_polygon(area), expected_result) def test_valid_string_area_value(self): # Test with valid string area value. area = 'E' expected_result = ( '{"type": "Polygon", "coordinates": ' '[[[-27.0, 73.5], [-27.0, 33.0], [45.0, 33.0], [45.0, 73.5], [-27.0, 73.5]]]}' ) self.assertEqual(fetch_geo_polygon(area), expected_result) def test_invalid_string_area_value(self): # Test with invalid string area value. area = 'B' with self.assertRaises(RuntimeError): fetch_geo_polygon(area) def test_invalid_latitude_south(self): # Test with invalid south latitude value area = [40, -75, -91, -74] with self.assertRaises(ValueError): fetch_geo_polygon(area) def test_invalid_latitude_north(self): # Test with invalid north latitude value area = [91, -75, 39, -74] with self.assertRaises(ValueError): fetch_geo_polygon(area) def test_invalid_longitude_west(self): # Test with invalid west longitude value area = [40, -181, 39, -74] with self.assertRaises(ValueError): fetch_geo_polygon(area) def test_invalid_longitude_east(self): # Test with invalid east longitude value area = [40, -75, 39, 181] with self.assertRaises(ValueError): fetch_geo_polygon(area) class TestGenerateHdate(unittest.TestCase): def test_valid_hdate(self): date = '2020-01-02' substract_year = '4' expected_result = '2016-01-02' self.assertEqual(generate_hdate(date, substract_year), expected_result) # Also test for leap day correctness date = '2020-02-29' substract_year = '3' expected_result = '2017-02-28' self.assertEqual(generate_hdate(date, substract_year), expected_result) substract_year = '4' expected_result = '2016-02-28' self.assertEqual(generate_hdate(date, substract_year), expected_result) ================================================ FILE: weather_dl/setup.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from setuptools import setup, find_packages beam_gcp_requirements = [ "google-cloud-bigquery==2.34.4", "google-cloud-bigquery-storage==2.14.1", "google-cloud-bigtable==1.7.2", "google-cloud-core==1.7.3", "google-cloud-datastore==1.15.5", "google-cloud-dlp==3.8.0", "google-cloud-language==1.3.2", "google-cloud-pubsub==2.13.4", "google-cloud-pubsublite==1.4.2", "google-cloud-recommendations-ai==0.2.0", "google-cloud-spanner==1.19.3", "google-cloud-videointelligence==1.16.3", "google-cloud-vision==1.0.2", "apache-beam[gcp]==2.40.0", ] base_requirements = [ "ecmwf-api-client==1.6.3", "numpy>=1.19.1", "pandas==1.5.1", "xarray==2023.1.0", "requests>=2.24.0", "urllib3==1.26.5", "google-cloud-firestore==2.6.0", "firebase-admin==6.0.1", # "gcloud" should already be installed in the host image. # If we install it here, we'll hit auth issues. ] setup( name='download_pipeline', packages=find_packages(), version='0.1.28', author='Anthromets', author_email='anthromets-ecmwf@google.com', url='https://weather-tools.readthedocs.io/en/latest/weather_dl/', description='A tool to download weather data.', install_requires=beam_gcp_requirements + base_requirements, ) ================================================ FILE: weather_dl/weather-dl ================================================ #!/usr/bin/env python3 # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import glob import logging import os import subprocess import sys import tarfile import tempfile import weather_dl SDK_CONTAINER_IMAGE='gcr.io/weather-tools-prod/weather-tools:0.0.0' if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) site_pkg = weather_dl.__path__[0] try: from download_pipeline import cli except ImportError: # Install the subpackage. subprocess.check_call(f'{sys.executable} -m pip -q install -e {site_pkg}'.split()) # Re-load sys.path import site from importlib import reload reload(site) # Re-attempt import. If this fails, the user probably has an older version of # the package already installed on their machine that breaks this process. # If that's the case, it's best to start from a clean virtual environment. try: from download_pipeline import cli except ImportError as e: raise ImportError('please re-install package in a clean python environment.') from e args = [] if "DataflowRunner" in sys.argv and "--sdk_container_image" not in sys.argv: args.extend(['--sdk_container_image', os.getenv('SDK_CONTAINER_IMAGE', SDK_CONTAINER_IMAGE), '--experiments', 'use_runner_v2']) if "--use-local-code" in sys.argv: with tempfile.TemporaryDirectory() as tmpdir: original_dir = os.getcwd() # Convert subpackage to a tarball os.chdir(site_pkg) subprocess.check_call( f'{sys.executable} ./setup.py -q sdist --dist-dir {tmpdir}'.split(), ) os.chdir(original_dir) # Set tarball as extra packages for Beam. pkg_archive = glob.glob(os.path.join(tmpdir, '*.tar.gz'))[0] with tarfile.open(pkg_archive, 'r') as tar: assert any([f.endswith('.py') for f in tar.getnames()]), 'extra_package must include python files!' # cleanup memory to prevent pickling error. tar = None weather_dl = None args.extend(['--extra_package', pkg_archive]) cli(args) else: cli(args) ================================================ FILE: weather_dl_v2/README.md ================================================ ## weather-dl-v2 > **_NOTE:_** weather-dl-v2 only supports python 3.10 ### Sequence of steps: 1) Refer to downloader_kubernetes/README.md 2) Refer to license_deployment/README.md 3) Refer to fastapi-server/README.md 4) Refer to cli/README.md ### To create docker images of services ``` export PROJECT_ID= export REPO= eg:weather-tools ``` Choose any ONE service from below: 1. CLI ``` export SERVICE=weather-dl-v2-cli export FOLDER=cli ``` 2. weather-dl-v2-downloader ``` export SERVICE=weather-dl-v2-downloader export FOLDER=downloader_kubernetes ``` 3. weather-dl-v2-license-dep ``` export SERVICE=weather-dl-v2-license-dep export FOLDER=license_deployment ``` 4. weather-dl-v2-server ``` export SERVICE=weather-dl-v2-server export FOLDER=fastapi-server ``` Finally run: ``` export VER=$(cat $FOLDER/VERSION.txt) gcloud builds submit --config=cloudbuild.yml --substitutions=_PROJECT_ID=$PROJECT_ID,_REPO=$REPO,_SERVICE=$SERVICE,_VER=$VER,_FOLDER=$FOLDER ``` ================================================ FILE: weather_dl_v2/__init__.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: weather_dl_v2/cli/CLI-Documentation.md ================================================ # CLI Documentation The following doc provides cli commands and their various arguments and options. Base Command: ``` weather-dl-v2 ``` ## Ping Ping the FastAPI server and check if it’s live and reachable. weather-dl-v2 ping ##### Usage ``` weather-dl-v2 ping ```
## Download Manage download configs. ### Add Downloads weather-dl-v2 download add
Adds a new download config to specific licenses.
##### Arguments > `FILE_PATH` : Path to config file. ##### Options > `-l/--license` (Required): License ID to which this download has to be added to. > `-f/--force-download` : Force redownload of partitions that were previously downloaded. ##### Usage ``` weather-dl-v2 download add /path/to/example.cfg –l L1 -l L2 [--force-download] ``` ### List Downloads weather-dl-v2 download list
List all the active downloads.
The list can also be filtered out by client_names. Available filters: ``` Filter Key: client_name Values: cds, mars, ecpublic Filter Key: status Values: completed, failed, in-progress ``` ##### Options > `--filter` : Filter the list by some key and value. Format of filter filter_key=filter_value ##### Usage ``` weather-dl-v2 download list weather-dl-v2 download list --filter client_name=cds weather-dl-v2 download list --filter status=success weather-dl-v2 download list --filter status=failed weather-dl-v2 download list --filter status=in-progress weather-dl-v2 download list --filter client_name=cds --filter status=success ``` ### Download Get weather-dl-v2 download get
Get a particular download by config name.
##### Arguments > `CONFIG_NAME` : Name of the download config. ##### Usage ``` weather-dl-v2 download get example.cfg ``` ### Download Show weather-dl-v2 download show
Get contents of a particular config by config name.
##### Arguments > `CONFIG_NAME` : Name of the download config. ##### Usage ``` weather-dl-v2 download show example.cfg ``` ### Download Remove weather-dl-v2 download remove
Remove a download by config name.
##### Arguments > `CONFIG_NAME` : Name of the download config. ##### Usage ``` weather-dl-v2 download remove example.cfg ``` ### Download Refetch weather-dl-v2 download refetch
Refetch all non-successful partitions of a config.
##### Arguments > `CONFIG_NAME` : Name of the download config. ##### Options > `-l/--license` (Required): License ID to which this download has to be added to. ##### Usage ``` weather-dl-v2 download refetch example.cfg -l L1 -l L2 ```
## License Manage licenses. ### License Add weather-dl-v2 license add
Add a new license. New licenses are added using a json file.
The json file should be in this format: ``` { "license_id: , "client_name": , "number_of_requests": , "secret_id": } ``` NOTE: `license_id` is case insensitive and has to be unique for each license. ##### Arguments > `FILE_PATH` : Path to the license json. ##### Usage ``` weather-dl-v2 license add /path/to/new-license.json ``` ### License Get weather-dl-v2 license get
Get a particular license by license ID.
##### Arguments > `LICENSE` : License ID of the license to be fetched. ##### Usage ``` weather-dl-v2 license get L1 ``` ### License Remove weather-dl-v2 license remove
Remove a particular license by license ID.
##### Arguments > `LICENSE` : License ID of the license to be removed. ##### Usage ``` weather-dl-v2 license remove L1 ``` ### License List weather-dl-v2 license list
List all the licenses available.
The list can also be filtered by client name. ##### Options > `--filter` : Filter the list by some key and value. Format of filter filter_key=filter_value. ##### Usage ``` weather-dl-v2 license list weather-dl-v2 license list --filter client_name=cds ``` ### License Update weather-dl-v2 license update
Update an existing license using License ID and a license json.
The json should be of the same format used to add a new license. ##### Arguments > `LICENSE` : License ID of the license to be edited. > `FILE_PATH` : Path to the license json. ##### Usage ``` weather-dl-v2 license update L1 /path/to/license.json ```
## Queue Manage all the license queue. ### Queue List weather-dl-v2 queue list
List all the queues.
The list can also be filtered by client name. ##### Options > `--filter` : Filter the list by some key and value. Format of filter filter_key=filter_value. ##### Usage ``` weather-dl-v2 queue list weather-dl-v2 queue list --filter client_name=cds ``` ### Queue Get weather-dl-v2 queue get
Get a queue by license ID.
The list can also be filtered by client name. ##### Arguments > `LICENSE` : License ID of the queue to be fetched. ##### Usage ``` weather-dl-v2 queue get L1 ``` ### Queue Edit weather-dl-v2 queue edit
Edit the priority of configs inside queues using edit.
Priority can be edited in two ways: 1. The new priority queue is passed using a priority json file that should follow the following format: ``` { “priority”: [“c1.cfg”, “c3.cfg”, “c2.cfg”] } ``` 2. A config file name and its absolute priority can be passed and it updates the priority for that particular config file in the mentioned license queue. ##### Arguments > `LICENSE` : License ID of queue to be edited. ##### Options > `-f/--file` : Path of the new priority json file. > `-c/--config` : Config name for absolute priority. > `-p/--priority`: Absolute priority for the config in a license queue. Priority increases in ascending order with 0 having highest priority. ##### Usage ``` weather-dl-v2 queue edit L1 --file /path/to/priority.json weather-dl-v2 queue edit L1 --config example.cfg --priority 0 ```
## Config Configurations for cli. ### Config Show IP weather-dl-v2 config show-ip
See the current server IP address.
##### Usage ``` weather-dl-v2 config show-ip ``` ### Config Set IP weather-dl-v2 config set-ip
See the current server IP address.
##### Arguments > `NEW_IP` : New IP address. (Do not add port or protocol). ##### Usage ``` weather-dl-v2 config set-ip 127.0.0.1 ``` ================================================ FILE: weather_dl_v2/cli/Dockerfile ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. FROM continuumio/miniconda3:latest COPY . . # Add the mamba solver for faster builds RUN conda install -n base conda-libmamba-solver RUN conda config --set solver libmamba # Create conda env using environment.yml RUN conda update conda -y RUN conda env create --name weather-dl-v2-cli --file=environment.yml # Activate the conda env and update the PATH ARG CONDA_ENV_NAME=weather-dl-v2-cli RUN echo "source activate ${CONDA_ENV_NAME}" >> ~/.bashrc ENV PATH /opt/conda/envs/${CONDA_ENV_NAME}/bin:$PATH RUN apt-get update -y RUN apt-get install nano -y RUN apt-get install vim -y RUN apt-get install curl -y # Install gsutil RUN curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-443.0.0-linux-arm.tar.gz RUN tar -xf google-cloud-cli-443.0.0-linux-arm.tar.gz RUN ./google-cloud-sdk/install.sh --quiet RUN echo "if [ -f '/google-cloud-sdk/path.bash.inc' ]; then . '/google-cloud-sdk/path.bash.inc'; fi" >> /root/.bashrc RUN echo "if [ -f '/google-cloud-sdk/completion.bash.inc' ]; then . '/google-cloud-sdk/completion.bash.inc'; fi" >> /root/.bashrc ================================================ FILE: weather_dl_v2/cli/README.md ================================================ # weather-dl-cli This is a command line interface for talking to the weather-dl-v2 FastAPI server. - Due to our org level policy we can't expose external-ip using LoadBalancer Service while deploying our FastAPI server. Hence we need to deploy the CLI on a VM to interact through our fastapi server. Replace the FastAPI server pod's IP in cli_config.json. ``` Please make approriate changes in cli_config.json, if required. ``` > Note: Command to get the Pod IP : `kubectl get pods -o wide`. > > Though note that in case of Pod restart IP might get change. So we need to look > for better solution for the same. ## Create docker image for weather-dl-cli Refer instructions in weather_dl_v2/README.md ## Create a VM using above created docker-image ``` export ZONE= eg: us-west1-a export SERVICE_ACCOUNT= # Let's keep this as Compute Engine Default Service Account export IMAGE_PATH= # The above created image-path gcloud compute instances create-with-container weather-dl-v2-cli \ --project=$PROJECT_ID \ --zone=$ZONE \ --machine-type=e2-medium \ --network-interface=network-tier=PREMIUM,subnet=default \ --maintenance-policy=MIGRATE \ --provisioning-model=STANDARD \ --service-account=$SERVICE_ACCOUNT \ --scopes=https://www.googleapis.com/auth/cloud-platform \ --tags=http-server,https-server \ --image=projects/cos-cloud/global/images/cos-stable-105-17412-101-24 \ --boot-disk-size=10GB \ --boot-disk-type=pd-balanced \ --boot-disk-device-name=weather-dl-v2-cli \ --container-image=$IMAGE_PATH \ --container-restart-policy=on-failure \ --container-tty \ --no-shielded-secure-boot \ --shielded-vtpm \ --labels=goog-ec-src=vm_add-gcloud,container-vm=cos-stable-105-17412-101-24 \ --metadata-from-file=startup-script=vm-startup.sh ``` ## Use the cli after doing ssh in the above created VM ``` weather-dl-v2 --help ``` ================================================ FILE: weather_dl_v2/cli/VERSION.txt ================================================ 1.0.4 ================================================ FILE: weather_dl_v2/cli/app/__init__.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: weather_dl_v2/cli/app/cli_config.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import json import typing as t import pkg_resources Values = t.Union[t.List["Values"], t.Dict[str, "Values"], bool, int, float, str] # pytype: disable=not-supported-yet @dataclasses.dataclass class CliConfig: pod_ip: str = "" port: str = "" @property def BASE_URI(self) -> str: # If pod IP is not present assume in dev environment. if(self.pod_ip == ""): return "http://127.0.0.1:8000" return f"http://{self.pod_ip}:{self.port}" kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict) @classmethod def from_dict(cls, config: t.Dict): config_instance = cls() for key, value in config.items(): if hasattr(config_instance, key): setattr(config_instance, key, value) else: config_instance.kwargs[key] = value return config_instance cli_config = None def get_config(): global cli_config cli_config_json = pkg_resources.resource_filename('app', 'data/cli_config.json') if cli_config is None: with open(cli_config_json) as file: firestore_dict = json.load(file) cli_config = CliConfig.from_dict(firestore_dict) return cli_config ================================================ FILE: weather_dl_v2/cli/app/data/cli_config.json ================================================ { "pod_ip": "", "port": 8080 } ================================================ FILE: weather_dl_v2/cli/app/main.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import requests import typer from app.cli_config import get_config from app.subcommands import config, download, license, queue from app.utils import Loader logger = logging.getLogger(__name__) app = typer.Typer( help="weather-dl-v2 is a cli tool for communicating with FastAPI server." ) app.add_typer(download.app, name="download", help="Manage downloads.") app.add_typer(queue.app, name="queue", help="Manage queues.") app.add_typer(license.app, name="license", help="Manage licenses.") app.add_typer(config.app, name="config", help="Configurations for cli.") @app.command("ping", help="Check if FastAPI server is live and rechable.") def ping(): uri = f"{get_config().BASE_URI}/" try: with Loader("Sending request..."): x = requests.get(uri) except Exception as e: print(f"error {e}") return print(x.text) if __name__ == "__main__": app() ================================================ FILE: weather_dl_v2/cli/app/services/__init__.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: weather_dl_v2/cli/app/services/download_service.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import json import logging import typing as t from app.cli_config import get_config from app.services.network_service import network_service logger = logging.getLogger(__name__) class DownloadService(abc.ABC): @abc.abstractmethod def _list_all_downloads(self): pass @abc.abstractmethod def _list_all_downloads_by_filter(self, filter_dict: dict): pass @abc.abstractmethod def _get_download_by_config(self, config_name: str): pass @abc.abstractmethod def _show_config_content(self, config_name: str): pass @abc.abstractmethod def _add_new_download( self, file_path: str, licenses: t.List[str], force_download: bool, priority: int | None ): pass @abc.abstractmethod def _remove_download(self, config_name: str): pass @abc.abstractmethod def _refetch_config_partitions(self, config_name: str, licenses: t.List[str], only_failed: bool): pass class DownloadServiceNetwork(DownloadService): def __init__(self): self.endpoint = f"{get_config().BASE_URI}/download" def _list_all_downloads(self): return network_service.get( uri=self.endpoint, header={"accept": "application/json"} ) def _list_all_downloads_by_filter(self, filter_dict: dict): return network_service.get( uri=self.endpoint, header={"accept": "application/json"}, query=filter_dict, ) def _get_download_by_config(self, config_name: str): return network_service.get( uri=f"{self.endpoint}/{config_name}", header={"accept": "application/json"}, ) def _show_config_content(self, config_name: str): return network_service.get( uri=f"{self.endpoint}/show/{config_name}", header={"accept": "application/json"}, ) def _add_new_download( self, file_path: str, licenses: t.List[str], force_download: bool, priority: int | None ): try: file = {"file": open(file_path, "rb")} except FileNotFoundError: return "File not found." return network_service.post( uri=self.endpoint, header={"accept": "application/json"}, file=file, payload={"licenses": licenses}, query={"force_download": force_download, "priority": priority}, ) def _remove_download(self, config_name: str): return network_service.delete( uri=f"{self.endpoint}/{config_name}", header={"accept": "application/json"} ) def _refetch_config_partitions(self, config_name: str, licenses: t.List[str], only_failed: bool): return network_service.post( uri=f"{self.endpoint}/retry/{config_name}", header={"accept": "application/json"}, payload=json.dumps({"licenses": licenses}), query={'only_failed': only_failed} ) class DownloadServiceMock(DownloadService): pass def get_download_service(test: bool = False): if test: return DownloadServiceMock() else: return DownloadServiceNetwork() download_service = get_download_service() ================================================ FILE: weather_dl_v2/cli/app/services/license_service.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import json import logging from app.cli_config import get_config from app.services.network_service import network_service logger = logging.getLogger(__name__) class LicenseService(abc.ABC): @abc.abstractmethod def _get_all_license(self): pass @abc.abstractmethod def _get_all_license_by_client_name(self, client_name: str): pass @abc.abstractmethod def _get_license_by_license_id(self, license_id: str): pass @abc.abstractmethod def _add_license(self, license_dict: dict): pass @abc.abstractmethod def _remove_license(self, license_id: str): pass @abc.abstractmethod def _update_license(self, license_id: str, license_dict: dict): pass @abc.abstractmethod def _redeploy_license_by_license_id(self, license_id: str): pass @abc.abstractmethod def _redeploy_licenses_by_client(self, client_name: str): pass class LicenseServiceNetwork(LicenseService): def __init__(self): self.endpoint = f"{get_config().BASE_URI}/license" def _get_all_license(self): return network_service.get( uri=self.endpoint, header={"accept": "application/json"} ) def _get_all_license_by_client_name(self, client_name: str): return network_service.get( uri=self.endpoint, header={"accept": "application/json"}, query={"client_name": client_name}, ) def _get_license_by_license_id(self, license_id: str): return network_service.get( uri=f"{self.endpoint}/{license_id}", header={"accept": "application/json"}, ) def _add_license(self, license_dict: dict): return network_service.post( uri=self.endpoint, header={"accept": "application/json"}, payload=json.dumps(license_dict), ) def _remove_license(self, license_id: str): return network_service.delete( uri=f"{self.endpoint}/{license_id}", header={"accept": "application/json"}, ) def _update_license(self, license_id: str, license_dict: dict): return network_service.put( uri=f"{self.endpoint}/{license_id}", header={"accept": "application/json"}, payload=json.dumps(license_dict), ) def _redeploy_license_by_license_id(self, license_id: str): return network_service.patch( uri=f"{self.endpoint}/redeploy", header={"accept": "application/json"}, query={"license_id": license_id} ) def _redeploy_licenses_by_client(self, client_name: str): return network_service.patch( uri=f"{self.endpoint}/redeploy", header={"accept": "application/json"}, query={"client_name": client_name} ) class LicenseServiceMock(LicenseService): pass def get_license_service(test: bool = False): if test: return LicenseServiceMock() else: return LicenseServiceNetwork() license_service = get_license_service() ================================================ FILE: weather_dl_v2/cli/app/services/network_service.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import logging import requests from app.utils import Loader, timeit logger = logging.getLogger(__name__) class NetworkService: def parse_response(self, response: requests.Response): try: parsed = json.loads(response.text) except Exception as e: logger.info(f"Parsing error: {e}.") logger.info(f"Status code {response.status_code}") logger.info(f"Response {response.text}") return if isinstance(parsed, list): print(f"[Total {len(parsed)} items.]") return json.dumps(parsed, indent=3) @timeit def get(self, uri, header, query=None, payload=None): try: with Loader("Sending request..."): x = requests.get(uri, params=query, headers=header, data=payload) return self.parse_response(x) except requests.exceptions.RequestException as e: logger.error(f"request error: {e}") raise SystemExit(e) @timeit def post(self, uri, header, query=None, payload=None, file=None): try: with Loader("Sending request..."): x = requests.post( uri, params=query, headers=header, data=payload, files=file ) return self.parse_response(x) except requests.exceptions.RequestException as e: logger.error(f"request error: {e}") raise SystemExit(e) @timeit def put(self, uri, header, query=None, payload=None, file=None): try: with Loader("Sending request..."): x = requests.put( uri, params=query, headers=header, data=payload, files=file ) return self.parse_response(x) except requests.exceptions.RequestException as e: logger.error(f"request error: {e}") raise SystemExit(e) @timeit def delete(self, uri, header, query=None): try: with Loader("Sending request..."): x = requests.delete(uri, params=query, headers=header) return self.parse_response(x) except requests.exceptions.RequestException as e: logger.error(f"request error: {e}") raise SystemExit(e) @timeit def patch(self, uri, header, query=None): try: with Loader("Sending request..."): x = requests.patch(uri, params=query, headers=header) return self.parse_response(x) except requests.exceptions.RequestException as e: logger.error(f"request error: {e}") raise SystemExit(e) network_service = NetworkService() ================================================ FILE: weather_dl_v2/cli/app/services/queue_service.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import json import logging import os import typing as t from app.cli_config import get_config from app.services.network_service import network_service logger = logging.getLogger(__name__) class QueueService(abc.ABC): @abc.abstractmethod def _get_all_license_queues(self): pass @abc.abstractmethod def _get_license_queue_by_client_name(self, client_name: str): pass @abc.abstractmethod def _get_queue_by_license(self, license_id: str): pass @abc.abstractmethod def _edit_license_queue(self, license_id: str, priority_list: t.List[str]): pass @abc.abstractmethod def _edit_config_absolute_priority( self, license_id: str, config_name: str, priority: int ): pass @abc.abstractmethod def _save_queue_to_file(self, license_id: str, dir: str): pass class QueueServiceNetwork(QueueService): def __init__(self): self.endpoint = f"{get_config().BASE_URI}/queues" def _get_all_license_queues(self): return network_service.get( uri=self.endpoint, header={"accept": "application/json"} ) def _get_license_queue_by_client_name(self, client_name: str): return network_service.get( uri=self.endpoint, header={"accept": "application/json"}, query={"client_name": client_name}, ) def _get_queue_by_license(self, license_id: str): return network_service.get( uri=f"{self.endpoint}/{license_id}", header={"accept": "application/json"} ) def _edit_license_queue(self, license_id: str, priority_list: t.List[str]): return network_service.post( uri=f"{self.endpoint}/{license_id}", header={"accept": "application/json", "Content-Type": "application/json"}, payload=json.dumps(priority_list), ) def _edit_config_absolute_priority( self, license_id: str, config_name: str, priority: int ): return network_service.put( uri=f"{self.endpoint}/priority/{license_id}", header={"accept": "application/json"}, query={"config_name": config_name, "priority": priority}, ) def _save_queue_to_file(self, license_id: str, dir_path: str) -> str: if not os.path.isdir(dir_path): print(f"{dir_path} is a not directory.") return None response = self._get_queue_by_license(license_id) parsed_response = json.loads(response) if "queue" not in parsed_response: print(response) return None json_data = {"priority": parsed_response["queue"]} file_path = os.path.join(dir_path, f"{license_id}.json") with open(file_path, 'w') as json_file: json.dump(json_data, json_file) return file_path class QueueServiceMock(QueueService): pass def get_queue_service(test: bool = False): if test: return QueueServiceMock() else: return QueueServiceNetwork() queue_service = get_queue_service() ================================================ FILE: weather_dl_v2/cli/app/subcommands/__init__.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: weather_dl_v2/cli/app/subcommands/config.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import subprocess import pkg_resources import typer from typing_extensions import Annotated from app.cli_config import get_config from app.utils import Validator, confirm_action app = typer.Typer() class ConfigValidator(Validator): pass @app.command("update", help="Update the cli.") def update_cli(): confirm_action("Are you sure you want to update cli?") try: print("Updating CLI. This will take some time...") subprocess.run(['pip', 'uninstall', 'weather-dl-v2', '-y', '-q']) subprocess.run(['pip', 'install', 'git+http://github.com/google/weather-tools#subdirectory=weather_dl_v2/cli']) subprocess.run(['clear']) print("CLI updated successfully. ✨") except Exception as e: print(f"Couldn't update CLI. Error: {e}.") @app.command("show_ip", help="See the current server IP address.") def show_server_ip(): print(f"Current pod IP: {get_config().pod_ip}") @app.command("set_ip", help="Update the server IP address.") def update_server_ip( new_ip: Annotated[ str, typer.Argument(help="New IP address. (Do not add port or protocol).") ], ): file_path = pkg_resources.resource_filename('app', 'data/cli_config.json') cli_config = {} with open(file_path, "r") as file: cli_config = json.load(file) old_ip = cli_config["pod_ip"] cli_config["pod_ip"] = new_ip with open(file_path, "w") as file: json.dump(cli_config, file) validator = ConfigValidator(valid_keys=["pod_ip", "port"]) try: cli_config = validator.validate_json(file_path=file_path) except Exception as e: print(f"payload error: {e}") return print(f"Pod IP Updated {old_ip} -> {new_ip} .") ================================================ FILE: weather_dl_v2/cli/app/subcommands/download.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List import typer from typing_extensions import Annotated from app.services.download_service import download_service from app.utils import Validator, as_table, confirm_action app = typer.Typer(rich_markup_mode="markdown") class DowloadFilterValidator(Validator): pass download_key_order = [ 'config_name', 'client_name', 'partitioning_status', 'scheduled_shards', 'in-progress_shards', 'downloaded_shards', 'failed_shards', 'total_shards' ] @app.command("list", help="List out all the configs.") def get_downloads( filter: Annotated[ List[str], typer.Option( help="""Filter by some value. Format: filter_key=filter_value. Available filters """ """[key: client_name, values: cds, mars, ecpublic] """ """[key: status, values: completed, failed, in-progress]""" ), ] = [] ): if len(filter) > 0: validator = DowloadFilterValidator(valid_keys=["client_name", "status"]) try: filter_dict = validator.validate(filters=filter, allow_missing=True) except Exception as e: print(f"filter error: {e}") return print(as_table(download_service._list_all_downloads_by_filter(filter_dict), download_key_order)) return print(as_table(download_service._list_all_downloads(), download_key_order)) # TODO: Add support for submitting multiple configs using *.cfg notation. @app.command("add", help="Submit new config to download.") def submit_download( file_path: Annotated[ str, typer.Argument(help="File path of config to be uploaded.") ], license: Annotated[List[str], typer.Option("--license", "-l", help="License ID.")], force_download: Annotated[ bool, typer.Option( "-f", "--force-download", help="Force redownload of partitions that were previously downloaded.", ), ] = False, priority: Annotated[ int, typer.Option( "-p", "--priority", help="Set the priority for submitted config in ALL licenses. If not added, the config is added" \ "at the end of the queue. Priority decreases in ascending order with 0 having highest priority.", ), ] = None, ): print(download_service._add_new_download(file_path, license, force_download, priority)) @app.command("get", help="Get a particular config.") def get_download_by_config( config_name: Annotated[str, typer.Argument(help="Config file name.")] ): print(as_table(download_service._get_download_by_config(config_name), download_key_order)) @app.command("show", help="Show contents of a particular config.") def show_config( config_name: Annotated[str, typer.Argument(help="Config file name.")] ): print(download_service._show_config_content(config_name)) @app.command("remove", help="Remove existing config.") def remove_download( config_name: Annotated[str, typer.Argument(help="Config file name.")], auto_confirm: Annotated[bool, typer.Option("-y", help="Automically confirm any promt.")] = False ): if not auto_confirm: confirm_action(f"Are you sure you want to remove {config_name}?") print(download_service._remove_download(config_name)) @app.command( "refetch", help="Reschedule all partitions of a config that are not successful." ) def refetch_config( config_name: Annotated[str, typer.Argument(help="Config file name.")], license: Annotated[List[str], typer.Option("--license", "-l", help="License ID.")], only_failed: Annotated[bool, typer.Option("--only_failed", help="Only refetch failed partitions.")] = False, auto_confirm: Annotated[bool, typer.Option("-y", help="Automically confirm any promt.")] = False ): if not auto_confirm: confirm_action(f"Are you sure you want to refetch {config_name}?") print(download_service._refetch_config_partitions(config_name, license, only_failed)) ================================================ FILE: weather_dl_v2/cli/app/subcommands/license.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import typer from typing_extensions import Annotated from app.services.license_service import license_service from app.utils import Validator, as_table, confirm_action app = typer.Typer() class LicenseValidator(Validator): pass license_key_order = [ 'license_id', 'client_name', 'status', 'number_of_requests', 'secret_id', 'k8s_deployment_id' ] @app.command("list", help="List all licenses.") def get_all_license( filter: Annotated[ str, typer.Option(help="Filter by some value. Format: filter_key=filter_value") ] = None ): if filter: validator = LicenseValidator(valid_keys=["client_name"]) try: data = validator.validate(filters=[filter]) client_name = data["client_name"] except Exception as e: print(f"filter error: {e}") return print(as_table(license_service._get_all_license_by_client_name(client_name), license_key_order)) return print(as_table(license_service._get_all_license(), license_key_order)) @app.command("get", help="Get a particular license by ID.") def get_license(license: Annotated[str, typer.Argument(help="License ID.")]): print(as_table(license_service._get_license_by_license_id(license), license_key_order)) @app.command("add", help="Add new license.") def add_license( file_path: Annotated[ str, typer.Argument( help="""Input json file. Example json for new license-""" """{"license_id" : , "client_name" : , "number_of_requests" : , "secret_id" : }""" """\nNOTE: license_id is case insensitive and has to be unique for each license.""" ), ], ): validator = LicenseValidator( valid_keys=["license_id", "client_name", "number_of_requests", "secret_id"] ) try: license_dict = validator.validate_json(file_path=file_path) except Exception as e: print(f"payload error: {e}") return print(license_service._add_license(license_dict)) @app.command("remove", help="Remove a license.") def remove_license( license: Annotated[str, typer.Argument(help="License ID.")], auto_confirm: Annotated[bool, typer.Option("-y", help="Automically confirm any promt.")] = False ): if not auto_confirm: confirm_action(f"Are you sure you want to remove {license}?") print(license_service._remove_license(license)) @app.command("update", help="Update existing license.") def update_license( license: Annotated[str, typer.Argument(help="License ID.")], file_path: Annotated[ str, typer.Argument( help="""Input json file. Example json for updated license- """ """{"license_id": , "client_name" : , "number_of_requests" : , "secret_id" : }""" ), ], auto_confirm: Annotated[bool, typer.Option("-y", help="Automically confirm any promt.")] = False ): validator = LicenseValidator( valid_keys=["license_id", "client_name", "number_of_requests", "secret_id"] ) try: license_dict = validator.validate_json(file_path=file_path) except Exception as e: print(f"payload error: {e}") return if not auto_confirm: confirm_action(f"Are you sure you want to update {license}?") print(license_service._update_license(license, license_dict)) @app.command("redeploy", help="""Redeploy licenses.""" """ CAUTION: Redeploying will cause licenses to stop whatever they are doing.""" """ This can cause queues to be filled with stray requests from previous deployments.""" ) def redeploy_license( license_id: Annotated[ str, typer.Option( "--license_id", help="""Mention license_id of license to redeploy.""" """ Send 'all' if want to redeploy all licenses.""" ) ] = None, client_name: Annotated[ str, typer.Option( "--client_name", help="Redeploy all licenses of a particular client." ) ] = None, auto_confirm: Annotated[bool, typer.Option("-y", help="Automically confirm any promt.")] = False ): if license_id is not None and client_name is not None: print("Can't pass both license_id and client_name. Please pass only one.") return if license_id is None and client_name is None: print("Please pass --license_id or --client_name.") return if license_id is not None: if not auto_confirm: confirm_action(f"Are you sure you want to redeploy {license_id}?") print(license_service._redeploy_license_by_license_id(license_id)) return if client_name is not None: if not auto_confirm: confirm_action(f"Are you sure you want to redeploy licenses from {client_name}?") print(license_service._redeploy_licenses_by_client(client_name)) return ================================================ FILE: weather_dl_v2/cli/app/subcommands/queue.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import typer from typing_extensions import Annotated from app.services.queue_service import queue_service from app.utils import Validator, as_table, confirm_action app = typer.Typer() class QueueValidator(Validator): pass queue_key_order = ['license_id', 'client_name', 'queue'] @app.command("list", help="List all the license queues.") def get_all_license_queue( filter: Annotated[ str, typer.Option(help="Filter by some value. Format: filter_key=filter_value") ] = None ): if filter: validator = QueueValidator(valid_keys=["client_name"]) try: data = validator.validate(filters=[filter]) client_name = data["client_name"] except Exception as e: print(f"filter error: {e}") return print(as_table(queue_service._get_license_queue_by_client_name(client_name), queue_key_order)) return print(as_table(queue_service._get_all_license_queues(), queue_key_order)) @app.command("get", help="Get queue of particular license.") def get_license_queue(license: Annotated[str, typer.Argument(help="License ID")]): print(as_table(queue_service._get_queue_by_license(license), queue_key_order)) @app.command( "edit", help="Edit existing license queue. Queue can edited via a priority" "file or my moving a single config to a given priority.", ) def modify_license_queue( license: Annotated[str, typer.Argument(help="License ID.")], empty: Annotated[ bool, typer.Option( "--empty", help="""Empties the license queue. If this is passed, other options are ignored.""" ) ] = False, save_dir_path: Annotated[ str, typer.Option( "--save_and_empty", help="""Saves the license queue to a file and empties the queue.""" """ Pass in path of directory. File will be saved as .json .""" ) ] = None, file: Annotated[ str, typer.Option( "--file", "-f", help="""File path of priority json file. Example json: {"priority": ["c1.cfg", "c2.cfg",...]}""", ), ] = None, config: Annotated[ str, typer.Option("--config", "-c", help="Config name for absolute priority.") ] = None, priority: Annotated[ int, typer.Option( "--priority", "-p", help="Absolute priority for the config in a license queue." " Priority decreases in ascending order with 0 having highest priority.", ), ] = None, auto_confirm: Annotated[bool, typer.Option("-y", help="Automically confirm any promt.")] = False ): if empty and save_dir_path: print("Both --empty and --save_and_empty can't be passed. Use only one.") return if empty: if not auto_confirm: confirm_action(f"Are you sure you want to empty queue for {license}?") print("Emptying license queue...") print(queue_service._edit_license_queue(license, [])) return if save_dir_path: if not auto_confirm: confirm_action(f"Are you sure you want to empty queue for {license}?") print("Saving and Emptying license queue...") file_path = queue_service._save_queue_to_file(license, save_dir_path) print(f"Queue saved at {file_path}.") print(queue_service._edit_license_queue(license, [])) return if file is None and (config is None and priority is None): print("Priority file or config name with absolute priority must be passed.") return if file is not None and (config is not None or priority is not None): print("--config & --priority can't be used along with --file argument.") return if file is not None: validator = QueueValidator(valid_keys=["priority"]) try: data = validator.validate_json(file_path=file) priority_list = data["priority"] except Exception as e: print(f"key error: {e}") return if not auto_confirm: confirm_action(f"Are you sure you want to edit {license} queue priority?") print(queue_service._edit_license_queue(license, priority_list)) return elif config is not None and priority is not None: if priority < 0: print("Priority can not be negative.") return if not auto_confirm: confirm_action(f"Are you sure you want to edit {license} queue priority?") print(queue_service._edit_config_absolute_priority(license, config, priority)) return else: print("--config & --priority arguments should be used together.") return ================================================ FILE: weather_dl_v2/cli/app/utils.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import dataclasses import json import logging import typer import typing as t from itertools import cycle from shutil import get_terminal_size from threading import Thread from time import sleep, time from tabulate import tabulate logger = logging.getLogger(__name__) def timeit(func): def wrap_func(*args, **kwargs): t1 = time() result = func(*args, **kwargs) t2 = time() print(f"[executed in {(t2-t1):.4f}s.]") return result return wrap_func def confirm_action(message: str = "Are you sure you want to continue?"): _ = typer.confirm(message, abort=True) def order_dict_fields(dictionary, key_order): return {key: dictionary[key] for key in key_order if key in dictionary} def as_table(response: str, key_order=None): data = json.loads(response) if not isinstance(data, list): # convert response to list if not a list. data = [data] if len(data) == 0: return "" if key_order is None: key_order = list(data[0].keys()) ordered_data = [order_dict_fields(d, key_order) for d in data] # if any column has lists, convert that to a string. rows = [ [ ",\n".join([f"{i} {ele}" for i, ele in enumerate(val)]) if isinstance(val, list) else val for _, val in x.items() ] for x in ordered_data ] rows.insert(0, list(key_order)) return tabulate( rows, showindex=True, tablefmt="grid", maxcolwidths=[16] * len(key_order) ) class Loader: def __init__(self, desc="Loading...", end="", timeout=0.1): """ A loader-like context manager Args: desc (str, optional): The loader's description. Defaults to "Loading...". end (str, optional): Final print. Defaults to "Done!". timeout (float, optional): Sleep time between prints. Defaults to 0.1. """ self.desc = desc self.end = end self.timeout = timeout self._thread = Thread(target=self._animate, daemon=True) self.steps = ["⢿", "⣻", "⣽", "⣾", "⣷", "⣯", "⣟", "⡿"] self.done = False def start(self): self._thread.start() return self def _animate(self): for c in cycle(self.steps): if self.done: break print(f"\r{self.desc} {c}", flush=True, end="") sleep(self.timeout) def __enter__(self): self.start() def stop(self): self.done = True cols = get_terminal_size((80, 20)).columns print("\r" + " " * cols, end="", flush=True) def __exit__(self, exc_type, exc_value, tb): # handle exceptions with those variables ^ self.stop() @dataclasses.dataclass class Validator(abc.ABC): valid_keys: t.List[str] def validate( self, filters: t.List[str], show_valid_filters=True, allow_missing: bool = False ): filter_dict = {} for filter in filters: _filter = filter.split("=") if len(_filter) != 2: if show_valid_filters: logger.info(f"valid filters are: {self.valid_keys}.") raise ValueError("Incorrect Filter. Please Try again.") key, value = _filter filter_dict[key] = value data_set = set(filter_dict.keys()) valid_set = set(self.valid_keys) if self._validate_keys(data_set, valid_set, allow_missing): return filter_dict def validate_json(self, file_path, allow_missing: bool = False): try: with open(file_path) as f: data: dict = json.load(f) data_keys = data.keys() data_set = set(data_keys) valid_set = set(self.valid_keys) if self._validate_keys(data_set, valid_set, allow_missing): return data except FileNotFoundError: logger.info("file not found.") raise FileNotFoundError def _validate_keys(self, data_set: set, valid_set: set, allow_missing: bool): missing_keys = valid_set.difference(data_set) invalid_keys = data_set.difference(valid_set) if not allow_missing and len(missing_keys) > 0: raise ValueError(f"keys {missing_keys} are missing in file.") if len(invalid_keys) > 0: raise ValueError(f"keys {invalid_keys} are invalid keys.") if allow_missing or data_set == valid_set: return True return False ================================================ FILE: weather_dl_v2/cli/environment.yml ================================================ name: weather-dl-v2-cli channels: - conda-forge dependencies: - python=3.10 - pip=23.0.1 - typer=0.9.0 - tabulate=0.9.0 - pip: - requests - ruff - pytype - pytest - . ================================================ FILE: weather_dl_v2/cli/setup.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from setuptools import setup requirements = ["typer", "requests", "tabulate"] setup( name="weather-dl-v2", packages=["app", "app.subcommands", "app.services"], install_requires=requirements, version="1.0.3", author='Anthromets', author_email='anthromets-ecmwf@google.com', description=( "This cli tools helps in interacting with weather dl v2 fast API server." ), entry_points={"console_scripts": ["weather-dl-v2=app.main:app", "dl-v2=app.main:app"]}, package_data={'app': ['data/*.json']}, ) ================================================ FILE: weather_dl_v2/cli/vm-startup.sh ================================================ #! /bin/bash command="docker exec -it \\\$(docker ps -qf name=weather-dl-v2-cli) /bin/bash" sudo sh -c "echo \"$command\" >> /etc/profile" ================================================ FILE: weather_dl_v2/cloudbuild.yml ================================================ steps: - name: 'gcr.io/cloud-builders/docker' args: ['build', '-t', 'gcr.io/$_PROJECT_ID/$_REPO:$_SERVICE', '$_FOLDER'] - name: 'gcr.io/cloud-builders/docker' args: ['push', 'gcr.io/$_PROJECT_ID/$_REPO:$_SERVICE'] - name: 'gcr.io/cloud-builders/docker' args: ['tag', 'gcr.io/$_PROJECT_ID/$_REPO:$_SERVICE', 'gcr.io/$_PROJECT_ID/$_REPO:$_SERVICE-$_VER'] - name: 'gcr.io/cloud-builders/docker' args: ['push', 'gcr.io/$_PROJECT_ID/$_REPO:$_SERVICE-$_VER'] timeout: 79200s options: machineType: E2_HIGHCPU_32 ================================================ FILE: weather_dl_v2/config.json ================================================ { "download_collection": "download", "queues_collection": "queues", "license_collection": "license", "manifest_collection": "manifest", "storage_bucket": "XXXXXXX", "gcs_project": "XXXXXXX", "license_deployment_image": "XXXXXXX", "downloader_k8_image": "XXXXXXX", "welcome_message": "Greetings from weather-dl v2 from weather-dl-v2-cluster-2!" } ================================================ FILE: weather_dl_v2/downloader_kubernetes/Dockerfile ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== FROM continuumio/miniconda3:latest # Update miniconda RUN conda update conda -y # Add the mamba solver for faster builds RUN conda install -n base conda-libmamba-solver RUN conda config --set solver libmamba # Create conda env using environment.yml COPY . . RUN conda env create -f environment.yml --debug # Activate the conda env and update the PATH ARG CONDA_ENV_NAME=weather-dl-v2-downloader RUN echo "source activate ${CONDA_ENV_NAME}" >> ~/.bashrc ENV PATH /opt/conda/envs/${CONDA_ENV_NAME}/bin:$PATH ================================================ FILE: weather_dl_v2/downloader_kubernetes/README.md ================================================ # Deployment / Usage Instruction ### User authorization required to set up the environment: * roles/container.admin ### Authorization needed for the tool to operate: We are not configuring any service account here hence make sure that compute engine default service account have roles: * roles/storage.admin * roles/bigquery.dataEditor * roles/bigquery.jobUser ### Make changes in weather_dl_v2/config.json, if required [for running locally] ``` export CONFIG_PATH=/path/to/weather_dl_v2/config.json ``` ### Create docker image for downloader: Refer instructions in weather_dl_v2/README.md ================================================ FILE: weather_dl_v2/downloader_kubernetes/VERSION.txt ================================================ 1.0.2 ================================================ FILE: weather_dl_v2/downloader_kubernetes/downloader.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This program downloads ECMWF data & upload it into GCS. """ import tempfile import os import sys import time from manifest import FirestoreManifest, Stage from util import copy, download_with_aria2, download_with_wget import datetime def download(url: str, path: str) -> None: """Download data from client, with retries.""" if path: if os.path.exists(path): # Empty the target file, if it already exists, otherwise the # transfer below might be fooled into thinking we're resuming # an interrupted download. open(path, "w").close() download_methods = [ download_with_aria2, download_with_aria2, download_with_wget, ] errors = [] for method in download_methods: print(f"Trying {method.__name__}.") try: method(url, path) return except Exception as e: print(f"{method.__name__} failed. Error: {e}.") errors.append(str(e)) print("Waiting for 2 mins.") time.sleep(120) err_msgs = "\n".join(errors) print(f"Failed to download {url}. Error Msg: {err_msgs}.") raise Exception( f"Downloading failed for url {url} & path {path}.\nError Msg: {err_msgs}." ) def main( config_name, dataset, selection, user_id, url, target_path, license_id ) -> None: """Download data from a client to a temp file.""" manifest = FirestoreManifest(license_id=license_id) temp_name = "" with manifest.transact(config_name, dataset, selection, target_path, user_id): with tempfile.NamedTemporaryFile(delete=False) as temp: temp_name = temp.name manifest.set_stage(Stage.DOWNLOAD) precise_download_start_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec="seconds") ) manifest.prev_stage_precise_start_time = precise_download_start_time print(f"Downloading data for {target_path!r}.") download(url, temp_name) print(f"Download completed for {target_path!r}.") manifest.set_stage(Stage.UPLOAD) precise_upload_start_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec="seconds") ) manifest.prev_stage_precise_start_time = precise_upload_start_time print(f"Uploading to store for {target_path!r}.") copy(temp_name, target_path) print(f"Upload to store complete for {target_path!r}.") os.unlink(temp_name) if __name__ == "__main__": main(*sys.argv[1:]) ================================================ FILE: weather_dl_v2/downloader_kubernetes/downloader_config.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import typing as t import json import os import logging logger = logging.getLogger(__name__) Values = t.Union[t.List["Values"], t.Dict[str, "Values"], bool, int, float, str] # pytype: disable=not-supported-yet @dataclasses.dataclass class DownloaderConfig: manifest_collection: str = "" kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict) @classmethod def from_dict(cls, config: t.Dict): config_instance = cls() for key, value in config.items(): if hasattr(config_instance, key): setattr(config_instance, key, value) else: config_instance.kwargs[key] = value return config_instance downloader_config = None def get_config(): global downloader_config if downloader_config: return downloader_config downloader_config_json = "config/config.json" if not os.path.exists(downloader_config_json): downloader_config_json = os.environ.get("CONFIG_PATH", None) if downloader_config_json is None: logger.error("Couldn't load config file for downloader.") raise FileNotFoundError("Couldn't load config file for downloader.") with open(downloader_config_json) as file: config_dict = json.load(file) downloader_config = DownloaderConfig.from_dict(config_dict) return downloader_config ================================================ FILE: weather_dl_v2/downloader_kubernetes/environment.yml ================================================ name: weather-dl-v2-downloader channels: - conda-forge dependencies: - python=3.10 - google-cloud-sdk=410.0.0 - aria2=1.36.0 - geojson=2.5.0=py_0 - xarray=2022.11.0 - google-apitools - pip=22.3 - pip: - apache_beam[gcp]==2.40.0 - firebase-admin - google-cloud-pubsub - kubernetes - psutil ================================================ FILE: weather_dl_v2/downloader_kubernetes/manifest.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Client interface for connecting to a manifest.""" import abc import dataclasses import datetime import enum import json import pandas as pd import time import traceback import typing as t from util import ( to_json_serializable_type, fetch_geo_polygon, get_file_size, get_wait_interval, generate_md5_hash, GLOBAL_COVERAGE_AREA, ) import firebase_admin from firebase_admin import credentials from firebase_admin import firestore from google.cloud.firestore_v1 import DocumentReference from google.cloud.firestore_v1.types import WriteResult from downloader_config import get_config """An implementation-dependent Manifest URI.""" Location = t.NewType("Location", str) class ManifestException(Exception): """Errors that occur in Manifest Clients.""" pass class Stage(enum.Enum): """A request can be either in one of the following stages at a time: fetch : This represents request is currently in fetch stage i.e. request placed on the client's server & waiting for some result before starting download (eg. MARS client). download : This represents request is currently in download stage i.e. data is being downloading from client's server to the worker's local file system. upload : This represents request is currently in upload stage i.e. data is getting uploaded from worker's local file system to target location (GCS path). retrieve : In case of clients where there is no proper separation of fetch & download stages (eg. CDS client), request will be in the retrieve stage i.e. fetch + download. """ RETRIEVE = "retrieve" FETCH = "fetch" DOWNLOAD = "download" UPLOAD = "upload" class Status(enum.Enum): """Depicts the request's state status: scheduled : A request partition is created & scheduled for processing. Note: Its corresponding state can be None only. in-progress : This represents the request state is currently in-progress (i.e. running). The next status would be "success" or "failure". success : This represents the request state execution completed successfully without any error. failure : This represents the request state execution failed. """ SCHEDULED = "scheduled" IN_PROGRESS = "in-progress" SUCCESS = "success" FAILURE = "failure" @dataclasses.dataclass class DownloadStatus: """Data recorded in `Manifest`s reflecting the status of a download.""" """The name of the config file associated with the request.""" config_name: str = "" """Represents the dataset field of the configuration.""" dataset: t.Optional[str] = "" """Copy of selection section of the configuration.""" selection: t.Dict = dataclasses.field(default_factory=dict) """Location of the downloaded data.""" location: str = "" """Represents area covered by the shard.""" area: str = "" """Current stage of request : 'fetch', 'download', 'retrieve', 'upload' or None.""" stage: t.Optional[Stage] = None """Download status: 'scheduled', 'in-progress', 'success', or 'failure'.""" status: t.Optional[Status] = None """Cause of error, if any.""" error: t.Optional[str] = "" """Identifier for the user running the download.""" username: str = "" """Shard size in GB.""" size: t.Optional[float] = 0 """A UTC datetime when download was scheduled.""" scheduled_time: t.Optional[str] = "" """A UTC datetime when the retrieve stage starts.""" retrieve_start_time: t.Optional[str] = "" """A UTC datetime when the retrieve state ends.""" retrieve_end_time: t.Optional[str] = "" """A UTC datetime when the fetch state starts.""" fetch_start_time: t.Optional[str] = "" """A UTC datetime when the fetch state ends.""" fetch_end_time: t.Optional[str] = "" """A UTC datetime when the download state starts.""" download_start_time: t.Optional[str] = "" """A UTC datetime when the download state ends.""" download_end_time: t.Optional[str] = "" """A UTC datetime when the upload state starts.""" upload_start_time: t.Optional[str] = "" """A UTC datetime when the upload state ends.""" upload_end_time: t.Optional[str] = "" @classmethod def from_dict(cls, download_status: t.Dict) -> "DownloadStatus": """Instantiate DownloadStatus dataclass from dict.""" download_status_instance = cls() for key, value in download_status.items(): if key == "status": setattr(download_status_instance, key, Status(value)) elif key == "stage" and value is not None: setattr(download_status_instance, key, Stage(value)) else: setattr(download_status_instance, key, value) return download_status_instance @classmethod def to_dict(cls, instance) -> t.Dict: """Return the fields of a dataclass instance as a manifest ingestible dictionary mapping of field names to field values.""" download_status_dict = {} for field in dataclasses.fields(instance): key = field.name value = getattr(instance, field.name) if isinstance(value, Status) or isinstance(value, Stage): download_status_dict[key] = value.value elif isinstance(value, pd.Timestamp): download_status_dict[key] = value.isoformat() elif key == "selection" and value is not None: download_status_dict[key] = json.dumps(value) else: download_status_dict[key] = value return download_status_dict @dataclasses.dataclass class Manifest(abc.ABC): """Abstract manifest of download statuses. Update download statuses to some storage medium. This class lets one indicate that a download is `scheduled` or in a transaction process. In the event of a transaction, a download will be updated with an `in-progress`, `success` or `failure` status (with accompanying metadata). Example: ``` my_manifest = parse_manifest_location(Location('fs://some-firestore-collection')) # Schedule data for download my_manifest.schedule({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username') # ... # Initiate a transaction – it will record that the download is `in-progess` with my_manifest.transact({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username') as tx: # download logic here pass # ... # on error, will record the download as a `failure` before propagating the error. By default, it will # record download as a `success`. ``` Attributes: status: The current `DownloadStatus` of the Manifest. """ # To reduce the impact of _read() and _update() calls # on the start time of the stage. license_id: str = "" prev_stage_precise_start_time: t.Optional[str] = None status: t.Optional[DownloadStatus] = None # This is overridden in subclass. def __post_init__(self): """Initialize the manifest.""" pass def schedule( self, config_name: str, dataset: str, selection: t.Dict, location: str, user: str, ) -> None: """Indicate that a job has been scheduled for download. 'scheduled' jobs occur before 'in-progress', 'success' or 'finished'. """ scheduled_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec="seconds") ) self.status = DownloadStatus( config_name=config_name, dataset=dataset if dataset else None, selection=selection, location=location, area=fetch_geo_polygon(selection.get("area", GLOBAL_COVERAGE_AREA)), username=user, stage=None, status=Status.SCHEDULED, error=None, size=None, scheduled_time=scheduled_time, retrieve_start_time=None, retrieve_end_time=None, fetch_start_time=None, fetch_end_time=None, download_start_time=None, download_end_time=None, upload_start_time=None, upload_end_time=None, ) self._update(self.status) def skip( self, config_name: str, dataset: str, selection: t.Dict, location: str, user: str, ) -> None: """Updates the manifest to mark the shards that were skipped in the current job as 'upload' stage and 'success' status, indicating that they have already been downloaded. """ old_status = self._read(location) # The manifest needs to be updated for a skipped shard if its entry is not present, or # if the stage is not 'upload', or if the stage is 'upload' but the status is not 'success'. if ( old_status.location != location or old_status.stage != Stage.UPLOAD or old_status.status != Status.SUCCESS ): current_utc_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec="seconds") ) size = get_file_size(location) status = DownloadStatus( config_name=config_name, dataset=dataset if dataset else None, selection=selection, location=location, area=fetch_geo_polygon(selection.get("area", GLOBAL_COVERAGE_AREA)), username=user, stage=Stage.UPLOAD, status=Status.SUCCESS, error=None, size=size, scheduled_time=None, retrieve_start_time=None, retrieve_end_time=None, fetch_start_time=None, fetch_end_time=None, download_start_time=None, download_end_time=None, upload_start_time=current_utc_time, upload_end_time=current_utc_time, ) self._update(status) print( f"Manifest updated for skipped shard: {location!r} -- {DownloadStatus.to_dict(status)!r}." ) def _set_for_transaction( self, config_name: str, dataset: str, selection: t.Dict, location: str, user: str, ) -> None: """Reset Manifest state in preparation for a new transaction.""" self.status = dataclasses.replace(self._read(location)) self.status.config_name = config_name self.status.dataset = dataset if dataset else None self.status.selection = selection self.status.location = location self.status.username = user def __enter__(self) -> None: pass def __exit__(self, exc_type, exc_inst, exc_tb) -> None: """Record end status of a transaction as either 'success' or 'failure'.""" if exc_type is None: status = Status.SUCCESS error = None else: status = Status.FAILURE # For explanation, see https://docs.python.org/3/library/traceback.html#traceback.format_exception error = f"license_id: {self.license_id} " error += "\n".join(traceback.format_exception(exc_type, exc_inst, exc_tb)) new_status = dataclasses.replace(self.status) new_status.error = error new_status.status = status current_utc_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec="seconds") ) # This is necessary for setting the precise start time of the previous stage # and end time of the final stage, as well as handling the case of Status.FAILURE. if new_status.stage == Stage.FETCH: new_status.fetch_start_time = self.prev_stage_precise_start_time new_status.fetch_end_time = current_utc_time elif new_status.stage == Stage.RETRIEVE: new_status.retrieve_start_time = self.prev_stage_precise_start_time new_status.retrieve_end_time = current_utc_time elif new_status.stage == Stage.DOWNLOAD: new_status.download_start_time = self.prev_stage_precise_start_time new_status.download_end_time = current_utc_time else: new_status.upload_start_time = self.prev_stage_precise_start_time new_status.upload_end_time = current_utc_time new_status.size = get_file_size(new_status.location) self.status = new_status self._update(self.status) def transact( self, config_name: str, dataset: str, selection: t.Dict, location: str, user: str, ) -> "Manifest": """Create a download transaction.""" self._set_for_transaction(config_name, dataset, selection, location, user) return self def set_stage(self, stage: Stage) -> None: """Sets the current stage in manifest.""" new_status = dataclasses.replace(self.status) new_status.stage = stage new_status.status = Status.IN_PROGRESS current_utc_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec="seconds") ) if stage == Stage.DOWNLOAD: new_status.download_start_time = current_utc_time else: new_status.download_start_time = self.prev_stage_precise_start_time new_status.download_end_time = current_utc_time new_status.upload_start_time = current_utc_time self.status = new_status self._update(self.status) @abc.abstractmethod def _read(self, location: str) -> DownloadStatus: pass @abc.abstractmethod def _update(self, download_status: DownloadStatus) -> None: pass class FirestoreManifest(Manifest): """A Firestore Manifest. This Manifest implementation stores DownloadStatuses in a Firebase document store. The document hierarchy for the manifest is as follows: [manifest ] ├── doc_id (md5 hash of the path) { 'selection': {...}, 'location': ..., 'username': ... } └── etc... Where `[]` indicates a collection and ` {...}` indicates a document. """ def _get_db(self) -> firestore.firestore.Client: """Acquire a firestore client, initializing the firebase app if necessary. Will attempt to get the db client five times. If it's still unsuccessful, a `ManifestException` will be raised. """ db = None attempts = 0 while db is None: try: db = firestore.client() except ValueError as e: # The above call will fail with a value error when the firebase app is not initialized. # Initialize the app here, and try again. # Use the application default credentials. cred = credentials.ApplicationDefault() firebase_admin.initialize_app(cred) print("Initialized Firebase App.") if attempts > 4: raise ManifestException( "Exceeded number of retries to get firestore client." ) from e time.sleep(get_wait_interval(attempts)) attempts += 1 return db def _read(self, location: str) -> DownloadStatus: """Reads the JSON data from a manifest.""" doc_id = generate_md5_hash(location) # Update document with download status download_doc_ref = self.root_document_for_store(doc_id) result = download_doc_ref.get() row = {} if result.exists: records = result.to_dict() row = {n: to_json_serializable_type(v) for n, v in records.items()} return DownloadStatus.from_dict(row) def _update(self, download_status: DownloadStatus) -> None: """Update or create a download status record.""" print("Updating Firestore Manifest.") status = DownloadStatus.to_dict(download_status) doc_id = generate_md5_hash(status["location"]) # Update document with download status download_doc_ref = self.root_document_for_store(doc_id) result: WriteResult = download_doc_ref.set(status) print( f"Firestore manifest updated. " f"update_time={result.update_time}, " f"filename={download_status.location}." ) def root_document_for_store(self, store_scheme: str) -> DocumentReference: """Get the root manifest document given the user's config and current document's storage location.""" return ( self._get_db() .collection(get_config().manifest_collection) .document(store_scheme) ) ================================================ FILE: weather_dl_v2/downloader_kubernetes/util.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import geojson import hashlib import itertools import os import socket import subprocess import sys import typing as t import numpy as np import pandas as pd from apache_beam.io.gcp import gcsio from apache_beam.utils import retry from xarray.core.utils import ensure_us_time_resolution from urllib.parse import urlparse from google.api_core.exceptions import BadRequest LATITUDE_RANGE = (-90, 90) LONGITUDE_RANGE = (-180, 180) GLOBAL_COVERAGE_AREA = [90, -180, -90, 180] def _retry_if_valid_input_but_server_or_socket_error_and_timeout_filter( exception, ) -> bool: if isinstance(exception, socket.timeout): return True if isinstance(exception, TimeoutError): return True # To handle the concurrency issue in BigQuery. if isinstance(exception, BadRequest): return True return retry.retry_if_valid_input_but_server_error_and_timeout_filter(exception) class _FakeClock: def sleep(self, value): pass def retry_with_exponential_backoff(fun): """A retry decorator that doesn't apply during test time.""" clock = retry.Clock() # Use a fake clock only during test time... if "unittest" in sys.modules.keys(): clock = _FakeClock() return retry.with_exponential_backoff( retry_filter=_retry_if_valid_input_but_server_or_socket_error_and_timeout_filter, clock=clock, )(fun) # TODO(#245): Group with common utilities (duplicated) def ichunked(iterable: t.Iterable, n: int) -> t.Iterator[t.Iterable]: """Yield evenly-sized chunks from an iterable.""" input_ = iter(iterable) try: while True: it = itertools.islice(input_, n) # peek to check if 'it' has next item. first = next(it) yield itertools.chain([first], it) except StopIteration: pass # TODO(#245): Group with common utilities (duplicated) def copy(src: str, dst: str) -> None: """Copy data via `gcloud storage cp`.""" try: subprocess.run(["gcloud", "storage", "cp", src, dst], check=True, capture_output=True) except subprocess.CalledProcessError as e: print( f'Failed to copy file {src!r} to {dst!r} due to {e.stderr.decode("utf-8")}' ) raise # TODO(#245): Group with common utilities (duplicated) def to_json_serializable_type(value: t.Any) -> t.Any: """Returns the value with a type serializable to JSON""" # Note: The order of processing is significant. print("Serializing to JSON") if pd.isna(value) or value is None: return None elif np.issubdtype(type(value), np.floating): return float(value) elif isinstance(value, np.ndarray): # Will return a scaler if array is of size 1, else will return a list. return value.tolist() elif ( isinstance(value, datetime.datetime) or isinstance(value, str) or isinstance(value, np.datetime64) ): # Assume strings are ISO format timestamps... try: value = datetime.datetime.fromisoformat(value) except ValueError: # ... if they are not, assume serialization is already correct. return value except TypeError: # ... maybe value is a numpy datetime ... try: value = ensure_us_time_resolution(value).astype(datetime.datetime) except AttributeError: # ... value is a datetime object, continue. pass # We use a string timestamp representation. if value.tzname(): return value.isoformat() # We assume here that naive timestamps are in UTC timezone. return value.replace(tzinfo=datetime.timezone.utc).isoformat() elif isinstance(value, np.timedelta64): # Return time delta in seconds. return float(value / np.timedelta64(1, "s")) # This check must happen after processing np.timedelta64 and np.datetime64. elif np.issubdtype(type(value), np.integer): return int(value) return value def fetch_geo_polygon(area: t.Union[list, str]) -> str: """Calculates a geography polygon from an input area.""" # Ref: https://confluence.ecmwf.int/pages/viewpage.action?pageId=151520973 if isinstance(area, str): # European area if area == "E": area = [73.5, -27, 33, 45] # Global area elif area == "G": area = GLOBAL_COVERAGE_AREA else: raise RuntimeError(f"Not a valid value for area in config: {area}.") n, w, s, e = [float(x) for x in area] if s < LATITUDE_RANGE[0]: raise ValueError(f"Invalid latitude value for south: '{s}'") if n > LATITUDE_RANGE[1]: raise ValueError(f"Invalid latitude value for north: '{n}'") if w < LONGITUDE_RANGE[0]: raise ValueError(f"Invalid longitude value for west: '{w}'") if e > LONGITUDE_RANGE[1]: raise ValueError(f"Invalid longitude value for east: '{e}'") # Define the coordinates of the bounding box. coords = [[w, n], [w, s], [e, s], [e, n], [w, n]] # Create the GeoJSON polygon object. polygon = geojson.dumps(geojson.Polygon([coords])) return polygon def get_file_size(path: str) -> float: parsed_gcs_path = urlparse(path) if parsed_gcs_path.scheme != "gs" or parsed_gcs_path.netloc == "": return os.stat(path).st_size / (1024**3) if os.path.exists(path) else 0 else: return ( gcsio.GcsIO().size(path) / (1024**3) if gcsio.GcsIO().exists(path) else 0 ) def get_wait_interval(num_retries: int = 0) -> float: """Returns next wait interval in seconds, using an exponential backoff algorithm.""" if 0 == num_retries: return 0 return 2**num_retries def generate_md5_hash(input: str) -> str: """Generates md5 hash for the input string.""" return hashlib.md5(input.encode("utf-8")).hexdigest() def download_with_aria2(url: str, path: str) -> None: """Downloads a file from the given URL using the `aria2c` command-line utility, with options set to improve download speed and reliability.""" dir_path, file_name = os.path.split(path) try: subprocess.run( [ "aria2c", "-x", "16", "-s", "16", url, "-d", dir_path, "-o", file_name, "--allow-overwrite", ], check=True, capture_output=True, ) except subprocess.CalledProcessError as e: print( f'Failed download from server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}' ) raise def download_with_wget(url: str, path: str) -> None: """Downloads a file from given URL using `wget` command.""" try: subprocess.run(["wget", url, "-O", path], check=True, capture_output=True) except subprocess.CalledProcessError as e: print( f'Failed download from server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}.' ) raise ================================================ FILE: weather_dl_v2/fastapi-server/API-Interactions.md ================================================ # API Interactions | Command | Type | Endpoint | |---|---|---| | `weather-dl-v2 ping` | `get` | `/` | Download | | | | `weather-dl-v2 download add –l [--force-download]` | `post` | `/download?force_download={value}` | | `weather-dl-v2 download list` | `get` | `/download/` | | `weather-dl-v2 download list --filter client_name=` | `get` | `/download?client_name={name}` | | `weather-dl-v2 download get ` | `get` | `/download/{config_name}` | | `weather-dl-v2 download show ` | `get` | `/download/show/{config_name}` | | `weather-dl-v2 download remove ` | `delete` | `/download/{config_name}` | | `weather-dl-v2 download refetch -l ` | `post` | `/download/refetch/{config_name}` | | License | | | | `weather-dl-v2 license add ` | `post` | `/license/` | | `weather-dl-v2 license get ` | `get` | `/license/{license_id}` | | `weather-dl-v2 license remove ` | `delete` | `/license/{license_id}` | | `weather-dl-v2 license list` | `get` | `/license/` | | `weather-dl-v2 license list --filter client_name=` | `get` | `/license?client_name={name}` | | `weather-dl-v2 license edit ` | `put` | `/license/{license_id}` | | Queue | | | | `weather-dl-v2 queue list` | `get` | `/queues/` | | `weather-dl-v2 queue list --filter client_name=` | `get` | `/queues?client_name={name}` | | `weather-dl-v2 queue get ` | `get` | `/queues/{license_id}` | | `queue edit --config --priority ` | `post` | `/queues/{license_id}` | | `queue edit --file ` | `put` | `/queues/priority/{license_id}` | ================================================ FILE: weather_dl_v2/fastapi-server/Dockerfile ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. FROM continuumio/miniconda3:latest EXPOSE 8080 # Update miniconda RUN conda update conda -y # Add the mamba solver for faster builds RUN conda install -n base conda-libmamba-solver RUN conda config --set solver libmamba COPY . . # Create conda env using environment.yml RUN conda env create -f environment.yml --debug # Activate the conda env and update the PATH ARG CONDA_ENV_NAME=weather-dl-v2-server RUN echo "source activate ${CONDA_ENV_NAME}" >> ~/.bashrc ENV PATH /opt/conda/envs/${CONDA_ENV_NAME}/bin:$PATH # Use the ping endpoint as a healthcheck, # so Docker knows if the API is still running ok or needs to be restarted HEALTHCHECK --interval=21s --timeout=3s --start-period=10s CMD curl --fail http://localhost:8080/ping || exit 1 CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"] ================================================ FILE: weather_dl_v2/fastapi-server/README.md ================================================ # Deployment Instructions & General Notes ### User authorization required to set up the environment: * roles/container.admin ### Authorization needed for the tool to operate: We are not configuring any service account here hence make sure that compute engine default service account have roles: * roles/pubsub.subscriber * roles/storage.admin * roles/bigquery.dataEditor * roles/bigquery.jobUser ### Install kubectl: ``` apt-get update apt-get install -y kubectl ``` ### Create cluster: ``` export PROJECT_ID=anthromet-ingestion export REGION=us-west1 export ZONE=us-west1-a export CLUSTER_NAME=weather-dl-v2-cluster-2 export DOWNLOAD_NODE_POOL=downloader-pool gcloud beta container --project $PROJECT_ID clusters create $CLUSTER_NAME --zone $ZONE --no-enable-basic-auth --cluster-version "1.29.5-gke.1091002" --release-channel "regular" --machine-type "e2-standard-2" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "200" --node-labels preemptible=false --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/cloud-platform" --max-pods-per-node "16" --num-nodes "3" --logging=SYSTEM,WORKLOAD --monitoring=SYSTEM,STORAGE,POD,DEPLOYMENT,STATEFULSET,DAEMONSET,HPA,CADVISOR,KUBELET --enable-ip-alias --network "projects/$PROJECT_ID/global/networks/default" --subnetwork "projects/$PROJECT_ID/regions/$REGION/subnetworks/default" --no-enable-intra-node-visibility --default-max-pods-per-node "110" --enable-autoscaling --min-nodes "1" --max-nodes "200" --location-policy "BALANCED" --no-enable-master-authorized-networks --addons HorizontalPodAutoscaling,HttpLoadBalancing,GcePersistentDiskCsiDriver --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --binauthz-evaluation-mode=DISABLED --enable-managed-prometheus --enable-shielded-nodes --node-locations $ZONE gcloud beta container --project $PROJECT_ID node-pools create $DOWNLOAD_NODE_POOL --cluster $CLUSTER_NAME --zone $ZONE --machine-type "e2-standard-8" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "1000" --node-labels preemptible=false --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/cloud-platform" --num-nodes "1" --enable-autoscaling --min-nodes "1" --max-nodes "100" --location-policy "BALANCED" --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --max-pods-per-node "16" --node-locations $ZONE ### Connect to Cluster: ``` gcloud container clusters get-credentials $CLUSTER_NAME --zone $ZONE --project $PROJECT_ID ``` ### How to create environment: ``` conda env create --name weather-dl-v2-server --file=environment.yml conda activate weather-dl-v2-server ``` ### Make changes in weather_dl_v2/config.json, if required [for running locally] ``` export CONFIG_PATH=/path/to/weather_dl_v2/config.json ``` ### To run fastapi server: ``` uvicorn main:app --reload ``` * Open your browser at http://127.0.0.1:8000. ### Create docker image for server: Refer instructions in weather_dl_v2/README.md ### Add path of created server image in server.yaml: ``` Please write down the fastAPI server's docker image path at Line 42 of server.yaml. ``` ### Create ConfigMap of common configurations for services: Make necessary changes to weather_dl_v2/config.json and run following command. ConfigMap is used for: - Having a common configuration file for all services. - Decoupling docker image and config files. ``` kubectl create configmap dl-v2-config --from-file=/path/to/weather_dl_v2/config.json ``` ### Deploy fastapi server on kubernetes: ``` kubectl apply -f server.yaml --force ``` ## General Commands ### For viewing the current pods: ``` kubectl get pods ``` ### For deleting existing deployment: ``` kubectl delete -f server.yaml --force ================================================ FILE: weather_dl_v2/fastapi-server/VERSION.txt ================================================ 1.0.12 ================================================ FILE: weather_dl_v2/fastapi-server/__init__.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: weather_dl_v2/fastapi-server/config_processing/config.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import calendar import copy import dataclasses import itertools import typing as t Values = t.Union[t.List["Values"], t.Dict[str, "Values"], bool, int, float, str] # pytype: disable=not-supported-yet @dataclasses.dataclass class Config: """Contains pipeline parameters. Attributes: config_name: Name of the config file. client: Name of the Weather-API-client. Supported clients are mentioned in the 'CLIENTS' variable. dataset (optional): Name of the target dataset. Allowed options are dictated by the client. partition_keys (optional): Choose the keys from the selection section to partition the data request. This will compute a cartesian cross product of the selected keys and assign each as their own download. target_path: Download artifact filename template. Can make use of Python's standard string formatting. It can contain format symbols to be replaced by partition keys; if this is used, the total number of format symbols must match the number of partition keys. subsection_name: Name of the particular subsection. 'default' if there is no subsection. force_download: Force redownload of partitions that were previously downloaded. user_id: Username from the environment variables. kwargs (optional): For representing subsections or any other parameters. selection: Contains parameters used to select desired data. """ config_name: str = "" client: str = "" dataset: t.Optional[str] = "" target_path: str = "" partition_keys: t.Optional[t.List[str]] = dataclasses.field(default_factory=list) subsection_name: str = "default" force_download: bool = False user_id: str = "unknown" kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict) selection: t.Dict[str, Values] = dataclasses.field(default_factory=dict) @classmethod def from_dict(cls, config: t.Dict) -> "Config": config_instance = cls() for section_key, section_value in config.items(): if section_key == "parameters": for key, value in section_value.items(): if hasattr(config_instance, key): setattr(config_instance, key, value) else: config_instance.kwargs[key] = value if section_key == "selection": config_instance.selection = section_value return config_instance def optimize_selection_partition(selection: t.Dict) -> t.Dict: """Compute right-hand-side values for the selection section of a single partition. Used to support custom syntax and optimizations, such as 'all'. """ selection_ = copy.deepcopy(selection) if "date_range" in selection_.keys(): selection_["date"] = selection_["date_range"][0] del selection_["date_range"] if "day" in selection_.keys() and selection_["day"] == "all": years, months = selection_["year"], selection_["month"] multiples_error = "When using day='all' in selection, '/' is not allowed in {type}." if isinstance(years, str): years = [years] if isinstance(months, str): months = [months] date_ranges = [] # Generating dates for every year-month. for year, month in itertools.product(years, months): if isinstance(year, str): assert "/" not in year, multiples_error.format(type="year") if isinstance(month, str): assert "/" not in month, multiples_error.format(type="month") year, month = int(year), int(month) _, n_days_in_month = calendar.monthrange(year, month) date_range = [f'{year:04d}-{month:02d}-{day:02d}' for day in range(1, n_days_in_month + 1)] date_ranges.extend(date_range) selection_["date"] = date_ranges del selection_["day"] del selection_["month"] del selection_["year"] return selection_ ================================================ FILE: weather_dl_v2/fastapi-server/config_processing/manifest.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Client interface for connecting to a manifest.""" import abc import dataclasses import logging import datetime import enum import json import pandas as pd import time import traceback import typing as t from .util import ( to_json_serializable_type, fetch_geo_polygon, get_file_size, get_wait_interval, generate_md5_hash, GLOBAL_COVERAGE_AREA, ) import firebase_admin from firebase_admin import credentials from firebase_admin import firestore from google.cloud.firestore_v1 import DocumentReference from google.cloud.firestore_v1.types import WriteResult from server_config import get_config from database.session import Database """An implementation-dependent Manifest URI.""" Location = t.NewType("Location", str) logger = logging.getLogger(__name__) class ManifestException(Exception): """Errors that occur in Manifest Clients.""" pass class Stage(enum.Enum): """A request can be either in one of the following stages at a time: fetch : This represents request is currently in fetch stage i.e. request placed on the client's server & waiting for some result before starting download (eg. MARS client). download : This represents request is currently in download stage i.e. data is being downloading from client's server to the worker's local file system. upload : This represents request is currently in upload stage i.e. data is getting uploaded from worker's local file system to target location (GCS path). retrieve : In case of clients where there is no proper separation of fetch & download stages (eg. CDS client), request will be in the retrieve stage i.e. fetch + download. """ RETRIEVE = "retrieve" FETCH = "fetch" DOWNLOAD = "download" UPLOAD = "upload" class Status(enum.Enum): """Depicts the request's state status: scheduled : A request partition is created & scheduled for processing. Note: Its corresponding state can be None only. in-progress : This represents the request state is currently in-progress (i.e. running). The next status would be "success" or "failure". success : This represents the request state execution completed successfully without any error. failure : This represents the request state execution failed. """ SCHEDULED = "scheduled" IN_PROGRESS = "in-progress" SUCCESS = "success" FAILURE = "failure" @dataclasses.dataclass class DownloadStatus: """Data recorded in `Manifest`s reflecting the status of a download.""" """The name of the config file associated with the request.""" config_name: str = "" """Represents the dataset field of the configuration.""" dataset: t.Optional[str] = "" """Copy of selection section of the configuration.""" selection: t.Dict = dataclasses.field(default_factory=dict) """Location of the downloaded data.""" location: str = "" """Represents area covered by the shard.""" area: str = "" """Current stage of request : 'fetch', 'download', 'retrieve', 'upload' or None.""" stage: t.Optional[Stage] = None """Download status: 'scheduled', 'in-progress', 'success', or 'failure'.""" status: t.Optional[Status] = None """Cause of error, if any.""" error: t.Optional[str] = "" """Identifier for the user running the download.""" username: str = "" """Shard size in GB.""" size: t.Optional[float] = 0 """A UTC datetime when download was scheduled.""" scheduled_time: t.Optional[str] = "" """A UTC datetime when the retrieve stage starts.""" retrieve_start_time: t.Optional[str] = "" """A UTC datetime when the retrieve state ends.""" retrieve_end_time: t.Optional[str] = "" """A UTC datetime when the fetch state starts.""" fetch_start_time: t.Optional[str] = "" """A UTC datetime when the fetch state ends.""" fetch_end_time: t.Optional[str] = "" """A UTC datetime when the download state starts.""" download_start_time: t.Optional[str] = "" """A UTC datetime when the download state ends.""" download_end_time: t.Optional[str] = "" """A UTC datetime when the upload state starts.""" upload_start_time: t.Optional[str] = "" """A UTC datetime when the upload state ends.""" upload_end_time: t.Optional[str] = "" @classmethod def from_dict(cls, download_status: t.Dict) -> "DownloadStatus": """Instantiate DownloadStatus dataclass from dict.""" download_status_instance = cls() for key, value in download_status.items(): if key == "status": setattr(download_status_instance, key, Status(value)) elif key == "stage" and value is not None: setattr(download_status_instance, key, Stage(value)) else: setattr(download_status_instance, key, value) return download_status_instance @classmethod def to_dict(cls, instance) -> t.Dict: """Return the fields of a dataclass instance as a manifest ingestible dictionary mapping of field names to field values.""" download_status_dict = {} for field in dataclasses.fields(instance): key = field.name value = getattr(instance, field.name) if isinstance(value, Status) or isinstance(value, Stage): download_status_dict[key] = value.value elif isinstance(value, pd.Timestamp): download_status_dict[key] = value.isoformat() elif key == "selection" and value is not None: download_status_dict[key] = json.dumps(value) else: download_status_dict[key] = value return download_status_dict @dataclasses.dataclass class Manifest(abc.ABC): """Abstract manifest of download statuses. Update download statuses to some storage medium. This class lets one indicate that a download is `scheduled` or in a transaction process. In the event of a transaction, a download will be updated with an `in-progress`, `success` or `failure` status (with accompanying metadata). Example: ``` my_manifest = parse_manifest_location(Location('fs://some-firestore-collection')) # Schedule data for download my_manifest.schedule({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username') # ... # Initiate a transaction – it will record that the download is `in-progess` with my_manifest.transact({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username') as tx: # download logic here pass # ... # on error, will record the download as a `failure` before propagating the error. By default, it will # record download as a `success`. ``` Attributes: status: The current `DownloadStatus` of the Manifest. """ # To reduce the impact of _read() and _update() calls # on the start time of the stage. prev_stage_precise_start_time: t.Optional[str] = None status: t.Optional[DownloadStatus] = None # This is overridden in subclass. def __post_init__(self): """Initialize the manifest.""" pass def schedule( self, config_name: str, dataset: str, selection: t.Dict, location: str, user: str, ) -> None: """Indicate that a job has been scheduled for download. 'scheduled' jobs occur before 'in-progress', 'success' or 'finished'. """ scheduled_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec="seconds") ) self.status = DownloadStatus( config_name=config_name, dataset=dataset if dataset else None, selection=selection, location=location, area=fetch_geo_polygon(selection.get("area", GLOBAL_COVERAGE_AREA)), username=user, stage=None, status=Status.SCHEDULED, error=None, size=None, scheduled_time=scheduled_time, retrieve_start_time=None, retrieve_end_time=None, fetch_start_time=None, fetch_end_time=None, download_start_time=None, download_end_time=None, upload_start_time=None, upload_end_time=None, ) self._update(self.status) def skip( self, config_name: str, dataset: str, selection: t.Dict, location: str, user: str, ) -> None: """Updates the manifest to mark the shards that were skipped in the current job as 'upload' stage and 'success' status, indicating that they have already been downloaded. """ old_status = self._read(location) # The manifest needs to be updated for a skipped shard if its entry is not present, or # if the stage is not 'upload', or if the stage is 'upload' but the status is not 'success'. if ( old_status.location != location or old_status.stage != Stage.UPLOAD or old_status.status != Status.SUCCESS ): current_utc_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec="seconds") ) size = get_file_size(location) status = DownloadStatus( config_name=config_name, dataset=dataset if dataset else None, selection=selection, location=location, area=fetch_geo_polygon(selection.get("area", GLOBAL_COVERAGE_AREA)), username=user, stage=Stage.UPLOAD, status=Status.SUCCESS, error=None, size=size, scheduled_time=None, retrieve_start_time=None, retrieve_end_time=None, fetch_start_time=None, fetch_end_time=None, download_start_time=None, download_end_time=None, upload_start_time=current_utc_time, upload_end_time=current_utc_time, ) self._update(status) logger.info( f"Manifest updated for skipped shard: {location!r} -- {DownloadStatus.to_dict(status)!r}." ) def _set_for_transaction( self, config_name: str, dataset: str, selection: t.Dict, location: str, user: str, ) -> None: """Reset Manifest state in preparation for a new transaction.""" self.status = dataclasses.replace(self._read(location)) self.status.config_name = config_name self.status.dataset = dataset if dataset else None self.status.selection = selection self.status.location = location self.status.username = user def __enter__(self) -> None: pass def __exit__(self, exc_type, exc_inst, exc_tb) -> None: """Record end status of a transaction as either 'success' or 'failure'.""" if exc_type is None: status = Status.SUCCESS error = None else: status = Status.FAILURE # For explanation, see https://docs.python.org/3/library/traceback.html#traceback.format_exception error = "\n".join(traceback.format_exception(exc_type, exc_inst, exc_tb)) new_status = dataclasses.replace(self.status) new_status.error = error new_status.status = status current_utc_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec="seconds") ) # This is necessary for setting the precise start time of the previous stage # and end time of the final stage, as well as handling the case of Status.FAILURE. if new_status.stage == Stage.FETCH: new_status.fetch_start_time = self.prev_stage_precise_start_time new_status.fetch_end_time = current_utc_time elif new_status.stage == Stage.RETRIEVE: new_status.retrieve_start_time = self.prev_stage_precise_start_time new_status.retrieve_end_time = current_utc_time elif new_status.stage == Stage.DOWNLOAD: new_status.download_start_time = self.prev_stage_precise_start_time new_status.download_end_time = current_utc_time else: new_status.upload_start_time = self.prev_stage_precise_start_time new_status.upload_end_time = current_utc_time new_status.size = get_file_size(new_status.location) self.status = new_status self._update(self.status) def transact( self, config_name: str, dataset: str, selection: t.Dict, location: str, user: str, ) -> "Manifest": """Create a download transaction.""" self._set_for_transaction(config_name, dataset, selection, location, user) return self def set_stage(self, stage: Stage) -> None: """Sets the current stage in manifest.""" prev_stage = self.status.stage new_status = dataclasses.replace(self.status) new_status.stage = stage new_status.status = Status.IN_PROGRESS current_utc_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec="seconds") ) if stage == Stage.FETCH: new_status.fetch_start_time = current_utc_time elif stage == Stage.RETRIEVE: new_status.retrieve_start_time = current_utc_time elif stage == Stage.DOWNLOAD: new_status.fetch_start_time = self.prev_stage_precise_start_time new_status.fetch_end_time = current_utc_time new_status.download_start_time = current_utc_time else: if prev_stage == Stage.DOWNLOAD: new_status.download_start_time = self.prev_stage_precise_start_time new_status.download_end_time = current_utc_time else: new_status.retrieve_start_time = self.prev_stage_precise_start_time new_status.retrieve_end_time = current_utc_time new_status.upload_start_time = current_utc_time self.status = new_status self._update(self.status) @abc.abstractmethod def _read(self, location: str) -> DownloadStatus: pass @abc.abstractmethod def _update(self, download_status: DownloadStatus) -> None: pass class FirestoreManifest(Manifest, Database): """A Firestore Manifest. This Manifest implementation stores DownloadStatuses in a Firebase document store. The document hierarchy for the manifest is as follows: [manifest ] ├── doc_id (md5 hash of the path) { 'selection': {...}, 'location': ..., 'username': ... } └── etc... Where `[]` indicates a collection and ` {...}` indicates a document. """ def _get_db(self) -> firestore.firestore.Client: """Acquire a firestore client, initializing the firebase app if necessary. Will attempt to get the db client five times. If it's still unsuccessful, a `ManifestException` will be raised. """ db = None attempts = 0 while db is None: try: db = firestore.client() except ValueError as e: # The above call will fail with a value error when the firebase app is not initialized. # Initialize the app here, and try again. # Use the application default credentials. cred = credentials.ApplicationDefault() firebase_admin.initialize_app(cred) logger.info("Initialized Firebase App.") if attempts > 4: raise ManifestException( "Exceeded number of retries to get firestore client." ) from e time.sleep(get_wait_interval(attempts)) attempts += 1 return db def _read(self, location: str) -> DownloadStatus: """Reads the JSON data from a manifest.""" doc_id = generate_md5_hash(location) # Update document with download status download_doc_ref = self.root_document_for_store(doc_id) result = download_doc_ref.get() row = {} if result.exists: records = result.to_dict() row = {n: to_json_serializable_type(v) for n, v in records.items()} return DownloadStatus.from_dict(row) def _update(self, download_status: DownloadStatus) -> None: """Update or create a download status record.""" logger.info("Updating Firestore Manifest.") status = DownloadStatus.to_dict(download_status) doc_id = generate_md5_hash(status["location"]) # Update document with download status download_doc_ref = self.root_document_for_store(doc_id) result: WriteResult = download_doc_ref.set(status) logger.info( f"Firestore manifest updated. " f"update_time={result.update_time}, " f"filename={download_status.location}." ) def root_document_for_store(self, store_scheme: str) -> DocumentReference: """Get the root manifest document given the user's config and current document's storage location.""" root_collection = get_config().manifest_collection return self._get_db().collection(root_collection).document(store_scheme) ================================================ FILE: weather_dl_v2/fastapi-server/config_processing/parsers.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Parsers for ECMWF download configuration.""" import ast import configparser import copy as cp import datetime import json import string import textwrap import typing as t import numpy as np from collections import OrderedDict from dateutil.relativedelta import relativedelta from .config import Config CLIENTS = ["cds", "mars", "ecpublic"] def date(candidate: str) -> datetime.date: """Converts ECMWF-format date strings into a `datetime.date`. Accepted absolute date formats: - YYYY-MM-DD - YYYYMMDD - YYYY-DDD, where DDD refers to the day of the year For example: - 2021-10-31 - 19700101 - 1950-007 See https://confluence.ecmwf.int/pages/viewpage.action?pageId=118817289 for date format spec. Note: Name of month is not supported. """ converted = None # Parse relative day value. if candidate.startswith("-"): return datetime.date.today() + datetime.timedelta(days=int(candidate)) accepted_formats = ["%Y-%m-%d", "%Y%m%d", "%Y-%j"] for fmt in accepted_formats: try: converted = datetime.datetime.strptime(candidate, fmt).date() break except ValueError: pass if converted is None: raise ValueError( f"Not a valid date: '{candidate}'. Please use valid relative or absolute format." ) return converted def time(candidate: str) -> datetime.time: """Converts ECMWF-format time strings into a `datetime.time`. Accepted time formats: - HH:MM - HHMM - HH For example: - 18:00 - 1820 - 18 Note: If MM is omitted it defaults to 00. """ converted = None accepted_formats = ["%H", "%H:%M", "%H%M"] for fmt in accepted_formats: try: converted = datetime.datetime.strptime(candidate, fmt).time() break except ValueError: pass if converted is None: raise ValueError(f"Not a valid time: '{candidate}'. Please use valid format.") return converted def day_month_year(candidate: t.Any) -> int: """Converts day, month and year strings into 'int'.""" try: if isinstance(candidate, str) or isinstance(candidate, int): return int(candidate) raise ValueError("must be a str or int.") except ValueError as e: raise ValueError( f"Not a valid day, month, or year value: {candidate}. Please use valid value." ) from e def date_range_converter(candidate: str) -> str: """Replace / with _ to avoid directory creation.""" return candidate.replace('/', '_') def parse_literal(candidate: t.Any) -> t.Any: try: # Support parsing ints with leading zeros, e.g. '01' if isinstance(candidate, str) and candidate.isdigit(): return int(candidate) return ast.literal_eval(candidate) except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): return candidate def validate(key: str, value: int) -> None: """Validates value based on the key.""" if key == "day": assert 1 <= value <= 31, "Day value must be between 1 to 31." if key == "month": assert 1 <= value <= 12, "Month value must be between 1 to 12." def typecast(key: str, value: t.Any) -> t.Any: """Type the value to its appropriate datatype.""" SWITCHER = { "date": date, "time": time, "day": day_month_year, "month": day_month_year, "year": day_month_year, 'date_range': date_range_converter, } converted = SWITCHER.get(key, parse_literal)(value) validate(key, converted) return converted def _read_config_file(file: t.IO) -> t.Dict: """Reads `*.json` or `*.cfg` files.""" try: return json.load(file) except json.JSONDecodeError: pass file.seek(0) try: config = configparser.ConfigParser() config.read_file(file) config = {s: dict(config.items(s)) for s in config.sections()} return config except configparser.ParsingError: return {} def parse_config(file: t.IO) -> t.Dict: """Parses a `*.json` or `*.cfg` file into a configuration dictionary.""" config = _read_config_file(file) config_by_section = {s: _parse_lists(v, s) for s, v in config.items()} config_with_nesting = parse_subsections(config_by_section) return config_with_nesting def _splitlines(block: str) -> t.List[str]: """Converts a multi-line block into a list of strings.""" return [line.strip() for line in block.strip().splitlines()] def mars_range_value(token: str, key: str) -> t.Union[datetime.date, int, float]: """Converts a range token into either a date, int, or float.""" try: if key == 'year-month': return datetime.datetime.strptime(token, "%Y-%m").date() else: return date(token) except ValueError: pass if token.isdecimal(): return int(token) try: return float(token) except ValueError: raise ValueError( "Token string must be an 'int', 'float', or 'datetime.date()'." ) def mars_increment_value(token: str) -> t.Union[int, float]: """Converts an increment token into either an int or a float.""" try: return int(token) except ValueError: pass try: return float(token) except ValueError: raise ValueError("Token string must be an 'int' or a 'float'.") def parse_mars_syntax(block: str, key: str) -> t.List[str]: """Parses MARS list or range into a list of arguments; ranges are inclusive. Types for the range and value are inferred. Examples: >>> parse_mars_syntax("10/to/12") ['10', '11', '12'] >>> parse_mars_syntax("12/to/10/by/-1") ['12', '11', '10'] >>> parse_mars_syntax("0.0/to/0.5/by/0.1") ['0.0', '0.1', '0.2', '0.30000000000000004', '0.4', '0.5'] >>> parse_mars_syntax("2020-01-07/to/2020-01-14/by/2") ['2020-01-07', '2020-01-09', '2020-01-11', '2020-01-13'] >>> parse_mars_syntax("2020-01-14/to/2020-01-07/by/-2") ['2020-01-14', '2020-01-12', '2020-01-10', '2020-01-08'] Returns: A list of strings representing a range from start to finish, based on the type of the values in the range. If all range values are integers, it will return a list of strings of integers. If range values are floats, it will return a list of strings of floats. If the range values are dates, it will return a list of strings of dates in YYYY-MM-DD format. (Note: here, the increment value should be an integer). """ # Split into tokens, omitting empty strings. tokens = [b.strip() for b in block.split("/") if b != ""] # Return list if no range operators are present. if "to" not in tokens and "by" not in tokens: return tokens # Parse range values, honoring 'to' and 'by' operators. try: to_idx = tokens.index("to") assert to_idx != 0, "There must be a start token." start_token, end_token = tokens[to_idx - 1], tokens[to_idx + 1] start, end = mars_range_value(start_token, key), mars_range_value(end_token, key) # Parse increment token, or choose default increment. increment_token = "1" increment = 1 if "by" in tokens: increment_token = tokens[tokens.index("by") + 1] increment = mars_increment_value(increment_token) except (AssertionError, IndexError, ValueError): raise SyntaxError(f"Improper range syntax in '{block}'.") # Return a range of values with appropriate data type. if (key == 'year-month' and isinstance(start, datetime.date) and isinstance(end, datetime.date) and isinstance(increment, int)): result = [] offset = 1 if start <= end else -1 if increment >= 0: increment *= offset # ensure increment has correct direction current = start while current <= end if offset > 0 else current >= end: result.append(current.strftime("%Y-%m")) current += relativedelta(months=increment) return result elif isinstance(start, datetime.date) and isinstance(end, datetime.date) and key != 'year-month': increment *= -1 if start > end and increment > 0 else 1 if not isinstance(increment, int): raise ValueError( f"Increments on a date range must be integer number of days, '{increment_token}' is invalid." ) return [d.strftime("%Y-%m-%d") for d in date_range(start, end, increment)] elif (isinstance(start, float) or isinstance(end, float)) and not isinstance( increment, datetime.date ): # Increment can be either an int or a float. _round_places = 4 return [ str(round(x, _round_places)).zfill(len(start_token)) for x in np.arange(start, end + increment, increment) ] elif isinstance(start, int) and isinstance(end, int) and isinstance(increment, int): # Honor leading zeros. offset = 1 if start <= end else -1 return [ str(x).zfill(len(start_token)) for x in range(start, end + offset, increment) ] else: raise ValueError( f"Range tokens (start='{start_token}', end='{end_token}', increment='{increment_token}')" f" are inconsistent types." ) def date_range( start: datetime.date, end: datetime.date, increment: int = 1 ) -> t.Iterable[datetime.date]: """Gets a range of dates, inclusive.""" offset = 1 if start <= end else -1 return ( start + datetime.timedelta(days=x) for x in range(0, (end - start).days + offset, increment) ) def _parse_lists(config: dict, section: str = "") -> t.Dict: """Parses multiline blocks in *.cfg and *.json files as lists.""" for key, val in config.items(): # Checks str type for backward compatibility since it also support "padding": 0 in json config if not isinstance(val, str): continue if "/" in val and "parameters" not in section and key != "date_range": config[key] = parse_mars_syntax(val, key) elif "\n" in val: config[key] = _splitlines(val) return config def _number_of_replacements(s: t.Text): format_names = [v[1] for v in string.Formatter().parse(s) if v[1] is not None] num_empty_names = len([empty for empty in format_names if empty == ""]) if num_empty_names != 0: num_empty_names -= 1 return len(set(format_names)) + num_empty_names def parse_subsections(config: t.Dict) -> t.Dict: """Interprets [section.subsection] as nested dictionaries in `.cfg` files.""" copy = cp.deepcopy(config) for key, val in copy.items(): path = key.split(".") runner = copy parent = {} p = None for p in path: if p not in runner: runner[p] = {} parent = runner runner = runner[p] parent[p] = val for_cleanup = [key for key, _ in copy.items() if "." in key] for target in for_cleanup: del copy[target] return copy def require( condition: bool, message: str, error_type: t.Type[Exception] = ValueError ) -> None: """A assert-like helper that wraps text and throws an error.""" if not condition: raise error_type(textwrap.dedent(message)) def process_config(file: t.IO, config_name: str) -> Config: """Read the config file and prompt the user if it is improperly structured.""" config = parse_config(file) require(bool(config), "Unable to parse configuration file.") require( "parameters" in config, """ 'parameters' section required in configuration file. The 'parameters' section specifies the 'client', 'dataset', 'target_path', and 'partition_key' for the API client. Please consult the documentation for more information.""", ) params = config.get("parameters", {}) require( "target_template" not in params, """ 'target_template' is deprecated, use 'target_path' instead. Please consult the documentation for more information.""", ) require( "target_path" in params, """ 'parameters' section requires a 'target_path' key. The 'target_path' is used to format the name of the output files. It accepts Python 3.5+ string format symbols (e.g. '{}'). The number of symbols should match the length of the 'partition_keys', as the 'partition_keys' args are used to create the templates.""", ) require( "client" in params, """ 'parameters' section requires a 'client' key. Supported clients are {} """.format( str(CLIENTS) ), ) require( params.get("client") in CLIENTS, """ Invalid 'client' parameter. Supported clients are {} """.format( str(CLIENTS) ), ) require( "append_date_dirs" not in params, """ The current version of 'google-weather-tools' no longer supports 'append_date_dirs'! Please refer to documentation for creating date-based directory hierarchy : https://weather-tools.readthedocs.io/en/latest/Configuration.html#""" """creating-a-date-based-directory-hierarchy.""", NotImplementedError, ) require( "target_filename" not in params, """ The current version of 'google-weather-tools' no longer supports 'target_filename'! Please refer to documentation : https://weather-tools.readthedocs.io/en/latest/Configuration.html#parameters-section.""", NotImplementedError, ) partition_keys = params.get("partition_keys", list()) if isinstance(partition_keys, str): partition_keys = [partition_keys.strip()] selection = config.get("selection", dict()) require( all((key in selection for key in partition_keys)), """ All 'partition_keys' must appear in the 'selection' section. 'partition_keys' specify how to split data for workers. Please consult documentation for more information.""", ) num_template_replacements = _number_of_replacements(params["target_path"]) num_partition_keys = len(partition_keys) require( num_template_replacements == num_partition_keys, """ 'target_path' has {0} replacements. Expected {1}, since there are {1} partition keys. """.format( num_template_replacements, num_partition_keys ), ) if "day" in partition_keys: require( selection["day"] != "all", """If 'all' is used for a selection value, it cannot appear as a partition key.""", ) if 'hdate' in selection: require('date' in partition_keys, """"If 'hdate' is specified in the 'selection' section, then 'date' is required as a partition keys.""") if 'date_range' in selection: require('date_range' in partition_keys, """"If 'date_range' is specified in the 'selection' section, then it is also required as a partition keys.""") # Ensure consistent lookup. config["parameters"]["partition_keys"] = partition_keys # Add config file name. config["parameters"]["config_name"] = config_name # Ensure the cartesian-cross can be taken on singleton values for the partition. for key in partition_keys: if not isinstance(selection[key], list): selection[key] = [selection[key]] return Config.from_dict(config) def prepare_target_name(config: Config) -> str: """Returns name of target location.""" partition_dict = OrderedDict( (key, typecast(key, config.selection[key][0])) for key in config.partition_keys ) target = config.target_path.format(*partition_dict.values(), **partition_dict) return target def get_subsections(config: Config) -> t.List[t.Tuple[str, t.Dict]]: """Collect parameter subsections from main configuration. If the `parameters` section contains subsections (e.g. '[parameters.1]', '[parameters.2]'), collect the subsection key-value pairs. Otherwise, return an empty dictionary (i.e. there are no subsections). This is useful for specifying multiple API keys for your configuration. For example: ``` [parameters.alice] api_key=KKKKK1 api_url=UUUUU1 [parameters.bob] api_key=KKKKK2 api_url=UUUUU2 [parameters.eve] api_key=KKKKK3 api_url=UUUUU3 ``` """ return [ (name, params) for name, params in config.kwargs.items() if isinstance(params, dict) ] or [("default", {})] ================================================ FILE: weather_dl_v2/fastapi-server/config_processing/partition.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import copy as cp import dataclasses import itertools import typing as t from .manifest import Manifest from .parsers import prepare_target_name from .config import Config from .stores import Store, FSStore from .util import generate_hdate logger = logging.getLogger(__name__) @dataclasses.dataclass class PartitionConfig: """Partition a config into multiple data requests. Partitioning involves four main operations: First, we fan-out shards based on partition keys (a cross product of the values). Second, we filter out existing downloads (unless we want to force downloads). Last, we assemble each partition into a single Config. Attributes: store: A cloud storage system, used for checking the existence of downloads. manifest: A download manifest to register preparation state. """ config: Config store: Store manifest: Manifest def _create_partition_config(self, option: t.Tuple) -> Config: """Create a config for a single partition option. Output a config dictionary, overriding the range of values for each key with the partition instance in 'selection'. Continuing the example from prepare_partitions, the selection section would be: { 'foo': ..., 'year': ['2020'], 'month': ['01'], ... } { 'foo': ..., 'year': ['2020'], 'month': ['02'], ... } { 'foo': ..., 'year': ['2020'], 'month': ['03'], ... } Args: option: A single item in the range of partition_keys. config: The download config, including the parameters and selection sections. Returns: A configuration with that selects a single download partition. """ copy = cp.deepcopy(self.config.selection) out = cp.deepcopy(self.config) for idx, key in enumerate(self.config.partition_keys): copy[key] = [option[idx]] # Replace hdate with actual value. if 'hdate' in copy: copy['hdate'] = [generate_hdate(copy['date'][0], v) for v in copy['hdate']] out.selection = copy return out def skip_partition(self, config: Config) -> bool: """Return true if partition should be skipped.""" if config.force_download: return False target = prepare_target_name(config) if self.store.exists(target): logger.info(f"file {target} found, skipping.") self.manifest.skip( config.config_name, config.dataset, config.selection, target, config.user_id, ) return True return False def prepare_partitions(self) -> t.Iterator[Config]: """Iterate over client parameters, partitioning over `partition_keys`. This produces a Cartesian-Cross over the range of keys. For example, if the keys were 'year' and 'month', it would produce an iterable like: ( ('2020', '01'), ('2020', '02'), ('2020', '03'), ...) Returns: An iterator of `Config`s. """ for option in itertools.product( *[self.config.selection[key] for key in self.config.partition_keys] ): yield self._create_partition_config(option) def new_downloads_only(self, candidate: Config) -> bool: """Predicate function to skip already downloaded partitions.""" if self.store is None: self.store = FSStore() should_skip = self.skip_partition(candidate) return not should_skip def update_manifest_collection(self, partition: Config) -> Config: """Updates the DB.""" location = prepare_target_name(partition) self.manifest.schedule( partition.config_name, partition.dataset, partition.selection, location, partition.user_id, ) logger.info(f"Created partition {location!r}.") ================================================ FILE: weather_dl_v2/fastapi-server/config_processing/pipeline.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import getpass import logging import os from .parsers import process_config from .partition import PartitionConfig from .manifest import FirestoreManifest from database.download_handler import get_download_handler from database.queue_handler import get_queue_handler from fastapi.concurrency import run_in_threadpool logger = logging.getLogger(__name__) download_handler = get_download_handler() queue_handler = get_queue_handler() def _do_partitions(partition_obj: PartitionConfig): for partition in partition_obj.prepare_partitions(): # Skip existing downloads if partition_obj.new_downloads_only(partition): partition_obj.update_manifest_collection(partition) # TODO: Make partitioning faster. async def start_processing_config(config_file, licenses, force_download, priority = None): config = {} manifest = FirestoreManifest() with open(config_file, "r", encoding="utf-8") as f: # configs/example.cfg -> example.cfg config_name = os.path.split(config_file)[1] config = process_config(f, config_name) config.force_download = force_download config.user_id = getpass.getuser() partition_obj = PartitionConfig(config, None, manifest) # Make entry in 'download' & 'queues' collection. await download_handler._start_download(config_name, config.client) await download_handler._mark_partitioning_status( config_name, "Partitioning in-progress." ) try: # Prepare partitions await run_in_threadpool(_do_partitions, partition_obj) await download_handler._mark_partitioning_status( config_name, "Partitioning completed." ) for license_id in licenses: await queue_handler._update_config_priority_in_license(license_id, config_name, priority) except Exception as e: error_str = f"Partitioning failed for {config_name} due to {e}." logger.error(error_str) await download_handler._mark_partitioning_status(config_name, error_str) ================================================ FILE: weather_dl_v2/fastapi-server/config_processing/stores.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Download destinations, or `Store`s.""" import abc import io import os import tempfile import typing as t from apache_beam.io.filesystems import FileSystems class Store(abc.ABC): """A interface to represent where downloads are stored. Default implementation uses Apache Beam's Filesystems. """ @abc.abstractmethod def open(self, filename: str, mode: str = "r") -> t.IO: pass @abc.abstractmethod def exists(self, filename: str) -> bool: pass class InMemoryStore(Store): """Store file data in memory.""" def __init__(self): self.store = {} def open(self, filename: str, mode: str = "r") -> t.IO: """Create or read in-memory data.""" if "b" in mode: file = io.BytesIO() else: file = io.StringIO() self.store[filename] = file return file def exists(self, filename: str) -> bool: """Return true if the 'file' exists in memory.""" return filename in self.store class TempFileStore(Store): """Store data into temporary files.""" def __init__(self, directory: t.Optional[str] = None) -> None: """Optionally specify the directory that contains all temporary files.""" self.dir = directory if self.dir and not os.path.exists(self.dir): os.makedirs(self.dir) def open(self, filename: str, mode: str = "r") -> t.IO: """Create a temporary file in the store directory.""" return tempfile.TemporaryFile(mode, dir=self.dir) def exists(self, filename: str) -> bool: """Return true if file exists.""" return os.path.exists(filename) class LocalFileStore(Store): """Store data into local files.""" def __init__(self, directory: t.Optional[str] = None) -> None: """Optionally specify the directory that contains all downloaded files.""" self.dir = directory if self.dir and not os.path.exists(self.dir): os.makedirs(self.dir) def open(self, filename: str, mode: str = "r") -> t.IO: """Open a local file from the store directory.""" return open(os.sep.join([self.dir, filename]), mode) def exists(self, filename: str) -> bool: """Returns true if local file exists.""" return os.path.exists(os.sep.join([self.dir, filename])) class FSStore(Store): """Store data into any store supported by Apache Beam's FileSystems.""" def open(self, filename: str, mode: str = "r") -> t.IO: """Open object in cloud bucket (or local file system) as a read or write channel. To work with cloud storage systems, only a read or write channel can be openend at one time. Data will be treated as bytes, not text (equivalent to `rb` or `wb`). Further, append operations, or writes on existing objects, are dissallowed (the error thrown will depend on the implementation of the underlying cloud provider). """ if "r" in mode and "w" not in mode: return FileSystems().open(filename) if "w" in mode and "r" not in mode: return FileSystems().create(filename) raise ValueError( f"invalid mode {mode!r}: mode must have either 'r' or 'w', but not both." ) def exists(self, filename: str) -> bool: """Returns true if object exists.""" return FileSystems().exists(filename) ================================================ FILE: weather_dl_v2/fastapi-server/config_processing/util.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dateutil.relativedelta import relativedelta import logging import datetime import geojson import hashlib import itertools import os import socket import subprocess import sys import typing as t import numpy as np import pandas as pd from apache_beam.io.gcp import gcsio from apache_beam.utils import retry from xarray.core.utils import ensure_us_time_resolution from urllib.parse import urlparse from google.api_core.exceptions import BadRequest LATITUDE_RANGE = (-90, 90) LONGITUDE_RANGE = (-180, 180) GLOBAL_COVERAGE_AREA = [90, -180, -90, 180] logger = logging.getLogger(__name__) def _retry_if_valid_input_but_server_or_socket_error_and_timeout_filter( exception, ) -> bool: if isinstance(exception, socket.timeout): return True if isinstance(exception, TimeoutError): return True # To handle the concurrency issue in BigQuery. if isinstance(exception, BadRequest): return True return retry.retry_if_valid_input_but_server_error_and_timeout_filter(exception) class _FakeClock: def sleep(self, value): pass def retry_with_exponential_backoff(fun): """A retry decorator that doesn't apply during test time.""" clock = retry.Clock() # Use a fake clock only during test time... if "unittest" in sys.modules.keys(): clock = _FakeClock() return retry.with_exponential_backoff( retry_filter=_retry_if_valid_input_but_server_or_socket_error_and_timeout_filter, clock=clock, )(fun) # TODO(#245): Group with common utilities (duplicated) def ichunked(iterable: t.Iterable, n: int) -> t.Iterator[t.Iterable]: """Yield evenly-sized chunks from an iterable.""" input_ = iter(iterable) try: while True: it = itertools.islice(input_, n) # peek to check if 'it' has next item. first = next(it) yield itertools.chain([first], it) except StopIteration: pass # TODO(#245): Group with common utilities (duplicated) def copy(src: str, dst: str) -> None: """Copy data via `gcloud storage cp`.""" try: subprocess.run(["gcloud", "storage", "cp", src, dst], check=True, capture_output=True) except subprocess.CalledProcessError as e: logger.info( f'Failed to copy file {src!r} to {dst!r} due to {e.stderr.decode("utf-8")}.' ) raise # TODO(#245): Group with common utilities (duplicated) def to_json_serializable_type(value: t.Any) -> t.Any: """Returns the value with a type serializable to JSON""" # Note: The order of processing is significant. logger.info("Serializing to JSON.") if pd.isna(value) or value is None: return None elif np.issubdtype(type(value), np.floating): return float(value) elif isinstance(value, np.ndarray): # Will return a scaler if array is of size 1, else will return a list. return value.tolist() elif ( isinstance(value, datetime.datetime) or isinstance(value, str) or isinstance(value, np.datetime64) ): # Assume strings are ISO format timestamps... try: value = datetime.datetime.fromisoformat(value) except ValueError: # ... if they are not, assume serialization is already correct. return value except TypeError: # ... maybe value is a numpy datetime ... try: value = ensure_us_time_resolution(value).astype(datetime.datetime) except AttributeError: # ... value is a datetime object, continue. pass # We use a string timestamp representation. if value.tzname(): return value.isoformat() # We assume here that naive timestamps are in UTC timezone. return value.replace(tzinfo=datetime.timezone.utc).isoformat() elif isinstance(value, np.timedelta64): # Return time delta in seconds. return float(value / np.timedelta64(1, "s")) # This check must happen after processing np.timedelta64 and np.datetime64. elif np.issubdtype(type(value), np.integer): return int(value) return value def fetch_geo_polygon(area: t.Union[list, str]) -> str: """Calculates a geography polygon from an input area.""" # Ref: https://confluence.ecmwf.int/pages/viewpage.action?pageId=151520973 if isinstance(area, str): # European area if area == "E": area = [73.5, -27, 33, 45] # Global area elif area == "G": area = GLOBAL_COVERAGE_AREA else: raise RuntimeError(f"Not a valid value for area in config: {area}.") n, w, s, e = [float(x) for x in area] if s < LATITUDE_RANGE[0]: raise ValueError(f"Invalid latitude value for south: '{s}'") if n > LATITUDE_RANGE[1]: raise ValueError(f"Invalid latitude value for north: '{n}'") if w < LONGITUDE_RANGE[0]: raise ValueError(f"Invalid longitude value for west: '{w}'") if e > LONGITUDE_RANGE[1]: raise ValueError(f"Invalid longitude value for east: '{e}'") # Define the coordinates of the bounding box. coords = [[w, n], [w, s], [e, s], [e, n], [w, n]] # Create the GeoJSON polygon object. polygon = geojson.dumps(geojson.Polygon([coords])) return polygon def get_file_size(path: str) -> float: parsed_gcs_path = urlparse(path) if parsed_gcs_path.scheme != "gs" or parsed_gcs_path.netloc == "": return os.stat(path).st_size / (1024**3) if os.path.exists(path) else 0 else: return ( gcsio.GcsIO().size(path) / (1024**3) if gcsio.GcsIO().exists(path) else 0 ) def get_wait_interval(num_retries: int = 0) -> float: """Returns next wait interval in seconds, using an exponential backoff algorithm.""" if 0 == num_retries: return 0 return 2**num_retries def generate_md5_hash(input: str) -> str: """Generates md5 hash for the input string.""" return hashlib.md5(input.encode("utf-8")).hexdigest() def download_with_aria2(url: str, path: str) -> None: """Downloads a file from the given URL using the `aria2c` command-line utility, with options set to improve download speed and reliability.""" dir_path, file_name = os.path.split(path) try: subprocess.run( [ "aria2c", "-x", "16", "-s", "16", url, "-d", dir_path, "-o", file_name, "--allow-overwrite", ], check=True, capture_output=True, ) except subprocess.CalledProcessError as e: logger.info( f'Failed download from server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}.' ) raise def generate_hdate(date: str, subtract_year: str) -> str: """Generate a historical date by subtracting a specified number of years from the given date. If input date is leap day (Feb 29), return Feb 28 even if target hdate is also a leap year. This is expected in ECMWF API. Args: date (str): The input date in the format 'YYYY-MM-DD'. subtract_year (str): The number of years to subtract. Returns: str: The historical date in the format 'YYYY-MM-DD'. """ try: input_date = datetime.datetime.strptime(date, "%Y-%m-%d") # Check for leap day if input_date.month == 2 and input_date.day == 29: input_date = input_date - datetime.timedelta(days=1) subtract_year = int(subtract_year) except (ValueError, TypeError): logger.error("Invalid input.") raise hdate = input_date - relativedelta(years=subtract_year) return hdate.strftime("%Y-%m-%d") ================================================ FILE: weather_dl_v2/fastapi-server/database/__init__.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: weather_dl_v2/fastapi-server/database/download_handler.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import logging from firebase_admin import firestore from google.cloud.firestore_v1 import DocumentSnapshot, FieldFilter from google.cloud.firestore_v1.types import WriteResult from database.session import get_async_client from server_config import get_config logger = logging.getLogger(__name__) def get_download_handler(): return DownloadHandlerFirestore(db=get_async_client()) def get_mock_download_handler(): return DownloadHandlerMock() class DownloadHandler(abc.ABC): @abc.abstractmethod async def _start_download(self, config_name: str, client_name: str) -> None: pass @abc.abstractmethod async def _stop_download(self, config_name: str) -> None: pass @abc.abstractmethod async def _mark_partitioning_status(self, config_name: str, status: str) -> None: pass @abc.abstractmethod async def _check_download_exists(self, config_name: str) -> bool: pass @abc.abstractmethod async def _get_downloads(self, client_name: str) -> list: pass @abc.abstractmethod async def _get_download_by_config_name(self, config_name: str): pass class DownloadHandlerMock(DownloadHandler): def __init__(self): pass async def _start_download(self, config_name: str, client_name: str) -> None: logger.info( f"Added {config_name} in 'download' collection. Update_time: 000000." ) async def _stop_download(self, config_name: str) -> None: logger.info( f"Removed {config_name} in 'download' collection. Update_time: 000000." ) async def _mark_partitioning_status(self, config_name: str, status: str) -> None: logger.info( f"Updated {config_name} in 'download' collection. Update_time: 000000." ) async def _check_download_exists(self, config_name: str) -> bool: if config_name == "not_exist": return False elif config_name == "not_exist.cfg": return False else: return True async def _get_downloads(self, client_name: str) -> list: return [{"config_name": "example.cfg", "client_name": "client", "status": "partitioning completed."}] async def _get_download_by_config_name(self, config_name: str): if config_name == "not_exist": return None return {"config_name": "example.cfg", "client_name": "client", "status": "partitioning completed."} class DownloadHandlerFirestore(DownloadHandler): def __init__(self, db: firestore.firestore.Client): self.db = db self.collection = get_config().download_collection async def _start_download(self, config_name: str, client_name: str) -> None: result: WriteResult = ( await self.db.collection(self.collection) .document(config_name) .set({"config_name": config_name, "client_name": client_name}) ) logger.info( f"Added {config_name} in 'download' collection. Update_time: {result.update_time}." ) async def _stop_download(self, config_name: str) -> None: timestamp = ( await self.db.collection(self.collection).document(config_name).delete() ) logger.info( f"Removed {config_name} in 'download' collection. Update_time: {timestamp}." ) async def _mark_partitioning_status(self, config_name: str, status: str) -> None: timestamp = ( await self.db.collection(self.collection) .document(config_name) .update({"status": status}) ) logger.info( f"Updated {config_name} in 'download' collection. Update_time: {timestamp}." ) async def _check_download_exists(self, config_name: str) -> bool: result: DocumentSnapshot = ( await self.db.collection(self.collection).document(config_name).get() ) return result.exists async def _get_downloads(self, client_name: str) -> list: docs = [] if client_name: docs = ( self.db.collection(self.collection) .where(filter=FieldFilter("client_name", "==", client_name)) .stream() ) else: docs = self.db.collection(self.collection).stream() return [doc.to_dict() async for doc in docs] async def _get_download_by_config_name(self, config_name: str): result: DocumentSnapshot = ( await self.db.collection(self.collection).document(config_name).get() ) if result.exists: return result.to_dict() else: return None ================================================ FILE: weather_dl_v2/fastapi-server/database/license_handler.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import logging from firebase_admin import firestore from google.cloud.firestore_v1 import DocumentSnapshot, FieldFilter from google.cloud.firestore_v1.types import WriteResult from database.session import get_async_client from server_config import get_config logger = logging.getLogger(__name__) def get_license_handler(): return LicenseHandlerFirestore(db=get_async_client()) def get_mock_license_handler(): return LicenseHandlerMock() class LicenseHandler(abc.ABC): @abc.abstractmethod async def _add_license(self, license_dict: dict) -> str: pass @abc.abstractmethod async def _delete_license(self, license_id: str) -> None: pass @abc.abstractmethod async def _check_license_exists(self, license_id: str) -> bool: pass @abc.abstractmethod async def _get_license_by_license_id(self, license_id: str) -> dict: pass @abc.abstractmethod async def _get_license_by_client_name(self, client_name: str) -> list: pass @abc.abstractmethod async def _get_licenses(self) -> list: pass @abc.abstractmethod async def _update_license(self, license_id: str, license_dict: dict) -> None: pass @abc.abstractmethod async def _get_license_without_deployment(self) -> list: pass @abc.abstractmethod async def _mark_license_status(self, license_id: str, status: str) -> None: pass class LicenseHandlerMock(LicenseHandler): def __init__(self): pass async def _add_license(self, license_dict: dict) -> str: license_id = "L1" logger.info(f"Added {license_id} in 'license' collection. Update_time: 00000.") return license_id async def _delete_license(self, license_id: str) -> None: logger.info( f"Removed {license_id} in 'license' collection. Update_time: 00000." ) async def _update_license(self, license_id: str, license_dict: dict) -> None: logger.info( f"Updated {license_id} in 'license' collection. Update_time: 00000." ) async def _check_license_exists(self, license_id: str) -> bool: if license_id == "not_exist": return False elif license_id == "no-exists": return False else: return True async def _get_license_by_license_id(self, license_id: str) -> dict: if license_id == "not_exist": return None return { "license_id": license_id, "secret_id": "xxxx", "client_name": "dummy_client", "k8s_deployment_id": "k1", "number_of_requets": 100, } async def _get_license_by_client_name(self, client_name: str) -> list: return [{ "license_id": "L1", "secret_id": "xxxx", "client_name": client_name, "k8s_deployment_id": "k1", "number_of_requets": 100, }] async def _get_licenses(self) -> list: return [{ "license_id": "L1", "secret_id": "xxxx", "client_name": "dummy_client", "k8s_deployment_id": "k1", "number_of_requets": 100, }] async def _get_license_without_deployment(self) -> list: return [] class LicenseHandlerFirestore(LicenseHandler): def __init__(self, db: firestore.firestore.AsyncClient): self.db = db self.collection = get_config().license_collection async def _add_license(self, license_dict: dict) -> str: license_dict["license_id"] = license_dict["license_id"].lower() license_id = license_dict["license_id"] result: WriteResult = ( await self.db.collection(self.collection) .document(license_id) .set(license_dict) ) logger.info( f"Added {license_id} in 'license' collection. Update_time: {result.update_time}." ) return license_id async def _delete_license(self, license_id: str) -> None: timestamp = ( await self.db.collection(self.collection).document(license_id).delete() ) logger.info( f"Removed {license_id} in 'license' collection. Update_time: {timestamp}." ) async def _update_license(self, license_id: str, license_dict: dict) -> None: result: WriteResult = ( await self.db.collection(self.collection) .document(license_id) .update(license_dict) ) logger.info( f"Updated {license_id} in 'license' collection. Update_time: {result.update_time}." ) async def _check_license_exists(self, license_id: str) -> bool: result: DocumentSnapshot = ( await self.db.collection(self.collection).document(license_id).get() ) return result.exists async def _get_license_by_license_id(self, license_id: str) -> dict: result: DocumentSnapshot = ( await self.db.collection(self.collection).document(license_id).get() ) return result.to_dict() async def _get_license_by_client_name(self, client_name: str) -> list: docs = ( self.db.collection(self.collection) .where(filter=FieldFilter("client_name", "==", client_name)) .stream() ) return [doc.to_dict() async for doc in docs] async def _get_licenses(self) -> list: docs = self.db.collection(self.collection).stream() return [doc.to_dict() async for doc in docs] async def _get_license_without_deployment(self) -> list: docs = ( self.db.collection(self.collection) .where(filter=FieldFilter("k8s_deployment_id", "==", "")) .stream() ) return [doc.to_dict() async for doc in docs] async def _mark_license_status(self, license_id: str, status: str) -> None: timestamp = ( await self.db.collection(self.collection) .document(license_id) .update({"status": status}) ) logger.info( f"Updated {license_id} in 'license' collection. Update_time: {timestamp}." ) ================================================ FILE: weather_dl_v2/fastapi-server/database/manifest_handler.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import logging from firebase_admin import firestore from google.cloud.firestore_v1.base_query import FieldFilter, Or, And from server_config import get_config from database.session import get_async_client logger = logging.getLogger(__name__) def get_manifest_handler(): return ManifestHandlerFirestore(db=get_async_client()) def get_mock_manifest_handler(): return ManifestHandlerMock() class ManifestHandler(abc.ABC): @abc.abstractmethod async def _get_download_success_count(self, config_name: str) -> int: pass @abc.abstractmethod async def _get_download_failure_count(self, config_name: str) -> int: pass @abc.abstractmethod async def _get_download_scheduled_count(self, config_name: str) -> int: pass @abc.abstractmethod async def _get_download_inprogress_count(self, config_name: str) -> int: pass @abc.abstractmethod async def _get_download_total_count(self, config_name: str) -> int: pass @abc.abstractmethod async def _get_non_successfull_downloads(self, config_name: str) -> list: pass @abc.abstractmethod async def _get_failed_downloads(self, config_name: str) -> list: pass class ManifestHandlerMock(ManifestHandler): async def _get_download_failure_count(self, config_name: str) -> int: return 0 async def _get_download_inprogress_count(self, config_name: str) -> int: return 0 async def _get_download_scheduled_count(self, config_name: str) -> int: return 0 async def _get_download_success_count(self, config_name: str) -> int: return 0 async def _get_download_total_count(self, config_name: str) -> int: return 0 async def _get_non_successfull_downloads(self, config_name: str) -> list: return [] async def _get_failed_downloads(self, config_name: str) -> list: return [] class ManifestHandlerFirestore(ManifestHandler): def __init__(self, db: firestore.firestore.Client): self.db = db self.collection = get_config().manifest_collection async def _get_download_success_count(self, config_name: str) -> int: result = ( await self.db.collection(self.collection) .where(filter=FieldFilter("config_name", "==", config_name)) .where(filter=FieldFilter("stage", "==", "upload")) .where(filter=FieldFilter("status", "==", "success")) .count() .get() ) count = result[0][0].value return count async def _get_download_failure_count(self, config_name: str) -> int: result = ( await self.db.collection(self.collection) .where(filter=FieldFilter("config_name", "==", config_name)) .where(filter=FieldFilter("status", "==", "failure")) .count() .get() ) count = result[0][0].value return count async def _get_download_scheduled_count(self, config_name: str) -> int: result = ( await self.db.collection(self.collection) .where(filter=FieldFilter("config_name", "==", config_name)) .where(filter=FieldFilter("status", "==", "scheduled")) .count() .get() ) count = result[0][0].value return count async def _get_download_inprogress_count(self, config_name: str) -> int: and_filter = And( filters=[ FieldFilter("status", "==", "success"), FieldFilter("stage", "!=", "upload"), ] ) or_filter = Or(filters=[ FieldFilter("status", "==", "in-progress"), FieldFilter("status", "==", "processing"), and_filter] ) result = ( await self.db.collection(self.collection) .where(filter=FieldFilter("config_name", "==", config_name)) .where(filter=or_filter) .count() .get() ) count = result[0][0].value return count async def _get_download_total_count(self, config_name: str) -> int: result = ( await self.db.collection(self.collection) .where(filter=FieldFilter("config_name", "==", config_name)) .count() .get() ) count = result[0][0].value return count async def _get_non_successfull_downloads(self, config_name: str) -> list: or_filter = Or( filters=[ FieldFilter("stage", "==", "fetch"), FieldFilter("stage", "==", "download"), FieldFilter("stage", "==", None), And( filters=[ FieldFilter("status", "!=", "success"), FieldFilter("stage", "==", "upload"), ] ), ] ) docs = ( self.db.collection(self.collection) .where(filter=FieldFilter("config_name", "==", config_name)) .where(filter=or_filter) .stream() ) return [doc.to_dict() async for doc in docs] async def _get_failed_downloads(self, config_name: str) -> list: docs = ( self.db.collection(self.collection) .where(filter=FieldFilter("config_name", "==", config_name)) .where(filter=FieldFilter("status", "==", "failure")) .stream() ) return [doc.to_dict() async for doc in docs] ================================================ FILE: weather_dl_v2/fastapi-server/database/queue_handler.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import logging from firebase_admin import firestore from google.cloud.firestore_v1 import DocumentSnapshot, FieldFilter from google.cloud.firestore_v1.types import WriteResult from database.session import get_async_client from server_config import get_config logger = logging.getLogger(__name__) def get_queue_handler(): return QueueHandlerFirestore(db=get_async_client()) def get_mock_queue_handler(): return QueueHandlerMock() class QueueHandler(abc.ABC): @abc.abstractmethod async def _create_license_queue(self, license_id: str, client_name: str) -> None: pass @abc.abstractmethod async def _remove_license_queue(self, license_id: str) -> None: pass @abc.abstractmethod async def _get_queues(self) -> list: pass @abc.abstractmethod async def _get_queue_by_license_id(self, license_id: str) -> dict: pass @abc.abstractmethod async def _get_queue_by_client_name(self, client_name: str) -> list: pass @abc.abstractmethod async def _update_license_queue(self, license_id: str, priority_list: list) -> None: pass @abc.abstractmethod async def _update_queues_on_start_download( self, config_name: str, licenses: list ) -> None: pass @abc.abstractmethod async def _update_queues_on_stop_download(self, config_name: str) -> None: pass @abc.abstractmethod async def _update_config_priority_in_license( self, license_id: str, config_name: str, priority: int ) -> None: pass @abc.abstractmethod async def _update_client_name_in_license_queue( self, license_id: str, client_name: str ) -> None: pass class QueueHandlerMock(QueueHandler): def __init__(self): pass async def _create_license_queue(self, license_id: str, client_name: str) -> None: logger.info( f"Added {license_id} queue in 'queues' collection. Update_time: 000000." ) async def _remove_license_queue(self, license_id: str) -> None: logger.info( f"Removed {license_id} queue in 'queues' collection. Update_time: 000000." ) async def _get_queues(self) -> list: return [{"client_name": "dummy_client", "license_id": "L1", "queue": []}] async def _get_queue_by_license_id(self, license_id: str) -> dict: if license_id == "not_exist": return None return {"client_name": "dummy_client", "license_id": license_id, "queue": []} async def _get_queue_by_client_name(self, client_name: str) -> list: return [{"client_name": client_name, "license_id": "L1", "queue": []}] async def _update_license_queue(self, license_id: str, priority_list: list) -> None: logger.info( f"Updated {license_id} queue in 'queues' collection. Update_time: 00000." ) async def _update_queues_on_start_download( self, config_name: str, licenses: list ) -> None: logger.info( f"Updated {license} queue in 'queues' collection. Update_time: 00000." ) async def _update_queues_on_stop_download(self, config_name: str) -> None: logger.info( "Updated snapshot.id queue in 'queues' collection. Update_time: 00000." ) async def _update_config_priority_in_license( self, license_id: str, config_name: str, priority: int ) -> None: logger.info( "Updated snapshot.id queue in 'queues' collection. Update_time: 00000." ) async def _update_client_name_in_license_queue( self, license_id: str, client_name: str ) -> None: logger.info( "Updated snapshot.id queue in 'queues' collection. Update_time: 00000." ) class QueueHandlerFirestore(QueueHandler): def __init__(self, db: firestore.firestore.Client): self.db = db self.collection = get_config().queues_collection async def _create_license_queue(self, license_id: str, client_name: str) -> None: result: WriteResult = ( await self.db.collection(self.collection) .document(license_id) .set({"license_id": license_id, "client_name": client_name, "queue": []}) ) logger.info( f"Added {license_id} queue in 'queues' collection. Update_time: {result.update_time}." ) async def _remove_license_queue(self, license_id: str) -> None: timestamp = ( await self.db.collection(self.collection).document(license_id).delete() ) logger.info( f"Removed {license_id} queue in 'queues' collection. Update_time: {timestamp}." ) async def _get_queues(self) -> list: docs = self.db.collection(self.collection).stream() return [doc.to_dict() async for doc in docs] async def _get_queue_by_license_id(self, license_id: str) -> dict: result: DocumentSnapshot = ( await self.db.collection(self.collection).document(license_id).get() ) return result.to_dict() async def _get_queue_by_client_name(self, client_name: str) -> list: docs = ( self.db.collection(self.collection) .where(filter=FieldFilter("client_name", "==", client_name)) .stream() ) return [doc.to_dict() async for doc in docs] async def _update_license_queue(self, license_id: str, priority_list: list) -> None: result: WriteResult = ( await self.db.collection(self.collection) .document(license_id) .update({"queue": priority_list}) ) logger.info( f"Updated {license_id} queue in 'queues' collection. Update_time: {result.update_time}." ) async def _update_queues_on_start_download( self, config_name: str, licenses: list ) -> None: for license in licenses: result: WriteResult = ( await self.db.collection(self.collection) .document(license) .update({"queue": firestore.ArrayUnion([config_name])}) ) logger.info( f"Updated {license} queue in 'queues' collection. Update_time: {result.update_time}." ) async def _update_queues_on_stop_download(self, config_name: str) -> None: snapshot_list = await self.db.collection(self.collection).get() for snapshot in snapshot_list: result: WriteResult = ( await self.db.collection(self.collection) .document(snapshot.id) .update({"queue": firestore.ArrayRemove([config_name])}) ) logger.info( f"Updated {snapshot.id} queue in 'queues' collection. Update_time: {result.update_time}." ) async def _update_config_priority_in_license( self, license_id: str, config_name: str, priority: int | None ) -> None: snapshot: DocumentSnapshot = ( await self.db.collection(self.collection).document(license_id).get() ) priority_list = snapshot.to_dict()["queue"] new_priority_list = [c for c in priority_list if c != config_name] if priority is None: # If no priority is given, insert at the end. priority = len(new_priority_list) new_priority_list.insert(priority, config_name) result: WriteResult = ( await self.db.collection(self.collection) .document(license_id) .update({"queue": new_priority_list}) ) logger.info( f"Updated {snapshot.id} queue in 'queues' collection. Update_time: {result.update_time}." ) async def _update_client_name_in_license_queue( self, license_id: str, client_name: str ) -> None: result: WriteResult = ( await self.db.collection(self.collection) .document(license_id) .update({"client_name": client_name}) ) logger.info( f"Updated {license_id} queue in 'queues' collection. Update_time: {result.update_time}." ) ================================================ FILE: weather_dl_v2/fastapi-server/database/session.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time import abc import logging import firebase_admin from google.cloud import firestore from firebase_admin import credentials from config_processing.util import get_wait_interval from server_config import get_config from gcloud import storage logger = logging.getLogger(__name__) class Database(abc.ABC): @abc.abstractmethod def _get_db(self): pass db: firestore.AsyncClient = None def get_async_client() -> firestore.AsyncClient: global db attempts = 0 while db is None: try: db = firestore.AsyncClient() except ValueError as e: # The above call will fail with a value error when the firebase app is not initialized. # Initialize the app here, and try again. # Use the application default credentials. cred = credentials.ApplicationDefault() firebase_admin.initialize_app(cred) logger.info("Initialized Firebase App.") if attempts > 4: raise RuntimeError( "Exceeded number of retries to get firestore client." ) from e time.sleep(get_wait_interval(attempts)) attempts += 1 return db def get_gcs_client() -> storage.Client: try: gcs = storage.Client(project=get_config().gcs_project) except ValueError as e: logger.error(f"Error initializing GCS client: {e}.") return gcs ================================================ FILE: weather_dl_v2/fastapi-server/database/storage_handler.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import os import logging import tempfile import contextlib import typing as t from google.cloud import storage from database.session import get_gcs_client from server_config import get_config logger = logging.getLogger(__name__) def get_storage_handler(): return StorageHandlerGCS(client=get_gcs_client()) class StorageHandler(abc.ABC): @abc.abstractmethod def _upload_file(self, file_path) -> str: pass @abc.abstractmethod def _open_local(self, file_name) -> t.Iterator[str]: pass class StorageHandlerMock(StorageHandler): def __init__(self) -> None: pass def _upload_file(self, file_path) -> None: pass def _open_local(self, file_name) -> t.Iterator[str]: pass class StorageHandlerGCS(StorageHandler): def __init__(self, client: storage.Client) -> None: self.client = client self.bucket = self.client.get_bucket(get_config().storage_bucket) def _upload_file(self, file_path) -> str: filename = os.path.basename(file_path).split("/")[-1] blob = self.bucket.blob(filename) blob.upload_from_filename(file_path) logger.info(f"Uploaded {filename} to {self.bucket}.") return blob.public_url @contextlib.contextmanager def _open_local(self, file_name) -> t.Iterator[str]: blob = self.bucket.blob(file_name) with tempfile.NamedTemporaryFile() as dest_file: blob.download_to_filename(dest_file.name) yield dest_file.name ================================================ FILE: weather_dl_v2/fastapi-server/environment.yml ================================================ name: weather-dl-v2-server channels: - conda-forge dependencies: - python=3.10 - xarray - geojson - pip=22.3 - google-cloud-sdk=410.0.0 - pip: - kubernetes - fastapi[all]==0.97.0 - python-multipart - numpy - apache-beam[gcp] - aiohttp - firebase-admin - gcloud ================================================ FILE: weather_dl_v2/fastapi-server/example.cfg ================================================ [parameters] client=mars target_path=gs:///test-weather-dl-v2/{date}T00z.gb partition_keys= date # step # API Keys & Subsections go here... [selection] class=od type=pf stream=enfo expver=0001 levtype=pl levelist=100 # params: # (z) Geopotential 129, (t) Temperature 130, # (u) U component of wind 131, (v) V component of wind 132, # (q) Specific humidity 133, (w) vertical velocity 135, # (vo) Vorticity (relative) 138, (d) Divergence 155, # (r) Relative humidity 157 param=129.128 # # next: 2019-01-01/to/existing # date=2019-07-18/to/2019-07-20 time=0000 step=0/to/2 number=1/to/2 grid=F640 ================================================ FILE: weather_dl_v2/fastapi-server/license_dep/deployment_creator.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from os import path import yaml from kubernetes import client, config from server_config import get_config logger = logging.getLogger(__name__) def create_license_deployment(license_id: str) -> str: """Creates a kubernetes workflow of type Job for downloading the data.""" config.load_config() with open(path.join(path.dirname(__file__), "license_deployment.yaml")) as f: deployment_manifest = yaml.safe_load(f) deployment_name = f"weather-dl-v2-license-dep-{license_id}".lower() # Update the deployment name with a unique identifier deployment_manifest["metadata"]["name"] = deployment_name deployment_manifest["spec"]["template"]["spec"]["containers"][0]["args"] = [ "--license", license_id, ] deployment_manifest["spec"]["template"]["spec"]["containers"][0][ "image" ] = get_config().license_deployment_image # Create an instance of the Kubernetes API client batch_api = client.BatchV1Api() # Create the deployment in the specified namespace response = batch_api.create_namespaced_job( body=deployment_manifest, namespace="default" ) logger.info(f"Deployment created successfully: {response.metadata.name}.") return deployment_name def terminate_license_deployment(license_id: str) -> None: # Load Kubernetes configuration config.load_config() # Create an instance of the Kubernetes API client batch_v1 = client.BatchV1Api() # Specify the name and namespace of the deployment to delete job_name = f"weather-dl-v2-license-dep-{license_id}".lower() # Delete the deployment batch_v1.delete_namespaced_job( name=job_name, namespace="default", body=client.V1DeleteOptions( propagation_policy='Foreground' ) ) logger.info(f"Deployment '{job_name}' deleted successfully.") ================================================ FILE: weather_dl_v2/fastapi-server/license_dep/license_deployment.yaml ================================================ # weather-dl-v2-license-dep Deployment # Defines the deployment of the app running in a pod on any worker node apiVersion: batch/v1 kind: Job metadata: name: weather-dl-v2-license-dep spec: backoffLimit: 5 podReplacementPolicy: Failed template: spec: restartPolicy: OnFailure containers: - name: weather-dl-v2-license-dep image: XXXXXXX imagePullPolicy: Always args: [] resources: requests: cpu: "1500m" # CPU: 1.5 vCPU memory: "2Gi" # RAM: 2 GiB ephemeral-storage: "10Gi" # Storage: 10 GiB volumeMounts: - name: config-volume mountPath: ./config terminationGracePeriodSeconds: 172800 # 48 hours volumes: - name: config-volume configMap: name: dl-v2-config ================================================ FILE: weather_dl_v2/fastapi-server/logging.conf ================================================ [loggers] keys=root,server [handlers] keys=consoleHandler,detailedConsoleHandler [formatters] keys=normalFormatter,detailedFormatter [logger_root] level=INFO handlers=consoleHandler [logger_server] level=DEBUG handlers=detailedConsoleHandler qualname=server propagate=0 [handler_consoleHandler] class=StreamHandler level=DEBUG formatter=normalFormatter args=(sys.stdout,) [handler_detailedConsoleHandler] class=StreamHandler level=DEBUG formatter=detailedFormatter args=(sys.stdout,) [formatter_normalFormatter] format=%(asctime)s loglevel=%(levelname)-6s logger=%(name)s %(funcName)s() msg:%(message)s [formatter_detailedFormatter] format=%(asctime)s loglevel=%(levelname)-6s logger=%(name)s %(funcName)s() msg:%(message)s call_trace=%(pathname)s L%(lineno)-4d ================================================ FILE: weather_dl_v2/fastapi-server/main.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os import logging.config from contextlib import asynccontextmanager from fastapi import FastAPI from routers import license, download, queues from database.license_handler import get_license_handler from routers.license import get_create_deployment from server_config import get_config ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) # set up logger. logging.config.fileConfig("logging.conf", disable_existing_loggers=False) logger = logging.getLogger(__name__) async def create_pending_license_deployments(): """Creates license deployments for Licenses whose deployments does not exist.""" license_handler = get_license_handler() create_deployment = get_create_deployment() license_list = await license_handler._get_license_without_deployment() for _license in license_list: license_id = _license["license_id"] try: logger.info(f"Creating license deployment for {license_id}.") await create_deployment(license_id, license_handler) except Exception as e: logger.error(f"License deployment failed for {license_id}. Exception: {e}.") @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Started FastAPI server.") # Boot up # Make directory to store the uploaded config files. os.makedirs(os.path.join(os.getcwd(), "config_files"), exist_ok=True) # Retrieve license information & create license deployment if needed. await create_pending_license_deployments() # TODO: Automatically create required indexes on firestore collections on server startup. yield # Clean up app = FastAPI(lifespan=lifespan) app.include_router(license.router) app.include_router(download.router) app.include_router(queues.router) @app.get("/") async def main(): return {"msg": get_config().welcome_message} ================================================ FILE: weather_dl_v2/fastapi-server/routers/download.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio import logging import os import shutil import json from enum import Enum from config_processing.parsers import parse_config, process_config from config_processing.config import Config from fastapi import APIRouter, HTTPException, BackgroundTasks, UploadFile, Depends, Body from config_processing.pipeline import start_processing_config from database.download_handler import DownloadHandler, get_download_handler from database.queue_handler import QueueHandler, get_queue_handler from database.license_handler import LicenseHandler, get_license_handler from database.manifest_handler import ManifestHandler, get_manifest_handler from database.storage_handler import StorageHandler, get_storage_handler from config_processing.manifest import FirestoreManifest, Manifest from fastapi.concurrency import run_in_threadpool from routers.license import mark_license_active logger = logging.getLogger(__name__) router = APIRouter( prefix="/download", tags=["download"], responses={404: {"description": "Not found"}}, ) async def fetch_config_stats( config_name: str, client_name: str, status: str, manifest_handler: ManifestHandler ): """Get all the config stats parallely.""" success_coroutine = manifest_handler._get_download_success_count(config_name) scheduled_coroutine = manifest_handler._get_download_scheduled_count(config_name) failure_coroutine = manifest_handler._get_download_failure_count(config_name) inprogress_coroutine = manifest_handler._get_download_inprogress_count(config_name) total_coroutine = manifest_handler._get_download_total_count(config_name) ( success_count, scheduled_count, failure_count, inprogress_count, total_count, ) = await asyncio.gather( success_coroutine, scheduled_coroutine, failure_coroutine, inprogress_coroutine, total_coroutine, ) return { "config_name": config_name, "client_name": client_name, "partitioning_status": status, "scheduled_shards": scheduled_count, "in-progress_shards": inprogress_count, "downloaded_shards": success_count, "failed_shards": failure_count, "total_shards": total_count, } def get_fetch_config_stats(): return fetch_config_stats def get_fetch_config_stats_mock(): async def fetch_config_stats( config_name: str, client_name: str, status: str, manifest_handler: ManifestHandler ): return { "config_name": config_name, "client_name": client_name, "scheduled_shards": 0, "in-progress_shards": 0, "downloaded_shards": 0, "failed_shards": 0, "total_shards": 0, } return fetch_config_stats def get_upload(): def upload(file: UploadFile): dest = os.path.join(os.getcwd(), "config_files", file.filename) with open(dest, "wb+") as dest_: shutil.copyfileobj(file.file, dest_) logger.info(f"Uploading {file.filename} to gcs bucket.") storage_handler: StorageHandler = get_storage_handler() storage_handler._upload_file(dest) return dest return upload def get_upload_mock(): def upload(file: UploadFile): return f"{os.getcwd()}/tests/test_data/{file.filename}" return upload def get_reschedule_partitions(): def invoke_manifest_schedule( partition_list: list, config: Config, manifest: Manifest ): for partition in partition_list: logger.info(f"Rescheduling partition {partition}.") manifest.schedule( config.config_name, config.dataset, json.loads(partition["selection"]), partition["location"], partition["username"], ) async def reschedule_partitions(config_name: str, licenses: list, only_failed: bool = False): manifest_handler: ManifestHandler = get_manifest_handler() download_handler: DownloadHandler = get_download_handler() queue_handler: QueueHandler = get_queue_handler() storage_handler: StorageHandler = get_storage_handler() if only_failed: partition_list = await manifest_handler._get_failed_downloads(config_name) else: partition_list = await manifest_handler._get_non_successfull_downloads( config_name ) config = None manifest = FirestoreManifest() with storage_handler._open_local(config_name) as local_path: with open(local_path, "r", encoding="utf-8") as f: config = process_config(f, config_name) await download_handler._mark_partitioning_status( config_name, "Partitioning in-progress." ) try: if config is None: logger.error( f"Failed reschedule_partitions. Could not open {config_name}." ) raise FileNotFoundError( f"Failed reschedule_partitions. Could not open {config_name}." ) await run_in_threadpool( invoke_manifest_schedule, partition_list, config, manifest ) await download_handler._mark_partitioning_status( config_name, "Partitioning completed." ) if len(licenses) > 0: await queue_handler._update_queues_on_start_download(config_name, licenses) except Exception as e: error_str = f"Partitioning failed for {config_name} due to {e}." logger.error(error_str) await download_handler._mark_partitioning_status(config_name, error_str) return reschedule_partitions def get_reschedule_partitions_mock(): async def reschedule_partitions(config_name: str, licenses: list, only_failed: bool = False): pass return reschedule_partitions # Can submit a config to the server. @router.post("/") async def submit_download( file: UploadFile | None = None, licenses: list = [], force_download: bool = False, priority: int | None = None, background_tasks: BackgroundTasks = BackgroundTasks(), download_handler: DownloadHandler = Depends(get_download_handler), license_handler: LicenseHandler = Depends(get_license_handler), upload=Depends(get_upload), ): if not file: logger.error("No upload file sent.") raise HTTPException(status_code=404, detail="No upload file sent.") else: if await download_handler._check_download_exists(file.filename): logger.error( f"Please stop the ongoing download of the config file '{file.filename}' " "before attempting to start a new download." ) raise HTTPException( status_code=400, detail=f"Please stop the ongoing download of the config file '{file.filename}' " "before attempting to start a new download.", ) for license_id in licenses: if not await license_handler._check_license_exists(license_id): logger.info(f"No such license {license_id}.") raise HTTPException( status_code=404, detail=f"No such license {license_id}." ) await mark_license_active(license_id, license_handler) try: dest = upload(file) # Start processing config. background_tasks.add_task( start_processing_config, dest, licenses, force_download, priority ) return { "message": f"file '{file.filename}' saved at '{dest}' successfully." } except Exception as e: logger.error(f"Failed to save file '{file.filename} due to {e}.") raise HTTPException( status_code=500, detail=f"Failed to save file '{file.filename}'." ) class DownloadStatus(str, Enum): COMPLETED = "completed" FAILED = "failed" IN_PROGRESS = "in-progress" @router.get("/show/{config_name}") async def show_download_config( config_name: str, download_handler: DownloadHandler = Depends(get_download_handler), storage_handler: StorageHandler = Depends(get_storage_handler), ): if not await download_handler._check_download_exists(config_name): logger.error(f"No such download config {config_name} to show.") raise HTTPException( status_code=404, detail=f"No such download config {config_name} to show.", ) contents = None with storage_handler._open_local(config_name) as local_path: with open(local_path, "r", encoding="utf-8") as f: contents = parse_config(f) logger.info(f"Contents of {config_name}: {contents}.") return {"config_name": config_name, "contents": contents} # Can check the current status of the submitted config. # List status for all the downloads + handle filters @router.get("/") async def get_downloads( client_name: str | None = None, status: DownloadStatus | None = None, download_handler: DownloadHandler = Depends(get_download_handler), manifest_handler: ManifestHandler = Depends(get_manifest_handler), fetch_config_stats=Depends(get_fetch_config_stats), ): downloads = await download_handler._get_downloads(client_name) coroutines = [] for download in downloads: coroutines.append( fetch_config_stats( download["config_name"], download["client_name"], download["status"], manifest_handler, ) ) config_details = await asyncio.gather(*coroutines) if status is None: return config_details if status.value == DownloadStatus.COMPLETED: return list( filter( lambda detail: detail["downloaded_shards"] == detail["total_shards"], config_details, ) ) elif status.value == DownloadStatus.FAILED: return list(filter(lambda detail: detail["failed_shards"] > 0, config_details)) elif status.value == DownloadStatus.IN_PROGRESS: return list( filter( lambda detail: detail["downloaded_shards"] != detail["total_shards"], config_details, ) ) else: return config_details # Get status of particular download @router.get("/{config_name}") async def get_download_by_config_name( config_name: str, download_handler: DownloadHandler = Depends(get_download_handler), manifest_handler: ManifestHandler = Depends(get_manifest_handler), fetch_config_stats=Depends(get_fetch_config_stats), ): download = await download_handler._get_download_by_config_name(config_name) if download is None: logger.error(f"Download config {config_name} not found in weather-dl v2.") raise HTTPException( status_code=404, detail=f"Download config {config_name} not found in weather-dl v2.", ) return await fetch_config_stats( download["config_name"], download["client_name"], download["status"], manifest_handler, ) # Stop & remove the execution of the config. @router.delete("/{config_name}") async def delete_download( config_name: str, download_handler: DownloadHandler = Depends(get_download_handler), queue_handler: QueueHandler = Depends(get_queue_handler), ): if not await download_handler._check_download_exists(config_name): logger.error(f"No such download config {config_name} to stop & remove.") raise HTTPException( status_code=404, detail=f"No such download config {config_name} to stop & remove.", ) await download_handler._stop_download(config_name) await queue_handler._update_queues_on_stop_download(config_name) return { "config_name": config_name, "message": "Download config stopped & removed successfully.", } @router.post("/retry/{config_name}") async def retry_config( config_name: str, licenses: list = Body(embed=True), only_failed: bool = False, background_tasks: BackgroundTasks = BackgroundTasks(), download_handler: DownloadHandler = Depends(get_download_handler), license_handler: LicenseHandler = Depends(get_license_handler), reschedule_partitions=Depends(get_reschedule_partitions), ): if not await download_handler._check_download_exists(config_name): logger.error(f"No such download config {config_name} to retry.") raise HTTPException( status_code=404, detail=f"No such download config {config_name} to retry.", ) for license_id in licenses: if not await license_handler._check_license_exists(license_id): logger.info(f"No such license {license_id}.") raise HTTPException( status_code=404, detail=f"No such license {license_id}." ) await mark_license_active(license_id, license_handler) background_tasks.add_task(reschedule_partitions, config_name, licenses, only_failed) return {"msg": "Refetch initiated successfully."} ================================================ FILE: weather_dl_v2/fastapi-server/routers/license.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import re from fastapi import APIRouter, HTTPException, BackgroundTasks, Depends from pydantic import BaseModel from license_dep.deployment_creator import create_license_deployment, terminate_license_deployment from database.license_handler import LicenseHandler, get_license_handler from database.queue_handler import QueueHandler, get_queue_handler logger = logging.getLogger(__name__) class License(BaseModel): license_id: str client_name: str number_of_requests: int secret_id: str class LicenseInternal(License): k8s_deployment_id: str # Can perform CRUD on license table -- helps in handling API KEY expiry. router = APIRouter( prefix="/license", tags=["license"], responses={404: {"description": "Not found"}}, ) async def mark_license_active(license_id: str, license_handler: LicenseHandler): logger.info(f"Marking {license_id} active.") await license_handler._mark_license_status(license_id, 'License Active.') # Add/Update k8s deployment ID for existing license (intenally). async def update_license_internal( license_id: str, k8s_deployment_id: str, license_handler: LicenseHandler, ): if not await license_handler._check_license_exists(license_id): logger.info(f"No such license {license_id} to update.") raise HTTPException( status_code=404, detail=f"No such license {license_id} to update." ) license_dict = {"k8s_deployment_id": k8s_deployment_id} await license_handler._update_license(license_id, license_dict) return {"license_id": license_id, "message": "License updated successfully."} def get_create_deployment(): async def create_deployment(license_id: str, license_handler: LicenseHandler): k8s_deployment_id = create_license_deployment(license_id) await update_license_internal(license_id, k8s_deployment_id, license_handler) return create_deployment def get_create_deployment_mock(): async def create_deployment_mock(license_id: str, license_handler: LicenseHandler): logger.info("create deployment mock.") return create_deployment_mock def get_terminate_license_deployment(): return terminate_license_deployment def get_terminate_license_deployment_mock(): def get_terminate_license_deployment_mock(license_id): logger.info(f"terminating license deployment for {license_id}.") return get_terminate_license_deployment_mock # List all the license + handle filters of {client_name} @router.get("/") async def get_licenses( client_name: str | None = None, license_handler: LicenseHandler = Depends(get_license_handler), ): if client_name: result = await license_handler._get_license_by_client_name(client_name) else: result = await license_handler._get_licenses() return result # Get particular license @router.get("/{license_id}") async def get_license_by_license_id( license_id: str, license_handler: LicenseHandler = Depends(get_license_handler) ): result = await license_handler._get_license_by_license_id(license_id) if not result: logger.info(f"License {license_id} not found.") raise HTTPException(status_code=404, detail=f"License {license_id} not found.") return result # Update existing license @router.put("/{license_id}") async def update_license( license_id: str, license: License, license_handler: LicenseHandler = Depends(get_license_handler), queue_handler: QueueHandler = Depends(get_queue_handler), create_deployment=Depends(get_create_deployment), terminate_license_deployment=Depends(get_terminate_license_deployment), ): if not await license_handler._check_license_exists(license_id): logger.error(f"No such license {license_id} to update.") raise HTTPException( status_code=404, detail=f"No such license {license_id} to update." ) license_dict = license.dict() await license_handler._update_license(license_id, license_dict) await mark_license_active(license_id, license_handler) await queue_handler._update_client_name_in_license_queue( license_id, license_dict["client_name"] ) terminate_license_deployment(license_id) await create_deployment(license_id, license_handler) return {"license_id": license_id, "name": "License updated successfully."} # Add new license @router.post("/") async def add_license( license: License, background_tasks: BackgroundTasks = BackgroundTasks(), license_handler: LicenseHandler = Depends(get_license_handler), queue_handler: QueueHandler = Depends(get_queue_handler), create_deployment=Depends(get_create_deployment), ): license_id = license.license_id.lower() # Check if license id is in correct format. LICENSE_REGEX = re.compile( r"[a-z0-9]([-a-z0-9]*[a-z0-9])?(\\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*" ) if not bool(LICENSE_REGEX.fullmatch(license_id)): logger.error( """Invalid format for license_id. License id must consist of lower case alphanumeric""" """ characters, '-' or '.', and must start and end with an alphanumeric character""" ) raise HTTPException( status_code=400, detail="""Invalid format for license_id. License id must consist of lower case alphanumeric""" """ characters, '-' or '.', and must start and end with an alphanumeric character""", ) if await license_handler._check_license_exists(license_id): logger.error(f"License with license_id {license_id} already exist.") raise HTTPException( status_code=409, detail=f"License with license_id {license_id} already exist.", ) license_dict = license.dict() license_dict["k8s_deployment_id"] = "" license_id = await license_handler._add_license(license_dict) await mark_license_active(license_id, license_handler) await queue_handler._create_license_queue(license_id, license_dict["client_name"]) background_tasks.add_task(create_deployment, license_id, license_handler) return {"license_id": license_id, "message": "License added successfully."} # Remove license @router.delete("/{license_id}") async def delete_license( license_id: str, background_tasks: BackgroundTasks = BackgroundTasks(), license_handler: LicenseHandler = Depends(get_license_handler), queue_handler: QueueHandler = Depends(get_queue_handler), terminate_license_deployment=Depends(get_terminate_license_deployment), ): if not await license_handler._check_license_exists(license_id): logger.error(f"No such license {license_id} to delete.") raise HTTPException( status_code=404, detail=f"No such license {license_id} to delete." ) await license_handler._delete_license(license_id) await queue_handler._remove_license_queue(license_id) background_tasks.add_task(terminate_license_deployment, license_id) return {"license_id": license_id, "message": "License removed successfully."} # TODO: Create a better response for this route. @router.patch("/redeploy") async def redeploy_licenses( license_id: str = None, client_name: str = None, license_handler: LicenseHandler = Depends(get_license_handler), terminate_license_deployment=Depends(get_terminate_license_deployment), create_deployment=Depends(get_create_deployment), ): licenses = [] if license_id is not None: if license_id == "all": licenses = await license_handler._get_licenses() else: license = await license_handler._get_license_by_license_id(license_id) licenses = [license] if client_name is not None: licenses = await license_handler._get_license_by_client_name(client_name) if len(licenses) == 0: return {"message": "No license found."} for license in licenses: license_id = license['license_id'] logger.info(f"Terminating deployment {license['k8s_deployment_id']}.") try: terminate_license_deployment(license_id) except Exception as e: logger.error(f"Couldn't terminate Deployment {license_id}. Error: {e}.") logger.info(f"Creating deployment for {license_id}.") try: await create_deployment(license_id, license_handler) await mark_license_active(license_id, license_handler) except Exception as e: logger.error(f"Couldn't create Deployment {license_id}. Error: {e}.") return {"message": "Licenses redeployed."} ================================================ FILE: weather_dl_v2/fastapi-server/routers/queues.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from fastapi import APIRouter, HTTPException, Depends from database.queue_handler import QueueHandler, get_queue_handler from database.license_handler import LicenseHandler, get_license_handler from database.download_handler import DownloadHandler, get_download_handler logger = logging.getLogger(__name__) router = APIRouter( prefix="/queues", tags=["queues"], responses={404: {"description": "Not found"}}, ) # Users can change the execution order of config per license basis. # List the licenses priority + {client_name} filter @router.get("/") async def get_all_license_queue( client_name: str | None = None, queue_handler: QueueHandler = Depends(get_queue_handler), ): if client_name: result = await queue_handler._get_queue_by_client_name(client_name) else: result = await queue_handler._get_queues() return result # Get particular license priority @router.get("/{license_id}") async def get_license_queue( license_id: str, queue_handler: QueueHandler = Depends(get_queue_handler) ): result = await queue_handler._get_queue_by_license_id(license_id) if not result: logger.error(f"License priority for {license_id} not found.") raise HTTPException( status_code=404, detail=f"License priority for {license_id} not found." ) return result # Change priority queue of particular license @router.post("/{license_id}") async def modify_license_queue( license_id: str, priority_list: list | None = [], queue_handler: QueueHandler = Depends(get_queue_handler), license_handler: LicenseHandler = Depends(get_license_handler), download_handler: DownloadHandler = Depends(get_download_handler), ): if not await license_handler._check_license_exists(license_id): logger.error(f"License {license_id} not found.") raise HTTPException(status_code=404, detail=f"License {license_id} not found.") for config_name in priority_list: config = await download_handler._get_download_by_config_name(config_name) if config is None: logger.error(f"Download config {config_name} not found in weather-dl v2.") raise HTTPException( status_code=404, detail=f"Download config {config_name} not found in weather-dl v2.", ) try: await queue_handler._update_license_queue(license_id, priority_list) return {"message": f"'{license_id}' license priority updated successfully."} except Exception as e: logger.error(f"Failed to update '{license_id}' license priority due to {e}.") raise HTTPException( status_code=404, detail=f"Failed to update '{license_id}' license priority." ) # Change config's priority in particular license @router.put("/priority/{license_id}") async def modify_config_priority_in_license( license_id: str, config_name: str, priority: int, queue_handler: QueueHandler = Depends(get_queue_handler), license_handler: LicenseHandler = Depends(get_license_handler), download_handler: DownloadHandler = Depends(get_download_handler), ): if not await license_handler._check_license_exists(license_id): logger.error(f"License {license_id} not found.") raise HTTPException(status_code=404, detail=f"License {license_id} not found.") config = await download_handler._get_download_by_config_name(config_name) if config is None: logger.error(f"Download config {config_name} not found in weather-dl v2.") raise HTTPException( status_code=404, detail=f"Download config {config_name} not found in weather-dl v2.", ) try: await queue_handler._update_config_priority_in_license( license_id, config_name, priority ) return { "message": f"'{license_id}' license -- '{config_name}' priority updated successfully." } except Exception as e: logger.error(f"Failed to update '{license_id}' license priority due to {e}.") raise HTTPException( status_code=404, detail=f"Failed to update '{license_id}' license priority." ) ================================================ FILE: weather_dl_v2/fastapi-server/server.yaml ================================================ # Due to our org level policy we can't expose external-ip. # In case your project don't have any such restriction a # then no need to create a nginx-server on VM to access this fastapi server # instead create the LoadBalancer Service given below. # # # weather-dl server LoadBalancer Service # # Enables the pods in a deployment to be accessible from outside the cluster # apiVersion: v1 # kind: Service # metadata: # name: weather-dl-v2-server-service # spec: # selector: # app: weather-dl-v2-server-api # ports: # - protocol: "TCP" # port: 8080 # targetPort: 8080 # type: LoadBalancer --- # weather-dl-server-api Deployment # Defines the deployment of the app running in a pod on any worker node apiVersion: apps/v1 kind: Deployment metadata: name: weather-dl-v2-server-api labels: app: weather-dl-v2-server-api spec: replicas: 1 selector: matchLabels: app: weather-dl-v2-server-api template: metadata: labels: app: weather-dl-v2-server-api spec: nodeSelector: cloud.google.com/gke-nodepool: default-pool containers: - name: weather-dl-v2-server-api image: XXXXX ports: - containerPort: 8080 imagePullPolicy: Always volumeMounts: - name: config-volume mountPath: ./config volumes: - name: config-volume configMap: name: dl-v2-config # resources: # # You must specify requests for CPU to autoscale # # based on CPU utilization # requests: # cpu: "250m" --- kind: Role apiVersion: rbac.authorization.k8s.io/v1 metadata: name: weather-dl-v2-server-api rules: - apiGroups: - "" - "apps" - "batch" resources: - endpoints - deployments - pods - jobs verbs: - get - list - watch - create - delete --- kind: RoleBinding apiVersion: rbac.authorization.k8s.io/v1 metadata: name: weather-dl-v2-server-api namespace: default subjects: - kind: ServiceAccount name: default namespace: default roleRef: apiGroup: rbac.authorization.k8s.io kind: Role name: weather-dl-v2-server-api --- ================================================ FILE: weather_dl_v2/fastapi-server/server_config.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import typing as t import json import os import logging logger = logging.getLogger(__name__) Values = t.Union[t.List["Values"], t.Dict[str, "Values"], bool, int, float, str] # pytype: disable=not-supported-yet @dataclasses.dataclass class ServerConfig: download_collection: str = "" queues_collection: str = "" license_collection: str = "" manifest_collection: str = "" storage_bucket: str = "" gcs_project: str = "" license_deployment_image: str = "" welcome_message: str = "" kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict) @classmethod def from_dict(cls, config: t.Dict): config_instance = cls() for key, value in config.items(): if hasattr(config_instance, key): setattr(config_instance, key, value) else: config_instance.kwargs[key] = value return config_instance server_config = None def get_config(): global server_config if server_config: return server_config server_config_json = "config/config.json" if not os.path.exists(server_config_json): server_config_json = os.environ.get("CONFIG_PATH", None) if server_config_json is None: logger.error("Couldn't load config file for fastAPI server.") raise FileNotFoundError("Couldn't load config file for fastAPI server.") with open(server_config_json) as file: config_dict = json.load(file) server_config = ServerConfig.from_dict(config_dict) return server_config ================================================ FILE: weather_dl_v2/fastapi-server/tests/__init__.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: weather_dl_v2/fastapi-server/tests/integration/__init__.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: weather_dl_v2/fastapi-server/tests/integration/test_download.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os from fastapi.testclient import TestClient from main import app, ROOT_DIR from database.download_handler import get_download_handler, get_mock_download_handler from database.license_handler import get_license_handler, get_mock_license_handler from database.queue_handler import get_queue_handler, get_mock_queue_handler from routers.download import get_upload, get_upload_mock, get_fetch_config_stats, get_fetch_config_stats_mock client = TestClient(app) logger = logging.getLogger(__name__) app.dependency_overrides[get_download_handler] = get_mock_download_handler app.dependency_overrides[get_license_handler] = get_mock_license_handler app.dependency_overrides[get_queue_handler] = get_mock_queue_handler app.dependency_overrides[get_upload] = get_upload_mock app.dependency_overrides[get_fetch_config_stats] = get_fetch_config_stats_mock def _get_download(headers, query, code, expected): response = client.get("/download", headers=headers, params=query) assert response.status_code == code assert response.json() == expected def test_get_downloads_basic(): headers = {} query = {} code = 200 expected = [{ "config_name": "example.cfg", "client_name": "client", "downloaded_shards": 0, "scheduled_shards": 0, "failed_shards": 0, "in-progress_shards": 0, "total_shards": 0, }] _get_download(headers, query, code, expected) def _submit_download(headers, file_path, licenses, code, expected): file = None try: file = {"file": open(file_path, "rb")} except FileNotFoundError: logger.info("file not found.") payload = {"licenses": licenses} response = client.post("/download", headers=headers, files=file, data=payload) logger.info(f"resp {response.json()}") assert response.status_code == code assert response.json() == expected def test_submit_download_basic(): header = { "accept": "application/json", } file_path = os.path.join(ROOT_DIR, "tests/test_data/not_exist.cfg") licenses = ["L1"] code = 200 expected = { "message": f"file 'not_exist.cfg' saved at '{os.getcwd()}/tests/test_data/not_exist.cfg' " "successfully." } _submit_download(header, file_path, licenses, code, expected) def test_submit_download_file_not_uploaded(): header = { "accept": "application/json", } file_path = os.path.join(ROOT_DIR, "tests/test_data/wrong_file.cfg") licenses = ["L1"] code = 404 expected = {"detail": "No upload file sent."} _submit_download(header, file_path, licenses, code, expected) def test_submit_download_file_alreadys_exist(): header = { "accept": "application/json", } file_path = os.path.join(ROOT_DIR, "tests/test_data/example.cfg") licenses = ["L1"] code = 400 expected = { "detail": "Please stop the ongoing download of the config file 'example.cfg' before attempting to start a new download." # noqa: E501 } _submit_download(header, file_path, licenses, code, expected) def _get_download_by_config(headers, config_name, code, expected): response = client.get(f"/download/{config_name}", headers=headers) assert response.status_code == code assert response.json() == expected def test_get_download_by_config_basic(): headers = {} config_name = "example.cfg" code = 200 expected = { "config_name": config_name, "client_name": "client", "downloaded_shards": 0, "scheduled_shards": 0, "failed_shards": 0, "in-progress_shards": 0, "total_shards": 0, } _get_download_by_config(headers, config_name, code, expected) def test_get_download_by_config_wrong_config(): headers = {} config_name = "not_exist" code = 404 expected = {"detail": "Download config not_exist not found in weather-dl v2."} _get_download_by_config(headers, config_name, code, expected) def _delete_download_by_config(headers, config_name, code, expected): response = client.delete(f"/download/{config_name}", headers=headers) assert response.status_code == code assert response.json() == expected def test_delete_download_by_config_basic(): headers = {} config_name = "dummy_config" code = 200 expected = { "config_name": "dummy_config", "message": "Download config stopped & removed successfully.", } _delete_download_by_config(headers, config_name, code, expected) def test_delete_download_by_config_wrong_config(): headers = {} config_name = "not_exist" code = 404 expected = {"detail": "No such download config not_exist to stop & remove."} _delete_download_by_config(headers, config_name, code, expected) ================================================ FILE: weather_dl_v2/fastapi-server/tests/integration/test_license.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import json from fastapi.testclient import TestClient from main import app from database.download_handler import get_download_handler, get_mock_download_handler from database.license_handler import get_license_handler, get_mock_license_handler from routers.license import ( get_create_deployment, get_create_deployment_mock, get_terminate_license_deployment, get_terminate_license_deployment_mock, ) from database.queue_handler import get_queue_handler, get_mock_queue_handler client = TestClient(app) logger = logging.getLogger(__name__) app.dependency_overrides[get_download_handler] = get_mock_download_handler app.dependency_overrides[get_license_handler] = get_mock_license_handler app.dependency_overrides[get_queue_handler] = get_mock_queue_handler app.dependency_overrides[get_create_deployment] = get_create_deployment_mock app.dependency_overrides[ get_terminate_license_deployment ] = get_terminate_license_deployment_mock def _get_license(headers, query, code, expected): response = client.get("/license", headers=headers, params=query) assert response.status_code == code assert response.json() == expected def test_get_license_basic(): headers = {} query = {} code = 200 expected = [{ "license_id": "L1", "secret_id": "xxxx", "client_name": "dummy_client", "k8s_deployment_id": "k1", "number_of_requets": 100, }] _get_license(headers, query, code, expected) def test_get_license_client_name(): headers = {} client_name = "dummy_client" query = {"client_name": client_name} code = 200 expected = [{ "license_id": "L1", "secret_id": "xxxx", "client_name": client_name, "k8s_deployment_id": "k1", "number_of_requets": 100, }] _get_license(headers, query, code, expected) def _add_license(headers, payload, code, expected): response = client.post( "/license", headers=headers, data=json.dumps(payload), params={"license_id": "L1"}, ) print(f"test add license {response.json()}") assert response.status_code == code assert response.json() == expected def test_add_license_basic(): headers = {"accept": "application/json", "Content-Type": "application/json"} license = { "license_id": "no-exists", "client_name": "dummy_client", "number_of_requests": 0, "secret_id": "xxxx", } payload = license code = 200 expected = {"license_id": "L1", "message": "License added successfully."} _add_license(headers, payload, code, expected) def _get_license_by_license_id(headers, license_id, code, expected): response = client.get(f"/license/{license_id}", headers=headers) logger.info(f"response {response.json()}") assert response.status_code == code assert response.json() == expected def test_get_license_by_license_id(): headers = {"accept": "application/json", "Content-Type": "application/json"} license_id = "L1" code = 200 expected = { "license_id": license_id, "secret_id": "xxxx", "client_name": "dummy_client", "k8s_deployment_id": "k1", "number_of_requets": 100, } _get_license_by_license_id(headers, license_id, code, expected) def test_get_license_wrong_license(): headers = {} license_id = "not_exist" code = 404 expected = { "detail": "License not_exist not found.", } _get_license_by_license_id(headers, license_id, code, expected) def _update_license(headers, license_id, license, code, expected): response = client.put( f"/license/{license_id}", headers=headers, data=json.dumps(license) ) print(f"_update license {response.json()}") assert response.status_code == code assert response.json() == expected def test_update_license_basic(): headers = {} license_id = "L1" license = { "license_id": "L1", "client_name": "dummy_client", "number_of_requests": 0, "secret_id": "xxxx", } code = 200 expected = {"license_id": license_id, "name": "License updated successfully."} _update_license(headers, license_id, license, code, expected) def test_update_license_wrong_license_id(): headers = {} license_id = "no-exists" license = { "license_id": "no-exists", "client_name": "dummy_client", "number_of_requests": 0, "secret_id": "xxxx", } code = 404 expected = {"detail": "No such license no-exists to update."} _update_license(headers, license_id, license, code, expected) def _delete_license(headers, license_id, code, expected): response = client.delete(f"/license/{license_id}", headers=headers) assert response.status_code == code assert response.json() == expected def test_delete_license_basic(): headers = {} license_id = "L1" code = 200 expected = {"license_id": license_id, "message": "License removed successfully."} _delete_license(headers, license_id, code, expected) def test_delete_license_wrong_license(): headers = {} license_id = "not_exist" code = 404 expected = {"detail": "No such license not_exist to delete."} _delete_license(headers, license_id, code, expected) ================================================ FILE: weather_dl_v2/fastapi-server/tests/integration/test_queues.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from main import app from fastapi.testclient import TestClient from database.download_handler import get_download_handler, get_mock_download_handler from database.license_handler import get_license_handler, get_mock_license_handler from database.queue_handler import get_queue_handler, get_mock_queue_handler client = TestClient(app) logger = logging.getLogger(__name__) app.dependency_overrides[get_download_handler] = get_mock_download_handler app.dependency_overrides[get_license_handler] = get_mock_license_handler app.dependency_overrides[get_queue_handler] = get_mock_queue_handler def _get_all_queue(headers, query, code, expected): response = client.get("/queues", headers=headers, params=query) assert response.status_code == code assert response.json() == expected def test_get_all_queues(): headers = {} query = {} code = 200 expected = [{"client_name": "dummy_client", "license_id": "L1", "queue": []}] _get_all_queue(headers, query, code, expected) def test_get_client_queues(): headers = {} client_name = "dummy_client" query = {"client_name": client_name} code = 200 expected = [{"client_name": client_name, "license_id": "L1", "queue": []}] _get_all_queue(headers, query, code, expected) def _get_queue_by_license(headers, license_id, code, expected): response = client.get(f"/queues/{license_id}", headers=headers) assert response.status_code == code assert response.json() == expected def test_get_queue_by_license_basic(): headers = {} license_id = "L1" code = 200 expected = {"client_name": "dummy_client", "license_id": license_id, "queue": []} _get_queue_by_license(headers, license_id, code, expected) def test_get_queue_by_license_wrong_license(): headers = {} license_id = "not_exist" code = 404 expected = {"detail": 'License priority for not_exist not found.'} _get_queue_by_license(headers, license_id, code, expected) def _modify_license_queue(headers, license_id, priority_list, code, expected): response = client.post(f"/queues/{license_id}", headers=headers, data=priority_list) assert response.status_code == code assert response.json() == expected def test_modify_license_queue_basic(): headers = {} license_id = "L1" priority_list = [] code = 200 expected = {"message": f"'{license_id}' license priority updated successfully."} _modify_license_queue(headers, license_id, priority_list, code, expected) def test_modify_license_queue_wrong_license_id(): headers = {} license_id = "not_exist" priority_list = [] code = 404 expected = {"detail": 'License not_exist not found.'} _modify_license_queue(headers, license_id, priority_list, code, expected) def _modify_config_priority_in_license(headers, license_id, query, code, expected): response = client.put(f"/queues/priority/{license_id}", params=query) logger.info(f"response {response.json()}") assert response.status_code == code assert response.json() == expected def test_modify_config_priority_in_license_basic(): headers = {} license_id = "L1" query = {"config_name": "example.cfg", "priority": 0} code = 200 expected = { "message": f"'{license_id}' license -- 'example.cfg' priority updated successfully." } _modify_config_priority_in_license(headers, license_id, query, code, expected) def test_modify_config_priority_in_license_wrong_license(): headers = {} license_id = "not_exist" query = {"config_name": "example.cfg", "priority": 0} code = 404 expected = {"detail": 'License not_exist not found.'} _modify_config_priority_in_license(headers, license_id, query, code, expected) def test_modify_config_priority_in_license_wrong_config(): headers = {} license_id = "not_exist" query = {"config_name": "wrong.cfg", "priority": 0} code = 404 expected = {"detail": 'License not_exist not found.'} _modify_config_priority_in_license(headers, license_id, query, code, expected) ================================================ FILE: weather_dl_v2/fastapi-server/tests/test_data/example.cfg ================================================ [parameters] client=mars target_path=gs:///test-weather-dl-v2/{date}T00z.gb partition_keys= date # step # API Keys & Subsections go here... [selection] class=od type=pf stream=enfo expver=0001 levtype=pl levelist=100 # params: # (z) Geopotential 129, (t) Temperature 130, # (u) U component of wind 131, (v) V component of wind 132, # (q) Specific humidity 133, (w) vertical velocity 135, # (vo) Vorticity (relative) 138, (d) Divergence 155, # (r) Relative humidity 157 param=129.128 # # next: 2019-01-01/to/existing # date=2019-07-18/to/2019-07-20 time=0000 step=0/to/2 number=1/to/2 grid=F640 ================================================ FILE: weather_dl_v2/fastapi-server/tests/test_data/not_exist.cfg ================================================ [parameters] client=mars target_path=gs:///test-weather-dl-v2/{date}T00z.gb partition_keys= date # step # API Keys & Subsections go here... [selection] class=od type=pf stream=enfo expver=0001 levtype=pl levelist=100 # params: # (z) Geopotential 129, (t) Temperature 130, # (u) U component of wind 131, (v) V component of wind 132, # (q) Specific humidity 133, (w) vertical velocity 135, # (vo) Vorticity (relative) 138, (d) Divergence 155, # (r) Relative humidity 157 param=129.128 # # next: 2019-01-01/to/existing # date=2019-07-18/to/2019-07-20 time=0000 step=0/to/2 number=1/to/2 grid=F640 ================================================ FILE: weather_dl_v2/license_deployment/Dockerfile ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. FROM continuumio/miniconda3:latest # Update miniconda RUN conda update conda -y # Add the mamba solver for faster builds RUN conda install -n base conda-libmamba-solver RUN conda config --set solver libmamba COPY . . # Create conda env using environment.yml RUN conda env create -f environment.yml --debug # Activate the conda env and update the PATH ARG CONDA_ENV_NAME=weather-dl-v2-license-dep RUN echo "source activate ${CONDA_ENV_NAME}" >> ~/.bashrc ENV PATH /opt/conda/envs/${CONDA_ENV_NAME}/bin:$PATH ENTRYPOINT ["python", "-u", "fetch.py"] ================================================ FILE: weather_dl_v2/license_deployment/README.md ================================================ # Deployment Instructions & General Notes ### How to create environment ``` conda env create --name weather-dl-v2-license-dep --file=environment.yml conda activate weather-dl-v2-license-dep ``` ### Make changes in weather_dl_v2/config.json, if required [for running locally] ``` export CONFIG_PATH=/path/to/weather_dl_v2/config.json ``` ### Create docker image for license deployment Refer instructions in weather_dl_v2/README.md ================================================ FILE: weather_dl_v2/license_deployment/VERSION.txt ================================================ 1.0.9 ================================================ FILE: weather_dl_v2/license_deployment/__init__.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: weather_dl_v2/license_deployment/clients.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ECMWF Downloader Clients.""" import abc import collections import contextlib import datetime import io import logging import os import time import typing as t import warnings from urllib.parse import urljoin from cdsapi import api as cds_api import urllib3 from ecmwfapi import api from config import optimize_selection_partition from manifest import Manifest, Stage from util import download_with_aria2, retry_with_exponential_backoff warnings.simplefilter("ignore", category=urllib3.connectionpool.InsecureRequestWarning) class Client(abc.ABC): """Weather data provider client interface. Defines methods and properties required to efficiently interact with weather data providers. Attributes: config: A config that contains pipeline parameters, such as API keys. level: Default log level for the client. """ def __init__(self, dataset: str, level: int = logging.INFO) -> None: """Clients are initialized with the general CLI configuration.""" self.dataset = dataset self.logger = logging.getLogger(f"{__name__}.{type(self).__name__}") self.logger.setLevel(level) @abc.abstractmethod def retrieve( self, dataset: str, selection: t.Dict, output: str, manifest: Manifest ) -> None: """Download from data source.""" pass @classmethod @abc.abstractmethod def num_requests_per_key(cls, dataset: str) -> int: """Specifies the number of workers to be used per api key for the dataset.""" pass @property @abc.abstractmethod def license_url(self): """Specifies the License URL.""" pass class SplitCDSRequest(): """Extended CDS class that separates fetch and download stage.""" def __init__(self, *args, **kwargs): self.__cds_client = cds_api.Client(*args, **kwargs) @retry_with_exponential_backoff def _download(self, url, path: str, size: int) -> None: self.__cds_client.info("Downloading %s to %s (%s)", url, path, cds_api.bytes_to_string(size)) start = time.time() download_with_aria2(url, path) elapsed = time.time() - start if elapsed: self.__cds_client.info("Download rate %s/s", cds_api.bytes_to_string(size / elapsed)) def fetch(self, request: t.Dict, dataset: str) -> t.Dict: result = self.__cds_client.retrieve(dataset, request) return {"href": result.location, "size": result.content_length} def download(self, result: cds_api.Result, target: t.Optional[str] = None) -> None: if target: if os.path.exists(target): # Empty the target file, if it already exists, otherwise the # transfer below might be fooled into thinking we're resuming # an interrupted download. open(target, "w").close() self._download(result["href"], target, result["size"]) class CdsClient(Client): """A client to access weather data from the Cloud Data Store (CDS). Datasets on CDS can be found at: https://cds.climate.copernicus.eu/cdsapp#!/search?type=dataset The parameters section of the input `config` requires two values: `api_url` and `api_key`. Or, these values can be set as the environment variables: `CDSAPI_URL` and `CDSAPI_KEY`. These can be acquired from the following URL, which requires creating a free account: https://cds.climate.copernicus.eu/api-how-to The CDS global queues for data access has dynamic rate limits. These can be viewed live here: https://cds.climate.copernicus.eu/live/limits. Attributes: config: A config that contains pipeline parameters, such as API keys. level: Default log level for the client. """ """Name patterns of datasets that are hosted internally on CDS servers.""" cds_hosted_datasets = {"reanalysis-era"} def retrieve(self, dataset: str, selection: t.Dict, manifest: Manifest) -> None: c = CDSClientExtended( url=os.environ.get("CLIENT_URL"), key=os.environ.get("CLIENT_KEY"), debug_callback=self.logger.debug, info_callback=self.logger.info, warning_callback=self.logger.warning, error_callback=self.logger.error, ) selection_ = optimize_selection_partition(selection) with StdoutLogger(self.logger, level=logging.DEBUG): manifest.set_stage(Stage.FETCH) precise_fetch_start_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec="seconds") ) manifest.prev_stage_precise_start_time = precise_fetch_start_time result = c.fetch(selection_, dataset) return result @property def license_url(self): return "https://cds.climate.copernicus.eu/api/v2/terms/static/licence-to-use-copernicus-products.pdf" @classmethod def num_requests_per_key(cls, dataset: str) -> int: """Number of requests per key from the CDS API. CDS has dynamic, data-specific limits, defined here: https://cds.climate.copernicus.eu/live/limits Typically, the reanalysis dataset allows for 3-5 simultaneous requets. For all standard CDS data (backed on disk drives), it's common that 2 requests are allowed, though this is dynamically set, too. If the Beam pipeline encounters a user request limit error, please cancel all outstanding requests (per each user account) at the following link: https://cds.climate.copernicus.eu/cdsapp#!/yourrequests """ # TODO(#15): Parse live CDS limits API to set data-specific limits. for internal_set in cls.cds_hosted_datasets: if dataset.startswith(internal_set): return 5 return 2 class StdoutLogger(io.StringIO): """Special logger to redirect stdout to logs.""" def __init__(self, logger_: logging.Logger, level: int = logging.INFO): super().__init__() self.logger = logger_ self.level = level self._redirector = contextlib.redirect_stdout(self) def log(self, msg) -> None: self.logger.log(self.level, msg) def write(self, msg): if msg and not msg.isspace(): self.logger.log(self.level, msg) def __enter__(self): self._redirector.__enter__() return self def __exit__(self, exc_type, exc_value, traceback): # let contextlib do any exception handling here self._redirector.__exit__(exc_type, exc_value, traceback) class SplitMARSRequest(api.APIRequest): """Extended MARS APIRequest class that separates fetch and download stage.""" @retry_with_exponential_backoff def _download(self, url, path: str, size: int) -> None: self.log("Transferring %s into %s" % (self._bytename(size), path)) self.log("From %s" % (url,)) download_with_aria2(url, path) def fetch(self, request: t.Dict, dataset: str) -> t.Dict: status = None self.connection.submit("%s/%s/requests" % (self.url, self.service), request) self.log("Request submitted") self.log("Request id: " + self.connection.last.get("name")) if self.connection.status != status: status = self.connection.status self.log("Request is %s" % (status,)) while not self.connection.ready(): if self.connection.status != status: status = self.connection.status self.log("Request is %s" % (status,)) self.connection.wait() if self.connection.status != status: status = self.connection.status self.log("Request is %s" % (status,)) result = self.connection.result() return result def download(self, result: t.Dict, target: t.Optional[str] = None) -> None: if target: if os.path.exists(target): # Empty the target file, if it already exists, otherwise the # transfer below might be fooled into thinking we're resuming # an interrupted download. open(target, "w").close() self._download(urljoin(self.url, result["href"]), target, result["size"]) self.connection.cleanup() class SplitRequestMixin: c = None def fetch(self, req: t.Dict, dataset: t.Optional[str] = None) -> t.Dict: return self.c.fetch(req, dataset) def download(self, res: t.Dict, target: str) -> None: self.c.download(res, target) class CDSClientExtended(SplitRequestMixin): """Extended CDS Client class that separates fetch and download stage.""" def __init__(self, *args, **kwargs): self.c = SplitCDSRequest(*args, **kwargs) class MARSECMWFServiceExtended(api.ECMWFService, SplitRequestMixin): """Extended MARS ECMFService class that separates fetch and download stage.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.c = SplitMARSRequest( self.url, "services/%s" % (self.service,), email=self.email, key=self.key, log=self.log, verbose=self.verbose, quiet=self.quiet, ) class PublicECMWFServerExtended(api.ECMWFDataServer, SplitRequestMixin): def __init__(self, *args, dataset="", **kwargs): super().__init__(*args, **kwargs) self.c = SplitMARSRequest( self.url, "datasets/%s" % (dataset,), email=self.email, key=self.key, log=self.log, verbose=self.verbose, ) class MarsClient(Client): """A client to access data from the Meteorological Archival and Retrieval System (MARS). See https://www.ecmwf.int/en/forecasts/datasets for a summary of datasets available on MARS. Most notable, MARS provides access to ECMWF's Operational Archive https://www.ecmwf.int/en/forecasts/dataset/operational-archive. The client config must contain three parameters to autheticate access to the MARS archive: `api_key`, `api_url`, and `api_email`. These can also be configued by setting the commensurate environment variables: `MARSAPI_KEY`, `MARSAPI_URL`, and `MARSAPI_EMAIL`. These credentials can be looked up by after registering for an ECMWF account (https://apps.ecmwf.int/registration/) and visitng: https://api.ecmwf.int/v1/key/. MARS server activity can be observed at https://apps.ecmwf.int/mars-activity/. Attributes: config: A config that contains pipeline parameters, such as API keys. level: Default log level for the client. """ def retrieve(self, dataset: str, selection: t.Dict, manifest: Manifest) -> None: c = MARSECMWFServiceExtended( "mars", key=os.environ.get("CLIENT_KEY"), url=os.environ.get("CLIENT_URL"), email=os.environ.get("CLIENT_EMAIL"), log=self.logger.debug, verbose=True, ) selection_ = optimize_selection_partition(selection) with StdoutLogger(self.logger, level=logging.DEBUG): manifest.set_stage(Stage.FETCH) precise_fetch_start_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec="seconds") ) manifest.prev_stage_precise_start_time = precise_fetch_start_time result = c.fetch(req=selection_) return result @property def license_url(self): return "https://apps.ecmwf.int/datasets/licences/general/" @classmethod def num_requests_per_key(cls, dataset: str) -> int: """Number of requests per key (or user) for the Mars API. Mars allows 2 active requests per user and 20 queued requests per user, as of Sept 27, 2021. To ensure we never hit a rate limit error during download, we only make use of the active requests. See: https://confluence.ecmwf.int/display/UDOC/Total+number+of+requests+a+user+can+submit+-+Web+API+FAQ Queued requests can _only_ be canceled manually from a web dashboard. If the `ERROR 101 (USER_QUEUED_LIMIT_EXCEEDED)` error occurs in the Beam pipeline, then go to http://apps.ecmwf.int/webmars/joblist/ and cancel queued jobs. """ return 2 class ECMWFPublicClient(Client): """A client for ECMWF's public datasets, like TIGGE.""" def retrieve(self, dataset: str, selection: t.Dict, manifest: Manifest) -> None: c = PublicECMWFServerExtended( url=os.environ.get("CLIENT_URL"), key=os.environ.get("CLIENT_KEY"), email=os.environ.get("CLIENT_EMAIL"), log=self.logger.debug, verbose=True, dataset=dataset, ) selection_ = optimize_selection_partition(selection) with StdoutLogger(self.logger, level=logging.DEBUG): manifest.set_stage(Stage.FETCH) precise_fetch_start_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec="seconds") ) manifest.prev_stage_precise_start_time = precise_fetch_start_time result = c.fetch(req=selection_) return result @classmethod def num_requests_per_key(cls, dataset: str) -> int: # Experimentally validated request limit. return 5 @property def license_url(self): if not self.dataset: raise ValueError("must specify a dataset for this client!") return f"https://apps.ecmwf.int/datasets/data/{self.dataset.lower()}/licence/" class FakeClient(Client): """A client that writes the selection arguments to the output file.""" def retrieve(self, dataset: str, selection: t.Dict, manifest: Manifest) -> None: manifest.set_stage(Stage.RETRIEVE) precise_retrieve_start_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec="seconds") ) manifest.prev_stage_precise_start_time = precise_retrieve_start_time self.logger.debug(f"Downloading {dataset}.") @property def license_url(self): return "lorem ipsum" @classmethod def num_requests_per_key(cls, dataset: str) -> int: return 1 CLIENTS = collections.OrderedDict( cds=CdsClient, mars=MarsClient, ecpublic=ECMWFPublicClient, fake=FakeClient, ) ================================================ FILE: weather_dl_v2/license_deployment/config.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import calendar import copy import dataclasses import itertools import typing as t Values = t.Union[t.List["Values"], t.Dict[str, "Values"], bool, int, float, str] # pytype: disable=not-supported-yet @dataclasses.dataclass class Config: """Contains pipeline parameters. Attributes: config_name: Name of the config file. client: Name of the Weather-API-client. Supported clients are mentioned in the 'CLIENTS' variable. dataset (optional): Name of the target dataset. Allowed options are dictated by the client. partition_keys (optional): Choose the keys from the selection section to partition the data request. This will compute a cartesian cross product of the selected keys and assign each as their own download. target_path: Download artifact filename template. Can make use of Python's standard string formatting. It can contain format symbols to be replaced by partition keys; if this is used, the total number of format symbols must match the number of partition keys. subsection_name: Name of the particular subsection. 'default' if there is no subsection. force_download: Force redownload of partitions that were previously downloaded. user_id: Username from the environment variables. kwargs (optional): For representing subsections or any other parameters. selection: Contains parameters used to select desired data. """ config_name: str = "" client: str = "" dataset: t.Optional[str] = "" target_path: str = "" partition_keys: t.Optional[t.List[str]] = dataclasses.field(default_factory=list) subsection_name: str = "default" force_download: bool = False user_id: str = "unknown" kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict) selection: t.Dict[str, Values] = dataclasses.field(default_factory=dict) @classmethod def from_dict(cls, config: t.Dict) -> "Config": config_instance = cls() for section_key, section_value in config.items(): if section_key == "parameters": for key, value in section_value.items(): if hasattr(config_instance, key): setattr(config_instance, key, value) else: config_instance.kwargs[key] = value if section_key == "selection": config_instance.selection = section_value return config_instance def optimize_selection_partition(selection: t.Dict) -> t.Dict: """Compute right-hand-side values for the selection section of a single partition. Used to support custom syntax and optimizations, such as 'all'. """ selection_ = copy.deepcopy(selection) if "date_range" in selection_.keys(): selection_["date"] = selection_["date_range"][0] del selection_["date_range"] if "day" in selection_.keys() and selection_["day"] == "all": years, months = selection_["year"], selection_["month"] multiples_error = "When using day='all' in selection, '/' is not allowed in {type}." if isinstance(years, str): years = [years] if isinstance(months, str): months = [months] date_ranges = [] # Generating dates for every year-month. for year, month in itertools.product(years, months): if isinstance(year, str): assert "/" not in year, multiples_error.format(type="year") if isinstance(month, str): assert "/" not in month, multiples_error.format(type="month") year, month = int(year), int(month) _, n_days_in_month = calendar.monthrange(year, month) date_range = [f'{year:04d}-{month:02d}-{day:02d}' for day in range(1, n_days_in_month + 1)] date_ranges.extend(date_range) selection_["date"] = date_ranges del selection_["day"] del selection_["month"] del selection_["year"] return selection_ ================================================ FILE: weather_dl_v2/license_deployment/database.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import time import logging import firebase_admin from firebase_admin import firestore from firebase_admin import credentials from google.cloud.firestore_v1 import DocumentSnapshot, DocumentReference from google.cloud.firestore_v1.types import WriteResult from google.cloud.firestore_v1.base_query import FieldFilter, And from util import get_wait_interval from deployment_config import get_config logger = logging.getLogger(__name__) class Database(abc.ABC): @abc.abstractmethod def _get_db(self): pass class CRUDOperations(abc.ABC): @abc.abstractmethod def _initialize_license_deployment(self, license_id: str) -> dict: pass @abc.abstractmethod def _get_config_from_queue_by_license_id(self, license_id: str) -> dict: pass @abc.abstractmethod def _remove_config_from_license_queue( self, license_id: str, config_name: str ) -> None: pass @abc.abstractmethod def _empty_license_queue(self, license_id: str) -> None: pass @abc.abstractmethod def _get_partition_from_manifest(self, config_name: str) -> str: pass @abc.abstractmethod def _mark_license_status(self, license_id: str, status: str) -> None: pass class FirestoreClient(Database, CRUDOperations): def _get_db(self) -> firestore.firestore.Client: """Acquire a firestore client, initializing the firebase app if necessary. Will attempt to get the db client five times. If it's still unsuccessful, a `ManifestException` will be raised. """ db = None attempts = 0 while db is None: try: db = firestore.client() except ValueError as e: # The above call will fail with a value error when the firebase app is not initialized. # Initialize the app here, and try again. # Use the application default credentials. cred = credentials.ApplicationDefault() firebase_admin.initialize_app(cred) logger.info("Initialized Firebase App.") if attempts > 4: raise RuntimeError( "Exceeded number of retries to get firestore client." ) from e time.sleep(get_wait_interval(attempts)) attempts += 1 return db def _initialize_license_deployment(self, license_id: str) -> dict: result: DocumentSnapshot = ( self._get_db() .collection(get_config().license_collection) .document(license_id) .get() ) return result.to_dict() def _get_config_from_queue_by_license_id(self, license_id: str) -> str | None: result: DocumentSnapshot = ( self._get_db() .collection(get_config().queues_collection) .document(license_id) .get(["queue"]) ) if result.exists: queue = result.to_dict()["queue"] if len(queue) > 0: return queue[0] return None def _get_partition_from_manifest(self, config_name: str) -> str | None: transaction = self._get_db().transaction() return get_partition_from_manifest(transaction, config_name) def _remove_config_from_license_queue( self, license_id: str, config_name: str ) -> None: result: WriteResult = ( self._get_db() .collection(get_config().queues_collection) .document(license_id) .update({"queue": firestore.ArrayRemove([config_name])}) ) logger.info( f"Updated {license_id} queue in 'queues' collection. Update_time: {result.update_time}." ) def _empty_license_queue(self, license_id: str) -> None: result: WriteResult = ( self._get_db() .collection(get_config().queues_collection) .document(license_id) .update({"queue": []}) ) logger.info( f"Updated {license_id} queue in 'queues' collection. Update_time: {result.update_time}." ) def _mark_license_status(self, license_id: str, status: str) -> None: timestamp = ( self._get_db() .collection(get_config().license_collection) .document(license_id) .update({"status": status}) ) logger.info( f"Updated {license_id} in 'license' collection. Update_time: {timestamp}." ) # TODO: Firestore transcational fails after reading a document 20 times with roll over. # This happens when too many licenses try to access the same partition document. # Find some alternative approach to handle this. @firestore.transactional def get_partition_from_manifest(transaction, config_name: str) -> str | None: db_client = FirestoreClient() filter_1 = FieldFilter("config_name", "==", config_name) filter_2 = FieldFilter("status", "==", "scheduled") and_filter = And(filters=[filter_1, filter_2]) snapshot = ( db_client._get_db() .collection(get_config().manifest_collection) .where(filter=and_filter) .limit(1) .get(transaction=transaction) ) if len(snapshot) > 0: snapshot = snapshot[0] else: return None ref: DocumentReference = ( db_client._get_db() .collection(get_config().manifest_collection) .document(snapshot.id) ) transaction.update(ref, {"status": "processing"}) return snapshot.to_dict() ================================================ FILE: weather_dl_v2/license_deployment/deployment_config.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import typing as t import json import os import logging logger = logging.getLogger(__name__) Values = t.Union[t.List["Values"], t.Dict[str, "Values"], bool, int, float, str] # pytype: disable=not-supported-yet @dataclasses.dataclass class DeploymentConfig: download_collection: str = "" queues_collection: str = "" license_collection: str = "" manifest_collection: str = "" downloader_k8_image: str = "" kwargs: t.Optional[t.Dict[str, Values]] = dataclasses.field(default_factory=dict) @classmethod def from_dict(cls, config: t.Dict): config_instance = cls() for key, value in config.items(): if hasattr(config_instance, key): setattr(config_instance, key, value) else: config_instance.kwargs[key] = value return config_instance deployment_config = None def get_config(): global deployment_config if deployment_config: return deployment_config deployment_config_json = "config/config.json" if not os.path.exists(deployment_config_json): deployment_config_json = os.environ.get("CONFIG_PATH", None) if deployment_config_json is None: logger.error("Couldn't load config file for license deployment.") raise FileNotFoundError("Couldn't load config file for license deployment.") with open(deployment_config_json) as file: config_dict = json.load(file) deployment_config = DeploymentConfig.from_dict(config_dict) return deployment_config ================================================ FILE: weather_dl_v2/license_deployment/downloader.yaml ================================================ apiVersion: batch/v1 kind: Job metadata: name: downloader-with-ttl spec: ttlSecondsAfterFinished: 0 backoffLimit: 0 template: spec: nodeSelector: cloud.google.com/gke-nodepool: downloader-pool containers: - name: downloader image: XXXXXXX imagePullPolicy: Always command: [] resources: requests: cpu: "1000m" # CPU: 1 vCPU memory: "2Gi" # RAM: 2 GiB ephemeral-storage: "100Gi" # Storage: 100 GiB volumeMounts: - name: data mountPath: /data - name: config-volume mountPath: ./config restartPolicy: Never volumes: - name: data emptyDir: sizeLimit: 100Gi - name: config-volume configMap: name: dl-v2-config ================================================ FILE: weather_dl_v2/license_deployment/environment.yml ================================================ name: weather-dl-v2-license-dep channels: - conda-forge dependencies: - python=3.10 - geojson - ecmwf-api-client=1.6.3 - pip=22.3 - cdsapi=0.7.5 - pip: - kubernetes - google-cloud-logging - google-cloud-secret-manager - aiohttp - numpy - xarray - apache-beam[gcp] - firebase-admin ================================================ FILE: weather_dl_v2/license_deployment/fetch.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from concurrent.futures import ThreadPoolExecutor from google.cloud import secretmanager import json import logging import time import sys import os import google.cloud.logging from database import FirestoreClient from job_creator import create_download_job from clients import CLIENTS from manifest import FirestoreManifest from util import exceptionit, ThreadSafeDict, GracefulKiller db_client = FirestoreClient() log_client = google.cloud.logging.Client() log_client.setup_logging() secretmanager_client = secretmanager.SecretManagerServiceClient() CONFIG_MAX_ERROR_COUNT = 10 logger = logging.getLogger(__name__) def create_job(request, result): res = { "config_name": request["config_name"], "dataset": request["dataset"], "selection": json.loads(request["selection"]), "user_id": request["username"], "url": result["href"], "target_path": request["location"], "license_id": license_id, } data_str = json.dumps(res) logger.info(f"Creating download job for res: {data_str}") create_download_job(data_str) @exceptionit def make_fetch_request(request, error_map: ThreadSafeDict): client = CLIENTS[client_name](request["dataset"]) manifest = FirestoreManifest(license_id=license_id) logger.info( f"By using {client_name} datasets, " f"users agree to the terms and conditions specified in {client.license_url!r}." ) target = request["location"] selection = json.loads(request["selection"]) logger.info(f"Fetching data for {target!r}.") config_name = request["config_name"] if not error_map.has_key(config_name): error_map[config_name] = 0 if error_map[config_name] >= CONFIG_MAX_ERROR_COUNT: logger.info(f"Error count for config {config_name} exceeded CONFIG_MAX_ERROR_COUNT ({CONFIG_MAX_ERROR_COUNT}).") error_map.remove(config_name) logger.info(f"Removing config {config_name} from license queue.") # Remove config from this license queue. db_client._remove_config_from_license_queue(license_id=license_id, config_name=config_name) return # Wait for exponential time based on error count. if error_map[config_name] > 0: logger.info(f"Error count for config {config_name}: {error_map[config_name]}.") sleep_time = error_map.exponential_time(config_name) logger.info(f"Sleeping for {sleep_time} secs.") time.sleep(sleep_time) try: with manifest.transact( request["config_name"], request["dataset"], selection, target, request["username"], ): result = client.retrieve(request["dataset"], selection, manifest) except Exception as e: # We are handling this as generic case as CDS client throws generic exceptions. # License expired. if "Access token expired" in str(e): logger.error(f"{license_id} expired. Emptying queue! error: {e}.") db_client._empty_license_queue(license_id=license_id) db_client._mark_license_status(license_id, "License Expired.") return if "Access token disabled'" in str(e): logger.error(f"{license_id} disabled. Emptying queue! error: {e}.") db_client._empty_license_queue(license_id=license_id) db_client._mark_license_status(license_id, "License Disabled.") return # License queue full on client side. if "USER_QUEUED_LIMIT_EXCEEDED" in str(e) or \ "Too many queued requests" in str(e): logger.error(f"{license_id} queue full. Emptying queue! error: {e}.") db_client._empty_license_queue(license_id=license_id) db_client._mark_license_status(license_id, "License Queue Full.") return # Increment error count for a config. logger.error(f"Partition fetching failed. Error {e}.") error_map.increment(config_name) return # If any partition in successful reset the error count. error_map[config_name] = 0 create_job(request, result) def fetch_request_from_db(): request = None config_name = db_client._get_config_from_queue_by_license_id(license_id) if config_name: try: logger.info(f"Fetching partition for {config_name}.") request = db_client._get_partition_from_manifest(config_name) if not request: db_client._remove_config_from_license_queue(license_id, config_name) except Exception as e: logger.error( f"Error in fetch_request_from_db for {config_name}. error: {e}." ) return request def main(): logger.info("Started looking at the request.") error_map = ThreadSafeDict() killer = GracefulKiller() with ThreadPoolExecutor(concurrency_limit) as executor: # Disclaimer: A license will pick always pick concurrency_limit + 1 # parition. One extra parition will be kept in threadpool task queue. log_count = 0 while True: # Check if SIGTERM was recived for graceful termination. if not killer.kill_now: # Fetch a request from the database. request = fetch_request_from_db() else: logger.warning('SIGTERM recieved. Stopping further requets processing.') break if request is not None: executor.submit(make_fetch_request, request, error_map) else: logger.info("No request available. Waiting...") time.sleep(30) # Each license should not pick more partitions than it's # concurrency_limit. We limit the threadpool queue size to just 1 # to prevent the license from picking more partitions than # it's concurrency_limit. When an executor is freed up, the task # in queue is picked and license fetches another task. while executor._work_queue.qsize() >= 1: # To prevent flooding of this log, we log this every 60 seconds. if log_count%60 == 0: logger.info("Worker busy. Waiting...") # Reset log_count if it goes beyond 3600. log_count = 1 if log_count >= 3600 else log_count + 1 time.sleep(1) logger.warning('Graceful Termination. Waiting for remaining requests to complete.') # Making sure all pending requests are completed. executor.shutdown(wait=True) logger.warning('Graceful Termination. Completed all pending requests.') # We want mark the pod as failed as we want to start a new pod which will # continue to fetch requests. raise RuntimeError('License Deployment was Graceful Terminated. ' \ 'Raising Error to mark the pod as failed.') def boot_up(license: str) -> None: global license_id, client_name, concurrency_limit result = db_client._initialize_license_deployment(license) license_id = license client_name = result["client_name"] concurrency_limit = result["number_of_requests"] response = secretmanager_client.access_secret_version( request={"name": result["secret_id"]} ) payload = response.payload.data.decode("UTF-8") secret_dict = json.loads(payload) os.environ.setdefault("CLIENT_URL", secret_dict.get("api_url", "")) os.environ.setdefault("CLIENT_KEY", secret_dict.get("api_key", "")) os.environ.setdefault("CLIENT_EMAIL", secret_dict.get("api_email", "")) if __name__ == "__main__": try: license = sys.argv[2] logger.info(f"Deployment for license: {license}.") boot_up(license) main() except Exception as e: logger.info(f"License error: {e}.") raise e logger.info('License deployment shutting down.') ================================================ FILE: weather_dl_v2/license_deployment/job_creator.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from os import path import yaml import json import uuid from kubernetes import client, config from deployment_config import get_config def create_download_job(message): """Creates a kubernetes workflow of type Job for downloading the data.""" parsed_message = json.loads(message) ( config_name, dataset, selection, user_id, url, target_path, license_id, ) = parsed_message.values() selection = str(selection).replace(" ", "") config.load_config() with open(path.join(path.dirname(__file__), "downloader.yaml")) as f: dep = yaml.safe_load(f) uid = uuid.uuid4() dep["metadata"]["name"] = f"downloader-job-id-{uid}" dep["spec"]["template"]["spec"]["containers"][0]["command"] = [ "python", "downloader.py", config_name, dataset, selection, user_id, url, target_path, license_id, ] dep["spec"]["template"]["spec"]["containers"][0][ "image" ] = get_config().downloader_k8_image batch_api = client.BatchV1Api() batch_api.create_namespaced_job(body=dep, namespace="default") ================================================ FILE: weather_dl_v2/license_deployment/manifest.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Client interface for connecting to a manifest.""" import abc import logging import dataclasses import datetime import enum import json import pandas as pd import time import traceback import typing as t from util import ( to_json_serializable_type, fetch_geo_polygon, get_file_size, get_wait_interval, generate_md5_hash, GLOBAL_COVERAGE_AREA, ) import firebase_admin from firebase_admin import credentials from firebase_admin import firestore from google.cloud.firestore_v1 import DocumentReference from google.cloud.firestore_v1.types import WriteResult from deployment_config import get_config from database import Database logger = logging.getLogger(__name__) """An implementation-dependent Manifest URI.""" Location = t.NewType("Location", str) class ManifestException(Exception): """Errors that occur in Manifest Clients.""" pass class Stage(enum.Enum): """A request can be either in one of the following stages at a time: fetch : This represents request is currently in fetch stage i.e. request placed on the client's server & waiting for some result before starting download (eg. MARS client). download : This represents request is currently in download stage i.e. data is being downloading from client's server to the worker's local file system. upload : This represents request is currently in upload stage i.e. data is getting uploaded from worker's local file system to target location (GCS path). retrieve : In case of clients where there is no proper separation of fetch & download stages (eg. CDS client), request will be in the retrieve stage i.e. fetch + download. """ RETRIEVE = "retrieve" FETCH = "fetch" DOWNLOAD = "download" UPLOAD = "upload" class Status(enum.Enum): """Depicts the request's state status: scheduled : A request partition is created & scheduled for processing. Note: Its corresponding state can be None only. processing: This represents that the request picked by license deployment. in-progress : This represents the request state is currently in-progress (i.e. running). The next status would be "success" or "failure". success : This represents the request state execution completed successfully without any error. failure : This represents the request state execution failed. """ PROCESSING = "processing" SCHEDULED = "scheduled" IN_PROGRESS = "in-progress" SUCCESS = "success" FAILURE = "failure" @dataclasses.dataclass class DownloadStatus: """Data recorded in `Manifest`s reflecting the status of a download.""" """The name of the config file associated with the request.""" config_name: str = "" """Represents the dataset field of the configuration.""" dataset: t.Optional[str] = "" """Copy of selection section of the configuration.""" selection: t.Dict = dataclasses.field(default_factory=dict) """Location of the downloaded data.""" location: str = "" """Represents area covered by the shard.""" area: str = "" """Current stage of request : 'fetch', 'download', 'retrieve', 'upload' or None.""" stage: t.Optional[Stage] = None """Download status: 'scheduled', 'in-progress', 'success', or 'failure'.""" status: t.Optional[Status] = None """Cause of error, if any.""" error: t.Optional[str] = "" """Identifier for the user running the download.""" username: str = "" """Shard size in GB.""" size: t.Optional[float] = 0 """A UTC datetime when download was scheduled.""" scheduled_time: t.Optional[str] = "" """A UTC datetime when the retrieve stage starts.""" retrieve_start_time: t.Optional[str] = "" """A UTC datetime when the retrieve state ends.""" retrieve_end_time: t.Optional[str] = "" """A UTC datetime when the fetch state starts.""" fetch_start_time: t.Optional[str] = "" """A UTC datetime when the fetch state ends.""" fetch_end_time: t.Optional[str] = "" """A UTC datetime when the download state starts.""" download_start_time: t.Optional[str] = "" """A UTC datetime when the download state ends.""" download_end_time: t.Optional[str] = "" """A UTC datetime when the upload state starts.""" upload_start_time: t.Optional[str] = "" """A UTC datetime when the upload state ends.""" upload_end_time: t.Optional[str] = "" @classmethod def from_dict(cls, download_status: t.Dict) -> "DownloadStatus": """Instantiate DownloadStatus dataclass from dict.""" download_status_instance = cls() for key, value in download_status.items(): if key == "status": setattr(download_status_instance, key, Status(value)) elif key == "stage" and value is not None: setattr(download_status_instance, key, Stage(value)) else: setattr(download_status_instance, key, value) return download_status_instance @classmethod def to_dict(cls, instance) -> t.Dict: """Return the fields of a dataclass instance as a manifest ingestible dictionary mapping of field names to field values.""" download_status_dict = {} for field in dataclasses.fields(instance): key = field.name value = getattr(instance, field.name) if isinstance(value, Status) or isinstance(value, Stage): download_status_dict[key] = value.value elif isinstance(value, pd.Timestamp): download_status_dict[key] = value.isoformat() elif key == "selection" and value is not None: download_status_dict[key] = json.dumps(value) else: download_status_dict[key] = value return download_status_dict @dataclasses.dataclass class Manifest(abc.ABC): """Abstract manifest of download statuses. Update download statuses to some storage medium. This class lets one indicate that a download is `scheduled` or in a transaction process. In the event of a transaction, a download will be updated with an `in-progress`, `success` or `failure` status (with accompanying metadata). Example: ``` my_manifest = parse_manifest_location(Location('fs://some-firestore-collection')) # Schedule data for download my_manifest.schedule({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username') # ... # Initiate a transaction – it will record that the download is `in-progess` with my_manifest.transact({'some': 'metadata'}, 'path/to/downloaded/file', 'my-username') as tx: # download logic here pass # ... # on error, will record the download as a `failure` before propagating the error. By default, it will # record download as a `success`. ``` Attributes: status: The current `DownloadStatus` of the Manifest. """ # To reduce the impact of _read() and _update() calls # on the start time of the stage. license_id: str = "" prev_stage_precise_start_time: t.Optional[str] = None status: t.Optional[DownloadStatus] = None # This is overridden in subclass. def __post_init__(self): """Initialize the manifest.""" pass def schedule( self, config_name: str, dataset: str, selection: t.Dict, location: str, user: str, ) -> None: """Indicate that a job has been scheduled for download. 'scheduled' jobs occur before 'in-progress', 'success' or 'finished'. """ scheduled_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec="seconds") ) self.status = DownloadStatus( config_name=config_name, dataset=dataset if dataset else None, selection=selection, location=location, area=fetch_geo_polygon(selection.get("area", GLOBAL_COVERAGE_AREA)), username=user, stage=None, status=Status.SCHEDULED, error=None, size=None, scheduled_time=scheduled_time, retrieve_start_time=None, retrieve_end_time=None, fetch_start_time=None, fetch_end_time=None, download_start_time=None, download_end_time=None, upload_start_time=None, upload_end_time=None, ) self._update(self.status) def skip( self, config_name: str, dataset: str, selection: t.Dict, location: str, user: str, ) -> None: """Updates the manifest to mark the shards that were skipped in the current job as 'upload' stage and 'success' status, indicating that they have already been downloaded. """ old_status = self._read(location) # The manifest needs to be updated for a skipped shard if its entry is not present, or # if the stage is not 'upload', or if the stage is 'upload' but the status is not 'success'. if ( old_status.location != location or old_status.stage != Stage.UPLOAD or old_status.status != Status.SUCCESS ): current_utc_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec="seconds") ) size = get_file_size(location) status = DownloadStatus( config_name=config_name, dataset=dataset if dataset else None, selection=selection, location=location, area=fetch_geo_polygon(selection.get("area", GLOBAL_COVERAGE_AREA)), username=user, stage=Stage.UPLOAD, status=Status.SUCCESS, error=None, size=size, scheduled_time=None, retrieve_start_time=None, retrieve_end_time=None, fetch_start_time=None, fetch_end_time=None, download_start_time=None, download_end_time=None, upload_start_time=current_utc_time, upload_end_time=current_utc_time, ) self._update(status) logger.info( f"Manifest updated for skipped shard: {location!r} -- {DownloadStatus.to_dict(status)!r}." ) def _set_for_transaction( self, config_name: str, dataset: str, selection: t.Dict, location: str, user: str, ) -> None: """Reset Manifest state in preparation for a new transaction.""" self.status = dataclasses.replace(self._read(location)) self.status.config_name = config_name self.status.dataset = dataset if dataset else None self.status.selection = selection self.status.location = location self.status.username = user def __enter__(self) -> None: pass def __exit__(self, exc_type, exc_inst, exc_tb) -> None: """Record end status of a transaction as either 'success' or 'failure'.""" if exc_type is None: status = Status.SUCCESS error = None else: status = Status.FAILURE # For explanation, see https://docs.python.org/3/library/traceback.html#traceback.format_exception error = f"license_id: {self.license_id} " error += "\n".join(traceback.format_exception(exc_type, exc_inst, exc_tb)) new_status = dataclasses.replace(self.status) new_status.error = error new_status.status = status current_utc_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec="seconds") ) # This is necessary for setting the precise start time of the previous stage # and end time of the final stage, as well as handling the case of Status.FAILURE. if new_status.stage == Stage.FETCH: new_status.fetch_start_time = self.prev_stage_precise_start_time new_status.fetch_end_time = current_utc_time elif new_status.stage == Stage.RETRIEVE: new_status.retrieve_start_time = self.prev_stage_precise_start_time new_status.retrieve_end_time = current_utc_time elif new_status.stage == Stage.DOWNLOAD: new_status.download_start_time = self.prev_stage_precise_start_time new_status.download_end_time = current_utc_time else: new_status.upload_start_time = self.prev_stage_precise_start_time new_status.upload_end_time = current_utc_time new_status.size = get_file_size(new_status.location) self.status = new_status self._update(self.status) def transact( self, config_name: str, dataset: str, selection: t.Dict, location: str, user: str, ) -> "Manifest": """Create a download transaction.""" self._set_for_transaction(config_name, dataset, selection, location, user) return self def set_stage(self, stage: Stage) -> None: """Sets the current stage in manifest.""" prev_stage = self.status.stage new_status = dataclasses.replace(self.status) new_status.stage = stage new_status.status = Status.IN_PROGRESS current_utc_time = ( datetime.datetime.utcnow() .replace(tzinfo=datetime.timezone.utc) .isoformat(timespec="seconds") ) if stage == Stage.FETCH: new_status.fetch_start_time = current_utc_time elif stage == Stage.RETRIEVE: new_status.retrieve_start_time = current_utc_time elif stage == Stage.DOWNLOAD: new_status.fetch_start_time = self.prev_stage_precise_start_time new_status.fetch_end_time = current_utc_time new_status.download_start_time = current_utc_time else: if prev_stage == Stage.DOWNLOAD: new_status.download_start_time = self.prev_stage_precise_start_time new_status.download_end_time = current_utc_time else: new_status.retrieve_start_time = self.prev_stage_precise_start_time new_status.retrieve_end_time = current_utc_time new_status.upload_start_time = current_utc_time self.status = new_status self._update(self.status) @abc.abstractmethod def _read(self, location: str) -> DownloadStatus: pass @abc.abstractmethod def _update(self, download_status: DownloadStatus) -> None: pass class FirestoreManifest(Manifest, Database): """A Firestore Manifest. This Manifest implementation stores DownloadStatuses in a Firebase document store. The document hierarchy for the manifest is as follows: [manifest ] ├── doc_id (md5 hash of the path) { 'selection': {...}, 'location': ..., 'username': ... } └── etc... Where `[]` indicates a collection and ` {...}` indicates a document. """ def _get_db(self) -> firestore.firestore.Client: """Acquire a firestore client, initializing the firebase app if necessary. Will attempt to get the db client five times. If it's still unsuccessful, a `ManifestException` will be raised. """ db = None attempts = 0 while db is None: try: db = firestore.client() except ValueError as e: # The above call will fail with a value error when the firebase app is not initialized. # Initialize the app here, and try again. # Use the application default credentials. cred = credentials.ApplicationDefault() firebase_admin.initialize_app(cred) logger.info("Initialized Firebase App.") if attempts > 4: raise ManifestException( "Exceeded number of retries to get firestore client." ) from e time.sleep(get_wait_interval(attempts)) attempts += 1 return db def _read(self, location: str) -> DownloadStatus: """Reads the JSON data from a manifest.""" doc_id = generate_md5_hash(location) # Update document with download status download_doc_ref = self.root_document_for_store(doc_id) result = download_doc_ref.get() row = {} if result.exists: records = result.to_dict() row = {n: to_json_serializable_type(v) for n, v in records.items()} return DownloadStatus.from_dict(row) def _update(self, download_status: DownloadStatus) -> None: """Update or create a download status record.""" logger.info("Updating Firestore Manifest.") status = DownloadStatus.to_dict(download_status) doc_id = generate_md5_hash(status["location"]) # Update document with download status. download_doc_ref = self.root_document_for_store(doc_id) result: WriteResult = download_doc_ref.set(status) logger.info( "Firestore manifest updated. " + f"update_time={result.update_time}, " + f"status={status['status']} " + f"stage={status['stage']} " + f"filename={download_status.location}." ) def root_document_for_store(self, store_scheme: str) -> DocumentReference: """Get the root manifest document given the user's config and current document's storage location.""" return ( self._get_db() .collection(get_config().manifest_collection) .document(store_scheme) ) ================================================ FILE: weather_dl_v2/license_deployment/util.py ================================================ # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import logging import geojson import hashlib import itertools import os import signal import socket import subprocess import sys import typing as t import numpy as np import pandas as pd from apache_beam.io.gcp import gcsio from apache_beam.utils import retry from xarray.core.utils import ensure_us_time_resolution from urllib.parse import urlparse from google.api_core.exceptions import BadRequest from threading import Lock logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) LATITUDE_RANGE = (-90, 90) LONGITUDE_RANGE = (-180, 180) GLOBAL_COVERAGE_AREA = [90, -180, -90, 180] def exceptionit(func): def inner_function(*args, **kwargs): try: func(*args, **kwargs) except Exception as e: logger.error(f"exception in {func.__name__} {e.__class__.__name__} {e}.") return inner_function def _retry_if_valid_input_but_server_or_socket_error_and_timeout_filter( exception, ) -> bool: if isinstance(exception, socket.timeout): return True if isinstance(exception, TimeoutError): return True # To handle the concurrency issue in BigQuery. if isinstance(exception, BadRequest): return True return retry.retry_if_valid_input_but_server_error_and_timeout_filter(exception) class GracefulKiller: """Class to check for SIGTERM signal. Used to handle gracefull termination. If ever SIGTERM is recived by the process GracefulKiller.kill_now will be `true`.""" kill_now = False def __init__(self): signal.signal(signal.SIGINT, self.exit_gracefully) signal.signal(signal.SIGTERM, self.exit_gracefully) def exit_gracefully(self, signum, frame): logger.warning('SIGTERM recieved.') self.kill_now = True class _FakeClock: def sleep(self, value): pass def retry_with_exponential_backoff(fun): """A retry decorator that doesn't apply during test time.""" clock = retry.Clock() # Use a fake clock only during test time... if "unittest" in sys.modules.keys(): clock = _FakeClock() return retry.with_exponential_backoff( retry_filter=_retry_if_valid_input_but_server_or_socket_error_and_timeout_filter, clock=clock, )(fun) # TODO(#245): Group with common utilities (duplicated) def ichunked(iterable: t.Iterable, n: int) -> t.Iterator[t.Iterable]: """Yield evenly-sized chunks from an iterable.""" input_ = iter(iterable) try: while True: it = itertools.islice(input_, n) # peek to check if 'it' has next item. first = next(it) yield itertools.chain([first], it) except StopIteration: pass # TODO(#245): Group with common utilities (duplicated) def copy(src: str, dst: str) -> None: """Copy data via `gcloud storage cp`.""" try: subprocess.run(["gcloud", "storage", "cp", src, dst], check=True, capture_output=True) except subprocess.CalledProcessError as e: logger.info( f'Failed to copy file {src!r} to {dst!r} due to {e.stderr.decode("utf-8")}.' ) raise # TODO(#245): Group with common utilities (duplicated) def to_json_serializable_type(value: t.Any) -> t.Any: """Returns the value with a type serializable to JSON""" # Note: The order of processing is significant. logger.info("Serializing to JSON.") if pd.isna(value) or value is None: return None elif np.issubdtype(type(value), np.floating): return float(value) elif isinstance(value, np.ndarray): # Will return a scaler if array is of size 1, else will return a list. return value.tolist() elif ( isinstance(value, datetime.datetime) or isinstance(value, str) or isinstance(value, np.datetime64) ): # Assume strings are ISO format timestamps... try: value = datetime.datetime.fromisoformat(value) except ValueError: # ... if they are not, assume serialization is already correct. return value except TypeError: # ... maybe value is a numpy datetime ... try: value = ensure_us_time_resolution(value).astype(datetime.datetime) except AttributeError: # ... value is a datetime object, continue. pass # We use a string timestamp representation. if value.tzname(): return value.isoformat() # We assume here that naive timestamps are in UTC timezone. return value.replace(tzinfo=datetime.timezone.utc).isoformat() elif isinstance(value, np.timedelta64): # Return time delta in seconds. return float(value / np.timedelta64(1, "s")) # This check must happen after processing np.timedelta64 and np.datetime64. elif np.issubdtype(type(value), np.integer): return int(value) return value def fetch_geo_polygon(area: t.Union[list, str]) -> str: """Calculates a geography polygon from an input area.""" # Ref: https://confluence.ecmwf.int/pages/viewpage.action?pageId=151520973 if isinstance(area, str): # European area if area == "E": area = [73.5, -27, 33, 45] # Global area elif area == "G": area = GLOBAL_COVERAGE_AREA else: raise RuntimeError(f"Not a valid value for area in config: {area}.") n, w, s, e = [float(x) for x in area] if s < LATITUDE_RANGE[0]: raise ValueError(f"Invalid latitude value for south: '{s}'") if n > LATITUDE_RANGE[1]: raise ValueError(f"Invalid latitude value for north: '{n}'") if w < LONGITUDE_RANGE[0]: raise ValueError(f"Invalid longitude value for west: '{w}'") if e > LONGITUDE_RANGE[1]: raise ValueError(f"Invalid longitude value for east: '{e}'") # Define the coordinates of the bounding box. coords = [[w, n], [w, s], [e, s], [e, n], [w, n]] # Create the GeoJSON polygon object. polygon = geojson.dumps(geojson.Polygon([coords])) return polygon def get_file_size(path: str) -> float: parsed_gcs_path = urlparse(path) if parsed_gcs_path.scheme != "gs" or parsed_gcs_path.netloc == "": return os.stat(path).st_size / (1024**3) if os.path.exists(path) else 0 else: return ( gcsio.GcsIO().size(path) / (1024**3) if gcsio.GcsIO().exists(path) else 0 ) def get_wait_interval(num_retries: int = 0) -> float: """Returns next wait interval in seconds, using an exponential backoff algorithm.""" if 0 == num_retries: return 0 return 2**num_retries def generate_md5_hash(input: str) -> str: """Generates md5 hash for the input string.""" return hashlib.md5(input.encode("utf-8")).hexdigest() def download_with_aria2(url: str, path: str) -> None: """Downloads a file from the given URL using the `aria2c` command-line utility, with options set to improve download speed and reliability.""" dir_path, file_name = os.path.split(path) try: subprocess.run( [ "aria2c", "-x", "16", "-s", "16", url, "-d", dir_path, "-o", file_name, "--allow-overwrite", ], check=True, capture_output=True, ) except subprocess.CalledProcessError as e: logger.info( f'Failed download from server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}.' ) raise class ThreadSafeDict: """A thread safe dict with crud operations.""" def __init__(self) -> None: self._dict = {} self._lock = Lock() self.initial_delay = 1 self.factor = 0.5 def __getitem__(self, key): val = None with self._lock: val = self._dict[key] return val def __setitem__(self, key, value): with self._lock: self._dict[key] = value def remove(self, key): with self._lock: self._dict.__delitem__(key) def has_key(self, key): present = False with self._lock: present = key in self._dict return present def increment(self, key, delta=1): with self._lock: if key in self._dict: self._dict[key] += delta def decrement(self, key, delta=1): with self._lock: if key in self._dict: self._dict[key] -= delta def find_exponential_delay(self, n: int) -> int: delay = self.initial_delay for _ in range(n): delay += delay*self.factor return delay def exponential_time(self, key): """Returns exponential time based on dict value. Time in seconds.""" delay = 0 with self._lock: if key in self._dict: delay = self.find_exponential_delay(self._dict[key]) return delay * 60 ================================================ FILE: weather_mv/MANIFEST.in ================================================ global-exclude *_test.py include README.md prune test_data ================================================ FILE: weather_mv/README.md ================================================ # ⛅️ `weather-mv` – Weather Mover Weather Mover loads weather data from cloud storage into analytics engines, like [Google BigQuery](https://cloud.google.com/bigquery) (_alpha_). ## Features * **Rapid Querability**: After geospatial data is in BigQuery, data wranging becomes as simple as writing SQL. This allows for rapid data exploration, visualization, and model pipeline prototyping. * **Simple Versioning**: All rows in the table come with a `data_import_time` column. This provides some notion of how the data is versioned. Downstream analysis can adapt to data ingested at differen times by updating a `WHERE` clause. * **Parallel Upload**: Each file will be processed in parallel. With Dataflow autoscaling, even large datasets can be processed in a reasonable amount of time. * **Streaming support**: When running the mover in streaming mode, it will automatically process files as they appear in cloud buckets via PubSub. * _(new)_ **Grib Regridding**: `weather-mv regrid` uses [MetView](https://metview.readthedocs.io/en/latest/) to interpolate Grib files to a [range of grids.](https://metview.readthedocs.io/en/latest/metview/using_metview/regrid_intro.html?highlight=grid#grid) * _(new)_ **Earth Engine Ingestion**: `weather-mv earthengine` ingests weather data into [Google Earth Engine](https://earthengine.google.com/). ## Usage ``` usage: weather-mv [-h] {bigquery,bq,regrid,rg} ... Weather Mover loads weather data from cloud storage into analytics engines. positional arguments: {bigquery,bq,regrid,rg,earthengine,ee} help for subcommand bigquery (bq) Move data into Google BigQuery regrid (rg) Copy and regrid grib data with MetView. earthengine (ee) Move data into Google Earth Engine optional arguments: -h, --help show this help message and exit ``` The weather mover makes use of subcommands to distinguish between tasks. The above tasks are currently supported. _Common options_ * `-i, --uris`: (required) URI glob pattern matching input weather data, e.g. 'gs://ecmwf/era5/era5-2015-*.gb'. * `--topic`: A Pub/Sub topic for GCS OBJECT_FINALIZE events, or equivalent, of a cloud bucket. E.g. 'projects//topics/'. Cannot be used with `--subscription`. * `--subscription`: A Pub/Sub subscription for GCS OBJECT_FINALIZE events, or equivalent, of a cloud bucket. Cannot be used with `--topic`. * `--window_size`: Output file's window size in minutes. Only used with the `topic` flag. Default: 1.0 minute. * `--num_shards`: Number of shards to use when writing windowed elements to cloud storage. Only used with the `topic` flag. Default: 5 shards. * `-d, --dry-run`: Preview the load into BigQuery. Default: off. * `--log-level`: An integer to configure log level. Default: 2(INFO). * `--use-local-code`: Supply local code to the Runner. Default: False. Invoke with `-h` or `--help` to see the full range of options. ### `weather-mv bigquery` ``` usage: weather-mv bigquery [-h] -i URIS [--topic TOPIC] [--window_size WINDOW_SIZE] [--num_shards NUM_SHARDS] [-d] -o OUTPUT_TABLE --geo_data_parquet_path GEO_DATA_PARQUET [-v variables [variables ...]] [-a area [area ...]] [--import_time IMPORT_TIME] [--infer_schema] [--xarray_open_dataset_kwargs XARRAY_OPEN_DATASET_KWARGS] [--tif_metadata_for_start_time TIF_METADATA_FOR_START_TIME] [--tif_metadata_for_end_time TIF_METADATA_FOR_END_TIME] [-s] [--rows_chunk_size rows_chunk_size] [--skip_creating_polygon] [--skip_creating_geo_data_parquet] ``` The `bigquery` subcommand loads weather data into BigQuery. In addition to the common options above, users may specify command-specific options: _Command options_: * `-o, --output_table`: (required) Full name of destination BigQuery table. Ex: my_project.my_dataset.my_table * `--geo_data_parquet_path`: (required) A path to dump the geo data parquet. This parquet consists of columns: latitude, longitude, geo_point, and geo_polygon. We calculate all of this information upfront so that we do not need to process it every time we process a set of files. * `-v, --variables`: Target variables (or coordinates) for the BigQuery schema. Default: will import all data variables as columns. * `-a, --area`: Target area in [N, W, S, E]. Default: Will include all available area. * `--import_time`: When writing data to BigQuery, record that data import occurred at this time (format: YYYY-MM-DD HH:MM:SS.usec+offset). Default: now in UTC. * `--infer_schema`: Download one file in the URI pattern and infer a schema from that file. Default: off. * `--xarray_open_dataset_kwargs`: Keyword-args to pass into `xarray.open_dataset()` in the form of a JSON string. * `--rows_chunk_size`: The size of the chunk of rows to be loaded into memory for processing. Depending on your system's memory, use this to tune how much rows to process. Default: 1_000_000. * `--tif_metadata_for_start_time` : Metadata that contains tif file's start/initialization time. Applicable only for tif files. * `--tif_metadata_for_end_time` : Metadata that contains tif file's end/forecast time. Applicable only for tif files (optional). * `-s, --skip_region_validation` : Skip validation of regions for data migration. Default: off. * `--disable_grib_schema_normalization` : To disable grib's schema normalization. Default: off. * `--skip_creating_polygon` : Not ingest grid points as polygons in BigQuery. Default: Ingest grid points as Polygon in BigQuery. Note: This feature relies on the assumption that the provided grid has an equal distance between consecutive points of latitude and longitude. * `--skip_creating_geo_data_parquet`: Skip the generation of the geo data parquet if it already exists at the given --geo_data_parquet_path. Please note that the geo data parquet is mandatory for ingesting data into BigQuery. Default: Create geo data parquet file. Invoke with `bq -h` or `bigquery --help` to see the full range of options. > Note: In case of grib files, by default its schema will be normalized and the name of the data variables will look > like `___`. > > This solves the issue of skipping over some of the data due to: https://github.com/ecmwf/cfgrib#filter-heterogeneous-grib-files. _Usage examples_: ```bash weather-mv bigquery --uris "gs://your-bucket/*.nc" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --temp_location "gs://$BUCKET/tmp" \ # Needed for batch writes to BigQuery --direct_num_workers 2 ``` Using the subcommand alias `bq`: ```bash weather-mv bq --uris "gs://your-bucket/*.nc" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --temp_location "gs://$BUCKET/tmp" \ # Needed for batch writes to BigQuery --direct_num_workers 2 ``` Preview load with a dry run: ```bash weather-mv bq --uris "gs://your-bucket/*.nc" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --temp_location "gs://$BUCKET/tmp" \ # Needed for batch writes to BigQuery --direct_num_workers 2 \ --dry-run ``` Ingest grid points with skip creating polygon in BigQuery: ```bash weather-mv bq --uris "gs://your-bucket/*.nc" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --temp_location "gs://$BUCKET/tmp" \ # Needed for batch writes to BigQuery --direct_num_workers 2 \ --skip_creating_polygon ``` Load COG's (.tif) files: ```bash weather-mv bq --uris "gs://your-bucket/*.tif" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --temp_location "gs://$BUCKET/tmp" \ # Needed for batch writes to BigQuery --direct_num_workers 2 \ --tif_metadata_for_start_time start_time \ --tif_metadata_for_end_time end_time ``` Upload only a subset of variables: ```bash weather-mv bq --uris "gs://your-bucket/*.nc" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --variables u10 v10 t --temp_location "gs://$BUCKET/tmp" \ --direct_num_workers 2 ``` Upload all variables, but for a specific geographic region (for example, the continental US): ```bash weather-mv bq --uris "gs://your-bucket/*.nc" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --area 49 -124 24 -66 \ --temp_location "gs://$BUCKET/tmp" \ --direct_num_workers 2 ``` Upload a zarr file: ```bash weather-mv bq --uris "gs://your-bucket/*.zarr" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --temp_location "gs://$BUCKET/tmp" \ --use-local-code \ --zarr \ --direct_num_workers 2 ``` Upload a specific date range's data from the zarr file: ```bash weather-mv bq --uris "gs://your-bucket/*.zarr" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --temp_location "gs://$BUCKET/tmp" \ --use-local-code \ --zarr \ --zarr_kwargs '{"start_date": "2021-07-18", "end_date": "2021-07-19"}' \ --direct_num_workers 2 ``` Upload a specific date range's data from the file: ```bash weather-mv bq --uris "gs://your-bucket/*.nc" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --temp_location "gs://$BUCKET/tmp" \ --use-local-code \ --xarray_open_dataset_kwargs '{"start_date": "2021-07-18", "end_date": "2021-07-19"}' \ ``` Control how weather data is opened with XArray: ```bash weather-mv bq --uris "gs://your-bucket/*.grib" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --xarray_open_dataset_kwargs '{"engine": "cfgrib", "indexpath": "", "backend_kwargs": {"filter_by_keys": {"typeOfLevel": "surface", "edition": 1}}}' \ --temp_location "gs://$BUCKET/tmp" \ --direct_num_workers 2 ``` Using DataflowRunner: ```bash weather-mv bq --uris "gs://your-bucket/*.nc" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --runner DataflowRunner \ --project $PROJECT \ --region $REGION \ --temp_location "gs://$BUCKET/tmp" \ --job_name $JOB_NAME ``` Using DataflowRunner and using local code for pipeline ```bash weather-mv bq --uris "gs://your-bucket/*.nc" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --runner DataflowRunner \ --project $PROJECT \ --region $REGION \ --temp_location "gs://$BUCKET/tmp" \ --job_name $JOB_NAME \ --use-local-code ``` For a full list of how to configure the Dataflow pipeline, please review [this table](https://cloud.google.com/dataflow/docs/reference/pipeline-options). ### `weather-mv regrid` ``` usage: weather-mv regrid [-h] -i URIS [--topic TOPIC] [--window_size WINDOW_SIZE] [--num_shards NUM_SHARDS] [-d] --output_path OUTPUT_PATH [--regrid_kwargs REGRID_KWARGS] [--to_netcdf] ``` The `regrid` subcommand makes a regridded copy of the input data with MetView. To use this capability of the weather mover, please use the `[regrid]` extra when installing: ```shell pip install google-weather-tools[regrid] ``` > **Warning**: MetView requires a decent amount of disk space in order to perform any regrid operation! Intermediary > regridding steps will write temporary grib data to disk. Thus, please make use of the `--disk_size_gb` Dataflow > option. A good rule of thumb would be to consume `30 + 2.5x` GBs of disk, where `x` is the size of each source data > file. > > TODO([#191](https://github.com/google/weather-tools/issues/191)): Find smaller disk space bound. In addition to the common options above, users may specify command-specific options: _Command options_: * `-o, --output_path`: (required) The destination path for the regridded files. * `-k, --regrid_kwargs`: Keyword-args to pass into `metview.regrid()` in the form of a JSON string. Will default to '{"grid": [0.25, 0.25]}'. * `--to_netcdf`: Write output file in NetCDF via XArray. Default: off * `-bz2`, `--apply_bz2_compression`: Enable bzip2 (.bz2) compression for the regridded file. Default: off. For a full range of grid options, please consult [this documentation.](https://metview.readthedocs.io/en/latest/metview/using_metview/regrid_intro.html?highlight=grid#grid) Invoke with `rg -h` or `regrid --help` to see the full range of options. > Note: Currently, `regrid` doesn't work out-of-the-box! Until [#172](https://github.com/google/weather-tools/issues/172) > is fixed, users will have to use a workaround in order to ensure MetView is installed in their runner environment > (instructions are below). _Usage examples_: ```bash weather-mv regrid --uris "gs://your-bucket/*.gb" \ --output_path "gs://regrid-bucket/" ``` Using the subcomand alias 'rg': ```bash weather-mv rg --uris "gs://your-bucket/*.gb" \ --output_path "gs://regrid-bucket/" ``` Preview regrid with a dry run: ```bash weather-mv rg --uris "gs://your-bucket/*.gb" \ --output_path "gs://regrid-bucket/" \ --dry-run ``` Interpolate to a finer grid resolution: ```bash weather-mv rg --uris "gs://your-bucket/*.gb" \ --output_path "gs://regrid-bucket/" \ --regrid_kwargs '{"grid": [0.1, 0.1]}'. ``` Interpolate to a high-resolution octahedral gaussian grid: ```bash weather-mv rg --uris "gs://your-bucket/*.gb" \ --output_path "gs://regrid-bucket/" \ --regrid_kwargs '{"grid": "O1280}'. ``` Convert gribs to NetCDF on copy: ```bash weather-mv rg --uris "gs://your-bucket/*.gb" \ --output_path "gs://regrid-bucket/" \ --to_netcdf ``` Using DataflowRunner: ```bash weather-mv rg --uris "gs://your-bucket/*.nc" \ --output_path "gs://regrid-bucket/" \ --runner DataflowRunner \ --project $PROJECT \ --region $REGION \ --temp_location "gs://$BUCKET/tmp" \ --experiment=use_runner_v2 \ --sdk_container_image="gcr.io/$PROJECT/$REPO:latest" \ --job_name $JOB_NAME ``` Using DataflowRunner, with added disk per VM: ```bash weather-mv rg --uris "gs://your-bucket/*.nc" \ --output_path "gs://regrid-bucket/" \ --runner DataflowRunner \ --project $PROJECT \ --region $REGION \ --disk_size_gb 250 \ --temp_location "gs://$BUCKET/tmp" \ --experiment=use_runner_v2 \ --sdk_container_image="gcr.io/$PROJECT/$REPO:latest" \ --job_name $JOB_NAME ``` For a full list of how to configure the Dataflow pipeline, please review [this table](https://cloud.google.com/dataflow/docs/reference/pipeline-options). ### `weather-mv earthengine` ``` usage: weather-mv earthengine [-h] -i URIS --asset_location ASSET_LOCATION --ee_asset EE_ASSET [--ee_asset_type ASSET_TYPE] [--disable_grib_schema_normalization] [--use_personal_account] [-s] [--xarray_open_dataset_kwargs XARRAY_OPEN_DATASET_KWARGS] [--service_account my-service-account@...gserviceaccount.com --private_key PRIVATE_KEY_LOCATION] [--ee_qps EE_QPS] [--ee_latency EE_LATENCY] [--ee_max_concurrent EE_MAX_CONCURRENT] ``` The `earthengine` subcommand ingests weather data into Earth Engine. It includes a caching function that allows it to skip ingestion for assets that have already been created in Earth Engine or for which the asset file already exists in the GCS bucket. In addition to the common options above, users may specify command-specific options: _Command options_: * `--asset_location`: (required) Bucket location at which asset files will be pushed. * `--ee_asset`: (required) The asset folder path in earth engine project where the asset files will be pushed. It should be in format: `projects//assets/`. Make sure that is there under in earth engine assets. i.e. projects/my-gcp-project/assets/my/foo/bar. * `--ee_asset_type`: The type of asset to ingest in the earth engine. Default: IMAGE. Supported types are `IMAGE` and `TABLE`.\ `IMAGE`: Uploads georeferenced raster datasets in GeoTIFF format.\ `TABLE`: Uploads the datsets in the CSV format. Useful in case of point data (sparse data). * `--disable_grib_schema_normalization`: Restricts merging of grib datasets. Default: False * `-u, --use_personal_account`: To use personal account for earth engine authentication. * `--service_account`: Service account address when using a private key for earth engine authentication. * `--private_key`: To use a private key for earth engine authentication. Only used with the `service_account` flag. * `--xarray_open_dataset_kwargs`: Keyword-args to pass into `xarray.open_dataset()` in the form of a JSON string. * `-s, --skip_region_validation` : Skip validation of regions for data migration. Default: off. * `-f, --force`: A flag that allows overwriting of existing asset files in the GCS bucket. Default: off, which means that the ingestion of URIs for which assets files (GeoTiff/CSV) already exist in the GCS bucket will be skipped. * `--ee_qps`: Maximum queries per second allowed by EE for your project. Default: 10. * `--ee_latency`: The expected latency per requests, in seconds. Default: 0.5. * `--ee_max_concurrent`: Maximum concurrent api requests to EE allowed for your project. Default: 10. * `--band_names_mapping`: A JSON file which contains the band names for the TIFF file. * `--initialization_time_regex`: A Regex string to get the initialization time from the filename. * `--forecast_time_regex`: A Regex string to get the forecast/end time from the filename. * `--group_common_hypercubes`: A flag that allows to split up large grib files into multiple level-wise ImageCollections / COGS. * `--use_deflate`: A flag that allows you to use deflate algorithm for beter compression. Using deflate compression takes extra time in COG creation. Default:False. * `--use_metrics`: A flag that allows you to add Beam metrics to the pipeline. Default: False. * `--use_monitoring_metrics`: A flag that allows you to to add Google Cloud Monitoring metrics to the pipeline. Default: False. Invoke with `ee -h` or `earthengine --help` to see the full range of options. _Usage examples_: ```bash weather-mv earthengine --uris "gs://your-bucket/*.grib" \ --asset_location "gs://$BUCKET/assets" \ # Needed to store assets generated from *.grib --ee_asset "projects/$PROJECT/assets/test_dir" ``` Using the subcommand alias `ee`: ```bash weather-mv ee --uris "gs://your-bucket/*.grib" \ --asset_location "gs://$BUCKET/assets" \ # Needed to store assets generated from *.grib --ee_asset "projects/$PROJECT/assets/test_dir" ``` Preview ingestion with a dry run: ```bash weather-mv ee --uris "gs://your-bucket/*.grib" \ --asset_location "gs://$BUCKET/assets" \ # Needed to store assets generated from *.grib --ee_asset "projects/$PROJECT/assets/test_dir" \ --dry-run ``` Authenticate earth engine using personal account: ```bash weather-mv ee --uris "gs://your-bucket/*.grib" \ --asset_location "gs://$BUCKET/assets" \ # Needed to store assets generated from *.grib --ee_asset "projects/$PROJECT/assets/test_dir" \ --use_personal_account ``` Authenticate earth engine using a private key: ```bash weather-mv ee --uris "gs://your-bucket/*.grib" \ --asset_location "gs://$BUCKET/assets" \ # Needed to store assets generated from *.grib --ee_asset "projects/$PROJECT/assets/test_dir" \ --service_account "my-service-account@...gserviceaccount.com" \ --private_key "path/to/private_key.json" ``` Ingest asset as table in earth engine: ```bash weather-mv ee --uris "gs://your-bucket/*.grib" \ --asset_location "gs://$BUCKET/assets" \ # Needed to store assets generated from *.grib --ee_asset "projects/$PROJECT/assets/test_dir" \ --ee_asset_type "TABLE" ``` Restrict merging all bands or grib normalization: ```bash weather-mv ee --uris "gs://your-bucket/*.grib" \ --asset_location "gs://$BUCKET/assets" \ # Needed to store assets generated from *.grib --ee_asset "projects/$PROJECT/assets/test_dir" \ --disable_grib_schema_normalization ``` Control how weather data is opened with XArray: ```bash weather-mv ee --uris "gs://your-bucket/*.grib" \ --asset_location "gs://$BUCKET/assets" \ # Needed to store assets generated from *.grib --ee_asset "projects/$PROJECT/assets/test_dir" --xarray_open_dataset_kwargs '{"engine": "cfgrib", "indexpath": "", "backend_kwargs": {"filter_by_keys": {"typeOfLevel": "surface", "edition": 1}}}' \ --temp_location "gs://$BUCKET/tmp" ``` Limit EE requests: ```bash weather-mv ee --uris "gs://your-bucket/*.grib" \ --asset_location "gs://$BUCKET/assets" \ # Needed to store assets generated from *.grib --ee_asset "projects/$PROJECT/assets/test_dir" \ --ee_qps 10 \ --ee_latency 0.5 \ --ee_max_concurrent 10 ``` Custom Band names: ```bash weather-mv ee --uris "gs://your-bucket/*.tif" \ --asset_location "gs://$BUCKET/assets" \ # Needed to store assets generated from *.tif --ee_asset "projects/$PROJECT/assets/test_dir" \ --band_names_mapping "filename.json" ``` Getting initialization and forecast/end date-time from the filename: ```bash weather-mv ee --uris "gs://your-bucket/*.tif" \ --asset_location "gs://$BUCKET/assets" \ # Needed to store assets generated from *.tif --ee_asset "projects/$PROJECT/assets/test_dir" \ --initialization_time_regex "$REGEX" \ --forecast_time_regex "$REGEX" ``` Example: ```bash weather-mv ee --uris "gs://tmp-gs-bucket/3B-HHR-E_MS_MRG_3IMERG_20220901-S000000-E002959_0000_V06C_30min.tiff" \ --asset_location "gs://$BUCKET/assets" \ # Needed to store assets generated from *.tif --ee_asset "projects/$PROJECT/assets/test_dir" \ --initialization_time_regex "3B-HHR-E_MS_MRG_3IMERG_%Y%m%d-S%H%M%S-*tiff" \ --forecast_time_regex "3B-HHR-E_MS_MRG_3IMERG_%Y%m%d-S*-E%H%M%S*tiff" ``` Ingesting a file into Earth Engine in a levelwise manner: ```bash weather-mv ee --uris "gs://your-bucket/*.tif" \ --asset_location "gs://$BUCKET/assets" \ # Needed to store assets generated from *.tif --ee_asset "projects/$PROJECT/assets/test_dir" \ --group_common_hypercubes ``` Using DataflowRunner: ```bash weather-mv ee --uris "gs://your-bucket/*.grib" \ --asset_location "gs://$BUCKET/assets" \ # Needed to store assets generated from *.grib --ee_asset "projects/$PROJECT/assets/test_dir" --runner DataflowRunner \ --project $PROJECT \ --region $REGION \ --temp_location "gs://$BUCKET/tmp" \ --job_name $JOB_NAME ``` For a full list of how to configure the Dataflow pipeline, please review [this table](https://cloud.google.com/dataflow/docs/reference/pipeline-options). ## Streaming ingestion `weather-mv` optionally provides the ability to react to [Pub/Sub events for objects added to GCS](https://cloud.google.com/storage/docs/pubsub-notifications). This can be used to automate ingestion into BigQuery as soon as weather data is disseminated. Another common use case it to automatically create a down-sampled version of a dataset with `regrid`. To set up the Weather Mover with streaming ingestion, use the `--topic` or `--subscription` flag (see "Common options" above). Objects that don't match the `--uris` glob pattern will be filtered out of ingestion. This way, a bucket can contain multiple types of data yet only have subsets processed with `weather-mv`. > It's worth noting: when setting up PubSub, **make sure to create a topic for GCS `OBJECT_FINALIZE` events only.** _Usage examples_: ```shell weather-mv bq --uris "gs://your-bucket/*.nc" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --topic "projects/$PROJECT/topics/$TOPIC_ID" \ --runner DataflowRunner \ --project $PROJECT \ --temp_location gs://$BUCKET/tmp \ --job_name $JOB_NAME ``` Window incoming data every five minutes instead of every minute. ```shell weather-mv bq --uris "gs://your-bucket/*.nc" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --topic "projects/$PROJECT/topics/$TOPIC_ID" \ --window_size 5 \ --runner DataflowRunner \ --project $PROJECT \ --temp_location gs://$BUCKET/tmp \ --job_name $JOB_NAME ``` Increase the number of shards per window. ```shell weather-mv bq --uris "gs://your-bucket/*.nc" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --topic "projects/$PROJECT/topics/$TOPIC_ID" \ --num_shards 10 \ --runner DataflowRunner \ --project $PROJECT \ --temp_location gs://$BUCKET/tmp \ --job_name $JOB_NAME ``` ### BigQuery Data is written into BigQuery using streaming inserts. It may take [up to 90 minutes](https://cloud.google.com/bigquery/streaming-data-into-bigquery#dataavailability) for buffers to persist into storage. However, weather data will be available for querying immediately. > Note: It's recommended that you specify variables to ingest (`-v, --variables`) instead of inferring the schema for > streaming pipelines. Not all variables will be distributed with every file, especially when they are in Grib format. ## Private Network Configuration While running `weather-mv` pipeline in GCP, there is a possibility that you may receive following error - "Quotas were exceeded: IN_USE_ADDRESSES" This error occurs when GCP is trying to add new worker-instances and finds that, “Public IP” quota (assigned to your project) is exhausted. To solve this, we recommend using private IP while running your dataflow pipelines. ```shell weather-mv bq --uris "gs://your-bucket/*.nc" \ --output_table $PROJECT.$DATASET_ID.$TABLE_ID \ --temp_location gs://$BUCKET/tmp \ --runner DataflowRunner \ --project $PROJECT \ --region $REGION \ --no_use_public_ips \ --network=$NETWORK \ --subnetwork=regions/$REGION/subnetworks/$SUBNETWORK ``` _Common options_: * `--no_use_public_ips`: To make Dataflow workers use private IP addresses for all communication, specify the command-line flag: --no_use_public_ips. Make sure that the specified network or subnetwork has Private Google Access enabled. * `--network`: The Compute Engine network for launching Compute Engine instances to run your pipeline. * `--subnetwork`: The Compute Engine subnetwork for launching Compute Engine instances to run your pipeline. For more information regarding how to configure Private IP, please refer to [Private IP Configuration Guide for Dataflow Pipeline Execution](../Private-IP-Configuration.md) . For more information regarding Pipeline options, please refer to [pipeline-options](https://cloud.google.com/dataflow/docs/reference/pipeline-options). ## Custom Dataflow Container for ECMWF dependencies (like MetView) It's difficult to install all necessary system dependencies on a Dataflow worker with a pure python solution. For example, MetView requires binaries to be installed on the system machine, which are broken in the standard debian install channels (they are only maintained via `conda-forge`). Thus, to include such dependencies, we've provided steps for you to build a [Beam container environment](https://beam.apache.org/documentation/runtime/environments/). In the near future, we'll arrange things so you don't have to worry about any of these extra steps ([#172](https://github.com/google/weather-tools/issues/172)). See [these instructions](../Runtime-Container.md) to learn how to build a custom image for this project. Currently, this image is necessary for the `weather-mv regrid` command, but no other commands. To deploy this tool, please do the following: 1. Host a container image of the included Dockerfile in your repository of choice (instructions for building images in GCS are in the next section). 2. Add the following two flags to your regrid pipeline. ``` --experiment=use_runner_v2 \ --sdk_container_image=$CONTAINER_URL ``` For example, the full Dataflow command, assuming you follow the next section's instructions, should look like: ```bash weather-mv rg --uris "gs://your-bucket/*.nc" \ --output_path "gs://regrid-bucket/" \ --runner DataflowRunner \ --project $PROJECT \ --region $REGION \ --temp_location "gs://$BUCKET/tmp" \ --experiment=use_runner_v2 \ --sdk_container_image="gcr.io/$PROJECT/$REPO:latest" --job_name $JOB_NAME ``` ================================================ FILE: weather_mv/__init__.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: weather_mv/loader_pipeline/__init__.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import sys from .pipeline import run, pipeline def cli(extra=[]): logging.getLogger().setLevel(logging.INFO) pipeline(*run(sys.argv + extra)) ================================================ FILE: weather_mv/loader_pipeline/bq.py ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import dataclasses import datetime import itertools import json import logging import math import os import pandas as pd import tempfile import typing as t from pprint import pformat import apache_beam as beam import geojson import numpy as np import xarray as xr import xarray_beam as xbeam from apache_beam.io import WriteToBigQuery, BigQueryDisposition from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.transforms import window from google.cloud import bigquery from xarray.core.utils import ensure_us_time_resolution from .sinks import ToDataSink, open_dataset, copy, open_local from .util import ( to_json_serializable_type, validate_region, _only_target_vars, get_coordinates, BQ_EXCLUDE_COORDS, ) logger = logging.getLogger(__name__) DEFAULT_IMPORT_TIME = datetime.datetime.utcfromtimestamp(0).replace(tzinfo=datetime.timezone.utc).isoformat() DATA_IMPORT_TIME_COLUMN = 'data_import_time' DATA_URI_COLUMN = 'data_uri' DATA_FIRST_STEP = 'data_first_step' GEO_POINT_COLUMN = 'geo_point' GEO_POLYGON_COLUMN = 'geo_polygon' LATITUDE_RANGE = (-90, 90) LONGITUDE_RANGE = (-180, 180) @dataclasses.dataclass class ToBigQuery(ToDataSink): """Load weather data into Google BigQuery. A sink that loads de-normalized weather data into BigQuery. First, this sink will create a BigQuery table from user input (either from `variables` or by inferring the schema). Next, it will convert the weather data into rows and then write each row to the BigQuery table. During a batch job, this transform will use the BigQueryWriter's file processing step, which requires that a `temp_location` is passed into the main CLI. This transform will perform streaming writes to BigQuery during a streaming Beam job. See `these docs`_ for more. Attributes: output_table: The destination for where data should be written in BigQuery geo_data_parquet_path: A path to dump the geo data parquet. This parquet consists of columns: latitude, longitude, geo_point, and geo_polygon. We calculate all of this information upfront so that we do not need to process it every time we process a set of files. variables: Target variables (or coordinates) for the BigQuery schema. By default, all data variables will be imported as columns. area: Target area in [N, W, S, E]; by default, all available area is included. import_time: The time when data was imported. This is used as a simple way to version data — variables can be distinguished based on import time. If None, the system will recompute the current time upon row extraction for each file. infer_schema: If true, this sink will attempt to read in an example data file read all its variables, and generate a BigQuery schema. xarray_open_dataset_kwargs: A dictionary of kwargs to pass to xr.open_dataset(). tif_metadata_for_start_time: If the input is a .tif file, parse the tif metadata at this location for a start time / initialization time. tif_metadata_for_end_time: If the input is a .tif file, parse the tif metadata at this location for a end/forecast time. skip_region_validation: Turn off validation that checks if all Cloud resources are in the same region. skip_creating_geo_data_parquet: Skip the generation of the geo data parquet if it already exists at the given --geo_data_parquet_path. Please note that the geo data parquet is mandatory for ingesting data into BigQuery. disable_grib_schema_normalization: Turn off grib's schema normalization; Default: normalization enabled. rows_chunk_size: The size of the chunk of rows to be loaded into memory for processing. Depending on your system's memory, use this to tune how much rows to process. .. _these docs: https://beam.apache.org/documentation/io/built-in/google-bigquery/#setting-the-insertion-method """ output_table: str geo_data_parquet_path: str variables: t.List[str] area: t.List[float] import_time: t.Optional[datetime.datetime] infer_schema: bool xarray_open_dataset_kwargs: t.Dict tif_metadata_for_start_time: t.Optional[str] tif_metadata_for_end_time: t.Optional[str] skip_region_validation: bool disable_grib_schema_normalization: bool rows_chunk_size: int = 1_000_000 skip_creating_polygon: bool = False skip_creating_geo_data_parquet: bool = False lat_grid_resolution: t.Optional[float] = None lon_grid_resolution: t.Optional[float] = None @classmethod def add_parser_arguments(cls, subparser: argparse.ArgumentParser): subparser.add_argument('-o', '--output_table', type=str, required=True, help="Full name of destination BigQuery table (..). Table " "will be created if it doesn't exist.") subparser.add_argument('--geo_data_parquet_path', type=str, required=True, help="A path to dump the geo data parquet.") subparser.add_argument('--skip_creating_geo_data_parquet', action='store_true', default=False, help="Skip the generation of geo data parquet if it already exists at given " "--geo_data_parquet_path. Please note that the geo data parquet is manditory for " " ingesting data into BigQuery. Default: off.") subparser.add_argument('-v', '--variables', metavar='variables', type=str, nargs='+', default=list(), help='Target variables (or coordinates) for the BigQuery schema. Default: will import ' 'all data variables as columns.') subparser.add_argument('-a', '--area', metavar='area', type=float, nargs='+', default=list(), help='Target area in [N, W, S, E]. Default: Will include all available area.') subparser.add_argument('--skip_creating_polygon', action='store_true', help='Not ingest grid points as polygons in BigQuery. Default: Ingest grid points as ' 'Polygon in BigQuery. Note: This feature relies on the assumption that the ' 'provided grid has an equal distance between consecutive points of latitude and ' 'longitude.') subparser.add_argument('--import_time', type=str, default=datetime.datetime.utcnow().isoformat(), help=("When writing data to BigQuery, record that data import occurred at this " "time (format: YYYY-MM-DD HH:MM:SS.usec+offset). Default: now in UTC.")) subparser.add_argument('--infer_schema', action='store_true', default=False, help='Download one file in the URI pattern and infer a schema from that file. Default: ' 'off') subparser.add_argument('--xarray_open_dataset_kwargs', type=json.loads, default='{}', help='Keyword-args to pass into `xarray.open_dataset()` in the form of a JSON string.') subparser.add_argument('--tif_metadata_for_start_time', type=str, default=None, help='Metadata that contains tif file\'s start/initialization time. ' 'Applicable only for tif files.') subparser.add_argument('--tif_metadata_for_end_time', type=str, default=None, help='Metadata that contains tif file\'s end/forecast time. ' 'Applicable only for tif files.') subparser.add_argument('-s', '--skip_region_validation', action='store_true', default=False, help='Skip validation of regions for data migration. Default: off') subparser.add_argument('--rows_chunk_size', type=int, default=1_000_000, help="The size of the chunk of rows to be loaded into memory for processing. " "Depending on your system's memory, use this to tune how much rows to process.") subparser.add_argument('--disable_grib_schema_normalization', action='store_true', default=False, help="To disable grib's schema normalization. Default: off") @classmethod def validate_arguments(cls, known_args: argparse.Namespace, pipeline_args: t.List[str]) -> None: pipeline_options = PipelineOptions(pipeline_args) pipeline_options_dict = pipeline_options.get_all_options() if known_args.area: assert len(known_args.area) == 4, 'Must specify exactly 4 lat/long values for area: N, W, S, E boundaries.' # Add a check for group_common_hypercubes. if pipeline_options_dict.get('group_common_hypercubes'): raise RuntimeError('--group_common_hypercubes can be specified only for earth engine ingestions.') # Check that all arguments are supplied for COG input. _, uri_extension = os.path.splitext(known_args.uris) if (uri_extension in ['.tif', '.tiff'] and not known_args.tif_metadata_for_start_time): raise RuntimeError("'--tif_metadata_for_start_time' is required for tif files.") elif uri_extension not in ['.tif', '.tiff'] and ( known_args.tif_metadata_for_start_time or known_args.tif_metadata_for_end_time ): raise RuntimeError("'--tif_metadata_for_start_time' and " "'--tif_metadata_for_end_time' can be specified only for tif files.") if not known_args.geo_data_parquet_path.endswith(".parquet"): raise RuntimeError(f"'--geo_data_parquet_path' {known_args.geo_data_parquet_path} must " "end with '.parquet'.") # Check that Cloud resource regions are consistent. if not (known_args.dry_run or known_args.skip_region_validation): # Program execution will terminate on failure of region validation. logger.info('Validating regions for data migration. This might take a few seconds...') validate_region(known_args.output_table, temp_location=pipeline_options_dict.get('temp_location'), region=pipeline_options_dict.get('region')) logger.info('Region validation completed successfully.') def generate_parquet( self, parquet_path: str, lats: t.List, lons: t.List, lat_grid_resolution: float, lon_grid_resolution: float, skip_creating_polygon: bool = False, ): """Generates geo data parquet.""" logger.info("Generating geo data parquet ...") # Generate Cartesian product of latitudes and longitudes. lat_lon_pairs = itertools.product(lats, lons) # Create a temp parquet file for writing. with tempfile.NamedTemporaryFile(suffix='.parquet', mode='w+', newline='') as temp: # Define header. header = ['latitude', 'longitude', GEO_POINT_COLUMN, GEO_POLYGON_COLUMN] data = [] for lat, lon in lat_lon_pairs: lat = float(lat) lon = float(lon) row = [lat, lon] sanitized_lon = (((lon % 360) + 540) % 360) - 180 # Fetch the geo point. geo_point = fetch_geo_point(lat, sanitized_lon) row.append(geo_point) # Fetch the geo polygon if not skipped. if not skip_creating_polygon: geo_polygon = fetch_geo_polygon(lat, sanitized_lon, lat_grid_resolution, lon_grid_resolution) row.append(geo_polygon) else: row.append(None) data.append(row) df = pd.DataFrame(data, columns=header) # Write DataFrame to parquet. df.to_parquet(temp.name, index=False) logger.info(f"geo data parquet generated successfully. Uploading to {parquet_path}.") copy(temp.name, parquet_path) logger.info(f"geo data parquet uploaded successfully at {parquet_path}.") def __post_init__(self): """Initializes Sink by creating a BigQuery table based on user input.""" if self.zarr: self.xarray_open_dataset_kwargs = self.zarr_kwargs with open_dataset(self.first_uri, self.xarray_open_dataset_kwargs, self.disable_grib_schema_normalization, self.tif_metadata_for_start_time, self.tif_metadata_for_end_time, is_zarr=self.zarr) as open_ds: if not self.skip_creating_polygon: logger.warning("Assumes that equal distance between consecutive points of latitude " "and longitude for the entire grid.") # Find the grid_resolution. if open_ds['latitude'].size > 1 and open_ds['longitude'].size > 1: latitude_length = len(open_ds['latitude']) longitude_length = len(open_ds['longitude']) latitude_range = np.ptp(open_ds["latitude"].values) longitude_range = np.ptp(open_ds["longitude"].values) self.lat_grid_resolution = abs(latitude_range / latitude_length) / 2 self.lon_grid_resolution = abs(longitude_range / longitude_length) / 2 else: self.skip_creating_polygon = True logger.warning("Polygon can't be genereated as provided dataset has a only single grid point.") else: logger.info("Polygon is not created as '--skip_creating_polygon' flag passed.") if not self.skip_creating_geo_data_parquet: if self.area: n, w, s, e = self.area open_ds = open_ds.sel(latitude=slice(n, s), longitude=slice(w, e)) lats = open_ds["latitude"].values.tolist() lons = open_ds["longitude"].values.tolist() self.generate_parquet( self.geo_data_parquet_path, [lats] if isinstance(lats, float) else lats, [lons] if isinstance(lons, float) else lons, self.lat_grid_resolution, self.lon_grid_resolution, self.skip_creating_polygon, ) else: logger.info("geo data parquet is not created as '--skip_creating_geo_data_parquet' flag passed.") # Define table from user input if self.variables and not self.infer_schema and not open_ds.attrs['is_normalized']: logger.info('Creating schema from input variables.') table_schema = to_table_schema( [('latitude', 'FLOAT64'), ('longitude', 'FLOAT64'), ('time', 'TIMESTAMP')] + [(var, 'FLOAT64') for var in self.variables] ) else: logger.info('Inferring schema from data.') ds: xr.Dataset = _only_target_vars(open_ds, self.variables) table_schema = dataset_to_table_schema(ds) if self.dry_run: logger.debug('Created the BigQuery table with schema...') logger.debug(f'\n{pformat(table_schema)}') return # Create the table in BigQuery try: table = bigquery.Table(self.output_table, schema=table_schema) self.table = bigquery.Client().create_table(table, exists_ok=True) except Exception as e: logger.error(f'Unable to create table in BigQuery: {e}') raise def prepare_coordinates(self, uri: str) -> t.Iterator[t.Tuple[str, t.Dict]]: """Open the dataset, filter by area, and prepare chunks of coordinates for parallel ingestion into BigQuery.""" logger.info(f'Preparing coordinates for: {uri!r}.') with open_dataset(uri, self.xarray_open_dataset_kwargs, self.disable_grib_schema_normalization, self.tif_metadata_for_start_time, self.tif_metadata_for_end_time, is_zarr=self.zarr) as ds: data_ds: xr.Dataset = _only_target_vars(ds, self.variables) for coordinate in get_coordinates(data_ds, uri): yield uri, coordinate def extract_rows(self, uri: str, coordinate: t.Dict) -> t.Iterator[t.Dict]: """Reads an asset and coordinates, then yields its rows as a mapping of column names to values.""" logger.info(f'Extracting rows for {coordinate!r} of {uri!r}.') # Re-calculate import time for streaming extractions. if not self.import_time: self.import_time = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc) with open_dataset(uri, self.xarray_open_dataset_kwargs, self.disable_grib_schema_normalization, self.tif_metadata_for_start_time, self.tif_metadata_for_end_time, is_zarr=self.zarr) as ds: data_ds: xr.Dataset = _only_target_vars(ds, self.variables) if self.area: n, w, s, e = self.area data_ds = data_ds.sel(latitude=slice(n, s), longitude=slice(w, e)) logger.info(f'Data filtered by area, size: {data_ds.nbytes}') yield from self.to_rows(coordinate, data_ds, uri) def to_rows(self, coordinate: t.Dict, ds: xr.Dataset, uri: str) -> t.Iterator[t.Dict]: first_ts_raw = ( ds.time[0].values if isinstance(ds.time.values, np.ndarray) else ds.time.values ) first_time_step = to_json_serializable_type(first_ts_raw) with open_local(self.geo_data_parquet_path) as master_lat_lon: selected_ds = ds.loc[coordinate] # Ensure that the latitude and longitude dimensions are in sync with the geo data parquet. if not BQ_EXCLUDE_COORDS - set(selected_ds.dims.keys()): selected_ds = selected_ds.transpose('latitude', 'longitude') vector_df = pd.read_parquet(master_lat_lon) if self.skip_creating_polygon: vector_df[GEO_POLYGON_COLUMN] = None # Add indexed coordinates. for k, v in coordinate.items(): vector_df[k] = to_json_serializable_type(v) # Add un-indexed coordinates. # Filter out excluded coordinates from coords. filtered_coords = (c for c in selected_ds.coords if c not in BQ_EXCLUDE_COORDS) for c in filtered_coords: if c not in coordinate and (not self.variables or c in self.variables): vector_df[c] = to_json_serializable_type(ensure_us_time_resolution(selected_ds[c].values)) # We are not directly assigning values to dataframe because we need to consider 'None'. # Vectorized operations are generally more faster and efficient than iterating over rows. # Furthermore, pd.Series does not enforces length consistency so just added a safety check. for var in selected_ds.data_vars: values = to_json_serializable_type(ensure_us_time_resolution(selected_ds[var].values.ravel())) if len(values) != len(vector_df): raise ValueError( f"Length of values {len(values)} does not match number of rows in DataFrame {len(vector_df)}." ) vector_df[var] = pd.Series(values, dtype=object) vector_df[DATA_IMPORT_TIME_COLUMN] = self.import_time vector_df[DATA_URI_COLUMN] = uri vector_df[DATA_FIRST_STEP] = first_time_step num_chunks = math.ceil(len(vector_df) / self.rows_chunk_size) logger.info(f"{uri!r} -- {coordinate!r}'s vector_df divided into {num_chunks} chunk(s).") for i in range(num_chunks): chunk = vector_df[i * self.rows_chunk_size:(i + 1) * self.rows_chunk_size] rows = chunk.to_dict('records') logger.info(f"{uri!r} -- {coordinate!r}'s rows for {i} chunk converted to dict.") yield from rows def chunks_to_rows(self, _, ds: xr.Dataset) -> t.Iterator[t.Dict]: uri = ds.attrs.get(DATA_URI_COLUMN, '') # Re-calculate import time for streaming extractions. if not self.import_time or self.zarr: self.import_time = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc) for coordinate in get_coordinates(ds, uri): yield from self.to_rows(coordinate, ds, uri) def expand(self, paths): """Extract rows of variables from data paths into a BigQuery table.""" if not self.zarr: extracted_rows = ( paths | 'PrepareCoordinates' >> beam.FlatMap(self.prepare_coordinates) | beam.Reshuffle() | 'ExtractRows' >> beam.FlatMapTuple(self.extract_rows) ) else: xarray_open_dataset_kwargs = self.xarray_open_dataset_kwargs.copy() xarray_open_dataset_kwargs.pop('chunks') start_date = xarray_open_dataset_kwargs.pop('start_date', None) end_date = xarray_open_dataset_kwargs.pop('end_date', None) ds, chunks = xbeam.open_zarr(self.first_uri, **xarray_open_dataset_kwargs) if start_date is not None and end_date is not None: ds = ds.sel(time=slice(start_date, end_date)) ds.attrs[DATA_URI_COLUMN] = self.first_uri extracted_rows = ( paths | 'OpenChunks' >> xbeam.DatasetToChunks(ds, chunks) | 'ExtractRows' >> beam.FlatMapTuple(self.chunks_to_rows) | 'Window' >> beam.WindowInto(window.FixedWindows(60)) | 'AddTimestamp' >> beam.Map(timestamp_row) ) if self.dry_run: return extracted_rows | 'Log Rows' >> beam.Map(logger.info) return ( extracted_rows | 'WriteToBigQuery' >> WriteToBigQuery( project=self.table.project, dataset=self.table.dataset_id, table=self.table.table_id, write_disposition=BigQueryDisposition.WRITE_APPEND, create_disposition=BigQueryDisposition.CREATE_NEVER) ) def map_dtype_to_sql_type(var_type: np.dtype) -> str: """Maps a np.dtype to a suitable BigQuery column type.""" if var_type in {np.dtype('float64'), np.dtype('float32'), np.dtype('timedelta64[ns]')}: return 'FLOAT64' elif var_type in {np.dtype(' t.List[bigquery.SchemaField]: """Returns a BigQuery table schema able to store the data in 'ds'.""" # Get the columns and data types for all variables in the dataframe columns = [ (str(col), map_dtype_to_sql_type(ds.variables[col].dtype)) for col in ds.variables.keys() if ds.variables[col].size != 0 ] return to_table_schema(columns) def to_table_schema(columns: t.List[t.Tuple[str, str]]) -> t.List[bigquery.SchemaField]: # Fields are all Nullable because data may have NANs. We treat these as null. fields = [ bigquery.SchemaField(column, var_type, mode='NULLABLE') for column, var_type in columns ] # Add an extra columns for recording import metadata. fields.append(bigquery.SchemaField(DATA_IMPORT_TIME_COLUMN, 'TIMESTAMP', mode='NULLABLE')) fields.append(bigquery.SchemaField(DATA_URI_COLUMN, 'STRING', mode='NULLABLE')) fields.append(bigquery.SchemaField(DATA_FIRST_STEP, 'TIMESTAMP', mode='NULLABLE')) fields.append(bigquery.SchemaField(GEO_POINT_COLUMN, 'GEOGRAPHY', mode='NULLABLE')) fields.append(bigquery.SchemaField(GEO_POLYGON_COLUMN, 'GEOGRAPHY', mode='NULLABLE')) return fields def timestamp_row(it: t.Dict) -> window.TimestampedValue: """Associate an extracted row with the import_time timestamp.""" timestamp = it[DATA_IMPORT_TIME_COLUMN].timestamp() return window.TimestampedValue(it, timestamp) def fetch_geo_point(lat: float, long: float) -> str: """Calculates a geography point from an input latitude and longitude.""" if lat > LATITUDE_RANGE[1] or lat < LATITUDE_RANGE[0]: raise ValueError(f"Invalid latitude value '{lat}'") if long > LONGITUDE_RANGE[1] or long < LONGITUDE_RANGE[0]: raise ValueError(f"Invalid longitude value '{long}'") point = geojson.dumps(geojson.Point((long, lat))) return point def fetch_geo_polygon(latitude: float, longitude: float, lat_grid_resolution: float, lon_grid_resolution: float) -> str: """Create a Polygon based on latitude, longitude and resolution. Example :: * - . - * | | . • . | | * - . - * In order to create the polygon, we require the `*` point as indicated in the above example. To determine the position of the `*` point, we find the `.` point. The `get_lat_lon_range` function gives the `.` point and `bound_point` gives the `*` point. """ lat_lon_bound = bound_point(latitude, longitude, lat_grid_resolution, lon_grid_resolution) polygon = geojson.dumps(geojson.Polygon([[ (lat_lon_bound[0][0], lat_lon_bound[0][1]), # lower_left (lat_lon_bound[1][0], lat_lon_bound[1][1]), # upper_left (lat_lon_bound[2][0], lat_lon_bound[2][1]), # upper_right (lat_lon_bound[3][0], lat_lon_bound[3][1]), # lower_right (lat_lon_bound[0][0], lat_lon_bound[0][1]), # lower_left ]])) return polygon def bound_point(latitude: float, longitude: float, lat_grid_resolution: float, lon_grid_resolution: float) -> t.List: """Calculate the bound point based on latitude, longitude and grid resolution. Example :: * - . - * | | . • . | | * - . - * This function gives the `*` point in the above example. """ lat_in_bound = latitude in [90.0, -90.0] lon_in_bound = longitude in [-180.0, 180.0] lat_range = get_lat_lon_range(latitude, "latitude", lat_in_bound, lat_grid_resolution, lon_grid_resolution) lon_range = get_lat_lon_range(longitude, "longitude", lon_in_bound, lat_grid_resolution, lon_grid_resolution) lower_left = [lon_range[1], lat_range[1]] upper_left = [lon_range[1], lat_range[0]] upper_right = [lon_range[0], lat_range[0]] lower_right = [lon_range[0], lat_range[1]] return [lower_left, upper_left, upper_right, lower_right] def get_lat_lon_range(value: float, lat_lon: str, is_point_out_of_bound: bool, lat_grid_resolution: float, lon_grid_resolution: float) -> t.List: """Calculate the latitude, longitude point range point latitude, longitude and grid resolution. Example :: * - . - * | | . • . | | * - . - * This function gives the `.` point in the above example. """ if lat_lon == 'latitude': if is_point_out_of_bound: return [-90 + lat_grid_resolution, 90 - lat_grid_resolution] else: return [value + lat_grid_resolution, value - lat_grid_resolution] else: if is_point_out_of_bound: return [-180 + lon_grid_resolution, 180 - lon_grid_resolution] else: return [value + lon_grid_resolution, value - lon_grid_resolution] ================================================ FILE: weather_mv/loader_pipeline/bq_test.py ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import json import logging import os import tempfile import typing as t import unittest import geojson import numpy as np import pandas as pd import simplejson import xarray as xr from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that, is_not_empty from google.cloud.bigquery import SchemaField from .bq import ( DEFAULT_IMPORT_TIME, dataset_to_table_schema, fetch_geo_point, fetch_geo_polygon, ToBigQuery, ) from .sinks_test import TestDataBase, _handle_missing_grib_be from .util import _only_target_vars logger = logging.getLogger(__name__) class SchemaCreationTests(TestDataBase): def setUp(self) -> None: super().setUp() self.test_dataset = { "coords": {"a": {"dims": ("a",), "data": [pd.Timestamp(0)], "attrs": {}}}, "attrs": {"is_normalized": False}, "dims": "a", "data_vars": { "b": {"dims": ("a",), "data": [np.float32(1.0)]}, "c": {"dims": ("a",), "data": [np.float64(2.0)]}, "d": {"dims": ("a",), "data": [3.0]}, } } self.test_dataset__with_schema_normalization = { "coords": {"a": {"dims": ("a",), "data": [pd.Timestamp(0)], "attrs": {}}}, "attrs": {"is_normalized": True}, "dims": "a", "data_vars": { "e_0_00_instant_b": {"dims": ("a",), "data": [np.float32(1.0)]}, "e_0_00_instant_c": {"dims": ("a",), "data": [np.float64(2.0)]}, "e_0_00_instant_d": {"dims": ("a",), "data": [3.0]}, } } def test_schema_generation(self): ds = xr.Dataset.from_dict(self.test_dataset) schema = dataset_to_table_schema(ds) expected_schema = [ SchemaField('a', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('b', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('c', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('d', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('data_import_time', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) def test_schema_generation__with_schema_normalization(self): ds = xr.Dataset.from_dict(self.test_dataset__with_schema_normalization) schema = dataset_to_table_schema(ds) expected_schema = [ SchemaField('a', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('e_0_00_instant_b', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('e_0_00_instant_c', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('e_0_00_instant_d', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('data_import_time', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) def test_schema_generation__with_target_columns(self): target_variables = ['c', 'd'] ds = _only_target_vars(xr.Dataset.from_dict(self.test_dataset), target_variables) schema = dataset_to_table_schema(ds) expected_schema = [ SchemaField('a', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('c', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('d', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('data_import_time', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) def test_schema_generation__with_target_columns__with_schema_normalization(self): target_variables = ['c', 'd'] ds = _only_target_vars(xr.Dataset.from_dict(self.test_dataset__with_schema_normalization), target_variables) schema = dataset_to_table_schema(ds) expected_schema = [ SchemaField('a', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('e_0_00_instant_c', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('e_0_00_instant_d', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('data_import_time', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) def test_schema_generation__no_targets_specified(self): target_variables = [] # intentionally empty ds = _only_target_vars(xr.Dataset.from_dict(self.test_dataset), target_variables) schema = dataset_to_table_schema(ds) expected_schema = [ SchemaField('a', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('b', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('c', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('d', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('data_import_time', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) def test_schema_generation__no_targets_specified__with_schema_normalization(self): target_variables = [] # intentionally empty ds = _only_target_vars(xr.Dataset.from_dict(self.test_dataset__with_schema_normalization), target_variables) schema = dataset_to_table_schema(ds) expected_schema = [ SchemaField('a', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('e_0_00_instant_b', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('e_0_00_instant_c', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('e_0_00_instant_d', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('data_import_time', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) def test_schema_generation__missing_target(self): with self.assertRaisesRegex(AssertionError, 'Target variable must be in original dataset.'): target_variables = ['a', 'foobar', 'd'] _only_target_vars(xr.Dataset.from_dict(self.test_dataset), target_variables) def test_schema_generation__missing_target__with_schema_normalization(self): with self.assertRaisesRegex(AssertionError, 'Target variable must be in original dataset.'): target_variables = ['a', 'foobar', 'd'] _only_target_vars(xr.Dataset.from_dict(self.test_dataset__with_schema_normalization), target_variables) @_handle_missing_grib_be def test_schema_generation__non_index_coords(self): test_single_var = xr.open_dataset( f'{self.test_data_folder}/test_data_grib_single_timestep', engine='cfgrib' ) schema = dataset_to_table_schema(test_single_var) expected_schema = [ SchemaField('number', 'INT64', 'NULLABLE', None, (), None), SchemaField('time', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('step', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('surface', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('latitude', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('longitude', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('valid_time', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('z', 'FLOAT64', 'NULLABLE', None, (), None), SchemaField('data_import_time', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('data_uri', 'STRING', 'NULLABLE', None, (), None), SchemaField('data_first_step', 'TIMESTAMP', 'NULLABLE', None, (), None), SchemaField('geo_point', 'GEOGRAPHY', 'NULLABLE', None, (), None), SchemaField('geo_polygon', 'GEOGRAPHY', 'NULLABLE', None, (), None) ] self.assertListEqual(schema, expected_schema) class ExtractRowsTestBase(TestDataBase): def extract(self, data_path, *, variables=None, area=None, open_dataset_kwargs=None, import_time=DEFAULT_IMPORT_TIME, disable_grib_schema_normalization=False, tif_metadata_for_start_time=None, tif_metadata_for_end_time=None, zarr: bool = False, zarr_kwargs=None, skip_creating_polygon: bool = False, geo_data_parquet_path, skip_creating_geo_data_parquet : bool = False) -> t.Iterator[t.Dict]: if zarr_kwargs is None: zarr_kwargs = {} op = ToBigQuery.from_kwargs( first_uri=data_path, dry_run=True, zarr=zarr, zarr_kwargs=zarr_kwargs, output_table='foo.bar.baz', variables=variables, area=area, xarray_open_dataset_kwargs=open_dataset_kwargs, import_time=import_time, infer_schema=False, tif_metadata_for_start_time=tif_metadata_for_start_time, tif_metadata_for_end_time=tif_metadata_for_end_time, skip_region_validation=True, disable_grib_schema_normalization=disable_grib_schema_normalization, rows_chunk_size=1_000_000, skip_creating_polygon=skip_creating_polygon, geo_data_parquet_path=geo_data_parquet_path, skip_creating_geo_data_parquet=skip_creating_geo_data_parquet) coords = op.prepare_coordinates(data_path) for uri, chunk in coords: yield from op.extract_rows(uri, chunk) def assertGeopointEqual(self, actual: str, expected: str) -> None: expected_json, actual_json = geojson.loads(expected), geojson.loads(actual) self.assertEqual(actual_json['type'], expected_json['type']) self.assertTrue(np.allclose(actual_json['coordinates'], expected_json['coordinates'])) def assertRowsEqual(self, actual: t.Dict, expected: t.Dict): self.assertEqual(expected.keys(), actual.keys()) for key in expected.keys(): if isinstance(expected[key], str): # Handle Geopoint JSON strings... try: self.assertGeopointEqual(actual[key], expected[key]) except (simplejson.JSONDecodeError, json.JSONDecodeError, KeyError): self.assertEqual(actual[key], expected[key]) continue self.assertAlmostEqual(actual[key], expected[key], places=4) self.assertNotIsInstance(actual[key], np.dtype) self.assertNotIsInstance(actual[key], np.float64) self.assertNotIsInstance(actual[key], np.float32) class ExtractRowsTest(ExtractRowsTestBase): def setUp(self) -> None: super().setUp() self.test_data_path = f'{self.test_data_folder}/test_data_20180101.nc' self.geo_data_parquet_path = f'{self.test_data_folder}/test_data_20180101_geo_data.parquet' def test_01_extract_rows(self): actual = next( self.extract( self.test_data_path, geo_data_parquet_path=self.geo_data_parquet_path, skip_creating_polygon=False, skip_creating_geo_data_parquet=False ) ) expected = { 'd2m': 242.3035430908203, 'data_import_time': '1970-01-01T00:00:00+00:00', 'data_first_step': '2018-01-02T06:00:00+00:00', 'data_uri': self.test_data_path, 'latitude': 49.0, 'longitude': -108.0, 'time': '2018-01-02T06:00:00+00:00', 'u10': 3.4776244163513184, 'v10': 0.03294110298156738, 'geo_point': geojson.dumps(geojson.Point((-108.0, 49.0))), 'geo_polygon': geojson.dumps(geojson.Polygon([ (-108.098837, 48.900826), (-108.098837, 49.099174), (-107.901163, 49.099174), (-107.901163, 48.900826), (-108.098837, 48.900826)])) } self.assertRowsEqual(actual, expected) def test_02_extract_rows__with_subset_variables(self): actual = next( self.extract( self.test_data_path, geo_data_parquet_path=self.geo_data_parquet_path, skip_creating_geo_data_parquet=True, skip_creating_polygon=True, variables=['u10'] ) ) expected = { 'data_import_time': '1970-01-01T00:00:00+00:00', 'data_first_step': '2018-01-02T06:00:00+00:00', 'data_uri': self.test_data_path, 'latitude': 49.0, 'longitude': -108.0, 'time': '2018-01-02T06:00:00+00:00', 'u10': 3.4776244163513184, 'geo_point': geojson.dumps(geojson.Point((-108.0, 49.0))), 'geo_polygon': None } self.assertRowsEqual(actual, expected) def test_03_extract_rows__specific_area(self): actual = next( self.extract( self.test_data_path, area=[45, -103, 33, -92], geo_data_parquet_path='./geo_data.parquet', skip_creating_polygon=True ) ) expected = { 'd2m': 246.19993591308594, 'data_import_time': '1970-01-01T00:00:00+00:00', 'data_first_step': '2018-01-02T06:00:00+00:00', 'data_uri': self.test_data_path, 'latitude': 45.0, 'longitude': -103.0, 'time': '2018-01-02T06:00:00+00:00', 'u10': 2.73445987701416, 'v10': 0.08277571201324463, 'geo_point': geojson.dumps(geojson.Point((-103.0, 45.0))), 'geo_polygon': None } self.assertRowsEqual(actual, expected) def test_04_extract_rows__specific_area_float_points(self): actual = next( self.extract( self.test_data_path, geo_data_parquet_path='./geo_data.parquet', area=[45.34, -103.45, 33.34, -92.87] ) ) expected = { 'd2m': 246.47116088867188, 'data_import_time': '1970-01-01T00:00:00+00:00', 'data_first_step': '2018-01-02T06:00:00+00:00', 'data_uri': self.test_data_path, 'latitude': 45.20000076293945, 'longitude': -103.4000015258789, 'time': '2018-01-02T06:00:00+00:00', 'u10': 3.94743275642395, 'v10': -0.19749987125396729, 'geo_point': geojson.dumps(geojson.Point((-103.400002, 45.200001))), 'geo_polygon': geojson.dumps(geojson.Polygon([ (-103.498839, 45.100827), (-103.498839, 45.299174), (-103.301164, 45.299174), (-103.301164, 45.100827), (-103.498839, 45.100827)])) } self.assertRowsEqual(actual, expected) def test_05_extract_rows_raises_error_when_geo_data_parquet_dimensions_mismatch(self): with self.assertRaisesRegex(ValueError, 'Length of values '): next( self.extract( self.test_data_path, area=[45, -103, 33, -92], geo_data_parquet_path=self.geo_data_parquet_path, skip_creating_geo_data_parquet=True, skip_creating_polygon=True ) ) def test_06_extract_rows__specify_import_time(self): now = datetime.datetime.utcnow().isoformat() actual = next( self.extract( self.test_data_path, geo_data_parquet_path=self.geo_data_parquet_path, skip_creating_geo_data_parquet=True, import_time=now ) ) expected = { 'd2m': 242.3035430908203, 'data_import_time': now, 'data_first_step': '2018-01-02T06:00:00+00:00', 'data_uri': self.test_data_path, 'latitude': 49.0, 'longitude': -108.0, 'time': '2018-01-02T06:00:00+00:00', 'u10': 3.4776244163513184, 'v10': 0.03294110298156738, 'geo_point': geojson.dumps(geojson.Point((-108.0, 49.0))), 'geo_polygon': geojson.dumps(geojson.Polygon([ (-108.098837, 48.900826), (-108.098837, 49.099174), (-107.901163, 49.099174), (-107.901163, 48.900826), (-108.098837, 48.900826)])) } self.assertRowsEqual(actual, expected) def test_07_extract_rows_single_point(self): self.test_data_path = f'{self.test_data_folder}/test_data_single_point.nc' self.geo_data_parquet_path = f'{self.test_data_folder}/test_data_single_point_geo_data.parquet' actual = next( self.extract( self.test_data_path, geo_data_parquet_path=self.geo_data_parquet_path, skip_creating_geo_data_parquet=False, ) ) expected = { 'd2m': 242.3035430908203, 'data_import_time': '1970-01-01T00:00:00+00:00', 'data_first_step': '2018-01-02T06:00:00+00:00', 'data_uri': self.test_data_path, 'latitude': 49.0, 'longitude': -108.0, 'time': '2018-01-02T06:00:00+00:00', 'u10': 3.4776244163513184, 'v10': 0.03294110298156738, 'geo_point': geojson.dumps(geojson.Point((-108.0, 49.0))), 'geo_polygon': None } self.assertRowsEqual(actual, expected) def test_08_extract_rows_nan(self): self.test_data_path = f'{self.test_data_folder}/test_data_has_nan.nc' self.geo_data_parquet_path = f'{self.test_data_folder}/test_data_has_nan_geo_data.parquet' actual = next( self.extract( self.test_data_path, geo_data_parquet_path=self.geo_data_parquet_path, skip_creating_geo_data_parquet=False, ) ) expected = { 'd2m': 242.3035430908203, 'data_import_time': '1970-01-01T00:00:00+00:00', 'data_first_step': '2018-01-02T06:00:00+00:00', 'data_uri': self.test_data_path, 'latitude': 49.0, 'longitude': -108.0, 'time': '2018-01-02T06:00:00+00:00', 'u10': None, 'v10': 0.03294110298156738, 'geo_point': geojson.dumps(geojson.Point((-108.0, 49.0))), 'geo_polygon': geojson.dumps(geojson.Polygon([ (-108.098837, 48.900826), (-108.098837, 49.099174), (-107.901163, 49.099174), (-107.901163, 48.900826), (-108.098837, 48.900826)])) } self.assertRowsEqual(actual, expected) def test_09_extract_rows__with_valid_lat_long_with_point(self): valid_lat_long = [[-90, 0], [-90, 1], [-45, -180], [-45, -45], [0, 0], [45, 45], [45, -180], [90, -1], [90, 0]] actual_val = [ '{"type": "Point", "coordinates": [0, -90]}', '{"type": "Point", "coordinates": [1, -90]}', '{"type": "Point", "coordinates": [-180, -45]}', '{"type": "Point", "coordinates": [-45, -45]}', '{"type": "Point", "coordinates": [0, 0]}', '{"type": "Point", "coordinates": [45, 45]}', '{"type": "Point", "coordinates": [-180, 45]}', '{"type": "Point", "coordinates": [-1, 90]}', '{"type": "Point", "coordinates": [0, 90]}' ] for actual, (lat, long) in zip(actual_val, valid_lat_long): with self.subTest(): expected = fetch_geo_point(lat, long) self.assertEqual(actual, expected) def test_10_extract_rows__with_valid_lat_long_with_polygon(self): valid_lat_long = [[-90, 0], [-90, -180], [-45, -180], [-45, 180], [0, 0], [90, 180], [45, -180], [-90, 180], [90, 1], [0, 180], [1, -180], [90, -180]] actual_val = [ '{"type": "Polygon", "coordinates": [[[-1, 89], [-1, -89], [1, -89], [1, 89], [-1, 89]]]}', '{"type": "Polygon", "coordinates": [[[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]]}', '{"type": "Polygon", "coordinates": [[[179, -46], [179, -44], [-179, -44], [-179, -46], [179, -46]]]}', '{"type": "Polygon", "coordinates": [[[179, -46], [179, -44], [-179, -44], [-179, -46], [179, -46]]]}', '{"type": "Polygon", "coordinates": [[[-1, -1], [-1, 1], [1, 1], [1, -1], [-1, -1]]]}', '{"type": "Polygon", "coordinates": [[[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]]}', '{"type": "Polygon", "coordinates": [[[179, 44], [179, 46], [-179, 46], [-179, 44], [179, 44]]]}', '{"type": "Polygon", "coordinates": [[[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]]}', '{"type": "Polygon", "coordinates": [[[0, 89], [0, -89], [2, -89], [2, 89], [0, 89]]]}', '{"type": "Polygon", "coordinates": [[[179, -1], [179, 1], [-179, 1], [-179, -1], [179, -1]]]}', '{"type": "Polygon", "coordinates": [[[179, 0], [179, 2], [-179, 2], [-179, 0], [179, 0]]]}', '{"type": "Polygon", "coordinates": [[[179, 89], [179, -89], [-179, -89], [-179, 89], [179, 89]]]}' ] lat_grid_resolution = 1 lon_grid_resolution = 1 for actual, (lat, long) in zip(actual_val, valid_lat_long): with self.subTest(): expected = fetch_geo_polygon(lat, long, lat_grid_resolution, lon_grid_resolution) self.assertEqual(actual, expected) def test_11_extract_rows__with_invalid_lat_lon(self): invalid_lat_long = [[-100, -2000], [-100, -500], [100, 500], [100, 2000]] for (lat, long) in invalid_lat_long: with self.subTest(): with self.assertRaises(ValueError): fetch_geo_point(lat, long) def test_12_extract_rows_zarr(self): input_path = os.path.join(self.test_data_folder, 'test_data.zarr') geo_data_parquet_path = os.path.join(self.test_data_folder, 'test_data_zarr_geo_data.parquet') actual = next( self.extract( input_path, geo_data_parquet_path=geo_data_parquet_path, skip_creating_geo_data_parquet=False, zarr=True ) ) expected = { 'cape': 0.623291015625, 'd2m': 237.5404052734375, 'data_import_time': '1970-01-01T00:00:00+00:00', 'data_first_step': '1959-01-01T00:00:00+00:00', 'data_uri': input_path, 'latitude': 90, 'longitude': 0, 'time': '1959-01-01T00:00:00+00:00', 'geo_point': geojson.dumps(geojson.Point((0.0, 90.0))), 'geo_polygon': geojson.dumps(geojson.Polygon([ (-0.124913, 89.875173), (-0.124913, -89.875173), (0.124913, -89.875173), (0.124913, 89.875173), (-0.124913, 89.875173)])) } self.assertRowsEqual(actual, expected) def test_13_droping_variable_while_opening_zarr(self): input_path = os.path.join(self.test_data_folder, 'test_data.zarr') geo_data_parquet_path = os.path.join(self.test_data_folder, 'test_data_zarr_geo_data.parquet') actual = next( self.extract( input_path, geo_data_parquet_path=geo_data_parquet_path, skip_creating_geo_data_parquet=True, zarr=True, zarr_kwargs={'drop_variables': ['cape']} ) ) expected = { 'd2m': 237.5404052734375, 'data_import_time': '1970-01-01T00:00:00+00:00', 'data_first_step': '1959-01-01T00:00:00+00:00', 'data_uri': input_path, 'latitude': 90, 'longitude': 0, 'time': '1959-01-01T00:00:00+00:00', 'geo_point': geojson.dumps(geojson.Point((0.0, 90.0))), 'geo_polygon': geojson.dumps(geojson.Polygon([ (-0.124913, 89.875173), (-0.124913, -89.875173), (0.124913, -89.875173), (0.124913, 89.875173), (-0.124913, 89.875173)])) } self.assertRowsEqual(actual, expected) class ExtractRowsTifSupportTest(ExtractRowsTestBase): def setUp(self) -> None: super().setUp() self.test_data_path = f'{self.test_data_folder}/test_data_tif_time.tif' self.geo_data_parquet_path = f'{self.test_data_folder}/test_data_tif_time_geo_data.parquet' def test_01_extract_rows_with_end_time(self): actual = next( self.extract( self.test_data_path, geo_data_parquet_path=self.geo_data_parquet_path, skip_creating_geo_data_parquet=False, tif_metadata_for_start_time='start_time', tif_metadata_for_end_time='end_time' ) ) expected = { 'dewpoint_temperature_2m': 281.09349060058594, 'temperature_2m': 296.8329772949219, 'data_import_time': '1970-01-01T00:00:00+00:00', 'data_first_step': '2020-07-01T00:00:00+00:00', 'data_uri': self.test_data_path, 'latitude': 42.09783344918844, 'longitude': -123.66686981141397, 'time': '2020-07-01T00:00:00+00:00', 'valid_time': '2020-07-01T00:00:00+00:00', 'geo_point': geojson.dumps(geojson.Point((-123.66687, 42.097833))), 'geo_polygon': geojson.dumps(geojson.Polygon([ (-123.669853, 42.095605), (-123.669853, 42.100066), (-123.663885, 42.100066), (-123.663885, 42.095605), (-123.669853, 42.095605)])) } self.assertRowsEqual(actual, expected) def test_02_extract_rows_without_end_time(self): actual = next( self.extract( self.test_data_path, geo_data_parquet_path=self.geo_data_parquet_path, skip_creating_geo_data_parquet=True, tif_metadata_for_start_time='start_time' ) ) expected = { 'dewpoint_temperature_2m': 281.09349060058594, 'temperature_2m': 296.8329772949219, 'data_import_time': '1970-01-01T00:00:00+00:00', 'data_first_step': '2020-07-01T00:00:00+00:00', 'data_uri': self.test_data_path, 'latitude': 42.09783344918844, 'longitude': -123.66686981141397, 'time': '2020-07-01T00:00:00+00:00', 'geo_point': geojson.dumps(geojson.Point((-123.66687, 42.097833))), 'geo_polygon': geojson.dumps(geojson.Polygon([ (-123.669853, 42.095605), (-123.669853, 42.100066), (-123.663885, 42.100066), (-123.663885, 42.095605), (-123.669853, 42.095605)])) } self.assertRowsEqual(actual, expected) class ExtractRowsGribSupportTest(ExtractRowsTestBase): def setUp(self) -> None: super().setUp() self.test_data_path = f'{self.test_data_folder}/test_data_grib_single_timestep' self.geo_data_parquet_path = f'{self.test_data_folder}/test_data_grib_single_timestep_geo_data.parquet' @_handle_missing_grib_be def test_01_extract_rows(self): actual = next( self.extract( self.test_data_path, geo_data_parquet_path=self.geo_data_parquet_path, skip_creating_geo_data_parquet=False, disable_grib_schema_normalization=True ) ) expected = { 'data_import_time': '1970-01-01T00:00:00+00:00', 'data_first_step': '2021-10-18T06:00:00+00:00', 'data_uri': self.test_data_path, 'latitude': 90.0, 'longitude': -180.0, 'number': 0, 'step': 0.0, 'surface': 0.0, 'time': '2021-10-18T06:00:00+00:00', 'valid_time': '2021-10-18T06:00:00+00:00', 'z': 1.42578125, 'geo_point': geojson.dumps(geojson.Point((-180.0, 90.0))), 'geo_polygon': geojson.dumps(geojson.Polygon([ (179.950014, 89.950028), (179.950014, -89.950028), (-179.950014, -89.950028), (-179.950014, 89.950028), (179.950014, 89.950028)])) } self.assertRowsEqual(actual, expected) @_handle_missing_grib_be def test_02_extract_rows__with_vars__excludes_non_index_coords__without_schema_normalization(self): actual = next( self.extract( self.test_data_path, geo_data_parquet_path=self.geo_data_parquet_path, skip_creating_geo_data_parquet=True, disable_grib_schema_normalization=True, variables=['z'])) expected = { 'data_import_time': '1970-01-01T00:00:00+00:00', 'data_first_step': '2021-10-18T06:00:00+00:00', 'data_uri': self.test_data_path, 'latitude': 90.0, 'longitude': -180.0, 'z': 1.42578125, 'geo_point': geojson.dumps(geojson.Point((-180.0, 90.0))), 'geo_polygon': geojson.dumps(geojson.Polygon([ (179.950014, 89.950028), (179.950014, -89.950028), (-179.950014, -89.950028), (-179.950014, 89.950028), (179.950014, 89.950028)])) } self.assertRowsEqual(actual, expected) @_handle_missing_grib_be def test_03_extract_rows__with_vars__includes_coordinates_in_vars__without_schema_normalization(self): actual = next( self.extract( self.test_data_path, geo_data_parquet_path=self.geo_data_parquet_path, skip_creating_geo_data_parquet=True, disable_grib_schema_normalization=True, variables=['z', 'step'] ) ) expected = { 'data_import_time': '1970-01-01T00:00:00+00:00', 'data_first_step': '2021-10-18T06:00:00+00:00', 'data_uri': self.test_data_path, 'latitude': 90.0, 'longitude': -180.0, 'step': 0, 'z': 1.42578125, 'geo_point': geojson.dumps(geojson.Point((-180.0, 90.0))), 'geo_polygon': geojson.dumps(geojson.Polygon([ (179.950014, 89.950028), (179.950014, -89.950028), (-179.950014, -89.950028), (-179.950014, 89.950028), (179.950014, 89.950028)])) } self.assertRowsEqual(actual, expected) @_handle_missing_grib_be def test_04_extract_rows__with_vars__excludes_non_index_coords__with_schema_normalization(self): actual = next( self.extract( self.test_data_path, geo_data_parquet_path=self.geo_data_parquet_path, skip_creating_geo_data_parquet=True, variables=['z'] ) ) expected = { 'data_import_time': '1970-01-01T00:00:00+00:00', 'data_first_step': '2021-10-18T06:00:00+00:00', 'data_uri': self.test_data_path, 'latitude': 90.0, 'longitude': -180.0, 'surface_0_00_instant_z': 1.42578125, 'geo_point': geojson.dumps(geojson.Point((-180.0, 90.0))), 'geo_polygon': geojson.dumps(geojson.Polygon([ (179.950014, 89.950028), (179.950014, -89.950028), (-179.950014, -89.950028), (-179.950014, 89.950028), (179.950014, 89.950028)])) } self.assertRowsEqual(actual, expected) @_handle_missing_grib_be def test_05_extract_rows__with_vars__includes_coordinates_in_vars__with_schema_normalization(self): actual = next( self.extract( self.test_data_path, geo_data_parquet_path=self.geo_data_parquet_path, skip_creating_geo_data_parquet=True, variables=['z', 'step'] ) ) expected = { 'data_import_time': '1970-01-01T00:00:00+00:00', 'data_first_step': '2021-10-18T06:00:00+00:00', 'data_uri': self.test_data_path, 'latitude': 90.0, 'longitude': -180.0, 'step': 0, 'surface_0_00_instant_z': 1.42578125, 'geo_point': geojson.dumps(geojson.Point((-180.0, 90.0))), 'geo_polygon': geojson.dumps(geojson.Polygon([ (179.950014, 89.950028), (179.950014, -89.950028), (-179.950014, -89.950028), (-179.950014, 89.950028), (179.950014, 89.950028)])) } self.assertRowsEqual(actual, expected) @_handle_missing_grib_be def test_06_multiple_editions__without_schema_normalization(self): self.test_data_path = f'{self.test_data_folder}/test_data_grib_multiple_edition_single_timestep.bz2' self.geo_data_parquet_path = ( f'{self.test_data_folder}/test_data_grib_multiple_edition_single_timestep_geo_data.parquet' ) actual = next( self.extract( self.test_data_path, geo_data_parquet_path=self.geo_data_parquet_path, skip_creating_geo_data_parquet=False, disable_grib_schema_normalization=True ) ) expected = { 'cape': 0.0, 'cbh': None, 'cp': 0.0, 'crr': 0.0, 'd2m': 248.3846893310547, 'data_first_step': '2021-12-10T12:00:00+00:00', 'data_import_time': '1970-01-01T00:00:00+00:00', 'data_uri': self.test_data_path, 'depthBelowLandLayer': 0.0, 'dsrp': 0.0, 'fdir': 0.0, 'hcc': 0.0, 'hcct': None, 'hwbt0': 0.0, 'i10fg': 7.41250467300415, 'latitude': 90.0, 'longitude': -180.0, 'lsp': 1.1444091796875e-05, 'mcc': 0.0, 'msl': 99867.3125, 'number': 0, 'p3020': 20306.701171875, 'sd': 0.0, 'sf': 1.049041748046875e-05, 'sp': 99867.15625, 'step': 28800.0, 'stl1': 251.02520751953125, 'surface': 0.0, 'swvl1': -1.9539930654413618e-13, 't2m': 251.18968200683594, 'tcc': 0.9609375, 'tcrw': 0.0, 'tcw': 2.314192295074463, 'tcwv': 2.314192295074463, 'time': '2021-12-10T12:00:00+00:00', 'tp': 1.1444091796875e-05, 'tsr': 0.0, 'u10': -4.6668853759765625, 'u100': -7.6197662353515625, 'u200': -9.176498413085938, 'v10': -3.2414093017578125, 'v100': -4.1650390625, 'v200': -3.6647186279296875, 'valid_time': '2021-12-10T20:00:00+00:00', 'geo_point': geojson.dumps(geojson.Point((-180.0, 90.0))), 'geo_polygon': geojson.dumps(geojson.Polygon([ (179.950014, 89.950028), (179.950014, -89.950028), (-179.950014, -89.950028), (-179.950014, 89.950028), (179.950014, 89.950028)])) } self.assertRowsEqual(actual, expected) @_handle_missing_grib_be def test_07_multiple_editions__with_schema_normalization(self): self.test_data_path = f'{self.test_data_folder}/test_data_grib_multiple_edition_single_timestep.bz2' self.geo_data_parquet_path = ( f'{self.test_data_folder}/test_data_grib_multiple_edition_single_timestep_geo_data.parquet' ) actual = next( self.extract( self.test_data_path, geo_data_parquet_path=self.geo_data_parquet_path, skip_creating_geo_data_parquet=True, ) ) expected = { 'surface_0_00_instant_cape': 0.0, 'surface_0_00_instant_cbh': None, 'surface_0_00_instant_cp': 0.0, 'surface_0_00_instant_crr': 0.0, 'surface_0_00_instant_d2m': 248.3846893310547, 'data_first_step': '2021-12-10T12:00:00+00:00', 'data_import_time': '1970-01-01T00:00:00+00:00', 'data_uri': self.test_data_path, 'surface_0_00_instant_dsrp': 0.0, 'surface_0_00_instant_fdir': 0.0, 'surface_0_00_instant_hcc': 0.0, 'surface_0_00_instant_hcct': None, 'surface_0_00_instant_hwbt0': 0.0, 'surface_0_00_instant_i10fg': 7.41250467300415, 'latitude': 90.0, 'longitude': -180.0, 'surface_0_00_instant_lsp': 1.1444091796875e-05, 'surface_0_00_instant_mcc': 0.0, 'surface_0_00_instant_msl': 99867.3125, 'number': 0, 'surface_0_00_instant_p3020': 20306.701171875, 'surface_0_00_instant_sd': 0.0, 'surface_0_00_instant_sf': 1.049041748046875e-05, 'surface_0_00_instant_sp': 99867.15625, 'step': 28800.0, 'depthBelowLandLayer_0_00_instant_stl1': 251.02520751953125, 'depthBelowLandLayer_0_00_instant_swvl1': -1.9539930654413618e-13, 'depthBelowLandLayer_7_00_instant_stl2': 253.54124450683594, 'entireAtmosphere_0_00_instant_litoti': 0.0, 'surface_0_00_instant_t2m': 251.18968200683594, 'surface_0_00_instant_tcc': 0.9609375, 'surface_0_00_instant_tcrw': 0.0, 'surface_0_00_instant_tcw': 2.314192295074463, 'surface_0_00_instant_tcwv': 2.314192295074463, 'time': '2021-12-10T12:00:00+00:00', 'surface_0_00_instant_tp': 1.1444091796875e-05, 'surface_0_00_instant_tsr': 0.0, 'surface_0_00_instant_u10': -4.6668853759765625, 'surface_0_00_instant_u100': -7.6197662353515625, 'surface_0_00_instant_u200': -9.176498413085938, 'surface_0_00_instant_v10': -3.2414093017578125, 'surface_0_00_instant_v100': -4.1650390625, 'surface_0_00_instant_v200': -3.6647186279296875, 'surface_0_00_instant_ptype': 5.0, 'surface_0_00_instant_tprate': 0.0, 'surface_0_00_instant_ceil': 179.17018127441406, 'valid_time': '2021-12-10T20:00:00+00:00', 'geo_point': geojson.dumps(geojson.Point((-180.0, 90.0))), 'geo_polygon': geojson.dumps(geojson.Polygon([ (179.950014, 89.950028), (179.950014, -89.950028), (-179.950014, -89.950028), (-179.950014, 89.950028), (179.950014, 89.950028)])) } self.assertRowsEqual(actual, expected) @_handle_missing_grib_be def test_08_multiple_editions__with_vars__includes_coordinates_in_vars__with_schema_normalization(self): self.test_data_path = f'{self.test_data_folder}/test_data_grib_multiple_edition_single_timestep.bz2' self.geo_data_parquet_path = ( f'{self.test_data_folder}/test_data_grib_multiple_edition_single_timestep_geo_data.parquet' ) actual = next( self.extract( self.test_data_path, geo_data_parquet_path=self.geo_data_parquet_path, skip_creating_geo_data_parquet=True, variables=['p3020', 'depthBelowLandLayer', 'step'] ) ) expected = { 'data_import_time': '1970-01-01T00:00:00+00:00', 'data_first_step': '2021-12-10T12:00:00+00:00', 'data_uri': self.test_data_path, 'latitude': 90.0, 'longitude': -180.0, 'step': 28800.0, 'surface_0_00_instant_p3020': 20306.701171875, 'depthBelowLandLayer_0_00_instant_swvl1': -1.9539930654413618e-13, 'depthBelowLandLayer_0_00_instant_stl1': 251.02520751953125, 'depthBelowLandLayer_7_00_instant_stl2': 253.54124450683594, 'geo_point': geojson.dumps(geojson.Point((-180.0, 90.0))), 'geo_polygon': geojson.dumps(geojson.Polygon([ (179.950014, 89.950028), (179.950014, -89.950028), (-179.950014, -89.950028), (-179.950014, 89.950028), (179.950014, 89.950028)])) } self.assertRowsEqual(actual, expected) class ExtractRowsFromZarrTest(ExtractRowsTestBase): def setUp(self) -> None: super().setUp() self.tmpdir = tempfile.TemporaryDirectory() def tearDown(self) -> None: super().tearDown() self.tmpdir.cleanup() def test_01_extracts_rows(self): input_zarr = os.path.join(self.tmpdir.name, 'air_temp.zarr') ds = ( xr.tutorial.open_dataset('air_temperature', cache_dir=self.test_data_folder) .isel(time=slice(0, 4), lat=slice(0, 4), lon=slice(0, 4)) .rename(dict(lon='longitude', lat='latitude')) ) ds.to_zarr(input_zarr) op = ToBigQuery.from_kwargs( first_uri=input_zarr, zarr_kwargs=dict(chunks=None, consolidated=True), dry_run=True, zarr=True, output_table='foo.bar.baz', variables=list(), area=list(), xarray_open_dataset_kwargs=dict(), import_time=None, infer_schema=False, tif_metadata_for_start_time=None, tif_metadata_for_end_time=None, skip_region_validation=True, disable_grib_schema_normalization=False, geo_data_parquet_path=os.path.join(self.tmpdir.name, 'geo_data.parquet'), ) with TestPipeline() as p: result = p | op assert_that(result, is_not_empty()) if __name__ == '__main__': unittest.main() ================================================ FILE: weather_mv/loader_pipeline/ee.py ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import csv import dataclasses import json import logging import math import os import re import shutil import subprocess import tempfile import time import typing as t from multiprocessing import Process, Queue import apache_beam as beam import ee import numpy as np from apache_beam.io.filesystems import FileSystems from apache_beam.metrics import metric from apache_beam.io.gcp.gcsio import WRITE_CHUNK_SIZE from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.utils import retry from google.auth import compute_engine, default, credentials from google.auth.transport import requests from google.auth.transport.requests import AuthorizedSession from rasterio.io import MemoryFile from .sinks import ToDataSink, open_dataset, open_local, KwargsFactoryMixin, copy from .util import make_attrs_ee_compatible, RateLimit, validate_region, get_utc_timestamp from .metrics import timeit, AddTimer, AddMetrics logger = logging.getLogger(__name__) COMPUTE_ENGINE_STR = 'Metadata-Flavor: Google' # For EE ingestion retry logic. INITIAL_DELAY = 1.0 # Initial delay in seconds. MAX_DELAY = 600 # Maximum delay before giving up in seconds. NUM_RETRIES = 10 # Number of tries with exponential backoff. TASK_QUEUE_WAIT_TIME = 120 # Task queue wait time in seconds. ASSET_TYPE_TO_EXTENSION_MAPPING = { 'IMAGE': '.tiff', 'TABLE': '.csv' } ROWS_PER_WRITE = 10_000 # Number of rows per feature collection write. def is_compute_engine() -> bool: """Determines if the application in running in Compute Engine Environment.""" command = ['curl', 'metadata.google.internal', '-i'] result = subprocess.run(command, stdout=subprocess.PIPE) result_output = result.stdout.decode('utf-8') return COMPUTE_ENGINE_STR in result_output def get_creds(use_personal_account: bool, service_account: str, private_key: str) -> credentials.Credentials: """Fetches credentials for authentication. If the `use_personal_account` argument is true then it will authenticate with pop-up browser window using personal account. Otherwise, if the application is running in compute engine, it will use credentials of service account bound to the VM. Otherwise, it will try to use user credentials. Args: use_personal_account: A flag to use personal account for ee authentication. service_account: Service account address when using a private key for earth engine authentication. private_key: A private key path to authenticate earth engine using private key. Returns: cred: Credentials object. """ # TODO(Issue #197): Test private key authentication. if service_account and private_key: try: with open_local(private_key) as local_path: creds = ee.ServiceAccountCredentials(service_account, local_path) except Exception: raise RuntimeError(f'Unable to open the private key {private_key}.') elif use_personal_account: ee.Authenticate() creds, _ = default() elif is_compute_engine(): creds = compute_engine.Credentials() else: creds, _ = default() creds.refresh(requests.Request()) return creds def ee_initialize(use_personal_account: bool = False, enforce_high_volume: bool = False, service_account: t.Optional[str] = None, private_key: t.Optional[str] = None, project_id: t.Optional[str] = None) -> None: """Initializes earth engine with the high volume API when using a compute engine VM. Args: use_personal_account: A flag to use personal account for ee authentication. Default: False. enforce_high_volume: A flag to use the high volume API when using a compute engine VM. Default: False. service_account: Service account address when using a private key for earth engine authentication. private_key: A private key path to authenticate earth engine using private key. Default: None. Project ID: An identifier that represents the name of a project present in Earth Engine. Raises: RuntimeError: Earth Engine did not initialize. """ creds = get_creds(use_personal_account, service_account, private_key) on_compute_engine = is_compute_engine() # Using the high volume api. if on_compute_engine: if project_id is None and use_personal_account: raise RuntimeError('Project_name should not be None!') params = {'credentials': creds, 'opt_url': 'https://earthengine-highvolume.googleapis.com'} if project_id: params['project'] = project_id ee.Initialize(**params) # Only the compute engine service service account can access the high volume api. elif enforce_high_volume and not on_compute_engine: raise RuntimeError( 'Must run on a compute engine VM to use the high volume earth engine api.' ) else: ee.Initialize(creds) class SetupEarthEngine(RateLimit): """A base class to setup the earth engine.""" def __init__(self, ee_qps: int, ee_latency: float, ee_max_concurrent: int, private_key: str, service_account: str, use_personal_account: bool, use_metrics: bool): super().__init__(global_rate_limit_qps=ee_qps, latency_per_request=ee_latency, max_concurrent_requests=ee_max_concurrent, use_metrics=use_metrics) self._has_setup = False self.private_key = private_key self.service_account = service_account self.use_personal_account = use_personal_account self.use_metrics = use_metrics def setup(self, project_id): """Makes sure ee is set up on every worker.""" ee_initialize(use_personal_account=self.use_personal_account, service_account=self.service_account, private_key=self.private_key, project_id=project_id) self._has_setup = True def check_setup(self, project_id: t.Optional[str] = None): """Ensures that setup has been called.""" if not self._has_setup: try: # This throws an exception if ee is not initialized. ee.data.getAlgorithms() self._has_setup = True except ee.EEException: self.setup(project_id) def process(self, *args, **kwargs): """Checks that setup has been called then call the process implementation.""" self.check_setup() def get_ee_safe_name(uri: str) -> str: """Extracts file name and converts it into an EE-safe name""" basename = os.path.basename(uri) # Strip the extension from the basename. basename, _ = os.path.splitext(basename) # An asset ID can only contain letters, numbers, hyphens, and underscores. # Converting everything else to underscore. asset_name = re.sub(r'[^a-zA-Z0-9-_]+', r'_', basename) return asset_name @dataclasses.dataclass class AssetData: """A class for holding the asset data. Attributes: name: The EE-safe name of the asset. target_path: The location of the asset in GCS. channel_names: A list of channel names in the asset. start_time: Image start time in floating point seconds since epoch. end_time: Image end time in floating point seconds since epoch. properties: A dictionary of asset metadata. """ name: str target_path: str channel_names: t.List[str] start_time: float end_time: float properties: t.Dict[str, t.Union[str, float, int]] @dataclasses.dataclass class ToEarthEngine(ToDataSink): """Loads weather data into Google Earth Engine. A sink that loads dataset (either normalized or read using user-provided kwargs). This sink will read each channel data and merge them into a single dataset if the `disable_grib_schema_normalization` flag is not specified. It will read the dataset and create an asset. Next, it will write the asset to the specified bucket path and initiate the earth engine upload request. When using the default service account bound to the VM, it is required to register the service account with EE from `here`_. See `this doc`_ for more detail. Attributes: asset_location: The bucket location at which asset files will be pushed. ee_asset: The asset folder path in earth engine project where the asset files will be pushed. ee_asset_type: The type of asset to ingest in the earth engine. Default: IMAGE. xarray_open_dataset_kwargs: A dictionary of kwargs to pass to xr.open_dataset(). disable_grib_schema_normalization: A flag to turn grib schema normalization off; Default: on. skip_region_validation: Turn off validation that checks if all Cloud resources are in the same region. use_personal_account: A flag to authenticate earth engine using personal account. Default: False. .. _here: https://signup.earthengine.google.com/#!/service_accounts .. _this doc: https://developers.google.com/earth-engine/guides/service_account """ asset_location: str ee_asset: str ee_asset_type: str xarray_open_dataset_kwargs: t.Dict disable_grib_schema_normalization: bool skip_region_validation: bool use_personal_account: bool force: bool service_account: str private_key: str ee_qps: int ee_latency: float ee_max_concurrent: int group_common_hypercubes: bool band_names_mapping: str initialization_time_regex: str forecast_time_regex: str ingest_as_virtual_asset: bool use_deflate: bool use_metrics: bool use_monitoring_metrics: bool topic: str # Pipeline arguments. job_name: str project: str region: str @classmethod def add_parser_arguments(cls, subparser: argparse.ArgumentParser): subparser.add_argument('--asset_location', type=str, required=True, default=None, help='The GCS location where the asset files will be pushed.') subparser.add_argument('--ee_asset', type=str, required=True, default=None, help='The asset folder path in earth engine project where the asset files' ' will be pushed.') subparser.add_argument('--ee_asset_type', type=str, choices=['IMAGE', 'TABLE'], default='IMAGE', help='The type of asset to ingest in the earth engine.') subparser.add_argument('--xarray_open_dataset_kwargs', type=json.loads, default='{}', help='Keyword-args to pass into `xarray.open_dataset()` in the form of a JSON string.') subparser.add_argument('--disable_grib_schema_normalization', action='store_true', default=False, help='To disable merge of grib datasets. Default: False') subparser.add_argument('-s', '--skip-region-validation', action='store_true', default=False, help='Skip validation of regions for data migration. Default: off') subparser.add_argument('-u', '--use_personal_account', action='store_true', default=False, help='To use personal account for earth engine authentication.') subparser.add_argument('-f', '--force', action='store_true', default=False, help='A flag that allows overwriting of existing asset files in the GCS bucket.' ' Default: off, which means that the ingestion of URIs for which assets files' ' (GeoTiff/CSV) already exist in the GCS bucket will be skipped.') subparser.add_argument('--service_account', type=str, default=None, help='Service account address when using a private key for earth engine authentication.') subparser.add_argument('--private_key', type=str, default=None, help='To use a private key for earth engine authentication.') subparser.add_argument('--ee_qps', type=int, default=10, help='Maximum queries per second allowed by EE for your project. Default: 10') subparser.add_argument('--ee_latency', type=float, default=0.5, help='The expected latency per requests, in seconds. Default: 0.5') subparser.add_argument('--ee_max_concurrent', type=int, default=10, help='Maximum concurrent api requests to EE allowed for your project. Default: 10') subparser.add_argument('--group_common_hypercubes', action='store_true', default=False, help='To group common hypercubes into image collections when loading grib data.') subparser.add_argument('--band_names_mapping', type=str, default=None, help='A JSON file which contains the band names for the TIFF file.') subparser.add_argument('--initialization_time_regex', type=str, default=None, help='A Regex string to get the initialization time from the filename.') subparser.add_argument('--forecast_time_regex', type=str, default=None, help='A Regex string to get the forecast/end time from the filename.') subparser.add_argument('--ingest_as_virtual_asset', action='store_true', default=False, help='To ingest image as a virtual asset. Default: False') subparser.add_argument('--use_deflate', action='store_true', default=False, help='To use deflate compression algorithm. Default: False') subparser.add_argument('--use_metrics', action='store_true', default=False, help='If you want to add Beam metrics to your pipeline. Default: False') subparser.add_argument('--use_monitoring_metrics', action='store_true', default=False, help='If you want to add GCP Monitoring metrics to your pipeline. Default: False') @classmethod def validate_arguments(cls, known_args: argparse.Namespace, pipeline_args: t.List[str]) -> None: pipeline_options = PipelineOptions(pipeline_args) pipeline_options_dict = pipeline_options.get_all_options() if known_args.zarr: raise RuntimeError('Reading Zarr is not (yet) supported.') # Check that ee_asset is in correct format. if not re.match("^projects/.+/assets.*", known_args.ee_asset): raise RuntimeError("'--ee_asset' is required to be in format: projects/+/assets/*.") # Check that both service_account and private_key are provided, or none is. if bool(known_args.service_account) ^ bool(known_args.private_key): raise RuntimeError("'--service_account' and '--private_key' both are required.") # Check that either personal or service account is asked to use. if known_args.use_personal_account and known_args.service_account: raise RuntimeError("Both personal and service account cannot be used at once.") if known_args.ee_qps and known_args.ee_qps < 1: raise RuntimeError("Queries per second should not be less than 1.") if known_args.ee_latency and known_args.ee_latency < 0.001: raise RuntimeError("Latency per request should not be less than 0.001.") if known_args.ee_max_concurrent and known_args.ee_max_concurrent < 1: raise RuntimeError("Maximum concurrent requests should not be less than 1.") # Check that when ingesting as a virtual asset, asset type is image. if known_args.ingest_as_virtual_asset and known_args.ee_asset_type != "IMAGE": raise RuntimeError("Only assets with IMAGE type can be ingested as a virtual asset.") # Check that Cloud resource regions are consistent. if not (known_args.dry_run or known_args.skip_region_validation): # Program execution will terminate on failure of region validation. logger.info('Validating regions for data migration. This might take a few seconds...') validate_region(temp_location=pipeline_options_dict.get('temp_location'), region=pipeline_options_dict.get('region')) logger.info('Region validation completed successfully.') # Check for the band_names_mapping json file. if known_args.band_names_mapping: if not os.path.exists(known_args.band_names_mapping): raise RuntimeError("--band_names_mapping file does not exist.") _, band_names_mapping_extension = os.path.splitext(known_args.band_names_mapping) if not band_names_mapping_extension == '.json': raise RuntimeError("--band_names_mapping should contain a json file as input.") # Check the initialization_time_regex and forecast_time_regex strings. if bool(known_args.initialization_time_regex) ^ bool(known_args.forecast_time_regex): raise RuntimeError("Both --initialization_time_regex & --forecast_time_regex flags need to be present") logger.info(f"Add metrics to pipeline: {known_args.use_metrics}") logger.info(f"Add Google Cloud Monitoring metrics to pipeline: {known_args.use_monitoring_metrics}") def expand(self, paths): """Converts input data files into assets and uploads them into the earth engine.""" band_names_dict = {} if self.band_names_mapping: with open(self.band_names_mapping, 'r', encoding='utf-8') as f: band_names_dict = json.load(f) if self.use_metrics: paths = paths | 'AddTimer' >> beam.ParDo(AddTimer.from_kwargs(**vars(self))) if not self.dry_run: output = ( paths | 'FilterFiles' >> FilterFilesTransform.from_kwargs(**vars(self)) | 'ReshuffleFiles' >> beam.Reshuffle() | 'ConvertToAsset' >> beam.ParDo( ConvertToAsset.from_kwargs(band_names_dict=band_names_dict, **vars(self)) ) | 'IngestIntoEE' >> IngestIntoEETransform.from_kwargs(**vars(self)) ) if self.use_metrics: output | 'AddMetrics' >> AddMetrics.from_kwargs(**vars(self)) else: ( paths | 'Log Files' >> beam.Map(logger.info) ) class FilterFilesTransform(SetupEarthEngine, KwargsFactoryMixin): """Filters out paths for which the assets that are already in the earth engine. Attributes: ee_asset: The asset folder path in earth engine project where the asset files will be pushed. ee_qps: Maximum queries per second allowed by EE for your project. ee_latency: The expected latency per requests, in seconds. ee_max_concurrent: Maximum concurrent api requests to EE allowed for your project. force: A flag that allows overwriting of existing asset files in the GCS bucket. private_key: A private key path to authenticate earth engine using private key. Default: None. service_account: Service account address when using a private key for earth engine authentication. use_personal_account: A flag to authenticate earth engine using personal account. Default: False. """ def __init__(self, asset_location: str, ee_asset: str, ee_asset_type: str, ee_qps: int, ee_latency: float, ee_max_concurrent: int, force: bool, private_key: str, service_account: str, use_personal_account: bool, use_metrics: bool): """Sets up rate limit and initializes the earth engine.""" super().__init__(ee_qps=ee_qps, ee_latency=ee_latency, ee_max_concurrent=ee_max_concurrent, private_key=private_key, service_account=service_account, use_personal_account=use_personal_account, use_metrics=use_metrics) self.asset_location = asset_location self.ee_asset = ee_asset self.ee_asset_type = ee_asset_type self.force_overwrite = force self.use_metrics = use_metrics @timeit('FilterFileTransform') def process(self, uri: str) -> t.Iterator[str]: """Yields uri if the asset does not already exist.""" project_id = self.ee_asset.split('/')[1] self.check_setup(project_id) asset_name = get_ee_safe_name(uri) # Checks if the asset is already present in the GCS bucket or not. target_path = os.path.join( self.asset_location, f'{asset_name}{ASSET_TYPE_TO_EXTENSION_MAPPING[self.ee_asset_type]}') if not self.force_overwrite and FileSystems.exists(target_path): logger.info(f'Asset file {target_path} already exists in GCS bucket. Skipping...') return asset_id = os.path.join(self.ee_asset, asset_name) try: ee.data.getAsset(asset_id) logger.info(f'Asset {asset_id} already exists in EE. Skipping...') except ee.EEException: yield uri @dataclasses.dataclass class ConvertToAsset(beam.DoFn, KwargsFactoryMixin): """Writes asset after extracting input data and uploads it to GCS. Attributes: ee_asset_type: The type of asset to ingest in the earth engine. Default: IMAGE. asset_location: The bucket location at which asset files will be pushed. xarray_open_dataset_kwargs: A dictionary of kwargs to pass to xr.open_dataset(). disable_grib_schema_normalization: A flag to turn grib schema normalization off; Default: on. """ asset_location: str ee_asset_type: str = 'IMAGE' xarray_open_dataset_kwargs: t.Optional[t.Dict] = None disable_grib_schema_normalization: bool = False group_common_hypercubes: t.Optional[bool] = False band_names_dict: t.Optional[t.Dict] = None initialization_time_regex: t.Optional[str] = None forecast_time_regex: t.Optional[str] = None use_deflate: t.Optional[bool] = False use_metrics: t.Optional[bool] = False def add_to_queue(self, queue: Queue, item: t.Any): """Adds a new item to the queue. It will wait until the queue has a room to add a new item. """ while queue.full(): pass queue.put_nowait(item) def convert_to_asset(self, queue: Queue, uri: str): """Converts source data into EE asset (GeoTiff or CSV) and uploads it to the bucket.""" child_logger = logging.getLogger(__name__) child_logger.info(f'Converting {uri!r} to COGs...') job_start_time = get_utc_timestamp() with open_dataset(uri, self.xarray_open_dataset_kwargs, self.disable_grib_schema_normalization, initialization_time_regex=self.initialization_time_regex, forecast_time_regex=self.forecast_time_regex, group_common_hypercubes=self.group_common_hypercubes) as ds_list: if not isinstance(ds_list, list): ds_list = [ds_list] for ds in ds_list: attrs = ds.attrs data = list(ds.values()) asset_name = get_ee_safe_name(uri) channel_names = [ self.band_names_dict.get(da.name, da.name) if self.band_names_dict else da.name for da in data ] dtype, crs, transform = (attrs.pop(key) for key in ['dtype', 'crs', 'transform']) # Adding job_start_time to properites. attrs["job_start_time"] = job_start_time # Make attrs EE ingestable. attrs = make_attrs_ee_compatible(attrs) start_time, end_time = (attrs.get(key) for key in ('start_time', 'end_time')) if self.group_common_hypercubes: level, height = (attrs.pop(key) for key in ['level', 'height']) safe_level_name = get_ee_safe_name(level) asset_name = f'{asset_name}_{safe_level_name}' compression = 'lzw' predictor = 'NO' if self.use_deflate: compression = 'deflate' # Depending on dtype select predictor value. # Predictor is a method of storing only the difference from the # previous value instead of the actual value. predictor = 2 if np.issubdtype(dtype, np.integer) else 3 # For tiff ingestions. if self.ee_asset_type == 'IMAGE': file_name = f'{asset_name}.tiff' with MemoryFile() as memfile: with memfile.open(driver='COG', dtype=dtype, width=data[0].data.shape[1], height=data[0].data.shape[0], count=len(data), nodata=np.nan, crs=crs, transform=transform, compress=compression, predictor=predictor) as f: for i, da in enumerate(data): f.write(da, i+1) # Making the channel name EE-safe before adding it as a band name. f.set_band_description(i+1, get_ee_safe_name(channel_names[i])) f.update_tags(i+1, band_name=channel_names[i]) f.update_tags(i+1, **da.attrs) # Write attributes as tags in tiff. f.update_tags(**attrs) # Copy in-memory tiff to gcs. target_path = os.path.join(self.asset_location, file_name) with FileSystems().create(target_path) as dst: shutil.copyfileobj(memfile, dst, WRITE_CHUNK_SIZE) child_logger.info(f"Uploaded {uri!r}'s COG to {target_path}") # For feature collection ingestions. elif self.ee_asset_type == 'TABLE': channel_names = [] file_name = f'{asset_name}.csv' shape = math.prod(list(ds.dims.values())) # Names of dimesions, coordinates and data variables. dims = list(ds.dims) coords = [c for c in list(ds.coords) if c not in dims] vars = list(ds.data_vars) header = dims + coords + vars # Data of dimesions, coordinates and data variables. dims_data = [ds[dim].data for dim in dims] coords_data = [np.full((shape,), ds[coord].data) for coord in coords] vars_data = [ds[var].data.flatten() for var in vars] data = coords_data + vars_data dims_shape = [len(ds[dim].data) for dim in dims] def get_dims_data(index: int) -> t.List[t.Any]: """Returns dimensions for the given flattened index.""" return [ dim[int(index/math.prod(dims_shape[i+1:])) % len(dim)] for (i, dim) in enumerate(dims_data) ] # Copy CSV to gcs. target_path = os.path.join(self.asset_location, file_name) with tempfile.NamedTemporaryFile() as temp: with open(temp.name, 'w', newline='') as f: writer = csv.writer(f) writer.writerows([header]) # Write rows in batches. for i in range(0, shape, ROWS_PER_WRITE): writer.writerows( [get_dims_data(i) + list(row) for row in zip( *[d[i:i + ROWS_PER_WRITE] for d in data] )] ) copy(temp.name, target_path) asset_data = AssetData( name=asset_name, target_path=target_path, channel_names=channel_names, start_time=start_time, end_time=end_time, properties=attrs ) self.add_to_queue(queue, asset_data) self.add_to_queue(queue, None) # Indicates end of the subprocess. @timeit('ConvertToAsset') def process(self, uri: str) -> t.Iterator[AssetData]: """Opens grib files and yields AssetData. We observed that the convert-to-cog process increases memory usage over time because xarray (v2022.11.0) is not releasing memory as expected while opening any dataset. So we will perform the convert-to-asset process in an isolated process so that the memory consumed while processing will be cleared after the process is killed. The process puts the asset data into the queue which the main process will consume. Queue buffer size is limited so the process will be able to put another item in a queue only after the main process has consumed the queue item, that way it makes sure that no queue item is dropped due to queue buffer size. """ queue = Queue(maxsize=1) process = Process(target=self.convert_to_asset, args=(queue, uri)) process.start() while True: if not queue.empty(): asset_data = queue.get_nowait() # Not needed now but keeping this check for backwards compatibility. if asset_data is None: break yield asset_data # When the convert-to-asset process terminates unexpectedly... if not process.is_alive(): logger.warning(f'Failed to convert {uri!r} to asset!') break process.terminate() class IngestIntoEETransform(SetupEarthEngine, KwargsFactoryMixin): """Ingests asset into earth engine and yields asset id. Attributes: ee_asset: The asset folder path in earth engine project where the asset files will be pushed. ee_asset_type: The type of asset to ingest in the earth engine. Default: IMAGE. ee_qps: Maximum queries per second allowed by EE for your project. ee_latency: The expected latency per requests, in seconds. ee_max_concurrent: Maximum concurrent api requests to EE allowed for your project. private_key: A private key path to authenticate earth engine using private key. Default: None. service_account: Service account address when using a private key for earth engine authentication. use_personal_account: A flag to authenticate earth engine using personal account. Default: False. """ def __init__(self, ee_asset: str, ee_asset_type: str, ee_qps: int, ee_latency: float, ee_max_concurrent: int, private_key: str, service_account: str, use_personal_account: bool, ingest_as_virtual_asset: bool, use_metrics: bool): """Sets up rate limit.""" super().__init__(ee_qps=ee_qps, ee_latency=ee_latency, ee_max_concurrent=ee_max_concurrent, private_key=private_key, service_account=service_account, use_personal_account=use_personal_account, use_metrics=use_metrics) self.ee_asset = ee_asset self.ee_asset_type = ee_asset_type self.ingest_as_virtual_asset = ingest_as_virtual_asset self.use_metrics = use_metrics def get_project_id(self) -> str: return self.ee_asset.split('/')[1] def ee_tasks_remaining(self) -> int: """Returns the remaining number of tasks in the tassk queue of earth engine.""" return len([task for task in ee.data.getTaskList() if task['state'] in ['UNSUBMITTED', 'READY', 'RUNNING']]) def wait_for_task_queue(self) -> None: """Waits until the task queue has space. Ingestion of table in the earth engine creates a task and every project has a limited task queue size. This function checks the task queue size and waits until the task queue has some space. """ while self.ee_tasks_remaining() >= self._num_shards: time.sleep(TASK_QUEUE_WAIT_TIME) @retry.with_exponential_backoff( num_retries=NUM_RETRIES, logger=logger.warning, initial_delay_secs=INITIAL_DELAY, max_delay_secs=MAX_DELAY ) def start_ingestion(self, asset_data: AssetData) -> t.Optional[str]: """Creates COG-backed asset in earth engine. Returns the asset id.""" project_id = self.get_project_id() self.check_setup(project_id) asset_name = os.path.join(self.ee_asset, asset_data.name) asset_data.properties['ingestion_time'] = get_utc_timestamp() try: logger.info(f"Uploading asset {asset_data.target_path} to Asset ID '{asset_name}'.") if self.ee_asset_type == 'IMAGE': # Ingest an image. creds = get_creds(self.use_personal_account, self.service_account, self.private_key) session = AuthorizedSession(creds) image_manifest = { 'name': asset_name, 'tilesets': [ { 'id': '0', 'sources': [{'uris': [asset_data.target_path]}] } ], 'startTime': asset_data.start_time, 'endTime': asset_data.end_time, 'properties': asset_data.properties } headers = { 'Content-Type': 'application/json', 'x-goog-user-project': project_id, } data = json.dumps({'imageManifest': image_manifest, 'overwrite': True}) if self.ingest_as_virtual_asset: # as a virtual image. # Makes an api call to register the virtual asset. url = ( f'https://earthengine-highvolume.googleapis.com/v1/projects/{project_id}/' f'image:import?overwrite=true&mode=VIRTUAL' ) else: # as a COG based image. url = ( f'https://earthengine-highvolume.googleapis.com/v1alpha/projects/{project_id}/' f'image:importExternal' ) # Send API request response = session.post(url=url, data=data, headers=headers) logger.info(f"EE Asset ingestion response for {asset_name}: {response.text}") if response.status_code != 200: logger.info(f"Failed to ingest asset '{asset_name}' in Earth Engine: {response.text}") raise ee.EEException(response.text) if self.ingest_as_virtual_asset: response_json = response.json() ingestion_state = ( response_json .get("metadata", {}) .get("state", "STATE_UNSPECIFIED") .upper() ) if ingestion_state != "SUCCEEDED": raise ee.EEException(response.text) return asset_name elif self.ee_asset_type == 'TABLE': # ingest a feature collection. self.wait_for_task_queue() task_id = ee.data.newTaskId(1)[0] ee.data.startTableIngestion(task_id, { 'name': asset_name, 'sources': [{ 'uris': [asset_data.target_path] }], 'startTime': asset_data.start_time, 'endTime': asset_data.end_time, 'properties': asset_data.properties }) return asset_name except ee.EEException as e: if "Could not parse a valid CRS from the first overview of the GeoTIFF" in repr(e): logger.error(f"Failed to create asset '{asset_name}' in earth engine: {e}. Moving on...") return "" if ( "whose type does not match the type of the same property of existing " "assets in the same collection" in repr(e) ): logger.error(f"Failed to ingest asset '{asset_name}' due to property mismatch: {e} Moving on...") return "" if "The metadata of the TIFF could not be read in the first 10000000 bytes." in repr(e): logger.error(f"Faild to ingest asset '{asset_name}', check the tiff file: {e} Moving on...") return "" logger.error(f"Failed to create asset '{asset_name}' in earth engine: {e}") # We do have logic for skipping the already created assets in FilterFilesTransform but # somehow we are observing that streaming pipeline reports "Cannot overwrite ..." error # so this will act as a quick fix for this issue. if f"Cannot overwrite asset '{asset_name}'" in repr(e): ee.data.deleteAsset(asset_name) raise @timeit('IngestIntoEE') def process(self, asset_data: AssetData) -> t.Iterator[t.Tuple[str, float]]: """Uploads an asset into the earth engine.""" asset_name = self.start_ingestion(asset_data) if asset_name: metric.Metrics.counter('Success', 'IngestIntoEE').inc() asset_start_time = asset_data.start_time yield asset_name, asset_start_time ================================================ FILE: weather_mv/loader_pipeline/ee_test.py ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os import tempfile import unittest from .ee import ( get_ee_safe_name, ConvertToAsset ) from .sinks_test import TestDataBase logger = logging.getLogger(__name__) class AssetNameCreationTests(unittest.TestCase): def test_asset_name_creation(self): uri = 'weather_mv/test_data/grib_multiple_edition_single_timestep.bz2' expected = 'grib_multiple_edition_single_timestep' actual = get_ee_safe_name(uri) self.assertEqual(actual, expected) def test_asset_name_creation__with_special_chars(self): uri = 'weather_mv/test_data/grib@2nd-edition×tep#1.bz2' expected = 'grib_2nd-edition_timestep_1' actual = get_ee_safe_name(uri) self.assertEqual(actual, expected) def test_asset_name_creation__with_missing_filename(self): uri = 'weather_mv/test_data/' expected = '' actual = get_ee_safe_name(uri) self.assertEqual(actual, expected) def test_asset_name_creation__with_only_filename(self): uri = 'grib@2nd-edition×tep#1.bz2' expected = 'grib_2nd-edition_timestep_1' actual = get_ee_safe_name(uri) self.assertEqual(actual, expected) class ConvertToAssetTests(TestDataBase): def setUp(self) -> None: super().setUp() self.tmpdir = tempfile.TemporaryDirectory() self.convert_to_image_asset = ConvertToAsset(asset_location=self.tmpdir.name) self.convert_to_table_asset = ConvertToAsset(asset_location=self.tmpdir.name, ee_asset_type='TABLE') def tearDown(self): self.tmpdir.cleanup() def test_convert_to_image_asset(self): data_path = f'{self.test_data_folder}/test_data_grib_single_timestep' asset_path = os.path.join(self.tmpdir.name, 'test_data_grib_single_timestep.tiff') next(self.convert_to_image_asset.process(data_path)) # The size of tiff is expected to be more than grib. self.assertTrue(os.path.getsize(asset_path) > os.path.getsize(data_path)) def test_convert_to_image_asset__with_multiple_grib_edition(self): data_path = f'{self.test_data_folder}/test_data_grib_multiple_edition_single_timestep.bz2' asset_path = os.path.join(self.tmpdir.name, 'test_data_grib_multiple_edition_single_timestep.tiff') next(self.convert_to_image_asset.process(data_path)) # The size of tiff is expected to be more than grib. self.assertTrue(os.path.getsize(asset_path) > os.path.getsize(data_path)) def test_convert_to_table_asset(self): data_path = f'{self.test_data_folder}/test_data_grib_single_timestep' asset_path = os.path.join(self.tmpdir.name, 'test_data_grib_single_timestep.csv') next(self.convert_to_table_asset.process(data_path)) # The size of tiff is expected to be more than grib. self.assertTrue(os.path.getsize(asset_path) > os.path.getsize(data_path)) def test_convert_to_table_asset__with_multiple_grib_edition(self): data_path = f'{self.test_data_folder}/test_data_grib_multiple_edition_single_timestep.bz2' asset_path = os.path.join(self.tmpdir.name, 'test_data_grib_multiple_edition_single_timestep.csv') next(self.convert_to_table_asset.process(data_path)) # The size of tiff is expected to be more than grib. self.assertTrue(os.path.getsize(asset_path) > os.path.getsize(data_path)) if __name__ == '__main__': unittest.main() ================================================ FILE: weather_mv/loader_pipeline/execution_test.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import weather_mv from .pipeline import run, pipeline class ExecutionTest(unittest.TestCase): TEST_DATA_FOLDER = f'{next(iter(weather_mv.__path__))}/test_data' DEFAULT_CMD = 'weather-mv bq --output_table myproject.mydataset.mytable --temp_location "gs://mybucket/tmp" ' \ '--geo_data_parquet_path ./geo_data.parquet --dry-run ' LOCAL_DATA_SOURCE = f'{TEST_DATA_FOLDER}/test_data_single_point.nc' TEST_CASES = [ ('local data source and local execution', f'{DEFAULT_CMD} --uris {LOCAL_DATA_SOURCE} --direct_num_workers 2'), ] def test_run(self): for msg, args in self.TEST_CASES: with self.subTest(msg): pipeline(*run(args.split())) if __name__ == '__main__': unittest.main() ================================================ FILE: weather_mv/loader_pipeline/metrics.py ================================================ # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for adding metrics to beam pipeline.""" import copy import dataclasses import datetime import inspect import json import logging import time import typing as t from collections import OrderedDict import apache_beam as beam from apache_beam.metrics import metric from apache_beam.transforms import window, trigger from functools import wraps from google.cloud import monitoring_v3 from .sinks import get_file_time, KwargsFactoryMixin logger = logging.getLogger(__name__) # For Metrics API retry logic. INITIAL_DELAY = 1.0 # Initial delay in seconds. MAX_DELAY = 600 # Maximum delay before giving up in seconds. NUM_RETRIES = 10 # Number of tries with exponential backoff. TASK_QUEUE_WAIT_TIME = 120 # Task queue wait time in seconds. def timeit(func_name: str, keyed_fn: bool = False): """Decorator to add time it takes for an element to be processed by a stage. Args: func_name: A unique name of the stage. keyed_fn (optional): This has to be passed true if the input is adding keys to the element. For example a stage like class Shard(beam.DoFn): @timeit('Sharding', keyed_fn=True) def process(self,element): key = randrange(10) yield key, element We are passing `keyed_fn=True` as we are adding a key to our element. Usually keys are added to later group the element by a `GroupBy` stage. """ def decorator(func): @wraps(func) def wrapper(self, *args, **kwargs): # If metrics are turned off, don't do anything. if not hasattr(self, "use_metrics") or ( hasattr(self, "use_metrics") and not self.use_metrics ): for result in func(self, *args, **kwargs): yield result return # Only the first timer wrapper will have no time_dict. # All subsequent wrappers can extract out the dict. # args 0 would be a tuple. if len(args[0]) == 1: raise ValueError("time_dict not found.") element, time_dict = args[0] args = (element,) + args[1:] if not isinstance(time_dict, OrderedDict): raise ValueError("time_dict not found.") # If the function is a generator, yield the output # othewise return it. if inspect.isgeneratorfunction(func): for result in func(self, *args, **kwargs): new_time_dict = copy.deepcopy(time_dict) if func_name in new_time_dict: del new_time_dict[func_name] new_time_dict[func_name] = time.time() if keyed_fn: (key, element) = result yield key, (element, new_time_dict) else: yield result, new_time_dict else: raise ValueError("Function is not a generator.") return wrapper return decorator @dataclasses.dataclass class AddTimer(beam.DoFn, KwargsFactoryMixin): """DoFn to add a time_dict with uri, file time in GCS bucket, when it was picked up in PCollection. This dict will contain each stage_names as keys and the timestamp when it finished that step's execution.""" topic: t.Optional[str] = None def process(self, element) -> t.Iterator[t.Any]: time_dict = OrderedDict( [ ("uri", element), ("bucket", get_file_time(element) if self.topic else time.time()), ("pickup", time.time()), ] ) yield element, time_dict class AddBeamMetrics(beam.DoFn): """DoFn to add Element Processing Time metric to beam. Expects PCollection to contain a time_dict.""" def __init__(self, asset_start_time_format: str): super().__init__() self.element_processing_time = metric.Metrics.distribution( "Time", "element_processing_time_ms" ) self.data_latency_time = metric.Metrics.distribution( "Time", "data_latency_time_ms" ) self.asset_start_time_format = asset_start_time_format def process(self, element): try: if len(element) == 0: raise ValueError("time_dict not found.") (asset_name, asset_start_time), time_dict = element if not isinstance(time_dict, OrderedDict): raise ValueError("time_dict not found.") uri = time_dict.pop("uri") # Time for a file to get ingested into EE from when it appeared in bucket. # When the pipeline is in batch mode, it will be from when the file # was picked up by the pipeline. element_processing_time = ( time_dict["IngestIntoEE"] - time_dict["bucket"] ) * 1000 self.element_processing_time.update(int(element_processing_time)) # Adding data latency. if asset_start_time: current_time = time.time() asset_start_time = datetime.datetime.strptime( asset_start_time, self.asset_start_time_format ).timestamp() # Converting seconds to milli seconds. data_latency_ms = (current_time - asset_start_time) * 1000 self.data_latency_time.update(int(data_latency_ms)) # Logging file init to bucket time as well. time_dict.update({"FileInit": asset_start_time}) time_dict.move_to_end("FileInit", last=False) # Logging time taken by each step... step_intervals = { f"{current_step} -> {next_step}": round(next_time - current_time) for (current_step, current_time), (next_step, next_time) in zip(time_dict.items(), list(time_dict.items())[1:]) } logger.info( f"Step intervals for {uri}:{asset_name} :: {json.dumps(step_intervals, indent=4)}" ) yield ("custom_metrics", (data_latency_ms / 1000, element_processing_time / 1000)) except Exception as e: logger.warning( f"Some error occured while adding metrics. Error {e}" ) @dataclasses.dataclass class CreateTimeSeries(beam.DoFn): """DoFn to write metrics TimeSeries data in Google Cloud Monitoring.""" job_name: str project: str region: str def create_time_series_object(self, metric_name: str, metric_value: float): """Returns a Metrics TimeSeries object.""" series = monitoring_v3.TimeSeries() series.metric.type = f"custom.googleapis.com/{metric_name}" series.metric.labels["description"] = metric_name series.resource.type = "dataflow_job" series.resource.labels["job_name"] = self.job_name series.resource.labels["project_id"] = self.project series.resource.labels["region"] = self.region now = time.time() seconds = int(now) nanos = int((now - seconds) * 10**9) interval = monitoring_v3.TimeInterval( {"end_time": {"seconds": seconds, "nanos": nanos}} ) point = monitoring_v3.Point( {"interval": interval, "value": {"double_value": metric_value}} ) series.points = [point] return series def process(self, element: t.Any): _, metric_values = element data_latency_times = [x[0] for x in metric_values] element_processing_times = [x[1] for x in metric_values] logger.info(f"data_latency_time values: {data_latency_times}") data_latency_max_series = self.create_time_series_object( "data_latency_time_max", max(data_latency_times) ) data_latency_mean_series = self.create_time_series_object( "data_latency_time_mean", sum(data_latency_times) / len(data_latency_times), ) logger.info( f"element_processing_time values: {element_processing_times}" ) element_processing_max_series = self.create_time_series_object( "element_processing_time_max", max(element_processing_times) ) element_processing_mean_series = self.create_time_series_object( "element_processing_time_mean", sum(element_processing_times) / len(element_processing_times), ) client = monitoring_v3.MetricServiceClient() client.create_time_series( name=f"projects/{self.project}", time_series=[ data_latency_max_series, data_latency_mean_series, element_processing_max_series, element_processing_mean_series, ], ) @dataclasses.dataclass class AddMetrics(beam.PTransform, KwargsFactoryMixin): """A custom transform to add metrics to the pipeline.""" job_name: str project: str region: str use_monitoring_metrics: bool asset_start_time_format: str = "%Y-%m-%dT%H:%M:%SZ" def expand(self, pcoll: beam.PCollection): metrics = pcoll | "AddBeamMetrics" >> beam.ParDo(AddBeamMetrics(self.asset_start_time_format)) if self.use_monitoring_metrics: ( metrics | "AddTimestamps" >> beam.Map( lambda element: window.TimestampedValue(element, time.time()) ) | "Window" >> beam.WindowInto( window.GlobalWindows(), trigger=trigger.Repeatedly(trigger.AfterProcessingTime(5)), accumulation_mode=trigger.AccumulationMode.DISCARDING, ) | "GroupByKeyAndWindow" >> beam.GroupByKey(lambda element: element) | "CreateTimeSeries" >> beam.ParDo( CreateTimeSeries(self.job_name, self.project, self.region) ) ) ================================================ FILE: weather_mv/loader_pipeline/pipeline.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Pipeline for loading weather data into analysis-ready mediums, like Google BigQuery.""" import argparse import json import logging import typing as t import warnings import apache_beam as beam from apache_beam.io.filesystems import FileSystems from apache_beam.options.pipeline_options import PipelineOptions from .bq import ToBigQuery from .regrid import Regrid from .ee import ToEarthEngine from .streaming import GroupMessagesByFixedWindows, ParsePaths logger = logging.getLogger(__name__) SDK_CONTAINER_IMAGE = 'gcr.io/weather-tools-prod/weather-tools:0.0.0' def configure_logger(verbosity: int) -> None: """Configures logging from verbosity. Default verbosity will show errors.""" level = (40 - verbosity * 10) logging.getLogger(__package__).setLevel(level) logger.setLevel(level) def pattern_to_uris(match_pattern: str, is_zarr: bool = False) -> t.Iterable[str]: if is_zarr: yield match_pattern return for match in FileSystems().match([match_pattern]): yield from [x.path for x in match.metadata_list] def pipeline(known_args: argparse.Namespace, pipeline_args: t.List[str]) -> None: all_uris = list(pattern_to_uris(known_args.uris, known_args.zarr)) if not all_uris: raise FileNotFoundError(f"File pattern '{known_args.uris}' matched no objects") # First URI is useful to get an example data shard. It also can be a Zarr path. known_args.first_uri = next(iter(all_uris)) with beam.Pipeline(argv=pipeline_args) as p: if known_args.zarr: paths = p elif known_args.topic or known_args.subscription: paths = ( p # Windowing is based on this code sample: # https://cloud.google.com/pubsub/docs/pubsub-dataflow#code_sample | 'ReadUploadEvent' >> beam.io.ReadFromPubSub(known_args.topic, known_args.subscription) | 'WindowInto' >> GroupMessagesByFixedWindows(known_args.window_size, known_args.num_shards) | 'ParsePaths' >> beam.ParDo(ParsePaths(known_args.uris)) ) else: paths = p | 'Create' >> beam.Create(all_uris) if known_args.subcommand == 'bigquery' or known_args.subcommand == 'bq': paths | "MoveToBigQuery" >> ToBigQuery.from_kwargs(**vars(known_args)) elif known_args.subcommand == 'regrid' or known_args.subcommand == 'rg': paths | "Regrid" >> Regrid.from_kwargs(**vars(known_args)) elif known_args.subcommand == 'earthengine' or known_args.subcommand == 'ee': pipeline_options = PipelineOptions(pipeline_args) pipeline_options_dict = pipeline_options.get_all_options() # all_args stores all arguments passed to the pipeline. # This is necessary because pipeline_args are later used by # the CreateTimeSeries DoFn in the AddMetrics transform. all_args = {**pipeline_options_dict, **vars(known_args)} paths | "MoveToEarthEngine" >> ToEarthEngine.from_kwargs(**all_args) else: raise ValueError('invalid subcommand!') logger.info('Pipeline is finished.') def run(argv: t.List[str]) -> t.Tuple[argparse.Namespace, t.List[str]]: """Main entrypoint & pipeline definition.""" parser = argparse.ArgumentParser( prog='weather-mv', description='Weather Mover loads weather data from cloud storage into analytics engines.' ) # Common arguments to all commands base = argparse.ArgumentParser(add_help=False) base.add_argument('-i', '--uris', type=str, required=True, help="URI glob pattern matching input weather data, e.g. 'gs://ecmwf/era5/era5-2015-*.gb'. Or, " "a path to a Zarr.") base.add_argument('--topic', type=str, help="A Pub/Sub topic for GCS OBJECT_FINALIZE events, or equivalent, of a cloud bucket. " "E.g. 'projects//topics/'. Cannot be used with `--subscription`.") base.add_argument('--subscription', type=str, help='A Pub/Sub subscription for GCS OBJECT_FINALIZE events, or equivalent, of a cloud bucket. ' 'Cannot be used with `--topic`.') base.add_argument("--window_size", type=float, default=1.0, help="Output file's window size in minutes. Only used with the `topic` flag. Default: 1.0 " "minute.") base.add_argument('--num_shards', type=int, default=5, help='Number of shards to use when writing windowed elements to cloud storage. Only used with ' 'the `topic` flag. Default: 5 shards.') base.add_argument('--zarr', action='store_true', default=False, help="Treat the input URI as a Zarr. If the URI ends with '.zarr', this will be set to True. " "Default: off") base.add_argument('--zarr_kwargs', type=json.loads, default='{}', help='Keyword arguments to pass into `xarray.open_zarr()`, as a JSON string. ' 'Default: `{"chunks": null, "consolidated": true}`.') base.add_argument('-d', '--dry-run', action='store_true', default=False, help='Preview the weather-mv job. Default: off') base.add_argument('--log-level', type=int, default=2, help='An integer to configure log level. Default: 2(INFO)') base.add_argument('--use-local-code', action='store_true', default=False, help='Supply local code to the Runner.') subparsers = parser.add_subparsers(help='help for subcommand', dest='subcommand') # BigQuery command registration bq_parser = subparsers.add_parser('bigquery', aliases=['bq'], parents=[base], help='Move data into Google BigQuery') ToBigQuery.add_parser_arguments(bq_parser) # Regrid command registration rg_parser = subparsers.add_parser('regrid', aliases=['rg'], parents=[base], help='Copy and regrid grib data with MetView.') Regrid.add_parser_arguments(rg_parser) # EarthEngine command registration ee_parser = subparsers.add_parser('earthengine', aliases=['ee'], parents=[base], help='Move data into Google EarthEngine') ToEarthEngine.add_parser_arguments(ee_parser) known_args, pipeline_args = parser.parse_known_args(argv[1:]) configure_logger(known_args.log_level) # 0 = error, 1 = warn, 2 = info, 3 = debug # Validate Zarr arguments if known_args.uris.endswith('.zarr'): known_args.zarr = True if known_args.zarr_kwargs and not known_args.zarr: raise ValueError('`--zarr_kwargs` argument is only allowed with valid Zarr input URI.') if known_args.zarr_kwargs: if not known_args.zarr_kwargs.get('start_date') or not known_args.zarr_kwargs.get('end_date'): warnings.warn('`--zarr_kwargs` not contains both `start_date` and `end_date`' 'so whole zarr-dataset will ingested.') if known_args.zarr: known_args.zarr_kwargs['chunks'] = known_args.zarr_kwargs.get('chunks', None) known_args.zarr_kwargs['consolidated'] = known_args.zarr_kwargs.get('consolidated', True) # Validate subcommand if known_args.subcommand == 'bigquery' or known_args.subcommand == 'bq': ToBigQuery.validate_arguments(known_args, pipeline_args) elif known_args.subcommand == 'regrid' or known_args.subcommand == 'rg': Regrid.validate_arguments(known_args, pipeline_args) elif known_args.subcommand == 'earthengine' or known_args.subcommand == 'ee': ToEarthEngine.validate_arguments(known_args, pipeline_args) # If a Pub/Sub is used, then the pipeline must be a streaming pipeline. if known_args.topic or known_args.subscription: if known_args.topic and known_args.subscription: raise ValueError('only one argument can be provided at a time: `topic` or `subscription`.') if known_args.zarr: raise ValueError('streaming updates to a Zarr file is not (yet) supported.') pipeline_args.extend('--streaming true'.split()) # make sure we re-compute utcnow() every time rows are extracted from a file. known_args.import_time = None # We use the save_main_session option because one or more DoFn's in this # workflow rely on global context (e.g., a module imported at module level). pipeline_args.extend('--save_main_session true'.split()) return known_args, pipeline_args ================================================ FILE: weather_mv/loader_pipeline/pipeline_test.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import unittest import weather_mv from .pipeline import run, pipeline class CLITests(unittest.TestCase): def setUp(self) -> None: self.test_data_folder = f'{next(iter(weather_mv.__path__))}/test_data' self.base_cli_args = ( 'weather-mv bq ' f'-i {self.test_data_folder}/test_data_2018*.nc ' '-o myproject.mydataset.mytable ' '--import_time 2022-02-04T22:22:12.125893 ' '--geo_data_parquet_path geo_data.parquet ' '-s' ).split() self.tif_base_cli_args = ( 'weather-mv bq ' f'-i {self.test_data_folder}/test_data_tif_time.tif ' '-o myproject.mydataset.mytable ' '--import_time 2022-02-04T22:22:12.125893 ' '-s --geo_data_parquet_path geo_data.parquet' ).split() self.bq_base_cli_args = ( 'weather-mv bq ' f'-i {self.test_data_folder}/test_data_2018*.nc ' '-o myproject.mydataset.mytable ' '--import_time 2022-02-04T22:22:12.125893 ' '-s' ).split() self.ee_cli_args = ( 'weather-mv ee ' '-i weather_mv/test_data/test_data_2018*.nc ' '--asset_location gs://bucket/my-assets/ ' '--ee_asset "projects/my-project/assets/asset_dir' ).split() self.rg_cli_args = ( 'weather-mv rg ' '-i weather_mv/test_data/test_data_2018*.nc ' '-o weather_mv/test_data/output/ ' ).split() self.base_cli_known_args = { 'subcommand': 'bq', 'uris': f'{self.test_data_folder}/test_data_2018*.nc', 'output_table': 'myproject.mydataset.mytable', 'dry_run': False, 'skip_region_validation': True, 'import_time': '2022-02-04T22:22:12.125893', 'infer_schema': False, 'num_shards': 5, 'topic': None, 'subscription': None, 'variables': [], 'window_size': 1.0, 'xarray_open_dataset_kwargs': {}, 'rows_chunk_size': 1_000_000, 'disable_grib_schema_normalization': False, 'tif_metadata_for_start_time': None, 'tif_metadata_for_end_time': None, 'zarr': False, 'zarr_kwargs': {}, 'log_level': 2, 'use_local_code': False, 'skip_creating_polygon': False, 'geo_data_parquet_path': 'geo_data.parquet', 'skip_creating_geo_data_parquet': False } class TestCLI(CLITests): def test_dry_runs_are_allowed(self): known_args, _ = run(self.base_cli_args + '--dry-run'.split()) self.assertEqual(known_args.dry_run, True) def test_log_level_arg(self): known_args, _ = run(self.base_cli_args + '--log-level 3'.split()) self.assertEqual(known_args.log_level, 3) def test_tif_metadata_for_datetime_raise_error_for_non_tif_file(self): with self.assertRaisesRegex(RuntimeError, 'can be specified only for tif files.'): run(self.base_cli_args + '--tif_metadata_for_start_time start_time ' '--tif_metadata_for_end_time end_time'.split()) def test_tif_metadata_for_datetime_raise_error_if_flag_is_absent(self): with self.assertRaisesRegex(RuntimeError, 'is required for tif files.'): run(self.tif_base_cli_args) def test_area_only_allows_four(self): with self.assertRaisesRegex(AssertionError, 'Must specify exactly 4 lat/long .* N, W, S, E'): run(self.base_cli_args + '--area 1 2 3'.split()) with self.assertRaisesRegex(AssertionError, 'Must specify exactly 4 lat/long .* N, W, S, E'): run(self.base_cli_args + '--area 1 2 3 4 5'.split()) known_args, pipeline_args = run(self.base_cli_args + '--area 1 2 3 4'.split()) self.assertEqual(pipeline_args, ['--save_main_session', 'true']) self.assertEqual(vars(known_args), { **self.base_cli_known_args, 'area': [1, 2, 3, 4] }) def test_topic_creates_a_streaming_pipeline(self): _, pipeline_args = run(self.base_cli_args + '--topic projects/myproject/topics/my-topic'.split()) self.assertEqual(pipeline_args, ['--streaming', 'true', '--save_main_session', 'true']) def test_subscription_creates_a_streaming_pipeline(self): _, pipeline_args = run(self.base_cli_args + '--subscription projects/myproject/topics/my-topic'.split()) self.assertEqual(pipeline_args, ['--streaming', 'true', '--save_main_session', 'true']) def test_accepts_json_string_for_xarray_open(self): xarray_kwargs = dict(engine='cfgrib', backend_kwargs={'filter_by_keys': {'edition': 1}}) json_kwargs = json.dumps(xarray_kwargs) known_args, _ = run( self.base_cli_args + ["--xarray_open_dataset_kwargs", f"{json_kwargs}"] ) self.assertEqual(known_args.xarray_open_dataset_kwargs, xarray_kwargs) def test_ee_does_not_yet_support_zarr(self): with self.assertRaisesRegex(RuntimeError, 'Reading Zarr'): run(self.ee_cli_args + '--zarr'.split()) def test_rg_zarr_cant_output_netcdf(self): with self.assertRaisesRegex(ValueError, 'only Zarr-to-Zarr'): run(self.rg_cli_args + '--zarr --to_netcdf'.split()) def test_rg_happy_path(self): run(self.rg_cli_args + ['--zarr']) def test_zarr_kwargs_must_come_with_zarr(self): with self.assertRaisesRegex(ValueError, 'allowed with valid Zarr input URI'): run(self.base_cli_args + ['--zarr_kwargs', json.dumps({"time": 100})]) def test_topic_and_subscription__mutually_exclusive(self): with self.assertRaisesRegex(ValueError, '`topic` or `subscription`'): run(self.base_cli_args + '--topic foo --subscription bar'.split()) def test_geo_data_parquet_path_must_be_parquet(self): with self.assertRaisesRegex(RuntimeError, "must end with '.parquet'"): run(self.bq_base_cli_args + '--geo_data_parquet_path test_geo.txt'.split()) class IntegrationTest(CLITests): def test_dry_runs_are_allowed(self): pipeline(*run(self.base_cli_args + '--dry-run'.split())) if __name__ == '__main__': unittest.main() ================================================ FILE: weather_mv/loader_pipeline/regrid.py ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import contextlib import dataclasses import glob import json import logging import os.path import shutil import subprocess import tempfile import typing as t import warnings import apache_beam as beam import dask import xarray as xr import xarray_beam as xbeam from apache_beam.io.filesystems import FileSystems from .sinks import ToDataSink, open_local, copy logger = logging.getLogger(__name__) try: import metview as mv Fieldset = mv.bindings.Fieldset except (ModuleNotFoundError, ImportError, FileNotFoundError, ValueError): logger.error('Metview could not be imported.') mv = None # noqa Fieldset = t.Any def _clear_metview(): """Clear the metview temporary directory. By default, caches are cleared when the MetView _process_ ends. This method is necessary to free space sooner than that, namely after invoking MetView functions. """ cache_dirs = glob.glob(f'{tempfile.gettempdir()}/mv.*') for cache_dir in cache_dirs: shutil.rmtree(cache_dir) os.makedirs(cache_dir) @contextlib.contextmanager def _metview_op() -> t.Iterator[None]: """Perform operation with MetView, including error handling and cleanup.""" try: yield except (ModuleNotFoundError, ImportError, FileNotFoundError) as e: raise ImportError('Please install MetView with Anaconda:\n' '`conda install metview-batch -c conda-forge`') from e finally: _clear_metview() class MapChunkAsFieldset(beam.PTransform): """Apply an operation with MetView on a xarray.Dataset as if it's a metview.Fieldset. This transform will handle converting to and from xr.Datasets to mv.Fieldsets. This allows the user to perform any MetView or Fieldset operation within the overridable `apply()` method. > Warning: This cannot process large Datasets without a decent amount of disk space! """ def apply(self, key: xbeam.Key, fs: Fieldset) -> t.Tuple[xbeam.Key, Fieldset]: return key, fs def _apply(self, key: xbeam.Key, ds: xr.Dataset) -> t.Tuple[xbeam.Key, xr.Dataset]: # Clear metadata so ecCodes doesn't mess up the conversion. Instead of default grib fields, # ecCodes will use the parameter ID. Thus, the fields will appear in the final, regridded # dataset. for dv in ds.data_vars: for to_del in ['GRIB_cfName', 'GRIB_shortName', 'GRIB_cfVarName']: if to_del in ds[dv].attrs: del ds[dv].attrs[to_del] with _metview_op(): # mv.dataset_to_fieldset() will error on input where there is only 1 value # in a dimension. ECMWF's cfgrib is in its alpha version. try: fs = mv.dataset_to_fieldset(ds) except ValueError as e: raise ValueError( 'please change `zarr_input_chunk`s so that there are no' 'single element dimensions (e.g. {"time": 1} is not allowed).' ) from e # Apply any & all MetView or FieldSet operations. kout, fs_out = self.apply(key, fs) return kout, fs_out.to_dataset().compute() def expand(self, pcoll): return pcoll | beam.MapTuple(self._apply) @dataclasses.dataclass class RegridChunk(MapChunkAsFieldset): """Regrid a xarray.Dataset with MetView. Attributes: regrid_kwargs: A dictionary of keyword-args to be passed into `mv.regrid()` (excluding the dataset). zarr_input_chunks: (Optional) When regridding Zarr data, how the input dataset should be chunked upon open. """ regrid_kwargs: t.Dict zarr_input_chunks: t.Optional[t.Dict] = None def template(self, source_ds: xr.Dataset) -> xr.Dataset: """Calculate the output Zarr template by regridding (a tiny slice of) the input dataset.""" # Silence Dask warning... with dask.config.set(**{'array.slicing.split_large_chunks': False}): zeros = source_ds.chunk().pipe(xr.zeros_like) # If the chunked source dataset is small (less than 10 MB), just regrid it! if (zeros.nbytes / 1024 / 1024) < 10: _, ds = self._apply(xbeam.Key(), zeros) return ds.chunk() # source_ds is probably very big! Let's shrink it by a non-spatial dimension # so calculating the template will be tractable... # Get a single timeslice of the zeros Dataset (or equivalent chunkable dimension). # We don't know for sure that 'time' is in the Zarr dataset, so here we make our # best attempt to find a good slice. t0 = None for dim in ['time', *(self.zarr_input_chunks or {}).keys()]: if dim in zeros: t0 = zeros.isel({dim: 0}, drop=True) break if t0 is None: raise ValueError('cannot infer any dimension when creating a Zarr template. ' 'Please define at least one chunk in `--zarr_input_chunks`.') _, ds = self._apply(xbeam.Key(), t0) # Regrid the single time, then expand the Dataset to span all times. tmpl = ( ds .chunk() .expand_dims({dim: zeros[dim]}, 0) ) return tmpl def apply(self, key: xbeam.Key, fs: Fieldset) -> t.Tuple[xbeam.Key, Fieldset]: return key, mv.regrid(data=fs, **self.regrid_kwargs) @dataclasses.dataclass class Regrid(ToDataSink): """Regrid data using MetView. See https://metview.readthedocs.io/en/latest/metview/using_metview/regrid_intro.html for an in-depth intro on regridding with MetView. Attributes: output_path: URI for regridding target. Can be a glob pattern of NetCDF or Grib files; optionally, it can be a Zarr corpus is supported. regrid_kwargs: A dictionary of keyword-args to be passed into `mv.regrid()` (excluding the dataset). to_netcdf: When set, it raw data output will be written as NetCDF. Cannot use with Zarr datasets. zarr_input_chunks: (Optional) When regridding Zarr data, how the input dataset should be chunked upon open. zarr_output_chunks: (Optional, recommended) When regridding Zarr data, how the output Zarr dataset should be divided into chunks. """ output_path: str regrid_kwargs: t.Dict force_regrid: bool = False to_netcdf: bool = False apply_bz2_compression: bool = False zarr_input_chunks: t.Optional[t.Dict] = None zarr_output_chunks: t.Optional[t.Dict] = None @classmethod def add_parser_arguments(cls, subparser: argparse.ArgumentParser) -> None: subparser.add_argument('-o', '--output_path', type=str, required=True, help='The destination path for the regridded files.') subparser.add_argument('-k', '--regrid_kwargs', type=json.loads, default='{"grid": [0.25, 0.25]}', help="""Keyword-args to pass into `metview.regrid()` in the form of a JSON string. """ """Will default to '{"grid": [0.25, 0.25]}'.""") subparser.add_argument('--force_regrid', action='store_true', default=False, help='Force regrid all files even if file is present at output_path.') subparser.add_argument('--to_netcdf', action='store_true', default=False, help='Write output file in NetCDF via XArray. Default: off') subparser.add_argument('-bz2', '--apply_bz2_compression', action='store_true', default=False, help='Enable bzip2 (.bz2) compression for the regridded file. Default: off.') subparser.add_argument('-zi', '--zarr_input_chunks', type=json.loads, default=None, help='When reading a Zarr, break up the data into chunks. Takes a JSON string.') subparser.add_argument('-zo', '--zarr_output_chunks', type=json.loads, default=None, help='When writing a Zarr, write the data with chunks. Takes a JSON string.') @classmethod def validate_arguments(cls, known_args: argparse.Namespace, pipeline_options: t.List[str]) -> None: if known_args.zarr and known_args.to_netcdf: raise ValueError('only Zarr-to-Zarr regridding is allowed!') if not known_args.zarr and (known_args.zarr_input_chunks or known_args.zarr_output_chunks): raise ValueError('chunks can only be set when input URI is a Zarr.') if known_args.zarr: # Encourage use of correct output_path format. _, out_ext = os.path.splitext(known_args.output_path) if out_ext not in ['', '.zarr']: warnings.warn('if input is a Zarr, the output_path must also be a Zarr.', RuntimeWarning) def target_from(self, uri: str) -> str: """Create the target path from the input URI. In the case of Zarr, the output will be treated like a valid path. For NetCDF, this will change the extension to '.nc'. """ if self.zarr: return self.output_path base = os.path.basename(uri) in_dest = os.path.join(self.output_path, base) if not self.to_netcdf: return in_dest # If we convert to NetCDF, change the extension. no_ext, _ = os.path.splitext(in_dest) return f'{no_ext}.nc' def is_grib_file_corrupt(self, local_grib: str) -> bool: try: # Run grib_ls command to check the file subprocess.check_output(['grib_ls', local_grib]) return False except subprocess.CalledProcessError as e: logger.info(f"Encountered error while reading GRIB: {e}.") return True def path_exists(self, path: str, force_regrid: bool = False) -> bool: """Check if path exists. Pass force_regrid to skip checking.""" if force_regrid: return False matches = FileSystems().match([path]) assert len(matches) == 1 return len(matches[0].metadata_list) > 0 def apply(self, uri: str) -> None: logger.info(f'Regridding from {uri!r} to {self.target_from(uri)!r}.') if self.dry_run: return if self.path_exists(self.target_from(uri), self.force_regrid): logger.info(f"Skipping {uri}.") return with _metview_op(): try: logger.info(f'Copying grib from {uri!r} to local disk.') with open_local(uri) as local_grib: logger.info(f"Checking for {uri}'s validity...") if self.is_grib_file_corrupt(local_grib): logger.error(f"Corrupt GRIB file found: {uri}.") return logger.info(f"No issues found with {uri}.") logger.info(f'Regridding {uri!r} using {self.regrid_kwargs}.') fs = mv.bindings.Fieldset(path=local_grib) fieldset = mv.regrid(data=fs, **self.regrid_kwargs) with tempfile.NamedTemporaryFile() as src: logger.info(f'Writing {self.target_from(uri)!r} to local disk.') if self.to_netcdf: fieldset.to_dataset().to_netcdf(src.name) else: mv.write(src.name, fieldset) src.flush() _clear_metview() logger.info(f'Uploading {self.target_from(uri)!r}.') if self.apply_bz2_compression: logger.info( f'Applying bzip2 compression before copying to {self.target_from(uri)!r} ...' ) subprocess.run(f"bzip2 -k {src.name}".split()) copy(src.name + '.bz2', self.target_from(uri)) logger.info(f'Cleaning up {src.name}.bz2 ...') os.unlink(src.name + '.bz2') # Deleting the tempfile.bz2 file. else: copy(src.name, self.target_from(uri)) except Exception as e: logger.info(f'Regrid failed for {uri!r}. Error: {str(e)}') def expand(self, paths): if not self.zarr: paths | beam.Map(self.apply) return # Since `chunks=None` here, data will be opened lazily upon access. # This is used to get the Zarr metadata without loading the data. source_ds = xr.open_zarr(self.first_uri, **self.zarr_kwargs) regrid_op = RegridChunk(self.regrid_kwargs, self.zarr_input_chunks) regridded = ( paths | xbeam.DatasetToChunks(source_ds, self.zarr_input_chunks) | 'RegridChunk' >> regrid_op ) tmpl = paths | beam.Create([source_ds]) | 'CalcZarrTemplate' >> beam.Map(regrid_op.template) to_write = regridded if self.zarr_output_chunks: to_write |= xbeam.ConsolidateChunks(self.zarr_output_chunks) to_write | xbeam.ChunksToZarr(self.output_path, beam.pvalue.AsSingleton(tmpl), self.zarr_output_chunks) ================================================ FILE: weather_mv/loader_pipeline/regrid_test.py ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import glob import os.path import tempfile import unittest import numpy as np import xarray as xr from apache_beam.testing.test_pipeline import TestPipeline from cfgrib.xarray_to_grib import to_grib from .regrid import Regrid from .sinks_test import TestDataBase try: import metview # noqa except (ModuleNotFoundError, ImportError, FileNotFoundError): raise unittest.SkipTest('MetView dependency is not installed. Skipping tests...') def make_skin_temperature_dataset() -> xr.Dataset: ds = xr.DataArray( np.full((4, 5, 6), 300.), coords=[ np.arange(0, 4), np.linspace(90., -90., 5), np.linspace(0., 360., 6, endpoint=False), ], dims=['time', 'latitude', 'longitude'], ).to_dataset(name='skin_temperature') ds.skin_temperature.attrs['GRIB_shortName'] = 'skt' ds.skin_temperature.attrs['GRIB_gridType'] = 'regular_ll' return ds def metview_cache_exists() -> bool: caches = glob.glob(f'{tempfile.gettempdir()}/mv.*/') return any(os.path.isfile(p) for p in caches) class RegridTest(TestDataBase): # TODO(alxr): Test the quality of the regridding... def setUp(self) -> None: super().setUp() self.tmpdir = tempfile.TemporaryDirectory() self.input_dir = os.path.join(self.tmpdir.name, 'input') self.input_grib = os.path.join(self.input_dir, 'test.gb') os.mkdir(self.input_dir) self.Op = Regrid( output_path=self.tmpdir.name, first_uri=self.input_grib, regrid_kwargs={'grid': [0.25, 0.25]}, dry_run=False, zarr=False, zarr_kwargs={}, ) def tearDown(self) -> None: self.tmpdir.cleanup() def test_target_name(self): actual = self.Op.target_from('path/to/data/called/foobar.gb') self.assertEqual(actual, f'{self.tmpdir.name}/foobar.gb') def test_target_name__to_netCDF__changes_ext(self): Op = dataclasses.replace(self.Op, to_netcdf=True) actual = Op.target_from('path/to/data/called/foobar.gb') self.assertEqual(actual, f'{self.tmpdir.name}/foobar.nc') def test_apply__creates_a_file(self): to_grib(make_skin_temperature_dataset(), self.input_grib) self.assertTrue(os.path.exists(self.input_grib)) self.Op.apply(self.input_grib) self.assertTrue(os.path.exists(f'{self.tmpdir.name}/test.gb')) self.assertFalse(metview_cache_exists()) def test_apply__works_when_called_twice(self): for _ in range(2): self.test_apply__creates_a_file() def test_apply__to_netCDF__creates_a_netCDF_file(self): to_grib(make_skin_temperature_dataset(), self.input_grib) self.assertTrue(os.path.exists(self.input_grib)) Op = dataclasses.replace(self.Op, to_netcdf=True) Op.apply(self.input_grib) expected = f'{self.tmpdir.name}/test.nc' self.assertTrue(os.path.exists(expected)) try: xr.open_dataset(expected) except: # noqa self.fail('Cannot open netCDF with Xarray.') def test_zarr__coarsen(self): input_zarr = os.path.join(self.input_dir, 'input.zarr') output_zarr = os.path.join(self.input_dir, 'output.zarr') xr.open_dataset(os.path.join(self.test_data_folder, 'test_data_20180101.nc')).to_zarr(input_zarr) self.assertTrue(os.path.exists(input_zarr)) Op = dataclasses.replace( self.Op, first_uri=input_zarr, output_path=output_zarr, zarr_input_chunks={"time": 25}, zarr=True ) with TestPipeline() as p: p | Op self.assertTrue(os.path.exists(output_zarr)) try: xr.open_zarr(output_zarr) except: # noqa self.fail('Cannot open Zarr with Xarray.') def test_corrupt_grib_file(self): correct_file_path = os.path.join(self.test_data_folder, 'test_data_grib_single_timestep') corrupt_file_path = os.path.join(self.test_data_folder, 'test_data_corrupt_grib') self.assertFalse(self.Op.is_grib_file_corrupt(correct_file_path)) self.assertTrue(self.Op.is_grib_file_corrupt(corrupt_file_path)) if __name__ == '__main__': unittest.main() ================================================ FILE: weather_mv/loader_pipeline/sinks.py ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import argparse import contextlib import dataclasses import datetime import inspect import logging import os import re import shutil import subprocess import tempfile import time import typing as t from urllib.parse import urlparse import apache_beam as beam import cfgrib import numpy as np import rasterio import rioxarray import xarray as xr from apache_beam.io.filesystem import CompressionTypes, FileSystem, CompressedFile, DEFAULT_READ_BUFFER_SIZE from apache_beam.utils import retry from google.cloud import storage from pyproj import Transformer TIF_TRANSFORM_CRS_TO = "EPSG:4326" # A constant for all the things in the coords key set that aren't the level name. DEFAULT_COORD_KEYS = frozenset(('latitude', 'time', 'step', 'valid_time', 'longitude', 'number')) DEFAULT_TIME_ORDER_LIST = ['%Y', '%m', '%d', '%H', '%M', '%S'] # For uploading / downloading retry logic. INITIAL_DELAY = 1.0 # Initial delay in seconds. MAX_DELAY = 600 # Maximum delay before giving up in seconds. NUM_RETRIES = 10 # Number of tries with exponential backoff. logger = logging.getLogger(__name__) class KwargsFactoryMixin: """Adds a factory method to classes or dataclasses for key-word args.""" @classmethod def from_kwargs(cls, **kwargs): if dataclasses.is_dataclass(cls): fields = [f.name for f in dataclasses.fields(cls)] else: fields = inspect.signature(cls.__init__).parameters.keys() return cls(**{k: v for k, v, in kwargs.items() if k in fields}) @dataclasses.dataclass class ToDataSink(abc.ABC, beam.PTransform, KwargsFactoryMixin): first_uri: str dry_run: bool zarr: bool zarr_kwargs: t.Dict @classmethod @abc.abstractmethod def add_parser_arguments(cls, subparser: argparse.ArgumentParser) -> None: pass @classmethod @abc.abstractmethod def validate_arguments(cls, known_args: argparse.Namespace, pipeline_options: t.List[str]) -> None: pass def _make_grib_dataset_inmem(grib_ds: xr.Dataset) -> xr.Dataset: """Copies all the vars in-memory to reduce disk seeks every time a single row is processed. This also removes the need to keep the backing temp source file around. """ data_ds = grib_ds.copy(deep=True) for v in grib_ds.variables: if v not in data_ds.coords: data_ds[v].variable.values = grib_ds[v].variable.values return data_ds def match_datetime(file_name: str, regex_expression: str) -> datetime.datetime: """Matches the regex string given and extracts the datetime object. Args: file_name: File name from which you want to extract datetime. regex_expression: Regex expression for extracting datetime from the filename. Returns: A datetime object after extracting from the filename. """ def rearrange_time_list(order_list: t.List, time_list: t.List) -> t.List: if order_list == DEFAULT_TIME_ORDER_LIST: return time_list new_time_list = [] for i, j in zip(order_list, time_list): dst = DEFAULT_TIME_ORDER_LIST.index(i) new_time_list.insert(dst, j) return new_time_list char_to_replace = { '%Y': ['([0-9]{4})', [0, 1978]], '%m': ['([0-9]{2})', [1, 1]], '%d': ['([0-9]{2})', [2, 1]], '%H': ['([0-9]{2})', [3, 0]], '%M': ['([0-9]{2})', [4, 0]], '%S': ['([0-9]{2})', [5, 0]], '*': ['.*'] } missing_idx_list = [] temp_expression = regex_expression for key, value in char_to_replace.items(): if key != '*' and regex_expression.find(key) == -1: missing_idx_list.append(value[1]) else: temp_expression = temp_expression.replace(key, value[0]) regex_matches = re.findall(temp_expression, file_name)[0] order_list = [f'%{char}' for char in re.findall(r'%(\w{1})', regex_expression)] time_list = list(map(int, regex_matches)) time_list = rearrange_time_list(order_list, time_list) if missing_idx_list: for [idx, val] in missing_idx_list: time_list.insert(idx, val) return datetime.datetime(*time_list) def _preprocess_tif( ds: xr.Dataset, tif_metadata_for_start_time: str, tif_metadata_for_end_time: str, uri: str, initialization_time_regex: str, forecast_time_regex: str ) -> xr.Dataset: """Transforms (y, x) coordinates into (lat, long) and adds bands data in data variables. This also retrieves datetime from tif's metadata and stores it into dataset. """ def _replace_dataarray_names_with_long_names(ds: xr.Dataset): rename_dict = {var_name: ds[var_name].attrs.get('long_name', var_name) for var_name in ds.variables} return ds.rename(rename_dict) y, x = np.meshgrid(ds['y'], ds['x']) transformer = Transformer.from_crs(ds.spatial_ref.crs_wkt, TIF_TRANSFORM_CRS_TO, always_xy=True) lon, lat = transformer.transform(x, y) ds['y'] = lat[0, :] ds['x'] = lon[:, 0] ds = ds.rename({'y': 'latitude', 'x': 'longitude'}) ds = ds.squeeze().drop_vars('spatial_ref') ds = _replace_dataarray_names_with_long_names(ds) end_time = None start_time = None if initialization_time_regex and forecast_time_regex: try: start_time = match_datetime(uri, initialization_time_regex) except Exception: raise RuntimeError("Wrong regex passed in --initialization_time_regex.") try: end_time = match_datetime(uri, forecast_time_regex) except Exception: raise RuntimeError("Wrong regex passed in --forecast_time_regex.") ds.attrs['start_time'] = start_time ds.attrs['end_time'] = end_time init_time = None forecast_time = None coords = {} try: # if start_time/end_time is in integer milliseconds init_time = (int(start_time.timestamp()) if start_time is not None else int(ds.attrs[tif_metadata_for_start_time]) / 1000.0) coords['time'] = datetime.datetime.utcfromtimestamp(init_time) if tif_metadata_for_end_time: forecast_time = (int(end_time.timestamp()) if end_time is not None else int(ds.attrs[tif_metadata_for_end_time]) / 1000.0) coords['valid_time'] = datetime.datetime.utcfromtimestamp(forecast_time) ds = ds.assign_coords(coords) except KeyError as e: raise RuntimeError(f"Invalid datetime metadata of tif: {e}.") except ValueError: try: # if start_time/end_time is in UTC string format init_time = (int(start_time.timestamp()) if start_time is not None else datetime.datetime.strptime(ds.attrs[tif_metadata_for_start_time], '%Y-%m-%dT%H:%M:%SZ')) coords['time'] = init_time if tif_metadata_for_end_time: forecast_time = (int(end_time.timestamp()) if end_time is not None else datetime.datetime.strptime(ds.attrs[tif_metadata_for_end_time], '%Y-%m-%dT%H:%M:%SZ')) coords['valid_time'] = forecast_time ds = ds.assign_coords(coords) except ValueError as e: raise RuntimeError(f"Invalid datetime value in tif's metadata: {e}.") return ds def _to_utc_timestring(np_time: np.datetime64) -> str: """Turn a numpy datetime64 into UTC timestring.""" timestamp = float((np_time - np.datetime64(0, 's')) / np.timedelta64(1, 's')) return datetime.datetime.utcfromtimestamp(timestamp).strftime('%Y-%m-%dT%H:%M:%SZ') def _add_is_normalized_attr(ds: xr.Dataset, value: bool) -> xr.Dataset: """Adds is_normalized to the attrs of the xarray.Dataset. This attribute represents if the dataset is the merged dataset (i.e. created by combining N datasets, specifically for normalizing grib's schema) or not. """ ds.attrs['is_normalized'] = value return ds def _is_3d_da(da): """Checks whether data array is 3d or not.""" return len(da.shape) == 3 def __normalize_grib_dataset(filename: str, group_common_hypercubes: t.Optional[bool] = False) -> t.Union[xr.Dataset, t.List[xr.Dataset]]: """Reads a list of datasets and merge them into a single dataset.""" _level_data_dict = {} list_ds = cfgrib.open_datasets(filename) ds_attrs = list_ds[0].attrs dv_units_dict = {} for ds in list_ds: coords_set = set(ds.coords.keys()) level_set = coords_set.difference(DEFAULT_COORD_KEYS) level = level_set.pop() # Now look at what data vars are in each level. for key in ds.data_vars.keys(): da = ds[key] # The data array attrs = da.attrs # The metadata for this dataset. # Also figure out the forecast hour for this file. forecast_hour = int(da.step.values / np.timedelta64(1, 'h')) # We are going to treat the time field as start_time and the # valid_time field as the end_time for EE purposes. Also, get the # times into UTC timestrings. start_time = _to_utc_timestring(da.time.values) end_time = _to_utc_timestring(da.valid_time.values) attrs['forecast_hour'] = forecast_hour # Stick the forecast hour in the metadata as well, that's useful. attrs['start_time'] = start_time attrs['end_time'] = end_time if group_common_hypercubes: attrs['level'] = level # Adding the level in the metadata, will remove in further steps. attrs['is_normalized'] = True # Adding the 'is_normalized' attribute in the metadata. if level not in _level_data_dict: _level_data_dict[level] = [] no_of_levels = da.shape[0] if _is_3d_da(da) else 1 # Deal with the randomness that is 3d data interspersed with 2d. # For 3d data, we need to extract ds for each value of level. for sub_c in range(no_of_levels): copied_da = da.copy(deep=True) height = copied_da.coords[level].data.flatten()[sub_c] # Some heights are super small, but we can't have decimal points # in channel names & schema fields for Earth Engine & BigQuery respectively , so mostly cut off the # fractional part, unless we are forced to keep it. If so, # replace the decimal point with yet another underscore. if height >= 10: height_string = f'{height:.0f}' else: height_string = f'{height:.2f}'.replace('.', '_') channel_name = f'{level}_{height_string}_{attrs["GRIB_stepType"]}_{key}' logger.debug('Found channel %s', channel_name) # Add the height as a metadata field, that seems useful. copied_da.attrs['height'] = height_string # Add the units of each band as a metadata field. dv_units_dict['unit_'+channel_name] = None if 'units' in attrs: dv_units_dict['unit_'+channel_name] = attrs['units'] copied_da.name = channel_name if _is_3d_da(da): copied_da = copied_da.sel({level: height}) copied_da = copied_da.drop_vars(level) _level_data_dict[level].append(copied_da) _data_array_list = [] _data_array_list = [xr.merge(list_da) for list_da in _level_data_dict.values()] if not group_common_hypercubes: # Stick the forecast hour, start_time, end_time, data variables units # in the ds attrs as well, that's useful. ds_attrs['forecast_hour'] = _data_array_list[0].attrs['forecast_hour'] ds_attrs['start_time'] = _data_array_list[0].attrs['start_time'] ds_attrs['end_time'] = _data_array_list[0].attrs['end_time'] ds_attrs.update(**dv_units_dict) merged_dataset = xr.merge(_data_array_list) merged_dataset.attrs.clear() merged_dataset.attrs.update(ds_attrs) return merged_dataset return _data_array_list def __open_dataset_file(filename: str, uri_extension: str, disable_grib_schema_normalization: bool, open_dataset_kwargs: t.Optional[t.Dict] = None, group_common_hypercubes: t.Optional[bool] = False) -> t.Union[xr.Dataset, t.List[xr.Dataset]]: """Opens the dataset at 'uri' and returns a xarray.Dataset.""" # add a flag to group common hypercubes if group_common_hypercubes: return __normalize_grib_dataset(filename, group_common_hypercubes) if open_dataset_kwargs: return _add_is_normalized_attr(xr.open_dataset(filename, **open_dataset_kwargs), False) # If URI extension is .tif, try opening file by specifying engine="rasterio". if uri_extension in ['.tif', '.tiff']: return _add_is_normalized_attr(rioxarray.open_rasterio(filename, band_as_variable=True), False) # If no open kwargs are available and URI extension is other than tif, make educated guesses about the dataset. try: return _add_is_normalized_attr(xr.open_dataset(filename), False) except ValueError as e: e_str = str(e) if not ("Consider explicitly selecting one of the installed engines" in e_str and "cfgrib" in e_str): raise if not disable_grib_schema_normalization: logger.warning("Assuming grib.") logger.info("Normalizing the grib schema, name of the data variables will look like " "'___'.") return _add_is_normalized_attr(__normalize_grib_dataset(filename), True) # Trying with explicit engine for cfgrib. try: return _add_is_normalized_attr( xr.open_dataset(filename, engine='cfgrib', backend_kwargs={'indexpath': ''}), False) except ValueError as e: if "multiple values for key 'edition'" not in str(e): raise logger.warning("Assuming grib edition 1.") # Try with edition 1 # Note: picking edition 1 for now as it seems to get the most data/variables for ECMWF realtime data. return _add_is_normalized_attr( xr.open_dataset(filename, engine='cfgrib', backend_kwargs={'filter_by_keys': {'edition': 1}, 'indexpath': ''}), False) @retry.with_exponential_backoff( num_retries=NUM_RETRIES, logger=logger.warning, initial_delay_secs=INITIAL_DELAY, max_delay_secs=MAX_DELAY ) def copy(src: str, dst: str) -> None: """Copy data via `gsutil` or local filesystem.""" is_gs = src.startswith("gs://") or dst.startswith("gs://") try: if is_gs: subprocess.run(['gcloud', 'storage', 'cp', src, dst], check=True, capture_output=True, text=True, input="n/n") else: os.makedirs(os.path.dirname(dst) or '.', exist_ok=True) shutil.copy(src, dst) except Exception as e: error_detail = getattr(e, "stderr", str(e)).strip() msg = f"Failed to copy {src!r} to {dst!r} due to {error_detail}" logger.error(msg) raise EnvironmentError(msg) from e @contextlib.contextmanager def open_local(uri: str) -> t.Iterator[str]: """Copy a cloud object (e.g. a netcdf, grib, or tif file) from cloud storage, like GCS, to local file.""" with tempfile.NamedTemporaryFile() as dest_file: # Transfer data with gsutil. copy(uri, dest_file.name) # Check if data is compressed. Decompress the data using the same methods that beam's # FileSystems interface uses. compression_type = FileSystem._get_compression_type(uri, CompressionTypes.AUTO) if compression_type == CompressionTypes.UNCOMPRESSED: yield dest_file.name return dest_file.seek(0) with tempfile.NamedTemporaryFile() as dest_uncompressed: with CompressedFile(open(dest_file.name, 'rb'), compression_type=compression_type) as dcomp: shutil.copyfileobj(dcomp, dest_uncompressed, DEFAULT_READ_BUFFER_SIZE) dest_uncompressed.seek(0) # Reposition the file pointer to the start. yield dest_uncompressed.name @contextlib.contextmanager def open_dataset(uri: str, open_dataset_kwargs: t.Optional[t.Dict] = None, disable_grib_schema_normalization: bool = False, tif_metadata_for_start_time: t.Optional[str] = None, tif_metadata_for_end_time: t.Optional[str] = None, initialization_time_regex: t.Optional[str] = None, forecast_time_regex: t.Optional[str] = None, group_common_hypercubes: t.Optional[bool] = False, is_zarr: bool = False) -> t.Iterator[xr.Dataset]: """Open the dataset at 'uri' and return a xarray.Dataset.""" try: local_open_dataset_kwargs = start_date = end_date = None if open_dataset_kwargs is not None: local_open_dataset_kwargs = open_dataset_kwargs.copy() start_date = local_open_dataset_kwargs.pop('start_date', None) end_date = local_open_dataset_kwargs.pop('end_date', None) if is_zarr: ds: xr.Dataset = _add_is_normalized_attr(xr.open_dataset(uri, engine='zarr', **local_open_dataset_kwargs), False) if start_date is not None and end_date is not None: ds = ds.sel(time=slice(start_date, end_date)) beam.metrics.Metrics.counter('Success', 'ReadNetcdfData').inc() yield ds ds.close() return with open_local(uri) as local_path: _, uri_extension = os.path.splitext(uri) xr_datasets: xr.Dataset = __open_dataset_file(local_path, uri_extension, disable_grib_schema_normalization, local_open_dataset_kwargs, group_common_hypercubes) # Extracting dtype, crs and transform from the dataset. rasterio_error = False try: with rasterio.open(local_path, 'r') as f: dtype, crs, transform = (f.profile.get(key) for key in ['dtype', 'crs', 'transform']) except rasterio.errors.RasterioIOError: rasterio_error = True logger.warning('Cannot parse projection and data type information for Dataset %r.', uri) if group_common_hypercubes: total_size_in_bytes = 0 for xr_dataset in xr_datasets: if not rasterio_error: xr_dataset.attrs.update({'dtype': dtype, 'crs': crs, 'transform': transform}) total_size_in_bytes += xr_dataset.nbytes logger.info(f'opened dataset size: {total_size_in_bytes}') else: xr_dataset = xr_datasets if start_date is not None and end_date is not None: xr_dataset = xr_datasets.sel(time=slice(start_date, end_date)) if uri_extension in ['.tif', '.tiff']: xr_dataset = _preprocess_tif(xr_dataset, tif_metadata_for_start_time, tif_metadata_for_end_time, uri, initialization_time_regex, forecast_time_regex) if not rasterio_error: # Extracting dtype, crs and transform from the dataset & storing them as attributes. xr_dataset.attrs.update({'dtype': dtype, 'crs': crs, 'transform': transform}) logger.info(f'opened dataset size: {xr_dataset.nbytes}') beam.metrics.Metrics.counter('Success', 'ReadNetcdfData').inc() yield xr_datasets if group_common_hypercubes else xr_dataset # Releasing any resources linked to the object(s). if group_common_hypercubes: for xr_dataset in xr_datasets: xr_dataset.close() else: xr_dataset.close() except Exception as e: beam.metrics.Metrics.counter('Failure', 'ReadNetcdfData').inc() logger.error(f'Unable to open file {uri!r}: {e}') raise def get_file_time(element: t.Any) -> int: """Calculates element file's write timestamp in UTC.""" try: element_parsed = urlparse(element) if element_parsed.scheme == "gs": # For file in Google cloud storage. client = storage.Client() bucket = client.get_bucket(element_parsed.netloc) blob = bucket.get_blob(element_parsed.path[1:]) updated_time = int(blob.updated.timestamp()) else: # For file in local. file_stats = os.stat(element) updated_time = int(time.mktime(time.gmtime(file_stats.st_mtime))) return updated_time except Exception as e: raise ValueError( f"Error fetching raw data file bucket time for {element!r}. Error: {str(e)}" ) ================================================ FILE: weather_mv/loader_pipeline/sinks_test.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import datetime from functools import wraps import numpy as np import os import tempfile import tracemalloc import unittest import xarray as xr import weather_mv from .sinks import match_datetime, open_dataset class TestDataBase(unittest.TestCase): def setUp(self) -> None: self.test_data_folder = f'{next(iter(weather_mv.__path__))}/test_data' def _handle_missing_grib_be(f): @wraps(f) def decorated(*args, **kwargs): try: return f(*args, **kwargs) except ValueError as e: # Some setups may not have Cfgrib installed properly. Ignore tests for these cases. e_str = str(e) if "Consider explicitly selecting one of the installed engines" not in e_str or "cfgrib" in e_str: raise return decorated @contextlib.contextmanager def limit_memory(max_memory=30): ''''Measure memory consumption of the function. 'memory limit' in MB ''' try: tracemalloc.start() yield finally: current, peak = tracemalloc.get_traced_memory() tracemalloc.stop() assert peak / 1024 ** 2 <= max_memory, f"Memory usage {peak / 1024 ** 2} exceeded {max_memory} MB limit." @contextlib.contextmanager def write_netcdf(): """Generates temporary netCDF file using xarray.""" lat_dim = 3210 lon_dim = 3440 lat = np.linspace(-90, 90, lat_dim) lon = np.linspace(-180, 180, lon_dim) data_arr = np.random.uniform(low=0, high=0.1, size=(5, lat_dim, lon_dim)) ds = xr.Dataset( {"var_1": (('time', 'lat', 'lon'), data_arr)}, coords={ "lat": lat, "lon": lon, }) with tempfile.NamedTemporaryFile() as fp: ds.to_netcdf(fp.name) yield fp.name class OpenDatasetTest(TestDataBase): def setUp(self) -> None: super().setUp() self.test_data_path = os.path.join(self.test_data_folder, 'test_data_20180101.nc') self.test_grib_path = os.path.join(self.test_data_folder, 'test_data_grib_single_timestep') self.test_tif_path = os.path.join(self.test_data_folder, 'test_data_tif_time.tif') self.test_zarr_path = os.path.join(self.test_data_folder, 'test_data.zarr') def test_opens_grib_files(self): with open_dataset(self.test_grib_path) as ds1: self.assertIsNotNone(ds1) self.assertDictContainsSubset({'is_normalized': True}, ds1.attrs) with open_dataset(self.test_grib_path, disable_grib_schema_normalization=True) as ds2: self.assertIsNotNone(ds2) self.assertDictContainsSubset({'is_normalized': False}, ds2.attrs) def test_accepts_xarray_kwargs(self): with open_dataset(self.test_data_path) as ds1: self.assertIn('d2m', ds1) self.assertDictContainsSubset({'is_normalized': False}, ds1.attrs) with open_dataset(self.test_data_path, {'drop_variables': 'd2m'}) as ds2: self.assertNotIn('d2m', ds2) self.assertDictContainsSubset({'is_normalized': False}, ds2.attrs) def test_opens_tif_files(self): with open_dataset(self.test_tif_path, tif_metadata_for_start_time='start_time', tif_metadata_for_end_time='end_time') as ds: self.assertIsNotNone(ds) self.assertDictContainsSubset({'is_normalized': False}, ds.attrs) def test_opens_zarr(self): with open_dataset(self.test_zarr_path, is_zarr=True, open_dataset_kwargs={}) as ds: self.assertIsNotNone(ds) self.assertEqual(list(ds.data_vars), ['cape', 'd2m']) def test_open_dataset__fits_memory_bounds(self): with write_netcdf() as test_netcdf_path: with limit_memory(max_memory=30): with open_dataset(test_netcdf_path) as _: pass def test_group_common_hypercubes(self): with open_dataset(self.test_grib_path, group_common_hypercubes=True) as ds: self.assertEqual(isinstance(ds, list), True) class DatetimeTest(unittest.TestCase): def test_datetime_regex_string(self): file_name = '3B-HHR-E_MS_MRG_3IMERG_20220901-S000000-E002959_0000_V06C_30min.tiff' regex_str = '3B-HHR-E_MS_MRG_3IMERG_%Y%m%d-S%H%M%S-*.tiff' expected = datetime.datetime.strptime('2022-09-01 00:00:00', '%Y-%m-%d %H:%M:%S') actual = match_datetime(file_name, regex_str) self.assertEqual(actual, expected) def test_datetime_regex_string_with_missing_parameters(self): file_name = '3B-HHR-E_MS_MRG_3IMERG_0901-S000000-E002959_0000_V06C_30min.tiff' regex_str = '3B-HHR-E_MS_MRG_3IMERG_%m%d-S%H%M%S-*.tiff' expected = datetime.datetime.strptime('1978-09-01 00:00:00', '%Y-%m-%d %H:%M:%S') actual = match_datetime(file_name, regex_str) self.assertEqual(actual, expected) def test_datetime_regex_string_with_different_order(self): file_name = '3B-HHR-E_MS_MRG_3IMERG_09012022-S000000-E002959_0000_V06C_30min.tiff' regex_str = '3B-HHR-E_MS_MRG_3IMERG_%m%d%Y-S%H%M%S-*.tiff' expected = datetime.datetime.strptime('2022-09-01 00:00:00', '%Y-%m-%d %H:%M:%S') actual = match_datetime(file_name, regex_str) self.assertEqual(actual, expected) if __name__ == '__main__': unittest.main() ================================================ FILE: weather_mv/loader_pipeline/streaming.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Window and parse Pub/Sub streams of real-time weather data added to cloud storage. Example windowing code borrowed from: https://cloud.google.com/pubsub/docs/pubsub-dataflow#code_sample """ import datetime import fnmatch import json import logging import random import typing as t from urllib.parse import urlparse import apache_beam as beam from apache_beam.transforms.window import FixedWindows logger = logging.getLogger(__name__) class GroupMessagesByFixedWindows(beam.PTransform): """A composite transform that groups Pub/Sub messages based on publish time and outputs a list of tuples, each containing a message and its publish time. """ def __init__(self, window_size: int, num_shards: int = 5): # Set window size to 60 seconds. self.window_size = int(window_size * 60) self.num_shards = num_shards def expand(self, pcoll): return ( pcoll # Bind window info to each element using element timestamp (or publish time). | "Window into fixed intervals" >> beam.WindowInto(FixedWindows(self.window_size)) | "Add timestamp to windowed elements" >> beam.ParDo(AddTimestamp()) # Assign a random key to each windowed element based on the number of shards. | "Add key" >> beam.WithKeys(lambda _: random.randint(0, self.num_shards - 1)) # Group windowed elements by key. All the elements in the same window must fit # memory for this. If not, you need to use `beam.util.BatchElements`. | "Group by key" >> beam.GroupByKey() ) class AddTimestamp(beam.DoFn): """Processes each windowed element by extracting the message body and its publish time into a tuple. """ def process(self, element, publish_time=beam.DoFn.TimestampParam) -> t.Iterable[t.Tuple[str, str]]: yield ( element.decode("utf-8"), datetime.datetime.utcfromtimestamp(float(publish_time)).strftime( "%Y-%m-%d %H:%M:%S.%f" ), ) class ParsePaths(beam.DoFn): """Parse paths to real-time weather data from windowed-batches.""" def __init__(self, uri_pattern: str): self.uri_pattern = uri_pattern self.protocol = f'{urlparse(uri_pattern).scheme}://' super().__init__() @classmethod def try_parse_message(cls, message_body: t.Union[str, t.Dict]) -> t.Dict: """Robustly parse message body, which will be JSON in the vast majority of cases, but might be a dictionary.""" try: return json.loads(message_body) except (json.JSONDecodeError, TypeError): if isinstance(message_body, dict): return message_body raise def to_object_path(self, payload: t.Dict) -> str: """Parse cloud object from Pub/Sub topic payload.""" return f'{self.protocol}{payload["bucket"]}/{payload["name"]}' def should_skip(self, message_body: t.Dict) -> bool: """Returns true if Pub/Sub topic does *not* match the target file URI pattern.""" try: return not fnmatch.fnmatch(self.to_object_path(message_body), self.uri_pattern) except KeyError: return True def process(self, key_value, window=beam.DoFn.WindowParam) -> t.Iterable[str]: """Yield paths to real-time weather data in cloud storage.""" shard_id, batch = key_value logger.debug(f'Processing shard {shard_id!r}.') for message_body, publish_time in batch: logger.debug(message_body) parsed_msg = self.try_parse_message(message_body) target = self.to_object_path(parsed_msg) logger.info(f'Parsed path {target!r}...') if self.should_skip(parsed_msg): logger.info(f'skipping {target!r}.') continue yield target ================================================ FILE: weather_mv/loader_pipeline/streaming_test.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import unittest from .streaming import ParsePaths class ParsePathsTests(unittest.TestCase): def setUp(self) -> None: self.parser = ParsePaths('gs://XXXX/tmp/*') self.test_input = """{"bucket": "XXXX", "name": "tmp/T1D10091200101309001"}""" media_link = 'https://www.googleapis.com/download/storage/v1/b/XXXX/o/tmp%2FT1D10091200101309001?generation' \ '=1635366553038121&alt=media' self.real_input = f""" {{ "kind": "storage#object", "id": "XXXX/tmp/T1D10091200101309001/1635366553038121", "selfLink": "https://www.googleapis.com/storage/v1/b/XXXX/o/tmp%2FT1D10091200101309001", "name": "tmp/T1D10091200101309001", "bucket": "XXXX", "generation": "1635366553038121", "metageneration": "1", "contentType": "application/octet-stream", "timeCreated": "2021-10-27T20:29:13.152Z", "updated": "2021-10-27T20:29:13.152Z", "storageClass": "STANDARD", "timeStorageClassUpdated": "2021-10-27T20:29:13.152Z", "size": "9725508", "md5Hash": "qrMcuK4nTr9uCD7aJJqtkA==", "mediaLink": "{media_link}", "crc32c": "zKlm3w==", "etag": "CKna4pO36/MCEAE=" }}""" def test_parse_message(self): actual = ParsePaths.try_parse_message(self.test_input) self.assertEqual(actual, {'bucket': 'XXXX', 'name': 'tmp/T1D10091200101309001'}) def test_parse_message__already_is_dict(self): actual = ParsePaths.try_parse_message({'bucket': 'XXXX', 'name': 'tmp/T1D10091200101309001'}) self.assertEqual(actual, {'bucket': 'XXXX', 'name': 'tmp/T1D10091200101309001'}) def test_parse_message__bad_json(self): with self.assertRaises(json.JSONDecodeError): ParsePaths.try_parse_message("""{"foo": 1, "bar": 2""") def test_parse_message__round_trip(self): parsed = ParsePaths.try_parse_message(self.real_input) converted = json.dumps(parsed) re_parsed = ParsePaths.try_parse_message(converted) self.assertEqual(parsed, re_parsed) def test_should_skip(self): parsed = self.parser.try_parse_message(self.real_input) self.assertFalse(self.parser.should_skip(parsed)) def test_should_skip__missing_values(self): missing_field = """{"name": "tmp/T1D10091200101309001"}""" parsed = self.parser.try_parse_message(missing_field) self.assertTrue(self.parser.should_skip(parsed)) def test_should_skip__mismatch_pattern(self): self.parser = ParsePaths('gs://XXXX/foo/*') parsed = self.parser.try_parse_message(self.test_input) self.assertTrue(self.parser.should_skip(parsed)) ================================================ FILE: weather_mv/loader_pipeline/util.py ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import datetime import inspect import itertools import json import logging import math import re import signal import sys import tempfile import time import traceback import typing as t import uuid from functools import partial from urllib.parse import urlparse import apache_beam as beam import numpy as np import pandas as pd import xarray as xr from google.api_core.exceptions import BadRequest from google.api_core.exceptions import NotFound from google.cloud import bigquery, storage from xarray.core.utils import ensure_us_time_resolution from .sinks import DEFAULT_COORD_KEYS from .metrics import timeit logger = logging.getLogger(__name__) CANARY_BUCKET_NAME = 'anthromet_canary_bucket' CANARY_RECORD = {'foo': 'bar'} CANARY_RECORD_FILE_NAME = 'canary_record.json' CANARY_OUTPUT_TABLE_SUFFIX = '_anthromet_canary_table' CANARY_TABLE_SCHEMA = [bigquery.SchemaField('name', 'STRING', mode='NULLABLE')] BQ_EXCLUDE_COORDS = {'longitude', 'latitude'} def make_attrs_ee_compatible(attrs: t.Dict) -> t.Dict: """Scans EEE asset attributes and makes them EE compatible. EE asset attribute names must contain only the following characters: A..Z, a..z, 0..9 or '_'. Maximum length is 110 characters. Attribute values must be string or number. If an attribute name is more than 110 characters, it will consider the first 110 characters as the attribute name. """ new_attrs = {} for k, v in attrs.items(): if len(k) > 110: # Truncate attribute name with > 110 characters. k = k[:110] # Replace unaccepted characters with underscores. k = re.sub(r'[^a-zA-Z0-9-_]+', r'_', k) if type(v) not in [int, float]: v = str(v) if len(v) > 1024: v = f'{v[:1021]}...' # Since 1 char = 1 byte. v = to_json_serializable_type(v) new_attrs[k] = v return new_attrs # TODO(#245): Group with common utilities (duplicated) def to_json_serializable_type(value: t.Any) -> t.Any: """Returns the value with a type serializable to JSON""" # Note: The order of processing is significant. logger.debug('Serializing to JSON') # pd.isna() returns ndarray if input is not scalar therefore checking if value is scalar. if (np.isscalar(value) and pd.isna(value)) or value is None: return None elif np.issubdtype(type(value), np.floating): return float(value) elif isinstance(value, set): value = list(value) return np.where(pd.isna(value), None, value).tolist() elif isinstance(value, np.ndarray): # Will return a scaler if array is of size 1, else will return a list. # Replace all NaNs, NaTs with None. return np.where(pd.isna(value), None, value).tolist() elif isinstance(value, datetime.datetime) or isinstance(value, str) or isinstance(value, np.datetime64): # Assume strings are ISO format timestamps... try: value = datetime.datetime.fromisoformat(value) except ValueError: # ... if they are not, assume serialization is already correct. return value except TypeError: # ... maybe value is a numpy datetime ... try: value = ensure_us_time_resolution(value).astype(datetime.datetime) except AttributeError: # ... value is a datetime object, continue. pass # We use a string timestamp representation. if value.tzname(): return value.isoformat() # We assume here that naive timestamps are in UTC timezone. return value.replace(tzinfo=datetime.timezone.utc).isoformat() elif isinstance(value, datetime.timedelta): return value.total_seconds() elif isinstance(value, np.timedelta64): # Return time delta in seconds. return float(value / np.timedelta64(1, 's')) # This check must happen after processing np.timedelta64 and np.datetime64. elif np.issubdtype(type(value), np.integer): return int(value) return value def _check_for_coords_vars(ds_data_var: str, target_var: str) -> bool: """Checks if the dataset's data variable matches with the target variables (or coordinates) specified by the user.""" return ds_data_var.endswith('_'+target_var) or ds_data_var.startswith(target_var+'_') def get_utc_timestamp() -> float: """Returns the current UTC Timestamp.""" return datetime.datetime.now().timestamp() def _only_target_coordinate_vars(ds: xr.Dataset, data_vars: t.List[str]) -> t.List[str]: """If the user specifies target fields in the dataset, get all the matching coords & data vars.""" # If the dataset is not the merged dataset (created for normalizing grib's schema), # return the target fields specified by the user as it is. if not ds.attrs['is_normalized']: return data_vars keep_coords_vars = [] for dv in data_vars: keep_coords_vars.extend([v for v in ds.data_vars if _check_for_coords_vars(v, dv)]) keep_coords_vars.extend([v for v in DEFAULT_COORD_KEYS if v in dv]) return keep_coords_vars def _only_target_vars(ds: xr.Dataset, data_vars: t.Optional[t.List[str]] = None) -> xr.Dataset: """If the user specifies target fields in the dataset, create a schema only from those fields.""" # If there are no restrictions on data vars, include the whole dataset. if not data_vars: logger.info(f'target data_vars empty; using whole dataset; size: {ds.nbytes}') return ds if not ds.attrs['is_normalized']: assert all([dv in ds.data_vars or dv in ds.coords for dv in data_vars]), 'Target variable must be in original '\ 'dataset. ' dropped_ds = ds.drop_vars([v for v in ds.data_vars if v not in data_vars]) else: drop_vars = [] check_target_variable = [] for dv in data_vars: searched_data_vars = [_check_for_coords_vars(v, dv) for v in ds.data_vars] searched_coords = [] if dv not in DEFAULT_COORD_KEYS else [dv in ds.coords] check_target_variable.append(any(searched_data_vars) or any(searched_coords)) assert all(check_target_variable), 'Target variable must be in original dataset.' for v in ds.data_vars: searched_data_vars = [_check_for_coords_vars(v, dv) for dv in data_vars] if not any(searched_data_vars): drop_vars.append(v) dropped_ds = ds.drop_vars(drop_vars) logger.info(f'target-only dataset size: {dropped_ds.nbytes}') return dropped_ds def ichunked(iterable: t.Iterable, n: int) -> t.Iterator[t.Iterable]: """Yield evenly-sized chunks from an iterable.""" input_ = iter(iterable) try: while True: it = itertools.islice(input_, n) # peek to check if 'it' has next item. first = next(it) yield itertools.chain([first], it) except StopIteration: pass def get_coordinates(ds: xr.Dataset, uri: str = '') -> t.Iterator[t.Dict]: """Generates normalized coordinate dictionaries that can be used to index Datasets with `.loc[]`.""" # Creates flattened iterator of all coordinate positions in the Dataset. # # Example: (datetime.datetime('2018-01-02T22:00:00+00:00')) # Filter out excluded coordinates from coords_indexes. filtered_coords_indexes = [c for c in ds.coords.indexes if c not in BQ_EXCLUDE_COORDS] # Filter out excluded coordinates from coords_dims. filtered_coords_dims = {c: ds.coords.dims [c] for c in ds.coords.dims if c not in BQ_EXCLUDE_COORDS} coords = itertools.product( *( ( v for v in ensure_us_time_resolution(ds[c].variable.values).tolist() ) for c in filtered_coords_indexes ) ) # Give dictionary keys to a coordinate index. # # Example: # {'time': datetime.datetime('2018-01-02T23:00:00+00:00')} idx = 0 total_coords = math.prod(filtered_coords_dims.values()) for idx, it in enumerate(coords): logger.info(f'Processed {idx+1} / {total_coords} coordinates for {uri!r}...') yield dict(zip(filtered_coords_indexes, it)) logger.info(f'Finished processing all {idx+1} coordinates.') def _cleanup_bigquery(bigquery_client: bigquery.Client, canary_output_table: str, sig: t.Optional[t.Any] = None, frame: t.Optional[t.Any] = None) -> None: """Deletes the bigquery table.""" if bigquery_client: bigquery_client.delete_table(canary_output_table, not_found_ok=True) if sig: traceback.print_stack(frame) sys.exit(0) def _cleanup_bucket(storage_client: storage.Client, canary_bucket_name: str, sig: t.Optional[t.Any] = None, frame: t.Optional[t.Any] = None) -> None: """Deletes the bucket.""" try: storage_client.get_bucket(canary_bucket_name).delete(force=True) except NotFound: pass if sig: traceback.print_stack(frame) sys.exit(0) def validate_region(output_table: t.Optional[str] = None, temp_location: t.Optional[str] = None, region: t.Optional[str] = None) -> None: """Validates non-compatible regions scenarios by performing sanity check.""" if not region and not temp_location: raise ValueError('Invalid GCS location: None.') bucket_region = region storage_client = storage.Client() canary_bucket_name = CANARY_BUCKET_NAME + str(uuid.uuid4()) # Doing cleanup if operation get cut off midway. # TODO : Should we handle some other signals ? do_bucket_cleanup = partial(_cleanup_bucket, storage_client, canary_bucket_name) original_sigint_handler = signal.getsignal(signal.SIGINT) original_sigtstp_handler = signal.getsignal(signal.SIGTSTP) signal.signal(signal.SIGINT, do_bucket_cleanup) signal.signal(signal.SIGTSTP, do_bucket_cleanup) if output_table: table_region = None bigquery_client = bigquery.Client() canary_output_table = output_table + CANARY_OUTPUT_TABLE_SUFFIX + str(uuid.uuid4()) do_bigquery_cleanup = partial(_cleanup_bigquery, bigquery_client, canary_output_table) signal.signal(signal.SIGINT, do_bigquery_cleanup) signal.signal(signal.SIGTSTP, do_bigquery_cleanup) if temp_location: parsed_temp_location = urlparse(temp_location) if parsed_temp_location.scheme != 'gs' or parsed_temp_location.netloc == '': raise ValueError(f'Invalid GCS location: {temp_location!r}.') bucket_name = parsed_temp_location.netloc bucket_region = storage_client.get_bucket(bucket_name).location try: bucket = storage_client.create_bucket(canary_bucket_name, location=bucket_region) with tempfile.NamedTemporaryFile(mode='w+') as temp: json.dump(CANARY_RECORD, temp) temp.flush() blob = bucket.blob(CANARY_RECORD_FILE_NAME) blob.upload_from_filename(temp.name) if output_table: table = bigquery.Table(canary_output_table, schema=CANARY_TABLE_SCHEMA) table = bigquery_client.create_table(table, exists_ok=True) table_region = table.location load_job = bigquery_client.load_table_from_uri( f'gs://{canary_bucket_name}/{CANARY_RECORD_FILE_NAME}', canary_output_table, ) load_job.result() except BadRequest: if output_table: raise RuntimeError(f'Can\'t migrate from source: {bucket_region} to destination: {table_region}') raise RuntimeError(f'Can\'t upload to destination: {bucket_region}') finally: _cleanup_bucket(storage_client, canary_bucket_name) if output_table: _cleanup_bigquery(bigquery_client, canary_output_table) signal.signal(signal.SIGINT, original_sigint_handler) signal.signal(signal.SIGINT, original_sigtstp_handler) def _shard(elem, num_shards: int): return (np.random.randint(0, num_shards), elem) class Shard(beam.DoFn): """DoFn to shard elements into groups.""" def __init__(self, num_shards: int, use_metrics: bool): super().__init__() self.num_shards = num_shards self.use_metrics = use_metrics @timeit('Sharding', keyed_fn=True) def process(self, element, *args, **kwargs): yield _shard(element, num_shards=self.num_shards) class RateLimit(beam.PTransform, abc.ABC): """PTransform to extend to apply a global rate limit to an operation. The input PCollection and be of any type and the output will be whatever is returned by the `process` method. """ def __init__(self, global_rate_limit_qps: int, latency_per_request: float, max_concurrent_requests: int, use_metrics: bool): """Creates a RateLimit object. global_rate_limit_qps and latency_per_request are used to determine how the data should be sharded via: global_rate_limit_qps * latency_per_request.total_seconds() For example, global_rate_limit_qps = 500 and latency_per_request=.5 seconds. Then the data will be sharded into 500*.5=250 groups. Each group can be processed in parallel and will call the 'process' function at most once every latency_per_request. It is important to note that the max QPS may not be reach based on how many workers are scheduled. Args: global_rate_limit_qps: QPS to rate limit requests across all workers to. latency_per_request: The expected latency per request. max_concurrent_requests: Maximum allowed concurrent api requests to EE. """ self._rate_limit = global_rate_limit_qps self._latency_per_request = datetime.timedelta(seconds=latency_per_request) self._num_shards = max(1, min(int(self._rate_limit * self._latency_per_request.total_seconds()), max_concurrent_requests)) self.use_metrics = use_metrics @abc.abstractmethod def process(self, elem: t.Any): """Process is the operation that will be rate limited. Results will be yielded each time time the process method is called. Args: elem: The individual element to process. Returns: Output can be anything, output will be the output of the RateLimit PTransform. """ pass def expand(self, pcol: beam.PCollection): return (pcol | beam.ParDo(Shard(num_shards=self._num_shards, use_metrics=self.use_metrics)) | beam.GroupByKey() | beam.ParDo( _RateLimitDoFn(self.process, self._latency_per_request))) class _RateLimitDoFn(beam.DoFn): """DoFn that ratelimits calls to rate_limit_fn.""" def __init__(self, rate_limit_fn: t.Callable, wait_time: datetime.timedelta): self._rate_limit_fn = rate_limit_fn self._wait_time = wait_time self._is_generator = inspect.isgeneratorfunction(self._rate_limit_fn) # type: ignore def process(self, keyed_elem: t.Tuple[t.Any, t.Iterable[t.Any]]): shard, elems = keyed_elem logger.info(f'processing shard: {shard}') start_time = datetime.datetime.now() end_time = None for elem in elems: if end_time is not None and (end_time - start_time) < self._wait_time: logger.info(f'previous operation took: {(end_time - start_time).total_seconds()}') wait_time = (self._wait_time - (end_time - start_time)) logger.info(f'wating: {wait_time.total_seconds()}') time.sleep(wait_time.total_seconds()) start_time = datetime.datetime.now() if self._is_generator: yield from self._rate_limit_fn(elem) else: yield self._rate_limit_fn(elem) end_time = datetime.datetime.now() ================================================ FILE: weather_mv/loader_pipeline/util_test.py ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools import unittest from collections import Counter from datetime import datetime, timezone, timedelta import xarray import xarray as xr import numpy as np from .sinks_test import TestDataBase from .util import ( get_coordinates, ichunked, make_attrs_ee_compatible, to_json_serializable_type, ) class GetCoordinatesTest(TestDataBase): def setUp(self) -> None: super().setUp() self.test_data_path = f'{self.test_data_folder}/test_data_20180101.nc' def test_gets_indexed_coordinates(self): ds = xr.open_dataset(self.test_data_path) self.assertEqual( next(get_coordinates(ds)), {'time': datetime.fromisoformat('2018-01-02T06:00:00+00:00').replace(tzinfo=None)} ) def test_no_duplicate_coordinates(self): ds = xr.open_dataset(self.test_data_path) # Assert that all the coordinates are unique. counts = Counter([tuple(c.values()) for c in get_coordinates(ds)]) self.assertTrue(all((c == 1 for c in counts.values()))) class IChunksTests(TestDataBase): def setUp(self) -> None: super().setUp() test_data_path = f'{self.test_data_folder}/test_data_20180101.nc' self.items = range(20) self.coords = get_coordinates(xarray.open_dataset(test_data_path), test_data_path) def test_even_chunks(self): actual = [] for chunk in ichunked(self.items, 4): actual.append(list(chunk)) self.assertEqual(actual, [ [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19], ]) def test_odd_chunks(self): actual = [] for chunk in ichunked(self.items, 7): actual.append(list(chunk)) self.assertEqual(actual, [ [0, 1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12, 13], [14, 15, 16, 17, 18, 19] ]) def test_get_coordinates(self): actual = [] for chunk in ichunked(itertools.islice(self.coords, 4), 3): actual.append(list(chunk)) self.assertEqual( actual, [ [ { 'time': datetime.fromisoformat('2018-01-02T06:00:00+00:00').replace(tzinfo=None) }, { 'time': datetime.fromisoformat('2018-01-02T07:00:00+00:00').replace(tzinfo=None) }, { 'time': datetime.fromisoformat('2018-01-02T08:00:00+00:00').replace(tzinfo=None) }, ], [ { 'time': datetime.fromisoformat('2018-01-02T09:00:00+00:00').replace(tzinfo=None) } ] ] ) class MakeAttrsEeCompatibleTests(TestDataBase): def test_make_attrs_ee_compatible_a1(self): attrs = { 'int_attr': 48, 'float_attr': 48.48, 'str_attr': '48.48', 'str_long_attr': 'Lorem ipsum dolor sit amet, consectetur ' 'adipiscing elit. Fusce bibendum odio ac lorem tristique, sed ' 'tincidunt orci ultricies. Vivamus eu rhoncus metus. Praesent ' 'vitae imperdiet sapien. Donec vel ipsum sapien. Aliquam ' 'suscipit suscipit turpis, a vehicula neque. Maecenas ' 'hendrerit, mauris eu consequat aliquam, nunc elit lacinia ' 'elit, vel accumsan ipsum ex a tellus. Pellentesque habitant ' 'morbi tristique senectus et netus et malesuada fames ac ' 'turpis egestas. Fusce a felis vel dolor lobortis vestibulum ' 'ac ac velit. Etiam vitae nibh sed justo hendrerit feugiat. ' 'Sed vulputate, turpis eget fringilla euismod, urna magna ' 'consequat turpis, at aliquam metus dolor vel tortor. Sed sit ' 'amet dolor quis libero venenatis porttitor a non odio. Morbi ' 'interdum tellus non neque placerat, vel fermentum turpis ' 'bibendum. In efficitur nunc ac leo eleifend commodo. Maecenas ' 'in tincidunt diam. In consectetur eget sapien a suscipit. ' 'Nulla porttitor ullamcorper tellus sit amet ornare. Aliquam ' 'in nibh at mauris tincidunt bibendum a a elit.', 'bool_attr': True, 'none_attr': None, 'key_long_raesent_id_tincidunt_velit_Integer_eget_sapien_tincidunt_' 'iaculis_nulla_vitae_consectetur_metus_Vestibul': 'long_string' } expected = { 'int_attr': 48, 'float_attr': 48.48, 'str_attr': '48.48', 'str_long_attr': 'Lorem ipsum dolor sit amet, consectetur ' 'adipiscing elit. Fusce bibendum odio ac lorem tristique, sed ' 'tincidunt orci ultricies. Vivamus eu rhoncus metus. Praesent ' 'vitae imperdiet sapien. Donec vel ipsum sapien. Aliquam ' 'suscipit suscipit turpis, a vehicula neque. Maecenas ' 'hendrerit, mauris eu consequat aliquam, nunc elit lacinia ' 'elit, vel accumsan ipsum ex a tellus. Pellentesque habitant ' 'morbi tristique senectus et netus et malesuada fames ac ' 'turpis egestas. Fusce a felis vel dolor lobortis vestibulum ' 'ac ac velit. Etiam vitae nibh sed justo hendrerit feugiat. ' 'Sed vulputate, turpis eget fringilla euismod, urna magna ' 'consequat turpis, at aliquam metus dolor vel tortor. Sed sit ' 'amet dolor quis libero venenatis porttitor a non odio. Morbi ' 'interdum tellus non neque placerat, vel fermentum turpis ' 'bibendum. In efficitur nunc ac leo eleifend commodo. Maecenas ' 'in tincidunt diam. In consectetur eget sapien a suscipit. ' 'Nulla porttitor ullamcorper tellus sit amet ornare. Aliquam ' 'in nibh at mauris tincidunt bibendum a a ...', 'bool_attr': 'True', 'none_attr': 'None', 'key_long_raesent_id_tincidunt_velit_Integer_eget_sapien_tincidunt_' 'iaculis_nulla_vitae_consectetur_metus_Vestib': 'long_string' } actual = make_attrs_ee_compatible(attrs) self.assertDictEqual(actual, expected) def test_make_attrs_ee_compatible_a2(self): attrs = { 'list_attr': ['attr1', 'attr1'], 'tuple_attr': ('attr1', 'attr2'), 'dict_attr': { 'attr1': 1, 'attr2': 'two', 'attr3': 3.0, 'attr4': True } } expected = { 'list_attr': "['attr1', 'attr1']", 'tuple_attr': "('attr1', 'attr2')", 'dict_attr': "{'attr1': 1, 'attr2': 'two', 'attr3': 3.0, " "'attr4': True}" } actual = make_attrs_ee_compatible(attrs) self.assertDictEqual(actual, expected) class ToJsonSerializableTypeTests(unittest.TestCase): def _convert(self, value): return to_json_serializable_type(value) def test_to_json_serializable_type_none(self): self.assertIsNone(self._convert(None)) self.assertIsNone(self._convert(float('NaN'))) self.assertIsNone(self._convert(np.NaN)) self.assertIsNone(self._convert(np.datetime64('NaT'))) self.assertIsNotNone(self._convert(np.array([]))) def test_to_json_serializable_type_float(self): self.assertIsInstance(self._convert(np.float32('0.1')), float) self.assertIsInstance(self._convert(np.float32('1')), float) self.assertIsInstance(self._convert(np.float16('0.1')), float) self.assertIsInstance(self._convert(np.single('0.1')), float) self.assertIsInstance(self._convert(np.double('0.1')), float) self.assertNotIsInstance(self._convert(1), float) self.assertNotIsInstance(self._convert(np.csingle('0.1')), float) self.assertNotIsInstance(self._convert(np.cdouble('0.1')), float) self.assertNotIsInstance(self._convert(np.intc('1')), float) def test_to_json_serializable_type_int(self): self.assertIsInstance(self._convert(np.int16('1')), int) self.assertEqual(self._convert(np.int16('1')), int(1)) self.assertEqual(self._convert(np.int32('1')), int(1)) self.assertEqual(self._convert(np.int64('1')), int(1)) self.assertEqual(self._convert(np.int64(-10_000)), -10_000) self.assertEqual(self._convert(np.uint16(25)), 25) def test_to_json_serializable_type_set(self): self.assertEqual(self._convert(set({})), []) self.assertEqual(self._convert(set({1, 2, 3})), [1, 2, 3]) self.assertEqual(self._convert(set({1})), [1]) self.assertEqual(self._convert(set({None})), [None]) self.assertEqual(self._convert(set({float('NaN')})), [None]) def test_to_json_serializable_type_ndarray(self): self.assertIsInstance(self._convert(np.array(list(range(10)))), list) self.assertEqual(self._convert(np.array(list(range(10)))), list(range(10))) self.assertEqual(self._convert(np.array([1])), [1]) self.assertEqual(self._convert(np.array([[1, 2, 3], [4, 5, 6]])), [[1, 2, 3], [4, 5, 6]]) self.assertEqual(self._convert(np.array(1)), 1) def test_to_json_serializable_type_datetime(self): input_date = '2000-01-01T00:00:00+00:00' now = datetime.now() self.assertEqual(self._convert(datetime.fromisoformat(input_date)), input_date) self.assertEqual(self._convert(now), now.replace(tzinfo=timezone.utc).isoformat()) self.assertEqual(self._convert(input_date), input_date) self.assertEqual(self._convert(np.datetime64(input_date)), input_date) self.assertEqual(self._convert(np.datetime64(1, 'Y')), '1971-01-01T00:00:00+00:00') self.assertEqual(self._convert(np.datetime64(30, 'Y')), input_date) self.assertEqual(self._convert(np.timedelta64(1, 'm')), float(60)) self.assertEqual(self._convert(timedelta(seconds=1)), float(1)) self.assertEqual(self._convert(timedelta(minutes=1)), float(60)) self.assertEqual(self._convert(timedelta(days=1)), float(86400)) ================================================ FILE: weather_mv/setup.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Setup weather-mv. This setup.py script makes use of Apache Beam's recommended way to install non-python dependencies to worker images. This is employed to enable a portable installation of cfgrib, which requires ecCodes. Please see this documentation and example code: - https://beam.apache.org/documentation/sdks/python-pipeline-dependencies/#nonpython - https://github.com/apache/beam/blob/master/sdks/python/apache_beam/examples/complete/juliaset/setup.py """ from setuptools import setup, find_packages beam_gcp_requirements = [ "google-cloud-bigquery==2.34.4", "google-cloud-bigquery-storage==2.14.1", "google-cloud-bigtable==1.7.2", "google-cloud-core==1.7.3", "google-cloud-datastore==1.15.5", "google-cloud-dlp==3.8.0", "google-cloud-language==1.3.2", "google-cloud-pubsub==2.13.4", "google-cloud-pubsublite==1.4.2", "google-cloud-recommendations-ai==0.2.0", "google-cloud-spanner==1.19.3", "google-cloud-videointelligence==1.16.3", "google-cloud-vision==1.0.2", "apache-beam[gcp]==2.40.0", ] base_requirements = [ "dataclasses", "numpy==1.22.4", "pandas==1.5.1", "xarray==2023.1.0", "xarray-beam==0.6.2", "cfgrib==0.9.10.2", "netcdf4==1.6.1", "geojson==2.5.0", "simplejson==3.17.6", "rioxarray==0.13.4", "metview==1.13.1", "rasterio==1.3.1", "earthengine-api>=0.1.263", "pyproj==3.4.0", # requires separate binary installation! "gdal==3.5.1", # requires separate binary installation! "gcsfs==2022.11.0", "zarr==2.15.0", ] setup( name='loader_pipeline', packages=find_packages(), author='Anthromets', author_email='anthromets-ecmwf@google.com', version='0.2.43', url='https://weather-tools.readthedocs.io/en/latest/weather_mv/', description='A tool to load weather data into BigQuery.', install_requires=beam_gcp_requirements + base_requirements, ) ================================================ FILE: weather_mv/test_data/test_data.zarr/.zattrs ================================================ { "Conventions": "CF-1.6", "history": "2022-12-07 15:15:12 GMT by grib_to_netcdf-2.25.1: /opt/ecmwf/mars-client/bin/grib_to_netcdf.bin -S param -o /cache/data1/adaptor.mars.internal-1670426111.8966048-4371-13-ae423cb0-986b-4dc9-abe5-76478aea4cef.nc /cache/tmp/ae423cb0-986b-4dc9-abe5-76478aea4cef-adaptor.mars.internal-1670426109.7807662-4371-20-tmp.grib" } ================================================ FILE: weather_mv/test_data/test_data.zarr/.zgroup ================================================ { "zarr_format": 2 } ================================================ FILE: weather_mv/test_data/test_data.zarr/.zmetadata ================================================ { "metadata": { ".zattrs": { "Conventions": "CF-1.6", "history": "2022-12-07 15:15:12 GMT by grib_to_netcdf-2.25.1: /opt/ecmwf/mars-client/bin/grib_to_netcdf.bin -S param -o /cache/data1/adaptor.mars.internal-1670426111.8966048-4371-13-ae423cb0-986b-4dc9-abe5-76478aea4cef.nc /cache/tmp/ae423cb0-986b-4dc9-abe5-76478aea4cef-adaptor.mars.internal-1670426109.7807662-4371-20-tmp.grib" }, ".zgroup": { "zarr_format": 2 }, "cape/.zarray": { "chunks": [ 24, 721, 1440 ], "compressor": { "blocksize": 0, "clevel": 5, "cname": "lz4", "id": "blosc", "shuffle": 1 }, "dtype": "/topics/'. Cannot be used with --subscription. * `--subscription`: A Pub/Sub subscription for GCS OBJECT_FINALIZE events, or equivalent, of a cloud bucket. Cannot be used with `--topic`. * `--window_size`: Output file's window size in minutes. Only used with the `topic` flag. Default: 1.0 minute. * `--num_shards`: Number of shards to use when writing windowed elements to cloud storage. Only used with the `topic` flag. Default: 5 shards. Invoke with `-h` or `--help` to see the full range of options. _Usage examples_: ```bash weather-sp --input-pattern 'gs://test-tmp/era5/2017/**' \ --output-dir 'gs://test-tmp/era5/splits' \ --formatting '.{typeOfLevel}' ``` Preview splits with a dry run: ```bash weather-sp --input-pattern 'gs://test-tmp/era5/2017/**' \ --output-dir 'gs://test-tmp/era5/splits' \ --formatting '.{typeOfLevel}' \ --dry-run ``` Using DataflowRunner ```bash weather-sp --input-pattern 'gs://test-tmp/era5/2015/**' \ --output-dir 'gs://test-tmp/era5/splits' --formatting '.{typeOfLevel}' \ --runner DataflowRunner \ --project $PROJECT \ --temp_location gs://$BUCKET/tmp \ --job_name $JOB_NAME ``` Using DataflowRunner and using local code for pipeline ```bash weather-sp --input-pattern 'gs://test-tmp/era5/2015/**' \ --output-dir 'gs://test-tmp/era5/splits' --formatting '.{typeOfLevel}' \ --runner DataflowRunner \ --project $PROJECT \ --temp_location gs://$BUCKET/tmp \ --job_name $JOB_NAME \ --use-local-code ``` Using DataflowRunner with using --where flag ```bash weather-sp --input-pattern 'gs://test-tmp/2015/*.gb' \ --output-template 'gs://temp/domain_{domain}/class_{class}/stream_{stream}/expver_{expver}/levtype_{levtype}/date_{date}/time_{time}/step_{step}/{shortName}.gb' \ --runner DataflowRunner \ --project $PROJECT \ --temp_location gs://$BUCKET/tmp \ --job_name $JOB_NAME \ --where "typeOfLevel=surface,shortName=lcc" ``` Using ecCodes-powered grib splitting on Dataflow (this is often more robust, especially when splitting multiple dimensions at once): ```bash weather-sp --input-pattern 'gs://test-tmp/era5/2017/**' \ --output-dir 'gs://test-tmp/era5/splits' \ --formatting '.{typeOfLevel}' \ --runner DataflowRunner \ --project $PROJECT \ --temp_location gs://$BUCKET/tmp \ --experiment=use_runner_v2 \ --sdk_container_image="gcr.io/$PROJECT/$REPO:latest" \ --job_name $JOB_NAME ``` ### Streaming mode example Run a streaming pipeline that reads file paths from a Pub/Sub topic and splits the files: ```bash weather-sp --input-pattern 'gs://bucket/data/**' \ --topic 'projects/my-project/topics/my-topic' \ --window_size 5 \ --num_shards 10 \ --output-template 'gs://output/grid_0p1/domain_{domain}/class_{class}/{shortName}.gb' \ --runner DataflowRunner \ --project $PROJECT \ --region us-west4 \ --temp_location gs://$BUCKET/tmp \ --experiment use_runner_v2 \ --sdk_container_image="gcr.io/$PROJECT/$REPO:latest" \ --job_name $JOB_NAME ``` _Consult [this documentation](../Runtime-Container.md) for steps on how to create a sufficient image._ _See the `weather-mv` documentation for more information about the container image._ For a full list of how to configure the Dataflow pipeline, please review [this table](https://cloud.google.com/dataflow/docs/reference/pipeline-options). ## Specifying input files The input file pattern matching is done by the Apache Beam filesystems module, see [this documentation](https://beam.apache.org/releases/pydoc/2.12.0/apache_beam.io.filesystems.html#apache_beam.io.filesystems.FileSystems.match) under 'pattern syntax' Notably, `**` Is equivalent to `.*`, so to match all files in a directory, use ```bash --input-pattern 'gs://test-tmp/era5/2017/**' ``` On the other hand, to specify a specific pattern use ```bash --input-pattern 'gs://test-tmp/era5/2017/*/*.nc' ``` ## Output & Split Dimensions The base output file names are specified using the `--output-template` or `--output-dir` flags. These flags are mutually exclusive, and one of them is required. \ The output formatting also specifies which dimensions to split by using Python formatting. For example, adding `{time}` in the output template or formatting will cause the file to be split along the time dimension. ### Available dimensions to split How a file can be split depends on the file type. #### GRIB GRIB files can be split along any dimensions that is available in the file's metadata. \ Examples: 'typeOfLevel', 'level', 'step', 'shortName', 'gridType', 'time', 'forecastTime' \ Any available dimensions can be combined when splitting. #### NetCDF NetCDF files are already in a hypercube format and can only be split by one of the dimensions and by data variable. Since splitting by latitude or longitude would lead to a large number of small files, this is not supported, and it is recommended to use the `weather-mv` tool instead. \ Supported splits for NetCDF files are thus 'variable' to split by data variable, and any dimension other than latitude and longitude. ### Output directory Based on the output directory path, the directory structure of the input pattern is replicated. To create the directory structure, the common path of the input pattern is removed from the input file path and replaced with the output path. \ The formatting specified by the `--formatting` flag is added between file name and ending and is used to determine along which dimensions to split. Example: ```bash --input-pattern 'gs://test-input/era5/2020/**' \ --output-dir 'gs://test-output/splits' --formatting '.{variable}' ``` For a file `gs://test-input/era5/2020/02/01.nc` the output file pattern is `gs://test-output/splits/2020/02/01.{variable}.nc` and if the temperature is a variable in that data, the output file for that split will be `gs://test-output/splits/2020/02/01.t.nc`. Example: ```bash --input-pattern 'gs://test-input/era5/2020/**' \ --output-dir 'gs://test-output/splits' --formatting '_{date}:{time}_{level}hPa' ``` For a file `gs://test-input/era5/2020/02/01.grib` the output file pattern is `gs://test-output/splits/2020/02/01_{date}:{time}_{level}hPa.grib` and for date = 20200101, time = 1200, level = 400, the output file for that split will be `gs://test-output/splits/2020/02/01_20200101:1200_400hPa.grib`. ### Output template with Python-style formatting Using Python-style substitution (e.g. `{1}`) allows for more flexibility when creating the output files. The substitutions are based on the directory structure of the input file, where each `{}` stands for one directory name, counting backwards from the end, i.e. the file name is `{0}`, the immediate directory in which it is located is `{1}`, and so on. In addition, you need to supply the split dimensions in the output template. These will be filled by values found in each file. Example: ```bash --input-pattern 'gs://test-input/era5/2020/**' \ --output-template 'gs://test-output/splits/{2}.{0}.{1}T00.{variable}.nc' ``` For a file `gs://test-input/era5/2020/02/01.nc` the output file pattern is `gs://test-output/splits/2020.01.02T00.{variable}.nc` and if the temperature is a variable in that data, the output file for that split will be `gs://test-output/splits/2020.01.02T00.t.nc` ## Dry run To verify the input file matching and the output naming scheme, `weather-sp` can be run with the `--dry-run` option. This does not read the files, so it will not check whether the files are readable and in the correct format. It will only list the input files with the corresponding output file schemes. ================================================ FILE: weather_sp/__init__.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: weather_sp/setup.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from setuptools import setup, find_packages beam_gcp_requirements = [ "google-cloud-bigquery==2.34.4", "google-cloud-bigquery-storage==2.14.1", "google-cloud-bigtable==1.7.2", "google-cloud-core==1.7.3", "google-cloud-datastore==1.15.5", "google-cloud-dlp==3.8.0", "google-cloud-language==1.3.2", "google-cloud-pubsub==2.13.4", "google-cloud-pubsublite==1.4.2", "google-cloud-recommendations-ai==0.2.0", "google-cloud-spanner==1.19.3", "google-cloud-videointelligence==1.16.3", "google-cloud-vision==1.0.2", "apache-beam[gcp]==2.40.0", ] base_requirements = [ "pygrib==2.1.4", "eccodes", "numpy>=1.20.3", "xarray==2023.1.0", "scipy==1.9.3", ] setup( name='splitter_pipeline', packages=find_packages(), author='Anthromets', author_email='anthromets-ecmwf@google.com', version='0.3.8', url='https://weather-tools.readthedocs.io/en/latest/weather_sp/', description='A tool to split weather data files into per-variable files.', install_requires=beam_gcp_requirements + base_requirements, ) ================================================ FILE: weather_sp/splitter_pipeline/__init__.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .pipeline import run def cli(extra=[]): import sys run(sys.argv + extra) ================================================ FILE: weather_sp/splitter_pipeline/file_name_utils.py ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass import logging import os import string import typing as t logger = logging.getLogger(__name__) GRIB_FILE_ENDINGS = ('.grib', '.grb', '.grb2', '.grib2', '.gb') NETCDF_FILE_ENDINGS = ('.nc', '.cd') @dataclass class OutFileInfo: """Holds data required to construct an output file name. Attributes: file_name_template: base output path, may contain python-style formatting marks. This can be a base directory or a full name. formatting: added after file_name_template to add formatting. Only used when using --output-dir. ending: file ending. template_folders: list of input file directory structure. Only used with --output-template """ file_name_template: str formatting: str ending: str template_folders: t.List[str] def __repr__(self): return self.unformatted_output_path() def unformatted_output_path(self): """Construct output file name with formatting marks.""" return self.file_name_template + self.formatting + self.ending def split_dims(self) -> t.List[str]: all_format = list(filter(None, [field[1] for field in string.Formatter().parse( self.unformatted_output_path())])) return [key for key in all_format if not key.isdigit()] def formatted_output_path(self, splits: t.Dict[str, str]) -> str: """Construct output file name with formatting applied""" return self.unformatted_output_path().format(*self.template_folders, **splits) def get_output_file_info(filename: str, input_base_dir: str = '', out_pattern: t.Optional[str] = None, out_dir: t.Optional[str] = None, formatting: str = '') -> OutFileInfo: """Construct the base output file name by applying the out_pattern to the filename. Example: filename = 'gs://my_bucket/data_to_split/2020/01/21.nc' out_pattern = 'gs://my_bucket/splits/{2}-{1}-{0}_old_data.' resulting output base = 'gs://my_bucket/splits/2020-01-21_old_data.' resulting file ending = '.nc' Args: filename: input file to be split out_pattern: pattern to apply when creating output file out_dir: directory to replace input base directory formatting: output formatting of split fields. Required when using out_dir, ignored when using out_pattern. input_base_dir: used if out_pattern does not contain any '{}' substitutions. The output file is then created by replacing this part of the input name with the output pattern. """ split_name, ending = os.path.splitext(filename) if ending in GRIB_FILE_ENDINGS or ending in NETCDF_FILE_ENDINGS: filename = split_name else: ending = '' if out_dir and not formatting: raise ValueError('No formatting specified when using --output-dir.') if out_dir: return OutFileInfo( f'{filename.replace(input_base_dir, out_dir)}', formatting, ending, [] ) if out_pattern: in_sections = [] path = filename while path: path, tail = os.path.split(path) in_sections.append(tail) # setting formatting and ending to empty strings since they are # part of the specified pattern. return OutFileInfo(out_pattern, '', '', in_sections) raise ValueError('no output specified.') ================================================ FILE: weather_sp/splitter_pipeline/file_name_utils_test.py ================================================ # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from .file_name_utils import get_output_file_info, OutFileInfo class FileNameUtilsTest(unittest.TestCase): def test_get_output_file_info_pattern(self): actual = get_output_file_info(filename='gs://my_bucket/data_to_split/2020/01/21.nc', out_pattern='gs://my_bucket/splits/{2}-{1}-{0}_old_data.{variable}', out_dir='', input_base_dir='ignored') expected = OutFileInfo( file_name_template='gs://my_bucket/splits/{2}-{1}-{0}_old_data.{variable}', template_folders=['21', '01', '2020', 'data_to_split', 'my_bucket', 'gs:'], ending='', formatting='') self.assertEqual(actual, expected) def test_get_output_file_info_dir(self): actual = get_output_file_info(filename='gs://my_bucket/data_to_split/2020/01/21.nc', out_pattern='', out_dir='gs://my_bucket/splits/', input_base_dir='gs://my_bucket/data_to_split/', formatting='_{foo}') expected = OutFileInfo( file_name_template='gs://my_bucket/splits/2020/01/21', template_folders=[], ending='.nc', formatting='_{foo}') self.assertEqual(actual, expected) def test_get_output_file_info_no_fileending(self): actual = get_output_file_info(filename='gs://my_bucket/data_to_split/2020/01/21', out_pattern='gs://my_bucket/splits/{2}-{1}-{0}_old_data.', out_dir='', input_base_dir='ignored') expected = OutFileInfo( file_name_template='gs://my_bucket/splits/{2}-{1}-{0}_old_data.', template_folders=['21', '01', '2020', 'data_to_split', 'my_bucket', 'gs:'], ending='', formatting='') self.assertEqual(actual, expected) def test_get_output_file_info_filecontainsdots(self): actual = get_output_file_info(filename='gs://my_bucket/data_to_split/2020/01/21.T00z.stuff', out_dir='gs://my_bucket/splits/', input_base_dir='gs://my_bucket/data_to_split/', formatting='.{foo}') expected = OutFileInfo( file_name_template='gs://my_bucket/splits/2020/01/21.T00z.stuff', template_folders=[], ending='', formatting='.{foo}') self.assertEqual(actual, expected) def test_get_output_file_info_dir_no_formatting(self): with self.assertRaises(ValueError): get_output_file_info(filename='gs://my_bucket/data_to_split/2020/01/21.nc', out_pattern='', out_dir='gs://my_bucket/splits/', input_base_dir='gs://my_bucket/data_to_split/') def test_output_pattern_ignores_formatting(self): actual = get_output_file_info(filename='gs://my_bucket/data_to_split/2020/01/21.nc', out_pattern='gs://my_bucket/splits/{2}-{1}-{0}_old_data.{variable}', out_dir=None, input_base_dir='ignored', formatting='_{time}_{level}hPa') expected = OutFileInfo( file_name_template='gs://my_bucket/splits/{2}-{1}-{0}_old_data.{variable}', template_folders=['21', '01', '2020', 'data_to_split', 'my_bucket', 'gs:'], ending='', formatting='') self.assertEqual(actual, expected) def test_split_dims(self): actual = get_output_file_info(filename='gs://my_bucket/data_to_split/2020/01/21.nc', out_pattern='gs://my_bucket/splits/{2}-{1}-{0}_old_data.{variable}', out_dir=None, input_base_dir='ignored') self.assertEqual(actual.split_dims(), ['variable']) ================================================ FILE: weather_sp/splitter_pipeline/file_splitters.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import itertools import logging import os import re import shutil import string import subprocess import tempfile import typing as t from contextlib import contextmanager import apache_beam.metrics as metrics import numpy as np import pygrib import xarray as xr from apache_beam.io.filesystem import CompressionTypes, FileSystem, CompressedFile, DEFAULT_READ_BUFFER_SIZE from apache_beam.io.filesystems import FileSystems from apache_beam.utils import retry from .file_name_utils import OutFileInfo # For uploading / downloading retry logic. INITIAL_DELAY = 1.0 # Initial delay in seconds. MAX_DELAY = 600 # Maximum delay before giving up in seconds. NUM_RETRIES = 10 # Number of tries with exponential backoff. logger = logging.getLogger(__name__) # TODO(#245): Group with common utilities (duplicated) @retry.with_exponential_backoff( num_retries=NUM_RETRIES, logger=logger.warning, initial_delay_secs=INITIAL_DELAY, max_delay_secs=MAX_DELAY ) def copy(src: str, dst: str) -> None: """Copy data via `gsutil` or local filesystem.""" is_gs = src.startswith("gs://") or dst.startswith("gs://") try: if is_gs: subprocess.run(['gcloud', 'storage', 'cp', src, dst], check=True, capture_output=True, text=True, input="n/n") else: os.makedirs(os.path.dirname(dst) or '.', exist_ok=True) shutil.copy(src, dst) except Exception as e: error_detail = getattr(e, "stderr", str(e)).strip() msg = f"Failed to copy {src!r} to {dst!r} due to {error_detail}" logger.error(msg) raise EnvironmentError(msg) from e class FileSplitter(abc.ABC): """Base class for weather file splitters.""" def __init__(self, input_path: str, output_info: OutFileInfo, force_split: bool = False, logging_level: int = logging.INFO, grib_filter_expression: t.Optional[str] = None): self.input_path = input_path self.output_info = output_info self.force_split = force_split self.logger = logging.getLogger(f'{__name__}.{type(self).__name__}') self.logger.setLevel(logging_level) self.logger.debug('Splitter for path=%s, output base=%s', self.input_path, self.output_info) self.grib_filter_expression = grib_filter_expression @abc.abstractmethod def split_data(self) -> None: raise NotImplementedError() @contextmanager def _copy_to_local_file(self) -> t.Iterator[t.IO]: self.logger.info(f'Copying {self.input_path!r} locally.') with tempfile.NamedTemporaryFile() as dest_file: copy(self.input_path, dest_file.name) # Check if data is compressed. Decompress the data using the same methods that beam's # FileSystems interface uses. compression_type = FileSystem._get_compression_type(self.input_path, CompressionTypes.AUTO) if compression_type == CompressionTypes.UNCOMPRESSED: yield dest_file return dest_file.seek(0) with tempfile.NamedTemporaryFile() as dest_uncompressed: with CompressedFile(open(dest_file.name, 'rb'), compression_type=compression_type) as dcomp: shutil.copyfileobj(dcomp, dest_uncompressed, DEFAULT_READ_BUFFER_SIZE) dest_uncompressed.seek(0) # Reposition the file pointer to the start. yield dest_uncompressed def should_skip(self): """Skip splitting if the data was already split.""" if self.force_split: return False for match in FileSystems().match([ self.output_info.formatted_output_path( {var: '*' for var in self.output_info.split_dims()}), ]): if len(match.metadata_list) > 0: return True return False def should_skip_file(self, output_path: str) -> bool: """Skip splitting if the data file was already split. TODO(#287): Consider making this the default skipping implementation... """ if self.force_split: return False matches = FileSystems().match([output_path]) assert len(matches) == 1 if len(matches[0].metadata_list) > 0: return True return False class GribSplitter(FileSplitter): def split_data(self) -> None: if not self.output_info.split_dims(): raise ValueError('No splitting specified in template.') if self.should_skip(): metrics.Metrics.counter('file_splitters', 'skipped').inc() self.logger.info('Skipping %s, file already split.', repr(self.input_path)) return # Here, we keep a map of open file objects (`outputs`). We need these since # each output grib file (named `key`) will include multiple `grb` messages # each. By writing data to the cache of open file objects, we can keep a # minimal amount of data in memory at a time. outputs = dict() with self._open_grib_locally() as grbs: self.logger.info('Splitting & uploading %r...', self.input_path) try: for grb in grbs: # Iterate through the split dimensions of the grib message in order to # produce the right output file. splits = dict() for dim in self.output_info.split_dims(): try: splits[dim] = getattr(grb, dim) except RuntimeError: self.logger.error( 'Variable not found in grib: %s', dim) key = self.output_info.formatted_output_path(splits) # Append the current grib message to a set number of output files. # If the target shard doesn't exist, create it. if key not in outputs: outputs[key] = FileSystems.create(key) outputs[key].write(grb.tostring()) outputs[key].flush() # Delete the grib message from memory – *and disk* – before moving on to the next # grib message. See the pygrib sources for more info. # https://github.com/jswhit/pygrib/blob/v2.1.4rel/src/pygrib/_pygrib.pyx#L759 del grb finally: for out in outputs.values(): out.close() self.logger.info('Split %s into %d files', self.input_path, len(outputs)) @contextmanager def _open_grib_locally(self) -> t.Iterator[t.Iterator[pygrib.gribmessage]]: with self._copy_to_local_file() as local_file: with pygrib.open(local_file.name) as gb: yield gb class GribSplitterV2(GribSplitter): """Splitter that makes use of `grib_copy` util for high performance splitting. See https://confluence.ecmwf.int/display/ECC/grib_copy. """ def replace_non_numeric_bracket(self, match: re.Match) -> str: value = match.group(1) return f"[{value}]" if not value.isdigit() else "{" + value + "}" def split_data(self) -> None: if not self.output_info.split_dims(): raise ValueError('No splitting specified in template.') grib_copy_cmd = shutil.which('grib_copy') grib_get_cmd = shutil.which('grib_get') uniq_cmd = shutil.which('uniq') for cmd, name in [(grib_get_cmd, 'grib_copy'), (grib_get_cmd, 'grib_get'), (uniq_cmd, 'uniq')]: if not cmd: raise EnvironmentError(f'binary {name!r} is not available in the current environment!') unformatted_output_path = self.output_info.unformatted_output_path() prefix, _ = os.path.split(next(iter(string.Formatter().parse(unformatted_output_path)))[0]) _, tail = unformatted_output_path.split(prefix) # Replace { with [ and } with ] only for non-numeric values inside {} of tail output_str = re.sub(r'\{(\w+)\}', self.replace_non_numeric_bracket, tail) output_template = output_str.format(*self.output_info.template_folders) slash = '/' delimiter = 'DELIMITER' flat_output_template = output_template.replace('/', delimiter) split_dims = self.output_info.split_dims() # Construct a string where each split dimension is "dim:s". # This ensures dims like time are represented as 0600 instead of 600. split_dims_arg = ','.join(f'{dim}:s' for dim in split_dims) with self._copy_to_local_file() as local_file: self.logger.info('Skipping as needed...') # Append -w flag to filter GRIB messages matching the given expression if self.grib_filter_expression: grib_get_args = [grib_get_cmd, '-p', split_dims_arg, '-w', self.grib_filter_expression, local_file.name] else: grib_get_args = [grib_get_cmd, '-p', split_dims_arg, local_file.name] grib_get_process = subprocess.Popen(grib_get_args, stdout=subprocess.PIPE) uniq_output = subprocess.check_output((uniq_cmd,), stdin=grib_get_process.stdout) output_paths = [] skipped_paths = [] for line in uniq_output.decode('utf-8').rstrip('\n').split('\n'): splits = dict(zip(split_dims, line.split(' '))) output_path = self.output_info.formatted_output_path(splits) if self.should_skip_file(output_path): skipped_paths.append(output_path) continue output_paths.append(output_path) if not output_paths: metrics.Metrics.counter('file_splitters', 'skipped').inc() self.logger.info('Skipping %s, file already split into: %s', repr(self.input_path), ', '.join(skipped_paths)) return with tempfile.TemporaryDirectory() as tmpdir: self.logger.info('Performing split.') dest = os.path.join(tmpdir, flat_output_template) if self.grib_filter_expression: subprocess.run([grib_copy_cmd, "-w", self.grib_filter_expression, local_file.name, dest], check=True) else: subprocess.run([grib_copy_cmd, local_file.name, dest], check=True) self.logger.info('Uploading %r...', self.input_path) for flat_target in os.listdir(tmpdir): dest_file_path = f'{prefix}{flat_target.replace(delimiter, slash)}' self.logger.info([prefix, dest_file_path, local_file.name, self.output_info.unformatted_output_path()]) copy(os.path.join(tmpdir, flat_target), dest_file_path) self.logger.info('Finished uploading %r', self.input_path) class NetCdfSplitter(FileSplitter): _UNSUPPORTED_DIMENSIONS = ('latitude', 'longitude', 'lat', 'lon') def split_data(self) -> None: if not self.output_info.split_dims(): raise ValueError('No splitting specified in template.') if any(dim in self._UNSUPPORTED_DIMENSIONS for dim in self.output_info.split_dims()): raise ValueError('Unsupported split dimension (lat, lng).') if self.should_skip(): metrics.Metrics.counter('file_splitters', 'skipped').inc() self.logger.info('Skipping %s, file already split.', repr(self.input_path)) return with self._open_dataset_locally() as dataset: if any(split not in dataset.dims and split not in ('variable') for split in self.output_info.split_dims()): raise ValueError( 'netcdf split: requested dimension not in dataset') iterlists = [] if 'variable' in self.output_info.split_dims(): iterlists.append([dataset[var].to_dataset() for var in dataset.data_vars]) else: iterlists.append([dataset]) filtered_split_dims = [ x for x in self.output_info.split_dims() if x not in ('variable', self._UNSUPPORTED_DIMENSIONS)] for dim in filtered_split_dims: iterlists.append(dataset[dim]) combinations = itertools.product(*iterlists) self.logger.info('Splitting & uploading %r...', self.input_path) for comb in combinations: selected = comb[0] for da in comb[1:]: for dim in da.coords: selected = selected.sel({dim: getattr(da, dim)}) self._write_dataset(selected, filtered_split_dims) self.logger.info('Finished splitting & uploading %r.', self.input_path) @contextmanager def _open_dataset_locally(self) -> t.Iterator[xr.Dataset]: with self._copy_to_local_file() as local_file: ds = xr.open_dataset(local_file.name, engine='netcdf4') yield ds ds.close() def _write_dataset(self, dataset: xr.Dataset, split_dims: t.List[str]) -> None: """Write destination NetCDF file in NETCDF4 format.""" # Here, we need to write the file locally, since only the scipy engine supports file objects or # returning bytes. Further, the scipy engine does not support NETCDF4 (which is HDF5 compliant). # Storing data in HDF5 is advantageous since it allows opening NetCDF files with buffered readers. with tempfile.NamedTemporaryFile() as tmp: dataset.to_netcdf(path=tmp.name, engine='netcdf4', format='NETCDF4') copy(tmp.name, self._get_output_for_dataset(dataset, split_dims)) def _get_output_for_dataset(self, dataset: xr.Dataset, split_dims: t.List[str]) -> str: splits = {'variable': list(dataset.data_vars.keys())[0]} for dim in split_dims: value = dataset[dim].values if dim == 'time': value = np.datetime_as_string(value, unit='m') splits[dim] = value return self.output_info.formatted_output_path(splits) class DrySplitter(FileSplitter): def split_data(self) -> None: if not self.output_info.split_dims(): raise ValueError('No splitting specified in template.') self.logger.info('input file: %s - output scheme: %s', self.input_path, self.output_info.formatted_output_path(self._get_keys())) def _get_keys(self) -> t.Dict[str, str]: return {name: name for name in self.output_info.split_dims()} def get_splitter(file_path: str, output_info: OutFileInfo, dry_run: bool, force_split: bool = False, logging_level: int = logging.INFO, grib_filter_expression: t.Optional[str] = None) -> FileSplitter: if dry_run: logger.info('Using splitter: DrySplitter') return DrySplitter(file_path, output_info, logging_level=logging_level) with FileSystems.open(file_path) as f: header = f.read(4) if b'GRIB' in header: metrics.Metrics.counter('get_splitter', 'grib').inc() # Decide which version of the grib splitter to use depending on if ecCodes is installed. # Prefer the v2 grib splitter, which should be more robust -- especially when splitting by # multiple dimensions at once. cmd = shutil.which('grib_copy') if cmd: logger.info('Using splitter: GribSplitterV2') return GribSplitterV2(file_path, output_info, force_split, logging_level, grib_filter_expression) else: logger.info('Using splitter: GribSplitter') return GribSplitter(file_path, output_info, force_split, logging_level) # See the NetCDF Spec docs: # https://docs.unidata.ucar.edu/netcdf-c/current/faq.html#How-can-I-tell-which-format-a-netCDF-file-uses if b'CDF' in header or b'HDF' in header: metrics.Metrics.counter('get_splitter', 'netcdf').inc() logger.info('Using splitter: NetCdfSplitter') return NetCdfSplitter(file_path, output_info, force_split) raise ValueError( f'cannot determine if file {file_path!r} is Grib or NetCDF.') ================================================ FILE: weather_sp/splitter_pipeline/file_splitters_test.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import shutil from collections import defaultdict import h5py import numpy as np import pygrib import pytest import xarray as xr from unittest.mock import patch import weather_sp from .file_name_utils import OutFileInfo from .file_name_utils import get_output_file_info from .file_splitters import ( DrySplitter, GribSplitter, GribSplitterV2, NetCdfSplitter, get_splitter, ) @pytest.fixture() def data_dir(): data_dir = f'{next(iter(weather_sp.__path__))}/test_data' yield data_dir split_dir = f'{data_dir}/split_files/' if os.path.exists(split_dir): shutil.rmtree(split_dir) @pytest.fixture(params=[GribSplitter, GribSplitterV2]) def grib_splitter(request): return request.param class TestGetSplitter: def test_get_splitter_grib(self, data_dir): splitter = get_splitter(f'{data_dir}/era5_sample.grib', OutFileInfo( file_name_template='some_out_{split}', ending='.grib', formatting='', template_folders=[]), dry_run=False) assert isinstance(splitter, GribSplitter) def test_get_splitter_nc(self, data_dir): splitter = get_splitter(f'{data_dir}/era5_sample.nc', OutFileInfo( file_name_template='some_out_{split}', ending='.nc', formatting='', template_folders=[]), dry_run=False) assert isinstance(splitter, NetCdfSplitter) def test_get_splitter_undetermined_grib(self, data_dir): splitter = get_splitter(f'{data_dir}/era5_sample_grib', OutFileInfo( file_name_template='some_out_{split}', ending='', formatting='', template_folders=[]), dry_run=False) assert isinstance(splitter, GribSplitter) def test_get_splitter_dryrun(self): splitter = get_splitter('some/file/path/data.grib', OutFileInfo( file_name_template='some_out_{split}', ending='.grib', formatting='', template_folders=[]), dry_run=True) assert isinstance(splitter, DrySplitter) class TestGribSplitter: def test_get_output_file_path(self, grib_splitter): splitter = grib_splitter( 'path/to/input', OutFileInfo( file_name_template='path/output/file.{typeOfLevel}_{shortName}', ending='.grib', formatting='', template_folders=[]) ) out = splitter.output_info.formatted_output_path( {'typeOfLevel': 'surface', 'shortName': 'cc'}) assert out == 'path/output/file.surface_cc.grib' def test_split_data_with_filter(self, data_dir): input_path = f'{data_dir}/era5_sample.grib' output_base = f'{data_dir}/split_files/era5_sample' splitter = GribSplitterV2( input_path, OutFileInfo( output_base, formatting='_{typeOfLevel}_{shortName}', ending='.grib', template_folders=[]), grib_filter_expression="typeOfLevel=isobaricInhPa,level=200" ) splitter.split_data() assert os.path.exists(f'{data_dir}/split_files/') is True short_names = ['d', 'cc', 'z', 'r'] input_data = defaultdict(list) split_data = defaultdict(list) input_grbs = pygrib.open(input_path) for grb in input_grbs: if grb.typeOfLevel == 'isobaricInhPa' and grb.level == 200: input_data[grb.shortName].append(grb.values) for sn in short_names: split_file = f'{data_dir}/split_files/era5_sample_isobaricInhPa_{sn}.grib' split_grbs = pygrib.open(split_file) for grb in split_grbs: split_data[sn].append(grb.values) for sn in short_names: orig = np.array(input_data[sn]) split = np.array(split_data[sn]) assert orig.shape == split.shape np.testing.assert_allclose(orig, split) def test_split_data(self, data_dir, grib_splitter): input_path = f'{data_dir}/era5_sample.grib' output_base = f'{data_dir}/split_files/era5_sample' splitter = grib_splitter( input_path, OutFileInfo( output_base, formatting='_{typeOfLevel}_{shortName}', ending='.grib', template_folders=[]) ) splitter.split_data() assert os.path.exists(f'{data_dir}/split_files/') is True short_names = ['z', 'r', 'cc', 'd'] input_data = defaultdict(list) split_data = defaultdict(list) input_grbs = pygrib.open(input_path) for grb in input_grbs: input_data[grb.shortName].append(grb.values) for sn in short_names: split_file = f'{data_dir}/split_files/era5_sample_isobaricInhPa_{sn}.grib' split_grbs = pygrib.open(split_file) for grb in split_grbs: split_data[sn].append(grb.values) for sn in short_names: orig = np.array(input_data[sn]) split = np.array(split_data[sn]) assert orig.shape == split.shape np.testing.assert_allclose(orig, split) def test_skips_existing_split(self, data_dir, grib_splitter): input_path = f'{data_dir}/era5_sample.grib' splitter = grib_splitter( input_path, OutFileInfo(f'{data_dir}/split_files/era5_sample', formatting='_{typeOfLevel}_{shortName}', ending='.grib', template_folders=[]) ) assert not splitter.should_skip() splitter.split_data() assert os.path.exists(f'{data_dir}/split_files/') assert splitter.should_skip() @patch('weather_sp.splitter_pipeline.file_splitters.FileSplitter.should_skip_file') def test_skips_existing_split_with_filter(self, mock_should_skip_file, data_dir): input_path = f'{data_dir}/era5_sample.grib' splitter = GribSplitterV2( input_path, OutFileInfo(f'{data_dir}/split_files/era5_sample', formatting='_{typeOfLevel}_{shortName}', ending='.grib', template_folders=[]), grib_filter_expression="typeOfLevel=isobaricInhPa,level=200" ) splitter.split_data() assert mock_should_skip_file.call_count == 8 def test_does_not_skip__if_forced(self, data_dir, grib_splitter): input_path = f'{data_dir}/era5_sample.grib' output_base = f'{data_dir}/split_files/era5_sample' splitter = grib_splitter( input_path, OutFileInfo( output_base, formatting='_{levelType}_{shortName}', ending='.grib', template_folders=[]), force_split=True ) assert not splitter.should_skip() splitter.split_data() assert os.path.exists(f'{data_dir}/split_files/') assert not splitter.should_skip() @pytest.mark.limit_memory('30 MB') def test_split__fits_memory_bounds(self, data_dir, grib_splitter): input_path = f'{data_dir}/era5_sample.grib' output_base = f'{data_dir}/split_files/era5_sample' splitter = grib_splitter( input_path, OutFileInfo( output_base, formatting='_{typeOfLevel}_{shortName}', ending='.grib', template_folders=[]) ) splitter.split_data() class TestNetCdfSplitter: def test_get_output_file_path(self): splitter = NetCdfSplitter( 'path/to/input', OutFileInfo('path/output/file_{variable}', ending='.nc', formatting='', template_folders=[])) out = splitter.output_info.formatted_output_path({'variable': 'cc'}) assert out == 'path/output/file_cc.nc' def test_split_data(self, data_dir): input_path = f'{data_dir}/era5_sample.nc' output_base = f'{data_dir}/split_files/era5_sample' splitter = NetCdfSplitter( input_path, OutFileInfo(output_base, formatting='_{time}_{variable}', ending='.nc', template_folders=[])) splitter.split_data() assert os.path.exists(f'{data_dir}/split_files/') input_data = xr.open_dataset(input_path, engine='netcdf4') for time in ['2015-01-15T00:00', '2015-01-15T06:00', '2015-01-15T12:00', '2015-01-15T18:00']: expect = input_data.sel(time=time) for sn in ['d', 'cc', 'z']: split_file = f'{data_dir}/split_files/era5_sample_{time}_{sn}.nc' split_data = xr.open_dataset(split_file, engine='netcdf4') xr.testing.assert_allclose(expect[sn], split_data[sn]) def test_split_data__not_in_dims_raises(self, data_dir): input_path = f'{data_dir}/era5_sample.nc' output_base = f'{data_dir}/split_files/era5_sample' splitter = NetCdfSplitter(input_path, OutFileInfo(output_base, formatting='_{level}', ending='.nc', template_folders=[])) with pytest.raises(ValueError): splitter.split_data() def test_split_data__unsupported_dim_raises(self, data_dir): input_path = f'{data_dir}/era5_sample.nc' output_base = f'{data_dir}/split_files/era5_sample' splitter = NetCdfSplitter(input_path, OutFileInfo(output_base, formatting='_{longitude}', ending='.nc', template_folders=[])) with pytest.raises(ValueError): splitter.split_data() def test_skips_existing_split(self, data_dir): input_path = f'{data_dir}/era5_sample.nc' output_base = f'{data_dir}/split_files/era5_sample' splitter = NetCdfSplitter(input_path, OutFileInfo(output_base, formatting='_{variable}', ending='.nc', template_folders=[])) assert not splitter.should_skip() splitter.split_data() assert os.path.exists(f'{data_dir}/split_files/') assert splitter.should_skip() def test_does_not_skip__if_forced(self, data_dir): input_path = f'{data_dir}/era5_sample.nc' output_base = f'{data_dir}/split_files/era5_sample' splitter = NetCdfSplitter(input_path, OutFileInfo( output_base, formatting='_{variable}', ending='.nc', template_folders=[]), force_split=True) assert not splitter.should_skip() splitter.split_data() assert os.path.exists(f'{data_dir}/split_files/') assert not splitter.should_skip() @pytest.mark.limit_memory('25 MB') def test_split_data__fits_memory_bounds(self, data_dir): input_path = f'{data_dir}/era5_sample.nc' output_base = f'{data_dir}/split_files/era5_sample' splitter = NetCdfSplitter( input_path, OutFileInfo(output_base, formatting='_{time}_{variable}', ending='.nc', template_folders=[])) splitter.split_data() def test_split_data__is_netcdf4(self, data_dir): input_path = f'{data_dir}/era5_sample.nc' assert not h5py.is_hdf5(input_path) output_base = f'{data_dir}/split_files/era5_sample' splitter = NetCdfSplitter( input_path, OutFileInfo(output_base, formatting='_{time}_{variable}', ending='.nc', template_folders=[])) splitter.split_data() assert os.path.exists(f'{data_dir}/split_files/') for time in ['2015-01-15T00:00', '2015-01-15T06:00', '2015-01-15T12:00', '2015-01-15T18:00']: for sn in ['d', 'cc', 'z']: split_file = f'{data_dir}/split_files/era5_sample_{time}_{sn}.nc' assert h5py.is_hdf5(split_file) class TestDrySplitter: def test_path_with_output_pattern(self): input_path = 'a/b/c/d/file.nc' out_pattern = 'gs://my_bucket/splits/{2}-{1}-{0}_old_data.{variable}.cd' out_info = get_output_file_info( filename=input_path, out_pattern=out_pattern) splitter = DrySplitter(input_path, out_info) keys = splitter._get_keys() assert keys == {'variable': 'variable'} out_file = splitter.output_info.formatted_output_path(keys) assert out_file == 'gs://my_bucket/splits/c-d-file_old_data.variable.cd' def test_path_with_output_pattern_no_formatting(self): # OutFileInfo using pattern but without any formatting marks. splitter = DrySplitter("input_path/file.grib", OutFileInfo("some/out/no/formatting", formatting='', ending='', template_folders=[])) with pytest.raises(ValueError): splitter.split_data() ================================================ FILE: weather_sp/splitter_pipeline/pipeline.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import logging import os import typing as t import apache_beam as beam import apache_beam.metrics as metrics from apache_beam.io.fileio import MatchFiles, ReadMatches from apache_beam.options.pipeline_options import PipelineOptions, SetupOptions from apache_beam.io.gcp.pubsub import ReadFromPubSub from .file_name_utils import OutFileInfo, get_output_file_info from .file_splitters import get_splitter from .streaming import GroupMessagesByFixedWindows, ParsePaths logger = logging.getLogger(__name__) SDK_CONTAINER_IMAGE='gcr.io/weather-tools-prod/weather-tools:0.0.0' def configure_logger(verbosity: int) -> None: """Configures logging from verbosity. Default verbosity will show errors.""" level = 40 - verbosity * 10 logging.getLogger(__package__).setLevel(level) logger.setLevel(level) def split_file(input_file: str, input_base_dir: str, output_template: t.Optional[str], output_dir: t.Optional[str], formatting: str, dry_run: bool, force_split: bool = False, logging_level: int = logging.INFO, grib_filter_expression: t.Optional[str] = None): output_base_name = get_output_base_name(input_path=input_file, input_base=input_base_dir, output_template=output_template, output_dir=output_dir, formatting=formatting) logger.info('Splitting file: %s. Output base name: %s', input_file, output_base_name) metrics.Metrics.counter('pipeline', 'splitting file').inc() level = 40 - logging_level * 10 splitter = get_splitter(input_file, output_base_name, dry_run, force_split, level, grib_filter_expression) splitter.split_data() def _get_base_input_directory(input_pattern: str) -> str: base_dir = input_pattern for x in ['*', '?', '[']: base_dir = base_dir.split(x, maxsplit=1)[0] # Go one directory up to include the last common directory in output path. return os.path.dirname(os.path.dirname(base_dir)) def get_output_base_name(input_path: str, input_base: str, output_template: t.Optional[str], output_dir: t.Optional[str], formatting: str) -> OutFileInfo: return get_output_file_info(input_path, input_base_dir=input_base, out_pattern=output_template, out_dir=output_dir, formatting=formatting) def run(argv: t.List[str], save_main_session: bool = True): """Main entrypoint & pipeline definition.""" parser = argparse.ArgumentParser( prog='weather-sp', description='Split weather data file into files by variable or other dimension.' ) parser.add_argument('-i', '--input-pattern', type=str, required=True, help='Pattern for input weather data.') parser.add_argument('--use-local-code', action='store_true', default=False, help='Supply local code to the Runner.') output_options = parser.add_mutually_exclusive_group(required=True) output_options.add_argument( '--output-template', type=str, help='Template specifying path to output files using ' 'python-style formatting substitution of input ' 'directory names. ' 'For `input_pattern a/b/c/**` and file `a/b/c/file.grib`, ' 'a template with formatting `/somewhere/{1}-{0}.{level}_{shortName}.grib` ' 'will give `somewhere/c-file.level_shortName.grib`' ) output_options.add_argument( '--output-dir', type=str, help='Output directory that will replace the common path of the ' 'input_pattern. ' 'For `input_pattern a/b/c/**` and file `a/b/c/file.nc`, ' '`outputdir /x/y/z` will create ' 'output files like `/x/y/z/c/file_variable.nc`' ) parser.add_argument( '--formatting', type=str, default='', help='Used in combination with `output-dir`: specifies the how to ' 'split the data and format the output file. ' 'Example: `_{time}_{level}hPa`' ) parser.add_argument('-d', '--dry-run', action='store_true', default=False, help='Test the input file matching and the output file scheme without splitting.') parser.add_argument('-f', '--force', action='store_true', default=False, help='Force re-splitting of the pipeline. Turns of skipping of already split data.') parser.add_argument('--log-level', type=int, default=2, help='An integer to configure log level. Default: 2(INFO)') parser.add_argument('-w', '--where', type=str, default=None, help='Optional GRIB filter expression to apply during' 'file splitting using grib_copy.' 'This allows filtering GRIB messages based on' 'key-value pairs, such as level, type of level,' 'or date.' 'This flag is only applicable to GRIB files and is' 'specifically supported by the GribSplitterV2' 'implementation.' 'Example: typeOfLevel=isobaricInhPa,level=1000') parser.add_argument('--topic', type=str, default=None, help='Pub/Sub topic to read from for streaming mode.') parser.add_argument('--subscription', type=str, default=None, help='Pub/Sub subscription to read from for streaming mode.') parser.add_argument('--window-size', type=int, default=1, help='Window size in minutes for grouping Pub/Sub messages.') parser.add_argument('--num-shards', type=int, default=5, help='Number of shards for partitioning windowed data.') known_args, pipeline_args = parser.parse_known_args(argv[1:]) # If a Pub/Sub is used, then the pipeline must be a streaming pipeline. if known_args.topic or known_args.subscription: if known_args.topic and known_args.subscription: raise ValueError('only one argument can be provided at a time: `topic` or `subscription`.') pipeline_args.extend('--streaming true'.split()) configure_logger(known_args.log_level) # 0 = error, 1 = warn, 2 = info, 3 = debug pipeline_options = PipelineOptions(pipeline_args) pipeline_options.view_as(SetupOptions).save_main_session = save_main_session input_pattern = known_args.input_pattern input_base_dir = _get_base_input_directory(input_pattern) output_template = known_args.output_template output_dir = known_args.output_dir formatting = known_args.formatting dry_run = known_args.dry_run grib_filter_expression = known_args.where if not output_template and not output_dir: raise ValueError('No output specified') output_file_tmpl = os.path.basename(output_template) if '[' in output_file_tmpl or ']' in output_file_tmpl or '[' in formatting or ']' in formatting: raise ValueError('Tokens `[]` are disallowed in the file output.') logger.debug('input_pattern: %s', input_pattern) logger.debug('input_base_dir: %s', input_base_dir) if output_template: logger.debug('output_template: %s', output_template) if output_dir: logger.debug('output_dir: %s', output_dir) logger.debug('dry_run: %s', known_args.dry_run) with beam.Pipeline(options=pipeline_options) as p: if known_args.topic or known_args.subscription: paths = ( p # Windowing is based on this code sample: # https://cloud.google.com/pubsub/docs/pubsub-dataflow#code_sample | 'ReadUploadEvent' >> ReadFromPubSub( known_args.topic, known_args.subscription ) | 'WindowInto' >> GroupMessagesByFixedWindows( known_args.window_size, known_args.num_shards ) | 'ParsePaths' >> beam.ParDo( ParsePaths(known_args.input_pattern) ) ) else: paths = ( p | 'MatchFiles' >> MatchFiles(input_pattern) | 'ReadMatchedFiles' >> ReadMatches() | 'Shuffle' >> beam.Reshuffle() | 'GetPath' >> beam.Map(lambda x: x.metadata.path) ) ( paths | 'SplitFiles' >> beam.Map( split_file, input_base_dir, output_template, output_dir, formatting, dry_run, known_args.force, known_args.log_level, grib_filter_expression, ) ) logger.info('Pipeline is finished.') ================================================ FILE: weather_sp/splitter_pipeline/pipeline_test.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest from unittest.mock import patch, ANY from .file_name_utils import OutFileInfo from .pipeline import _get_base_input_directory from .pipeline import get_output_base_name from .pipeline import split_file class PipelineTest(unittest.TestCase): def test_get_base_input_directory(self): self.assertEqual( _get_base_input_directory( '/path/to/some/wild/*/card/??[0-1].nc'), '/path/to/some') self.assertEqual( _get_base_input_directory( '/path/to/some/wild/??/card/*[0-1].nc'), '/path/to/some') self.assertEqual( _get_base_input_directory( '/path/to/some/wild/201[8,9]/card/??.nc'), '/path/to/some') def test_get_output_base_name(self): self.assertEqual(get_output_base_name( input_path='somewhere/somefile', input_base='somewhere', output_template=None, formatting='.{shortName}', output_dir='out/there').file_name_template, 'out/there/somefile') @patch('weather_sp.splitter_pipeline.file_splitters.get_splitter') @unittest.skip('bad mocks') def test_split_file(self, mock_get_splitter): split_file(input_file='somewhere/somefile', input_base_dir='somewhere', output_dir='out/there', output_template=None, formatting='_{variable}', dry_run=True) mock_get_splitter.assert_called_with('somewhere/somefile', OutFileInfo('out/there/somefile', formatting='_{variable}', ending='', template_folders=[]), True) @patch('weather_sp.splitter_pipeline.pipeline.get_splitter') def test_split_file_with_filter(self, mock_get_splitter): split_file( input_file='somewhere/somefile', input_base_dir='somewhere', output_dir='out/there', output_template=None, formatting='_{variable}', dry_run=True, grib_filter_expression='typeOfLevel=isobaricInhPa,level=200', ) mock_get_splitter.assert_called_with( 'somewhere/somefile', OutFileInfo('out/there/somefile', formatting='_{variable}', ending='', template_folders=[]), True, False, ANY, 'typeOfLevel=isobaricInhPa,level=200') if __name__ == '__main__': unittest.main() ================================================ FILE: weather_sp/splitter_pipeline/streaming.py ================================================ # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Window and parse Pub/Sub streams of real-time weather data added to cloud storage. Example windowing code borrowed from: https://cloud.google.com/pubsub/docs/pubsub-dataflow#code_sample """ import datetime import fnmatch import json import logging import random import typing as t from urllib.parse import urlparse import apache_beam as beam from apache_beam.transforms.window import FixedWindows logger = logging.getLogger(__name__) class GroupMessagesByFixedWindows(beam.PTransform): """A composite transform that groups Pub/Sub messages based on publish time and outputs a list of tuples, each containing a message and its publish time. """ def __init__(self, window_size: int, num_shards: int = 5): # Set window size to 60 seconds. self.window_size = int(window_size * 60) self.num_shards = num_shards def expand(self, pcoll): return ( pcoll # Bind window info to each element using element timestamp (or publish time). | "Window into fixed intervals" >> beam.WindowInto(FixedWindows(self.window_size)) | "Add timestamp to windowed elements" >> beam.ParDo(AddTimestamp()) # Assign a random key to each windowed element based on the number of shards. | "Add key" >> beam.WithKeys(lambda _: random.randint(0, self.num_shards - 1)) # Group windowed elements by key. All the elements in the same window must fit # memory for this. If not, you need to use `beam.util.BatchElements`. | "Group by key" >> beam.GroupByKey() ) class AddTimestamp(beam.DoFn): """Processes each windowed element by extracting the message body and its publish time into a tuple. """ def process(self, element, publish_time=beam.DoFn.TimestampParam) -> t.Iterable[t.Tuple[str, str]]: yield ( element.decode("utf-8"), datetime.datetime.utcfromtimestamp(float(publish_time)).strftime( "%Y-%m-%d %H:%M:%S.%f" ), ) class ParsePaths(beam.DoFn): """Parse paths to real-time weather data from windowed-batches.""" def __init__(self, uri_pattern: str): self.uri_pattern = uri_pattern self.protocol = f'{urlparse(uri_pattern).scheme}://' super().__init__() @classmethod def try_parse_message(cls, message_body: t.Union[str, t.Dict]) -> t.Dict: """Robustly parse message body, which will be JSON in the vast majority of cases, but might be a dictionary.""" try: return json.loads(message_body) except (json.JSONDecodeError, TypeError): if isinstance(message_body, dict): return message_body raise def to_object_path(self, payload: t.Dict) -> str: """Parse cloud object from Pub/Sub topic payload.""" return f'{self.protocol}{payload["bucket"]}/{payload["name"]}' def should_skip(self, message_body: t.Dict) -> bool: """Returns true if Pub/Sub topic does *not* match the target file URI pattern.""" try: return not fnmatch.fnmatch(self.to_object_path(message_body), self.uri_pattern) except KeyError: return True def process(self, key_value, window=beam.DoFn.WindowParam) -> t.Iterable[str]: """Yield paths to real-time weather data in cloud storage.""" shard_id, batch = key_value logger.debug(f'Processing shard {shard_id!r}.') for message_body, publish_time in batch: logger.debug(message_body) parsed_msg = self.try_parse_message(message_body) target = self.to_object_path(parsed_msg) logger.info(f'Parsed path {target!r}...') if self.should_skip(parsed_msg): logger.info(f'skipping {target!r}.') continue yield target ================================================ FILE: weather_sp/test_data/era5_sample.grib ================================================ [File too large to display: 10.4 MB] ================================================ FILE: weather_sp/test_data/era5_sample_grib ================================================ [File too large to display: 10.4 MB] ================================================ FILE: weather_sp/weather-sp ================================================ #!/usr/bin/env python3 # Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import glob import logging import os import subprocess import sys import tarfile import tempfile import weather_sp SDK_CONTAINER_IMAGE='gcr.io/weather-tools-prod/weather-tools:0.0.0' if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) site_pkg = weather_sp.__path__[0] try: from splitter_pipeline import cli except ImportError: # Install the subpackage. subprocess.check_call( f'{sys.executable} -m pip -q install -e {site_pkg}'.split()) # Re-load sys.path import site from importlib import reload reload(site) # Re-attempt import. If this fails, the user probably has an older version of # the package already installed on their machine that breaks this process. # If that's the case, it's best to start from a clean virtual environment. try: from splitter_pipeline import cli except ImportError as e: raise ImportError( 'please re-install package in a clean python environment.') from e args = [] if "DataflowRunner" in sys.argv and "--sdk_container_image" not in sys.argv: args.extend(['--sdk_container_image', os.getenv('SDK_CONTAINER_IMAGE', SDK_CONTAINER_IMAGE), '--experiments', 'use_runner_v2']) if "--use-local-code" in sys.argv: with tempfile.TemporaryDirectory() as tmpdir: original_dir = os.getcwd() # Convert subpackage to a tarball os.chdir(site_pkg) subprocess.check_call( f'{sys.executable} ./setup.py -q sdist --dist-dir {tmpdir}'.split(), ) os.chdir(original_dir) # Set tarball as extra packages for Beam. pkg_archive = glob.glob(os.path.join(tmpdir, '*.tar.gz'))[0] with tarfile.open(pkg_archive, 'r') as tar: assert any([f.endswith('.py') for f in tar.getnames()]), 'extra_package must include python files!' # cleanup memory to prevent pickling error. tar = None weather_sp = None args.extend(['--extra_package', pkg_archive]) cli(args) else: cli(args) ================================================ FILE: xql/README.md ================================================ # `xql` - Querying Xarray Datasets with SQL Running SQL like queries on Xarray Datasets. Consider dataset as a table and data variable as a column. > Note: For now, we support only zarr datasets and earth engine image collections. # Supported Features * **`Select` Variables** - From a large dataset having hundreds of variables select only needed variables. * **Apply `where` clause** - A general where condition like SQL. Applicable for queries which includes data for specific time range or only for specific regions. * **`group by` Functions** - This is supported on the coordinates only. e.g. time, latitude, longitude, etc. * **`aggregate` Functions** - Aggregate functions `AVG()`, `MIN()`, `MAX()`, etc. Only supported on data variables. * **`limit and offset` clause** - Apply limit and offset to filter out the required result. * For more checkout the [road-map](https://github.com/google/weather-tools/tree/xql-init/xql#roadmap). > Note: For now, we support `where` conditions and `groupby` on coordinates only. `orderby` can only be applied either on selected variables or on coordinates. # Quickstart ## Prerequisites Get an access to the dataset you want to query. Here as an example we're going to use the analysis ready era5 public dataset. [full_37-1h-0p25deg-chunk-1.zarr-v3](https://pantheon.corp.google.com/storage/browser/gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3?project=gcp-public-data-signals). For this `gcloud` must be configured in your local environment. Refer [Initializing the gcloud CLI](https://cloud.google.com/sdk/docs/initializing) for configuring the `gcloud` locally. ## Usage ``` # Install required packages pip install xql # Jump into xql python xql/main.py xql ``` --- ### Supported meta commands `.help`: For usage info. `.exit`: To exit from the xql interpreter. `.set`: To set the dataset uri as a shortened key. ``` .set era5 gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3 ``` `.show`: To list down dataset shortened key. Eg. `.show` or `.show [key]` ``` .show era5 ``` `[query]` => Any valid sql like query. --- ### Example Queries 1. Apply a conditions. Query to get temperature of arctic region in January 2022: ``` SELECT temperature FROM 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' WHERE time >= '2022-01-01' AND time < '2022-02-01' AND latitude >= 66.5 ``` > Note: Multiline queries are not yet supported. Convert copied queries into single line before execution. 2. Aggregating results using Group By and Aggregate function. Daily average of temperature of arctic region in January 2022. Setting the table name as shortened key. ``` .set era5 gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3 ``` ``` SELECT AVG(temperature), SUM(charnock), MIN('100m_v_component_of_wind') FROM era5 WHERE time >= '2022-01-01' AND time < '2022-02-01' AND latitude >= 66.5 GROUP BY time_date ``` Replace `time_date` to `time_month` or `time_year` if monthly or yearly average is needed. Also use `MIN()` and `MAX()` functions same way as `AVG()`. 3. `caveat`: Above queries run on the client's local machine and it generates a large two dimensional array so querying for very large amount of data will fall into out of memory erros. e.g. Query like below will give OOM errors if the client machine don't have the enough RAM. ``` SELECT evaporation, geopotential_at_surface, temperature FROM era5 ``` # Dask Cluster Configuration Steps to deploy a Dask Cluster on GKE. 1. Create a Kubernetes cluster if don't have any. Follow [Creating a zonal cluster](https://cloud.google.com/kubernetes-engine/docs/how-to/creating-a-zonal-cluster). 2. Get Cluster Credentials on Local Machine ``` gcloud container clusters get-credentials {cluster_name} --region {cluster_region} --project {project} ``` 3. Install `helm`. Follow [Helm | Installing Helm](https://helm.sh/docs/intro/install/). 4. Deploy Dask Cluster using below `helm` commands. ``` helm repo add dask https://helm.dask.org/ helm repo update helm install xql-dask dask/dask ``` > Replace `xql-dask` from above command with the name you want your dask cluster to have. Just set `DASK_CLUSTER={dask_cluster_name}` environment variable. 4. Connect to a dask cluster ``` from xql.utils import connect_dask_cluster connect_dask_cluster() ``` # Roadmap _Updated on 2024-01-08_ 1. [x] **Select Variables** 1. [ ] On Coordinates 2. [x] On Variables 2. [x] **Where Clause**: `=`, `>`, `>=`, `<`, `<=`, etc. 1. [x] On Coordinates 2. [ ] On Variables 3. [x] **Aggregate Functions**: Only `AVG()`, `MIN()`, `MAX()`, `SUM()` are supported. 1. [x] With Group By 2. [x] Without Group By 3. [x] Multiple Aggregate function in a single query 4. [x] **Order By**: Apply sorting on the result. 5. [x] **Limit**: Limiting the result to display. 6. [ ] **Mathematical Operators** `(+, - , *, / )`: Add support to use mathematical operators in the query. 7. [ ] **Aliases**: Add support to alias while querying. 8. [ ] **Join Operations**: Support joining tables and apply query. 9. [ ] **Nested Queries**: Add support to write nested queries. 10. [ ] **Custom Aggregate Functions**: Support custom aggregate functions # `weather-lm` - Querying weather data using Natural Language prompts Querying weather data using Natural Language prompts. This uses a gemini (large language model from Google) to generate SQL like queries and `xql` to execute that query. # Quickstart ## Prerequisites Google API Key is needed to initiate Language Model. Refer [Setup your API key](https://ai.google.dev/tutorials/python_quickstart#setup_your_api_key) to generate that key. Set that key as an environment variable. Run below command. ``` export GOOGLE_API_KEY="generate_key" ``` ## Usage ``` # Install required packages pip install xql # Jump into language model python xql/main.py lm ``` --- ### Examples `Input Prompt`: Daily average temperature of New York for January 2015 Relevant SQL Query: ``` SELECT AVG('temperature') FROM 'gs://darshan-store/ar/2013-2022-full_37-1h-0p25deg-chunk-1.zarr-v3' WHERE time >= '2015-01-01' AND time < '2015-02-01' AND city = 'New York' GROUP BY time_date ``` Output Data: ``` time_date avg_temperature 0 2015-01-01 240.978073 1 2015-01-02 243.375031 2 2015-01-03 244.584747 3 2015-01-04 249.673065 4 2015-01-05 245.650833 ... Query took: 00:01:55 ``` ================================================ FILE: xql/main.py ================================================ # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from src.weather_lm import nl_to_weather_data from src.xql import main from src.xql.utils import connect_dask_cluster from typing import List, Tuple def parse_args() -> Tuple[argparse.Namespace, List[str]]: parser = argparse.ArgumentParser() parser.add_argument('mode', type=str, help='Select one from [xql, lm]') return parser.parse_known_args() if __name__ == '__main__': known_args, _ = parse_args() if known_args.mode not in ["xql", "lm"]: raise RuntimeError("Invalid mode type. Select one from [xql, lm]") prefix = "xql" if known_args.mode == "xql" else "lm" try: # Connect Dask Cluster connect_dask_cluster() while True: query = input(f"{prefix}> ") if query == ".exit": break if known_args.mode == "xql": main(query) else: print(nl_to_weather_data(query)) except ImportError as e: raise ImportError('main function is not imported please try again.') from e ================================================ FILE: xql/setup.py ================================================ # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from setuptools import find_packages, setup # Please maintain the dask_kubernetes version at 2024.4.2 (https://github.com/dask/dask-kubernetes/releases/tag/2024.4.2). # Subsequent versions have removed code which is essential for the current codebase to function properly. requirements = [ "dask==2024.3.0", "distributed==2024.3.0", "dask_kubernetes==2024.4.2", "fsspec", "gcsfs==2024.2.0", "numpy==1.26.3", "sqlglot", "toolz==0.12.0", "xarray==2024.01.0", "xee", "zarr==2.17.0", "langchain", "langchain-experimental", "langchain-openai", "langchain-google-genai" ] setup( name="xql", packages=find_packages(where='src'), package_dir={'': 'src'}, install_requires=requirements, version="0.0.2", author='anthromet', author_email='anthromets-ecmwf@google.com', description=("Running SQL queries on Xarray Datasets. Consider dataset as a table and data variable as a column."), long_description=open('README.md', 'r', encoding='utf-8').read(), long_description_content_type='text/markdown', python_requires='>=3.9, <3.11', ) ================================================ FILE: xql/src/__init__.py ================================================ # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: xql/src/weather_lm/__init__.py ================================================ #!/usr/bin/env python3 # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .gemini import nl_to_sql_query, nl_to_weather_data #noqa ================================================ FILE: xql/src/weather_lm/constant.py ================================================ #!/usr/bin/env python3 # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ruff: noqa: E501 METADATA_URI = "gs://darshan-store/xql/metadata.json" GENERATE_SQL_TEMPLATE = """You are a SQL expert. Given an input question, first create a syntactically correct SQL query to execute. Never query for all columns from a table. You must query only the columns that are needed to answer the question. Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. Everytime wrap table name in single quotes (''). Specify the time range, latitude, longitude as follows: (time >= '2012-12-01' AND time < '2013-04-01'). While accessing the variable name from the table don't use "\" this. Ex. ( SELECT MAX("vertical_velocity") FROM 'table' ) => True syntax ( SELECT MAX(\"vertical_velocity\") FROM 'table' ) => False syntax Avoid using the 'time BETWEEN', 'latitude BETWEEN' syntax, opt for the former style instead. Note: At present, only data variables are supported in the SELECT Clause. Coordinates (latitude, longitude, time) are not supported. Therefore, coordinates should not be used in the SELECT Clause. Example: Some important data details to consider: - Use latitude and longitude ranges for cities and countries. - Standard aggregations are applied to the data. A unique convention for aggregation for daily, monthly and yearly are time_date, time_month and time_year. - The WHERE clause and GROUP BY is specifically applies to coordinates variables. e.g. timestamp, latitude, longitude, and level coordinates. For "timestamp," use time_date for grouping by date and time_month for grouping by month. Standard SQL GROUP BY operations apply only to "latitude", "longitude", and "level" column. - Write time always into 'YYYY-MM-DD' format. i.e. '2021-12-01'. Please use the following format: Question: "Question here" SQLQuery: "SQL Query to run" Use the following information for the database: - Use {table} as table name. - The dataset includes columns like {columns}. Select appropriate columns from these which are most relevant to the Question. - Latitude range is {latitude_range}, and longitude range is {longitude_range}. Generate query accordingly. - {latitude_dim} and {longitude_dim} are my columns for latitude and longitude so use them everywhere in query. Ex. If lat and lon are in the {dims} then instead of latitude > x AND longitude > y use lat > x AND lon > y. - The interpretation of the "organic" soil type is value of soil type is equal to 6. - "Over all locations", "globally" entails iterating through "latitude" & "longitude." Some examples of SQL queries that correspond to questions are: {few_shot_examples} Question: {question}""" few_shots = { "Aggregate precipitation over months over all locations?" : "SELECT SUM(precipitation) FROM 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' GROUP BY latitude, longitude, time_month", "Daily average temperature.":"""SELECT AVG(temperature) FROM "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3" GROUP BY time_date""", "Average temperature of the Antarctic Area during last monsoon over months.":"SELECT AVG(temperature) FROM 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' WHERE time >= '2022-06-01' AND time < '2022-11-01' AND latitude < 66.5 GROUP BY time_month", "Average temperature over years.":"SELECT AVG(temperature) FROM 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' GROUP BY time_year", "Aggregate precipitation globally?" : "SELECT SUM(precipitation) FROM 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' GROUP BY latitude, longitude", "For January 2000" : "SELECT * from TABLE where time >= '2000-01-01 00:00:00' AND time < '2000-02-01 00:00:00' ", "Daily average temperature of city x for January 2015?": "SELECT AVG(temperature) FROM 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3' WHERE time >= '2015-01-01' AND time < '2015-02-01' AND latitude > 40 AND latitude < 41 AND longitude > 286 AND longitude < 287 GROUP BY time_date", "Daily min reflectivity of city x for January 2015?": "SELECT AVG(reflectivity) FROM 'ee://projects/anthromet-prod/assets/opera/instantaneous_maximum_reflectivity' WHERE time >= '2015-01-01' AND time < '2015-02-01' AND lat > 40.48 AND lat < 41.87 AND lon > -74.25 AND lon < -71.98 GROUP BY time_date" } SELECT_DATASET_TEMPLATE = """ I have some description of tables that stores weather related data. Analyze and give me an table that i need to query for provided question. Sometimes the exact column not be there in the table so select table that contains most relevant columns. Ex. Daily average of precipitation rate asked but exact precipitation column is not there then select the table that contains relevant column like total_precipitation, precipitation_rate, total_precipitation_rate, etc. Below is the description and input qustion {table_map} Question: {question} Please use the following format: {question}:appropriate table """ ================================================ FILE: xql/src/weather_lm/gemini.py ================================================ #!/usr/bin/env python3 # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from langchain_google_genai import ChatGoogleGenerativeAI from .constant import few_shots from .template import DEFINED_PROMPTS from .utils import get_invocation_steps, get_table_map_prompt from xql import run_query def nl_to_sql_query(input_statement: str) -> str: """ Convert a natural language query to SQL. Parameters: - input_statement (str): The natural language query. Returns: - str: The generated SQL query. """ # Check if API key is provided either directly or through environment variable api_key = os.getenv("GOOGLE_API_KEY") if api_key is None: raise RuntimeError("Environment variable GOOGLE_API_KEY is not set.") # Get table map table_map_prompt, table_map = get_table_map_prompt() # Initialize model for natural language processing model = ChatGoogleGenerativeAI(model="gemini-pro") # Get invocation steps for selecting dataset select_dataset_model = get_invocation_steps(DEFINED_PROMPTS['select_dataset'], model) # Get invocation steps for generating SQL generate_sql_model = get_invocation_steps(DEFINED_PROMPTS['generate_sql'], model) # Invoke pipeline to select dataset based on input statement select_dataset_res = select_dataset_model.invoke({ "question": input_statement, "table_map": table_map_prompt }) # Extract dataset key from result dataset_key = select_dataset_res.split(":")[-1].strip() # Retrieve dataset metadata using dataset key dataset_metadata = table_map[dataset_key] # Invoke pipeline to generate SQL query generate_sql_res = generate_sql_model.invoke({ "question": input_statement, "table": dataset_metadata['uri'], "columns": dataset_metadata['columns'], "few_shot_examples": few_shots, "dims": dataset_metadata["dims"], 'latitude_dim': dataset_metadata["latitude_dim"], 'latitude_range': dataset_metadata["latitude_range"], 'longitude_dim': dataset_metadata["longitude_dim"], 'longitude_range': dataset_metadata["longitude_range"] }) # Extract SQL query from result. # The response will look like [SQLQuery: SELECT * FROM {table} WHERE ...]. # So slice the sql query from string. sql_query = generate_sql_res[11:-1] return sql_query def nl_to_weather_data(input_statement: str): """ Convert a natural language query to SQL and fetch weather data. Parameters: - input_statement (str): The natural language query. """ # Generate SQL query sql_query = nl_to_sql_query(input_statement) # Print generated SQL statement for debugging print("Generated SQL Statement:", sql_query) # Execute SQL query to fetch weather data return run_query(sql_query) ================================================ FILE: xql/src/weather_lm/template.py ================================================ #!/usr/bin/env python3 # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from langchain.prompts import PromptTemplate from .constant import GENERATE_SQL_TEMPLATE, SELECT_DATASET_TEMPLATE DEFINED_PROMPTS = { 'select_dataset': PromptTemplate( input_variables = ["table_map", "question"], template = SELECT_DATASET_TEMPLATE, ), 'generate_sql': PromptTemplate( input_variables = [ "question", "few_shot_examples", "table", "columns", "dims", "latitude_dim", "latitude_range", "longitude_dim", "longitude_range" ], template = GENERATE_SQL_TEMPLATE, ) } ================================================ FILE: xql/src/weather_lm/utils.py ================================================ #!/usr/bin/env python3 # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import typing as t from gcsfs import GCSFileSystem from langchain.prompts import PromptTemplate from langchain.schema.output_parser import StrOutputParser from langchain_google_genai import ChatGoogleGenerativeAI from .constant import METADATA_URI def get_table_map() -> t.Dict: """ Load and return the table map from dataset-meta.json file. Returns: dict: Dictionary containing table names as keys and their metadata as values. """ fs = GCSFileSystem() table_map = {} with fs.open(METADATA_URI) as f: table_map = json.load(f) return table_map def get_table_map_prompt() -> t.Tuple: """ Generate a prompt containing information about each table in the dataset. Returns: tuple: A tuple containing the prompt string and the table map dictionary. """ table_prompts = [] table_map = get_table_map() for k, v in table_map.items(): data_str = f"""Table name is {k}. It's located at {v['uri']} and containing following columns: {', '.join(v['columns'])}""" table_prompts.append(data_str) return "\n".join(table_prompts), table_map def get_invocation_steps(prompt: PromptTemplate, model: ChatGoogleGenerativeAI): """ Get the invocation steps for a given prompt and model. Parameters: - prompt (PromptTemplate): The prompt template to use. - model (ChatGoogleGenerativeAI): The generative model to use. Returns: - Pipeline: The invocation steps for the given prompt and model. """ chat = ( prompt | model.bind() | StrOutputParser() ) return chat ================================================ FILE: xql/src/xql/__init__.py ================================================ # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .apply import main, parse_query, run_query #noqa ================================================ FILE: xql/src/xql/apply.py ================================================ #!/usr/bin/env python3 # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import readline # noqa import numpy as np import pandas as pd import typing as t import xarray as xr from sqlglot import parse_one, exp from xarray.core.groupby import DatasetGroupBy from .open import get_chunking, open_dataset from .utils import timing from .where import apply_where command_info = { ".exit": "To exit from the current session.", ".set": "To set the dataset uri as a shortened key. e.g. .set era5 gs://{BUCKET}/dataset-uri", ".show": "To list down dataset shortened key. e.g. .show or .show [key]", "[query]": "Any valid sql like query." } table_dataset_map = {} # To store dataset shortened keys for a single session. operate = { "and" : lambda a, b: a & b, "or" : lambda a, b: a | b, "eq" : lambda a, b: a == b, "gt" : lambda a, b: a > b, "lt" : lambda a, b: a < b, "gte" : lambda a, b: a >= b, "lte" : lambda a, b: a <= b, } aggregate_function_map = { 'avg': lambda x, y: x.mean(dim=y), 'min': lambda x, y: x.min(dim=y), 'max': lambda x, y: x.max(dim=y), 'sum': lambda x, y: x.sum(dim=y), } timestamp_formats = { 'time_date':"%Y-%m-%d", 'time_month':"%Y-%m", 'time_year': "%Y" } def parse(a: t.Union[xr.DataArray, str], b: t.Union[xr.DataArray, str]) -> t.Tuple[t.Union[xr.DataArray, str], t.Union[xr.DataArray, str]]: """ Parse input values 'a' and 'b' into NumPy arrays with compatible types for evaluation. Parameters: - a (Union[xr.DataArray, str]): The first input value. - b (Union[xr.DataArray, str]): The second input value. Returns: - Tuple[xr.DataArray, Union[np.float64, np.float32, np.datetime64]]: Parsed NumPy arrays 'a' and 'b'. """ if isinstance(a, str): a, b = b, a arr_type = a.dtype.name if arr_type == 'float64': b = np.float64(b) elif arr_type == 'float32': b = np.float32(b) elif arr_type == 'datetime64[ns]': b = np.datetime64(b) return a, b def apply_orderby(e: exp.Order, df: pd.DataFrame) -> pd.DataFrame: """ Apply ORDER BY clause to the DataFrame. Args: - e (exp.Order): Parsed ORDER BY expression. - df (pd.DataFrame): DataFrame to be sorted. Returns: - pd.DataFrame: Sorted DataFrame based on the ORDER BY clause. """ orderby_columns = [] orderby_columns_order = [] # Extract columns and sorting orders from the parsed ORDER BY expression for el in e.expressions: orderby_column = el.find(exp.Column).find(exp.Identifier).this order = not bool(el.args['desc']) # Descending if desc=False orderby_columns.append(orderby_column) orderby_columns_order.append(order) # Sort the DataFrame based on the extracted columns and orders df = df.sort_values(orderby_columns, ascending=orderby_columns_order) return df def aggregate_variables(agg_funcs: t.List[t.Dict[str, str]], ds: xr.Dataset, time_fields: t.List[str], coords_to_squeeze: t.List[str]) -> xr.Dataset: """ Aggregate variables in an xarray dataset based on aggregation functions. Args: agg_funcs (List[Dict[str, str]]): List of dictionaries specifying aggregation functions for variables. ds (xr.Dataset): The input xarray dataset. time_fields (List[str]): List of time fields to consider for time-based grouping. coords_to_squeeze (List[str]): List of coordinates to be squeezed during aggregation. Returns: xr.Dataset: The aggregated xarray dataset. """ agg_dataset = xr.Dataset(coords=ds.coords, attrs=ds.attrs) # Aggregate based on time fields if len(time_fields): agg_dataset = agg_dataset.groupby(ds['time'].dt.strftime(timestamp_formats[time_fields[0]])) agg_dataset = apply_aggregation(agg_dataset, 'avg', None) agg_dataset = agg_dataset.rename({"strftime": time_fields[0]}) # Aggregate based on other coordinates agg_dataset = apply_aggregation(agg_dataset, 'avg', coords_to_squeeze) # Loop through aggregation functions for agg_func in agg_funcs: variable, function = agg_func['var'], agg_func['func'] grouped_ds = ds[variable] dims = [value for value in coords_to_squeeze if value in ds[variable].coords] if coords_to_squeeze else None # If time fields are specified, group by time if len(time_fields): groups = grouped_ds.groupby(ds['time'].dt.strftime(timestamp_formats[time_fields[0]])) grouped_ds = apply_aggregation(groups, function, None) grouped_ds = grouped_ds.rename({"strftime": time_fields[0]}) # Apply aggregation on dimensions agg_dim_ds = apply_aggregation(grouped_ds, function, dims) agg_dataset = agg_dataset.assign({f"{function}_{variable}": agg_dim_ds}) return agg_dataset def apply_group_by(time_fields: t.List[str], ds: xr.Dataset, agg_funcs: t.Dict[str, str], coords_to_squeeze: t.List[str] = []) -> xr.Dataset: """ Apply group-by and aggregation operations to the dataset based on specified fields and aggregation functions. Parameters: - time_fields (List[str]): List of time_fields(coordinates) to be used for grouping. - ds (xarray.Dataset): The input dataset. - agg_funcs (t.List[t.Dict[str, str]]): Dictionary mapping aggregation function names to their corresponding xarray-compatible string representations. - coords_to_squeeze (t.List[str]): The dimension along which to apply the aggregation. If None, aggregation is applied to the entire dataset. Returns: - xarray.Dataset: The dataset after applying group-by and aggregation operations. """ grouped_ds = ds if len(time_fields) > 1: raise NotImplementedError("GroupBy using multiple time fields is not supported.") elif len(time_fields) == 1: grouped_ds = aggregate_variables(agg_funcs, ds, time_fields, coords_to_squeeze) return grouped_ds def apply_aggregation(groups: t.Union[xr.Dataset, DatasetGroupBy], fun: str, dim: t.List[str] = []) -> xr.DataArray: """ Apply aggregation to the groups based on the specified aggregation function. Parameters: - groups (Union[xr.Dataset, xr.core.groupby.DatasetGroupBy]): The input dataset or dataset groupby object. - fun (str): The aggregation function to be applied. - dim (Optional[str]): The dimension along which to apply the aggregation. If None, aggregation is applied to the entire dataset. Returns: - xr.Dataset: The dataset after applying the aggregation. """ return aggregate_function_map[fun](groups, dim) def get_coords_to_squeeze(fields: t.List[str], ds: xr.Dataset) -> t.List[str]: """ Get the coordinates to squeeze from an xarray dataset. The function identifies coordinates in the dataset that are not part of the specified fields and are not the 'time' coordinate. Args: fields (List[str]): List of field names. ds (xr.Dataset): The xarray dataset. Returns: List[str]: List of coordinates to squeeze. """ # Identify coordinates not in fields and not 'time' coords_to_squeeze = [coord for coord in ds.coords if coord not in fields and (coord != "time")] return coords_to_squeeze def get_table(e: exp.Expression) -> str: """ Get the table name from an expression. Args: e (Expression): The expression containing table information. Returns: str: The table name. """ # Extract the table name from the expression table = e.find(exp.Table).args['this'].args['this'] # Check if the table is mapped in table_dataset_map if table in table_dataset_map: table = table_dataset_map[table] return table def parse_query(query: str) -> xr.Dataset: expr = parse_one(query) if not isinstance(expr, exp.Select): return "ERROR: Only select queries are supported." table = get_table(expr) is_star = expr.find(exp.Star) data_vars = [] if is_star is None: data_vars = [var.args['this'].args['this'] if var.key == 'column' else var.args['this'] for var in expr.expressions if (var.key == "column" or (var.key == "literal" and var.args.get("is_string") is True))] where_clause = expr.find(exp.Where) group_by = expr.find(exp.Group) agg_funcs = [ { 'var': var.args['this'].args['this'].args['this'] if var.args['this'].key == 'column' else var.args['this'].args['this'], 'func': var.key } for var in expr.expressions if var.key in aggregate_function_map ] if len(agg_funcs): data_vars = [ agg_var['var'] for agg_var in agg_funcs ] ds, chunkable = open_dataset(table) if is_star is None: ds = ds[data_vars] if where_clause is not None: ds = apply_where(ds, where_clause.args['this']) if chunkable: ds = ds.chunk(chunks=get_chunking(table, list(ds.data_vars))) coords_to_squeeze = None time_fields = [] if group_by: fields = [ e.args['this'].args['this'] for e in group_by.args['expressions'] ] time_fields = list(filter(lambda field: "time" in field, fields)) coords_to_squeeze = get_coords_to_squeeze(fields, ds) ds = apply_group_by(time_fields, ds, agg_funcs, coords_to_squeeze) if len(time_fields) == 0 and len(agg_funcs): if isinstance(coords_to_squeeze, t.List): coords_to_squeeze.append('time') ds = aggregate_variables(agg_funcs, ds, time_fields, coords_to_squeeze) return ds def convert_to_dataframe(ds: xr.Dataset) -> pd.DataFrame: """ Convert xarray Dataset to pandas DataFrame. Args: ds (xr.Dataset): xarray Dataset to be converted. Returns: pd.DataFrame: Pandas DataFrame containing the data from the xarray Dataset. """ if len(ds.coords): # If the dataset has coordinates, convert it to DataFrame and reset index df = ds.to_dataframe().reset_index() else: # If the dataset doesn't have coordinates, compute it and convert to dictionary ds = ds.compute().to_dict(data="list") # Create DataFrame from dictionary df = pd.DataFrame({k: [v['data']] for k, v in ds['data_vars'].items()}) return df def filter_records(df: pd.DataFrame, query: str) -> pd.DataFrame: """ Filter records in an xarray Dataset based on a given query. Args: ds (xr.Dataset): The xarray Dataset to filter. query (str): The query string for filtering the dataset. Returns: pd.DataFrame: A pandas DataFrame containing the filtered records. """ # Parse the query expression expr = parse_one(query) # Find Limit, Offset and OrderBy clauses in the query orderby_clause = expr.find(exp.Order) limit_clause = expr.find(exp.Limit) offset_clause = expr.find(exp.Offset) # Apply orderby clause if present if orderby_clause: df = apply_orderby(orderby_clause, df) # Initialize start location for slicing start_loc = 0 # Apply offset clause if present if offset_clause: start_loc = int(offset_clause.expression.args['this']) df = df.iloc[start_loc:] # Apply limit clause if present if limit_clause: limit = int(limit_clause.expression.args['this']) df = df.iloc[:start_loc + limit] # Compute and return the filtered DataFrame return df def set_dataset_table(cmd: str) -> None: """ Set the mapping between a key and a dataset. Args: cmd (str): The command string in the format ".set key val" where key is the identifier and val is the dataset table. """ # Split the command into parts cmd_parts = cmd.split(" ") # Check if the command has the correct number of arguments if len(cmd_parts) == 3: # Extract key and val from the command _, key, val = cmd_parts # Update the dataset table mapping table_dataset_map[key] = val else: # Print an error message for incorrect arguments print("Incorrect args. Run .help .set for usage info.") def list_key_values(input: t.Dict[str, str]) -> None: """ Display key-value pairs from a dictionary. Args: input (Dict[str, str]): The dictionary containing key-value pairs. """ for cmd, desc in input.items(): print(f"{cmd} => {desc}") def display_help(cmd: str) -> None: """ Display help information for commands. Args: cmd (str): The command string. """ cmd_parts = cmd.split(" ") if len(cmd_parts) == 2: if cmd_parts[1] in command_info: print(f"{cmd_parts[1]} => {command_info[cmd_parts[1]]}") else: list_key_values(command_info) elif len(cmd_parts) == 1: list_key_values(command_info) else: print("Incorrect usage. Run .help or .help [cmd] for usage info.") def display_table_dataset_map(cmd: str) -> None: """ Display information from the table_dataset_map. Args: cmd (str): The command string. """ cmd_parts = cmd.split(" ") if len(cmd_parts) == 2: if cmd_parts[1] in table_dataset_map: print(f"{cmd_parts[1]} => {table_dataset_map[cmd_parts[1]]}") else: list_key_values(table_dataset_map) else: list_key_values(table_dataset_map) @timing def run_query(query: str) -> None: """ Run a query and display the result. Args: query (str): The query to be executed. """ try: result = parse_query(query) except Exception as e: result = f"ERROR: {type(e).__name__}: {e.__str__()}." return result return filter_records(convert_to_dataframe(result), query) @timing def main(query: str): """ Main function for runnning this file. """ if ".help" in query: display_help(query) elif ".set" in query: set_dataset_table(query) elif ".show" in query: display_table_dataset_map(query) else: result = run_query(query) print(result) ================================================ FILE: xql/src/xql/constant.py ================================================ #!/usr/bin/env python3 # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. SUPPORTED_CUSTOM_COORDS = ['city', 'country'] COUNTRIES_BOUNDING_BOXES = { 'india': (6.5546079, 35.4940095078, 68.1766451354, 97.4025614766), 'canada': (41.6751050889, 83.23324, -140.99778, -52.6480987209), 'japan': (31.0295791692, 45.5514834662, 129.408463169, 145.543137242), 'united kingdom': (49.959999905, 58.6350001085, -7.57216793459, 1.68153079591), 'south africa': (-34.8191663551, -22.0913127581, 16.3449768409, 32.830120477), 'australia': (-44, -10, 113, 154), 'united states': (24.396308, 49.384358, -125.0, -66.93457) } CITIES_BOUNDING_BOXES = { 'delhi': (28.404, 28.883, 76.838, 77.348), 'new york': (40.4774, 40.9176, -74.2591, -73.7002), 'san francisco': (37.6398, 37.9298, -122.5975, -122.3210), 'los angeles': (33.7036, 34.3373, -118.6682, -118.1553), 'london': (51.3849, 51.6724, -0.3515, 0.1482) } ================================================ FILE: xql/src/xql/open.py ================================================ #!/usr/bin/env python3 # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ee import logging import zarr import typing as t import xarray as xr logger = logging.getLogger(__name__) OPENER_MAP = { "zarr": "zarr", "ee": "ee" } def get_chunking(uri: str, variables: t.List[str]) -> t.Dict: """ Retrieve chunking information for the specified variables in a Zarr dataset. Parameters: uri (str): The URI of the Zarr dataset. variables (List[str]): A list of variable names. Returns: t.Dict: A dictionary containing chunking information for each variable. """ # Initialize dictionary to store chunking information chunks = {} # Open the Zarr dataset zf = zarr.open(uri) # Iterate over each variable for v in variables: # Get the variable object var = zf[v] # Get chunking info for the variable var_chunks = var.chunks # Get variable dimensions var_dims = var.attrs.get('_ARRAY_DIMENSIONS') # Map dimensions to chunk sizes chunk_dict = dict(zip(var_dims, var_chunks)) # Update chunks with array chunk dimensions chunks.update(chunk_dict) # Return chunking information dictionary return chunks def open_dataset(uri: str) -> t.Tuple[xr.Dataset, bool]: """ Open a dataset from the given URI using the appropriate engine. Parameters: - uri (str): The URI of the dataset to open. Returns: - xr.Dataset: The opened dataset. Raises: - RuntimeError: If unable to open the dataset. """ chunkable = False try: # Check if the URI starts with "ee://" if uri.startswith("ee://"): # If yes, initialize Earth Engine ee.Initialize(opt_url='https://earthengine-highvolume.googleapis.com') # Open dataset using Earth Engine engine ds = xr.open_dataset(uri, engine=OPENER_MAP["ee"]) else: # If not, open dataset using zarr engine ds = xr.open_zarr(uri, chunks=None) chunkable = True except Exception: # If opening fails, raise RuntimeError raise RuntimeError("Unable to open dataset. [zarr, ee] are the only supported dataset types.") return ds, chunkable ================================================ FILE: xql/src/xql/utils.py ================================================ #!/usr/bin/env python3 # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from dask_kubernetes import HelmCluster from dask.distributed import Client from functools import wraps from time import gmtime, strftime, time def timing(f): """Measure a time for any function execution.""" @wraps(f) def wrap(*args, **kw): ts = time() result = f(*args, **kw) te = time() print(f"Query took: { strftime('%H:%M:%S', gmtime(te - ts)) }") return result return wrap def connect_dask_cluster() -> None: """ Connects to a Dask cluster. """ # Fetch the cluster name from environment variable, default to "xql-dask" if not set cluster_name = os.getenv('DASK_CLUSTER', "xql-dask") try: # Create a HelmCluster instance with the specified release name cluster = HelmCluster(release_name=cluster_name) # Connect a Dask client to the cluster client = Client(cluster) # noqa: F841 # Print a message indicating successful connection print("Dask cluster connected.") except Exception: # Print a message indicating failure to connect print("Dask cluster not connected.") ================================================ FILE: xql/src/xql/where.py ================================================ #!/usr/bin/env python3 # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import typing as t import xarray as xr from sqlglot import exp from .constant import CITIES_BOUNDING_BOXES, COUNTRIES_BOUNDING_BOXES, SUPPORTED_CUSTOM_COORDS def parse(a: t.Union[xr.DataArray, str], b: t.Union[xr.DataArray, str]) -> t.Tuple[t.Union[xr.DataArray, str], t.Union[xr.DataArray, str]]: """ Parse input values 'a' and 'b' into NumPy arrays with compatible types for evaluation. Parameters: - a (Union[xr.DataArray, str]): The first input value. - b (Union[xr.DataArray, str]): The second input value. Returns: - Tuple[xr.DataArray, Union[np.float64, np.float32, np.datetime64]]: Parsed NumPy arrays 'a' and 'b'. """ if isinstance(a, str): a, b = b, a arr_type = a.dtype.name if arr_type == 'float64': b = np.float64(b) elif arr_type == 'float32': b = np.float32(b) elif arr_type == 'datetime64[ns]': b = np.datetime64(b) return a, b def get_sop_terms(expression: exp.Expression): def cross_product(exp1: t.List[exp.Expression], exp2: t.List[exp.Expression]): pdt = [] if(len(exp1) == 0): return exp2 if(len(exp2) == 0): return exp1 for item1 in exp1: for item2 in exp2: term = item1.and_(item2) pdt.append(term) return pdt if expression.key == "and": children = list(expression.flatten()) product_terms = [] for child in children: terms = get_sop_terms(child) product_terms = cross_product(product_terms, terms) return product_terms elif expression.key == "or": children = list(expression.flatten()) sum_terms = [] for child in children: if child.key=="and" or child.key=="or": sum_terms.extend(get_sop_terms(child)) else: sum_terms.append(child) return sum_terms else: return [expression] def check_conditional(expression: exp.Expression) -> bool: k = expression.key if k == 'lt' or k == 'gt' or k == 'gte' or k == 'lte' or k == 'eq': return True return False def parse_condition(expression: exp.Expression, condition_dict: dict) -> bool: # TODO: This function right now assumnes that for a condition # LHS is always Column name and RHS is always value. op = expression.key args = expression.args left = args.get('this', None) right = args.get('expression', None) identifier = left.args.get('this') coordinate = identifier.args.get('this') # handing negative if isinstance(right, exp.Neg): right = right.args.get('this') value = "-" + right.args.get('this') else: value = right.args.get('this') if coordinate not in condition_dict: condition_dict[coordinate] = {} condition_dict[coordinate][op] = value def is_ascending_order(da: xr.DataArray) -> bool: """Simple check if a dataarray is in increasing or descreasing order.""" da = da.drop_duplicates(...) if(da[0].values < da[1].values): return True return False def select_coordinate(da: xr.DataArray, coordinate: str, operator: str, value: any) -> xr.DataArray: """Based on operator and sort order (if data is in ascending or descending order) It will apply the greater or lesser condition. Equal condition does not depend on that. """ op = operator da, parsed_value = parse(da, value) if op == 'eq': return da.sel({coordinate: parsed_value}) if(is_ascending_order(da)): if op == 'gt' or op == 'gte': return da.sel({coordinate: slice(parsed_value, None)}) elif op == 'lt' or op == 'lte': return da.sel({coordinate: slice(None, parsed_value)}) else: raise ValueError(f"Unkown operator in select_coordinate op: {op}.") else: if op == 'gt' or op == 'gte': return da.sel({coordinate: slice(None, parsed_value)}) elif op == 'lt' or op == 'lte': return da.sel({coordinate: slice(parsed_value, None)}) else: raise ValueError(f"Unkown operator in select_coordinate op: {op}.") def get_coords_condition(coordinate: str, condition: t.Dict[str, str]): """Generate a bounding box from the defined lat long ranges for cities / countries.""" value = condition['eq'].lower() bounding_box = { 'latitude': { }, 'longitude': { } } if value in COUNTRIES_BOUNDING_BOXES: lat_min, lat_max, lon_min, lon_max = COUNTRIES_BOUNDING_BOXES[value] elif value in CITIES_BOUNDING_BOXES: lat_min, lat_max, lon_min, lon_max = CITIES_BOUNDING_BOXES[value] else: raise NotImplementedError(f"Can not query for {coordinate}:{value}") bounding_box['latitude']['gte'] = lat_min bounding_box['latitude']['lte'] = lat_max bounding_box['longitude']['gte'] = lon_min + 360 if lon_min < 0 else lon_min bounding_box['longitude']['lte'] = lon_max + 360 if lon_max < 0 else lon_max return bounding_box def filter_condition_dict(condition_dict: dict, ds: xr.Dataset): """Filter out custom fields and update them with actual dataset supported coord. { 'country': 'new york' } => { 'latitude': {...}, 'longitude': {...} } """ result = {} for coordinate, conditions in condition_dict.items(): if coordinate in ds.coords: result[coordinate] = conditions elif coordinate in SUPPORTED_CUSTOM_COORDS: simple_coords_dict = get_coords_condition(coordinate, conditions) result.update(simple_coords_dict) else: raise NotImplementedError(f"Dataset can not be queried over {coordinate} field") return result def apply_select_condition(ds: xr.Dataset, condition_dict: dict) -> xr.Dataset: """A condition dict will be in form { 'coord1': {'lt': val1, 'gt': val2}, 'coord2: {'eq': someval}, ...} This function applies above conditions on the dataset. """ absolute_condition_dict = {} condition_dict = filter_condition_dict(condition_dict, ds) for coordinate, conditions in condition_dict.items(): coordinate_array = ds[coordinate] # Iterate each operation and apply on a data array. for operator, value in conditions.items(): coordinate_array = select_coordinate(coordinate_array, coordinate, operator, value) coordinate_values = coordinate_array.values # If the condition mentions lt(less than) or gt(greater than), # then remove the last or first element based on below conditions. if 'gt' in conditions: if is_ascending_order(coordinate_array): coordinate_values = np.delete(coordinate_values, 0) else: coordinate_values = np.delete(coordinate_values, -1) if 'lt' in conditions: if is_ascending_order(coordinate_array): coordinate_values = np.delete(coordinate_values, -1) else: coordinate_values = np.delete(coordinate_values, 0) absolute_condition_dict[coordinate] = coordinate_values # Finally perform the select operation. return ds.sel(absolute_condition_dict) def postorder(expression: exp.Expression, condition_dict: dict): """Performs post order traversal on sqlglot expression and converts it into a dict. The dict in updated in place. """ if expression is None: return if expression.key == "literal" or expression.key == "identifier" or expression.key == "column": return args = expression.args left = args.get('this', None) right = args.get('expression', None) postorder(left, condition_dict) postorder(right, condition_dict) if(check_conditional(expression)): parse_condition(expression, condition_dict) def apply_where(ds: xr.Dataset, expression: exp.Expression) -> xr.Dataset: terms = get_sop_terms(expression) or_ds = [] for term in terms: condition_dict = {} postorder(term, condition_dict) reduced_ds = apply_select_condition(ds, condition_dict) or_ds.append(reduced_ds) # TODO: Add support for OR operation. return or_ds[0]