Repository: googledatalab/pydatalab Branch: master Commit: 8bf007da3e43 Files: 438 Total size: 6.3 MB Directory structure: gitextract_nogfp4_v/ ├── .build-bot.json ├── .coveragerc ├── .gitignore ├── .travis.yml ├── CONTRIBUTING.md ├── LICENSE.txt ├── README.md ├── datalab/ │ ├── README │ ├── __init__.py │ ├── bigquery/ │ │ ├── __init__.py │ │ ├── _api.py │ │ ├── _csv_options.py │ │ ├── _dataset.py │ │ ├── _dialect.py │ │ ├── _federated_table.py │ │ ├── _job.py │ │ ├── _parser.py │ │ ├── _query.py │ │ ├── _query_job.py │ │ ├── _query_results_table.py │ │ ├── _query_stats.py │ │ ├── _sampling.py │ │ ├── _schema.py │ │ ├── _table.py │ │ ├── _udf.py │ │ ├── _utils.py │ │ ├── _view.py │ │ └── commands/ │ │ ├── __init__.py │ │ └── _bigquery.py │ ├── context/ │ │ ├── __init__.py │ │ ├── _api.py │ │ ├── _context.py │ │ ├── _project.py │ │ ├── _utils.py │ │ └── commands/ │ │ ├── __init__.py │ │ └── _projects.py │ ├── data/ │ │ ├── __init__.py │ │ ├── _csv.py │ │ ├── _sql_module.py │ │ ├── _sql_statement.py │ │ ├── _utils.py │ │ └── commands/ │ │ ├── __init__.py │ │ └── _sql.py │ ├── kernel/ │ │ └── __init__.py │ ├── notebook/ │ │ ├── __init__.py │ │ └── static/ │ │ ├── bigquery.css │ │ ├── bigquery.ts │ │ ├── charting.css │ │ ├── charting.ts │ │ ├── element.ts │ │ ├── extern/ │ │ │ ├── d3.parcoords.css │ │ │ ├── d3.parcoords.js │ │ │ ├── lantern-browser.html │ │ │ ├── parcoords-LICENSE.txt │ │ │ ├── sylvester-LICENSE.txt │ │ │ └── sylvester.js │ │ ├── job.css │ │ ├── job.ts │ │ ├── parcoords.ts │ │ ├── style.ts │ │ └── visualization.ts │ ├── stackdriver/ │ │ ├── __init__.py │ │ ├── commands/ │ │ │ ├── __init__.py │ │ │ └── _monitoring.py │ │ └── monitoring/ │ │ ├── __init__.py │ │ ├── _group.py │ │ ├── _metric.py │ │ ├── _query.py │ │ ├── _query_metadata.py │ │ ├── _resource.py │ │ └── _utils.py │ ├── storage/ │ │ ├── __init__.py │ │ ├── _api.py │ │ ├── _bucket.py │ │ ├── _item.py │ │ └── commands/ │ │ ├── __init__.py │ │ └── _storage.py │ └── utils/ │ ├── __init__.py │ ├── _async.py │ ├── _dataflow_job.py │ ├── _gcp_job.py │ ├── _http.py │ ├── _iterator.py │ ├── _job.py │ ├── _json_encoder.py │ ├── _lambda_job.py │ ├── _lru_cache.py │ ├── _utils.py │ └── commands/ │ ├── __init__.py │ ├── _chart.py │ ├── _chart_data.py │ ├── _commands.py │ ├── _csv.py │ ├── _extension.py │ ├── _html.py │ ├── _job.py │ ├── _modules.py │ └── _utils.py ├── docs/ │ ├── .nojekyll │ ├── Makefile │ ├── README │ ├── conf.py │ ├── datalab Commands.rst │ ├── datalab.bigquery.rst │ ├── datalab.context.rst │ ├── datalab.data.rst │ ├── datalab.stackdriver.monitoring.rst │ ├── datalab.storage.rst │ ├── gen-magic-rst.ipy │ ├── google.datalab Commands.rst │ ├── google.datalab.bigquery.rst │ ├── google.datalab.data.rst │ ├── google.datalab.ml.rst │ ├── google.datalab.rst │ ├── google.datalab.stackdriver.monitoring.rst │ ├── google.datalab.storage.rst │ ├── index.rst │ ├── make.bat │ ├── mltoolbox.classification.dnn.rst │ ├── mltoolbox.classification.linear.rst │ ├── mltoolbox.image.classification.rst │ ├── mltoolbox.regression.dnn.rst │ └── mltoolbox.regression.linear.rst ├── externs/ │ └── ts/ │ └── require/ │ └── require.d.ts ├── google/ │ ├── __init__.py │ └── datalab/ │ ├── __init__.py │ ├── _context.py │ ├── _job.py │ ├── bigquery/ │ │ ├── __init__.py │ │ ├── _api.py │ │ ├── _csv_options.py │ │ ├── _dataset.py │ │ ├── _external_data_source.py │ │ ├── _job.py │ │ ├── _parser.py │ │ ├── _query.py │ │ ├── _query_job.py │ │ ├── _query_output.py │ │ ├── _query_results_table.py │ │ ├── _query_stats.py │ │ ├── _sampling.py │ │ ├── _schema.py │ │ ├── _table.py │ │ ├── _udf.py │ │ ├── _utils.py │ │ ├── _view.py │ │ └── commands/ │ │ ├── __init__.py │ │ └── _bigquery.py │ ├── commands/ │ │ ├── __init__.py │ │ └── _datalab.py │ ├── contrib/ │ │ ├── __init__.py │ │ ├── bigquery/ │ │ │ ├── __init__.py │ │ │ ├── commands/ │ │ │ │ ├── __init__.py │ │ │ │ └── _bigquery.py │ │ │ └── operators/ │ │ │ ├── __init__.py │ │ │ ├── _bq_execute_operator.py │ │ │ ├── _bq_extract_operator.py │ │ │ └── _bq_load_operator.py │ │ ├── mlworkbench/ │ │ │ ├── __init__.py │ │ │ ├── _archive.py │ │ │ ├── _local_predict.py │ │ │ ├── _prediction_explainer.py │ │ │ ├── _shell_process.py │ │ │ └── commands/ │ │ │ ├── __init__.py │ │ │ └── _ml.py │ │ └── pipeline/ │ │ ├── __init__.py │ │ ├── _pipeline.py │ │ ├── airflow/ │ │ │ ├── __init__.py │ │ │ └── _airflow.py │ │ ├── commands/ │ │ │ ├── __init__.py │ │ │ └── _pipeline.py │ │ └── composer/ │ │ ├── __init__.py │ │ ├── _api.py │ │ └── _composer.py │ ├── data/ │ │ ├── __init__.py │ │ └── _csv_file.py │ ├── kernel/ │ │ └── __init__.py │ ├── ml/ │ │ ├── __init__.py │ │ ├── _cloud_models.py │ │ ├── _cloud_training_config.py │ │ ├── _confusion_matrix.py │ │ ├── _dataset.py │ │ ├── _fasets.py │ │ ├── _feature_slice_view.py │ │ ├── _job.py │ │ ├── _metrics.py │ │ ├── _summary.py │ │ ├── _tensorboard.py │ │ └── _util.py │ ├── notebook/ │ │ ├── __init__.py │ │ └── static/ │ │ ├── bigquery.css │ │ ├── bigquery.ts │ │ ├── charting.css │ │ ├── charting.ts │ │ ├── element.ts │ │ ├── extern/ │ │ │ ├── d3.parcoords.css │ │ │ ├── d3.parcoords.js │ │ │ ├── facets-jupyter.html │ │ │ ├── lantern-browser.html │ │ │ ├── parcoords-LICENSE.txt │ │ │ ├── sylvester-LICENSE.txt │ │ │ └── sylvester.js │ │ ├── job.css │ │ ├── job.ts │ │ ├── parcoords.ts │ │ ├── style.ts │ │ └── visualization.ts │ ├── stackdriver/ │ │ ├── __init__.py │ │ ├── commands/ │ │ │ ├── __init__.py │ │ │ └── _monitoring.py │ │ └── monitoring/ │ │ ├── __init__.py │ │ ├── _group.py │ │ ├── _metric.py │ │ ├── _query.py │ │ ├── _query_metadata.py │ │ ├── _resource.py │ │ └── _utils.py │ ├── storage/ │ │ ├── __init__.py │ │ ├── _api.py │ │ ├── _bucket.py │ │ ├── _object.py │ │ └── commands/ │ │ ├── __init__.py │ │ └── _storage.py │ └── utils/ │ ├── __init__.py │ ├── _async.py │ ├── _dataflow_job.py │ ├── _gcp_job.py │ ├── _http.py │ ├── _iterator.py │ ├── _json_encoder.py │ ├── _lambda_job.py │ ├── _lru_cache.py │ ├── _utils.py │ ├── commands/ │ │ ├── __init__.py │ │ ├── _chart.py │ │ ├── _chart_data.py │ │ ├── _commands.py │ │ ├── _csv.py │ │ ├── _html.py │ │ ├── _job.py │ │ └── _utils.py │ └── facets/ │ ├── __init__.py │ ├── base_feature_statistics_generator.py │ ├── base_generic_feature_statistics_generator.py │ ├── feature_statistics_generator.py │ ├── feature_statistics_pb2.py │ └── generic_feature_statistics_generator.py ├── install-no-virtualenv.sh ├── install-virtualenv.sh ├── legacy_tests/ │ ├── _util/ │ │ ├── __init__.py │ │ ├── http_tests.py │ │ ├── lru_cache_tests.py │ │ └── util_tests.py │ ├── bigquery/ │ │ ├── __init__.py │ │ ├── api_tests.py │ │ ├── dataset_tests.py │ │ ├── federated_table_tests.py │ │ ├── jobs_tests.py │ │ ├── parser_tests.py │ │ ├── query_tests.py │ │ ├── sampling_tests.py │ │ ├── schema_tests.py │ │ ├── table_tests.py │ │ ├── udf_tests.py │ │ └── view_tests.py │ ├── data/ │ │ ├── __init__.py │ │ └── sql_tests.py │ ├── kernel/ │ │ ├── __init__.py │ │ ├── bigquery_tests.py │ │ ├── chart_data_tests.py │ │ ├── chart_tests.py │ │ ├── commands_tests.py │ │ ├── html_tests.py │ │ ├── module_tests.py │ │ ├── sql_tests.py │ │ ├── storage_tests.py │ │ └── utils_tests.py │ ├── main.py │ ├── stackdriver/ │ │ ├── __init__.py │ │ ├── commands/ │ │ │ ├── __init__.py │ │ │ └── monitoring_tests.py │ │ └── monitoring/ │ │ ├── __init__.py │ │ ├── group_tests.py │ │ ├── metric_tests.py │ │ ├── query_metadata_tests.py │ │ ├── query_tests.py │ │ ├── resource_tests.py │ │ └── utils_tests.py │ └── storage/ │ ├── __init__.py │ ├── api_tests.py │ ├── bucket_tests.py │ └── item_tests.py ├── release.sh ├── setup.cfg ├── setup.py ├── solutionbox/ │ ├── image_classification/ │ │ ├── mltoolbox/ │ │ │ ├── __init__.py │ │ │ └── image/ │ │ │ ├── __init__.py │ │ │ └── classification/ │ │ │ ├── __init__.py │ │ │ ├── _api.py │ │ │ ├── _cloud.py │ │ │ ├── _inceptionlib.py │ │ │ ├── _local.py │ │ │ ├── _model.py │ │ │ ├── _predictor.py │ │ │ ├── _preprocess.py │ │ │ ├── _trainer.py │ │ │ ├── _util.py │ │ │ ├── setup.py │ │ │ └── task.py │ │ └── setup.py │ ├── ml_workbench/ │ │ ├── setup.py │ │ ├── tensorflow/ │ │ │ ├── __init__.py │ │ │ ├── analyze.py │ │ │ ├── setup.py │ │ │ ├── trainer/ │ │ │ │ ├── __init__.py │ │ │ │ ├── feature_analysis.py │ │ │ │ ├── feature_transforms.py │ │ │ │ └── task.py │ │ │ └── transform.py │ │ ├── test_tensorflow/ │ │ │ ├── run_all.sh │ │ │ ├── test_analyze.py │ │ │ ├── test_cloud_workflow.py │ │ │ ├── test_feature_transforms.py │ │ │ ├── test_training.py │ │ │ └── test_transform.py │ │ ├── test_xgboost/ │ │ │ ├── run_all.sh │ │ │ ├── test_analyze.py │ │ │ └── test_transform.py │ │ └── xgboost/ │ │ ├── __init__.py │ │ ├── analyze.py │ │ ├── setup.py │ │ ├── trainer/ │ │ │ ├── __init__.py │ │ │ ├── feature_analysis.py │ │ │ ├── feature_transforms.py │ │ │ └── task.py │ │ └── transform.py │ └── structured_data/ │ ├── build.sh │ ├── mltoolbox/ │ │ ├── __init__.py │ │ ├── _structured_data/ │ │ │ ├── __init__.py │ │ │ ├── __version__.py │ │ │ ├── _package.py │ │ │ ├── master_setup.py │ │ │ ├── prediction/ │ │ │ │ ├── __init__.py │ │ │ │ └── predict.py │ │ │ ├── preprocess/ │ │ │ │ ├── __init__.py │ │ │ │ ├── cloud_preprocess.py │ │ │ │ └── local_preprocess.py │ │ │ └── trainer/ │ │ │ ├── __init__.py │ │ │ ├── task.py │ │ │ └── util.py │ │ ├── classification/ │ │ │ ├── __init__.py │ │ │ ├── dnn/ │ │ │ │ ├── __init__.py │ │ │ │ └── _classification_dnn.py │ │ │ └── linear/ │ │ │ ├── __init__.py │ │ │ └── _classification_linear.py │ │ └── regression/ │ │ ├── __init__.py │ │ ├── dnn/ │ │ │ ├── __init__.py │ │ │ └── _regression_dnn.py │ │ └── linear/ │ │ ├── __init__.py │ │ └── _regression_linear.py │ ├── setup.py │ └── test_mltoolbox/ │ ├── __init__.py │ ├── e2e_functions.py │ ├── test_datalab_e2e.py │ ├── test_package_functions.py │ ├── test_sd_preprocess.py │ └── test_sd_trainer.py ├── tests/ │ ├── _util/ │ │ ├── __init__.py │ │ ├── commands_tests.py │ │ ├── feature_statistics_generator_test.py │ │ ├── generic_feature_statistics_generator_test.py │ │ ├── http_tests.py │ │ ├── lru_cache_tests.py │ │ └── util_tests.py │ ├── bigquery/ │ │ ├── __init__.py │ │ ├── api_tests.py │ │ ├── dataset_tests.py │ │ ├── external_data_source_tests.py │ │ ├── jobs_tests.py │ │ ├── operator_tests.py │ │ ├── parser_tests.py │ │ ├── pipeline_tests.py │ │ ├── query_tests.py │ │ ├── sampling_tests.py │ │ ├── schema_tests.py │ │ ├── table_tests.py │ │ ├── udf_tests.py │ │ └── view_tests.py │ ├── context_tests.py │ ├── integration/ │ │ └── storage_test.py │ ├── kernel/ │ │ ├── __init__.py │ │ ├── bigquery_tests.py │ │ ├── chart_data_tests.py │ │ ├── chart_tests.py │ │ ├── html_tests.py │ │ ├── pipeline_tests.py │ │ ├── storage_tests.py │ │ └── utils_tests.py │ ├── main.py │ ├── ml/ │ │ ├── __init__.py │ │ ├── confusion_matrix_tests.py │ │ ├── dataset_tests.py │ │ ├── facets_tests.py │ │ ├── metrics_tests.py │ │ ├── summary_tests.py │ │ └── tensorboard_tests.py │ ├── ml_workbench/ │ │ ├── __init__.py │ │ └── all_tests.py │ ├── mltoolbox_structured_data/ │ │ ├── __init__.py │ │ ├── dl_interface_tests.py │ │ ├── sd_e2e_tests.py │ │ └── traininglib_tests.py │ ├── mlworkbench_magic/ │ │ ├── __init__.py │ │ ├── archive_tests.py │ │ ├── explainer_tests.py │ │ ├── local_predict_tests.py │ │ ├── ml_tests.py │ │ └── shell_process_tests.py │ ├── pipeline/ │ │ ├── __init__.py │ │ ├── airflow_tests.py │ │ ├── composer_api_tests.py │ │ ├── composer_tests.py │ │ └── pipeline_tests.py │ ├── stackdriver/ │ │ ├── __init__.py │ │ ├── commands/ │ │ │ ├── __init__.py │ │ │ └── monitoring_tests.py │ │ └── monitoring/ │ │ ├── __init__.py │ │ ├── group_tests.py │ │ ├── metric_tests.py │ │ ├── query_metadata_tests.py │ │ ├── query_tests.py │ │ ├── resource_tests.py │ │ └── utils_tests.py │ └── storage/ │ ├── __init__.py │ ├── api_tests.py │ ├── bucket_tests.py │ └── object_tests.py └── tox.ini ================================================ FILE CONTENTS ================================================ ================================================ FILE: .coveragerc ================================================ # .coveragerc to control coverage.py [run] [report] include = */site-packages/google/datalab/* ================================================ FILE: .gitignore ================================================ *.pyc *.pyi *.map *.egg-info *.iml .idea .DS_Store MANIFEST build .coverage dist datalab.magics.rst datalab/notebook/static/*.js google/datalab/notebook/static/*.js # Test files .tox/ .cache/ ================================================ FILE: .travis.yml ================================================ language: python dist: trusty sudo: false matrix: include: - python: 2.7 env: TOX_ENV=py27 - python: 3.5 env: TOX_ENV=py35 - python: 2.7 env: TOX_ENV=flake8 - python: 2.7 env: TOX_ENV=coveralls before_install: - npm install -g typescript@3.0.3 - tsc --module amd --noImplicitAny --outdir datalab/notebook/static datalab/notebook/static/*.ts # We use tox for actually running tests. - pip install --upgrade pip tox script: # tox reads its configuration from tox.ini. - tox -e $TOX_ENV ================================================ FILE: CONTRIBUTING.md ================================================ Want to contribute? Great! First, read this page (including the small print at the end). ### Before you contribute Before we can use your code, you must sign the [Google Individual Contributor License Agreement] (https://cla.developers.google.com/about/google-individual) (CLA), which you can do online. The CLA is necessary mainly because you own the copyright to your changes, even after your contribution becomes part of our codebase, so we need your permission to use and distribute your code. We also need to be sure of various other things—for instance that you'll tell us if you know that your code infringes on other people's patents. You don't have to sign the CLA until after you've submitted your code for review and a member has approved it, but you must do it before we can put your code into our codebase. Before you start working on a larger contribution, you should get in touch with us first through the issue tracker with your idea so that we can help out and possibly guide you. Coordinating up front makes it much easier to avoid frustration later on. ### Code reviews All submissions, including submissions by project members, require review. We use Github pull requests for this purpose. ### Running tests We use [`tox`](https://tox.readthedocs.io/) for running our tests. To run tests before sending out a pull request, just [install tox](https://tox.readthedocs.io/en/latest/install.html) and run ```shell $ tox ``` to run tests under all supported environments. (This will skip any environments for which no interpreter is available.) `tox -l` will provide a list of all supported environments. `tox` will run all tests referenced by `tests/main.py` and `legacy_tests/main.py`. ### The small print Contributions made by corporations are covered by a different agreement than the one above, the [Software Grant and Corporate Contributor License Agreement] (https://cla.developers.google.com/about/google-corporate). ================================================ FILE: LICENSE.txt ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # Google Cloud DataLab Datalab is deprecated. [Vertex AI Workbench](https://cloud.google.com/vertex-ai/docs/workbench) provides a notebook-based environment that offers capabilities beyond Datalab. We recommend that you use Vertex AI Workbench for new projects and [migrate your Datalab notebooks to Vertex AI Workbench](https://cloud.google.com/datalab/docs/resources/troubleshooting#migrate). For more information, see [Deprecation information](https://cloud.google.com/datalab/docs/resources/deprecation). To get help migrating Datalab projects to Vertex AI Workbench see [Get help](https://cloud.google.com/datalab/docs/resources/support#get-help). ================================================ FILE: datalab/README ================================================ Everything under datalab namespace is actively maintained but no new features are being added. Please use corresponding libraries under google.datalab namespace (source code under google/datalab directory). To migrate existing code that relies on datalab namespace, since most API interfaces are the same between google.datalab and datalab, usually you just need to change the import namespace. The magic interface is different for bigquery though (%%sql --> %%bq). For more details please see https://github.com/googledatalab/pydatalab/wiki/%60datalab%60-to-%60google.datalab%60-Migration-Guide. ================================================ FILE: datalab/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: datalab/bigquery/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - BigQuery Functionality.""" from __future__ import absolute_import from ._csv_options import CSVOptions from ._dataset import Dataset, Datasets from ._dialect import Dialect from ._federated_table import FederatedTable from ._job import Job from ._query import Query from ._query_job import QueryJob from ._query_results_table import QueryResultsTable from ._query_stats import QueryStats from ._sampling import Sampling from ._schema import Schema from ._table import Table, TableMetadata from ._udf import UDF from ._utils import TableName, DatasetName from ._view import View __all__ = ['CSVOptions', 'Dataset', 'Datasets', 'Dialect', 'FederatedTable', 'Query', 'QueryJob', 'QueryResultsTable', 'QueryStats', 'Sampling', 'Schema', 'Table', 'TableMetadata', 'UDF', 'TableName', 'DatasetName', 'View'] def wait_any(jobs, timeout=None): """ Return when at least one of the specified jobs has completed or timeout expires. Args: jobs: a list of Jobs to wait on. timeout: a timeout in seconds to wait for. None (the default) means no timeout. Returns: Once at least one job completes, a list of all completed jobs. If the call times out then an empty list will be returned. """ return Job.wait_any(jobs, timeout) def wait_all(jobs, timeout=None): """ Return when all of the specified jobs have completed or timeout expires. Args: jobs: a single Job or list of Jobs to wait on. timeout: a timeout in seconds to wait for. None (the default) means no timeout. Returns: A list of completed Jobs. If the call timed out this will be shorter than the list of jobs supplied as a parameter. """ return Job.wait_all(jobs, timeout) ================================================ FILE: datalab/bigquery/_api.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements BigQuery HTTP API wrapper.""" from __future__ import absolute_import from __future__ import unicode_literals from past.builtins import basestring from builtins import object import datalab.utils import datalab.bigquery class Api(object): """A helper class to issue BigQuery HTTP requests.""" # TODO(nikhilko): Use named placeholders in these string templates. _ENDPOINT = 'https://www.googleapis.com/bigquery/v2' _JOBS_PATH = '/projects/%s/jobs/%s' _QUERIES_PATH = '/projects/%s/queries/%s' _DATASETS_PATH = '/projects/%s/datasets/%s' _TABLES_PATH = '/projects/%s/datasets/%s/tables/%s%s' _TABLEDATA_PATH = '/projects/%s/datasets/%s/tables/%s%s/data' _DEFAULT_TIMEOUT = 60000 def __init__(self, context): """Initializes the BigQuery helper with context information. Args: context: a Context object providing project_id and credentials. """ self._credentials = context.credentials self._project_id = context.project_id @property def project_id(self): """The project_id associated with this API client.""" return self._project_id def jobs_insert_load(self, source, table_name, append=False, overwrite=False, create=False, source_format='CSV', field_delimiter=',', allow_jagged_rows=False, allow_quoted_newlines=False, encoding='UTF-8', ignore_unknown_values=False, max_bad_records=0, quote='"', skip_leading_rows=0): """ Issues a request to load data from GCS to a BQ table Args: source: the URL of the source bucket(s). Can include wildcards, and can be a single string argument or a list. table_name: a tuple representing the full name of the destination table. append: if True append onto existing table contents. overwrite: if True overwrite existing table contents. create: if True, create the table if it doesn't exist source_format: the format of the data; default 'CSV'. Other options are DATASTORE_BACKUP or NEWLINE_DELIMITED_JSON. field_delimiter: The separator for fields in a CSV file. BigQuery converts the string to ISO-8859-1 encoding, and then uses the first byte of the encoded string to split the data as raw binary (default ','). allow_jagged_rows: If True, accept rows in CSV files that are missing trailing optional columns; the missing values are treated as nulls (default False). allow_quoted_newlines: If True, allow quoted data sections in CSV files that contain newline characters (default False). encoding: The character encoding of the data, either 'UTF-8' (the default) or 'ISO-8859-1'. ignore_unknown_values: If True, accept rows that contain values that do not match the schema; the unknown values are ignored (default False). max_bad_records: The maximum number of bad records that are allowed (and ignored) before returning an 'invalid' error in the Job result (default 0). quote: The value used to quote data sections in a CSV file; default '"'. If your data does not contain quoted sections, set the property value to an empty string. If your data contains quoted newline characters, you must also enable allow_quoted_newlines. skip_leading_rows: A number of rows at the top of a CSV file to skip (default 0). Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._JOBS_PATH % (table_name.project_id, '')) if isinstance(source, basestring): source = [source] write_disposition = 'WRITE_EMPTY' if overwrite: write_disposition = 'WRITE_TRUNCATE' if append: write_disposition = 'WRITE_APPEND' data = { 'kind': 'bigquery#job', 'configuration': { 'load': { 'sourceUris': source, 'destinationTable': { 'projectId': table_name.project_id, 'datasetId': table_name.dataset_id, 'tableId': table_name.table_id }, 'createDisposition': 'CREATE_IF_NEEDED' if create else 'CREATE_NEVER', 'writeDisposition': write_disposition, 'sourceFormat': source_format, 'ignoreUnknownValues': ignore_unknown_values, 'maxBadRecords': max_bad_records, } } } if source_format == 'CSV': load_config = data['configuration']['load'] load_config.update({ 'fieldDelimiter': field_delimiter, 'allowJaggedRows': allow_jagged_rows, 'allowQuotedNewlines': allow_quoted_newlines, 'quote': quote, 'encoding': encoding, 'skipLeadingRows': skip_leading_rows }) return datalab.utils.Http.request(url, data=data, credentials=self._credentials) def jobs_insert_query(self, sql, code=None, imports=None, table_name=None, append=False, overwrite=False, dry_run=False, use_cache=True, batch=True, allow_large_results=False, table_definitions=None, dialect=None, billing_tier=None): """Issues a request to insert a query job. Args: sql: the SQL string representing the query to execute. code: code for Javascript UDFs, if any. imports: a list of GCS URLs containing additional Javascript UDF support code, if any. table_name: None for an anonymous table, or a name parts tuple for a long-lived table. append: if True, append to the table if it is non-empty; else the request will fail if table is non-empty unless overwrite is True. overwrite: if the table already exists, truncate it instead of appending or raising an Exception. dry_run: whether to actually execute the query or just dry run it. use_cache: whether to use past query results or ignore cache. Has no effect if destination is specified. batch: whether to run this as a batch job (lower priority) or as an interactive job (high priority, more expensive). allow_large_results: whether to allow large results (slower with some restrictions but can handle big jobs). table_definitions: a list of JSON external table definitions for any external tables referenced in the query. dialect : {'legacy', 'standard'}, default 'legacy' 'legacy' : Use BigQuery's legacy SQL dialect. 'standard' : Use BigQuery's standard SQL (beta), which is compliant with the SQL 2011 standard. billing_tier: Limits the billing tier for this job. Queries that have resource usage beyond this tier will fail (without incurring a charge). If unspecified, this will be set to your project default. This can also be used to override your project-wide default billing tier on a per-query basis. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._JOBS_PATH % (self._project_id, '')) if dialect is None: dialect = datalab.bigquery.Dialect.default().bq_dialect data = { 'kind': 'bigquery#job', 'configuration': { 'query': { 'query': sql, 'useQueryCache': use_cache, 'allowLargeResults': allow_large_results, 'useLegacySql': dialect == 'legacy' }, 'dryRun': dry_run, 'priority': 'BATCH' if batch else 'INTERACTIVE', }, } query_config = data['configuration']['query'] resources = [] if code: resources.extend([{'inlineCode': fragment} for fragment in code]) if imports: resources.extend([{'resourceUri': uri} for uri in imports]) query_config['userDefinedFunctionResources'] = resources if table_definitions: query_config['tableDefinitions'] = table_definitions if table_name: query_config['destinationTable'] = { 'projectId': table_name.project_id, 'datasetId': table_name.dataset_id, 'tableId': table_name.table_id } if append: query_config['writeDisposition'] = "WRITE_APPEND" elif overwrite: query_config['writeDisposition'] = "WRITE_TRUNCATE" if billing_tier: query_config['maximumBillingTier'] = billing_tier return datalab.utils.Http.request(url, data=data, credentials=self._credentials) def jobs_query_results(self, job_id, project_id, page_size, timeout, start_index=0): """Issues a request to the jobs/getQueryResults method. Args: job_id: the id of job from a previously executed query. project_id: the project id to use to fetch the results; use None for the default project. page_size: limit to the number of rows to fetch. timeout: duration (in milliseconds) to wait for the query to complete. start_index: the index of the row (0-based) at which to start retrieving the page of result rows. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ if timeout is None: timeout = Api._DEFAULT_TIMEOUT if project_id is None: project_id = self._project_id args = { 'maxResults': page_size, 'timeoutMs': timeout, 'startIndex': start_index } url = Api._ENDPOINT + (Api._QUERIES_PATH % (project_id, job_id)) return datalab.utils.Http.request(url, args=args, credentials=self._credentials) def jobs_get(self, job_id, project_id=None): """Issues a request to retrieve information about a job. Args: job_id: the id of the job project_id: the project id to use to fetch the results; use None for the default project. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ if project_id is None: project_id = self._project_id url = Api._ENDPOINT + (Api._JOBS_PATH % (project_id, job_id)) return datalab.utils.Http.request(url, credentials=self._credentials) def datasets_insert(self, dataset_name, friendly_name=None, description=None): """Issues a request to create a dataset. Args: dataset_name: the name of the dataset to create. friendly_name: (optional) the friendly name for the dataset description: (optional) a description for the dataset Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._DATASETS_PATH % (dataset_name.project_id, '')) data = { 'kind': 'bigquery#dataset', 'datasetReference': { 'projectId': dataset_name.project_id, 'datasetId': dataset_name.dataset_id }, } if friendly_name: data['friendlyName'] = friendly_name if description: data['description'] = description return datalab.utils.Http.request(url, data=data, credentials=self._credentials) def datasets_delete(self, dataset_name, delete_contents): """Issues a request to delete a dataset. Args: dataset_name: the name of the dataset to delete. delete_contents: if True, any tables in the dataset will be deleted. If False and the dataset is non-empty an exception will be raised. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._DATASETS_PATH % dataset_name) args = {} if delete_contents: args['deleteContents'] = True return datalab.utils.Http.request(url, method='DELETE', args=args, credentials=self._credentials, raw_response=True) def datasets_update(self, dataset_name, dataset_info): """Updates the Dataset info. Args: dataset_name: the name of the dataset to update as a tuple of components. dataset_info: the Dataset resource with updated fields. """ url = Api._ENDPOINT + (Api._DATASETS_PATH % dataset_name) return datalab.utils.Http.request(url, method='PUT', data=dataset_info, credentials=self._credentials) def datasets_get(self, dataset_name): """Issues a request to retrieve information about a dataset. Args: dataset_name: the name of the dataset Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._DATASETS_PATH % dataset_name) return datalab.utils.Http.request(url, credentials=self._credentials) def datasets_list(self, project_id=None, max_results=0, page_token=None): """Issues a request to list the datasets in the project. Args: project_id: the project id to use to fetch the results; use None for the default project. max_results: an optional maximum number of tables to retrieve. page_token: an optional token to continue the retrieval. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ if project_id is None: project_id = self._project_id url = Api._ENDPOINT + (Api._DATASETS_PATH % (project_id, '')) args = {} if max_results != 0: args['maxResults'] = max_results if page_token is not None: args['pageToken'] = page_token return datalab.utils.Http.request(url, args=args, credentials=self._credentials) def tables_get(self, table_name): """Issues a request to retrieve information about a table. Args: table_name: a tuple representing the full name of the table. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._TABLES_PATH % table_name) return datalab.utils.Http.request(url, credentials=self._credentials) def tables_list(self, dataset_name, max_results=0, page_token=None): """Issues a request to retrieve a list of tables. Args: dataset_name: the name of the dataset to enumerate. max_results: an optional maximum number of tables to retrieve. page_token: an optional token to continue the retrieval. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT +\ (Api._TABLES_PATH % (dataset_name.project_id, dataset_name.dataset_id, '', '')) args = {} if max_results != 0: args['maxResults'] = max_results if page_token is not None: args['pageToken'] = page_token return datalab.utils.Http.request(url, args=args, credentials=self._credentials) def tables_insert(self, table_name, schema=None, query=None, friendly_name=None, description=None): """Issues a request to create a table or view in the specified dataset with the specified id. A schema must be provided to create a Table, or a query must be provided to create a View. Args: table_name: the name of the table as a tuple of components. schema: the schema, if this is a Table creation. query: the query, if this is a View creation. friendly_name: an optional friendly name. description: an optional description. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + \ (Api._TABLES_PATH % (table_name.project_id, table_name.dataset_id, '', '')) data = { 'kind': 'bigquery#table', 'tableReference': { 'projectId': table_name.project_id, 'datasetId': table_name.dataset_id, 'tableId': table_name.table_id } } if schema: data['schema'] = {'fields': schema} if query: data['view'] = {'query': query} if friendly_name: data['friendlyName'] = friendly_name if description: data['description'] = description return datalab.utils.Http.request(url, data=data, credentials=self._credentials) def tabledata_insert_all(self, table_name, rows): """Issues a request to insert data into a table. Args: table_name: the name of the table as a tuple of components. rows: the data to populate the table, as a list of dictionaries. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._TABLES_PATH % table_name) + "/insertAll" data = { 'kind': 'bigquery#tableDataInsertAllRequest', 'rows': rows } return datalab.utils.Http.request(url, data=data, credentials=self._credentials) def tabledata_list(self, table_name, start_index=None, max_results=None, page_token=None): """ Retrieves the contents of a table. Args: table_name: the name of the table as a tuple of components. start_index: the index of the row at which to start retrieval. max_results: an optional maximum number of rows to retrieve. page_token: an optional token to continue the retrieval. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._TABLEDATA_PATH % table_name) args = {} if start_index: args['startIndex'] = start_index if max_results: args['maxResults'] = max_results if page_token is not None: args['pageToken'] = page_token return datalab.utils.Http.request(url, args=args, credentials=self._credentials) def table_delete(self, table_name): """Issues a request to delete a table. Args: table_name: the name of the table as a tuple of components. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._TABLES_PATH % table_name) return datalab.utils.Http.request(url, method='DELETE', credentials=self._credentials, raw_response=True) def table_extract(self, table_name, destination, format='CSV', compress=True, field_delimiter=',', print_header=True): """Exports the table to GCS. Args: table_name: the name of the table as a tuple of components. destination: the destination URI(s). Can be a single URI or a list. format: the format to use for the exported data; one of CSV, NEWLINE_DELIMITED_JSON or AVRO. Defaults to CSV. compress: whether to compress the data on export. Compression is not supported for AVRO format. Defaults to False. field_delimiter: for CSV exports, the field delimiter to use. Defaults to ',' print_header: for CSV exports, whether to include an initial header line. Default true. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._JOBS_PATH % (table_name.project_id, '')) if isinstance(destination, basestring): destination = [destination] data = { # 'projectId': table_name.project_id, # Code sample shows this but it is not in job # reference spec. Filed as b/19235843 'kind': 'bigquery#job', 'configuration': { 'extract': { 'sourceTable': { 'projectId': table_name.project_id, 'datasetId': table_name.dataset_id, 'tableId': table_name.table_id, }, 'compression': 'GZIP' if compress else 'NONE', 'fieldDelimiter': field_delimiter, 'printHeader': print_header, 'destinationUris': destination, 'destinationFormat': format, } } } return datalab.utils.Http.request(url, data=data, credentials=self._credentials) def table_update(self, table_name, table_info): """Updates the Table info. Args: table_name: the name of the table to update as a tuple of components. table_info: the Table resource with updated fields. """ url = Api._ENDPOINT + (Api._TABLES_PATH % table_name) return datalab.utils.Http.request(url, method='PUT', data=table_info, credentials=self._credentials) ================================================ FILE: datalab/bigquery/_csv_options.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements CSV options for External Tables and Table loads from GCS.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object class CSVOptions(object): def __init__(self, delimiter=',', skip_leading_rows=0, encoding='utf-8', quote='"', allow_quoted_newlines=False, allow_jagged_rows=False): """ Initialize an instance of CSV options. Args: delimiter: The separator for fields in a CSV file. BigQuery converts the string to ISO-8859-1 encoding, and then uses the first byte of the encoded string to split the data as raw binary (default ','). skip_leading_rows: A number of rows at the top of a CSV file to skip (default 0). encoding: The character encoding of the data, either 'utf-8' (the default) or 'iso-8859-1'. quote: The value used to quote data sections in a CSV file; default '"'. If your data does not contain quoted sections, set the property value to an empty string. If your data contains quoted newline characters, you must also enable allow_quoted_newlines. allow_quoted_newlines: If True, allow quoted data sections in CSV files that contain newline characters (default False). allow_jagged_rows: If True, accept rows in CSV files that are missing trailing optional columns; the missing values are treated as nulls (default False). """ encoding_upper = encoding.upper() if encoding_upper != 'UTF-8' and encoding_upper != 'ISO-8859-1': raise Exception("Invalid source encoding %s" % encoding) self._delimiter = delimiter self._skip_leading_rows = skip_leading_rows self._encoding = encoding self._quote = quote self._allow_quoted_newlines = allow_quoted_newlines self._allow_jagged_rows = allow_jagged_rows @property def delimiter(self): return self._delimiter @property def skip_leading_rows(self): return self._skip_leading_rows @property def encoding(self): return self._encoding @property def quote(self): return self._quote @property def allow_quoted_newlines(self): return self._allow_quoted_newlines @property def allow_jagged_rows(self): return self._allow_jagged_rows def _to_query_json(self): """ Return the options as a dictionary to be used as JSON in a query job. """ return { 'quote': self._quote, 'fieldDelimiter': self._delimiter, 'encoding': self._encoding.upper(), 'skipLeadingRows': self._skip_leading_rows, 'allowQuotedNewlines': self._allow_quoted_newlines, 'allowJaggedRows': self._allow_jagged_rows } ================================================ FILE: datalab/bigquery/_dataset.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Dataset, and related Dataset BigQuery APIs.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object import datalab.context import datalab.utils from . import _api from . import _table from . import _utils from . import _view class Dataset(object): """Represents a list of BigQuery tables in a dataset.""" def __init__(self, name, context=None): """Initializes an instance of a Dataset. Args: name: the name of the dataset, as a string or (project_id, dataset_id) tuple. context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. Raises: Exception if the name is invalid. """ if context is None: context = datalab.context.Context.default() self._context = context self._api = _api.Api(context) self._name_parts = _utils.parse_dataset_name(name, self._api.project_id) self._full_name = '%s:%s' % self._name_parts self._info = None try: self._info = self._get_info() except datalab.utils.RequestException: pass @property def name(self): """The DatasetName named tuple (project_id, dataset_id) for the dataset.""" return self._name_parts @property def description(self): """The description of the dataset, if any. Raises: Exception if the dataset exists but the metadata for the dataset could not be retrieved. """ self._get_info() return self._info['description'] if self._info else None @property def friendly_name(self): """The friendly name of the dataset, if any. Raises: Exception if the dataset exists but the metadata for the dataset could not be retrieved. """ self._get_info() return self._info['friendlyName'] if self._info else None def _get_info(self): try: if self._info is None: self._info = self._api.datasets_get(self._name_parts) return self._info except datalab.utils.RequestException as e: if e.status == 404: return None raise e except Exception as e: raise e def exists(self): """ Checks if the dataset exists. Returns: True if the dataset exists; False otherwise. Raises: Exception if the dataset exists but the metadata for the dataset could not be retrieved. """ self._get_info() return self._info is not None def delete(self, delete_contents=False): """Issues a request to delete the dataset. Args: delete_contents: if True, any tables and views in the dataset will be deleted. If False and the dataset is non-empty an exception will be raised. Returns: None on success. Raises: Exception if the delete fails (including if table was nonexistent). """ if not self.exists(): raise Exception('Cannot delete non-existent dataset %s' % self._full_name) try: self._api.datasets_delete(self._name_parts, delete_contents=delete_contents) except Exception as e: raise e self._info = None return None def create(self, friendly_name=None, description=None): """Creates the Dataset with the specified friendly name and description. Args: friendly_name: (optional) the friendly name for the dataset if it is being created. description: (optional) a description for the dataset if it is being created. Returns: The Dataset. Raises: Exception if the Dataset could not be created. """ if not self.exists(): try: response = self._api.datasets_insert(self._name_parts, friendly_name=friendly_name, description=description) except Exception as e: raise e if 'selfLink' not in response: raise Exception("Could not create dataset %s" % self._full_name) return self def update(self, friendly_name=None, description=None): """ Selectively updates Dataset information. Args: friendly_name: if not None, the new friendly name. description: if not None, the new description. Returns: """ self._get_info() if self._info: if friendly_name: self._info['friendlyName'] = friendly_name if description: self._info['description'] = description try: self._api.datasets_update(self._name_parts, self._info) except Exception as e: raise e finally: self._info = None # need a refresh def _retrieve_items(self, page_token, item_type): try: list_info = self._api.tables_list(self._name_parts, page_token=page_token) except Exception as e: raise e tables = list_info.get('tables', []) contents = [] if len(tables): try: for info in tables: if info['type'] != item_type: continue if info['type'] == 'TABLE': item = _table.Table((info['tableReference']['projectId'], info['tableReference']['datasetId'], info['tableReference']['tableId']), self._context) else: item = _view.View((info['tableReference']['projectId'], info['tableReference']['datasetId'], info['tableReference']['tableId']), self._context) contents.append(item) except KeyError: raise Exception('Unexpected item list response') page_token = list_info.get('nextPageToken', None) return contents, page_token def _retrieve_tables(self, page_token, _): return self._retrieve_items(page_token=page_token, item_type='TABLE') def _retrieve_views(self, page_token, _): return self._retrieve_items(page_token=page_token, item_type='VIEW') def tables(self): """ Returns an iterator for iterating through the Tables in the dataset. """ return iter(datalab.utils.Iterator(self._retrieve_tables)) def views(self): """ Returns an iterator for iterating through the Views in the dataset. """ return iter(datalab.utils.Iterator(self._retrieve_views)) def __iter__(self): """ Returns an iterator for iterating through the Tables in the dataset. """ return self.tables() def __str__(self): """Returns a string representation of the dataset using its specified name. Returns: The string representation of this object. """ return self._full_name def __repr__(self): """Returns a representation for the dataset for showing in the notebook. """ return 'Dataset %s' % self._full_name class Datasets(object): """ Iterator class for enumerating the datasets in a project. """ def __init__(self, project_id=None, context=None): """ Initialize the Datasets object. Args: project_id: the ID of the project whose datasets you want to list. If None defaults to the project in the context. context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. """ if context is None: context = datalab.context.Context.default() self._context = context self._api = _api.Api(context) self._project_id = project_id if project_id else self._api.project_id def _retrieve_datasets(self, page_token, count): try: list_info = self._api.datasets_list(self._project_id, max_results=count, page_token=page_token) except Exception as e: raise e datasets = list_info.get('datasets', []) if len(datasets): try: datasets = [Dataset((info['datasetReference']['projectId'], info['datasetReference']['datasetId']), self._context) for info in datasets] except KeyError: raise Exception('Unexpected response from server.') page_token = list_info.get('nextPageToken', None) return datasets, page_token def __iter__(self): """ Returns an iterator for iterating through the Datasets in the project. """ return iter(datalab.utils.Iterator(self._retrieve_datasets)) ================================================ FILE: datalab/bigquery/_dialect.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - BigQuery SQL Dialect""" from __future__ import absolute_import class Dialect(object): """ Represents the default BigQuery SQL dialect """ _global_dialect = None def __init__(self, bq_dialect): self._global_dialect = bq_dialect @property def bq_dialect(self): """Retrieves the value of the bq_dialect property. Returns: The default BigQuery SQL dialect """ return self._global_dialect def set_bq_dialect(self, bq_dialect): """ Set the default BigQuery SQL dialect""" if bq_dialect in ['legacy', 'standard']: self._global_dialect = bq_dialect @staticmethod def default(): """Retrieves the default BigQuery SQL dialect, creating it if necessary. Returns: An initialized and shared instance of a Dialect object. """ if Dialect._global_dialect is None: Dialect._global_dialect = Dialect('legacy') return Dialect._global_dialect ================================================ FILE: datalab/bigquery/_federated_table.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements External Table functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object from . import _csv_options class FederatedTable(object): @staticmethod def from_storage(source, source_format='csv', csv_options=None, ignore_unknown_values=False, max_bad_records=0, compressed=False, schema=None): """ Create an external table for a GCS object. Args: source: the URL of the source objects(s). Can include a wildcard '*' at the end of the item name. Can be a single source or a list. source_format: the format of the data, 'csv' or 'json'; default 'csv'. csv_options: For CSV files, the options such as quote character and delimiter. ignore_unknown_values: If True, accept rows that contain values that do not match the schema; the unknown values are ignored (default False). max_bad_records: The maximum number of bad records that are allowed (and ignored) before returning an 'invalid' error in the Job result (default 0). compressed: whether the data is GZ compressed or not (default False). Note that compressed data can be used as a federated table but cannot be loaded into a BQ Table. schema: the schema of the data. This is required for this table to be used as a federated table or to be loaded using a Table object that itself has no schema (default None). """ result = FederatedTable() # Do some sanity checking and concert some params from friendly form to form used by BQ. if source_format == 'csv': result._bq_source_format = 'CSV' if csv_options is None: csv_options = _csv_options.CSVOptions() # use defaults elif source_format == 'json': if csv_options: raise Exception('CSV options are not support for JSON tables') result._bq_source_format = 'NEWLINE_DELIMITED_JSON' else: raise Exception("Invalid source format %s" % source_format) result._source = source if isinstance(source, list) else [source] result._source_format = source_format result._csv_options = csv_options result._ignore_unknown_values = ignore_unknown_values result._max_bad_records = max_bad_records result._compressed = compressed result._schema = schema return result def __init__(self): """ Create an external table reference. Do not call this directly; use factory method(s). """ # Do some sanity checking and concert some params from friendly form to form used by BQ. self._bq_source_format = None self._source = None self._source_format = None self._csv_options = None self._ignore_unknown_values = None self._max_bad_records = None self._compressed = None self._schema = None @property def schema(self): return self._schema def _to_query_json(self): """ Return the table as a dictionary to be used as JSON in a query job. """ json = { 'compression': 'GZIP' if self._compressed else 'NONE', 'ignoreUnknownValues': self._ignore_unknown_values, 'maxBadRecords': self._max_bad_records, 'sourceFormat': self._bq_source_format, 'sourceUris': self._source, } if self._source_format == 'csv' and self._csv_options: json['csvOptions'] = {} json['csvOptions'].update(self._csv_options._to_query_json()) if self._schema: json['schema'] = {'fields': self._schema._bq_schema} return json ================================================ FILE: datalab/bigquery/_job.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements BigQuery Job functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from __future__ import division import datetime import datalab.utils from . import _api class Job(datalab.utils.GCPJob): """Represents a BigQuery Job. """ def __init__(self, job_id, context): """Initializes an instance of a Job. Args: job_id: the BigQuery job ID corresponding to this job. context: a Context object providing project_id and credentials. """ super(Job, self).__init__(job_id, context) def _create_api(self, context): return _api.Api(context) def _refresh_state(self): """ Get the state of a job. If the job is complete this does nothing; otherwise it gets a refreshed copy of the job resource. """ # TODO(gram): should we put a choke on refreshes? E.g. if the last call was less than # a second ago should we return the cached value? if self._is_complete: return try: response = self._api.jobs_get(self._job_id) except Exception as e: raise e if 'status' in response: status = response['status'] if 'state' in status and status['state'] == 'DONE': self._end_time = datetime.datetime.utcnow() self._is_complete = True self._process_job_status(status) if 'statistics' in response: statistics = response['statistics'] start_time = statistics.get('creationTime', None) end_time = statistics.get('endTime', None) if start_time and end_time and end_time >= start_time: self._start_time = datetime.datetime.fromtimestamp(float(start_time) / 1000.0) self._end_time = datetime.datetime.fromtimestamp(float(end_time) / 1000.0) def _process_job_status(self, status): if 'errorResult' in status: error_result = status['errorResult'] location = error_result.get('location', None) message = error_result.get('message', None) reason = error_result.get('reason', None) self._fatal_error = datalab.utils.JobError(location, message, reason) if 'errors' in status: self._errors = [] for error in status['errors']: location = error.get('location', None) message = error.get('message', None) reason = error.get('reason', None) self._errors.append(datalab.utils.JobError(location, message, reason)) ================================================ FILE: datalab/bigquery/_parser.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements BigQuery related data parsing helpers.""" from __future__ import absolute_import from __future__ import division from __future__ import unicode_literals from builtins import zip from builtins import str from builtins import object import datetime class Parser(object): """A set of helper functions to parse data in BigQuery responses.""" def __init__(self): pass @staticmethod def parse_row(schema, data): """Parses a row from query results into an equivalent object. Args: schema: the array of fields defining the schema of the data. data: the JSON row from a query result. Returns: The parsed row object. """ def parse_value(data_type, value): """Parses a value returned from a BigQuery response. Args: data_type: the type of the value as specified by the schema. value: the raw value to return (before casting to data_type). Returns: The value cast to the data_type. """ if value is not None: if value == 'null': value = None elif data_type == 'INTEGER': value = int(value) elif data_type == 'FLOAT': value = float(value) elif data_type == 'TIMESTAMP': value = datetime.datetime.utcfromtimestamp(float(value)) elif data_type == 'BOOLEAN': value = value == 'true' elif (type(value) != str): # TODO(gram): Handle nested JSON records value = str(value) return value row = {} if data is None: return row for i, (field, schema_field) in enumerate(zip(data['f'], schema)): val = field['v'] name = schema_field['name'] data_type = schema_field['type'] repeated = True if 'mode' in schema_field and schema_field['mode'] == 'REPEATED' else False if repeated and val is None: row[name] = [] elif data_type == 'RECORD': sub_schema = schema_field['fields'] if repeated: row[name] = [Parser.parse_row(sub_schema, v['v']) for v in val] else: row[name] = Parser.parse_row(sub_schema, val) elif repeated: row[name] = [parse_value(data_type, v['v']) for v in val] else: row[name] = parse_value(data_type, val) return row @staticmethod def parse_timestamp(value): """Parses a timestamp. Args: value: the number of milliseconds since epoch. """ return datetime.datetime.utcfromtimestamp(float(value) / 1000.0) ================================================ FILE: datalab/bigquery/_query.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Query BigQuery API.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object import datalab.context import datalab.data import datalab.utils from . import _api from . import _federated_table from . import _query_job from . import _sampling from . import _udf from . import _utils class Query(object): """Represents a Query object that encapsulates a BigQuery SQL query. This object can be used to execute SQL queries and retrieve results. """ @staticmethod def sampling_query(sql, context, fields=None, count=5, sampling=None, udfs=None, data_sources=None): """Returns a sampling Query for the SQL object. Args: sql: the SQL statement (string) or Query object to sample. context: a Context object providing project_id and credentials. fields: an optional list of field names to retrieve. count: an optional count of rows to retrieve which is used if a specific sampling is not specified. sampling: an optional sampling strategy to apply to the table. udfs: array of UDFs referenced in the SQL. data_sources: dictionary of federated (external) tables referenced in the SQL. Returns: A Query object for sampling the table. """ return Query(_sampling.Sampling.sampling_query(sql, fields, count, sampling), context=context, udfs=udfs, data_sources=data_sources) def __init__(self, sql, context=None, values=None, udfs=None, data_sources=None, **kwargs): """Initializes an instance of a Query object. Note that either values or kwargs may be used, but not both. Args: sql: the BigQuery SQL query string to execute, or a SqlStatement object. The latter will have any variable references replaced before being associated with the Query (i.e. once constructed the SQL associated with a Query is static). It is possible to have variable references in a query string too provided the variables are passed as keyword arguments to this constructor. context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. values: a dictionary used to expand variables if passed a SqlStatement or a string with variable references. udfs: array of UDFs referenced in the SQL. data_sources: dictionary of federated (external) tables referenced in the SQL. kwargs: arguments to use when expanding the variables if passed a SqlStatement or a string with variable references. Raises: Exception if expansion of any variables failed. """ if context is None: context = datalab.context.Context.default() self._context = context self._api = _api.Api(context) self._data_sources = data_sources self._udfs = udfs if data_sources is None: data_sources = {} self._results = None self._code = None self._imports = [] if values is None: values = kwargs self._sql = datalab.data.SqlModule.expand(sql, values) # We need to take care not to include the same UDF code twice so we use sets. udfs = set(udfs if udfs else []) for value in list(values.values()): if isinstance(value, _udf.UDF): udfs.add(value) included_udfs = set([]) tokens = datalab.data.tokenize(self._sql) udf_dict = {udf.name: udf for udf in udfs} for i, token in enumerate(tokens): # Find the preceding and following non-whitespace tokens prior = i - 1 while prior >= 0 and tokens[prior].isspace(): prior -= 1 if prior < 0: continue next = i + 1 while next < len(tokens) and tokens[next].isspace(): next += 1 if next >= len(tokens): continue uprior = tokens[prior].upper() if uprior != 'FROM' and uprior != 'JOIN': continue # Check for external tables. if tokens[next] not in "[('\"": if token not in data_sources: if values and token in values: if isinstance(values[token], _federated_table.FederatedTable): data_sources[token] = values[token] # Now check for UDF calls. if uprior != 'FROM' or tokens[next] != '(': continue # We have a 'FROM token (' sequence. if token in udf_dict: udf = udf_dict[token] if token not in included_udfs: included_udfs.add(token) if self._code is None: self._code = [] self._code.append(udf.code) if udf.imports: self._imports.extend(udf.imports) fields = ', '.join([f[0] for f in udf._outputs]) tokens[i] = '(SELECT %s FROM %s' % (fields, token) # Find the closing parenthesis and add the additional one now needed. num_paren = 0 j = i + 1 while j < len(tokens): if tokens[j] == '(': num_paren += 1 elif tokens[j] == ')': num_paren -= 1 if num_paren == 0: tokens[j] = '))' break j += 1 self._external_tables = None if len(data_sources): self._external_tables = {} for name, table in list(data_sources.items()): if table.schema is None: raise Exception('Referenced external table %s has no known schema' % name) self._external_tables[name] = table._to_query_json() self._sql = ''.join(tokens) def _repr_sql_(self): """Creates a SQL representation of this object. Returns: The SQL representation to use when embedding this object into other SQL. """ return '(%s)' % self._sql def __str__(self): """Creates a string representation of this object. Returns: The string representation of this object (the unmodified SQL). """ return self._sql def __repr__(self): """Creates a friendly representation of this object. Returns: The friendly representation of this object (the unmodified SQL). """ return self._sql @property def sql(self): """ Get the SQL for the query. """ return self._sql @property def scripts(self): """ Get the code for any Javascript UDFs used in the query. """ return self._code def results(self, use_cache=True, dialect=None, billing_tier=None): """Retrieves table of results for the query. May block if the query must be executed first. Args: use_cache: whether to use cached results or not. Ignored if append is specified. dialect : {'legacy', 'standard'}, default 'legacy' 'legacy' : Use BigQuery's legacy SQL dialect. 'standard' : Use BigQuery's standard SQL (beta), which is compliant with the SQL 2011 standard. billing_tier: Limits the billing tier for this job. Queries that have resource usage beyond this tier will fail (without incurring a charge). If unspecified, this will be set to your project default. This can also be used to override your project-wide default billing tier on a per-query basis. Returns: A QueryResultsTable containing the result set. Raises: Exception if the query could not be executed or query response was malformed. """ if not use_cache or (self._results is None): self.execute(use_cache=use_cache, dialect=dialect, billing_tier=billing_tier) return self._results.results def extract(self, storage_uris, format='csv', csv_delimiter=',', csv_header=True, compress=False, use_cache=True, dialect=None, billing_tier=None): """Exports the query results to GCS. Args: storage_uris: the destination URI(s). Can be a single URI or a list. format: the format to use for the exported data; one of 'csv', 'json', or 'avro' (default 'csv'). csv_delimiter: for csv exports, the field delimiter to use (default ','). csv_header: for csv exports, whether to include an initial header line (default True). compress: whether to compress the data on export. Compression is not supported for AVRO format (default False). use_cache: whether to use cached results or not (default True). dialect : {'legacy', 'standard'}, default 'legacy' 'legacy' : Use BigQuery's legacy SQL dialect. 'standard' : Use BigQuery's standard SQL (beta), which is compliant with the SQL 2011 standard. billing_tier: Limits the billing tier for this job. Queries that have resource usage beyond this tier will fail (without incurring a charge). If unspecified, this will be set to your project default. This can also be used to override your project-wide default billing tier on a per-query basis. Returns: A Job object for the export Job if it was completed successfully; else None. Raises: An Exception if the query or extract failed. """ return self.results(use_cache=use_cache, dialect=dialect, billing_tier=billing_tier).extract(storage_uris, format=format, csv_delimiter=csv_delimiter, csv_header=csv_header, compress=compress) @datalab.utils.async_method def extract_async(self, storage_uris, format='csv', csv_delimiter=',', csv_header=True, compress=False, use_cache=True, dialect=None, billing_tier=None): """Exports the query results to GCS. Returns a Job immediately. Note that there are two jobs that may need to be run sequentially, one to run the query, and the second to extract the resulting table. These are wrapped by a single outer Job. If the query has already been executed and you would prefer to get a Job just for the extract, you can can call extract_async on the QueryResultsTable instead; i.e.: query.results().extract_async(...) Args: storage_uris: the destination URI(s). Can be a single URI or a list. format: the format to use for the exported data; one of 'csv', 'json', or 'avro' (default 'csv'). csv_delimiter: for CSV exports, the field delimiter to use (default ','). csv_header: for CSV exports, whether to include an initial header line (default True). compress: whether to compress the data on export. Compression is not supported for AVRO format (default False). use_cache: whether to use cached results or not (default True). dialect : {'legacy', 'standard'}, default 'legacy' 'legacy' : Use BigQuery's legacy SQL dialect. 'standard' : Use BigQuery's standard SQL (beta), which is compliant with the SQL 2011 standard. billing_tier: Limits the billing tier for this job. Queries that have resource usage beyond this tier will fail (without incurring a charge). If unspecified, this will be set to your project default. This can also be used to override your project-wide default billing tier on a per-query basis. Returns: A Job for the combined (execute, extract) task that will in turn return the Job object for the completed extract task when done; else None. Raises: An Exception if the query failed. """ return self.extract(storage_uris, format=format, csv_delimiter=csv_delimiter, csv_header=csv_header, use_cache=use_cache, compress=compress, dialect=dialect, billing_tier=billing_tier) def to_dataframe(self, start_row=0, max_rows=None, use_cache=True, dialect=None, billing_tier=None): """ Exports the query results to a Pandas dataframe. Args: start_row: the row of the table at which to start the export (default 0). max_rows: an upper limit on the number of rows to export (default None). use_cache: whether to use cached results or not (default True). dialect : {'legacy', 'standard'}, default 'legacy' 'legacy' : Use BigQuery's legacy SQL dialect. 'standard' : Use BigQuery's standard SQL (beta), which is compliant with the SQL 2011 standard. billing_tier: Limits the billing tier for this job. Queries that have resource usage beyond this tier will fail (without incurring a charge). If unspecified, this will be set to your project default. This can also be used to override your project-wide default billing tier on a per-query basis. Returns: A Pandas dataframe containing the table data. """ return self.results(use_cache=use_cache, dialect=dialect, billing_tier=billing_tier) \ .to_dataframe(start_row=start_row, max_rows=max_rows) def to_file(self, path, format='csv', csv_delimiter=',', csv_header=True, use_cache=True, dialect=None, billing_tier=None): """Save the results to a local file in CSV format. Args: path: path on the local filesystem for the saved results. format: the format to use for the exported data; currently only 'csv' is supported. csv_delimiter: for CSV exports, the field delimiter to use. Defaults to ',' csv_header: for CSV exports, whether to include an initial header line. Default true. use_cache: whether to use cached results or not. dialect : {'legacy', 'standard'}, default 'legacy' 'legacy' : Use BigQuery's legacy SQL dialect. 'standard' : Use BigQuery's standard SQL (beta), which is compliant with the SQL 2011 standard. billing_tier: Limits the billing tier for this job. Queries that have resource usage beyond this tier will fail (without incurring a charge). If unspecified, this will be set to your project default. This can also be used to override your project-wide default billing tier on a per-query basis. Returns: The path to the local file. Raises: An Exception if the operation failed. """ self.results(use_cache=use_cache, dialect=dialect, billing_tier=billing_tier) \ .to_file(path, format=format, csv_delimiter=csv_delimiter, csv_header=csv_header) return path @datalab.utils.async_method def to_file_async(self, path, format='csv', csv_delimiter=',', csv_header=True, use_cache=True, dialect=None, billing_tier=None): """Save the results to a local file in CSV format. Returns a Job immediately. Args: path: path on the local filesystem for the saved results. format: the format to use for the exported data; currently only 'csv' is supported. csv_delimiter: for CSV exports, the field delimiter to use. Defaults to ',' csv_header: for CSV exports, whether to include an initial header line. Default true. use_cache: whether to use cached results or not. dialect : {'legacy', 'standard'}, default 'legacy' 'legacy' : Use BigQuery's legacy SQL dialect. 'standard' : Use BigQuery's standard SQL (beta), which is compliant with the SQL 2011 standard. billing_tier: Limits the billing tier for this job. Queries that have resource usage beyond this tier will fail (without incurring a charge). If unspecified, this will be set to your project default. This can also be used to override your project-wide default billing tier on a per-query basis. Returns: A Job for the save that returns the path to the local file on completion. Raises: An Exception if the operation failed. """ return self.to_file(path, format=format, csv_delimiter=csv_delimiter, csv_header=csv_header, use_cache=use_cache, dialect=dialect, billing_tier=billing_tier) def sample(self, count=5, fields=None, sampling=None, use_cache=True, dialect=None, billing_tier=None): """Retrieves a sampling of rows for the query. Args: count: an optional count of rows to retrieve which is used if a specific sampling is not specified (default 5). fields: the list of fields to sample (default None implies all). sampling: an optional sampling strategy to apply to the table. use_cache: whether to use cached results or not (default True). dialect : {'legacy', 'standard'}, default 'legacy' 'legacy' : Use BigQuery's legacy SQL dialect. 'standard' : Use BigQuery's standard SQL (beta), which is compliant with the SQL 2011 standard. billing_tier: Limits the billing tier for this job. Queries that have resource usage beyond this tier will fail (without incurring a charge). If unspecified, this will be set to your project default. This can also be used to override your project-wide default billing tier on a per-query basis. Returns: A QueryResultsTable containing a sampling of the result set. Raises: Exception if the query could not be executed or query response was malformed. """ return Query.sampling_query(self._sql, self._context, count=count, fields=fields, sampling=sampling, udfs=self._udfs, data_sources=self._data_sources).results(use_cache=use_cache, dialect=dialect, billing_tier=billing_tier) def execute_dry_run(self, dialect=None, billing_tier=None): """Dry run a query, to check the validity of the query and return some useful statistics. Args: dialect : {'legacy', 'standard'}, default 'legacy' 'legacy' : Use BigQuery's legacy SQL dialect. 'standard' : Use BigQuery's standard SQL (beta), which is compliant with the SQL 2011 standard. billing_tier: Limits the billing tier for this job. Queries that have resource usage beyond this tier will fail (without incurring a charge). If unspecified, this will be set to your project default. This can also be used to override your project-wide default billing tier on a per-query basis. Returns: A dict with 'cacheHit' and 'totalBytesProcessed' fields. Raises: An exception if the query was malformed. """ try: query_result = self._api.jobs_insert_query(self._sql, self._code, self._imports, dry_run=True, table_definitions=self._external_tables, dialect=dialect, billing_tier=billing_tier) except Exception as e: raise e return query_result['statistics']['query'] def execute_async(self, table_name=None, table_mode='create', use_cache=True, priority='interactive', allow_large_results=False, dialect=None, billing_tier=None): """ Initiate the query and return a QueryJob. Args: table_name: the result table name as a string or TableName; if None (the default), then a temporary table will be used. table_mode: one of 'create', 'overwrite' or 'append'. If 'create' (the default), the request will fail if the table exists. use_cache: whether to use past query results or ignore cache. Has no effect if destination is specified (default True). priority:one of 'batch' or 'interactive' (default). 'interactive' jobs should be scheduled to run quickly but are subject to rate limits; 'batch' jobs could be delayed by as much as three hours but are not rate-limited. allow_large_results: whether to allow large results; i.e. compressed data over 100MB. This is slower and requires a table_name to be specified) (default False). dialect : {'legacy', 'standard'}, default 'legacy' 'legacy' : Use BigQuery's legacy SQL dialect. 'standard' : Use BigQuery's standard SQL (beta), which is compliant with the SQL 2011 standard. billing_tier: Limits the billing tier for this job. Queries that have resource usage beyond this tier will fail (without incurring a charge). If unspecified, this will be set to your project default. This can also be used to override your project-wide default billing tier on a per-query basis. Returns: A QueryJob. Raises: Exception if query could not be executed. """ batch = priority == 'low' append = table_mode == 'append' overwrite = table_mode == 'overwrite' if table_name is not None: table_name = _utils.parse_table_name(table_name, self._api.project_id) try: query_result = self._api.jobs_insert_query(self._sql, self._code, self._imports, table_name=table_name, append=append, overwrite=overwrite, use_cache=use_cache, batch=batch, allow_large_results=allow_large_results, table_definitions=self._external_tables, dialect=dialect, billing_tier=billing_tier) except Exception as e: raise e if 'jobReference' not in query_result: raise Exception('Unexpected response from server') job_id = query_result['jobReference']['jobId'] if not table_name: try: destination = query_result['configuration']['query']['destinationTable'] table_name = (destination['projectId'], destination['datasetId'], destination['tableId']) except KeyError: # The query was in error raise Exception(_utils.format_query_errors(query_result['status']['errors'])) return _query_job.QueryJob(job_id, table_name, self._sql, context=self._context) def execute(self, table_name=None, table_mode='create', use_cache=True, priority='interactive', allow_large_results=False, dialect=None, billing_tier=None): """ Initiate the query, blocking until complete and then return the results. Args: table_name: the result table name as a string or TableName; if None (the default), then a temporary table will be used. table_mode: one of 'create', 'overwrite' or 'append'. If 'create' (the default), the request will fail if the table exists. use_cache: whether to use past query results or ignore cache. Has no effect if destination is specified (default True). priority:one of 'batch' or 'interactive' (default). 'interactive' jobs should be scheduled to run quickly but are subject to rate limits; 'batch' jobs could be delayed by as much as three hours but are not rate-limited. allow_large_results: whether to allow large results; i.e. compressed data over 100MB. This is slower and requires a table_name to be specified) (default False). dialect : {'legacy', 'standard'}, default 'legacy' 'legacy' : Use BigQuery's legacy SQL dialect. 'standard' : Use BigQuery's standard SQL (beta), which is compliant with the SQL 2011 standard. billing_tier: Limits the billing tier for this job. Queries that have resource usage beyond this tier will fail (without incurring a charge). If unspecified, this will be set to your project default. This can also be used to override your project-wide default billing tier on a per-query basis. Returns: The QueryResultsTable for the query. Raises: Exception if query could not be executed. """ job = self.execute_async(table_name=table_name, table_mode=table_mode, use_cache=use_cache, priority=priority, allow_large_results=allow_large_results, dialect=dialect, billing_tier=billing_tier) self._results = job.wait() return self._results def to_view(self, view_name): """ Create a View from this Query. Args: view_name: the name of the View either as a string or a 3-part tuple (projectid, datasetid, name). Returns: A View for the Query. """ # Do the import here to avoid circular dependencies at top-level. from . import _view return _view.View(view_name, self._context).create(self._sql) ================================================ FILE: datalab/bigquery/_query_job.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements BigQuery query job functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from . import _job from . import _query_results_table class QueryJob(_job.Job): """ Represents a BigQuery Query Job. """ def __init__(self, job_id, table_name, sql, context): """ Initializes a QueryJob object. Args: job_id: the ID of the query job. table_name: the name of the table where the query results will be stored. sql: the SQL statement that was executed for the query. context: the Context object providing project_id and credentials that was used when executing the query. """ super(QueryJob, self).__init__(job_id, context) self._sql = sql self._table = _query_results_table.QueryResultsTable(table_name, context, self, is_temporary=True) self._bytes_processed = None self._cache_hit = None self._total_rows = None @property def bytes_processed(self): """ The number of bytes processed, or None if the job is not complete. """ return self._bytes_processed @property def total_rows(self): """ The total number of rows in the result, or None if not complete. """ return self._total_rows @property def cache_hit(self): """ Whether the query results were obtained from the cache or not, or None if not complete. """ return self._cache_hit @property def sql(self): """ The SQL statement that was executed for the query. """ return self._sql def wait(self, timeout=None): """ Wait for the job to complete, or a timeout to happen. This is more efficient than the version in the base Job class, in that we can use a call that blocks for the poll duration rather than a sleep. That means we shouldn't block unnecessarily long and can also poll less. Args: timeout: how long to wait (in seconds) before giving up; default None which means no timeout. Returns: The QueryJob """ poll = 30 while not self._is_complete: try: query_result = self._api.jobs_query_results(self._job_id, project_id=self._context.project_id, page_size=0, timeout=poll * 1000) except Exception as e: raise e if query_result['jobComplete']: if 'totalBytesProcessed' in query_result: self._bytes_processed = int(query_result['totalBytesProcessed']) self._cache_hit = query_result.get('cacheHit', None) if 'totalRows' in query_result: self._total_rows = int(query_result['totalRows']) break if timeout is not None: timeout -= poll if timeout <= 0: break self._refresh_state() return self @property def results(self): """ Get the table used for the results of the query. If the query is incomplete, this blocks. Raises: Exception if we timed out waiting for results or the query failed. """ self.wait() if self.failed: raise Exception('Query failed: %s' % str(self.errors)) return self._table ================================================ FILE: datalab/bigquery/_query_results_table.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements BigQuery query job results table functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from . import _table class QueryResultsTable(_table.Table): """ A subclass of Table specifically for Query results. The primary differences are the additional properties job_id and sql. """ def __init__(self, name, context, job, is_temporary=False): """Initializes an instance of a Table object. Args: name: the name of the table either as a string or a 3-part tuple (projectid, datasetid, name). context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. job: the QueryJob associated with these results. is_temporary: if True, this is a short-lived table for intermediate results (default False). """ super(QueryResultsTable, self).__init__(name, context) self._job = job self._is_temporary = is_temporary def __repr__(self): """Returns a representation for the dataset for showing in the notebook. """ if self._is_temporary: return 'QueryResultsTable %s' % self.job_id else: return super(QueryResultsTable, self).__repr__() @property def job(self): """ The QueryJob object that caused the table to be populated. """ return self._job @property def job_id(self): """ The ID of the query job that caused the table to be populated. """ return self._job.id @property def sql(self): """ The SQL statement for the query that populated the table. """ return self._job.sql @property def is_temporary(self): """ Whether this is a short-lived table or not. """ return self._is_temporary ================================================ FILE: datalab/bigquery/_query_stats.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements representation of BigQuery query job dry run results.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object class QueryStats(object): """A wrapper for statistics returned by a dry run query. Useful so we can get an HTML representation in a notebook. """ def __init__(self, total_bytes, is_cached): self.total_bytes = float(total_bytes) self.is_cached = is_cached def _repr_html_(self): self.total_bytes = QueryStats._size_formatter(self.total_bytes) return """

Dry run information: %s to process, results %s

""" % (self.total_bytes, "cached" if self.is_cached else "not cached") @staticmethod def _size_formatter(byte_num, suf='B'): for mag in ['', 'K', 'M', 'G', 'T']: if byte_num < 1000.0: if suf == 'B': # Don't do fractional bytes return "%5d%s%s" % (int(byte_num), mag, suf) return "%3.1f%s%s" % (byte_num, mag, suf) byte_num /= 1000.0 return "%.1f%s%s".format(byte_num, 'P', suf) ================================================ FILE: datalab/bigquery/_sampling.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Sampling for BigQuery.""" from __future__ import absolute_import from __future__ import division from __future__ import unicode_literals from builtins import object class Sampling(object): """Provides common sampling strategies. Sampling strategies can be used for sampling tables or queries. They are implemented as functions that take in a SQL statement representing the table or query that should be sampled, and return a new SQL statement that limits the result set in some manner. """ def __init__(self): pass @staticmethod def _create_projection(fields): """Creates a projection for use in a SELECT statement. Args: fields: the list of fields to be specified. """ if (fields is None) or (len(fields) == 0): return '*' return ','.join(fields) @staticmethod def default(fields=None, count=5): """Provides a simple default sampling strategy which limits the result set by a count. Args: fields: an optional list of field names to retrieve. count: optional number of rows to limit the sampled results to. Returns: A sampling function that can be applied to get a random sampling. """ projection = Sampling._create_projection(fields) return lambda sql: 'SELECT %s FROM (%s) LIMIT %d' % (projection, sql, count) @staticmethod def sorted(field_name, ascending=True, fields=None, count=5): """Provides a sampling strategy that picks from an ordered set of rows. Args: field_name: the name of the field to sort the rows by. ascending: whether to sort in ascending direction or not. fields: an optional list of field names to retrieve. count: optional number of rows to limit the sampled results to. Returns: A sampling function that can be applied to get the initial few rows. """ direction = '' if ascending else ' DESC' projection = Sampling._create_projection(fields) return lambda sql: 'SELECT %s FROM (%s) ORDER BY %s%s LIMIT %d' % (projection, sql, field_name, direction, count) @staticmethod def sampling_query(sql, fields=None, count=5, sampling=None): """Returns a sampling query for the SQL object. Args: sql: the SQL object to sample fields: an optional list of field names to retrieve. count: an optional count of rows to retrieve which is used if a specific sampling is not specified. sampling: an optional sampling strategy to apply to the table. Returns: A SQL query string for sampling the input sql. """ if sampling is None: sampling = Sampling.default(count=count, fields=fields) return sampling(sql) @staticmethod def hashed(field_name, percent, fields=None, count=0): """Provides a sampling strategy based on hashing and selecting a percentage of data. Args: field_name: the name of the field to hash. percent: the percentage of the resulting hashes to select. fields: an optional list of field names to retrieve. count: optional maximum count of rows to pick. Returns: A sampling function that can be applied to get a hash-based sampling. """ def _hashed_sampling(sql): projection = Sampling._create_projection(fields) sql = 'SELECT %s FROM (%s) WHERE ABS(HASH(%s)) %% 100 < %d' % \ (projection, sql, field_name, percent) if count != 0: sql = '%s LIMIT %d' % (sql, count) return sql return _hashed_sampling @staticmethod def random(percent, fields=None, count=0): """Provides a sampling strategy that picks a semi-random set of rows. Args: percent: the percentage of the resulting hashes to select. fields: an optional list of field names to retrieve. count: maximum number of rows to limit the sampled results to (default 5). Returns: A sampling function that can be applied to get some random rows. In order for this to provide a good random sample percent should be chosen to be ~count/#rows where #rows is the number of rows in the object (query, view or table) being sampled. The rows will be returned in order; i.e. the order itself is not randomized. """ def _random_sampling(sql): projection = Sampling._create_projection(fields) sql = 'SELECT %s FROM (%s) WHERE rand() < %f' % (projection, sql, (float(percent) / 100.0)) if count != 0: sql = '%s LIMIT %d' % (sql, count) return sql return _random_sampling ================================================ FILE: datalab/bigquery/_schema.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Table and View Schema APIs.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from builtins import range from past.builtins import basestring from builtins import object import datetime import pandas class Schema(list): """Represents the schema of a BigQuery table as a flattened list of objects representing fields. Each field object has name, data_type, mode and description properties. Nested fields get flattened with their full-qualified names. So a Schema that has an object A with nested field B will be represented as [(name: 'A', ...), (name: 'A.b', ...)]. """ class Field(object): """ Represents a single field in a Table schema. This has the properties: - name: the flattened, full-qualified name of the field. - data_type: the type of the field as a string ('INTEGER', 'BOOLEAN', 'FLOAT', 'STRING' or 'TIMESTAMP'). - mode: the mode of the field; 'NULLABLE' by default. - description: a description of the field, if known; empty string by default. """ # TODO(gram): consider renaming data_type member to type. Yes, it shadows top-level # name but that is what we are using in __str__ and __getitem__ and is what is used in BQ. # The shadowing is unlikely to cause problems. def __init__(self, name, data_type, mode='NULLABLE', description=''): self.name = name self.data_type = data_type self.mode = mode self.description = description def _repr_sql_(self): """Returns a representation of the field for embedding into a SQL statement. Returns: A formatted field name for use within SQL statements. """ return self.name def __eq__(self, other): """ Compare two schema field objects for equality (ignoring description). """ return self.name == other.name and self.data_type == other.data_type\ and self.mode == other.mode def __str__(self): """ Returns the schema field as a string form of a dictionary. """ return "{ 'name': '%s', 'type': '%s', 'mode':'%s', 'description': '%s' }" % \ (self.name, self.data_type, self.mode, self.description) def __repr__(self): """ Returns the schema field as a string form of a dictionary. """ return str(self) def __getitem__(self, item): # TODO(gram): Currently we need this for a Schema object to work with the Parser object. # Eventually if we change Parser to only work with Schema (and not also with the # schema dictionaries in query results) we can remove this. if item == 'name': return self.name if item == 'type': return self.data_type if item == 'mode': return self.mode if item == 'description': return self.description @staticmethod def _from_dataframe(dataframe, default_type='STRING'): """ Infer a BigQuery table schema from a Pandas dataframe. Note that if you don't explicitly set the types of the columns in the dataframe, they may be of a type that forces coercion to STRING, so even though the fields in the dataframe themselves may be numeric, the type in the derived schema may not be. Hence it is prudent to make sure the Pandas dataframe is typed correctly. Args: dataframe: The DataFrame. default_type : The default big query type in case the type of the column does not exist in the schema. Defaults to 'STRING'. Returns: A list of dictionaries containing field 'name' and 'type' entries, suitable for use in a BigQuery Tables resource schema. """ type_mapping = { 'i': 'INTEGER', 'b': 'BOOLEAN', 'f': 'FLOAT', 'O': 'STRING', 'S': 'STRING', 'U': 'STRING', 'M': 'TIMESTAMP' } fields = [] for column_name, dtype in dataframe.dtypes.iteritems(): fields.append({'name': column_name, 'type': type_mapping.get(dtype.kind, default_type)}) return fields @staticmethod def from_dataframe(dataframe, default_type='STRING'): """ Infer a BigQuery table schema from a Pandas dataframe. Note that if you don't explicitly set the types of the columns in the dataframe, they may be of a type that forces coercion to STRING, so even though the fields in the dataframe themselves may be numeric, the type in the derived schema may not be. Hence it is prudent to make sure the Pandas dataframe is typed correctly. Args: dataframe: The DataFrame. default_type : The default big query type in case the type of the column does not exist in the schema. Defaults to 'STRING'. Returns: A Schema. """ return Schema(Schema._from_dataframe(dataframe, default_type=default_type)) @staticmethod def _get_field_entry(name, value): entry = {'name': name} if isinstance(value, datetime.datetime): _type = 'TIMESTAMP' elif isinstance(value, bool): _type = 'BOOLEAN' elif isinstance(value, float): _type = 'FLOAT' elif isinstance(value, int): _type = 'INTEGER' elif isinstance(value, dict) or isinstance(value, list): _type = 'RECORD' entry['fields'] = Schema._from_record(value) else: _type = 'STRING' entry['type'] = _type return entry @staticmethod def _from_dict_record(data): """ Infer a BigQuery table schema from a dictionary. If the dictionary has entries that are in turn OrderedDicts these will be turned into RECORD types. Ideally this will be an OrderedDict but it is not required. Args: data: The dict to infer a schema from. Returns: A list of dictionaries containing field 'name' and 'type' entries, suitable for use in a BigQuery Tables resource schema. """ return [Schema._get_field_entry(name, value) for name, value in list(data.items())] @staticmethod def _from_list_record(data): """ Infer a BigQuery table schema from a list of values. Args: data: The list of values. Returns: A list of dictionaries containing field 'name' and 'type' entries, suitable for use in a BigQuery Tables resource schema. """ return [Schema._get_field_entry('Column%d' % (i + 1), value) for i, value in enumerate(data)] @staticmethod def _from_record(data): """ Infer a BigQuery table schema from a list of fields or a dictionary. The typeof the elements is used. For a list, the field names are simply 'Column1', 'Column2', etc. Args: data: The list of fields or dictionary. Returns: A list of dictionaries containing field 'name' and 'type' entries, suitable for use in a BigQuery Tables resource schema. """ if isinstance(data, dict): return Schema._from_dict_record(data) elif isinstance(data, list): return Schema._from_list_record(data) else: raise Exception('Cannot create a schema from record %s' % str(data)) @staticmethod def from_record(source): """ Infers a table/view schema from a single record that can contain a list of fields or a dictionary of fields. The type of the elements is used for the types in the schema. For a dict, key names are used for column names while for a list, the field names are simply named 'Column1', 'Column2', etc. Note that if using a dict you may want to use an OrderedDict to ensure column ordering is deterministic. Args: source: The list of field values or dictionary of key/values. Returns: A Schema for the data. """ # TODO(gram): may want to allow an optional second argument which is a list of field # names; could be useful for the record-containing-list case. return Schema(Schema._from_record(source)) @staticmethod def from_data(source): """Infers a table/view schema from its JSON representation, a list of records, or a Pandas dataframe. Args: source: the Pandas Dataframe, a dictionary representing a record, a list of heterogeneous data (record) or homogeneous data (list of records) from which to infer the schema, or a definition of the schema as a list of dictionaries with 'name' and 'type' entries and possibly 'mode' and 'description' entries. Only used if no data argument was provided. 'mode' can be 'NULLABLE', 'REQUIRED' or 'REPEATED'. For the allowed types, see: https://cloud.google.com/bigquery/preparing-data-for-bigquery#datatypes Note that there is potential ambiguity when passing a list of lists or a list of dicts between whether that should be treated as a list of records or a single record that is a list. The heuristic used is to check the length of the entries in the list; if they are equal then a list of records is assumed. To avoid this ambiguity you can instead use the Schema.from_record method which assumes a single record, in either list of values or dictionary of key-values form. Returns: A Schema for the data. """ if isinstance(source, pandas.DataFrame): bq_schema = Schema._from_dataframe(source) elif isinstance(source, list): if len(source) == 0: bq_schema = source elif all(isinstance(d, dict) for d in source): if all('name' in d and 'type' in d for d in source): # It looks like a bq_schema; use it as-is. bq_schema = source elif all(len(d) == len(source[0]) for d in source): bq_schema = Schema._from_dict_record(source[0]) else: raise Exception(('Cannot create a schema from heterogeneous list %s; perhaps you meant ' + 'to use Schema.from_record?') % str(source)) elif isinstance(source[0], list) and \ all([isinstance(l, list) and len(l) == len(source[0]) for l in source]): # A list of lists all of the same length; treat first entry as a list record. bq_schema = Schema._from_record(source[0]) else: # A heterogeneous list; treat as a record. raise Exception(('Cannot create a schema from heterogeneous list %s; perhaps you meant ' + 'to use Schema.from_record?') % str(source)) elif isinstance(source, dict): raise Exception(('Cannot create a schema from dict %s; perhaps you meant to use ' + 'Schema.from_record?') % str(source)) else: raise Exception('Cannot create a schema from %s' % str(source)) return Schema(bq_schema) def __init__(self, definition=None): """Initializes a Schema from its raw JSON representation, a Pandas Dataframe, or a list. Args: definition: a definition of the schema as a list of dictionaries with 'name' and 'type' entries and possibly 'mode' and 'description' entries. Only used if no data argument was provided. 'mode' can be 'NULLABLE', 'REQUIRED' or 'REPEATED'. For the allowed types, see: https://cloud.google.com/bigquery/preparing-data-for-bigquery#datatypes """ super(Schema, self).__init__() self._map = {} self._bq_schema = definition self._populate_fields(definition) def __getitem__(self, key): """Provides ability to lookup a schema field by position or by name. """ if isinstance(key, basestring): return self._map.get(key, None) # noinspection PyCallByClass return list.__getitem__(self, key) def _add_field(self, name, data_type, mode='NULLABLE', description=''): field = Schema.Field(name, data_type, mode, description) self.append(field) self._map[name] = field def find(self, name): """ Get the index of a field in the flattened list given its (fully-qualified) name. Args: name: the fully-qualified name of the field. Returns: The index of the field, if found; else -1. """ for i in range(0, len(self)): if self[i].name == name: return i return -1 def _populate_fields(self, data, prefix=''): for field_data in data: name = prefix + field_data['name'] data_type = field_data['type'] self._add_field(name, data_type, field_data.get('mode', None), field_data.get('description', None)) if data_type == 'RECORD': # Recurse into the nested fields, using this field's name as a prefix. self._populate_fields(field_data.get('fields'), name + '.') def __str__(self): """ Returns a string representation of the non-flattened form of the schema. """ # TODO(gram): We should probably return the flattened form. There was a reason why this is # not but I don't remember what it was. Figure that out and fix this. return str(self._bq_schema) def __eq__(self, other): """ Compares two schema for equality. """ other_map = other._map if len(self._map) != len(other_map): return False for name in self._map.keys(): if name not in other_map: return False if not self._map[name] == other_map[name]: return False return True def __ne__(self, other): """ Compares two schema for inequality. """ return not(self.__eq__(other)) ================================================ FILE: datalab/bigquery/_table.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Table, and related Table BigQuery APIs.""" from __future__ import absolute_import from __future__ import division from __future__ import unicode_literals from builtins import str from past.utils import old_div from builtins import object import calendar import codecs import csv import datetime import pandas import time import traceback import uuid import sys import datalab.context import datalab.utils from . import _api from . import _csv_options from . import _job from . import _parser from . import _schema from . import _utils # import of Query is at end of module as we have a circular dependency of # Query.execute().results -> Table and Table.sample() -> Query class TableMetadata(object): """Represents metadata about a BigQuery table.""" def __init__(self, table, info): """Initializes a TableMetadata instance. Args: table: the Table object this belongs to. info: The BigQuery information about this table as a Python dictionary. """ self._table = table self._info = info @property def created_on(self): """The creation timestamp.""" timestamp = self._info.get('creationTime') return _parser.Parser.parse_timestamp(timestamp) @property def description(self): """The description of the table if it exists.""" return self._info.get('description', '') @property def expires_on(self): """The timestamp for when the table will expire, or None if unknown.""" timestamp = self._info.get('expirationTime', None) if timestamp is None: return None return _parser.Parser.parse_timestamp(timestamp) @property def friendly_name(self): """The friendly name of the table if it exists.""" return self._info.get('friendlyName', '') @property def modified_on(self): """The timestamp for when the table was last modified.""" timestamp = self._info.get('lastModifiedTime') return _parser.Parser.parse_timestamp(timestamp) @property def rows(self): """The number of rows within the table, or -1 if unknown. """ return int(self._info['numRows']) if 'numRows' in self._info else -1 @property def size(self): """The size of the table in bytes, or -1 if unknown. """ return int(self._info['numBytes']) if 'numBytes' in self._info else -1 def refresh(self): """ Refresh the metadata. """ self._info = self._table._load_info() class Table(object): """Represents a Table object referencing a BigQuery table. """ # Allowed characters in a BigQuery table column name _VALID_COLUMN_NAME_CHARACTERS = '_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' # When fetching table contents, the max number of rows to fetch per HTTP request _DEFAULT_PAGE_SIZE = 1024 # Milliseconds per week _MSEC_PER_WEEK = 7 * 24 * 3600 * 1000 def __init__(self, name, context=None): """Initializes an instance of a Table object. The Table need not exist yet. Args: name: the name of the table either as a string or a 3-part tuple (projectid, datasetid, name). If a string, it must have the form ':.' or '.
'. context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. Raises: Exception if the name is invalid. """ if context is None: context = datalab.context.Context.default() self._context = context self._api = _api.Api(context) self._name_parts = _utils.parse_table_name(name, self._api.project_id) self._full_name = '%s:%s.%s%s' % self._name_parts self._info = None self._cached_page = None self._cached_page_index = 0 self._schema = None @property def name(self): """The TableName named tuple (project_id, dataset_id, table_id, decorator) for the table.""" return self._name_parts @property def job(self): """ For tables resulting from executing queries, the job that created the table. Default is None for a Table object; this is overridden by QueryResultsTable. """ return None @property def is_temporary(self): """ Whether this is a short-lived table or not. Always False for non-QueryResultsTables. """ return False def _load_info(self): """Loads metadata about this table.""" if self._info is None: try: self._info = self._api.tables_get(self._name_parts) except Exception as e: raise e @property def metadata(self): """Retrieves metadata about the table. Returns: A TableMetadata object. Raises Exception if the request could not be executed or the response was malformed. """ self._load_info() return TableMetadata(self, self._info) def exists(self): """Checks if the table exists. Returns: True if the table exists; False otherwise. Raises: Exception if there was an error requesting information about the table. """ try: info = self._api.tables_get(self._name_parts) except datalab.utils.RequestException as e: if e.status == 404: return False raise e except Exception as e: raise e self._info = info return True def is_listable(self): """ Determine if the table can be listed. Returns: True is the Table can be listed; False otherwise. """ self._load_info() return 'type' not in self._info or 'MODEL' != self._info['type'] def delete(self): """ Delete the table. Returns: True if the Table no longer exists; False otherwise. """ try: self._api.table_delete(self._name_parts) except datalab.utils.RequestException: # TODO(gram): May want to check the error reasons here and if it is not # because the file didn't exist, return an error. pass except Exception as e: raise e return not self.exists() def create(self, schema, overwrite=False): """ Create the table with the specified schema. Args: schema: the schema to use to create the table. Should be a list of dictionaries, each containing at least a pair of entries, 'name' and 'type'. See https://cloud.google.com/bigquery/docs/reference/v2/tables#resource overwrite: if True, delete the table first if it exists. If False and the table exists, creation will fail and raise an Exception. Returns: The Table instance. Raises: Exception if the table couldn't be created or already exists and truncate was False. """ if overwrite and self.exists(): self.delete() if not isinstance(schema, _schema.Schema): # Convert to a Schema object schema = _schema.Schema(schema) try: response = self._api.tables_insert(self._name_parts, schema=schema._bq_schema) except Exception as e: raise e if 'selfLink' in response: self._schema = schema return self raise Exception("Table %s could not be created as it already exists" % self._full_name) def sample(self, fields=None, count=5, sampling=None, use_cache=True, dialect=None, billing_tier=None): """Retrieves a sampling of data from the table. Args: fields: an optional list of field names to retrieve. count: an optional count of rows to retrieve which is used if a specific sampling is not specified. sampling: an optional sampling strategy to apply to the table. use_cache: whether to use cached results or not. dialect : {'legacy', 'standard'}, default 'legacy' 'legacy' : Use BigQuery's legacy SQL dialect. 'standard' : Use BigQuery's standard SQL (beta), which is compliant with the SQL 2011 standard. billing_tier: Limits the billing tier for this job. Queries that have resource usage beyond this tier will fail (without incurring a charge). If unspecified, this will be set to your project default. This can also be used to override your project-wide default billing tier on a per-query basis. Returns: A QueryResultsTable object containing the resulting data. Raises: Exception if the sample query could not be executed or query response was malformed. """ # Do import here to avoid top-level circular dependencies. from . import _query sql = self._repr_sql_() return _query.Query.sampling_query(sql, context=self._context, count=count, fields=fields, sampling=sampling).results(use_cache=use_cache, dialect=dialect, billing_tier=billing_tier) @staticmethod def _encode_dict_as_row(record, column_name_map): """ Encode a dictionary representing a table row in a form suitable for streaming to BQ. This includes encoding timestamps as ISO-compatible strings and removing invalid characters from column names. Args: record: a Python dictionary representing the table row. column_name_map: a dictionary mapping dictionary keys to column names. This is initially empty and built up by this method when it first encounters each column, then used as a cache subsequently. Returns: The sanitized dictionary. """ for k in list(record.keys()): v = record[k] # If the column is a date, convert to ISO string. if isinstance(v, pandas.Timestamp) or isinstance(v, datetime.datetime): v = record[k] = record[k].isoformat() # If k has invalid characters clean it up if k not in column_name_map: column_name_map[k] = ''.join(c for c in k if c in Table._VALID_COLUMN_NAME_CHARACTERS) new_k = column_name_map[k] if k != new_k: record[new_k] = v del record[k] return record def insert_data(self, data, include_index=False, index_name=None): """ Insert the contents of a Pandas DataFrame or a list of dictionaries into the table. The insertion will be performed using at most 500 rows per POST, and at most 10 POSTs per second, as BigQuery has some limits on streaming rates. Args: data: the DataFrame or list to insert. include_index: whether to include the DataFrame or list index as a column in the BQ table. index_name: for a list, if include_index is True, this should be the name for the index. If not specified, 'Index' will be used. Returns: The table. Raises: Exception if the table doesn't exist, the table's schema differs from the data's schema, or the insert failed. """ # TODO(gram): we could create the Table here is it doesn't exist using a schema derived # from the data. IIRC we decided not to but doing so seems less unwieldy that having to # create it first and then validate the schema against it itself. # There are BigQuery limits on the streaming API: # # max_rows_per_post = 500 # max_bytes_per_row = 20000 # max_rows_per_second = 10000 # max_bytes_per_post = 1000000 # max_bytes_per_second = 10000000 # # It is non-trivial to enforce these here, and the max bytes per row is not something we # can really control. As an approximation we enforce the 500 row limit # with a 0.05 sec POST interval (to enforce the 10,000 rows per sec limit). max_rows_per_post = 500 post_interval = 0.05 # TODO(gram): add different exception types for each failure case. if not self.exists(): raise Exception('Table %s does not exist.' % self._full_name) data_schema = _schema.Schema.from_data(data) if isinstance(data, list): if include_index: if not index_name: index_name = 'Index' data_schema._add_field(index_name, 'INTEGER') table_schema = self.schema # Do some validation of the two schema to make sure they are compatible. for data_field in data_schema: name = data_field.name table_field = table_schema[name] if table_field is None: raise Exception('Table does not contain field %s' % name) data_type = data_field.data_type table_type = table_field.data_type if table_type != data_type: raise Exception('Field %s in data has type %s but in table has type %s' % (name, data_type, table_type)) total_rows = len(data) total_pushed = 0 job_id = uuid.uuid4().hex rows = [] column_name_map = {} is_dataframe = isinstance(data, pandas.DataFrame) if is_dataframe: # reset_index creates a new dataframe so we don't affect the original. reset_index(drop=True) # drops the original index and uses an integer range. gen = data.reset_index(drop=not include_index).iterrows() else: gen = enumerate(data) for index, row in gen: if is_dataframe: row = row.to_dict() elif include_index: row[index_name] = index rows.append({ 'json': self._encode_dict_as_row(row, column_name_map), 'insertId': job_id + str(index) }) total_pushed += 1 if (total_pushed == total_rows) or (len(rows) == max_rows_per_post): try: response = self._api.tabledata_insert_all(self._name_parts, rows) except Exception as e: raise e if 'insertErrors' in response: raise Exception('insertAll failed: %s' % response['insertErrors']) time.sleep(post_interval) # Streaming API is rate-limited rows = [] # Block until data is ready while True: self._info = self._api.tables_get(self._name_parts) if 'streamingBuffer' not in self._info or \ 'estimatedRows' not in self._info['streamingBuffer'] or \ int(self._info['streamingBuffer']['estimatedRows']) > 0: break time.sleep(2) return self def _init_job_from_response(self, response): """ Helper function to create a Job instance from a response. """ job = None if response and 'jobReference' in response: job = _job.Job(job_id=response['jobReference']['jobId'], context=self._context) return job def extract_async(self, destination, format='csv', csv_delimiter=',', csv_header=True, compress=False): """Starts a job to export the table to GCS. Args: destination: the destination URI(s). Can be a single URI or a list. format: the format to use for the exported data; one of 'csv', 'json', or 'avro' (default 'csv'). csv_delimiter: for CSV exports, the field delimiter to use. Defaults to ',' csv_header: for CSV exports, whether to include an initial header line. Default true. compress: whether to compress the data on export. Compression is not supported for AVRO format. Defaults to False. Returns: A Job object for the export Job if it was started successfully; else None. """ format = format.upper() if format == 'JSON': format = 'NEWLINE_DELIMITED_JSON' try: response = self._api.table_extract(self._name_parts, destination, format, compress, csv_delimiter, csv_header) return self._init_job_from_response(response) except Exception as e: raise datalab.utils.JobError(location=traceback.format_exc(), message=str(e), reason=str(type(e))) def extract(self, destination, format='csv', csv_delimiter=',', csv_header=True, compress=False): """Exports the table to GCS; blocks until complete. Args: destination: the destination URI(s). Can be a single URI or a list. format: the format to use for the exported data; one of 'csv', 'json', or 'avro' (default 'csv'). csv_delimiter: for CSV exports, the field delimiter to use. Defaults to ',' csv_header: for CSV exports, whether to include an initial header line. Default true. compress: whether to compress the data on export. Compression is not supported for AVRO format. Defaults to False. Returns: A Job object for the completed export Job if it was started successfully; else None. """ job = self.extract_async(destination, format=format, csv_delimiter=csv_delimiter, csv_header=csv_header, compress=compress) if job is not None: job.wait() return job def load_async(self, source, mode='create', source_format='csv', csv_options=None, ignore_unknown_values=False, max_bad_records=0): """ Starts importing a table from GCS and return a Future. Args: source: the URL of the source objects(s). Can include a wildcard '*' at the end of the item name. Can be a single source or a list. mode: one of 'create', 'append', or 'overwrite'. 'append' or 'overwrite' will fail if the table does not already exist, while 'create' will fail if it does. The default is 'create'. If 'create' the schema will be inferred if necessary. source_format: the format of the data, 'csv' or 'json'; default 'csv'. csv_options: if source format is 'csv', additional options as a CSVOptions object. ignore_unknown_values: If True, accept rows that contain values that do not match the schema; the unknown values are ignored (default False). max_bad_records: the maximum number of bad records that are allowed (and ignored) before returning an 'invalid' error in the Job result (default 0). Returns: A Job object for the import if it was started successfully or None if not. Raises: Exception if the load job failed to be started or invalid arguments were supplied. """ if source_format == 'csv': source_format = 'CSV' elif source_format == 'json': source_format = 'NEWLINE_DELIMITED_JSON' else: raise Exception("Invalid source format %s" % source_format) if not(mode == 'create' or mode == 'append' or mode == 'overwrite'): raise Exception("Invalid mode %s" % mode) if csv_options is None: csv_options = _csv_options.CSVOptions() try: response = self._api.jobs_insert_load(source, self._name_parts, append=(mode == 'append'), overwrite=(mode == 'overwrite'), create=(mode == 'create'), source_format=source_format, field_delimiter=csv_options.delimiter, allow_jagged_rows=csv_options.allow_jagged_rows, allow_quoted_newlines=csv_options.allow_quoted_newlines, encoding=csv_options.encoding.upper(), ignore_unknown_values=ignore_unknown_values, max_bad_records=max_bad_records, quote=csv_options.quote, skip_leading_rows=csv_options.skip_leading_rows) except Exception as e: raise e return self._init_job_from_response(response) def load(self, source, mode='create', source_format='csv', csv_options=None, ignore_unknown_values=False, max_bad_records=0): """ Load the table from GCS. Args: source: the URL of the source objects(s). Can include a wildcard '*' at the end of the item name. Can be a single source or a list. mode: one of 'create', 'append', or 'overwrite'. 'append' or 'overwrite' will fail if the table does not already exist, while 'create' will fail if it does. The default is 'create'. If 'create' the schema will be inferred if necessary. source_format: the format of the data, 'csv' or 'json'; default 'csv'. csv_options: if source format is 'csv', additional options as a CSVOptions object. ignore_unknown_values: if True, accept rows that contain values that do not match the schema; the unknown values are ignored (default False). max_bad_records: the maximum number of bad records that are allowed (and ignored) before returning an 'invalid' error in the Job result (default 0). Returns: A Job object for the completed load Job if it was started successfully; else None. """ job = self.load_async(source, mode=mode, source_format=source_format, csv_options=csv_options, ignore_unknown_values=ignore_unknown_values, max_bad_records=max_bad_records) if job is not None: job.wait() return job def _get_row_fetcher(self, start_row=0, max_rows=None, page_size=_DEFAULT_PAGE_SIZE): """ Get a function that can retrieve a page of rows. The function returned is a closure so that it can have a signature suitable for use by Iterator. Args: start_row: the row to start fetching from; default 0. max_rows: the maximum number of rows to fetch (across all calls, not per-call). Default is None which means no limit. page_size: the maximum number of results to fetch per page; default 1024. Returns: A function that can be called repeatedly with a page token and running count, and that will return an array of rows and a next page token; when the returned page token is None the fetch is complete. """ if not start_row: start_row = 0 elif start_row < 0: # We are measuring from the table end if self.length >= 0: start_row += self.length else: raise Exception('Cannot use negative indices for table of unknown length') schema = self.schema._bq_schema name_parts = self._name_parts def _retrieve_rows(page_token, count): page_rows = [] if max_rows and count >= max_rows: page_token = None else: if max_rows and page_size > (max_rows - count): max_results = max_rows - count else: max_results = page_size try: if page_token: response = self._api.tabledata_list(name_parts, page_token=page_token, max_results=max_results) else: response = self._api.tabledata_list(name_parts, start_index=start_row, max_results=max_results) except Exception as e: raise e page_token = response['pageToken'] if 'pageToken' in response else None if 'rows' in response: page_rows = response['rows'] rows = [] for row_dict in page_rows: rows.append(_parser.Parser.parse_row(schema, row_dict)) return rows, page_token return _retrieve_rows def range(self, start_row=0, max_rows=None): """ Get an iterator to iterate through a set of table rows. Args: start_row: the row of the table at which to start the iteration (default 0) max_rows: an upper limit on the number of rows to iterate through (default None) Returns: A row iterator. """ fetcher = self._get_row_fetcher(start_row=start_row, max_rows=max_rows) return iter(datalab.utils.Iterator(fetcher)) def to_dataframe(self, start_row=0, max_rows=None): """ Exports the table to a Pandas dataframe. Args: start_row: the row of the table at which to start the export (default 0) max_rows: an upper limit on the number of rows to export (default None) Returns: A Pandas dataframe containing the table data. """ fetcher = self._get_row_fetcher(start_row=start_row, max_rows=max_rows) count = 0 page_token = None df = None while True: page_rows, page_token = fetcher(page_token, count) if len(page_rows): count += len(page_rows) if df is None: df = pandas.DataFrame.from_records(page_rows) else: df = df.append(page_rows, ignore_index=True) if not page_token: break # Need to reorder the dataframe to preserve column ordering ordered_fields = [field.name for field in self.schema] return df[ordered_fields] if df is not None else pandas.DataFrame() def to_file(self, destination, format='csv', csv_delimiter=',', csv_header=True): """Save the results to a local file in CSV format. Args: destination: path on the local filesystem for the saved results. format: the format to use for the exported data; currently only 'csv' is supported. csv_delimiter: for CSV exports, the field delimiter to use. Defaults to ',' csv_header: for CSV exports, whether to include an initial header line. Default true. Raises: An Exception if the operation failed. """ f = codecs.open(destination, 'w', 'utf-8') fieldnames = [] for column in self.schema: fieldnames.append(column.name) if sys.version_info[0] == 2: csv_delimiter = csv_delimiter.encode('unicode_escape') writer = csv.DictWriter(f, fieldnames=fieldnames, delimiter=csv_delimiter) if csv_header: writer.writeheader() for row in self: writer.writerow(row) f.close() @datalab.utils.async_method def to_file_async(self, destination, format='csv', csv_delimiter=',', csv_header=True): """Start saving the results to a local file in CSV format and return a Job for completion. Args: destination: path on the local filesystem for the saved results. format: the format to use for the exported data; currently only 'csv' is supported. csv_delimiter: for CSV exports, the field delimiter to use. Defaults to ',' csv_header: for CSV exports, whether to include an initial header line. Default true. Returns: A Job for the async save operation. Raises: An Exception if the operation failed. """ self.to_file(destination, format=format, csv_delimiter=csv_delimiter, csv_header=csv_header) @property def schema(self): """Retrieves the schema of the table. Returns: A Schema object containing a list of schema fields and associated metadata. Raises Exception if the request could not be executed or the response was malformed. """ if not self._schema: try: self._load_info() self._schema = _schema.Schema(self._info['schema']['fields']) except KeyError: raise Exception('Unexpected table response: missing schema') return self._schema def update(self, friendly_name=None, description=None, expiry=None, schema=None): """ Selectively updates Table information. Any parameters that are omitted or None are not updated. Args: friendly_name: if not None, the new friendly name. description: if not None, the new description. expiry: if not None, the new expiry time, either as a DateTime or milliseconds since epoch. schema: if not None, the new schema: either a list of dictionaries or a Schema. """ self._load_info() if friendly_name is not None: self._info['friendlyName'] = friendly_name if description is not None: self._info['description'] = description if expiry is not None: if isinstance(expiry, datetime.datetime): expiry = calendar.timegm(expiry.utctimetuple()) * 1000 self._info['expirationTime'] = expiry if schema is not None: if isinstance(schema, _schema.Schema): schema = schema._bq_schema self._info['schema'] = {'fields': schema} try: self._api.table_update(self._name_parts, self._info) except datalab.utils.RequestException: # The cached metadata is out of sync now; abandon it. self._info = None except Exception as e: raise e def _repr_sql_(self): """Returns a representation of the table for embedding into a SQL statement. Returns: A formatted table name for use within SQL statements. """ return '[' + self._full_name + ']' def __repr__(self): """Returns a representation for the table for showing in the notebook. """ return 'Table %s' % self._full_name def __str__(self): """Returns a string representation of the table using its specified name. Returns: The string representation of this object. """ return self._full_name @property def length(self): """ Get the length of the table (number of rows). We don't use __len__ as this may return -1 for 'unknown'. """ return self.metadata.rows def __iter__(self): """ Get an iterator for the table. """ return self.range(start_row=0) def __getitem__(self, item): """ Get an item or a slice of items from the table. This uses a small cache to reduce the number of calls to tabledata.list. Note: this is a useful function to have, and supports some current usage like query.results()[0], but should be used with care. """ if isinstance(item, slice): # Just treat this as a set of calls to __getitem__(int) result = [] i = item.start step = item.step if item.step else 1 while i < item.stop: result.append(self[i]) i += step return result # Handle the integer index case. if item < 0: if self.length >= 0: item += self.length else: raise Exception('Cannot use negative indices for table of unknown length') if not self._cached_page \ or self._cached_page_index > item \ or self._cached_page_index + len(self._cached_page) <= item: # cache a new page. To get the start row we round to the nearest multiple of the page # size. first = old_div(item, self._DEFAULT_PAGE_SIZE) * self._DEFAULT_PAGE_SIZE count = self._DEFAULT_PAGE_SIZE if self.length >= 0: remaining = self.length - first if count > remaining: count = remaining fetcher = self._get_row_fetcher(start_row=first, max_rows=count, page_size=count) self._cached_page_index = first self._cached_page, _ = fetcher(None, 0) return self._cached_page[item - self._cached_page_index] @staticmethod def _convert_decorator_time(when): if isinstance(when, datetime.datetime): value = 1000 * (when - datetime.datetime.utcfromtimestamp(0)).total_seconds() elif isinstance(when, datetime.timedelta): value = when.total_seconds() * 1000 if value > 0: raise Exception("Invalid snapshot relative when argument: %s" % str(when)) else: raise Exception("Invalid snapshot when argument type: %s" % str(when)) if value < -Table._MSEC_PER_WEEK: raise Exception("Invalid snapshot relative when argument: must be within 7 days: %s" % str(when)) if value > 0: now = 1000 * (datetime.datetime.utcnow() - datetime.datetime.utcfromtimestamp(0)).total_seconds() # Check that an abs value is not more than 7 days in the past and is # not in the future if not ((now - Table._MSEC_PER_WEEK) < value < now): raise Exception("Invalid snapshot absolute when argument: %s" % str(when)) return int(value) def snapshot(self, at): """ Return a new Table which is a snapshot of this table at the specified time. Args: at: the time of the snapshot. This can be a Python datetime (absolute) or timedelta (relative to current time). The result must be after the table was created and no more than seven days in the past. Passing None will get a reference the oldest snapshot. Note that using a datetime will get a snapshot at an absolute point in time, while a timedelta will provide a varying snapshot; any queries issued against such a Table will be done against a snapshot that has an age relative to the execution time of the query. Returns: A new Table object referencing the snapshot. Raises: An exception if this Table is already decorated, or if the time specified is invalid. """ if self._name_parts.decorator != '': raise Exception("Cannot use snapshot() on an already decorated table") value = Table._convert_decorator_time(at) return Table("%s@%s" % (self._full_name, str(value)), context=self._context) def window(self, begin, end=None): """ Return a new Table limited to the rows added to this Table during the specified time range. Args: begin: the start time of the window. This can be a Python datetime (absolute) or timedelta (relative to current time). The result must be after the table was created and no more than seven days in the past. Note that using a relative value will provide a varying snapshot, not a fixed snapshot; any queries issued against such a Table will be done against a snapshot that has an age relative to the execution time of the query. end: the end time of the snapshot; if None, then the current time is used. The types and interpretation of values is as for start. Returns: A new Table object referencing the window. Raises: An exception if this Table is already decorated, or if the time specified is invalid. """ if self._name_parts.decorator != '': raise Exception("Cannot use window() on an already decorated table") start = Table._convert_decorator_time(begin) if end is None: if isinstance(begin, datetime.timedelta): end = datetime.timedelta(0) else: end = datetime.datetime.utcnow() stop = Table._convert_decorator_time(end) # Both values must have the same sign if (start > 0 >= stop) or (stop > 0 >= start): raise Exception("window: Between arguments must both be absolute or relative: %s, %s" % (str(begin), str(end))) # start must be less than stop if start > stop: raise Exception("window: Between arguments: begin must be before end: %s, %s" % (str(begin), str(end))) return Table("%s@%s-%s" % (self._full_name, str(start), str(stop)), context=self._context) def to_query(self, fields=None): """ Return a Query for this Table. Args: fields: the fields to return. If None, all fields will be returned. This can be a string which will be injected into the Query after SELECT, or a list of field names. Returns: A Query object that will return the specified fields from the records in the Table. """ # Do import here to avoid top-level circular dependencies. from . import _query if fields is None: fields = '*' elif isinstance(fields, list): fields = ','.join(fields) return _query.Query('SELECT %s FROM %s' % (fields, self._repr_sql_()), context=self._context) ================================================ FILE: datalab/bigquery/_udf.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - BigQuery UDF Functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from builtins import object import json class UDF(object): """Represents a BigQuery UDF declaration. """ @property def name(self): return self._name @property def imports(self): return self._imports @property def code(self): return self._code def __init__(self, inputs, outputs, name, implementation, support_code=None, imports=None): """Initializes a Function object from its pieces. Args: inputs: a list of string field names representing the schema of input. outputs: a list of name/type tuples representing the schema of the output. name: the name of the javascript function implementation: a javascript function implementing the logic. support_code: additional javascript code that the function can use. imports: a list of GCS URLs or files containing further support code. Raises: Exception if the name is invalid. """ self._outputs = outputs self._name = name self._implementation = implementation self._support_code = support_code self._imports = imports self._code = UDF._build_js(inputs, outputs, name, implementation, support_code) @staticmethod def _build_js(inputs, outputs, name, implementation, support_code): """Creates a BigQuery SQL UDF javascript object. Args: inputs: a list of (name, type) tuples representing the schema of input. outputs: a list of (name, type) tuples representing the schema of the output. name: the name of the function implementation: a javascript function defining the UDF logic. support_code: additional javascript code that the function can use. """ # Construct a comma-separated list of input field names # For example, field1,field2,... input_fields = json.dumps([f[0] for f in inputs]) # Construct a json representation of the output schema # For example, [{'name':'field1','type':'string'},...] output_fields = [{'name': f[0], 'type': f[1]} for f in outputs] output_fields = json.dumps(output_fields, sort_keys=True) # Build the JS from the individual bits with proper escaping of the implementation if support_code is None: support_code = '' return ('{code}\n{name}={implementation};\nbigquery.defineFunction(\'{name}\', {inputs}, ' '{outputs}, {name});').format(code=support_code, name=name, implementation=implementation, inputs=str(input_fields), outputs=str(output_fields)) ================================================ FILE: datalab/bigquery/_utils.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Useful common utility functions.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from past.builtins import basestring import collections import re DatasetName = collections.namedtuple('DatasetName', ['project_id', 'dataset_id']) """ A namedtuple for Dataset names. Args: project_id: the project id for the dataset. dataset_id: the dataset id for the dataset. """ TableName = collections.namedtuple('TableName', ['project_id', 'dataset_id', 'table_id', 'decorator']) """ A namedtuple for Table names. Args: project_id: the project id for the table. dataset_id: the dataset id for the table. table_id: the table id for the table. decorator: the optional decorator for the table (for windowing/snapshot-ing). """ # Absolute project-qualified name pattern: : _ABS_DATASET_NAME_PATTERN = r'^([a-z\d\-_\.:]+)\:(\w+)$' # Relative name pattern: _REL_DATASET_NAME_PATTERN = r'^(\w+)$' # Absolute project-qualified name pattern: :.
_ABS_TABLE_NAME_PATTERN = r'^([a-z\d\-_\.:]+)\:(\w+)\.(\w+)(@[\d\-]+)?$' # Relative name pattern: .
_REL_TABLE_NAME_PATTERN = r'^(\w+)\.(\w+)(@[\d\-]+)?$' # Table-only name pattern:
. Includes an optional decorator. _TABLE_NAME_PATTERN = r'^(\w+)(@[\d\-]+)$' def parse_dataset_name(name, project_id=None): """Parses a dataset name into its individual parts. Args: name: the name to parse, or a tuple, dictionary or array containing the parts. project_id: the expected project ID. If the name does not contain a project ID, this will be used; if the name does contain a project ID and it does not match this, an exception will be thrown. Returns: A DatasetName named tuple for the dataset. Raises: Exception: raised if the name doesn't match the expected formats or a project_id was specified that does not match that in the name. """ _project_id = _dataset_id = None if isinstance(name, basestring): # Try to parse as absolute name first. m = re.match(_ABS_DATASET_NAME_PATTERN, name, re.IGNORECASE) if m is not None: _project_id, _dataset_id = m.groups() else: # Next try to match as a relative name implicitly scoped within current project. m = re.match(_REL_DATASET_NAME_PATTERN, name) if m is not None: groups = m.groups() _dataset_id = groups[0] elif isinstance(name, dict): try: _dataset_id = name['dataset_id'] _project_id = name['project_id'] except KeyError: pass else: # Try treat as an array or tuple if len(name) == 2: # Treat as a tuple or array. _project_id, _dataset_id = name elif len(name) == 1: _dataset_id = name[0] if not _dataset_id: raise Exception('Invalid dataset name: ' + str(name)) if not _project_id: _project_id = project_id return DatasetName(_project_id, _dataset_id) def parse_table_name(name, project_id=None, dataset_id=None): """Parses a table name into its individual parts. Args: name: the name to parse, or a tuple, dictionary or array containing the parts. project_id: the expected project ID. If the name does not contain a project ID, this will be used; if the name does contain a project ID and it does not match this, an exception will be thrown. dataset_id: the expected dataset ID. If the name does not contain a dataset ID, this will be used; if the name does contain a dataset ID and it does not match this, an exception will be thrown. Returns: A TableName named tuple consisting of the full name and individual name parts. Raises: Exception: raised if the name doesn't match the expected formats, or a project_id and/or dataset_id was provided that does not match that in the name. """ _project_id = _dataset_id = _table_id = _decorator = None if isinstance(name, basestring): # Try to parse as absolute name first. m = re.match(_ABS_TABLE_NAME_PATTERN, name, re.IGNORECASE) if m is not None: _project_id, _dataset_id, _table_id, _decorator = m.groups() else: # Next try to match as a relative name implicitly scoped within current project. m = re.match(_REL_TABLE_NAME_PATTERN, name) if m is not None: groups = m.groups() _project_id, _dataset_id, _table_id, _decorator =\ project_id, groups[0], groups[1], groups[2] else: # Finally try to match as a table name only. m = re.match(_TABLE_NAME_PATTERN, name) if m is not None: groups = m.groups() _project_id, _dataset_id, _table_id, _decorator =\ project_id, dataset_id, groups[0], groups[1] elif isinstance(name, dict): try: _table_id = name['table_id'] _dataset_id = name['dataset_id'] _project_id = name['project_id'] except KeyError: pass else: # Try treat as an array or tuple if len(name) == 4: _project_id, _dataset_id, _table_id, _decorator = name elif len(name) == 3: _project_id, _dataset_id, _table_id = name elif len(name) == 2: _dataset_id, _table_id = name if not _table_id: raise Exception('Invalid table name: ' + str(name)) if not _project_id: _project_id = project_id if not _dataset_id: _dataset_id = dataset_id if not _decorator: _decorator = '' return TableName(_project_id, _dataset_id, _table_id, _decorator) def format_query_errors(errors): return '\n'.join(['%s: %s' % (error['reason'], error['message']) for error in errors]) ================================================ FILE: datalab/bigquery/_view.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements BigQuery Views.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from builtins import object import datalab.context from . import _query from . import _table # Query import is at end to avoid issues with circular dependencies. class View(object): """ An implementation of a BigQuery View. """ # Views in BigQuery are virtual tables, but it is useful to have a mixture of both Table and # Query semantics; our version thus internally has a BaseTable and a Query (for materialization; # not the same as the view query), and exposes a number of the same APIs as Table and Query # through wrapper functions around these. def __init__(self, name, context=None): """Initializes an instance of a View object. Args: name: the name of the view either as a string or a 3-part tuple (projectid, datasetid, name). If a string, it must have the form ':.' or '.'. context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. Raises: Exception if the name is invalid. """ if context is None: context = datalab.context.Context.default() self._context = context self._table = _table.Table(name, context=context) self._materialization = _query.Query('SELECT * FROM %s' % self._repr_sql_(), context=context) def __str__(self): """The full name for the view as a string.""" return str(self._table) @property def name(self): """The name for the view as a named tuple.""" return self._table.name @property def description(self): """The description of the view if it exists.""" return self._table.metadata.description @property def friendly_name(self): """The friendly name of the view if it exists.""" return self._table.metadata.friendly_name @property def query(self): """The Query that defines the view.""" if not self.exists(): return None self._table._load_info() if 'view' in self._table._info and 'query' in self._table._info['view']: return _query.Query(self._table._info['view']['query'], context=self._context) return None def exists(self): """Whether the view's Query has been executed and the view is available or not.""" return self._table.exists() def delete(self): """Removes the view if it exists.""" self._table.delete() def create(self, query): """ Creates the view with the specified query. Args: query: the query to use to for the View; either a string containing a SQL query or a Query object. Returns: The View instance. Raises: Exception if the view couldn't be created or already exists and overwrite was False. """ if isinstance(query, _query.Query): query = query.sql try: response = self._table._api.tables_insert(self._table.name, query=query) except Exception as e: raise e if 'selfLink' in response: return self raise Exception("View %s could not be created as it already exists" % str(self)) def sample(self, fields=None, count=5, sampling=None, use_cache=True, dialect=None, billing_tier=None): """Retrieves a sampling of data from the view. Args: fields: an optional list of field names to retrieve. count: an optional count of rows to retrieve which is used if a specific sampling is not specified. sampling: an optional sampling strategy to apply to the view. use_cache: whether to use cached results or not. dialect : {'legacy', 'standard'}, default 'legacy' 'legacy' : Use BigQuery's legacy SQL dialect. 'standard' : Use BigQuery's standard SQL (beta), which is compliant with the SQL 2011 standard. billing_tier: Limits the billing tier for this job. Queries that have resource usage beyond this tier will fail (without incurring a charge). If unspecified, this will be set to your project default. This can also be used to override your project-wide default billing tier on a per-query basis. Returns: A QueryResultsTable object containing the resulting data. Raises: Exception if the sample query could not be executed or the query response was malformed. """ return self._table.sample(fields=fields, count=count, sampling=sampling, use_cache=use_cache, dialect=dialect, billing_tier=billing_tier) @property def schema(self): """Retrieves the schema of the table. Returns: A Schema object containing a list of schema fields and associated metadata. Raises Exception if the request could not be executed or the response was malformed. """ return self._table.schema def update(self, friendly_name=None, description=None, query=None): """ Selectively updates View information. Any parameters that are None (the default) are not applied in the update. Args: friendly_name: if not None, the new friendly name. description: if not None, the new description. query: if not None, a new query string for the View. """ self._table._load_info() if query is not None: if isinstance(query, _query.Query): query = query.sql self._table._info['view'] = {'query': query} self._table.update(friendly_name=friendly_name, description=description) def results(self, use_cache=True, dialect=None, billing_tier=None): """Materialize the view synchronously. If you require more control over the execution, use execute() or execute_async(). Args: use_cache: whether to use cached results or not. dialect : {'legacy', 'standard'}, default 'legacy' 'legacy' : Use BigQuery's legacy SQL dialect. 'standard' : Use BigQuery's standard SQL (beta), which is compliant with the SQL 2011 standard. billing_tier: Limits the billing tier for this job. Queries that have resource usage beyond this tier will fail (without incurring a charge). If unspecified, this will be set to your project default. This can also be used to override your project-wide default billing tier on a per-query basis. Returns: A QueryResultsTable containing the result set. Raises: Exception if the query could not be executed or query response was malformed. """ return self._materialization.results(use_cache=use_cache, dialect=dialect, billing_tier=billing_tier) def execute_async(self, table_name=None, table_mode='create', use_cache=True, priority='high', allow_large_results=False, dialect=None, billing_tier=None): """Materialize the View asynchronously. Args: table_name: the result table name; if None, then a temporary table will be used. table_mode: one of 'create', 'overwrite' or 'append'. If 'create' (the default), the request will fail if the table exists. use_cache: whether to use past query results or ignore cache. Has no effect if destination is specified (default True). priority:one of 'low' or 'high' (default). Note that 'high' is more expensive, but is better suited to exploratory analysis. allow_large_results: whether to allow large results; i.e. compressed data over 100MB. This is slower and requires a table_name to be specified) (default False). dialect : {'legacy', 'standard'}, default 'legacy' 'legacy' : Use BigQuery's legacy SQL dialect. 'standard' : Use BigQuery's standard SQL (beta), which is compliant with the SQL 2011 standard. billing_tier: Limits the billing tier for this job. Queries that have resource usage beyond this tier will fail (without incurring a charge). If unspecified, this will be set to your project default. This can also be used to override your project-wide default billing tier on a per-query basis. Returns: A QueryJob for the materialization Raises: Exception (KeyError) if View could not be materialized. """ return self._materialization.execute_async(table_name=table_name, table_mode=table_mode, use_cache=use_cache, priority=priority, allow_large_results=allow_large_results, dialect=dialect, billing_tier=billing_tier) def execute(self, table_name=None, table_mode='create', use_cache=True, priority='high', allow_large_results=False, dialect=None, billing_tier=None): """Materialize the View synchronously. Args: table_name: the result table name; if None, then a temporary table will be used. table_mode: one of 'create', 'overwrite' or 'append'. If 'create' (the default), the request will fail if the table exists. use_cache: whether to use past query results or ignore cache. Has no effect if destination is specified (default True). priority:one of 'low' or 'high' (default). Note that 'high' is more expensive, but is better suited to exploratory analysis. allow_large_results: whether to allow large results; i.e. compressed data over 100MB. This is slower and requires a table_name to be specified) (default False). dialect : {'legacy', 'standard'}, default 'legacy' 'legacy' : Use BigQuery's legacy SQL dialect. 'standard' : Use BigQuery's standard SQL (beta), which is compliant with the SQL 2011 standard. billing_tier: Limits the billing tier for this job. Queries that have resource usage beyond this tier will fail (without incurring a charge). If unspecified, this will be set to your project default. This can also be used to override your project-wide default billing tier on a per-query basis. Returns: A QueryJob for the materialization Raises: Exception (KeyError) if View could not be materialized. """ return self._materialization.execute(table_name=table_name, table_mode=table_mode, use_cache=use_cache, priority=priority, allow_large_results=allow_large_results, dialect=dialect, billing_tier=billing_tier) def _repr_sql_(self): """Returns a representation of the view for embedding into a SQL statement. Returns: A formatted table name for use within SQL statements. """ return '[' + str(self) + ']' def __repr__(self): """Returns a representation for the view for showing in the notebook. """ return 'View %s\n%s' % (self._table, self.query) ================================================ FILE: datalab/bigquery/commands/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from . import _bigquery __all__ = ['_bigquery'] ================================================ FILE: datalab/bigquery/commands/_bigquery.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - BigQuery IPython Functionality.""" from __future__ import absolute_import from __future__ import print_function from __future__ import unicode_literals from builtins import zip from builtins import str from past.builtins import basestring try: import IPython import IPython.core.display import IPython.core.magic except ImportError: raise Exception('This module can only be loaded in ipython.') import fnmatch import json import re import datalab.bigquery import datalab.data import datalab.utils import datalab.utils.commands def _create_create_subparser(parser): create_parser = parser.subcommand('create', 'Create a dataset or table.') sub_commands = create_parser.add_subparsers(dest='command') create_dataset_parser = sub_commands.add_parser('dataset', help='Create a dataset.') create_dataset_parser.add_argument('-n', '--name', help='The name of the dataset to create.', required=True) create_dataset_parser.add_argument('-f', '--friendly', help='The friendly name of the dataset.') create_table_parser = sub_commands.add_parser('table', help='Create a table.') create_table_parser.add_argument('-n', '--name', help='The name of the table to create.', required=True) create_table_parser.add_argument('-o', '--overwrite', help='Overwrite table if it exists.', action='store_true') return create_parser def _create_delete_subparser(parser): delete_parser = parser.subcommand('delete', 'Delete a dataset or table.') sub_commands = delete_parser.add_subparsers(dest='command') delete_dataset_parser = sub_commands.add_parser('dataset', help='Delete a dataset.') delete_dataset_parser.add_argument('-n', '--name', help='The name of the dataset to delete.', required=True) delete_table_parser = sub_commands.add_parser('table', help='Delete a table.') delete_table_parser.add_argument('-n', '--name', help='The name of the table to delete.', required=True) return delete_parser def _create_sample_subparser(parser): sample_parser = parser.subcommand('sample', 'Display a sample of the results of a BigQuery SQL query.\nThe ' 'cell can optionally contain arguments for expanding variables ' 'in the query,\nif -q/--query was used, or it can contain SQL ' 'for a query.') group = sample_parser.add_mutually_exclusive_group() group.add_argument('-q', '--query', help='the name of the query to sample') group.add_argument('-t', '--table', help='the name of the table to sample') group.add_argument('-v', '--view', help='the name of the view to sample') sample_parser.add_argument('-d', '--dialect', help='BigQuery SQL dialect', choices=['legacy', 'standard']) sample_parser.add_argument('-b', '--billing', type=int, help='BigQuery billing tier') sample_parser.add_argument('-c', '--count', type=int, default=10, help='The number of rows to limit to, if sampling') sample_parser.add_argument('-m', '--method', help='The type of sampling to use', choices=['limit', 'random', 'hashed', 'sorted'], default='limit') sample_parser.add_argument('-p', '--percent', type=int, default=1, help='For random or hashed sampling, what percentage to sample from') sample_parser.add_argument('-f', '--field', help='The field to use for sorted or hashed sampling') sample_parser.add_argument('-o', '--order', choices=['ascending', 'descending'], default='ascending', help='The sort order to use for sorted sampling') sample_parser.add_argument('-P', '--profile', action='store_true', default=False, help='Generate an interactive profile of the data') sample_parser.add_argument('--verbose', help='Show the expanded SQL that is being executed', action='store_true') return sample_parser def _create_udf_subparser(parser): udf_parser = parser.subcommand('udf', 'Create a named Javascript BigQuery UDF') udf_parser.add_argument('-m', '--module', help='The name for this UDF') return udf_parser def _create_dry_run_subparser(parser): dry_run_parser = parser.subcommand('dryrun', 'Execute a dry run of a BigQuery query and display ' 'approximate usage statistics') dry_run_parser.add_argument('-q', '--query', help='The name of the query to be dry run') dry_run_parser.add_argument('-d', '--dialect', help='BigQuery SQL dialect', choices=['legacy', 'standard']) dry_run_parser.add_argument('-b', '--billing', type=int, help='BigQuery billing tier') dry_run_parser.add_argument('-v', '--verbose', help='Show the expanded SQL that is being executed', action='store_true') return dry_run_parser def _create_execute_subparser(parser): execute_parser = parser.subcommand('execute', 'Execute a BigQuery SQL query and optionally send the results ' 'to a named table.\nThe cell can optionally contain arguments ' 'for expanding variables in the query.') execute_parser.add_argument('-nc', '--nocache', help='Don\'t use previously cached results', action='store_true') execute_parser.add_argument('-d', '--dialect', help='BigQuery SQL dialect', choices=['legacy', 'standard']) execute_parser.add_argument('-b', '--billing', type=int, help='BigQuery billing tier') execute_parser.add_argument('-m', '--mode', help='The table creation mode', default='create', choices=['create', 'append', 'overwrite']) execute_parser.add_argument('-l', '--large', help='Whether to allow large results', action='store_true') execute_parser.add_argument('-q', '--query', help='The name of query to run') execute_parser.add_argument('-t', '--target', help='target table name') execute_parser.add_argument('-v', '--verbose', help='Show the expanded SQL that is being executed', action='store_true') return execute_parser def _create_pipeline_subparser(parser): pipeline_parser = parser.subcommand('pipeline', 'Define a deployable pipeline based on a BigQuery query.\n' 'The cell can optionally contain arguments for expanding ' 'variables in the query.') pipeline_parser.add_argument('-n', '--name', help='The pipeline name') pipeline_parser.add_argument('-nc', '--nocache', help='Don\'t use previously cached results', action='store_true') pipeline_parser.add_argument('-d', '--dialect', help='BigQuery SQL dialect', choices=['legacy', 'standard']) pipeline_parser.add_argument('-b', '--billing', type=int, help='BigQuery billing tier') pipeline_parser.add_argument('-m', '--mode', help='The table creation mode', default='create', choices=['create', 'append', 'overwrite']) pipeline_parser.add_argument('-l', '--large', help='Allow large results', action='store_true') pipeline_parser.add_argument('-q', '--query', help='The name of query to run', required=True) pipeline_parser.add_argument('-t', '--target', help='The target table name', nargs='?') pipeline_parser.add_argument('-v', '--verbose', help='Show the expanded SQL that is being executed', action='store_true') pipeline_parser.add_argument('action', nargs='?', choices=('deploy', 'run', 'dryrun'), default='dryrun', help='Whether to deploy the pipeline, execute it immediately in ' + 'the notebook, or validate it with a dry run') # TODO(gram): we may want to move some command line arguments to the cell body config spec # eventually. return pipeline_parser def _create_table_subparser(parser): table_parser = parser.subcommand('table', 'View a BigQuery table.') table_parser.add_argument('-r', '--rows', type=int, default=25, help='Rows to display per page') table_parser.add_argument('-c', '--cols', help='Comma-separated list of column names to restrict to') table_parser.add_argument('table', help='The name of, or a reference to, the table or view') return table_parser def _create_schema_subparser(parser): schema_parser = parser.subcommand('schema', 'View a BigQuery table or view schema.') group = schema_parser.add_mutually_exclusive_group() group.add_argument('-v', '--view', help='the name of the view whose schema should be displayed') group.add_argument('-t', '--table', help='the name of the table whose schema should be displayed') return schema_parser def _create_datasets_subparser(parser): datasets_parser = parser.subcommand('datasets', 'List the datasets in a BigQuery project.') datasets_parser.add_argument('-p', '--project', help='The project whose datasets should be listed') datasets_parser.add_argument('-f', '--filter', help='Optional wildcard filter string used to limit the results') return datasets_parser def _create_tables_subparser(parser): tables_parser = parser.subcommand('tables', 'List the tables in a BigQuery project or dataset.') tables_parser.add_argument('-p', '--project', help='The project whose tables should be listed') tables_parser.add_argument('-d', '--dataset', help='The dataset to restrict to') tables_parser.add_argument('-f', '--filter', help='Optional wildcard filter string used to limit the results') return tables_parser def _create_extract_subparser(parser): extract_parser = parser.subcommand('extract', 'Extract BigQuery query results or table to GCS.') extract_parser.add_argument('-f', '--format', choices=['csv', 'json'], default='csv', help='The format to use for the export') extract_parser.add_argument('-c', '--compress', action='store_true', help='Whether to compress the data') extract_parser.add_argument('-H', '--header', action='store_true', help='Whether to include a header line (CSV only)') extract_parser.add_argument('-d', '--delimiter', default=',', help='The field delimiter to use (CSV only)') extract_parser.add_argument('-S', '--source', help='The name of the query or table to extract') extract_parser.add_argument('-D', '--destination', help='The URL of the destination') return extract_parser def _create_load_subparser(parser): load_parser = parser.subcommand('load', 'Load data from GCS into a BigQuery table.') load_parser.add_argument('-m', '--mode', help='One of create (default), append or overwrite', choices=['create', 'append', 'overwrite'], default='create') load_parser.add_argument('-f', '--format', help='The source format', choices=['json', 'csv'], default='csv') load_parser.add_argument('-n', '--skip', help='The number of initial lines to skip; useful for CSV headers', type=int, default=0) load_parser.add_argument('-s', '--strict', help='Whether to reject bad values and jagged lines', action='store_true') load_parser.add_argument('-d', '--delimiter', default=',', help='The inter-field delimiter for CVS (default ,)') load_parser.add_argument('-q', '--quote', default='"', help='The quoted field delimiter for CVS (default ")') load_parser.add_argument('-i', '--infer', help='Whether to attempt to infer the schema from source; ' 'if false the table must already exist', action='store_true') load_parser.add_argument('-S', '--source', help='The URL of the GCS source(s)') load_parser.add_argument('-D', '--destination', help='The destination table name') return load_parser def _get_query_argument(args, cell, env): """ Get a query argument to a cell magic. The query is specified with args['query']. We look that up and if it is a BQ query just return it. If it is instead a SqlModule or SqlStatement it may have variable references. We resolve those using the arg parser for the SqlModule, then override the resulting defaults with either the Python code in cell, or the dictionary in overrides. The latter is for if the overrides are specified with YAML or JSON and eventually we should eliminate code in favor of this. Args: args: the dictionary of magic arguments. cell: the cell contents which can be variable value overrides (if args has a 'query' value) or inline SQL otherwise. env: a dictionary that is used for looking up variable values. Returns: A Query object. """ sql_arg = args.get('query', None) if sql_arg is None: # Assume we have inline SQL in the cell if not isinstance(cell, basestring): raise Exception('Expected a --query argument or inline SQL') return datalab.bigquery.Query(cell, values=env) item = datalab.utils.commands.get_notebook_item(sql_arg) if isinstance(item, datalab.bigquery.Query): # Queries are already expanded. return item # Create an expanded BQ Query. config = datalab.utils.commands.parse_config(cell, env) item, env = datalab.data.SqlModule.get_sql_statement_with_environment(item, config) if cell: env.update(config) # config is both a fallback and an override. return datalab.bigquery.Query(item, values=env) def _sample_cell(args, cell_body): """Implements the bigquery sample cell magic for ipython notebooks. Args: args: the optional arguments following '%%bigquery sample'. cell_body: optional contents of the cell interpreted as SQL, YAML or JSON. Returns: The results of executing the sampling query, or a profile of the sample data. """ env = datalab.utils.commands.notebook_environment() query = None table = None view = None if args['query']: query = _get_query_argument(args, cell_body, env) elif args['table']: table = _get_table(args['table']) elif args['view']: view = datalab.utils.commands.get_notebook_item(args['view']) if not isinstance(view, datalab.bigquery.View): raise Exception('%s is not a view' % args['view']) else: query = datalab.bigquery.Query(cell_body, values=env) count = args['count'] method = args['method'] if method == 'random': sampling = datalab.bigquery.Sampling.random(percent=args['percent'], count=count) elif method == 'hashed': sampling = datalab.bigquery.Sampling.hashed(field_name=args['field'], percent=args['percent'], count=count) elif method == 'sorted': ascending = args['order'] == 'ascending' sampling = datalab.bigquery.Sampling.sorted(args['field'], ascending=ascending, count=count) elif method == 'limit': sampling = datalab.bigquery.Sampling.default(count=count) else: sampling = datalab.bigquery.Sampling.default(count=count) if query: results = query.sample(sampling=sampling, dialect=args['dialect'], billing_tier=args['billing']) elif view: results = view.sample(sampling=sampling) else: results = table.sample(sampling=sampling) if args['verbose']: print(results.sql) if args['profile']: return datalab.utils.commands.profile_df(results.to_dataframe()) else: return results def _create_cell(args, cell_body): """Implements the BigQuery cell magic used to create datasets and tables. The supported syntax is: %%bigquery create dataset -n|--name [-f|--friendly ] [] or: %%bigquery create table -n|--name [--overwrite] [] Args: args: the argument following '%bigquery create '. """ if args['command'] == 'dataset': try: datalab.bigquery.Dataset(args['name']).create(friendly_name=args['friendly'], description=cell_body) except Exception as e: print('Failed to create dataset %s: %s' % (args['name'], e)) else: if cell_body is None: print('Failed to create %s: no schema specified' % args['name']) else: try: record = datalab.utils.commands.parse_config(cell_body, datalab.utils.commands.notebook_environment(), as_dict=False) schema = datalab.bigquery.Schema(record) datalab.bigquery.Table(args['name']).create(schema=schema, overwrite=args['overwrite']) except Exception as e: print('Failed to create table %s: %s' % (args['name'], e)) def _delete_cell(args, _): """Implements the BigQuery cell magic used to delete datasets and tables. The supported syntax is: %%bigquery delete dataset -n|--name or: %%bigquery delete table -n|--name Args: args: the argument following '%bigquery delete '. """ # TODO(gram): add support for wildchars and multiple arguments at some point. The latter is # easy, the former a bit more tricky if non-default projects are involved. if args['command'] == 'dataset': try: datalab.bigquery.Dataset(args['name']).delete() except Exception as e: print('Failed to delete dataset %s: %s' % (args['name'], e)) else: try: datalab.bigquery.Table(args['name']).delete() except Exception as e: print('Failed to delete table %s: %s' % (args['name'], e)) def _dryrun_cell(args, cell_body): """Implements the BigQuery cell magic used to dry run BQ queries. The supported syntax is: %%bigquery dryrun [-q|--sql ] [] Args: args: the argument following '%bigquery dryrun'. cell_body: optional contents of the cell interpreted as YAML or JSON. Returns: The response wrapped in a DryRunStats object """ query = _get_query_argument(args, cell_body, datalab.utils.commands.notebook_environment()) if args['verbose']: print(query.sql) result = query.execute_dry_run(dialect=args['dialect'], billing_tier=args['billing']) return datalab.bigquery._query_stats.QueryStats(total_bytes=result['totalBytesProcessed'], is_cached=result['cacheHit']) def _udf_cell(args, js): """Implements the bigquery_udf cell magic for ipython notebooks. The supported syntax is: %%bigquery udf --module Args: args: the optional arguments following '%%bigquery udf'. js: the UDF declaration (inputs and outputs) and implementation in javascript. Returns: The results of executing the UDF converted to a dataframe if no variable was specified. None otherwise. """ variable_name = args['module'] if not variable_name: raise Exception('Declaration must be of the form %%bigquery udf --module ') # Parse out the input and output specification spec_pattern = r'\{\{([^}]+)\}\}' spec_part_pattern = r'[a-z_][a-z0-9_]*' specs = re.findall(spec_pattern, js) if len(specs) < 2: raise Exception('The JavaScript must declare the input row and output emitter parameters ' 'using valid jsdoc format comments.\n' 'The input row param declaration must be typed as {{field:type, field2:type}} ' 'and the output emitter param declaration must be typed as ' 'function({{field:type, field2:type}}.') inputs = [] input_spec_parts = re.findall(spec_part_pattern, specs[0], flags=re.IGNORECASE) if len(input_spec_parts) % 2 != 0: raise Exception('Invalid input row param declaration. The jsdoc type expression must ' 'define an object with field and type pairs.') for n, t in zip(input_spec_parts[0::2], input_spec_parts[1::2]): inputs.append((n, t)) outputs = [] output_spec_parts = re.findall(spec_part_pattern, specs[1], flags=re.IGNORECASE) if len(output_spec_parts) % 2 != 0: raise Exception('Invalid output emitter param declaration. The jsdoc type expression must ' 'define a function accepting an an object with field and type pairs.') for n, t in zip(output_spec_parts[0::2], output_spec_parts[1::2]): outputs.append((n, t)) # Look for imports. We use a non-standard @import keyword; we could alternatively use @requires. # Object names can contain any characters except \r and \n. import_pattern = r'@import[\s]+(gs://[a-z\d][a-z\d_\.\-]*[a-z\d]/[^\n\r]+)' imports = re.findall(import_pattern, js) # Split the cell if necessary. We look for a 'function(' with no name and a header comment # block with @param and assume this is the primary function, up to a closing '}' at the start # of the line. The remaining cell content is used as support code. split_pattern = r'(.*)(/\*.*?@param.*?@param.*?\*/\w*\n\w*function\w*\(.*?^}\n?)(.*)' parts = re.match(split_pattern, js, re.MULTILINE | re.DOTALL) support_code = '' if parts: support_code = (parts.group(1) + parts.group(3)).strip() if len(support_code): js = parts.group(2) # Finally build the UDF object udf = datalab.bigquery.UDF(inputs, outputs, variable_name, js, support_code, imports) datalab.utils.commands.notebook_environment()[variable_name] = udf def _execute_cell(args, cell_body): """Implements the BigQuery cell magic used to execute BQ queries. The supported syntax is: %%bigquery execute [-q|--sql ] [] Args: args: the arguments following '%bigquery execute'. cell_body: optional contents of the cell interpreted as YAML or JSON. Returns: The QueryResultsTable """ query = _get_query_argument(args, cell_body, datalab.utils.commands.notebook_environment()) if args['verbose']: print(query.sql) return query.execute(args['target'], table_mode=args['mode'], use_cache=not args['nocache'], allow_large_results=args['large'], dialect=args['dialect'], billing_tier=args['billing']).results def _pipeline_cell(args, cell_body): """Implements the BigQuery cell magic used to validate, execute or deploy BQ pipelines. The supported syntax is: %%bigquery pipeline [-q|--sql ] [] Args: args: the arguments following '%bigquery pipeline'. cell_body: optional contents of the cell interpreted as YAML or JSON. Returns: The QueryResultsTable """ if args['action'] == 'deploy': raise Exception('Deploying a pipeline is not yet supported') env = {} for key, value in datalab.utils.commands.notebook_environment().items(): if isinstance(value, datalab.bigquery._udf.UDF): env[key] = value query = _get_query_argument(args, cell_body, env) if args['verbose']: print(query.sql) if args['action'] == 'dryrun': print(query.sql) result = query.execute_dry_run() return datalab.bigquery._query_stats.QueryStats(total_bytes=result['totalBytesProcessed'], is_cached=result['cacheHit']) if args['action'] == 'run': return query.execute(args['target'], table_mode=args['mode'], use_cache=not args['nocache'], allow_large_results=args['large'], dialect=args['dialect'], billing_tier=args['billing']).results def _table_line(args): """Implements the BigQuery table magic used to display tables. The supported syntax is: %bigquery table -t|--table Args: args: the arguments following '%bigquery table'. Returns: The HTML rendering for the table. """ # TODO(gram): It would be good to turn _table_viewer into a class that has a registered # renderer. That would allow this to return a table viewer object which is easier to test. name = args['table'] table = _get_table(name) if table and table.exists(): fields = args['cols'].split(',') if args['cols'] else None html = _table_viewer(table, rows_per_page=args['rows'], fields=fields) return IPython.core.display.HTML(html) else: raise Exception('Table %s does not exist; cannot display' % name) def _get_schema(name): """ Given a variable or table name, get the Schema if it exists. """ item = datalab.utils.commands.get_notebook_item(name) if not item: item = _get_table(name) if isinstance(item, datalab.bigquery.Schema): return item if hasattr(item, 'schema') and isinstance(item.schema, datalab.bigquery._schema.Schema): return item.schema return None # An LRU cache for Tables. This is mostly useful so that when we cross page boundaries # when paging through a table we don't have to re-fetch the schema. _table_cache = datalab.utils.LRUCache(10) def _get_table(name): """ Given a variable or table name, get a Table if it exists. Args: name: the name of the Table or a variable referencing the Table. Returns: The Table, if found. """ # If name is a variable referencing a table, use that. item = datalab.utils.commands.get_notebook_item(name) if isinstance(item, datalab.bigquery.Table): return item # Else treat this as a BQ table name and return the (cached) table if it exists. try: return _table_cache[name] except KeyError: table = datalab.bigquery.Table(name) if table.exists(): _table_cache[name] = table return table return None def _schema_line(args): """Implements the BigQuery schema magic used to display table/view schemas. Args: args: the arguments following '%bigquery schema'. Returns: The HTML rendering for the schema. """ # TODO(gram): surely we could just return the schema itself? name = args['table'] if args['table'] else args['view'] if name is None: raise Exception('No table or view specified; cannot show schema') schema = _get_schema(name) if schema: html = _repr_html_table_schema(schema) return IPython.core.display.HTML(html) else: raise Exception('%s is not a schema and does not appear to have a schema member' % name) def _render_table(data, fields=None): """ Helper to render a list of dictionaries as an HTML display object. """ return IPython.core.display.HTML(datalab.utils.commands.HtmlBuilder.render_table(data, fields)) def _render_list(data): """ Helper to render a list of objects as an HTML list object. """ return IPython.core.display.HTML(datalab.utils.commands.HtmlBuilder.render_list(data)) def _datasets_line(args): """Implements the BigQuery datasets magic used to display datasets in a project. The supported syntax is: %bigquery datasets [-f ] [-p|--project ] Args: args: the arguments following '%bigquery datasets'. Returns: The HTML rendering for the table of datasets. """ filter_ = args['filter'] if args['filter'] else '*' return _render_list([str(dataset) for dataset in datalab.bigquery.Datasets(args['project']) if fnmatch.fnmatch(str(dataset), filter_)]) def _tables_line(args): """Implements the BigQuery tables magic used to display tables in a dataset. The supported syntax is: %bigquery tables -p|--project -d|--dataset Args: args: the arguments following '%bigquery tables'. Returns: The HTML rendering for the list of tables. """ filter_ = args['filter'] if args['filter'] else '*' if args['dataset']: if args['project'] is None: datasets = [datalab.bigquery.Dataset(args['dataset'])] else: datasets = [datalab.bigquery.Dataset((args['project'], args['dataset']))] else: datasets = datalab.bigquery.Datasets(args['project']) tables = [] for dataset in datasets: tables.extend([str(table) for table in dataset if fnmatch.fnmatch(str(table), filter_)]) return _render_list(tables) def _extract_line(args): """Implements the BigQuery extract magic used to extract table data to GCS. The supported syntax is: %bigquery extract -S|--source
-D|--destination Args: args: the arguments following '%bigquery extract'. Returns: A message about whether the extract succeeded or failed. """ name = args['source'] source = datalab.utils.commands.get_notebook_item(name) if not source: source = _get_table(name) if not source: raise Exception('No source named %s found' % name) elif isinstance(source, datalab.bigquery.Table) and not source.exists(): raise Exception('Table %s does not exist' % name) else: job = source.extract(args['destination'], format='CSV' if args['format'] == 'csv' else 'NEWLINE_DELIMITED_JSON', compress=args['compress'], csv_delimiter=args['delimiter'], csv_header=args['header']) if job.failed: raise Exception('Extract failed: %s' % str(job.fatal_error)) elif job.errors: raise Exception('Extract completed with errors: %s' % str(job.errors)) def _load_cell(args, schema): """Implements the BigQuery load magic used to load data from GCS to a table. The supported syntax is: %bigquery load -S|--source -D|--destination
Args: args: the arguments following '%bigquery load'. schema: a JSON schema for the destination table. Returns: A message about whether the load succeeded or failed. """ name = args['destination'] table = _get_table(name) if not table: table = datalab.bigquery.Table(name) if table.exists(): if args['mode'] == 'create': raise Exception('%s already exists; use --append or --overwrite' % name) elif schema: table.create(json.loads(schema)) elif not args['infer']: raise Exception( 'Table does not exist, no schema specified in cell and no --infer flag; cannot load') # TODO(gram): we should probably try do the schema infer ourselves as BQ doesn't really seem # to be able to do it. Alternatively we can drop the --infer argument and force the user # to use a pre-existing table or supply a JSON schema. csv_options = datalab.bigquery.CSVOptions(delimiter=args['delimiter'], skip_leading_rows=args['skip'], allow_jagged_rows=not args['strict'], quote=args['quote']) job = table.load(args['source'], mode=args['mode'], source_format=('CSV' if args['format'] == 'csv' else 'NEWLINE_DELIMITED_JSON'), csv_options=csv_options, ignore_unknown_values=not args['strict']) if job.failed: raise Exception('Load failed: %s' % str(job.fatal_error)) elif job.errors: raise Exception('Load completed with errors: %s' % str(job.errors)) def _add_command(parser, subparser_fn, handler, cell_required=False, cell_prohibited=False): """ Create and initialize a bigquery subcommand handler. """ sub_parser = subparser_fn(parser) sub_parser.set_defaults(func=lambda args, cell: _dispatch_handler(args, cell, sub_parser, handler, cell_required=cell_required, cell_prohibited=cell_prohibited)) def _create_bigquery_parser(): """ Create the parser for the %bigquery magics. Note that because we use the func default handler dispatch mechanism of argparse, our handlers can take only one argument which is the parsed args. So we must create closures for the handlers that bind the cell contents and thus must recreate this parser for each cell upon execution. """ parser = datalab.utils.commands.CommandParser(prog='bigquery', description=""" Execute various BigQuery-related operations. Use "%bigquery -h" for help on a specific command. """) # This is a bit kludgy because we want to handle some line magics and some cell magics # with the bigquery command. # %%bigquery sample _add_command(parser, _create_sample_subparser, _sample_cell) # %%bigquery create _add_command(parser, _create_create_subparser, _create_cell) # %%bigquery delete _add_command(parser, _create_delete_subparser, _delete_cell) # %%bigquery dryrun _add_command(parser, _create_dry_run_subparser, _dryrun_cell) # %%bigquery udf _add_command(parser, _create_udf_subparser, _udf_cell, cell_required=True) # %%bigquery execute _add_command(parser, _create_execute_subparser, _execute_cell) # %%bigquery pipeline _add_command(parser, _create_pipeline_subparser, _pipeline_cell) # %bigquery table _add_command(parser, _create_table_subparser, _table_line, cell_prohibited=True) # %bigquery schema _add_command(parser, _create_schema_subparser, _schema_line, cell_prohibited=True) # %bigquery datasets _add_command(parser, _create_datasets_subparser, _datasets_line, cell_prohibited=True) # %bigquery tables _add_command(parser, _create_tables_subparser, _tables_line, cell_prohibited=True) # % bigquery extract _add_command(parser, _create_extract_subparser, _extract_line, cell_prohibited=True) # %bigquery load # TODO(gram): need some additional help, esp. around the option of specifying schema in # cell body and how schema infer may fail. _add_command(parser, _create_load_subparser, _load_cell) return parser _bigquery_parser = _create_bigquery_parser() @IPython.core.magic.register_line_cell_magic def bigquery(line, cell=None): """Implements the bigquery cell magic for ipython notebooks. The supported syntax is: %%bigquery [] or: %bigquery [] Use %bigquery --help for a list of commands, or %bigquery --help for help on a specific command. """ namespace = {} if line.find('$') >= 0: # We likely have variables to expand; get the appropriate context. namespace = datalab.utils.commands.notebook_environment() return datalab.utils.commands.handle_magic_line(line, cell, _bigquery_parser, namespace=namespace) def _dispatch_handler(args, cell, parser, handler, cell_required=False, cell_prohibited=False): """ Makes sure cell magics include cell and line magics don't, before dispatching to handler. Args: args: the parsed arguments from the magic line. cell: the contents of the cell, if any. parser: the argument parser for ; used for error message. handler: the handler to call if the cell present/absent check passes. cell_required: True for cell magics, False for line magics that can't be cell magics. cell_prohibited: True for line magics, False for cell magics that can't be line magics. Returns: The result of calling the handler. Raises: Exception if the invocation is not valid. """ if cell_prohibited: if cell and len(cell.strip()): parser.print_help() raise Exception('Additional data is not supported with the %s command.' % parser.prog) return handler(args) if cell_required and not cell: parser.print_help() raise Exception('The %s command requires additional data' % parser.prog) return handler(args, cell) def _table_viewer(table, rows_per_page=25, fields=None): """ Return a table viewer. This includes a static rendering of the first page of the table, that gets replaced by the charting code in environments where Javascript is executable and BQ is available. Args: table: the table to view. rows_per_page: how many rows to display at one time. fields: an array of field names to display; default is None which uses the full schema. Returns: A string containing the HTML for the table viewer. """ # TODO(gram): rework this to use datalab.utils.commands.chart_html if not table.exists(): raise Exception('Table %s does not exist' % str(table)) if not table.is_listable(): return "Done" _HTML_TEMPLATE = u"""
{static_table}

{meta_data}
""" if fields is None: fields = datalab.utils.commands.get_field_list(fields, table.schema) div_id = datalab.utils.commands.Html.next_id() meta_count = ('rows: %d' % table.length) if table.length >= 0 else '' meta_name = str(table) if table.job is None else ('job: %s' % table.job.id) if table.job: if table.job.cache_hit: meta_cost = 'cached' else: bytes = datalab.bigquery._query_stats.QueryStats._size_formatter(table.job.bytes_processed) meta_cost = '%s processed' % bytes meta_time = 'time: %.1fs' % table.job.total_time else: meta_cost = '' meta_time = '' data, total_count = datalab.utils.commands.get_data(table, fields, first_row=0, count=rows_per_page) if total_count < 0: # The table doesn't have a length metadata property but may still be small if we fetched less # rows than we asked for. fetched_count = len(data['rows']) if fetched_count < rows_per_page: total_count = fetched_count chart = 'table' if 0 <= total_count <= rows_per_page else 'paged_table' meta_entries = [meta_count, meta_time, meta_cost, meta_name] meta_data = '(%s)' % (', '.join([entry for entry in meta_entries if len(entry)])) return _HTML_TEMPLATE.format(div_id=div_id, static_table=datalab.utils.commands.HtmlBuilder .render_chart_data(data), meta_data=meta_data, chart_style=chart, source_index=datalab.utils.commands .get_data_source_index(str(table)), fields=','.join(fields), total_rows=total_count, rows_per_page=rows_per_page, data=json.dumps(data, cls=datalab.utils.JSONEncoder)) def _repr_html_query(query): # TODO(nikhilko): Pretty print the SQL return datalab.utils.commands.HtmlBuilder.render_text(query.sql, preformatted=True) def _repr_html_query_results_table(results): return _table_viewer(results) def _repr_html_table(results): return _table_viewer(results) def _repr_html_table_schema(schema): _HTML_TEMPLATE = """
""" id = datalab.utils.commands.Html.next_id() return _HTML_TEMPLATE % (id, id, json.dumps(schema._bq_schema)) def _register_html_formatters(): try: ipy = IPython.get_ipython() html_formatter = ipy.display_formatter.formatters['text/html'] html_formatter.for_type_by_name('datalab.bigquery._query', 'Query', _repr_html_query) html_formatter.for_type_by_name('datalab.bigquery._query_results_table', 'QueryResultsTable', _repr_html_query_results_table) html_formatter.for_type_by_name('datalab.bigquery._table', 'Table', _repr_html_table) html_formatter.for_type_by_name('datalab.bigquery._schema', 'Schema', _repr_html_table_schema) except TypeError: # For when running unit tests pass _register_html_formatters() ================================================ FILE: datalab/context/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - authorization Context for Cloud services.""" from ._context import Context from ._project import Project, Projects __all__ = ['Context', 'Project', 'Projects'] ================================================ FILE: datalab/context/_api.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements HTTP API wrapper.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object import datalab.utils class Api(object): """A helper class to issue API HTTP requests to resource manager API.""" _ENDPOINT = 'https://cloudresourcemanager.googleapis.com/v1' _PROJECT_PATH = '/projects/%s' _PROJECTS_PATH = '/projects' def __init__(self, credentials): self._credentials = credentials def projects_list(self, max_results=0, page_token=None): url = Api._ENDPOINT + Api._PROJECTS_PATH args = {} if max_results != 0: args['pageSize'] = max_results if page_token is not None: args['pageToken'] = page_token return datalab.utils.Http.request(url, args=args, credentials=self._credentials) def project_get(self, projectId): url = Api._ENDPOINT + (Api._PROJECT_PATH % projectId) return datalab.utils.Http.request(url, credentials=self._credentials) ================================================ FILE: datalab/context/_context.py ================================================ # Copyright 2014 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Implements Context functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object from . import _project from . import _utils class Context(object): """Maintains contextual state for connecting to Cloud APIs. """ _global_context = None def __init__(self, project_id, credentials): """Initializes an instance of a Context object. Args: project_id: the current cloud project. credentials: the credentials to use to authorize requests. """ self._project_id = project_id self._credentials = credentials @property def credentials(self): """Retrieves the value of the credentials property. Returns: The current credentials used in authorizing API requests. """ return self._credentials def set_credentials(self, credentials): """ Set the credentials for the context. """ self._credentials = credentials @property def project_id(self): """Retrieves the value of the project_id property. Returns: The current project id to associate with API requests. """ if not self._project_id: raise Exception('No project ID found. Perhaps you should set one by running ' '"%projects set " in a code cell.') return self._project_id def set_project_id(self, project_id): """ Set the project_id for the context. """ self._project_id = project_id if self == Context._global_context: try: from google.datalab import Context as new_context new_context.default().set_project_id(project_id) except ImportError: # If the new library is not loaded, then we have nothing to do. pass @staticmethod def is_signed_in(): """ If the user has signed in or it is on GCE VM with default credential.""" try: _utils.get_credentials() return True except Exception: return False @staticmethod def default(): """Retrieves a default Context object, creating it if necessary. The default Context is a global shared instance used every time the default context is retrieved. Attempting to use a Context with no project_id will raise an exception, so on first use set_project_id must be called. Returns: An initialized and shared instance of a Context object. """ credentials = _utils.get_credentials() if Context._global_context is None: project = _project.Projects.get_default_id(credentials) Context._global_context = Context(project, credentials) else: # Always update the credentials in case the access token is revoked or expired Context._global_context.set_credentials(credentials) return Context._global_context ================================================ FILE: datalab/context/_project.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Projects functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object import datalab.utils from . import _api from . import _utils # We could do this with the gcloud SDK. However, installing that while locked on oauth2.5 # introduces some ugliness; in particular we are stuck with v0.9 and need to work around: # # https://github.com/GoogleCloudPlatform/gcloud-python/issues/1412 # # and would have to put up with: # # https://github.com/GoogleCloudPlatform/gcloud-python/issues/1570 # # So we use the REST API instead. class Project(object): """ Simple wrapper class for Cloud projects. """ def __init__(self, api, id, number, name): self._api = api self._id = id self._number = number self._name = name @property def id(self): return self._id @property def name(self): return self._name @property def number(self): return self._number def __str__(self): return self._id class Projects(object): """ Iterator class for enumerating the active projects accessible to the account. """ def __init__(self, credentials=None): """ Initialize the Projects object. Args: credentials: the credentials for the account. """ if credentials is None: credentials = _utils.get_credentials() self._api = _api.Api(credentials) def _retrieve_projects(self, page_token, count): try: list_info = self._api.projects_list(max_results=count, page_token=page_token) except Exception as e: raise e projects = list_info.get('projects', []) if len(projects): try: projects = [Project(self._api, info['projectId'], info['projectNumber'], info['name']) for info in projects if info['lifecycleState'] == 'ACTIVE'] except KeyError: raise Exception('Unexpected response from server.') page_token = list_info.get('nextPageToken', None) return projects, page_token def __iter__(self): """ Returns an iterator for iterating through the Datasets in the project. """ return iter(datalab.utils.Iterator(self._retrieve_projects)) @staticmethod def get_default_id(credentials=None): """ Get default project id. Returns: the default project id if there is one, or None. """ project_id = _utils.get_project_id() if project_id is None: projects, _ = Projects(credentials)._retrieve_projects(None, 2) if len(projects) == 1: project_id = projects[0].id return project_id @staticmethod def save_default_id(project_id): """ Save default project id to config so it will persist across kernels and Datalab runs. Args: project_id: the project_id to save. """ _utils.save_project_id(project_id) ================================================ FILE: datalab/context/_utils.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Support for getting gcloud credentials. """ from __future__ import absolute_import from __future__ import unicode_literals import json import os import subprocess import oauth2client.client import google.auth import google.auth.exceptions import google.auth.credentials import google.auth._oauth2client # TODO(ojarjur): This limits the APIs against which Datalab can be called # (when using a service account with a credentials file) to only being those # that are part of the Google Cloud Platform. We should either extend this # to all of the API scopes that Google supports, or make it extensible so # that the user can define for themselves which scopes they want to use. CREDENTIAL_SCOPES = [ 'https://www.googleapis.com/auth/cloud-platform', ] def _in_datalab_docker(): return os.path.exists('/datalab') and os.getenv('DATALAB_ENV') def get_config_dir(): config_dir = os.getenv('CLOUDSDK_CONFIG') if config_dir is None: if os.name == 'nt': try: config_dir = os.path.join(os.environ['APPDATA'], 'gcloud') except KeyError: # This should never happen unless someone is really messing with things. drive = os.environ.get('SystemDrive', 'C:') config_dir = os.path.join(drive, '\\gcloud') else: config_dir = os.path.join(os.path.expanduser('~'), '.config/gcloud') return config_dir def _convert_oauth2client_creds(credentials): new_credentials = google.oauth2.credentials.Credentials( token=credentials.access_token, refresh_token=credentials.refresh_token, token_uri=credentials.token_uri, client_id=credentials.client_id, client_secret=credentials.client_secret, scopes=credentials.scopes) new_credentials._expires = credentials.token_expiry return new_credentials def get_credentials(): """ Get the credentials to use. We try application credentials first, followed by user credentials. The path to the application credentials can be overridden by pointing the GOOGLE_APPLICATION_CREDENTIALS environment variable to some file; the path to the user credentials can be overridden by pointing the CLOUDSDK_CONFIG environment variable to some directory (after which we will look for the file $CLOUDSDK_CONFIG/gcloud/credentials). Unless you have specific reasons for overriding these the defaults should suffice. """ try: credentials, _ = google.auth.default() credentials = google.auth.credentials.with_scopes_if_required(credentials, CREDENTIAL_SCOPES) return credentials except Exception as e: # Try load user creds from file cred_file = get_config_dir() + '/credentials' if os.path.exists(cred_file): with open(cred_file) as f: creds = json.loads(f.read()) # Use the first gcloud one we find for entry in creds['data']: if entry['key']['type'] == 'google-cloud-sdk': creds = oauth2client.client.OAuth2Credentials.from_json(json.dumps(entry['credential'])) return _convert_oauth2client_creds(creds) if type(e) == google.auth.exceptions.DefaultCredentialsError: # If we are in Datalab container, change the message to be about signing in. if _in_datalab_docker(): raise Exception('No application credentials found. Perhaps you should sign in.') raise e def save_project_id(project_id): """ Save project id to config file. Args: project_id: the project_id to save. """ # Try gcloud first. If gcloud fails (probably because it does not exist), then # write to a config file. try: subprocess.call(['gcloud', 'config', 'set', 'project', project_id]) except: config_file = os.path.join(get_config_dir(), 'config.json') config = {} if os.path.exists(config_file): with open(config_file) as f: config = json.loads(f.read()) config['project_id'] = project_id with open(config_file, 'w') as f: f.write(json.dumps(config)) def get_project_id(): """ Get default project id from config or environment var. Returns: the project id if available, or None. """ # Try getting default project id from gcloud. If it fails try config.json. try: proc = subprocess.Popen(['gcloud', 'config', 'list', '--format', 'value(core.project)'], stdout=subprocess.PIPE) stdout, _ = proc.communicate() value = stdout.decode().strip() if proc.poll() == 0 and value: return value except: pass config_file = os.path.join(get_config_dir(), 'config.json') if os.path.exists(config_file): with open(config_file) as f: config = json.loads(f.read()) if 'project_id' in config and config['project_id']: return str(config['project_id']) if os.getenv('PROJECT_ID') is not None: return os.getenv('PROJECT_ID') return None ================================================ FILE: datalab/context/commands/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from . import _projects __all__ = ['_projects'] ================================================ FILE: datalab/context/commands/_projects.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements listing projects and setting default project.""" from __future__ import absolute_import from __future__ import unicode_literals try: import IPython import IPython.core.magic import IPython.core.display except ImportError: raise Exception('This module can only be loaded in ipython.') import fnmatch import datalab.utils.commands import datalab.context @IPython.core.magic.register_line_cell_magic def projects(line, cell=None): parser = datalab.utils.commands.CommandParser.create('projects') list_parser = parser.subcommand('list', 'List available projects.') list_parser.add_argument('-f', '--filter', help='Optional wildcard id filter string used to limit the results') list_parser.set_defaults(func=_list_line, cell_prohibited=True) set_parser = parser.subcommand('set', 'Set the default project.') set_parser.add_argument('id', help='The ID of the project to use') set_parser.set_defaults(func=_set_line, cell_prohibited=True) return datalab.utils.commands.handle_magic_line(line, cell, parser) def _list_line(args, _): # TODO(gram): should we use a paginated table? filter_ = args['filter'] if args['filter'] else '*' data = [{'id': project.id, 'name': project.name} for project in datalab.context.Projects() if fnmatch.fnmatch(project.id, filter_)] return IPython.core.display.HTML(datalab.utils.commands.HtmlBuilder.render_table(data, ['id', 'name'])) def _set_line(args, _): id_ = args['id'] if args['id'] else '' context = datalab.context.Context.default() context.set_project_id(id_) datalab.context.Projects.save_default_id(id_) ================================================ FILE: datalab/data/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - Generic SQL Helpers.""" from __future__ import absolute_import from __future__ import unicode_literals from ._csv import Csv from ._sql_module import SqlModule from ._sql_statement import SqlStatement from ._utils import tokenize __all__ = ['Csv', 'SqlModule', 'SqlStatement', 'tokenize'] ================================================ FILE: datalab/data/_csv.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements usefule CSV utilities.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import next from builtins import str as newstr from builtins import range from builtins import object import csv import os import pandas as pd import random import sys try: from StringIO import StringIO except ImportError: from io import StringIO import tempfile import datalab.storage import datalab.utils _MAX_CSV_BYTES = 10000000 class Csv(object): """Represents a CSV file in GCS or locally with same schema. """ def __init__(self, path, delimiter=b','): """Initializes an instance of a Csv instance. Args: path: path of the Csv file. delimiter: the separator used to parse a Csv line. """ self._path = path self._delimiter = delimiter @property def path(self): return self._path @staticmethod def _read_gcs_lines(path, max_lines=None): return datalab.storage.Item.from_url(path).read_lines(max_lines) @staticmethod def _read_local_lines(path, max_lines=None): lines = [] for line in open(path): if max_lines is not None and len(lines) >= max_lines: break lines.append(line) return lines def _is_probably_categorical(self, column): if newstr(column.dtype) != 'object': # only string types (represented in DataFrame as object) can potentially be categorical return False if len(max(column, key=lambda p: len(newstr(p)))) > 100: return False # value too long to be a category if len(set(column)) > 100: return False # too many unique values to be a category return True def browse(self, max_lines=None, headers=None): """Try reading specified number of lines from the CSV object. Args: max_lines: max number of lines to read. If None, the whole file is read headers: a list of strings as column names. If None, it will use "col0, col1..." Returns: A pandas DataFrame with the schema inferred from the data. Raises: Exception if the csv object cannot be read or not enough lines to read, or the headers size does not match columns size. """ if self.path.startswith('gs://'): lines = Csv._read_gcs_lines(self.path, max_lines) else: lines = Csv._read_local_lines(self.path, max_lines) if len(lines) == 0: return pd.DataFrame(columns=headers) columns_size = len(next(csv.reader([lines[0]], delimiter=self._delimiter))) if headers is None: headers = ['col' + newstr(e) for e in range(columns_size)] if len(headers) != columns_size: raise Exception('Number of columns in CSV do not match number of headers') buf = StringIO() for line in lines: buf.write(line) buf.write('\n') buf.seek(0) df = pd.read_csv(buf, names=headers, delimiter=self._delimiter) for key, col in df.iteritems(): if self._is_probably_categorical(col): df[key] = df[key].astype('category') return df def _create_federated_table(self, skip_header_rows): import datalab.bigquery as bq df = self.browse(1, None) # read each column as STRING because we only want to sample rows. schema_train = bq.Schema([{'name': name, 'type': 'STRING'} for name in df.keys()]) options = bq.CSVOptions(skip_leading_rows=(1 if skip_header_rows is True else 0)) return bq.FederatedTable.from_storage(self.path, csv_options=options, schema=schema_train, max_bad_records=0) def _get_gcs_csv_row_count(self, federated_table): import datalab.bigquery as bq results = bq.Query('SELECT count(*) from data', data_sources={'data': federated_table}).results() return results[0].values()[0] def sample_to(self, count, skip_header_rows, strategy, target): """Sample rows from GCS or local file and save results to target file. Args: count: number of rows to sample. If strategy is "BIGQUERY", it is used as approximate number. skip_header_rows: whether to skip first row when reading from source. strategy: can be "LOCAL" or "BIGQUERY". If local, the sampling happens in local memory, and number of resulting rows matches count. If BigQuery, sampling is done with BigQuery in cloud, and the number of resulting rows will be approximated to count. target: The target file path, can be GCS or local path. Raises: Exception if strategy is "BIGQUERY" but source is not a GCS path. """ # TODO(qimingj) Add unit test # Read data from source into DataFrame. if sys.version_info.major > 2: xrange = range # for python 3 compatibility if strategy == 'BIGQUERY': import datalab.bigquery as bq if not self.path.startswith('gs://'): raise Exception('Cannot use BIGQUERY if data is not in GCS') federated_table = self._create_federated_table(skip_header_rows) row_count = self._get_gcs_csv_row_count(federated_table) query = bq.Query('SELECT * from data', data_sources={'data': federated_table}) sampling = bq.Sampling.random(count * 100 / float(row_count)) sample = query.sample(sampling=sampling) df = sample.to_dataframe() elif strategy == 'LOCAL': local_file = self.path if self.path.startswith('gs://'): local_file = tempfile.mktemp() datalab.utils.gcs_copy_file(self.path, local_file) with open(local_file) as f: row_count = sum(1 for line in f) start_row = 1 if skip_header_rows is True else 0 skip_count = row_count - count - 1 if skip_header_rows is True else row_count - count skip = sorted(random.sample(xrange(start_row, row_count), skip_count)) header_row = 0 if skip_header_rows is True else None df = pd.read_csv(local_file, skiprows=skip, header=header_row, delimiter=self._delimiter) if self.path.startswith('gs://'): os.remove(local_file) else: raise Exception('strategy must be BIGQUERY or LOCAL') # Write to target. if target.startswith('gs://'): with tempfile.NamedTemporaryFile() as f: df.to_csv(f, header=False, index=False) f.flush() datalab.utils.gcs_copy_file(f.name, target) else: with open(target, 'w') as f: df.to_csv(f, header=False, index=False, sep=str(self._delimiter)) ================================================ FILE: datalab/data/_sql_module.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Helper functions for %%sql modules.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from past.builtins import basestring from builtins import object import shlex from . import _sql_statement from . import _utils # It would be nice to be able to inherit from Python module but AFAICT that is not possible. # So this just wraps a bunch of static helpers. class SqlModule(object): """ A container for SqlStatements defined together and able to reference each other. """ @staticmethod def _get_sql_args(parser, args=None): """ Parse a set of %%sql arguments or get the default value of the arguments. Args: parser: the argument parser to use. args: the argument flags. May be a string or a list. If omitted the empty string is used so we can get the default values for the arguments. These are all used to override the arg parser. Alternatively args may be a dictionary, in which case it overrides the default values from the arg parser. Returns: A dictionary of argument names and values. """ overrides = None if args is None: tokens = [] elif isinstance(args, basestring): command_line = ' '.join(args.split('\n')) tokens = shlex.split(command_line) elif isinstance(args, dict): overrides = args tokens = [] else: tokens = args args = {} if parser is None else vars(parser.parse_args(tokens)) if overrides: args.update(overrides) # Don't return any args that are None as we don't want to expand to 'None' return {arg: value for arg, value in args.items() if value is not None} @staticmethod def get_default_query_from_module(module): """ Given a %%sql module return the default (last) query for the module. Args: module: the %%sql module. Returns: The default query associated with this module. """ return _utils.get_default_query_from_module(module) @staticmethod def get_sql_statement_with_environment(item, args=None): """ Given a SQLStatement, string or module plus command line args or a dictionary, return a SqlStatement and final dictionary for variable resolution. Args: item: a SqlStatement, %%sql module, or string containing a query. args: a string of command line arguments or a dictionary of values. Returns: A SqlStatement for the query or module, plus a dictionary of variable values to use. """ if isinstance(item, basestring): item = _sql_statement.SqlStatement(item) elif not isinstance(item, _sql_statement.SqlStatement): item = SqlModule.get_default_query_from_module(item) if not item: raise Exception('Expected a SQL statement or module but got %s' % str(item)) env = {} if item.module: env.update(item.module.__dict__) parser = env.get(_utils._SQL_MODULE_ARGPARSE, None) if parser: args = SqlModule._get_sql_args(parser, args=args) else: args = None if isinstance(args, dict): env.update(args) return item, env @staticmethod def expand(sql, args=None): """ Expand a SqlStatement, query string or SqlModule with a set of arguments. Args: sql: a SqlStatement, %%sql module, or string containing a query. args: a string of command line arguments or a dictionary of values. If a string, it is passed to the argument parser for the SqlModule associated with the SqlStatement or SqlModule. If a dictionary, it is used to override any default arguments from the argument parser. If the sql argument is a string then args must be None or a dictionary as in this case there is no associated argument parser. Returns: The expanded SQL, list of referenced scripts, and list of referenced external tables. """ sql, args = SqlModule.get_sql_statement_with_environment(sql, args) return _sql_statement.SqlStatement.format(sql._sql, args) ================================================ FILE: datalab/data/_sql_statement.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements SQL statement helper functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from builtins import object from past.builtins import basestring import re import types import datalab.utils from . import _utils class SqlStatement(object): """A helper class for wrapping and manipulating SQL statements. """ def __init__(self, sql, module=None): """ Initializes the SqlStatement. Args: sql: a string containing a SQL query with optional variable references. module: if defined in a %%sql cell, the parent SqlModule object for the SqlStatement. """ self._sql = sql self._module = module def __str__(self): """Creates a string representation of this object. Returns: The string representation of this object. """ return self._sql def __repr__(self): """Creates a friendly representation of this object. Returns: The friendly representation of this object. """ return self._sql @property def sql(self): """ The (unexpanded) SQL for the SqlStatement. """ return self._sql @property def module(self): """ The parent SqlModule for the SqlStatement, if any. """ return self._module @staticmethod def _find_recursive_dependencies(sql, values, code, resolved_vars, resolving_vars=None): """ Recursive helper method for expanding variables including transitive dependencies. Placeholders in SQL are represented as $. If '$' must appear within the SQL statement literally, then it can be escaped as '$$'. Args: sql: the raw SQL statement with named placeholders. values: the user-supplied dictionary of name/value pairs to use for placeholder values. code: an array of referenced UDFs found during expansion. resolved_vars: a ref parameter for the variable references completely resolved so far. resolving_vars: a ref parameter for the variable(s) we are currently resolving; if we see a dependency again that is in this set we know we have a circular reference. Returns: The formatted SQL statement with placeholders replaced with their values. Raises: Exception if a placeholder was found in the SQL statement, but did not have a corresponding argument value. """ # Get the set of $var references in this SQL. dependencies = SqlStatement._get_dependencies(sql) for dependency in dependencies: # Now we check each dependency. If it is in complete - i.e., we have an expansion # for it already - we just continue. if dependency in resolved_vars: continue # Look it up in our resolution namespace dictionary. dep = datalab.utils.get_item(values, dependency) # If it is a SQL module, get the main/last query from the module, so users can refer # to $module. Useful especially if final query in module has no DEFINE QUERY part. if isinstance(dep, types.ModuleType): dep = _utils.get_default_query_from_module(dep) # If we can't resolve the $name, give up. if dep is None: raise Exception("Unsatisfied dependency $%s" % dependency) # If it is a SqlStatement, it may have its own $ references in turn; check to make # sure we don't have circular references, and if not, recursively expand it and add # it to the set of complete dependencies. if isinstance(dep, SqlStatement): if resolving_vars is None: resolving_vars = [] elif dependency in resolving_vars: # Circular dependency raise Exception("Circular dependency in $%s" % dependency) resolving_vars.append(dependency) SqlStatement._find_recursive_dependencies(dep._sql, values, code, resolved_vars, resolving_vars) resolving_vars.pop() resolved_vars[dependency] = SqlStatement(dep._sql) else: resolved_vars[dependency] = dep @staticmethod def _escape_string(s): return '"' + s.replace('"', '\\"') + '"' @staticmethod def format(sql, args=None): """ Resolve variable references in a query within an environment. This computes and resolves the transitive dependencies in the query and raises an exception if that fails due to either undefined or circular references. Args: sql: query to format. args: a dictionary of values to use in variable expansion. Returns: The resolved SQL text with variables expanded. Raises: Exception on failure. """ resolved_vars = {} code = [] SqlStatement._find_recursive_dependencies(sql, args, code=code, resolved_vars=resolved_vars) # Rebuild the SQL string, substituting just '$' for escaped $ occurrences, # variable references substituted with their values, or literal text copied # over as-is. parts = [] for (escape, placeholder, _, literal) in SqlStatement._get_tokens(sql): if escape: parts.append('$') elif placeholder: variable = placeholder[1:] try: value = resolved_vars[variable] except KeyError as e: raise Exception('Invalid sql. Unable to substitute $%s.' % e.args[0]) if isinstance(value, types.ModuleType): value = _utils.get_default_query_from_module(value) if isinstance(value, SqlStatement): sql = value.format(value._sql, resolved_vars) value = '(%s)' % sql elif '_repr_sql_' in dir(value): # pylint: disable=protected-access value = value._repr_sql_() elif isinstance(value, basestring): value = SqlStatement._escape_string(value) elif isinstance(value, list) or isinstance(value, tuple): if isinstance(value, tuple): value = list(value) expansion = '(' for v in value: if len(expansion) > 1: expansion += ', ' if isinstance(v, basestring): expansion += SqlStatement._escape_string(v) else: expansion += str(v) expansion += ')' value = expansion else: value = str(value) parts.append(value) elif literal: parts.append(literal) expanded = ''.join(parts) return expanded @staticmethod def _get_tokens(sql): # Find escaped '$' characters ($$), "$" variable references, lone '$' characters, or # literal sequences of character without any '$' in them (in that order). return re.findall(r'(\$\$)|(\$[a-zA-Z_][a-zA-Z0-9_\.]*)|(\$)|([^\$]*)', sql) @staticmethod def _get_dependencies(sql): """ Return the list of variables referenced in this SQL. """ dependencies = [] for (_, placeholder, dollar, _) in SqlStatement._get_tokens(sql): if placeholder: variable = placeholder[1:] if variable not in dependencies: dependencies.append(variable) elif dollar: raise Exception('Invalid sql; $ with no following $ or identifier: %s.' % sql) return dependencies ================================================ FILE: datalab/data/_utils.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - Generic SQL Helpers.""" from __future__ import absolute_import from __future__ import unicode_literals import types # Names used for the arg parser, unnamed (main) query and last query in the module. # Note that every module has a last query, but not every module has a main query. _SQL_MODULE_ARGPARSE = '_sql_module_arg_parser' _SQL_MODULE_MAIN = '_sql_module_main' _SQL_MODULE_LAST = '_sql_module_last' def get_default_query_from_module(module): """ Given a %%sql module return the default (last) query for the module. Args: module: the %%sql module. Returns: The default query associated with this module. """ if isinstance(module, types.ModuleType): return module.__dict__.get(_SQL_MODULE_LAST, None) return None def _next_token(sql): """ This is a basic tokenizer for our limited purposes. It splits a SQL statement up into a series of segments, where a segment is one of: - identifiers - left or right parentheses - multi-line comments - single line comments - white-space sequences - string literals - consecutive strings of characters that are not one of the items above The aim is for us to be able to find function calls (identifiers followed by '('), and the associated closing ')') so we can augment these if needed. Args: sql: a SQL statement as a (possibly multi-line) string. Returns: For each call, the next token in the initial input. """ i = 0 # We use def statements here to make the logic more clear. The start_* functions return # true if i is the index of the start of that construct, while the end_* functions # return true if i point to the first character beyond that construct or the end of the # content. # # We don't currently need numbers so the tokenizer here just does sequences of # digits as a convenience to shrink the total number of tokens. If we needed numbers # later we would need a special handler for these much like strings. def start_multi_line_comment(s, i): return s[i] == '/' and i < len(s) - 1 and s[i + 1] == '*' def end_multi_line_comment(s, i): return s[i - 2] == '*' and s[i - 1] == '/' def start_single_line_comment(s, i): return s[i] == '-' and i < len(s) - 1 and s[i + 1] == '-' def end_single_line_comment(s, i): return s[i - 1] == '\n' def start_whitespace(s, i): return s[i].isspace() def end_whitespace(s, i): return not s[i].isspace() def start_number(s, i): return s[i].isdigit() def end_number(s, i): return not s[i].isdigit() def start_identifier(s, i): return s[i].isalpha() or s[i] == '_' or s[i] == '$' def end_identifier(s, i): return not(s[i].isalnum() or s[i] == '_') def start_string(s, i): return s[i] == '"' or s[i] == "'" def always_true(s, i): return True while i < len(sql): start = i if start_multi_line_comment(sql, i): i += 1 end_checker = end_multi_line_comment elif start_single_line_comment(sql, i): i += 1 end_checker = end_single_line_comment elif start_whitespace(sql, i): end_checker = end_whitespace elif start_identifier(sql, i): end_checker = end_identifier elif start_number(sql, i): end_checker = end_number elif start_string(sql, i): # Special handling here as we need to check for escaped closing quotes. quote = sql[i] end_checker = always_true i += 1 while i < len(sql) and sql[i] != quote: i += 2 if sql[i] == '\\' else 1 else: # We return single characters for everything else end_checker = always_true i += 1 while i < len(sql) and not end_checker(sql, i): i += 1 (yield sql[start:i]) def tokenize(sql): """ This is a basic tokenizer for our limited purposes. It splits a SQL statement up into a series of segments, where a segment is one of: - identifiers - left or right parentheses - multi-line comments - single line comments - white-space sequences - string literals - consecutive strings of characters that are not one of the items above The aim is for us to be able to find function calls (identifiers followed by '('), and the associated closing ')') so we can augment these if needed. Args: sql: a SQL statement as a (possibly multi-line) string. Returns: A list of strings corresponding to the groups above. """ return list(_next_token(sql)) ================================================ FILE: datalab/data/commands/__init__.py ================================================ from __future__ import absolute_import # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from . import _sql __all__ = ['_sql'] ================================================ FILE: datalab/data/commands/_sql.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - %%arguments IPython Cell Magic Functionality.""" from __future__ import absolute_import from __future__ import print_function from __future__ import unicode_literals from builtins import str from past.builtins import basestring try: import IPython import IPython.core.magic except ImportError: raise Exception('This module can only be loaded in ipython.') import argparse import datetime import imp import re import sys import time import datalab.bigquery import datalab.data from datalab.utils.commands import CommandParser, handle_magic_line def _create_sql_parser(): sql_parser = CommandParser(prog="%%sql", formatter_class=argparse.RawDescriptionHelpFormatter, description=""" Create a named SQL module with one or more queries. The cell body should contain an optional initial part defining the default values for the variables, if any, using Python code, followed by one or more queries. Queries should start with 'DEFINE QUERY ' in order to bind them to . in the notebook (as datalab.data.SqlStament instances). The final query can optionally omit 'DEFINE QUERY ', as using the module name in places where a SqlStatement is expected will resolve to the final query in the module. Queries can refer to variables with '$', as well as refer to other queries within the same module, making it easy to compose nested queries and test their parts. The Python code defining the variable default values can assign scalar or list/tuple values to variables, or one of the special functions 'datestring' and 'source'. When a variable with a 'datestring' default is expanded it will expand to a formatted string based on the current date, while a 'source' default will expand to a table whose name is based on the current date. datestring() takes two named arguments, 'format' and 'offset'. The former is a format string that is the same as for Python's time.strftime function. The latter is a string containing a comma-separated list of expressions such as -1y, +2m, etc; these are offsets from the time of expansion that are applied in order. The suffix (y, m, d, h, M) correspond to units of years, months, days, hours and minutes, while the +n or -n prefix is the number of units to add or subtract from the time of expansion. Three special values 'now', 'today' and 'yesterday' are also supported; 'today' and 'yesterday' will be midnight UTC on the current date or previous days date. source() can take a 'name' argument for a fixed table name, or 'format' and 'offset' arguments similar to datestring(), but unlike datestring() will resolve to a Table with the specified name. """) sql_parser.add_argument('-m', '--module', help='The name for this SQL module') sql_parser.add_argument('-d', '--dialect', help='BigQuery SQL dialect', choices=['legacy', 'standard']) sql_parser.add_argument('-b', '--billing', type=int, help='BigQuery billing tier') sql_parser.set_defaults(func=lambda args, cell: sql_cell(args, cell)) return sql_parser _sql_parser = _create_sql_parser() # Register the line magic as well as the cell magic so we can at least give people help # without requiring them to enter cell content first. @IPython.core.magic.register_line_cell_magic def sql(line, cell=None): """ Create a SQL module with one or more queries. Use %sql --help for more details. The supported syntax is: %%sql [--module ] [] [] [] At least one query should be present. Named queries should start with: DEFINE QUERY on a line by itself. Args: args: the optional arguments following '%%sql'. cell: the contents of the cell; Python code for arguments followed by SQL queries. """ if cell is None: _sql_parser.print_help() else: return handle_magic_line(line, cell, _sql_parser) def _date(val, offset=None): """ A special pseudo-type for pipeline arguments. This allows us to parse dates as Python datetimes, including special values like 'now' and 'today', as well as apply offsets to the datetime. Args: val: a string containing the value for the datetime. This can be 'now', 'today' (midnight at start of day), 'yesterday' (midnight at start of yesterday), or a formatted date that will be passed to the datetime constructor. Note that 'now' etc are assumed to be in UTC. offset: for date arguments a string containing a comma-separated list of relative offsets to apply of the form where is an integer and is a single character unit (d=day, m=month, y=year, h=hour, m=minute). Returns: A Python datetime resulting from starting at and applying the sequence of deltas specified in . """ if val is None: return val if val == '' or val == 'now': when = datetime.datetime.utcnow() elif val == 'today': dt = datetime.datetime.utcnow() when = datetime.datetime(dt.year, dt.month, dt.day) elif val == 'yesterday': dt = datetime.datetime.utcnow() - datetime.timedelta(1) when = datetime.datetime(dt.year, dt.month, dt.day) else: when = datetime.datetime.strptime(val, "%Y%m%d") if offset is not None: for part in offset.split(','): unit = part[-1] quantity = int(part[:-1]) # We can use timedelta for days and under, but not for years and months if unit == 'y': when = datetime.datetime(year=when.year + quantity, month=when.month, day=when.day, hour=when.hour, minute=when.minute) elif unit == 'm': new_year = when.year new_month = when.month + quantity if new_month < 1: new_month = -new_month new_year += 1 + (new_month // 12) new_month = 12 - new_month % 12 elif new_month > 12: new_year += (new_month - 1) // 12 new_month = 1 + (new_month - 1) % 12 when = datetime.datetime(year=new_year, month=new_month, day=when.day, hour=when.hour, minute=when.minute) elif unit == 'd': when += datetime.timedelta(days=quantity) elif unit == 'h': when += datetime.timedelta(hours=quantity) elif unit == 'M': when += datetime.timedelta(minutes=quantity) return when def _resolve_table(v, format, delta): try: when = _date(v, delta) v = time.strftime(format, when.timetuple()) except Exception: pass return datalab.bigquery.Table(v) def _make_string_formatter(f, offset=None): """ A closure-izer for string arguments that include a format and possibly an offset. """ format = f delta = offset return lambda v: time.strftime(format, (_date(v, delta)).timetuple()) def _make_table_formatter(f, offset=None): """ A closure-izer for table arguments that include a format and possibly an offset. """ format = f delta = offset return lambda v: _resolve_table(v, format, delta) def _make_table(v): return datalab.bigquery.Table(v) def _datestring(format, offset=''): return {'type': 'datestring', 'format': format, 'offset': offset} def _table(name=None, format=None, offset=''): return {'type': 'table', 'name': name, 'format': format, 'offset': offset} def _arguments(code, module): """Define pipeline arguments. Args: code: the Python code to execute that defines the arguments. """ arg_parser = CommandParser.create('') try: # Define our special argument 'types' and add them to the environment. builtins = {'source': _table, 'datestring': _datestring} env = {} env.update(builtins) # Execute the cell which should be one or more calls to arg(). exec(code, env) # Iterate through the module dictionary. For any newly defined objects, # add args to the parser. for key in env: # Skip internal/private stuff. if key in builtins or key[0] == '_': continue # If we want to support importing query modules into other query modules, uncomment next 4 # Skip imports but add them to the module # if isinstance(env[key], types.ModuleType): # module.__dict__[key] = env[key] # continue val = env[key] key = '--%s' % key if isinstance(val, bool): if val: arg_parser.add_argument(key, default=val, action='store_true') else: arg_parser.add_argument(key, default=val, action='store_false') elif isinstance(val, basestring) or isinstance(val, int) or isinstance(val, float) \ or isinstance(val, int): arg_parser.add_argument(key, default=val) elif isinstance(val, list): arg_parser.add_argument(key, default=val, nargs='+') elif isinstance(val, tuple): arg_parser.add_argument(key, default=list(val), nargs='+') # Is this one of our pseudo-types for dates/tables? elif isinstance(val, dict) and 'type' in val: if val['type'] == 'datestring': arg_parser.add_argument(key, default='', type=_make_string_formatter(val['format'], offset=val['offset'])) elif val['type'] == 'table': if val['format'] is not None: arg_parser.add_argument(key, default='', type=_make_table_formatter(val['format'], offset=val['offset'])) else: arg_parser.add_argument(key, default=val['name'], type=_make_table) else: raise Exception('Cannot generate argument for %s of type %s' % (key, type(val))) else: raise Exception('Cannot generate argument for %s of type %s' % (key, type(val))) except Exception as e: print("%%sql arguments: %s from code '%s'" % (str(e), str(code))) return arg_parser def _split_cell(cell, module): """ Split a hybrid %%sql cell into the Python code and the queries. Populates a module with the queries. Args: cell: the contents of the %%sql cell. module: the module that the contents will populate. Returns: The default (last) query for the module. """ lines = cell.split('\n') code = None last_def = -1 name = None define_wild_re = re.compile('^DEFINE\s+.*$', re.IGNORECASE) define_re = re.compile('^DEFINE\s+QUERY\s+([A-Z]\w*)\s*?(.*)$', re.IGNORECASE) select_re = re.compile('^SELECT\s*.*$', re.IGNORECASE) standard_sql_re = re.compile('^(CREATE|WITH|INSERT|DELETE|UPDATE)\s*.*$', re.IGNORECASE) # TODO(gram): a potential issue with this code is if we have leading Python code followed # by a SQL-style comment before we see SELECT/DEFINE. When switching to the tokenizer see # if we can address this. for i, line in enumerate(lines): define_match = define_re.match(line) select_match = select_re.match(line) standard_sql_match = standard_sql_re.match(line) if i: prior_content = ''.join(lines[:i]).strip() if select_match: # Avoid matching if previous token was '(' or if Standard SQL is found # TODO: handle the possibility of comments immediately preceding SELECT select_match = len(prior_content) == 0 or \ (prior_content[-1] != '(' and not standard_sql_re.match(prior_content)) if standard_sql_match: standard_sql_match = len(prior_content) == 0 or not standard_sql_re.match(prior_content) if define_match or select_match or standard_sql_match: # If this is the first query, get the preceding Python code. if code is None: code = ('\n'.join(lines[:i])).strip() if len(code): code += '\n' elif last_def >= 0: # This is not the first query, so gather the previous query text. query = '\n'.join([line for line in lines[last_def:i] if len(line)]).strip() if select_match and name != datalab.data._utils._SQL_MODULE_MAIN and len(query) == 0: # Avoid DEFINE query name\nSELECT ... being seen as an empty DEFINE followed by SELECT continue # Save the query statement = datalab.data.SqlStatement(query, module) module.__dict__[name] = statement # And set the 'last' query to be this too module.__dict__[datalab.data._utils._SQL_MODULE_LAST] = statement # Get the query name and strip off our syntactic sugar if appropriate. if define_match: name = define_match.group(1) lines[i] = define_match.group(2) else: name = datalab.data._utils._SQL_MODULE_MAIN # Save the starting line index of the new query last_def = i else: define_wild_match = define_wild_re.match(line) if define_wild_match: raise Exception('Expected "DEFINE QUERY "') if last_def >= 0: # We were in a query so save this tail query. query = '\n'.join([line for line in lines[last_def:] if len(line)]).strip() statement = datalab.data.SqlStatement(query, module) module.__dict__[name] = statement module.__dict__[datalab.data._utils._SQL_MODULE_LAST] = statement if code is None: code = '' module.__dict__[datalab.data._utils._SQL_MODULE_ARGPARSE] = _arguments(code, module) return module.__dict__.get(datalab.data._utils._SQL_MODULE_LAST, None) def sql_cell(args, cell): """Implements the SQL cell magic for ipython notebooks. The supported syntax is: %%sql [--module ] [] [] [] At least one query should be present. Named queries should start with: DEFINE QUERY on a line by itself. Args: args: the optional arguments following '%%sql'. cell: the contents of the cell; Python code for arguments followed by SQL queries. """ name = args['module'] if args['module'] else '_sql_cell' module = imp.new_module(name) query = _split_cell(cell, module) ipy = IPython.get_ipython() if not args['module']: # Execute now if query: return datalab.bigquery.Query(query, values=ipy.user_ns) \ .execute(dialect=args['dialect'], billing_tier=args['billing']).results else: # Add it as a module sys.modules[name] = module exec('import %s' % name, ipy.user_ns) ================================================ FILE: datalab/kernel/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Datalab - notebook functionality.""" import httplib2 as _httplib2 import requests as _requests try: import IPython as _IPython import IPython.core.magic as _magic # noqa import IPython.core.interactiveshell as _shell from IPython import get_ipython # noqa except ImportError: raise Exception('This package requires an IPython notebook installation') import datalab.context as _context # Import the modules that do cell magics. import datalab.bigquery.commands import datalab.context.commands import datalab.data.commands import datalab.stackdriver.commands import datalab.storage.commands import datalab.utils.commands # mlalpha modules require TensorFlow, CloudML SDK, and DataFlow (installed with CloudML SDK). # These are big dependencies and users who want to use Bigquery/Storage features may not # want to install them. # This __init__.py file is called when Jupyter/Datalab loads magics on startup. We don't want # Jupyter+pydatalab fail to start because of missing TensorFlow/DataFlow. So we ignore import # errors on mlalpha commands. try: import datalab.mlalpha.commands except: print('TensorFlow and CloudML SDK are required.') _orig_request = _httplib2.Http.request _orig_init = _requests.Session.__init__ _orig_run_cell_magic = _shell.InteractiveShell.run_cell_magic _orig_run_line_magic = _shell.InteractiveShell.run_line_magic def load_ipython_extension(shell): """ Called when the extension is loaded. Args: shell - (NotebookWebApplication): handle to the Notebook interactive shell instance. """ # Inject our user agent on all requests by monkey-patching a wrapper around httplib2.Http.request. def _request(self, uri, method="GET", body=None, headers=None, redirections=_httplib2.DEFAULT_MAX_REDIRECTS, connection_type=None): if headers is None: headers = {} headers['user-agent'] = 'GoogleCloudDataLab/1.0' return _orig_request(self, uri, method=method, body=body, headers=headers, redirections=redirections, connection_type=connection_type) _httplib2.Http.request = _request # Similarly for the requests library. def _init_session(self): _orig_init(self) self.headers['User-Agent'] = 'GoogleCloudDataLab/1.0' _requests.Session.__init__ = _init_session # Be more tolerant with magics. If the user specified a cell magic that doesn't # exist and an empty cell body but a line magic with that name exists, run that # instead. Conversely, if the user specified a line magic that doesn't exist but # a cell magic exists with that name, run the cell magic with an empty body. def _run_line_magic(self, magic_name, line): fn = self.find_line_magic(magic_name) if fn is None: cm = self.find_cell_magic(magic_name) if cm: return _run_cell_magic(self, magic_name, line, None) return _orig_run_line_magic(self, magic_name, line) def _run_cell_magic(self, magic_name, line, cell): if len(cell) == 0 or cell.isspace(): fn = self.find_line_magic(magic_name) if fn: return _orig_run_line_magic(self, magic_name, line) # IPython will complain if cell is empty string but not if it is None cell = None return _orig_run_cell_magic(self, magic_name, line, cell) _shell.InteractiveShell.run_cell_magic = _run_cell_magic _shell.InteractiveShell.run_line_magic = _run_line_magic # Define global 'project_id' and 'set_project_id' functions to manage the default project ID. We # do this conditionally in a try/catch # to avoid the call to Context.default() when running tests # which mock IPython.get_ipython(). def _get_project_id(): try: return _context.Context.default().project_id except Exception: return None def _set_project_id(project_id): context = _context.Context.default() context.set_project_id(project_id) def _get_bq_dialect(): return datalab.bigquery.Dialect.default().bq_dialect def _set_bq_dialect(bq_dialect): datalab.bigquery.Dialect.default().set_bq_dialect(bq_dialect) try: if 'datalab_project_id' not in _IPython.get_ipython().user_ns: _IPython.get_ipython().user_ns['datalab_project_id'] = _get_project_id _IPython.get_ipython().user_ns['set_datalab_project_id'] = _set_project_id if 'datalab_bq_dialect' not in _IPython.get_ipython().user_ns: _IPython.get_ipython().user_ns['datalab_bq_dialect'] = _get_bq_dialect _IPython.get_ipython().user_ns['set_datalab_bq_dialect'] = _set_bq_dialect except TypeError: pass def unload_ipython_extension(shell): _shell.InteractiveShell.run_cell_magic = _orig_run_cell_magic _shell.InteractiveShell.run_line_magic = _orig_run_line_magic _requests.Session.__init__ = _orig_init _httplib2.Http.request = _orig_request try: del _IPython.get_ipython().user_ns['project_id'] del _IPython.get_ipython().user_ns['set_project_id'] except Exception: pass # We mock IPython for tests so we need this. # TODO(gram): unregister imports/magics/etc. ================================================ FILE: datalab/notebook/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Datalab - notebook extension functionality.""" try: import IPython as _ except ImportError: raise Exception('This package requires an IPython notebook installation') __all__ = ['_'] def _jupyter_nbextension_paths(): return [dict(section="notebook", src="static", dest="gcpdatalab")] ================================================ FILE: datalab/notebook/static/bigquery.css ================================================ /* * Copyright 2015 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ table.bqsv { font-family: inherit; font-size: smaller; } table.bqsv th, table.bqsv td { border: solid 1px #cfcfcf; } th.bqsv_expanded, th.bqsv_collapsed { background-color: #f7f7f7; } th.bqsv_colheader { font-weight: bold; background-color: #e7e7e7; } tbody.bqsv_hidden { display: none; } th.bqsv_expanded:before { content: '\25be ' } th.bqsv_collapsed:before { content: '\25b8 ' } ================================================ FILE: datalab/notebook/static/bigquery.ts ================================================ /* * Copyright 2015 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ /// module BigQuery { // Event handler to toggle visibility of a nested schema table. function _toggleNode(e: any): void { var node = e.target; var expand = node.className == 'bqsv_collapsed'; node.className = expand ? 'bqsv_expanded' : 'bqsv_collapsed'; var tgroup = node.parentNode.nextSibling; tgroup.className = expand ? 'bqsv_visible' : 'bqsv_hidden'; } // Helper function to recursively render a table schema. function _renderSchema(table: any, schema: any, title: string, includeColumnHeaders: boolean, columns: any): void { // Create a tbody element to hold the entities for this level. We group them so // we can easily collapse/expand the level. var tbody = document.createElement('tbody'); for (var i = 0; i < schema.length; i++) { if (i == 0) { if (title.length > 0) { // title.length > 0 implies we are in a nested table. Create a title header row // for this nested table with a click handler and hide the tbody. tbody.className = 'bqsv_hidden'; var th = document.createElement('th'); th.colSpan = columns.length; th.className = 'bqsv_collapsed'; th.textContent = title.substring(1); // skip the leading '.' th.addEventListener('click', _toggleNode); var tr = document.createElement('tr'); tr.appendChild(th); table.appendChild(tr); } else { // We are in the top-level table; add a header row with the column labels. tbody.className = 'bqsv_visible'; if (includeColumnHeaders) { // First line; show column headers. var tr = document.createElement('tr'); for (var j = 0; j < columns.length; j++) { var th = document.createElement('th'); th.textContent = columns[j]; th.className = 'bqsv_colheader'; tr.appendChild(th); } table.appendChild(tr); } } } // Add the details for the current row to the tbody. var field = schema[i]; var tr = document.createElement('tr'); for (var j = 0; j < columns.length; j++) { var td = document.createElement('td'); var v = field[columns[j]]; td.textContent = v == undefined ? '' : v; tr.appendChild(td); } tbody.appendChild(tr); } // Add the tbody with all the rows to the table. table.appendChild(tbody); // Recurse into any nested tables. for (var i = 0; i < schema.length; i++) { var field = schema[i]; if (field.type == 'RECORD') { _renderSchema(table, field.fields, title + '.' + field.name, false, columns); } } } // Top-level public function for schema rendering. export function renderSchema(dom: any, schema: any) { var columns = ['name', 'type', 'mode', 'description']; var table = document.createElement('table'); table.className = 'bqsv'; _renderSchema(table, schema, '', /*includeColumnHeaders*/ true, columns); dom.appendChild(table); } } export = BigQuery; ================================================ FILE: datalab/notebook/static/charting.css ================================================ /* * Copyright 2015 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ table.google-visualization-table-table, table.dataframe { font-family: inherit; font-size: smaller; } tr.gchart-table-row { } tr.gchart-table-headerrow, table.dataframe thead th { font-weight: bold; background-color: #e7e7e7; } tr.gchart-table-oddrow, table.dataframe tr:nth-child(odd) { background-color: #f7f7f7; } tr.gchart-table-selectedTableRow { background-color: #e3f2fd; } tr.gchart-table-hoverrow, table.dataframe tr:hover { background-color: #bbdefb; } td.gchart-table-cell, table.dataframe td { border: solid 1px #cfcfcf; } td.gchart-table-rownumcell, table.dataframe tr th { border: solid 1px #cfcfcf; color: #999; } th.gchart-table-headercell, table.dataframe th { border: solid 1px #cfcfcf; } div.bqgc { display: flex; justify-content: center; } div.bqgc img { max-width: none; // Fix the conflict with maps and Bootstrap that messes up zoom controls. } .gchart-slider { width: 80%; float: left; } .gchart-slider-value { text-align: center; float: left; width: 20%; } .gchart-control { padding-top: 10px; padding-bottom: 10px; } .gchart-controls { font-size: 14px; color: #333333; background: #f4f4f4; padding: 10px; width: 180px; float: left; } .bqgc { padding: 0; max-width: 100%; } .bqgc-controlled { display: flex; flex-direction: row; justify-content:space-between; } .bqgc-container { display: block; } .bqgc-ml-metrics { display: flex; flex-direction: row; justify-content:left; } ================================================ FILE: datalab/notebook/static/charting.ts ================================================ /* * Copyright 2015 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ /// module Charting { declare var IPython:any; declare var datalab:any; // Wrappers for Plotly.js and Google Charts abstract class ChartLibraryDriver { chartModule:any; constructor(protected dom:HTMLElement, protected chartStyle:string) { } abstract requires(url: string, chartStyle:string):Array; init(chartModule:any):void { this.chartModule = chartModule; } abstract draw(data:any, options:any):void; abstract getStaticImage(callback:Function):void; abstract addChartReadyHandler(handler:Function):void; addPageChangedHandler(handler:Function):void { } error(message:string):void { } } class PlotlyDriver extends ChartLibraryDriver { readyHandler:any; constructor(dom:HTMLElement, chartStyle:string) { super(dom, chartStyle) } requires(url: string, chartStyle:string):Array { return ['d3', 'plotly']; } public draw(data:any, options:any):void { /* * TODO(gram): if we start moving more chart types over to Plotly.js we should change the * shape of the data we pass to render so we don't need to reshape it here. Also, a fair * amount of the computation done here could be moved to Python code. We should just be * passing in the mostly complete layout object in JSON, for example. */ var xlabels: Array = []; var points: Array = []; var layout: any = { xaxis: {}, yaxis: {}, height: 300, margin: { b: 60, t: 60, l: 60, r: 60 } }; if (options.title) { layout.title = options.title; } var minX: number = undefined; var maxX: number = undefined; if ('hAxis' in options) { if ('minValue' in options.hAxis) { minX = options.hAxis.minValue; } if ('maxValue' in options.hAxis) { maxX = options.hAxis.maxValue; } if (minX != undefined || maxX != undefined) { layout.xaxis.range = [minX, maxX]; } } var minY: number = undefined; var maxY: number = undefined; if ('vAxis' in options) { if ('minValue' in options.vAxis) { minY = options.vAxis.minValue; } else if ('minValues' in options.vAxis) { minY = options.vAxis.minValues[0]; } if ('maxValue' in options.vAxis) { maxY = options.vAxis.maxValue; } else if ('maxValues' in options.vAxis) { maxY = options.vAxis.maxValues[0]; } if (minY != undefined || maxY != undefined) { layout.yaxis.range = [minY, maxY]; } if ('minValues' in options.vAxis) { minY = options.vAxis.minValues[1]; // for second axis below } if ('maxValues' in options.vAxis) { maxY = options.vAxis.maxValues[1]; // for second axis below } } if (options.xAxisTitle) { layout.xaxis.title = options.xAxisTitle; } if (options.xAxisSide) { layout.xaxis.side = options.xAxisSide; } if (options.yAxisTitle) { layout.yaxis.title = options.yAxisTitle; } if (options.yAxesTitles) { layout.yaxis.title = options.yAxesTitles[0]; layout.yaxis2 = { title: options.yAxesTitles[1], side: 'right', overlaying: 'y' }; if (minY != undefined || maxY != undefined) { layout.yaxis2.range = [minY, maxY]; } } if ('width' in options) { layout.width = options.width; } if ('height' in options) { layout.height = options.height; if ('width' in options) { layout.autosize = false; } } var pdata: Array = []; if (this.chartStyle == 'line' || this.chartStyle == 'scatter') { var hoverCol: number = 0; var x: Array = []; // First col is X, other cols are Y's and optional hover text only column var y: Array = []; var hover: Array = []; for (var c = 1; c < data.cols.length; c++) { x[c - 1] = []; y[c - 1] = []; var line:any = { x: x[c - 1], y: y[c - 1], name: data.cols[c].label, type: 'scatter', mode: this.chartStyle == 'scatter' ? 'markers' : 'lines' }; if (options.hoverOnly) { hover[c - 1] = []; line.text = hover[c - 1]; line.hoverinfo = 'text'; } if (options.yAxesTitles && (c % 2) == 0) { line.yaxis = 'y2'; } pdata.push(line); } for (var c = 1; c < data.cols.length; c++) { if (c == hoverCol) { continue; } for (var r = 0; r < data.rows.length; r++) { var entry:Array = data.rows[r].c; if ('v' in entry[c]) { var xVal = entry[0].v; var yVal = entry[c].v; if (options.hoverOnly) { // Each column is a dict with two values, one for y and one for // hover. Extract these. var hoverVal:any; var yDict:any = yVal; for (var prop in yDict) { var val = yDict[prop]; if (prop == options.hoverOnly) { hoverVal = val; } else { yVal = val; } } // TODO(gram): we may want to add explicit hover text this even without hoverOnly. var xlabel:any = options.xAxisTitle || data.cols[0].label; var ylabel:any = options.yAxisTitle || data.cols[c].label; var prefix = ''; if (options.yAxisTitle) { prefix += data.cols[c].label + ': '; } hover[c - 1].push(prefix + options.hoverOnly + '=' + hoverVal + ', ' + xlabel + '=' + xVal + ', ' + ylabel + '=' + yVal); } x[c - 1].push(xVal); y[c - 1].push(yVal); } } } } else if (this.chartStyle == 'heatmap') { var size:number = 200 + data.cols.length * 50; if (size > 800) size = 800; layout.height = size; layout.width = size; layout.autosize = false; for (var i = 0; i < data.cols.length; i++) { xlabels[i] = data.cols[i].label; } var ylabels = [].concat(xlabels); // Plotly draws the first row at the bottom, not the top, so we need // to reverse the y and z array ordering. // We will need to tweak this a bit if we later support non-square maps. ylabels.reverse(); var hovertext: Array> = []; var hoverx:string = options.xAxisTitle || 'x'; var hovery = options.yAxisTitle || 'y'; for (var i = 0; i < data.rows.length; i++) { var entry:Array = data.rows[i].c; var row:Array = []; var hoverrow:Array = []; for (var j = 0; j < data.cols.length; j++) { row[j] = entry[j].v; hoverrow[j] = hoverx + '= ' + xlabels[j] + ', ' + hovery + '= ' + ylabels[i] + ': ' + row[j]; } points[i] = row; hovertext[i] = hoverrow; } points.reverse(); layout.hovermode = 'closest'; pdata = [{ x: xlabels, y: ylabels, z: points, type: 'heatmap', text: hovertext, hoverinfo: 'text' }]; if (options.colorScale) { pdata[0].colorscale = [ [0, options.colorScale.min], [1, options.colorScale.max] ]; } else { pdata[0].colorscale = [ [0, 'red'], [0.5, 'gray'], [1, 'blue'] ]; } if (options.hideScale) { pdata[0].showscale = false; } if (options.annotate) { layout.annotations = []; for (var i = 0; i < pdata[0].y.length; i++) { for (var j = 0; j < pdata[0].x.length; j++) { var currentValue = pdata[0].z[i][j]; var textColor = (currentValue == 0.0) ? 'black' : 'white'; var result = { xref: 'x1', yref: 'y1', x: pdata[0].x[j], y: pdata[0].y[i], text: pdata[0].z[i][j].toPrecision(3), showarrow: false, font: { color: textColor } }; layout.annotations.push(result); } } } } this.chartModule.newPlot(this.dom.id, pdata, layout, {displayModeBar: false}); if (this.readyHandler) { this.readyHandler(); } } getStaticImage(callback:Function):void { this.chartModule.Snapshot.toImage(document.getElementById(this.dom.id), {format: 'png'}).once('success', function (url:string) { callback(this.model, url); }); } addChartReadyHandler(handler:Function):void { this.readyHandler = handler; } } interface IStringMap { [key: string]: string; } class GChartsDriver extends ChartLibraryDriver { chart:any; nameMap: IStringMap = { annotation: 'AnnotationChart', area: 'AreaChart', columns: 'ColumnChart', bars: 'BarChart', bubbles: 'BubbleChart', calendar: 'Calendar', candlestick: 'CandlestickChart', combo: 'ComboChart', gauge: 'Gauge', geo: 'GeoChart', histogram: 'Histogram', line: 'LineChart', map: 'Map', org: 'OrgChart', paged_table: 'Table', pie: 'PieChart', sankey: 'Sankey', scatter: 'ScatterChart', stepped_area: 'SteppedAreaChart', table: 'Table', timeline: 'Timeline', treemap: 'TreeMap', }; scriptMap: IStringMap = { annotation: 'annotationchart', calendar: 'calendar', gauge: 'gauge', geo: 'geochart', map: 'map', org: 'orgchart', paged_table: 'table', sankey: 'sankey', table: 'table', timeline: 'timeline', treemap: 'treemap' }; constructor(dom:HTMLElement, chartStyle:string) { super(dom, chartStyle); } requires(url: string, chartStyle:string):Array { var chartScript:string = 'corechart'; if (chartStyle in this.scriptMap) { chartScript = this.scriptMap[chartStyle]; } return [url + 'visualization!' + chartScript]; } init(chartModule:any):void { super.init(chartModule); var constructor:Function = this.chartModule[this.nameMap[this.chartStyle]]; this.chart = new (constructor)(this.dom); } error(message:string):void { this.chartModule.errors.addError(this.dom, 'Unable to render the chart', message, {showInTooltip: false}); } draw(data:any, options:any):void { console.log('Drawing with options ' + JSON.stringify(options)); this.chart.draw(new this.chartModule.DataTable(data), options); } getStaticImage(callback:Function):void { if (this.chart.getImageURI) { callback(this.chart.getImageURI()); } } addChartReadyHandler(handler:Function) { this.chartModule.events.addListener(this.chart, 'ready', handler); } addPageChangedHandler(handler:Function) { this.chartModule.events.addListener(this.chart, 'page', function (e:any) { handler(e.page); }); } } class Chart { dataCache:any; // TODO: add interface types for the caches. optionsCache:any; hasIPython:boolean; cellElement:HTMLElement; totalRows:number; constructor(protected driver:ChartLibraryDriver, protected dom:Element, protected controlIds:Array, protected base_options:any, protected refreshData:any, protected refreshInterval:number, totalRows:number) { this.totalRows = totalRows || -1; // Total rows in all (server-side) data. this.dataCache = {}; this.optionsCache = {}; this.hasIPython = false; try { if (IPython) { this.hasIPython = true; } } catch (e) { } (this.dom).innerHTML = ''; this.removeStaticChart(); this.addControls(); // Generate and add a new static chart once chart is ready. var _this = this; this.driver.addChartReadyHandler(function () { _this.addStaticChart(); }); } // Convert any string fields that are date type to JS Dates. public static convertDates(data:any):void { for (var i = 0; i < data.cols.length; i++) { if (data.cols[i].type == 'date' || data.cols[i].type == 'datetime') { var rows = data.rows; for (var j = 0; j < rows.length; j++) { rows[j].c[i].v = new Date(rows[j].c[i].v); } } else if (data.cols[i].type == 'timeofday') { var rows = data.rows; for (var j = 0; j < rows.length; j++) { var timeInSeconds = rows[j].c[i].v.split('.')[0]; rows[j].c[i].v = timeInSeconds.split(':').map( function(n:string) { return parseInt(n, 10); }); } } } } // Extend the properties in a 'base' object with the changes in an 'update' object. // We can add properties or override properties but not delete yet. private static extend(base:any, update:any):void { for (var p in update) { if (typeof base[p] !== 'object' || !base.hasOwnProperty(p)) { base[p] = update[p]; } else { this.extend(base[p], update[p]); } } } // Get the IPython cell associated with this chart. private getCell() { if (!this.hasIPython) { return undefined; } var cells = IPython.notebook.get_cells(); for (var cellIndex in cells) { var cell = cells[cellIndex]; if (cell.element && cell.element.length) { var element = cell.element[0]; var chartDivs = element.getElementsByClassName('bqgc'); if (chartDivs && chartDivs.length) { for (var i = 0; i < chartDivs.length; i++) { if (chartDivs[i].id == this.dom.id) { return cell; } } } } } return undefined; } protected getRefreshHandler(useCache:boolean):Function { var _this = this; return function () { _this.refresh(useCache); }; } // Bind event handlers to the chart controls, if any. private addControls():void { if (!this.controlIds) { return; } var controlHandler = this.getRefreshHandler(true); for (var i = 0; i < this.controlIds.length; i++) { var id = this.controlIds[i]; var split = id.indexOf(':'); var control:HTMLInputElement; if (split >= 0) { // Checkbox group. var count = parseInt(id.substring(split + 1)); var base = id.substring(0, split + 1); for (var j = 0; j < count; j++) { control = document.getElementById(base + j); control.disabled = !this.hasIPython; control.addEventListener('change', function() { controlHandler(); }); } continue; } // See if we have an associated control that needs dual binding. control = document.getElementById(id); if (!control) { // Kernel restart? return; } control.disabled = !this.hasIPython; var textControl = document.getElementById(id + '_value'); if (textControl) { textControl.disabled = !this.hasIPython; textControl.addEventListener('change', function () { if (control.value != textControl.value) { control.value = textControl.value; controlHandler(); } }); control.addEventListener('change', function () { textControl.value = control.value; controlHandler(); }); } else { control.addEventListener('change', function() { controlHandler(); }); } } } // Iterate through any widget controls and build up a JSON representation // of their values that can be passed to the Python kernel as part of the // magic to fetch data (also used as part of the cache key). protected getControlSettings():any { var env:any = {}; if (this.controlIds) { for (var i = 0; i < this.controlIds.length; i++) { var id = this.controlIds[i]; var parts = id.split('__'); var varName = parts[1]; var splitPoint = varName.indexOf(':'); if (splitPoint >= 0) { // this is a checkbox group var count = parseInt(varName.substring(splitPoint + 1)); varName = varName.substring(0, splitPoint); var cbBaseId = parts[0] + '__' + varName + ':'; var list:Array = []; env[varName] = list; for (var j = 0; j < count; j++) { var cb = document.getElementById(cbBaseId + j); if (!cb) { // Stale refresh; user re-executed cell. return undefined; } if (cb.checked) { list.push(cb.value); } } } else { var e = document.getElementById(id); if (!e) { // Stale refresh; user re-executed cell. return undefined; } if (e && e.type == 'checkbox') { // boolean env[varName] = e.checked; } else { // picker/slider/text env[varName] = e.value; } } } } return env; } // Get a string representation of the current environment - i.e. control settings and // refresh data. This is used as a cache key. private getEnvironment():string { var controls:any = this.getControlSettings(); if (controls == undefined) { // This means the user has re-executed the cell and our controls are gone. return undefined; } var env:any = {controls: controls}; Chart.extend(env, this.refreshData); return JSON.stringify(env); } protected refresh(useCache:boolean):void { // TODO(gram): remember last cache key and don't redraw chart if cache // key is the same unless this is an ML key and the number of data points has changed. this.removeStaticChart(); var env:string = this.getEnvironment(); if (env == undefined) { // This means the user has re-executed the cell and our controls are gone. console.log('No chart control environment; abandoning refresh'); return; } if (useCache && env in this.dataCache) { this.draw(this.dataCache[env], this.optionsCache[env]); return; } var code = '%_get_chart_data\n' + env; // TODO: hook into the notebook UI to enable/disable 'Running...' while we fetch more data. if (!this.cellElement) { var cell = this.getCell(); if (cell && cell.element && cell.element.length == 1) { this.cellElement = cell.element[0]; } } // Start the cell spinner in the notebook UI. if (this.cellElement) { this.cellElement.classList.remove('completed'); } var _this = this; datalab.session.execute(code, function (error:string, response:any) { _this.handleNewData(env, error, response); }); } private handleNewData(env: any, error:any, response: any) { var data = response.data; // Stop the cell spinner in the notebook UI. if (this.cellElement) { this.cellElement.classList.add('completed'); } if (data == undefined || data.cols == undefined) { error = 'No data'; } if (error) { this.driver.error(error); return; } this.refreshInterval = response.refresh_interval; if (this.refreshInterval == 0) { console.log('No more refreshes for ' + this.refreshData.name); } Chart.convertDates(data); var options = this.base_options; if (response.options) { // update any options. We need to make a copy so we don't break the base options. options = JSON.parse(JSON.stringify(options)); Chart.extend(options, response.options); } // Don't update or keep refreshing this if control settings have changed. var newEnv = this.getEnvironment(); if (env == newEnv) { console.log('Got refresh for ' + this.refreshData.name + ', ' + env); this.draw(data, options); } else { console.log('Stopping refresh for ' + env + ' as controls are now ' + newEnv) } } // Remove a static chart (PNG) from the notebook and the DOM. protected removeStaticChart():void { var cell = this.getCell(); if (cell) { var pngDivs = > cell.element[0].getElementsByClassName('output_png'); if (pngDivs) { for (var i = 0; i < pngDivs.length; i++) { pngDivs[i].innerHTML = ''; } } var cell_outputs = cell.output_area.outputs; var changed = true; while (changed) { changed = false; for (var outputIndex in cell_outputs) { var output = cell_outputs[outputIndex]; if (output.output_type == 'display_data' && output.metadata.source_id == this.dom.id) { cell_outputs.splice(outputIndex, 1); changed = true; break; } } } } else { // Not running under IPython; use a different approach and just clear the DOM. // Iterate through the IPython outputs... var outputDivs = document.getElementsByClassName('output_wrapper'); if (outputDivs) { for (var i = 0; i < outputDivs.length; i++) { // ...and any chart outputs in each... var outputDiv = outputDivs[i]; var chartDivs = outputDiv.getElementsByClassName('bqgc'); if (chartDivs) { for (var j = 0; j < chartDivs.length; j++) { // ...until we find the chart div ID we want... if (chartDivs[j].id == this.dom.id) { // ...then get any PNG outputs in that same output group... var pngDivs = >outputDiv. getElementsByClassName('output_png'); if (pngDivs) { for (var k = 0; k < pngDivs.length; k++) { // ... and clear their contents. pngDivs[k].innerHTML = ''; } } return; } } } } } } } // Add a static chart (PNG) to the notebook. The notebook will in turn add it to the DOM when // the notebook is opened. private addStaticChart():void { var _this = this; this.driver.getStaticImage(function (img:string) { _this.handleStaticChart(img); }); } private handleStaticChart(img: string) { if (img) { var cell = this.getCell(); if (cell) { var encoding = img.substr(img.indexOf(',') + 1); // strip leading base64 etc. var static_output = { metadata: { source_id: this.dom.id }, data: { 'image/png': encoding }, output_type: 'display_data' }; cell.output_area.outputs.push(static_output); } } } // Set up a refresh callback if we have a non-zero interval and the DOM element still exists // (i.e. output hasn't been cleared). private configureRefresh(refreshInterval:number):void { if (refreshInterval > 0 && document.getElementById(this.dom.id)) { window.setTimeout(this.getRefreshHandler(false), 1000 * refreshInterval); } } // Cache the current data and options and draw the chart. public draw(data:any, options:any):void { var env:string = this.getEnvironment(); this.dataCache[env] = data; this.optionsCache[env] = options; if ('cols' in data) { this.driver.draw(data, options); } this.configureRefresh(this.refreshInterval); } } //----------------------------------------------------------- // A special version of Chart for supporting paginated data. class PagedTable extends Chart { firstRow:number; pageSize:number; constructor(driver:ChartLibraryDriver, dom:HTMLElement, controlIds:Array, base_options:any, refreshData:any, refreshInterval:number, totalRows:number) { super(driver, dom, controlIds, base_options, refreshData, refreshInterval, totalRows); this.firstRow = 0; // Index of first row being displayed in page. this.pageSize = base_options.pageSize || 25; if (this.base_options.showRowNumber == undefined) { this.base_options.showRowNumber = true; } this.base_options.sort = 'disable'; var __this = this; this.driver.addPageChangedHandler(function (page:number) { __this.handlePageEvent(page); }); } // Get control settings for cache key. For paged table we add the first row offset of the table. protected getControlSettings():any { var env = super.getControlSettings(); if (env) { env.first = this.firstRow; } return env; } public draw(data:any, options:any):void { var count = this.pageSize; options.firstRowNumber = this.firstRow + 1; options.page = 'event'; if (this.totalRows < 0) { // We don't know where the end is, so we should have 'next' button. options.pagingButtonsConfiguration = this.firstRow > 0 ? 'both' : 'next'; } else { count = this.totalRows - this.firstRow; if (count > this.pageSize) { count = this.pageSize; } if (this.firstRow + count < this.totalRows) { // We are not on last page, so we should have 'next' button. options.pagingButtonsConfiguration = this.firstRow > 0 ? 'both' : 'next'; } else { // We are on last page if (this.firstRow == 0) { options.pagingButtonsConfiguration = 'none'; options.page = 'disable'; } else { options.pagingButtonsConfiguration = 'prev'; } } } super.draw(data, options); } // Handle page forward/back events. Page will only be 0 or 1. handlePageEvent(page:number):void { var offset = (page == 0) ? -1 : 1; this.firstRow += offset * this.pageSize; this.refreshData.first = this.firstRow; this.refreshData.count = this.pageSize; this.refresh(true); } } function convertListToDataTable(data:any):any { if (!data || !data.length) { return {cols: [], rows: []}; } var firstItem = data[0]; var names = Object.keys(firstItem); var columns = names.map(function (name) { return {id: name, label: name, type: typeof firstItem[name]} }); var rows = data.map(function (item:any) { var cells = names.map(function (name) { return {v: item[name]}; }); return {c: cells}; }); return {cols: columns, rows: rows}; } // The main render method, called from render() wrapper below. dom is the DOM element // for the chart, model is a set of parameters from Python, and options is a JSON // set of options provided by the user in the cell magic body, which takes precedence over // model. An initial set of data can be passed in as a final optional parameter. function _render(driver:ChartLibraryDriver, dom:HTMLElement, chartStyle:string, controlIds:Array, data:any, options:any, refreshData:any, refreshInterval:number, totalRows:number):void { require(["base/js/namespace"], function(Jupyter: any) { var url = "datalab/"; require(driver.requires(url, chartStyle), function (/* ... */) { // chart module should be last dependency in require() call... var chartModule = arguments[arguments.length - 1]; // See if it needs to be a member. driver.init(chartModule); options = options || {}; var chart:Chart; if (chartStyle == 'paged_table') { chart = new PagedTable(driver, dom, controlIds, options, refreshData, refreshInterval, totalRows); } else { chart = new Chart(driver, dom, controlIds, options, refreshData, refreshInterval, totalRows); } Chart.convertDates(data); chart.draw(data, options); // Do we need to do anything to prevent it getting GCed? }); }); } export function render(driverName:string, dom:HTMLElement, events:any, chartStyle:string, controlIds:Array, data:any, options:any, refreshData:any, refreshInterval:number, totalRows:number):void { // If this is HTML from nbconvert we can't support paging so add some text making this clear. if (chartStyle == 'paged_table' && document.hasOwnProperty('_in_nbconverted')) { chartStyle = 'table'; var p = document.createElement("div"); p.innerHTML = '
(Truncated to first page of results)'; dom.parentNode.insertBefore(p, dom.nextSibling); } // Allocate an appropriate driver. var driver:ChartLibraryDriver; if (driverName == 'plotly') { driver = new PlotlyDriver(dom, chartStyle); } else if (driverName == 'gcharts') { driver = new GChartsDriver(dom, chartStyle); } else { throw new Error('Unsupported chart driver ' + driverName); } // Get data in form needed for GCharts. // We shouldn't need this; should be handled by caller. if (!data.cols && !data.rows) { data = this.convertListToDataTable(data); } // If we have a datalab session, we can go ahead and draw the chart; if not, add code to do the // drawing to an event handler for when the kernel is ready. if (IPython.notebook.kernel.is_connected()) { _render(driver, dom, chartStyle, controlIds, data, options, refreshData, refreshInterval, totalRows) } else { // If the kernel is not connected, wait for the event. events.on('kernel_ready.Kernel', function (e:any) { _render(driver, dom, chartStyle, controlIds, data, options, refreshData, refreshInterval, totalRows) }); } } } export = Charting; ================================================ FILE: datalab/notebook/static/element.ts ================================================ /* * Copyright 2015 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ /// module Element { // RequireJS plugin to resolve DOM elements. 'use strict'; var pendingCallbacks: any = null; function resolve(cbInfo: any): void { cbInfo.cb(document.getElementById(cbInfo.name)); } function domReadyCallback(): void { if (pendingCallbacks) { // Clear out pendingCallbacks, so any future requests are immediately resolved. var callbacks = pendingCallbacks; pendingCallbacks = null; callbacks.forEach(resolve); } } export function load(name: any, req: any, loadCallback: any, config: any): void { if (config.isBuild) { loadCallback(null); } else { var cbInfo = { name: name, cb: loadCallback }; if (document.readyState == 'loading') { if (!pendingCallbacks) { pendingCallbacks = []; document.addEventListener('DOMContentLoaded', domReadyCallback, false); } pendingCallbacks.push(cbInfo); } else { resolve(cbInfo); } } } } export = Element; ================================================ FILE: datalab/notebook/static/extern/d3.parcoords.css ================================================ .parcoords > svg, .parcoords > canvas { /*font: 14px sans-serif;*/ position: absolute; } .parcoords > canvas { pointer-events: none; } .parcoords rect.background { fill: transparent; } .parcoords rect.background:hover { fill: rgba(120,120,120,0.2); } .parcoords .resize rect { fill: rgba(0,0,0,0.1); } .parcoords rect.extent { fill: rgba(255,255,255,0.25); stroke: rgba(0,0,0,0.6); } .parcoords .axis line, .parcoords .axis path { fill: none; stroke: #222; shape-rendering: crispEdges; } .parcoords canvas { opacity: 1; -moz-transition: opacity 0.3s; -webkit-transition: opacity 0.3s; -o-transition: opacity 0.3s; } .parcoords canvas.faded { opacity: 0.25; } .parcoords_grid { text-align: center; } .parcoords_grid .row, .header { clear: left; font-size: 16px; line-height: 18px; height: 18px; } .parcoords_grid .row:nth-child(odd) { background: rgba(0,0,0,0.05); } .parcoords_grid .row:hover { background: green; } .parcoords_grid .header { font-weight: bold; } .parcoords_grid .cell { float: left; overflow: hidden; white-space: nowrap; width: 120px; height: 18px; } .parcoords_grid .col-0 { width: 110px; } ================================================ FILE: datalab/notebook/static/extern/d3.parcoords.js ================================================ d3.parcoords = function(config) { var __ = { data: [], highlighted: [], dimensions: [], dimensionTitles: {}, dimensionTitleRotation: 0, types: {}, brushed: false, mode: "default", rate: 20, width: 600, height: 300, margin: { top: 30, right: 0, bottom: 12, left: 0 }, color: "#069", composite: "source-over", alpha: 0.7, bundlingStrength: 0.5, bundleDimension: null, smoothness: 0.25, showControlPoints: false, hideAxis : [] }; extend(__, config); var pc = function(selection) { selection = pc.selection = d3.select(selection); __.width = selection[0][0].clientWidth; __.height = selection[0][0].clientHeight; // canvas data layers ["shadows", "marks", "foreground", "highlight"].forEach(function(layer) { canvas[layer] = selection .append("canvas") .attr("class", layer)[0][0]; ctx[layer] = canvas[layer].getContext("2d"); }); // svg tick and brush layers pc.svg = selection .append("svg") .attr("width", __.width) .attr("height", __.height) .append("svg:g") .attr("transform", "translate(" + __.margin.left + "," + __.margin.top + ")"); return pc; }; var events = d3.dispatch.apply(this,["render", "resize", "highlight", "brush", "brushend", "axesreorder"].concat(d3.keys(__))), w = function() { return __.width - __.margin.right - __.margin.left; }, h = function() { return __.height - __.margin.top - __.margin.bottom; }, flags = { brushable: false, reorderable: false, axes: false, interactive: false, shadows: false, debug: false }, xscale = d3.scale.ordinal(), yscale = {}, dragging = {}, line = d3.svg.line(), axis = d3.svg.axis().orient("left").ticks(5), g, // groups for axes, brushes ctx = {}, canvas = {}, clusterCentroids = []; // side effects for setters var side_effects = d3.dispatch.apply(this,d3.keys(__)) .on("composite", function(d) { ctx.foreground.globalCompositeOperation = d.value; }) .on("alpha", function(d) { ctx.foreground.globalAlpha = d.value; }) .on("width", function(d) { pc.resize(); }) .on("height", function(d) { pc.resize(); }) .on("margin", function(d) { pc.resize(); }) .on("rate", function(d) { rqueue.rate(d.value); }) .on("data", function(d) { if (flags.shadows){paths(__.data, ctx.shadows);} }) .on("dimensions", function(d) { xscale.domain(__.dimensions); if (flags.interactive){pc.render().updateAxes();} }) .on("bundleDimension", function(d) { if (!__.dimensions.length) pc.detectDimensions(); if (!(__.dimensions[0] in yscale)) pc.autoscale(); if (typeof d.value === "number") { if (d.value < __.dimensions.length) { __.bundleDimension = __.dimensions[d.value]; } else if (d.value < __.hideAxis.length) { __.bundleDimension = __.hideAxis[d.value]; } } else { __.bundleDimension = d.value; } __.clusterCentroids = compute_cluster_centroids(__.bundleDimension); }) .on("hideAxis", function(d) { if (!__.dimensions.length) pc.detectDimensions(); pc.dimensions(without(__.dimensions, d.value)); }); // expose the state of the chart pc.state = __; pc.flags = flags; // create getter/setters getset(pc, __, events); // expose events d3.rebind(pc, events, "on"); // tick formatting d3.rebind(pc, axis, "ticks", "orient", "tickValues", "tickSubdivide", "tickSize", "tickPadding", "tickFormat"); // getter/setter with event firing function getset(obj,state,events) { d3.keys(state).forEach(function(key) { obj[key] = function(x) { if (!arguments.length) { return state[key]; } var old = state[key]; state[key] = x; side_effects[key].call(pc,{"value": x, "previous": old}); events[key].call(pc,{"value": x, "previous": old}); return obj; }; }); }; function extend(target, source) { for (key in source) { target[key] = source[key]; } return target; }; function without(arr, item) { return arr.filter(function(elem) { return item.indexOf(elem) === -1; }) }; pc.autoscale = function() { // yscale var defaultScales = { "date": function(k) { return d3.time.scale() .domain(d3.extent(__.data, function(d) { return d[k] ? d[k].getTime() : null; })) .range([h()+1, 1]); }, "number": function(k) { return d3.scale.linear() .domain(d3.extent(__.data, function(d) { return +d[k]; })) .range([h()+1, 1]); }, "string": function(k) { var counts = {}, domain = []; // Let's get the count for each value so that we can sort the domain based // on the number of items for each value. __.data.map(function(p) { if (counts[p[k]] === undefined) { counts[p[k]] = 1; } else { counts[p[k]] = counts[p[k]] + 1; } }); domain = Object.getOwnPropertyNames(counts).sort(function(a, b) { return counts[a] - counts[b]; }); return d3.scale.ordinal() .domain(domain) .rangePoints([h()+1, 1]); } }; __.dimensions.forEach(function(k) { yscale[k] = defaultScales[__.types[k]](k); }); __.hideAxis.forEach(function(k) { yscale[k] = defaultScales[__.types[k]](k); }); // hack to remove ordinal dimensions with many values pc.dimensions(pc.dimensions().filter(function(p,i) { var uniques = yscale[p].domain().length; if (__.types[p] == "string" && (uniques > 60 || uniques < 2)) { return false; } return true; })); // xscale xscale.rangePoints([0, w()], 1); // canvas sizes pc.selection.selectAll("canvas") .style("margin-top", __.margin.top + "px") .style("margin-left", __.margin.left + "px") .attr("width", w()+2) .attr("height", h()+2); // default styles, needs to be set when canvas width changes ctx.foreground.strokeStyle = __.color; ctx.foreground.lineWidth = 1.4; ctx.foreground.globalCompositeOperation = __.composite; ctx.foreground.globalAlpha = __.alpha; ctx.highlight.lineWidth = 3; ctx.shadows.strokeStyle = "#dadada"; return this; }; pc.scale = function(d, domain) { yscale[d].domain(domain); return this; }; pc.flip = function(d) { //yscale[d].domain().reverse(); // does not work yscale[d].domain(yscale[d].domain().reverse()); // works return this; }; pc.commonScale = function(global, type) { var t = type || "number"; if (typeof global === 'undefined') { global = true; } // scales of the same type var scales = __.dimensions.concat(__.hideAxis).filter(function(p) { return __.types[p] == t; }); if (global) { var extent = d3.extent(scales.map(function(p,i) { return yscale[p].domain(); }).reduce(function(a,b) { return a.concat(b); })); scales.forEach(function(d) { yscale[d].domain(extent); }); } else { scales.forEach(function(k) { yscale[k].domain(d3.extent(__.data, function(d) { return +d[k]; })); }); } // update centroids if (__.bundleDimension !== null) { pc.bundleDimension(__.bundleDimension); } return this; };pc.detectDimensions = function() { pc.types(pc.detectDimensionTypes(__.data)); pc.dimensions(d3.keys(pc.types())); return this; }; // a better "typeof" from this post: http://stackoverflow.com/questions/7390426/better-way-to-get-type-of-a-javascript-variable pc.toType = function(v) { return ({}).toString.call(v).match(/\s([a-zA-Z]+)/)[1].toLowerCase(); }; // try to coerce to number before returning type pc.toTypeCoerceNumbers = function(v) { if ((parseFloat(v) == v) && (v != null)) { return "number"; } return pc.toType(v); }; // attempt to determine types of each dimension based on first row of data pc.detectDimensionTypes = function(data) { var types = {}; d3.keys(data[0]) .forEach(function(col) { types[col] = pc.toTypeCoerceNumbers(data[0][col]); }); return types; }; pc.render = function() { // try to autodetect dimensions and create scales if (!__.dimensions.length) pc.detectDimensions(); if (!(__.dimensions[0] in yscale)) pc.autoscale(); pc.render[__.mode](); events.render.call(this); return this; }; pc.render['default'] = function() { pc.clear('foreground'); if (__.brushed) { __.brushed.forEach(path_foreground); __.highlighted.forEach(path_highlight); } else { __.data.forEach(path_foreground); __.highlighted.forEach(path_highlight); } }; var rqueue = d3.renderQueue(path_foreground) .rate(50) .clear(function() { pc.clear('foreground'); pc.clear('highlight'); }); pc.render.queue = function() { if (__.brushed) { rqueue(__.brushed); __.highlighted.forEach(path_highlight); } else { rqueue(__.data); __.highlighted.forEach(path_highlight); } }; function compute_cluster_centroids(d) { var clusterCentroids = d3.map(); var clusterCounts = d3.map(); // determine clusterCounts __.data.forEach(function(row) { var scaled = yscale[d](row[d]); if (!clusterCounts.has(scaled)) { clusterCounts.set(scaled, 0); } var count = clusterCounts.get(scaled); clusterCounts.set(scaled, count + 1); }); __.data.forEach(function(row) { __.dimensions.map(function(p, i) { var scaled = yscale[d](row[d]); if (!clusterCentroids.has(scaled)) { var map = d3.map(); clusterCentroids.set(scaled, map); } if (!clusterCentroids.get(scaled).has(p)) { clusterCentroids.get(scaled).set(p, 0); } var value = clusterCentroids.get(scaled).get(p); value += yscale[p](row[p]) / clusterCounts.get(scaled); clusterCentroids.get(scaled).set(p, value); }); }); return clusterCentroids; } function compute_centroids(row) { var centroids = []; var p = __.dimensions; var cols = p.length; var a = 0.5; // center between axes for (var i = 0; i < cols; ++i) { // centroids on 'real' axes var x = position(p[i]); var y = yscale[p[i]](row[p[i]]); centroids.push([x, y]); //centroids.push($V([x, y])); // centroids on 'virtual' axes if (i < cols - 1) { var cx = x + a * (position(p[i+1]) - x); var cy = y + a * (yscale[p[i+1]](row[p[i+1]]) - y); if (__.bundleDimension !== null) { var leftCentroid = __.clusterCentroids.get(yscale[__.bundleDimension](row[__.bundleDimension])).get(p[i]); var rightCentroid = __.clusterCentroids.get(yscale[__.bundleDimension](row[__.bundleDimension])).get(p[i+1]); var centroid = 0.5 * (leftCentroid + rightCentroid); cy = centroid + (1 - __.bundlingStrength) * (cy - centroid); } centroids.push([cx, cy]); //centroids.push($V([cx, cy])); } } return centroids; } pc.compute_centroids = compute_centroids; function compute_control_points(centroids) { var cols = centroids.length; var a = __.smoothness; var cps = []; cps.push(centroids[0]); cps.push($V([centroids[0].e(1) + a*2*(centroids[1].e(1)-centroids[0].e(1)), centroids[0].e(2)])); for (var col = 1; col < cols - 1; ++col) { var mid = centroids[col]; var left = centroids[col - 1]; var right = centroids[col + 1]; var diff = left.subtract(right); cps.push(mid.add(diff.x(a))); cps.push(mid); cps.push(mid.subtract(diff.x(a))); } cps.push($V([centroids[cols-1].e(1) + a*2*(centroids[cols-2].e(1)-centroids[cols-1].e(1)), centroids[cols-1].e(2)])); cps.push(centroids[cols - 1]); return cps; };pc.shadows = function() { flags.shadows = true; if (__.data.length > 0) { paths(__.data, ctx.shadows); } return this; }; // draw little dots on the axis line where data intersects pc.axisDots = function() { var ctx = pc.ctx.marks; ctx.globalAlpha = d3.min([ 1 / Math.pow(data.length, 1 / 2), 1 ]); __.data.forEach(function(d) { __.dimensions.map(function(p, i) { ctx.fillRect(position(p) - 0.75, yscale[p](d[p]) - 0.75, 1.5, 1.5); }); }); return this; }; // draw single cubic bezier curve function single_curve(d, ctx) { var centroids = compute_centroids(d); var cps = compute_control_points(centroids); ctx.moveTo(cps[0].e(1), cps[0].e(2)); for (var i = 1; i < cps.length; i += 3) { if (__.showControlPoints) { for (var j = 0; j < 3; j++) { ctx.fillRect(cps[i+j].e(1), cps[i+j].e(2), 2, 2); } } ctx.bezierCurveTo(cps[i].e(1), cps[i].e(2), cps[i+1].e(1), cps[i+1].e(2), cps[i+2].e(1), cps[i+2].e(2)); } }; // draw single polyline function color_path(d, i, ctx) { ctx.strokeStyle = d3.functor(__.color)(d, i); ctx.beginPath(); if (__.bundleDimension === null || (__.bundlingStrength === 0 && __.smoothness == 0)) { single_path(d, ctx); } else { single_curve(d, ctx); } ctx.stroke(); }; // draw many polylines of the same color function paths(data, ctx) { ctx.clearRect(-1, -1, w() + 2, h() + 2); ctx.beginPath(); data.forEach(function(d) { if (__.bundleDimension === null || (__.bundlingStrength === 0 && __.smoothness == 0)) { single_path(d, ctx); } else { single_curve(d, ctx); } }); ctx.stroke(); }; function single_path(d, ctx) { __.dimensions.map(function(p, i) { if (i == 0) { ctx.moveTo(position(p), yscale[p](d[p])); } else { ctx.lineTo(position(p), yscale[p](d[p])); } }); } function path_foreground(d, i) { return color_path(d, i, ctx.foreground); }; function path_highlight(d, i) { return color_path(d, i, ctx.highlight); }; pc.clear = function(layer) { ctx[layer].clearRect(0,0,w()+2,h()+2); return this; }; function flipAxisAndUpdatePCP(dimension, i) { var g = pc.svg.selectAll(".dimension"); pc.flip(dimension); d3.select(g[0][i]) .transition() .duration(1100) .call(axis.scale(yscale[dimension])); pc.render(); if (flags.shadows) paths(__.data, ctx.shadows); } function rotateLabels() { var delta = d3.event.deltaY; delta = delta < 0 ? -5 : delta; delta = delta > 0 ? 5 : delta; __.dimensionTitleRotation += delta; pc.svg.selectAll("text.label") .attr("transform", "translate(0,-5) rotate(" + __.dimensionTitleRotation + ")"); d3.event.preventDefault(); } pc.createAxes = function() { if (g) pc.removeAxes(); // Add a group element for each dimension. g = pc.svg.selectAll(".dimension") .data(__.dimensions, function(d) { return d; }) .enter().append("svg:g") .attr("class", "dimension") .attr("transform", function(d) { return "translate(" + xscale(d) + ")"; }); // Add an axis and title. g.append("svg:g") .attr("class", "axis") .attr("transform", "translate(0,0)") .each(function(d) { d3.select(this).call(axis.scale(yscale[d])); }) .append("svg:text") .attr({ "text-anchor": "middle", "y": 0, "transform": "translate(0,-5) rotate(" + __.dimensionTitleRotation + ")", "x": 0, "class": "label" }) .text(function(d) { return d in __.dimensionTitles ? __.dimensionTitles[d] : d; // dimension display names }) .on("dblclick", flipAxisAndUpdatePCP) .on("wheel", rotateLabels); flags.axes= true; return this; }; pc.removeAxes = function() { g.remove(); return this; }; pc.updateAxes = function() { var g_data = pc.svg.selectAll(".dimension").data(__.dimensions); // Enter g_data.enter().append("svg:g") .attr("class", "dimension") .attr("transform", function(p) { return "translate(" + position(p) + ")"; }) .style("opacity", 0) .append("svg:g") .attr("class", "axis") .attr("transform", "translate(0,0)") .each(function(d) { d3.select(this).call(axis.scale(yscale[d])); }) .append("svg:text") .attr({ "text-anchor": "middle", "y": 0, "transform": "translate(0,-5) rotate(" + __.dimensionTitleRotation + ")", "x": 0, "class": "label" }) .text(String) .on("dblclick", flipAxisAndUpdatePCP) .on("wheel", rotateLabels); // Update g_data.attr("opacity", 0); g_data.select(".axis") .transition() .duration(1100) .each(function(d) { d3.select(this).call(axis.scale(yscale[d])); }); g_data.select(".label") .transition() .duration(1100) .text(String) .attr("transform", "translate(0,-5) rotate(" + __.dimensionTitleRotation + ")"); // Exit g_data.exit().remove(); g = pc.svg.selectAll(".dimension"); g.transition().duration(1100) .attr("transform", function(p) { return "translate(" + position(p) + ")"; }) .style("opacity", 1); pc.svg.selectAll(".axis") .transition() .duration(1100) .each(function(d) { d3.select(this).call(axis.scale(yscale[d])); }); if (flags.shadows) paths(__.data, ctx.shadows); if (flags.brushable) pc.brushable(); if (flags.reorderable) pc.reorderable(); if (pc.brushMode() !== "None") { var mode = pc.brushMode(); pc.brushMode("None"); pc.brushMode(mode); } return this; }; // Jason Davies, http://bl.ocks.org/1341281 pc.reorderable = function() { if (!g) pc.createAxes(); // Keep track of the order of the axes to verify if the order has actually // changed after a drag ends. Changed order might have consequence (e.g. // strums that need to be reset). var dimsAtDragstart; g.style("cursor", "move") .call(d3.behavior.drag() .on("dragstart", function(d) { dragging[d] = this.__origin__ = xscale(d); dimsAtDragstart = __.dimensions.slice(); }) .on("drag", function(d) { dragging[d] = Math.min(w(), Math.max(0, this.__origin__ += d3.event.dx)); __.dimensions.sort(function(a, b) { return position(a) - position(b); }); xscale.domain(__.dimensions); pc.render(); g.attr("transform", function(d) { return "translate(" + position(d) + ")"; }); }) .on("dragend", function(d, i) { // Let's see if the order has changed and send out an event if so. var j = __.dimensions.indexOf(d), parent = this.parentElement; if (i !== j) { events.axesreorder.call(pc, __.dimensions); // We now also want to reorder the actual dom elements that represent // the axes. That is, the g.dimension elements. If we don't do this, // we get a weird and confusing transition when updateAxes is called. // This is due to the fact that, initially the nth g.dimension element // represents the nth axis. However, after a manual reordering, // without reordering the dom elements, the nth dom elements no longer // necessarily represents the nth axis. // // i is the original index of the dom element // j is the new index of the dom element parent.insertBefore(this, parent.children[j + 1]) } delete this.__origin__; delete dragging[d]; d3.select(this).transition().attr("transform", "translate(" + xscale(d) + ")"); pc.render(); if (flags.shadows) paths(__.data, ctx.shadows); })); flags.reorderable = true; return this; }; // pairs of adjacent dimensions pc.adjacent_pairs = function(arr) { var ret = []; for (var i = 0; i < arr.length-1; i++) { ret.push([arr[i],arr[i+1]]); }; return ret; }; var brush = { modes: { "None": { install: function(pc) {}, // Nothing to be done. uninstall: function(pc) {}, // Nothing to be done. selected: function() { return []; } // Nothing to return } }, mode: "None", predicate: "AND", currentMode: function() { return this.modes[this.mode]; } }; // This function can be used for 'live' updates of brushes. That is, during the // specification of a brush, this method can be called to update the view. // // @param newSelection - The new set of data items that is currently contained // by the brushes function brushUpdated(newSelection) { __.brushed = newSelection; events.brush.call(pc,__.brushed); pc.render(); } function brushPredicate(predicate) { if (!arguments.length) { return brush.predicate; } predicate = String(predicate).toUpperCase(); if (predicate !== "AND" && predicate !== "OR") { throw "Invalid predicate " + predicate; } brush.predicate = predicate; __.brushed = brush.currentMode().selected(); pc.render(); return pc; } pc.brushModes = function() { return Object.getOwnPropertyNames(brush.modes); }; pc.brushMode = function(mode) { if (arguments.length === 0) { return brush.mode; } if (pc.brushModes().indexOf(mode) === -1) { throw "pc.brushmode: Unsupported brush mode: " + mode; } // Make sure that we don't trigger unnecessary events by checking if the mode // actually changes. if (mode !== brush.mode) { // When changing brush modes, the first thing we need to do is clearing any // brushes from the current mode, if any. if (brush.mode !== "None") { pc.brushReset(); } // Next, we need to 'uninstall' the current brushMode. brush.modes[brush.mode].uninstall(pc); // Finally, we can install the requested one. brush.mode = mode; brush.modes[brush.mode].install(); if (mode === "None") { delete pc.brushPredicate; } else { pc.brushPredicate = brushPredicate; } } return pc; }; // brush mode: 1D-Axes (function() { var brushes = {}; function is_brushed(p) { return !brushes[p].empty(); } // data within extents function selected() { var actives = __.dimensions.filter(is_brushed), extents = actives.map(function(p) { return brushes[p].extent(); }); // We don't want to return the full data set when there are no axes brushed. // Actually, when there are no axes brushed, by definition, no items are // selected. So, let's avoid the filtering and just return false. //if (actives.length === 0) return false; // Resolves broken examples for now. They expect to get the full dataset back from empty brushes if (actives.length === 0) return __.data; // test if within range var within = { "date": function(d,p,dimension) { return extents[dimension][0] <= d[p] && d[p] <= extents[dimension][1] }, "number": function(d,p,dimension) { return extents[dimension][0] <= d[p] && d[p] <= extents[dimension][1] }, "string": function(d,p,dimension) { return extents[dimension][0] <= yscale[p](d[p]) && yscale[p](d[p]) <= extents[dimension][1] } }; return __.data .filter(function(d) { switch(brush.predicate) { case "AND": return actives.every(function(p, dimension) { return within[__.types[p]](d,p,dimension); }); case "OR": return actives.some(function(p, dimension) { return within[__.types[p]](d,p,dimension); }); default: throw "Unknown brush predicate " + __.brushPredicate; } }); }; function brushExtents() { var extents = {}; __.dimensions.forEach(function(d) { var brush = brushes[d]; if (!brush.empty()) { var extent = brush.extent(); extent.sort(d3.ascending); extents[d] = extent; } }); return extents; } function brushFor(axis) { var brush = d3.svg.brush(); brush .y(yscale[axis]) .on("brushstart", function() { d3.event.sourceEvent.stopPropagation() }) .on("brush", function() { brushUpdated(selected()); }) .on("brushend", function() { events.brushend.call(pc, __.brushed); }); brushes[axis] = brush; return brush; } function brushReset(dimension) { __.brushed = false; if (g) { g.selectAll('.brush') .each(function(d) { d3.select(this).call( brushes[d].clear() ); }); pc.render(); } return this; }; function install() { if (!g) pc.createAxes(); // Add and store a brush for each axis. g.append("svg:g") .attr("class", "brush") .each(function(d) { d3.select(this).call(brushFor(d)); }) .selectAll("rect") .style("visibility", null) .attr("x", -15) .attr("width", 30); pc.brushExtents = brushExtents; pc.brushReset = brushReset; return pc; } brush.modes["1D-axes"] = { install: install, uninstall: function() { g.selectAll(".brush").remove(); brushes = {}; delete pc.brushExtents; delete pc.brushReset; }, selected: selected } })(); // brush mode: 2D-strums // bl.ocks.org/syntagmatic/5441022 (function() { var strums = {}, strumRect; function drawStrum(strum, activePoint) { var svg = pc.selection.select("svg").select("g#strums"), id = strum.dims.i, points = [strum.p1, strum.p2], line = svg.selectAll("line#strum-" + id).data([strum]), circles = svg.selectAll("circle#strum-" + id).data(points), drag = d3.behavior.drag(); line.enter() .append("line") .attr("id", "strum-" + id) .attr("class", "strum"); line .attr("x1", function(d) { return d.p1[0]; }) .attr("y1", function(d) { return d.p1[1]; }) .attr("x2", function(d) { return d.p2[0]; }) .attr("y2", function(d) { return d.p2[1]; }) .attr("stroke", "black") .attr("stroke-width", 2); drag .on("drag", function(d, i) { var ev = d3.event; i = i + 1; strum["p" + i][0] = Math.min(Math.max(strum.minX + 1, ev.x), strum.maxX); strum["p" + i][1] = Math.min(Math.max(strum.minY, ev.y), strum.maxY); drawStrum(strum, i - 1); }) .on("dragend", onDragEnd()); circles.enter() .append("circle") .attr("id", "strum-" + id) .attr("class", "strum"); circles .attr("cx", function(d) { return d[0]; }) .attr("cy", function(d) { return d[1]; }) .attr("r", 5) .style("opacity", function(d, i) { return (activePoint !== undefined && i === activePoint) ? 0.8 : 0; }) .on("mouseover", function() { d3.select(this).style("opacity", 0.8); }) .on("mouseout", function() { d3.select(this).style("opacity", 0); }) .call(drag); } function dimensionsForPoint(p) { var dims = { i: -1, left: undefined, right: undefined }; __.dimensions.some(function(dim, i) { if (xscale(dim) < p[0]) { var next = __.dimensions[i + 1]; dims.i = i; dims.left = dim; dims.right = next; return false; } return true; }); if (dims.left === undefined) { // Event on the left side of the first axis. dims.i = 0; dims.left = __.dimensions[0]; dims.right = __.dimensions[1]; } else if (dims.right === undefined) { // Event on the right side of the last axis dims.i = __.dimensions.length - 1; dims.right = dims.left; dims.left = __.dimensions[__.dimensions.length - 2]; } return dims; } function onDragStart() { // First we need to determine between which two axes the sturm was started. // This will determine the freedom of movement, because a strum can // logically only happen between two axes, so no movement outside these axes // should be allowed. return function() { var p = d3.mouse(strumRect[0][0]), dims = dimensionsForPoint(p), strum = { p1: p, dims: dims, minX: xscale(dims.left), maxX: xscale(dims.right), minY: 0, maxY: h() }; strums[dims.i] = strum; strums.active = dims.i; // Make sure that the point is within the bounds strum.p1[0] = Math.min(Math.max(strum.minX, p[0]), strum.maxX); strum.p1[1] = p[1] - __.margin.top; strum.p2 = strum.p1.slice(); }; } function onDrag() { return function() { var ev = d3.event, strum = strums[strums.active]; // Make sure that the point is within the bounds strum.p2[0] = Math.min(Math.max(strum.minX + 1, ev.x), strum.maxX); strum.p2[1] = Math.min(Math.max(strum.minY, ev.y - __.margin.top), strum.maxY); drawStrum(strum, 1); }; } function containmentTest(strum, width) { var p1 = [strum.p1[0] - strum.minX, strum.p1[1] - strum.minX], p2 = [strum.p2[0] - strum.minX, strum.p2[1] - strum.minX], m1 = 1 - width / p1[0], b1 = p1[1] * (1 - m1), m2 = 1 - width / p2[0], b2 = p2[1] * (1 - m2); // test if point falls between lines return function(p) { var x = p[0], y = p[1], y1 = m1 * x + b1, y2 = m2 * x + b2; if (y > Math.min(y1, y2) && y < Math.max(y1, y2)) { return true; } return false; }; } function selected() { var ids = Object.getOwnPropertyNames(strums), brushed = __.data; // Get the ids of the currently active strums. ids = ids.filter(function(d) { return !isNaN(d); }); function crossesStrum(d, id) { var strum = strums[id], test = containmentTest(strum, strums.width(id)), d1 = strum.dims.left, d2 = strum.dims.right, y1 = yscale[d1], y2 = yscale[d2], point = [y1(d[d1]) - strum.minX, y2(d[d2]) - strum.minX]; return test(point); } if (ids.length === 0) { return brushed; } return brushed.filter(function(d) { switch(brush.predicate) { case "AND": return ids.every(function(id) { return crossesStrum(d, id); }); case "OR": return ids.some(function(id) { return crossesStrum(d, id); }); default: throw "Unknown brush predicate " + __.brushPredicate; } }); } function removeStrum() { var strum = strums[strums.active], svg = pc.selection.select("svg").select("g#strums"); delete strums[strums.active]; strums.active = undefined; svg.selectAll("line#strum-" + strum.dims.i).remove(); svg.selectAll("circle#strum-" + strum.dims.i).remove(); } function onDragEnd() { return function() { var brushed = __.data, strum = strums[strums.active]; // Okay, somewhat unexpected, but not totally unsurprising, a mousclick is // considered a drag without move. So we have to deal with that case if (strum && strum.p1[0] === strum.p2[0] && strum.p1[1] === strum.p2[1]) { removeStrum(strums); } brushed = selected(strums); strums.active = undefined; __.brushed = brushed; pc.render(); events.brushend.call(pc, __.brushed); }; } function brushReset(strums) { return function() { var ids = Object.getOwnPropertyNames(strums).filter(function(d) { return !isNaN(d); }); ids.forEach(function(d) { strums.active = d; removeStrum(strums); }); onDragEnd(strums)(); }; } function install() { var drag = d3.behavior.drag(); // Map of current strums. Strums are stored per segment of the PC. A segment, // being the area between two axes. The left most area is indexed at 0. strums.active = undefined; // Returns the width of the PC segment where currently a strum is being // placed. NOTE: even though they are evenly spaced in our current // implementation, we keep for when non-even spaced segments are supported as // well. strums.width = function(id) { var strum = strums[id]; if (strum === undefined) { return undefined; } return strum.maxX - strum.minX; }; pc.on("axesreorder.strums", function() { var ids = Object.getOwnPropertyNames(strums).filter(function(d) { return !isNaN(d); }); // Checks if the first dimension is directly left of the second dimension. function consecutive(first, second) { var length = __.dimensions.length; return __.dimensions.some(function(d, i) { return (d === first) ? i + i < length && __.dimensions[i + 1] === second : false; }); } if (ids.length > 0) { // We have some strums, which might need to be removed. ids.forEach(function(d) { var dims = strums[d].dims; strums.active = d; // If the two dimensions of the current strum are not next to each other // any more, than we'll need to remove the strum. Otherwise we keep it. if (!consecutive(dims.left, dims.right)) { removeStrum(strums); } }); onDragEnd(strums)(); } }); // Add a new svg group in which we draw the strums. pc.selection.select("svg").append("g") .attr("id", "strums") .attr("transform", "translate(" + __.margin.left + "," + __.margin.top + ")"); // Install the required brushReset function pc.brushReset = brushReset(strums); drag .on("dragstart", onDragStart(strums)) .on("drag", onDrag(strums)) .on("dragend", onDragEnd(strums)); // NOTE: The styling needs to be done here and not in the css. This is because // for 1D brushing, the canvas layers should not listen to // pointer-events. strumRect = pc.selection.select("svg").insert("rect", "g#strums") .attr("id", "strum-events") .attr("x", __.margin.left) .attr("y", __.margin.top) .attr("width", w()) .attr("height", h() + 2) .style("opacity", 0) .call(drag); } brush.modes["2D-strums"] = { install: install, uninstall: function() { pc.selection.select("svg").select("g#strums").remove(); pc.selection.select("svg").select("rect#strum-events").remove(); pc.on("axesreorder.strums", undefined); delete pc.brushReset; strumRect = undefined; }, selected: selected }; }()); pc.interactive = function() { flags.interactive = true; return this; }; // expose a few objects pc.xscale = xscale; pc.yscale = yscale; pc.ctx = ctx; pc.canvas = canvas; pc.g = function() { return g; }; // rescale for height, width and margins // TODO currently assumes chart is brushable, and destroys old brushes pc.resize = function() { // selection size pc.selection.select("svg") .attr("width", __.width) .attr("height", __.height) pc.svg.attr("transform", "translate(" + __.margin.left + "," + __.margin.top + ")"); // FIXME: the current brush state should pass through if (flags.brushable) pc.brushReset(); // scales pc.autoscale(); // axes, destroys old brushes. if (g) pc.createAxes(); if (flags.shadows) paths(__.data, ctx.shadows); if (flags.brushable) pc.brushable(); if (flags.reorderable) pc.reorderable(); events.resize.call(this, {width: __.width, height: __.height, margin: __.margin}); return this; }; // highlight an array of data pc.highlight = function(data) { if (arguments.length === 0) { return __.highlighted; } __.highlighted = data; pc.clear("highlight"); d3.select(canvas.foreground).classed("faded", true); data.forEach(path_highlight); events.highlight.call(this, data); return this; }; // clear highlighting pc.unhighlight = function() { __.highlighted = []; pc.clear("highlight"); d3.select(canvas.foreground).classed("faded", false); return this; }; // calculate 2d intersection of line a->b with line c->d // points are objects with x and y properties pc.intersection = function(a, b, c, d) { return { x: ((a.x * b.y - a.y * b.x) * (c.x - d.x) - (a.x - b.x) * (c.x * d.y - c.y * d.x)) / ((a.x - b.x) * (c.y - d.y) - (a.y - b.y) * (c.x - d.x)), y: ((a.x * b.y - a.y * b.x) * (c.y - d.y) - (a.y - b.y) * (c.x * d.y - c.y * d.x)) / ((a.x - b.x) * (c.y - d.y) - (a.y - b.y) * (c.x - d.x)) }; }; function position(d) { var v = dragging[d]; return v == null ? xscale(d) : v; } pc.version = "0.5.0"; // this descriptive text should live with other introspective methods pc.toString = function() { return "Parallel Coordinates: " + __.dimensions.length + " dimensions (" + d3.keys(__.data[0]).length + " total) , " + __.data.length + " rows"; }; return pc; }; d3.renderQueue = (function(func) { var _queue = [], // data to be rendered _rate = 10, // number of calls per frame _clear = function() {}, // clearing function _i = 0; // current iteration var rq = function(data) { if (data) rq.data(data); rq.invalidate(); _clear(); rq.render(); }; rq.render = function() { _i = 0; var valid = true; rq.invalidate = function() { valid = false; }; function doFrame() { if (!valid) return true; if (_i > _queue.length) return true; // Typical d3 behavior is to pass a data item *and* its index. As the // render queue splits the original data set, we'll have to be slightly // more carefull about passing the correct index with the data item. var end = Math.min(_i + _rate, _queue.length); for (var i = _i; i < end; i++) { func(_queue[i], i); } _i += _rate; } d3.timer(doFrame); }; rq.data = function(data) { rq.invalidate(); _queue = data.slice(0); return rq; }; rq.rate = function(value) { if (!arguments.length) return _rate; _rate = value; return rq; }; rq.remaining = function() { return _queue.length - _i; }; // clear the canvas rq.clear = function(func) { if (!arguments.length) { _clear(); return rq; } _clear = func; return rq; }; rq.invalidate = function() {}; return rq; }); d3.divgrid = function(config) { var columns = []; var dg = function(selection) { if (columns.length == 0) { columns = d3.keys(selection.data()[0][0]); columns = columns.filter( function(item) { return (item.substr(item.length - 5) != "(log)"); }); } // header selection.selectAll(".header") .data([true]) .enter().append("div") .attr("class", "header") var header = selection.select(".header") .selectAll(".cell") .data(columns); header.enter().append("div") .attr("class", function(d,i) { return "col-" + i; }) .classed("cell", true) selection.selectAll(".header .cell") .text(function(d) { return d; }); header.exit().remove(); // rows var rows = selection.selectAll(".row") .data(function(d) { return d; }) rows.enter().append("div") .attr("class", "row") rows.exit().remove(); var cells = selection.selectAll(".row").selectAll(".cell") .data(function(d) { return columns.map(function(col){return d[col];}) }) // cells cells.enter().append("div") .attr("class", function(d,i) { return "col-" + i; }) .classed("cell", true) cells.exit().remove(); selection.selectAll(".cell") .text(function(d) { return d; }); return dg; }; dg.columns = function(_) { if (!arguments.length) return columns; columns = _; return this; }; return dg; }; ================================================ FILE: datalab/notebook/static/extern/lantern-browser.html ================================================ ================================================ FILE: datalab/notebook/static/extern/parcoords-LICENSE.txt ================================================ Copyright (c) 2012, Kai Chang All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * The name Kai Chang may not be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: datalab/notebook/static/extern/sylvester-LICENSE.txt ================================================ (The MIT License) Copyright (c) 2007-2015 James Coglan Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the 'Software'), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: datalab/notebook/static/extern/sylvester.js ================================================ // === Sylvester === // Vector and Matrix mathematics modules for JavaScript // Copyright (c) 2007 James Coglan // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the "Software"), // to deal in the Software without restriction, including without limitation // the rights to use, copy, modify, merge, publish, distribute, sublicense, // and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included // in all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. var Sylvester = { version: '0.1.3', precision: 1e-6 }; function Vector() {} Vector.prototype = { // Returns element i of the vector e: function(i) { return (i < 1 || i > this.elements.length) ? null : this.elements[i-1]; }, // Returns the number of elements the vector has dimensions: function() { return this.elements.length; }, // Returns the modulus ('length') of the vector modulus: function() { return Math.sqrt(this.dot(this)); }, // Returns true iff the vector is equal to the argument eql: function(vector) { var n = this.elements.length; var V = vector.elements || vector; if (n != V.length) { return false; } do { if (Math.abs(this.elements[n-1] - V[n-1]) > Sylvester.precision) { return false; } } while (--n); return true; }, // Returns a copy of the vector dup: function() { return Vector.create(this.elements); }, // Maps the vector to another vector according to the given function map: function(fn) { var elements = []; this.each(function(x, i) { elements.push(fn(x, i)); }); return Vector.create(elements); }, // Calls the iterator for each element of the vector in turn each: function(fn) { var n = this.elements.length, k = n, i; do { i = k - n; fn(this.elements[i], i+1); } while (--n); }, // Returns a new vector created by normalizing the receiver toUnitVector: function() { var r = this.modulus(); if (r === 0) { return this.dup(); } return this.map(function(x) { return x/r; }); }, // Returns the angle between the vector and the argument (also a vector) angleFrom: function(vector) { var V = vector.elements || vector; var n = this.elements.length, k = n, i; if (n != V.length) { return null; } var dot = 0, mod1 = 0, mod2 = 0; // Work things out in parallel to save time this.each(function(x, i) { dot += x * V[i-1]; mod1 += x * x; mod2 += V[i-1] * V[i-1]; }); mod1 = Math.sqrt(mod1); mod2 = Math.sqrt(mod2); if (mod1*mod2 === 0) { return null; } var theta = dot / (mod1*mod2); if (theta < -1) { theta = -1; } if (theta > 1) { theta = 1; } return Math.acos(theta); }, // Returns true iff the vector is parallel to the argument isParallelTo: function(vector) { var angle = this.angleFrom(vector); return (angle === null) ? null : (angle <= Sylvester.precision); }, // Returns true iff the vector is antiparallel to the argument isAntiparallelTo: function(vector) { var angle = this.angleFrom(vector); return (angle === null) ? null : (Math.abs(angle - Math.PI) <= Sylvester.precision); }, // Returns true iff the vector is perpendicular to the argument isPerpendicularTo: function(vector) { var dot = this.dot(vector); return (dot === null) ? null : (Math.abs(dot) <= Sylvester.precision); }, // Returns the result of adding the argument to the vector add: function(vector) { var V = vector.elements || vector; if (this.elements.length != V.length) { return null; } return this.map(function(x, i) { return x + V[i-1]; }); }, // Returns the result of subtracting the argument from the vector subtract: function(vector) { var V = vector.elements || vector; if (this.elements.length != V.length) { return null; } return this.map(function(x, i) { return x - V[i-1]; }); }, // Returns the result of multiplying the elements of the vector by the argument multiply: function(k) { return this.map(function(x) { return x*k; }); }, x: function(k) { return this.multiply(k); }, // Returns the scalar product of the vector with the argument // Both vectors must have equal dimensionality dot: function(vector) { var V = vector.elements || vector; var i, product = 0, n = this.elements.length; if (n != V.length) { return null; } do { product += this.elements[n-1] * V[n-1]; } while (--n); return product; }, // Returns the vector product of the vector with the argument // Both vectors must have dimensionality 3 cross: function(vector) { var B = vector.elements || vector; if (this.elements.length != 3 || B.length != 3) { return null; } var A = this.elements; return Vector.create([ (A[1] * B[2]) - (A[2] * B[1]), (A[2] * B[0]) - (A[0] * B[2]), (A[0] * B[1]) - (A[1] * B[0]) ]); }, // Returns the (absolute) largest element of the vector max: function() { var m = 0, n = this.elements.length, k = n, i; do { i = k - n; if (Math.abs(this.elements[i]) > Math.abs(m)) { m = this.elements[i]; } } while (--n); return m; }, // Returns the index of the first match found indexOf: function(x) { var index = null, n = this.elements.length, k = n, i; do { i = k - n; if (index === null && this.elements[i] == x) { index = i + 1; } } while (--n); return index; }, // Returns a diagonal matrix with the vector's elements as its diagonal elements toDiagonalMatrix: function() { return Matrix.Diagonal(this.elements); }, // Returns the result of rounding the elements of the vector round: function() { return this.map(function(x) { return Math.round(x); }); }, // Returns a copy of the vector with elements set to the given value if they // differ from it by less than Sylvester.precision snapTo: function(x) { return this.map(function(y) { return (Math.abs(y - x) <= Sylvester.precision) ? x : y; }); }, // Returns the vector's distance from the argument, when considered as a point in space distanceFrom: function(obj) { if (obj.anchor) { return obj.distanceFrom(this); } var V = obj.elements || obj; if (V.length != this.elements.length) { return null; } var sum = 0, part; this.each(function(x, i) { part = x - V[i-1]; sum += part * part; }); return Math.sqrt(sum); }, // Returns true if the vector is point on the given line liesOn: function(line) { return line.contains(this); }, // Return true iff the vector is a point in the given plane liesIn: function(plane) { return plane.contains(this); }, // Rotates the vector about the given object. The object should be a // point if the vector is 2D, and a line if it is 3D. Be careful with line directions! rotate: function(t, obj) { var V, R, x, y, z; switch (this.elements.length) { case 2: V = obj.elements || obj; if (V.length != 2) { return null; } R = Matrix.Rotation(t).elements; x = this.elements[0] - V[0]; y = this.elements[1] - V[1]; return Vector.create([ V[0] + R[0][0] * x + R[0][1] * y, V[1] + R[1][0] * x + R[1][1] * y ]); break; case 3: if (!obj.direction) { return null; } var C = obj.pointClosestTo(this).elements; R = Matrix.Rotation(t, obj.direction).elements; x = this.elements[0] - C[0]; y = this.elements[1] - C[1]; z = this.elements[2] - C[2]; return Vector.create([ C[0] + R[0][0] * x + R[0][1] * y + R[0][2] * z, C[1] + R[1][0] * x + R[1][1] * y + R[1][2] * z, C[2] + R[2][0] * x + R[2][1] * y + R[2][2] * z ]); break; default: return null; } }, // Returns the result of reflecting the point in the given point, line or plane reflectionIn: function(obj) { if (obj.anchor) { // obj is a plane or line var P = this.elements.slice(); var C = obj.pointClosestTo(P).elements; return Vector.create([C[0] + (C[0] - P[0]), C[1] + (C[1] - P[1]), C[2] + (C[2] - (P[2] || 0))]); } else { // obj is a point var Q = obj.elements || obj; if (this.elements.length != Q.length) { return null; } return this.map(function(x, i) { return Q[i-1] + (Q[i-1] - x); }); } }, // Utility to make sure vectors are 3D. If they are 2D, a zero z-component is added to3D: function() { var V = this.dup(); switch (V.elements.length) { case 3: break; case 2: V.elements.push(0); break; default: return null; } return V; }, // Returns a string representation of the vector inspect: function() { return '[' + this.elements.join(', ') + ']'; }, // Set vector's elements from an array setElements: function(els) { this.elements = (els.elements || els).slice(); return this; } }; // Constructor function Vector.create = function(elements) { var V = new Vector(); return V.setElements(elements); }; // i, j, k unit vectors Vector.i = Vector.create([1,0,0]); Vector.j = Vector.create([0,1,0]); Vector.k = Vector.create([0,0,1]); // Random vector of size n Vector.Random = function(n) { var elements = []; do { elements.push(Math.random()); } while (--n); return Vector.create(elements); }; // Vector filled with zeros Vector.Zero = function(n) { var elements = []; do { elements.push(0); } while (--n); return Vector.create(elements); }; function Matrix() {} Matrix.prototype = { // Returns element (i,j) of the matrix e: function(i,j) { if (i < 1 || i > this.elements.length || j < 1 || j > this.elements[0].length) { return null; } return this.elements[i-1][j-1]; }, // Returns row k of the matrix as a vector row: function(i) { if (i > this.elements.length) { return null; } return Vector.create(this.elements[i-1]); }, // Returns column k of the matrix as a vector col: function(j) { if (j > this.elements[0].length) { return null; } var col = [], n = this.elements.length, k = n, i; do { i = k - n; col.push(this.elements[i][j-1]); } while (--n); return Vector.create(col); }, // Returns the number of rows/columns the matrix has dimensions: function() { return {rows: this.elements.length, cols: this.elements[0].length}; }, // Returns the number of rows in the matrix rows: function() { return this.elements.length; }, // Returns the number of columns in the matrix cols: function() { return this.elements[0].length; }, // Returns true iff the matrix is equal to the argument. You can supply // a vector as the argument, in which case the receiver must be a // one-column matrix equal to the vector. eql: function(matrix) { var M = matrix.elements || matrix; if (typeof(M[0][0]) == 'undefined') { M = Matrix.create(M).elements; } if (this.elements.length != M.length || this.elements[0].length != M[0].length) { return false; } var ni = this.elements.length, ki = ni, i, nj, kj = this.elements[0].length, j; do { i = ki - ni; nj = kj; do { j = kj - nj; if (Math.abs(this.elements[i][j] - M[i][j]) > Sylvester.precision) { return false; } } while (--nj); } while (--ni); return true; }, // Returns a copy of the matrix dup: function() { return Matrix.create(this.elements); }, // Maps the matrix to another matrix (of the same dimensions) according to the given function map: function(fn) { var els = [], ni = this.elements.length, ki = ni, i, nj, kj = this.elements[0].length, j; do { i = ki - ni; nj = kj; els[i] = []; do { j = kj - nj; els[i][j] = fn(this.elements[i][j], i + 1, j + 1); } while (--nj); } while (--ni); return Matrix.create(els); }, // Returns true iff the argument has the same dimensions as the matrix isSameSizeAs: function(matrix) { var M = matrix.elements || matrix; if (typeof(M[0][0]) == 'undefined') { M = Matrix.create(M).elements; } return (this.elements.length == M.length && this.elements[0].length == M[0].length); }, // Returns the result of adding the argument to the matrix add: function(matrix) { var M = matrix.elements || matrix; if (typeof(M[0][0]) == 'undefined') { M = Matrix.create(M).elements; } if (!this.isSameSizeAs(M)) { return null; } return this.map(function(x, i, j) { return x + M[i-1][j-1]; }); }, // Returns the result of subtracting the argument from the matrix subtract: function(matrix) { var M = matrix.elements || matrix; if (typeof(M[0][0]) == 'undefined') { M = Matrix.create(M).elements; } if (!this.isSameSizeAs(M)) { return null; } return this.map(function(x, i, j) { return x - M[i-1][j-1]; }); }, // Returns true iff the matrix can multiply the argument from the left canMultiplyFromLeft: function(matrix) { var M = matrix.elements || matrix; if (typeof(M[0][0]) == 'undefined') { M = Matrix.create(M).elements; } // this.columns should equal matrix.rows return (this.elements[0].length == M.length); }, // Returns the result of multiplying the matrix from the right by the argument. // If the argument is a scalar then just multiply all the elements. If the argument is // a vector, a vector is returned, which saves you having to remember calling // col(1) on the result. multiply: function(matrix) { if (!matrix.elements) { return this.map(function(x) { return x * matrix; }); } var returnVector = matrix.modulus ? true : false; var M = matrix.elements || matrix; if (typeof(M[0][0]) == 'undefined') { M = Matrix.create(M).elements; } if (!this.canMultiplyFromLeft(M)) { return null; } var ni = this.elements.length, ki = ni, i, nj, kj = M[0].length, j; var cols = this.elements[0].length, elements = [], sum, nc, c; do { i = ki - ni; elements[i] = []; nj = kj; do { j = kj - nj; sum = 0; nc = cols; do { c = cols - nc; sum += this.elements[i][c] * M[c][j]; } while (--nc); elements[i][j] = sum; } while (--nj); } while (--ni); var M = Matrix.create(elements); return returnVector ? M.col(1) : M; }, x: function(matrix) { return this.multiply(matrix); }, // Returns a submatrix taken from the matrix // Argument order is: start row, start col, nrows, ncols // Element selection wraps if the required index is outside the matrix's bounds, so you could // use this to perform row/column cycling or copy-augmenting. minor: function(a, b, c, d) { var elements = [], ni = c, i, nj, j; var rows = this.elements.length, cols = this.elements[0].length; do { i = c - ni; elements[i] = []; nj = d; do { j = d - nj; elements[i][j] = this.elements[(a+i-1)%rows][(b+j-1)%cols]; } while (--nj); } while (--ni); return Matrix.create(elements); }, // Returns the transpose of the matrix transpose: function() { var rows = this.elements.length, cols = this.elements[0].length; var elements = [], ni = cols, i, nj, j; do { i = cols - ni; elements[i] = []; nj = rows; do { j = rows - nj; elements[i][j] = this.elements[j][i]; } while (--nj); } while (--ni); return Matrix.create(elements); }, // Returns true iff the matrix is square isSquare: function() { return (this.elements.length == this.elements[0].length); }, // Returns the (absolute) largest element of the matrix max: function() { var m = 0, ni = this.elements.length, ki = ni, i, nj, kj = this.elements[0].length, j; do { i = ki - ni; nj = kj; do { j = kj - nj; if (Math.abs(this.elements[i][j]) > Math.abs(m)) { m = this.elements[i][j]; } } while (--nj); } while (--ni); return m; }, // Returns the indeces of the first match found by reading row-by-row from left to right indexOf: function(x) { var index = null, ni = this.elements.length, ki = ni, i, nj, kj = this.elements[0].length, j; do { i = ki - ni; nj = kj; do { j = kj - nj; if (this.elements[i][j] == x) { return {i: i+1, j: j+1}; } } while (--nj); } while (--ni); return null; }, // If the matrix is square, returns the diagonal elements as a vector. // Otherwise, returns null. diagonal: function() { if (!this.isSquare) { return null; } var els = [], n = this.elements.length, k = n, i; do { i = k - n; els.push(this.elements[i][i]); } while (--n); return Vector.create(els); }, // Make the matrix upper (right) triangular by Gaussian elimination. // This method only adds multiples of rows to other rows. No rows are // scaled up or switched, and the determinant is preserved. toRightTriangular: function() { var M = this.dup(), els; var n = this.elements.length, k = n, i, np, kp = this.elements[0].length, p; do { i = k - n; if (M.elements[i][i] == 0) { for (j = i + 1; j < k; j++) { if (M.elements[j][i] != 0) { els = []; np = kp; do { p = kp - np; els.push(M.elements[i][p] + M.elements[j][p]); } while (--np); M.elements[i] = els; break; } } } if (M.elements[i][i] != 0) { for (j = i + 1; j < k; j++) { var multiplier = M.elements[j][i] / M.elements[i][i]; els = []; np = kp; do { p = kp - np; // Elements with column numbers up to an including the number // of the row that we're subtracting can safely be set straight to // zero, since that's the point of this routine and it avoids having // to loop over and correct rounding errors later els.push(p <= i ? 0 : M.elements[j][p] - M.elements[i][p] * multiplier); } while (--np); M.elements[j] = els; } } } while (--n); return M; }, toUpperTriangular: function() { return this.toRightTriangular(); }, // Returns the determinant for square matrices determinant: function() { if (!this.isSquare()) { return null; } var M = this.toRightTriangular(); var det = M.elements[0][0], n = M.elements.length - 1, k = n, i; do { i = k - n + 1; det = det * M.elements[i][i]; } while (--n); return det; }, det: function() { return this.determinant(); }, // Returns true iff the matrix is singular isSingular: function() { return (this.isSquare() && this.determinant() === 0); }, // Returns the trace for square matrices trace: function() { if (!this.isSquare()) { return null; } var tr = this.elements[0][0], n = this.elements.length - 1, k = n, i; do { i = k - n + 1; tr += this.elements[i][i]; } while (--n); return tr; }, tr: function() { return this.trace(); }, // Returns the rank of the matrix rank: function() { var M = this.toRightTriangular(), rank = 0; var ni = this.elements.length, ki = ni, i, nj, kj = this.elements[0].length, j; do { i = ki - ni; nj = kj; do { j = kj - nj; if (Math.abs(M.elements[i][j]) > Sylvester.precision) { rank++; break; } } while (--nj); } while (--ni); return rank; }, rk: function() { return this.rank(); }, // Returns the result of attaching the given argument to the right-hand side of the matrix augment: function(matrix) { var M = matrix.elements || matrix; if (typeof(M[0][0]) == 'undefined') { M = Matrix.create(M).elements; } var T = this.dup(), cols = T.elements[0].length; var ni = T.elements.length, ki = ni, i, nj, kj = M[0].length, j; if (ni != M.length) { return null; } do { i = ki - ni; nj = kj; do { j = kj - nj; T.elements[i][cols + j] = M[i][j]; } while (--nj); } while (--ni); return T; }, // Returns the inverse (if one exists) using Gauss-Jordan inverse: function() { if (!this.isSquare() || this.isSingular()) { return null; } var ni = this.elements.length, ki = ni, i, j; var M = this.augment(Matrix.I(ni)).toRightTriangular(); var np, kp = M.elements[0].length, p, els, divisor; var inverse_elements = [], new_element; // Matrix is non-singular so there will be no zeros on the diagonal // Cycle through rows from last to first do { i = ni - 1; // First, normalise diagonal elements to 1 els = []; np = kp; inverse_elements[i] = []; divisor = M.elements[i][i]; do { p = kp - np; new_element = M.elements[i][p] / divisor; els.push(new_element); // Shuffle of the current row of the right hand side into the results // array as it will not be modified by later runs through this loop if (p >= ki) { inverse_elements[i].push(new_element); } } while (--np); M.elements[i] = els; // Then, subtract this row from those above it to // give the identity matrix on the left hand side for (j = 0; j < i; j++) { els = []; np = kp; do { p = kp - np; els.push(M.elements[j][p] - M.elements[i][p] * M.elements[j][i]); } while (--np); M.elements[j] = els; } } while (--ni); return Matrix.create(inverse_elements); }, inv: function() { return this.inverse(); }, // Returns the result of rounding all the elements round: function() { return this.map(function(x) { return Math.round(x); }); }, // Returns a copy of the matrix with elements set to the given value if they // differ from it by less than Sylvester.precision snapTo: function(x) { return this.map(function(p) { return (Math.abs(p - x) <= Sylvester.precision) ? x : p; }); }, // Returns a string representation of the matrix inspect: function() { var matrix_rows = []; var n = this.elements.length, k = n, i; do { i = k - n; matrix_rows.push(Vector.create(this.elements[i]).inspect()); } while (--n); return matrix_rows.join('\n'); }, // Set the matrix's elements from an array. If the argument passed // is a vector, the resulting matrix will be a single column. setElements: function(els) { var i, elements = els.elements || els; if (typeof(elements[0][0]) != 'undefined') { var ni = elements.length, ki = ni, nj, kj, j; this.elements = []; do { i = ki - ni; nj = elements[i].length; kj = nj; this.elements[i] = []; do { j = kj - nj; this.elements[i][j] = elements[i][j]; } while (--nj); } while(--ni); return this; } var n = elements.length, k = n; this.elements = []; do { i = k - n; this.elements.push([elements[i]]); } while (--n); return this; } }; // Constructor function Matrix.create = function(elements) { var M = new Matrix(); return M.setElements(elements); }; // Identity matrix of size n Matrix.I = function(n) { var els = [], k = n, i, nj, j; do { i = k - n; els[i] = []; nj = k; do { j = k - nj; els[i][j] = (i == j) ? 1 : 0; } while (--nj); } while (--n); return Matrix.create(els); }; // Diagonal matrix - all off-diagonal elements are zero Matrix.Diagonal = function(elements) { var n = elements.length, k = n, i; var M = Matrix.I(n); do { i = k - n; M.elements[i][i] = elements[i]; } while (--n); return M; }; // Rotation matrix about some axis. If no axis is // supplied, assume we're after a 2D transform Matrix.Rotation = function(theta, a) { if (!a) { return Matrix.create([ [Math.cos(theta), -Math.sin(theta)], [Math.sin(theta), Math.cos(theta)] ]); } var axis = a.dup(); if (axis.elements.length != 3) { return null; } var mod = axis.modulus(); var x = axis.elements[0]/mod, y = axis.elements[1]/mod, z = axis.elements[2]/mod; var s = Math.sin(theta), c = Math.cos(theta), t = 1 - c; // Formula derived here: http://www.gamedev.net/reference/articles/article1199.asp // That proof rotates the co-ordinate system so theta // becomes -theta and sin becomes -sin here. return Matrix.create([ [ t*x*x + c, t*x*y - s*z, t*x*z + s*y ], [ t*x*y + s*z, t*y*y + c, t*y*z - s*x ], [ t*x*z - s*y, t*y*z + s*x, t*z*z + c ] ]); }; // Special case rotations Matrix.RotationX = function(t) { var c = Math.cos(t), s = Math.sin(t); return Matrix.create([ [ 1, 0, 0 ], [ 0, c, -s ], [ 0, s, c ] ]); }; Matrix.RotationY = function(t) { var c = Math.cos(t), s = Math.sin(t); return Matrix.create([ [ c, 0, s ], [ 0, 1, 0 ], [ -s, 0, c ] ]); }; Matrix.RotationZ = function(t) { var c = Math.cos(t), s = Math.sin(t); return Matrix.create([ [ c, -s, 0 ], [ s, c, 0 ], [ 0, 0, 1 ] ]); }; // Random matrix of n rows, m columns Matrix.Random = function(n, m) { return Matrix.Zero(n, m).map( function() { return Math.random(); } ); }; // Matrix filled with zeros Matrix.Zero = function(n, m) { var els = [], ni = n, i, nj, j; do { i = n - ni; els[i] = []; nj = m; do { j = m - nj; els[i][j] = 0; } while (--nj); } while (--ni); return Matrix.create(els); }; function Line() {} Line.prototype = { // Returns true if the argument occupies the same space as the line eql: function(line) { return (this.isParallelTo(line) && this.contains(line.anchor)); }, // Returns a copy of the line dup: function() { return Line.create(this.anchor, this.direction); }, // Returns the result of translating the line by the given vector/array translate: function(vector) { var V = vector.elements || vector; return Line.create([ this.anchor.elements[0] + V[0], this.anchor.elements[1] + V[1], this.anchor.elements[2] + (V[2] || 0) ], this.direction); }, // Returns true if the line is parallel to the argument. Here, 'parallel to' // means that the argument's direction is either parallel or antiparallel to // the line's own direction. A line is parallel to a plane if the two do not // have a unique intersection. isParallelTo: function(obj) { if (obj.normal) { return obj.isParallelTo(this); } var theta = this.direction.angleFrom(obj.direction); return (Math.abs(theta) <= Sylvester.precision || Math.abs(theta - Math.PI) <= Sylvester.precision); }, // Returns the line's perpendicular distance from the argument, // which can be a point, a line or a plane distanceFrom: function(obj) { if (obj.normal) { return obj.distanceFrom(this); } if (obj.direction) { // obj is a line if (this.isParallelTo(obj)) { return this.distanceFrom(obj.anchor); } var N = this.direction.cross(obj.direction).toUnitVector().elements; var A = this.anchor.elements, B = obj.anchor.elements; return Math.abs((A[0] - B[0]) * N[0] + (A[1] - B[1]) * N[1] + (A[2] - B[2]) * N[2]); } else { // obj is a point var P = obj.elements || obj; var A = this.anchor.elements, D = this.direction.elements; var PA1 = P[0] - A[0], PA2 = P[1] - A[1], PA3 = (P[2] || 0) - A[2]; var modPA = Math.sqrt(PA1*PA1 + PA2*PA2 + PA3*PA3); if (modPA === 0) return 0; // Assumes direction vector is normalized var cosTheta = (PA1 * D[0] + PA2 * D[1] + PA3 * D[2]) / modPA; var sin2 = 1 - cosTheta*cosTheta; return Math.abs(modPA * Math.sqrt(sin2 < 0 ? 0 : sin2)); } }, // Returns true iff the argument is a point on the line contains: function(point) { var dist = this.distanceFrom(point); return (dist !== null && dist <= Sylvester.precision); }, // Returns true iff the line lies in the given plane liesIn: function(plane) { return plane.contains(this); }, // Returns true iff the line has a unique point of intersection with the argument intersects: function(obj) { if (obj.normal) { return obj.intersects(this); } return (!this.isParallelTo(obj) && this.distanceFrom(obj) <= Sylvester.precision); }, // Returns the unique intersection point with the argument, if one exists intersectionWith: function(obj) { if (obj.normal) { return obj.intersectionWith(this); } if (!this.intersects(obj)) { return null; } var P = this.anchor.elements, X = this.direction.elements, Q = obj.anchor.elements, Y = obj.direction.elements; var X1 = X[0], X2 = X[1], X3 = X[2], Y1 = Y[0], Y2 = Y[1], Y3 = Y[2]; var PsubQ1 = P[0] - Q[0], PsubQ2 = P[1] - Q[1], PsubQ3 = P[2] - Q[2]; var XdotQsubP = - X1*PsubQ1 - X2*PsubQ2 - X3*PsubQ3; var YdotPsubQ = Y1*PsubQ1 + Y2*PsubQ2 + Y3*PsubQ3; var XdotX = X1*X1 + X2*X2 + X3*X3; var YdotY = Y1*Y1 + Y2*Y2 + Y3*Y3; var XdotY = X1*Y1 + X2*Y2 + X3*Y3; var k = (XdotQsubP * YdotY / XdotX + XdotY * YdotPsubQ) / (YdotY - XdotY * XdotY); return Vector.create([P[0] + k*X1, P[1] + k*X2, P[2] + k*X3]); }, // Returns the point on the line that is closest to the given point or line pointClosestTo: function(obj) { if (obj.direction) { // obj is a line if (this.intersects(obj)) { return this.intersectionWith(obj); } if (this.isParallelTo(obj)) { return null; } var D = this.direction.elements, E = obj.direction.elements; var D1 = D[0], D2 = D[1], D3 = D[2], E1 = E[0], E2 = E[1], E3 = E[2]; // Create plane containing obj and the shared normal and intersect this with it // Thank you: http://www.cgafaq.info/wiki/Line-line_distance var x = (D3 * E1 - D1 * E3), y = (D1 * E2 - D2 * E1), z = (D2 * E3 - D3 * E2); var N = Vector.create([x * E3 - y * E2, y * E1 - z * E3, z * E2 - x * E1]); var P = Plane.create(obj.anchor, N); return P.intersectionWith(this); } else { // obj is a point var P = obj.elements || obj; if (this.contains(P)) { return Vector.create(P); } var A = this.anchor.elements, D = this.direction.elements; var D1 = D[0], D2 = D[1], D3 = D[2], A1 = A[0], A2 = A[1], A3 = A[2]; var x = D1 * (P[1]-A2) - D2 * (P[0]-A1), y = D2 * ((P[2] || 0) - A3) - D3 * (P[1]-A2), z = D3 * (P[0]-A1) - D1 * ((P[2] || 0) - A3); var V = Vector.create([D2 * x - D3 * z, D3 * y - D1 * x, D1 * z - D2 * y]); var k = this.distanceFrom(P) / V.modulus(); return Vector.create([ P[0] + V.elements[0] * k, P[1] + V.elements[1] * k, (P[2] || 0) + V.elements[2] * k ]); } }, // Returns a copy of the line rotated by t radians about the given line. Works by // finding the argument's closest point to this line's anchor point (call this C) and // rotating the anchor about C. Also rotates the line's direction about the argument's. // Be careful with this - the rotation axis' direction affects the outcome! rotate: function(t, line) { // If we're working in 2D if (typeof(line.direction) == 'undefined') { line = Line.create(line.to3D(), Vector.k); } var R = Matrix.Rotation(t, line.direction).elements; var C = line.pointClosestTo(this.anchor).elements; var A = this.anchor.elements, D = this.direction.elements; var C1 = C[0], C2 = C[1], C3 = C[2], A1 = A[0], A2 = A[1], A3 = A[2]; var x = A1 - C1, y = A2 - C2, z = A3 - C3; return Line.create([ C1 + R[0][0] * x + R[0][1] * y + R[0][2] * z, C2 + R[1][0] * x + R[1][1] * y + R[1][2] * z, C3 + R[2][0] * x + R[2][1] * y + R[2][2] * z ], [ R[0][0] * D[0] + R[0][1] * D[1] + R[0][2] * D[2], R[1][0] * D[0] + R[1][1] * D[1] + R[1][2] * D[2], R[2][0] * D[0] + R[2][1] * D[1] + R[2][2] * D[2] ]); }, // Returns the line's reflection in the given point or line reflectionIn: function(obj) { if (obj.normal) { // obj is a plane var A = this.anchor.elements, D = this.direction.elements; var A1 = A[0], A2 = A[1], A3 = A[2], D1 = D[0], D2 = D[1], D3 = D[2]; var newA = this.anchor.reflectionIn(obj).elements; // Add the line's direction vector to its anchor, then mirror that in the plane var AD1 = A1 + D1, AD2 = A2 + D2, AD3 = A3 + D3; var Q = obj.pointClosestTo([AD1, AD2, AD3]).elements; var newD = [Q[0] + (Q[0] - AD1) - newA[0], Q[1] + (Q[1] - AD2) - newA[1], Q[2] + (Q[2] - AD3) - newA[2]]; return Line.create(newA, newD); } else if (obj.direction) { // obj is a line - reflection obtained by rotating PI radians about obj return this.rotate(Math.PI, obj); } else { // obj is a point - just reflect the line's anchor in it var P = obj.elements || obj; return Line.create(this.anchor.reflectionIn([P[0], P[1], (P[2] || 0)]), this.direction); } }, // Set the line's anchor point and direction. setVectors: function(anchor, direction) { // Need to do this so that line's properties are not // references to the arguments passed in anchor = Vector.create(anchor); direction = Vector.create(direction); if (anchor.elements.length == 2) {anchor.elements.push(0); } if (direction.elements.length == 2) { direction.elements.push(0); } if (anchor.elements.length > 3 || direction.elements.length > 3) { return null; } var mod = direction.modulus(); if (mod === 0) { return null; } this.anchor = anchor; this.direction = Vector.create([ direction.elements[0] / mod, direction.elements[1] / mod, direction.elements[2] / mod ]); return this; } }; // Constructor function Line.create = function(anchor, direction) { var L = new Line(); return L.setVectors(anchor, direction); }; // Axes Line.X = Line.create(Vector.Zero(3), Vector.i); Line.Y = Line.create(Vector.Zero(3), Vector.j); Line.Z = Line.create(Vector.Zero(3), Vector.k); function Plane() {} Plane.prototype = { // Returns true iff the plane occupies the same space as the argument eql: function(plane) { return (this.contains(plane.anchor) && this.isParallelTo(plane)); }, // Returns a copy of the plane dup: function() { return Plane.create(this.anchor, this.normal); }, // Returns the result of translating the plane by the given vector translate: function(vector) { var V = vector.elements || vector; return Plane.create([ this.anchor.elements[0] + V[0], this.anchor.elements[1] + V[1], this.anchor.elements[2] + (V[2] || 0) ], this.normal); }, // Returns true iff the plane is parallel to the argument. Will return true // if the planes are equal, or if you give a line and it lies in the plane. isParallelTo: function(obj) { var theta; if (obj.normal) { // obj is a plane theta = this.normal.angleFrom(obj.normal); return (Math.abs(theta) <= Sylvester.precision || Math.abs(Math.PI - theta) <= Sylvester.precision); } else if (obj.direction) { // obj is a line return this.normal.isPerpendicularTo(obj.direction); } return null; }, // Returns true iff the receiver is perpendicular to the argument isPerpendicularTo: function(plane) { var theta = this.normal.angleFrom(plane.normal); return (Math.abs(Math.PI/2 - theta) <= Sylvester.precision); }, // Returns the plane's distance from the given object (point, line or plane) distanceFrom: function(obj) { if (this.intersects(obj) || this.contains(obj)) { return 0; } if (obj.anchor) { // obj is a plane or line var A = this.anchor.elements, B = obj.anchor.elements, N = this.normal.elements; return Math.abs((A[0] - B[0]) * N[0] + (A[1] - B[1]) * N[1] + (A[2] - B[2]) * N[2]); } else { // obj is a point var P = obj.elements || obj; var A = this.anchor.elements, N = this.normal.elements; return Math.abs((A[0] - P[0]) * N[0] + (A[1] - P[1]) * N[1] + (A[2] - (P[2] || 0)) * N[2]); } }, // Returns true iff the plane contains the given point or line contains: function(obj) { if (obj.normal) { return null; } if (obj.direction) { return (this.contains(obj.anchor) && this.contains(obj.anchor.add(obj.direction))); } else { var P = obj.elements || obj; var A = this.anchor.elements, N = this.normal.elements; var diff = Math.abs(N[0]*(A[0] - P[0]) + N[1]*(A[1] - P[1]) + N[2]*(A[2] - (P[2] || 0))); return (diff <= Sylvester.precision); } }, // Returns true iff the plane has a unique point/line of intersection with the argument intersects: function(obj) { if (typeof(obj.direction) == 'undefined' && typeof(obj.normal) == 'undefined') { return null; } return !this.isParallelTo(obj); }, // Returns the unique intersection with the argument, if one exists. The result // will be a vector if a line is supplied, and a line if a plane is supplied. intersectionWith: function(obj) { if (!this.intersects(obj)) { return null; } if (obj.direction) { // obj is a line var A = obj.anchor.elements, D = obj.direction.elements, P = this.anchor.elements, N = this.normal.elements; var multiplier = (N[0]*(P[0]-A[0]) + N[1]*(P[1]-A[1]) + N[2]*(P[2]-A[2])) / (N[0]*D[0] + N[1]*D[1] + N[2]*D[2]); return Vector.create([A[0] + D[0]*multiplier, A[1] + D[1]*multiplier, A[2] + D[2]*multiplier]); } else if (obj.normal) { // obj is a plane var direction = this.normal.cross(obj.normal).toUnitVector(); // To find an anchor point, we find one co-ordinate that has a value // of zero somewhere on the intersection, and remember which one we picked var N = this.normal.elements, A = this.anchor.elements, O = obj.normal.elements, B = obj.anchor.elements; var solver = Matrix.Zero(2,2), i = 0; while (solver.isSingular()) { i++; solver = Matrix.create([ [ N[i%3], N[(i+1)%3] ], [ O[i%3], O[(i+1)%3] ] ]); } // Then we solve the simultaneous equations in the remaining dimensions var inverse = solver.inverse().elements; var x = N[0]*A[0] + N[1]*A[1] + N[2]*A[2]; var y = O[0]*B[0] + O[1]*B[1] + O[2]*B[2]; var intersection = [ inverse[0][0] * x + inverse[0][1] * y, inverse[1][0] * x + inverse[1][1] * y ]; var anchor = []; for (var j = 1; j <= 3; j++) { // This formula picks the right element from intersection by // cycling depending on which element we set to zero above anchor.push((i == j) ? 0 : intersection[(j + (5 - i)%3)%3]); } return Line.create(anchor, direction); } }, // Returns the point in the plane closest to the given point pointClosestTo: function(point) { var P = point.elements || point; var A = this.anchor.elements, N = this.normal.elements; var dot = (A[0] - P[0]) * N[0] + (A[1] - P[1]) * N[1] + (A[2] - (P[2] || 0)) * N[2]; return Vector.create([P[0] + N[0] * dot, P[1] + N[1] * dot, (P[2] || 0) + N[2] * dot]); }, // Returns a copy of the plane, rotated by t radians about the given line // See notes on Line#rotate. rotate: function(t, line) { var R = Matrix.Rotation(t, line.direction).elements; var C = line.pointClosestTo(this.anchor).elements; var A = this.anchor.elements, N = this.normal.elements; var C1 = C[0], C2 = C[1], C3 = C[2], A1 = A[0], A2 = A[1], A3 = A[2]; var x = A1 - C1, y = A2 - C2, z = A3 - C3; return Plane.create([ C1 + R[0][0] * x + R[0][1] * y + R[0][2] * z, C2 + R[1][0] * x + R[1][1] * y + R[1][2] * z, C3 + R[2][0] * x + R[2][1] * y + R[2][2] * z ], [ R[0][0] * N[0] + R[0][1] * N[1] + R[0][2] * N[2], R[1][0] * N[0] + R[1][1] * N[1] + R[1][2] * N[2], R[2][0] * N[0] + R[2][1] * N[1] + R[2][2] * N[2] ]); }, // Returns the reflection of the plane in the given point, line or plane. reflectionIn: function(obj) { if (obj.normal) { // obj is a plane var A = this.anchor.elements, N = this.normal.elements; var A1 = A[0], A2 = A[1], A3 = A[2], N1 = N[0], N2 = N[1], N3 = N[2]; var newA = this.anchor.reflectionIn(obj).elements; // Add the plane's normal to its anchor, then mirror that in the other plane var AN1 = A1 + N1, AN2 = A2 + N2, AN3 = A3 + N3; var Q = obj.pointClosestTo([AN1, AN2, AN3]).elements; var newN = [Q[0] + (Q[0] - AN1) - newA[0], Q[1] + (Q[1] - AN2) - newA[1], Q[2] + (Q[2] - AN3) - newA[2]]; return Plane.create(newA, newN); } else if (obj.direction) { // obj is a line return this.rotate(Math.PI, obj); } else { // obj is a point var P = obj.elements || obj; return Plane.create(this.anchor.reflectionIn([P[0], P[1], (P[2] || 0)]), this.normal); } }, // Sets the anchor point and normal to the plane. If three arguments are specified, // the normal is calculated by assuming the three points should lie in the same plane. // If only two are sepcified, the second is taken to be the normal. Normal vector is // normalised before storage. setVectors: function(anchor, v1, v2) { anchor = Vector.create(anchor); anchor = anchor.to3D(); if (anchor === null) { return null; } v1 = Vector.create(v1); v1 = v1.to3D(); if (v1 === null) { return null; } if (typeof(v2) == 'undefined') { v2 = null; } else { v2 = Vector.create(v2); v2 = v2.to3D(); if (v2 === null) { return null; } } var A1 = anchor.elements[0], A2 = anchor.elements[1], A3 = anchor.elements[2]; var v11 = v1.elements[0], v12 = v1.elements[1], v13 = v1.elements[2]; var normal, mod; if (v2 !== null) { var v21 = v2.elements[0], v22 = v2.elements[1], v23 = v2.elements[2]; normal = Vector.create([ (v12 - A2) * (v23 - A3) - (v13 - A3) * (v22 - A2), (v13 - A3) * (v21 - A1) - (v11 - A1) * (v23 - A3), (v11 - A1) * (v22 - A2) - (v12 - A2) * (v21 - A1) ]); mod = normal.modulus(); if (mod === 0) { return null; } normal = Vector.create([normal.elements[0] / mod, normal.elements[1] / mod, normal.elements[2] / mod]); } else { mod = Math.sqrt(v11*v11 + v12*v12 + v13*v13); if (mod === 0) { return null; } normal = Vector.create([v1.elements[0] / mod, v1.elements[1] / mod, v1.elements[2] / mod]); } this.anchor = anchor; this.normal = normal; return this; } }; // Constructor function Plane.create = function(anchor, v1, v2) { var P = new Plane(); return P.setVectors(anchor, v1, v2); }; // X-Y-Z planes Plane.XY = Plane.create(Vector.Zero(3), Vector.k); Plane.YZ = Plane.create(Vector.Zero(3), Vector.i); Plane.ZX = Plane.create(Vector.Zero(3), Vector.j); Plane.YX = Plane.XY; Plane.ZY = Plane.YZ; Plane.XZ = Plane.ZX; // Utility functions var $V = Vector.create; var $M = Matrix.create; var $L = Line.create; var $P = Plane.create; ================================================ FILE: datalab/notebook/static/job.css ================================================ /* * Copyright 2015 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ p.jobfail { color: red; } p.jobsucceed { color: green; } p.jobfooter { font-size: smaller; } ================================================ FILE: datalab/notebook/static/job.ts ================================================ /* * Copyright 2015 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ /// declare var datalab: any; declare var IPython: any; module Job { function refresh(dom: any, job_name: any, job_type: any, interval: any, html_on_running: string, html_on_success: string): any { var code = '%_get_job_status ' + job_name + ' ' + job_type; datalab.session.execute(code, function (error: any, newData: any) { error = error || newData.error; if (error) { dom.innerHTML = '

Job failed with error: ' + error + '

'; return; } if (!newData.exists) { dom.innerHTML = '

The job does not exist.

'; } else if (newData.done) { dom.innerHTML = '

Job completed successfully.


' + html_on_success; } else { dom.innerHTML = 'Running...

Updated at ' + new Date().toLocaleTimeString() + '

' + html_on_running; setTimeout(function() { refresh(dom, job_name, job_type, interval, html_on_running, html_on_success); }, interval * 1000); } }); } // Render the job view. This is called from Python generated code. export function render(dom: any, events: any, job_name: string, job_type: string, interval: any, html_on_running: string, html_on_success: string) { if (IPython.notebook.kernel.is_connected()) { refresh(dom, job_name, job_type, interval, html_on_running, html_on_success); return; } // If the kernel is not connected, wait for the event. events.on('kernel_ready.Kernel', function(e: any) { refresh(dom, job_name, job_type, interval, html_on_running, html_on_success); }); } } export = Job; ================================================ FILE: datalab/notebook/static/parcoords.ts ================================================ /* * Copyright 2016 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ /// module ParCoords { function getCentroids(data: any, graph: any): any { var margins = graph.margin(); var graphCentPts: any[] = []; data.forEach(function(d: any){ var initCenPts = graph.compute_centroids(d).filter(function(d: any, i: number) {return i%2==0;}); var cenPts = initCenPts.map(function(d: any){ return [d[0] + margins["left"], d[1]+ margins["top"]]; }); graphCentPts.push(cenPts); }); return graphCentPts; } function getActiveData(graph: any): any{ if (graph.brushed()!=false) return graph.brushed(); return graph.data(); } function findAxes(testPt: any, cenPts: any): number { var x: number = testPt[0]; var y: number = testPt[1]; if (cenPts[0][0] > x) return 0; if (cenPts[cenPts.length-1][0] < x) return 0; for (var i=0; i x) return i; } return 0; } function isOnLine(startPt: any, endPt: any, testPt: any, tol: number){ var x0 = testPt[0]; var y0 = testPt[1]; var x1 = startPt[0]; var y1 = startPt[1]; var x2 = endPt[0]; var y2 = endPt[1]; var Dx = x2 - x1; var Dy = y2 - y1; var delta = Math.abs(Dy*x0 - Dx*y0 - x1*y2+x2*y1)/Math.sqrt(Math.pow(Dx, 2) + Math.pow(Dy, 2)); if (delta <= tol) return true; return false; } function getClickedLines(mouseClick: any, graph: any): any { var clicked: any[] = []; var clickedCenPts: any[] = []; // find which data is activated right now var activeData: any = getActiveData(graph); // find centriod points var graphCentPts: any = getCentroids(activeData, graph); if (graphCentPts.length==0) return false; // find between which axes the point is var axeNum: number = findAxes(mouseClick, graphCentPts[0]); if (!axeNum) return false; graphCentPts.forEach(function(d: any, i: number){ if (isOnLine(d[axeNum-1], d[axeNum], mouseClick, 2)) { clicked.push(activeData[i]); clickedCenPts.push(graphCentPts[i]); // for tooltip } }); return [clicked, clickedCenPts] } function highlightLineOnClick(mouseClick: any, graph: any) { var clicked: any[] = []; var clickedCenPts: any[] = []; var clickedData: any = getClickedLines(mouseClick, graph); if (clickedData && clickedData[0].length!=0){ clicked = clickedData[0]; clickedCenPts = clickedData[1]; // highlight clicked line graph.highlight(clicked); } }; export function plot(d3: any, color_domain: number[], maximize: boolean, data: any, graph_html_id: string, grid_html_id: string) { var range = ["green", "gray"]; if (maximize) { range = ["gray", "green"]; } var blue_to_brown = d3.scale.linear().domain(color_domain) .range(range) .interpolate(d3.interpolateLab); var color = function(d: any) { return blue_to_brown(d['Objective']); }; var columns_hide: string[] = ["Trial", "Training Step"]; for (var attr in data) { if (attr.lastIndexOf("(log)") > 0) { columns_hide.push(attr.slice(0, -5)); } } var data_display: any[] = []; for (var i: number =0; i module Style { 'use strict'; // An object containing the set of loaded stylesheets, so as to avoid reloading. var loadedStyleSheets: any = {}; // An object containing stylesheets to load, once the DOM is ready. var pendingStyleSheets: Array = null; function addStyleSheet(url: string): void { loadedStyleSheets[url] = true; var stylesheet = document.createElement('link'); stylesheet.type = 'text/css'; stylesheet.rel = 'stylesheet'; stylesheet.href = url; document.getElementsByTagName('head')[0].appendChild(stylesheet); } function domReadyCallback(): void { if (pendingStyleSheets) { // Clear out pendingStyleSheets, so any future adds are immediately processed. var styleSheets: Array = pendingStyleSheets; pendingStyleSheets = null; styleSheets.forEach(addStyleSheet); } } export function load(url: string, req: any, loadCallback: any, config: any): void { if (config.isBuild) { loadCallback(null); } else { // Go ahead and immediately/optimistically resolve this, since the resolved value of a // stylesheet is never interesting. setTimeout(loadCallback, 0); // Only load a specified stylesheet once for the lifetime of this page. if (loadedStyleSheets[url]) { return; } loadedStyleSheets[url] = true; if (document.readyState == 'loading') { if (!pendingStyleSheets) { pendingStyleSheets = []; document.addEventListener('DOMContentLoaded', domReadyCallback, false); } pendingStyleSheets.push(url); } else { addStyleSheet(url); } } } } export = Style; ================================================ FILE: datalab/notebook/static/visualization.ts ================================================ /* * Copyright 2015 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ // require.js plugin to allow Google Chart API to be loaded. /// declare var google: any; declare var window: any; module Visualization { 'use strict'; // Queued packages to load until the google api loader itself has not been loaded. var queue: any = { packages: [], callbacks: [] }; function loadGoogleApiLoader(callback: any): void { // Visualization packages are loaded using the Google loader. // The loader URL itself must contain a callback (by name) that it invokes when its loaded. var callbackName: string = '__googleApiLoaderCallback'; window[callbackName] = callback; var script = document.createElement('script'); script.type = 'text/javascript'; script.async = true; script.src = 'https://www.google.com/jsapi?callback=' + callbackName; document.getElementsByTagName('head')[0].appendChild(script); } function invokeVisualizationCallback(cb: any) { cb(google.visualization); } function loadVisualizationPackages(names: any, callbacks: any): void { if (names.length) { var visualizationOptions = { packages: names, callback: function() { callbacks.forEach(invokeVisualizationCallback); } }; google.load('visualization', '1', visualizationOptions); } } loadGoogleApiLoader(function() { if (queue) { loadVisualizationPackages(queue.packages, queue.callbacks); queue = null; } }); export function load(name: any, req: any, callback: any, config: any) { if (config.isBuild) { callback(null); } else { if (queue) { // Queue the package and associated callback to load, once the loader has been loaded. queue.packages.push(name); queue.callbacks.push(callback); } else { // Loader has already been loaded, so go ahead and load the specified package. loadVisualizationPackages([ name ], [ callback ]); } } } } export = Visualization; ================================================ FILE: datalab/stackdriver/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - Stackdriver Functionality.""" ================================================ FILE: datalab/stackdriver/commands/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from . import _monitoring __all__ = ['_monitoring'] ================================================ FILE: datalab/stackdriver/commands/_monitoring.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """IPython Functionality for the Google Monitoring API.""" from __future__ import absolute_import try: import IPython.core.display import IPython.core.magic except ImportError: raise Exception('This module can only be loaded in ipython.') import datalab.stackdriver.monitoring as gcm import datalab.utils.commands @IPython.core.magic.register_line_cell_magic def monitoring(line, cell=None): """Implements the monitoring cell magic for ipython notebooks. Args: line: the contents of the storage line. Returns: The results of executing the cell. """ parser = datalab.utils.commands.CommandParser(prog='monitoring', description=( 'Execute various Monitoring-related operations. Use "%monitoring ' ' -h" for help on a specific command.')) list_parser = parser.subcommand( 'list', 'List the metrics or resource types in a monitored project.') list_metric_parser = list_parser.subcommand( 'metrics', 'List the metrics that are available through the Monitoring API.') list_metric_parser.add_argument( '-t', '--type', help='The type of metric(s) to list; can include wildchars.') list_metric_parser.add_argument( '-p', '--project', help='The project on which to execute the request.') list_metric_parser.set_defaults(func=_list_metric_descriptors) list_resource_parser = list_parser.subcommand( 'resource_types', ('List the monitored resource types that are available through the ' 'Monitoring API.')) list_resource_parser.add_argument( '-p', '--project', help='The project on which to execute the request.') list_resource_parser.add_argument( '-t', '--type', help='The resource type(s) to list; can include wildchars.') list_resource_parser.set_defaults(func=_list_resource_descriptors) list_group_parser = list_parser.subcommand( 'groups', ('List the Stackdriver groups in this project.')) list_group_parser.add_argument( '-p', '--project', help='The project on which to execute the request.') list_group_parser.add_argument( '-n', '--name', help='The name of the group(s) to list; can include wildchars.') list_group_parser.set_defaults(func=_list_groups) return datalab.utils.commands.handle_magic_line(line, cell, parser) def _list_metric_descriptors(args, _): """Lists the metric descriptors in the project.""" project_id = args['project'] pattern = args['type'] or '*' descriptors = gcm.MetricDescriptors(project_id=project_id) dataframe = descriptors.as_dataframe(pattern=pattern) return _render_dataframe(dataframe) def _list_resource_descriptors(args, _): """Lists the resource descriptors in the project.""" project_id = args['project'] pattern = args['type'] or '*' descriptors = gcm.ResourceDescriptors(project_id=project_id) dataframe = descriptors.as_dataframe(pattern=pattern) return _render_dataframe(dataframe) def _list_groups(args, _): """Lists the groups in the project.""" project_id = args['project'] pattern = args['name'] or '*' groups = gcm.Groups(project_id=project_id) dataframe = groups.as_dataframe(pattern=pattern) return _render_dataframe(dataframe) def _render_dataframe(dataframe): """Helper to render a dataframe as an HTML table.""" data = dataframe.to_dict(orient='records') fields = dataframe.columns.tolist() return IPython.core.display.HTML( datalab.utils.commands.HtmlBuilder.render_table(data, fields)) ================================================ FILE: datalab/stackdriver/monitoring/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - Monitoring Functionality.""" from __future__ import absolute_import from google.cloud.monitoring import enums from ._group import Groups from ._metric import MetricDescriptors from ._query import Query from ._query_metadata import QueryMetadata from ._resource import ResourceDescriptors Aligner = enums.Aggregation.Aligner Reducer = enums.Aggregation.Reducer __all__ = ['Aligner', 'Reducer', 'Groups', 'MetricDescriptors', 'Query', 'QueryMetadata', 'ResourceDescriptors'] ================================================ FILE: datalab/stackdriver/monitoring/_group.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Groups for the Google Monitoring API.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object import collections import fnmatch import pandas import google.datalab from . import _utils class Groups(object): """Represents a list of Stackdriver groups for a project.""" _DISPLAY_HEADERS = ('Group ID', 'Group name', 'Parent ID', 'Parent name', 'Is cluster', 'Filter') def __init__(self, context=None): """Initializes the Groups for a Stackdriver project. Args: context: An optional Context object to use instead of the global default. """ self._context = context or google.datalab.Context.default() self._client = _utils.make_client(self._context) self._group_dict = None def list(self, pattern='*'): """Returns a list of groups that match the filters. Args: pattern: An optional pattern to filter the groups based on their display name. This can include Unix shell-style wildcards. E.g. ``"Production*"``. Returns: A list of Group objects that match the filters. """ if self._group_dict is None: self._group_dict = collections.OrderedDict( (group.name, group) for group in self._client.list_groups()) return [group for group in self._group_dict.values() if fnmatch.fnmatch(group.display_name, pattern)] def as_dataframe(self, pattern='*', max_rows=None): """Creates a pandas dataframe from the groups that match the filters. Args: pattern: An optional pattern to further filter the groups. This can include Unix shell-style wildcards. E.g. ``"Production *"``, ``"*-backend"``. max_rows: The maximum number of groups to return. If None, return all. Returns: A pandas dataframe containing matching groups. """ data = [] for i, group in enumerate(self.list(pattern)): if max_rows is not None and i >= max_rows: break parent = self._group_dict.get(group.parent_name) parent_display_name = '' if parent is None else parent.display_name data.append([ group.name, group.display_name, group.parent_name, parent_display_name, group.is_cluster, group.filter]) return pandas.DataFrame(data, columns=self._DISPLAY_HEADERS) ================================================ FILE: datalab/stackdriver/monitoring/_metric.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Provides the MetricDescriptors in the monitoring API.""" from __future__ import absolute_import from builtins import object from google.cloud.monitoring_v3 import enums import fnmatch import pandas from . import _utils class MetricDescriptors(object): """MetricDescriptors object for retrieving the metric descriptors.""" _DISPLAY_HEADERS = ('Metric type', 'Display name', 'Kind', 'Value', 'Unit', 'Labels') def __init__(self, filter_string=None, type_prefix=None, context=None): """Initializes the MetricDescriptors based on the specified filters. Args: filter_string: An optional filter expression describing the resource descriptors to be returned. type_prefix: An optional prefix constraining the selected metric types. This adds ``metric.type = starts_with("")`` to the filter. context: An optional Context object to use instead of the global default. """ self._client = _utils.make_client(context) self._filter_string = filter_string self._type_prefix = type_prefix self._descriptors = None def list(self, pattern='*'): """Returns a list of metric descriptors that match the filters. Args: pattern: An optional pattern to further filter the descriptors. This can include Unix shell-style wildcards. E.g. ``"compute*"``, ``"*cpu/load_??m"``. Returns: A list of MetricDescriptor objects that match the filters. """ if self._descriptors is None: self._descriptors = self._client.list_metric_descriptors( filter_string=self._filter_string, type_prefix=self._type_prefix) return [metric for metric in self._descriptors if fnmatch.fnmatch(metric.type, pattern)] def as_dataframe(self, pattern='*', max_rows=None): """Creates a pandas dataframe from the descriptors that match the filters. Args: pattern: An optional pattern to further filter the descriptors. This can include Unix shell-style wildcards. E.g. ``"compute*"``, ``"*/cpu/load_??m"``. max_rows: The maximum number of descriptors to return. If None, return all. Returns: A pandas dataframe containing matching metric descriptors. """ data = [] for i, metric in enumerate(self.list(pattern)): if max_rows is not None and i >= max_rows: break labels = ', '. join([l.key for l in metric.labels]) data.append([ metric.type, metric.display_name, enums.MetricDescriptor.MetricKind(metric.metric_kind).name, enums.MetricDescriptor.ValueType(metric.value_type).name, metric.unit, labels]) return pandas.DataFrame(data, columns=self._DISPLAY_HEADERS) ================================================ FILE: datalab/stackdriver/monitoring/_query.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Provides access to metric data as pandas dataframes.""" from __future__ import absolute_import import google.cloud.monitoring_v3.query from . import _query_metadata from . import _utils class Query(google.cloud.monitoring_v3.query.Query): """Query object for retrieving metric data.""" def __init__(self, metric_type=google.cloud.monitoring_v3.query.Query.DEFAULT_METRIC_TYPE, end_time=None, days=0, hours=0, minutes=0, context=None): """Initializes the core query parameters. The start time (exclusive) is determined by combining the values of ``days``, ``hours``, and ``minutes``, and subtracting the resulting duration from the end time. It is also allowed to omit the end time and duration here, in which case :meth:`~google.cloud.monitoring_v3.query.Query.select_interval` must be called before the query is executed. Args: metric_type: The metric type name. The default value is :data:`Query.DEFAULT_METRIC_TYPE `, but please note that this default value is provided only for demonstration purposes and is subject to change. end_time: The end time (inclusive) of the time interval for which results should be returned, as a datetime object. The default is the start of the current minute. days: The number of days in the time interval. hours: The number of hours in the time interval. minutes: The number of minutes in the time interval. context: An optional Context object to use instead of the global default. Raises: ValueError: ``end_time`` was specified but ``days``, ``hours``, and ``minutes`` are all zero. If you really want to specify a point in time, use :meth:`~google.cloud.monitoring_v3.query.Query.select_interval`. """ client = _utils.make_client(context) super(Query, self).__init__(client.metrics_client, project=client.project, metric_type=metric_type, end_time=end_time, days=days, hours=hours, minutes=minutes) def metadata(self): """Retrieves the metadata for the query.""" return _query_metadata.QueryMetadata(self) ================================================ FILE: datalab/stackdriver/monitoring/_query_metadata.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """QueryMetadata object that shows the metadata in a query's results.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object from google.cloud.monitoring_v3 import _dataframe from google.protobuf.json_format import MessageToDict import pandas class QueryMetadata(object): """QueryMetadata object contains the metadata of a timeseries query.""" def __init__(self, query): """Initializes the QueryMetadata given the query object. Args: query: A Query object. """ self._timeseries_list = list(query.iter(headers_only=True)) # Note: If self._timeseries_list has even one entry, the metric type # can be extracted from there as well. self._metric_type = query.metric_type def __iter__(self): for timeseries in self._timeseries_list: yield timeseries @property def metric_type(self): """Returns the metric type in the underlying query.""" return self._metric_type @property def resource_types(self): """Returns a set containing resource types in the query result.""" return set([ts.resource.type for ts in self._timeseries_list]) def as_dataframe(self, max_rows=None): """Creates a pandas dataframe from the query metadata. Args: max_rows: The maximum number of timeseries metadata to return. If None, return all. Returns: A pandas dataframe containing the resource type, resource labels and metric labels. Each row in this dataframe corresponds to the metadata from one time series. """ max_rows = len(self._timeseries_list) if max_rows is None else max_rows headers = [{ 'resource': MessageToDict(ts.resource), 'metric': MessageToDict(ts.metric) } for ts in self._timeseries_list[:max_rows]] if not headers: return pandas.DataFrame() dataframe = pandas.io.json.json_normalize(headers) # Add a 2 level column header. dataframe.columns = pandas.MultiIndex.from_tuples( [(col, '') if col == 'resource.type' else col.rsplit('.', 1) for col in dataframe.columns]) # Re-order the columns. resource_keys = _dataframe._sorted_resource_labels( dataframe['resource.labels'].columns) sorted_columns = [('resource.type', '')] sorted_columns += [('resource.labels', key) for key in resource_keys] sorted_columns += sorted(col for col in dataframe.columns if col[0] == 'metric.labels') dataframe = dataframe[sorted_columns] # Sort the data, and clean up index values, and NaNs. dataframe = dataframe.sort_values(sorted_columns) dataframe = dataframe.reset_index(drop=True).fillna('') return dataframe ================================================ FILE: datalab/stackdriver/monitoring/_resource.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Provides the ResourceDescriptors in the monitoring API.""" from __future__ import absolute_import from builtins import object import fnmatch import pandas from . import _utils class ResourceDescriptors(object): """ResourceDescriptors object for retrieving the resource descriptors.""" _DISPLAY_HEADERS = ('Resource type', 'Display name', 'Labels') def __init__(self, filter_string=None, context=None): """Initializes the ResourceDescriptors based on the specified filters. Args: filter_string: An optional filter expression describing the resource descriptors to be returned. context: An optional Context object to use instead of the global default. """ self._client = _utils.make_client(context) self._filter_string = filter_string self._descriptors = None def list(self, pattern='*'): """Returns a list of resource descriptors that match the filters. Args: pattern: An optional pattern to further filter the descriptors. This can include Unix shell-style wildcards. E.g. ``"aws*"``, ``"*cluster*"``. Returns: A list of ResourceDescriptor objects that match the filters. """ if self._descriptors is None: self._descriptors = self._client.list_resource_descriptors( filter_string=self._filter_string) return [resource for resource in self._descriptors if fnmatch.fnmatch(resource.type, pattern)] def as_dataframe(self, pattern='*', max_rows=None): """Creates a pandas dataframe from the descriptors that match the filters. Args: pattern: An optional pattern to further filter the descriptors. This can include Unix shell-style wildcards. E.g. ``"aws*"``, ``"*cluster*"``. max_rows: The maximum number of descriptors to return. If None, return all. Returns: A pandas dataframe containing matching resource descriptors. """ data = [] for i, resource in enumerate(self.list(pattern)): if max_rows is not None and i >= max_rows: break labels = ', '. join([l.key for l in resource.labels]) data.append([resource.type, resource.display_name, labels]) return pandas.DataFrame(data, columns=self._DISPLAY_HEADERS) ================================================ FILE: datalab/stackdriver/monitoring/_utils.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Provides utility methods for the Monitoring API.""" from __future__ import absolute_import from google.api_core.gapic_v1.client_info import ClientInfo from google.cloud.monitoring_v3 import MetricServiceClient from google.cloud.monitoring_v3 import GroupServiceClient import google.datalab # _MonitoringClient holds instances of individual google.cloud.monitoring # clients and translates each call from the old signature, since the prior # client has been updated and has split into multiple client classes. class _MonitoringClient(object): def __init__(self, context): self.project = context.project_id client_info = ClientInfo(user_agent='pydatalab/v0') self.metrics_client = MetricServiceClient( credentials=context.credentials, client_info=client_info ) self.group_client = GroupServiceClient( credentials=context.credentials, client_info=client_info ) def list_metric_descriptors(self, filter_string=None, type_prefix=None): filters = [] if filter_string is not None: filters.append(filter_string) if type_prefix is not None: filters.append('metric.type = starts_with("{prefix}")'.format( prefix=type_prefix)) metric_filter = ' AND '.join(filters) metrics = self.metrics_client.list_metric_descriptors( self.project, filter_=metric_filter) return metrics def list_resource_descriptors(self, filter_string=None): resources = self.metrics_client.list_monitored_resource_descriptors( self.project, filter_=filter_string) return resources def list_groups(self): groups = self.group_client.list_groups(self.project) return groups def make_client(context=None): context = context or google.datalab.Context.default() client = _MonitoringClient(context) return client ================================================ FILE: datalab/storage/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - Cloud Storage Functionality.""" from __future__ import absolute_import from ._bucket import Bucket, Buckets from ._item import Item, Items __all__ = ['Bucket', 'Buckets', 'Item', 'Items'] ================================================ FILE: datalab/storage/_api.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Storage HTTP API wrapper.""" from __future__ import absolute_import from __future__ import unicode_literals from future import standard_library standard_library.install_aliases() # noqa from builtins import object import urllib.request import urllib.parse import urllib.error import datalab.context import datalab.utils class Api(object): """A helper class to issue Storage HTTP requests.""" # TODO(nikhilko): Use named placeholders in these string templates. _ENDPOINT = 'https://www.googleapis.com/storage/v1' _DOWNLOAD_ENDPOINT = 'https://www.googleapis.com/download/storage/v1' _UPLOAD_ENDPOINT = 'https://www.googleapis.com/upload/storage/v1' _BUCKET_PATH = '/b/%s' _OBJECT_PATH = '/b/%s/o/%s' _OBJECT_COPY_PATH = '/b/%s/o/%s/copyTo/b/%s/o/%s' _MAX_RESULTS = 100 def __init__(self, context): """Initializes the Storage helper with context information. Args: context: a Context object providing project_id and credentials. """ self._credentials = context.credentials self._project_id = context.project_id @property def project_id(self): """The project_id associated with this API client.""" return self._project_id def buckets_insert(self, bucket, project_id=None): """Issues a request to create a new bucket. Args: bucket: the name of the bucket. project_id: the project to use when inserting the bucket. Returns: A parsed bucket information dictionary. Raises: Exception if there is an error performing the operation. """ args = {'project': project_id if project_id else self._project_id} data = {'name': bucket} url = Api._ENDPOINT + (Api._BUCKET_PATH % '') return datalab.utils.Http.request(url, args=args, data=data, credentials=self._credentials) def buckets_delete(self, bucket): """Issues a request to delete a bucket. Args: bucket: the name of the bucket. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._BUCKET_PATH % bucket) datalab.utils.Http.request(url, method='DELETE', credentials=self._credentials, raw_response=True) def buckets_get(self, bucket, projection='noAcl'): """Issues a request to retrieve information about a bucket. Args: bucket: the name of the bucket. projection: the projection of the bucket information to retrieve. Returns: A parsed bucket information dictionary. Raises: Exception if there is an error performing the operation. """ args = {'projection': projection} url = Api._ENDPOINT + (Api._BUCKET_PATH % bucket) return datalab.utils.Http.request(url, credentials=self._credentials, args=args) def buckets_list(self, projection='noAcl', max_results=0, page_token=None, project_id=None): """Issues a request to retrieve the list of buckets. Args: projection: the projection of the bucket information to retrieve. max_results: an optional maximum number of objects to retrieve. page_token: an optional token to continue the retrieval. project_id: the project whose buckets should be listed. Returns: A parsed list of bucket information dictionaries. Raises: Exception if there is an error performing the operation. """ if max_results == 0: max_results = Api._MAX_RESULTS args = {'project': project_id if project_id else self._project_id, 'maxResults': max_results} if projection is not None: args['projection'] = projection if page_token is not None: args['pageToken'] = page_token url = Api._ENDPOINT + (Api._BUCKET_PATH % '') return datalab.utils.Http.request(url, args=args, credentials=self._credentials) def object_download(self, bucket, key, start_offset=0, byte_count=None): """Reads the contents of an object as text. Args: bucket: the name of the bucket containing the object. key: the key of the object to be read. start_offset: the start offset of bytes to read. byte_count: the number of bytes to read. If None, it reads to the end. Returns: The text content within the object. Raises: Exception if the object could not be read from. """ args = {'alt': 'media'} headers = {} if start_offset > 0 or byte_count is not None: header = 'bytes=%d-' % start_offset if byte_count is not None: header += '%d' % byte_count headers['Range'] = header url = Api._DOWNLOAD_ENDPOINT + (Api._OBJECT_PATH % (bucket, Api._escape_key(key))) return datalab.utils.Http.request(url, args=args, headers=headers, credentials=self._credentials, raw_response=True) def object_upload(self, bucket, key, content, content_type): """Writes text content to the object. Args: bucket: the name of the bucket containing the object. key: the key of the object to be written. content: the text content to be written. content_type: the type of text content. Raises: Exception if the object could not be written to. """ args = {'uploadType': 'media', 'name': key} headers = {'Content-Type': content_type} url = Api._UPLOAD_ENDPOINT + (Api._OBJECT_PATH % (bucket, '')) return datalab.utils.Http.request(url, args=args, data=content, headers=headers, credentials=self._credentials, raw_response=True) def objects_copy(self, source_bucket, source_key, target_bucket, target_key): """Updates the metadata associated with an object. Args: source_bucket: the name of the bucket containing the source object. source_key: the key of the source object being copied. target_bucket: the name of the bucket that will contain the copied object. target_key: the key of the copied object. Returns: A parsed object information dictionary. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._OBJECT_COPY_PATH % (source_bucket, Api._escape_key(source_key), target_bucket, Api._escape_key(target_key))) return datalab.utils.Http.request(url, method='POST', credentials=self._credentials) def objects_delete(self, bucket, key): """Deletes the specified object. Args: bucket: the name of the bucket. key: the key of the object within the bucket. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._OBJECT_PATH % (bucket, Api._escape_key(key))) datalab.utils.Http.request(url, method='DELETE', credentials=self._credentials, raw_response=True) def objects_get(self, bucket, key, projection='noAcl'): """Issues a request to retrieve information about an object. Args: bucket: the name of the bucket. key: the key of the object within the bucket. projection: the projection of the object to retrieve. Returns: A parsed object information dictionary. Raises: Exception if there is an error performing the operation. """ args = {} if projection is not None: args['projection'] = projection url = Api._ENDPOINT + (Api._OBJECT_PATH % (bucket, Api._escape_key(key))) return datalab.utils.Http.request(url, args=args, credentials=self._credentials) def objects_list(self, bucket, prefix=None, delimiter=None, projection='noAcl', versions=False, max_results=0, page_token=None): """Issues a request to retrieve information about an object. Args: bucket: the name of the bucket. prefix: an optional key prefix. delimiter: an optional key delimiter. projection: the projection of the objects to retrieve. versions: whether to list each version of a file as a distinct object. max_results: an optional maximum number of objects to retrieve. page_token: an optional token to continue the retrieval. Returns: A parsed list of object information dictionaries. Raises: Exception if there is an error performing the operation. """ if max_results == 0: max_results = Api._MAX_RESULTS args = {'maxResults': max_results} if prefix is not None: args['prefix'] = prefix if delimiter is not None: args['delimiter'] = delimiter if projection is not None: args['projection'] = projection if versions: args['versions'] = 'true' if page_token is not None: args['pageToken'] = page_token url = Api._ENDPOINT + (Api._OBJECT_PATH % (bucket, '')) return datalab.utils.Http.request(url, args=args, credentials=self._credentials) def objects_patch(self, bucket, key, info): """Updates the metadata associated with an object. Args: bucket: the name of the bucket containing the object. key: the key of the object being updated. info: the metadata to update. Returns: A parsed object information dictionary. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._OBJECT_PATH % (bucket, Api._escape_key(key))) return datalab.utils.Http.request(url, method='PATCH', data=info, credentials=self._credentials) @staticmethod def _escape_key(key): # Disable the behavior to leave '/' alone by explicitly specifying the safe parameter. return urllib.parse.quote(key, safe='') @staticmethod def verify_permitted_to_read(gs_path): """Check if the user has permissions to read from the given path. Args: gs_path: the GCS path to check if user is permitted to read. Raises: Exception if user has no permissions to read. """ # TODO(qimingj): Storage APIs need to be modified to allow absence of project # or credential on Items. When that happens we can move the function # to Items class. from . import _bucket bucket, prefix = _bucket.parse_name(gs_path) credentials = None if datalab.context.Context.is_signed_in(): credentials = datalab.context._utils.get_credentials() args = { 'maxResults': Api._MAX_RESULTS, 'projection': 'noAcl' } if prefix is not None: args['prefix'] = prefix url = Api._ENDPOINT + (Api._OBJECT_PATH % (bucket, '')) try: datalab.utils.Http.request(url, args=args, credentials=credentials) except datalab.utils.RequestException as e: if e.status == 401: raise Exception('Not permitted to read from specified path. ' 'Please sign in and make sure you have read access.') raise e ================================================ FILE: datalab/storage/_bucket.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Bucket-related Cloud Storage APIs.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object import dateutil.parser import re import datalab.context import datalab.utils from . import _api from . import _item # REs to match bucket names and optionally object names _BUCKET_NAME = '[a-z\d][a-z\d_\.\-]+[a-z\d]' _OBJECT_NAME = '[^\n\r]+' _STORAGE_NAME = 'gs://(' + _BUCKET_NAME + ')(/' + _OBJECT_NAME + ')?' def parse_name(name): """ Parse a gs:// URL into the bucket and item names. Args: name: a GCS URL of the form gs://bucket or gs://bucket/item Returns: The bucket name (with no gs:// prefix), and the item name if present. If the name could not be parsed returns None for both. """ bucket = None item = None m = re.match(_STORAGE_NAME, name) if m: # We want to return the last two groups as first group is the optional 'gs://' bucket = m.group(1) item = m.group(2) if item is not None: item = item[1:] # Strip '/' else: m = re.match('(' + _OBJECT_NAME + ')', name) if m: item = m.group(1) return bucket, item class BucketMetadata(object): """Represents metadata about a Cloud Storage bucket.""" def __init__(self, info): """Initializes an instance of a BucketMetadata object. Args: info: a dictionary containing information about an Item. """ self._info = info @property def created_on(self): """The created timestamp of the bucket as a datetime.datetime.""" s = self._info.get('timeCreated', None) return dateutil.parser.parse(s) if s else None @property def etag(self): """The ETag of the bucket, if any.""" return self._info.get('etag', None) @property def name(self): """The name of the bucket.""" return self._info['name'] class Bucket(object): """Represents a Cloud Storage bucket.""" def __init__(self, name, info=None, context=None): """Initializes an instance of a Bucket object. Args: name: the name of the bucket. info: the information about the bucket if available. context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. """ if context is None: context = datalab.context.Context.default() self._context = context self._api = _api.Api(context) self._name = name self._info = info @property def name(self): """The name of the bucket.""" return self._name def __repr__(self): """Returns a representation for the table for showing in the notebook. """ return 'Bucket gs://%s' % self._name @property def metadata(self): """Retrieves metadata about the bucket. Returns: A BucketMetadata instance with information about this bucket. Raises: Exception if there was an error requesting the bucket's metadata. """ if self._info is None: try: self._info = self._api.buckets_get(self._name) except Exception as e: raise e return BucketMetadata(self._info) if self._info else None def item(self, key): """Retrieves an Item object for the specified key in this bucket. The item need not exist. Args: key: the key of the item within the bucket. Returns: An Item instance representing the specified key. """ return _item.Item(self._name, key, context=self._context) def items(self, prefix=None, delimiter=None): """Get an iterator for the items within this bucket. Args: prefix: an optional prefix to match items. delimiter: an optional string to simulate directory-like semantics. The returned items will be those whose names do not contain the delimiter after the prefix. For the remaining items, the names will be returned truncated after the delimiter with duplicates removed (i.e. as pseudo-directories). Returns: An iterable list of items within this bucket. """ return _item.Items(self._name, prefix, delimiter, context=self._context) def exists(self): """ Checks if the bucket exists. """ try: return self.metadata is not None except Exception: return False def create(self, project_id=None): """Creates the bucket. Args: project_id: the project in which to create the bucket. Returns: The bucket. Raises: Exception if there was an error creating the bucket. """ if not self.exists(): if project_id is None: project_id = self._api.project_id try: self._info = self._api.buckets_insert(self._name, project_id=project_id) except Exception as e: raise e return self def delete(self): """Deletes the bucket. Raises: Exception if there was an error deleting the bucket. """ if self.exists(): try: self._api.buckets_delete(self._name) except Exception as e: raise e class Buckets(object): """Represents a list of Cloud Storage buckets for a project.""" def __init__(self, project_id=None, context=None): """Initializes an instance of a BucketList. Args: project_id: an optional project whose buckets we want to manipulate. If None this is obtained from the api object. context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. """ if context is None: context = datalab.context.Context.default() self._context = context self._api = _api.Api(context) self._project_id = project_id if project_id else self._api.project_id def contains(self, name): """Checks if the specified bucket exists. Args: name: the name of the bucket to lookup. Returns: True if the bucket exists; False otherwise. Raises: Exception if there was an error requesting information about the bucket. """ try: self._api.buckets_get(name) except datalab.utils.RequestException as e: if e.status == 404: return False raise e except Exception as e: raise e return True def create(self, name): """Creates a new bucket. Args: name: a unique name for the new bucket. Returns: The newly created bucket. Raises: Exception if there was an error creating the bucket. """ return Bucket(name, context=self._context).create(self._project_id) def _retrieve_buckets(self, page_token, _): try: list_info = self._api.buckets_list(page_token=page_token, project_id=self._project_id) except Exception as e: raise e buckets = list_info.get('items', []) if len(buckets): try: buckets = [Bucket(info['name'], info, context=self._context) for info in buckets] except KeyError: raise Exception('Unexpected response from server') page_token = list_info.get('nextPageToken', None) return buckets, page_token def __iter__(self): return iter(datalab.utils.Iterator(self._retrieve_buckets)) ================================================ FILE: datalab/storage/_item.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Object-related Cloud Storage APIs.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object import dateutil.parser import datalab.utils import datalab.context from . import _api # TODO(nikhilko): Read/write operations don't account for larger files, or non-textual content. # Use streaming reads into a buffer or StringIO or into a file handle. class ItemMetadata(object): """Represents metadata about a Cloud Storage object.""" def __init__(self, info): """Initializes an instance of a ItemMetadata object. Args: info: a dictionary containing information about an Item. """ self._info = info @property def content_type(self): """The Content-Type associated with the item, if any.""" return self._info.get('contentType', None) @property def etag(self): """The ETag of the item, if any.""" return self._info.get('etag', None) @property def name(self): """The name of the item.""" return self._info['name'] @property def size(self): """The size (in bytes) of the item. 0 for items that don't exist.""" return int(self._info.get('size', 0)) @property def updated_on(self): """The updated timestamp of the item as a datetime.datetime.""" s = self._info.get('updated', None) return dateutil.parser.parse(s) if s else None class Item(object): """Represents a Cloud Storage object within a bucket.""" def __init__(self, bucket, key, info=None, context=None): """Initializes an instance of an Item. Args: bucket: the name of the bucket containing the item. key: the key of the item. info: the information about the item if available. context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. """ if context is None: context = datalab.context.Context.default() self._context = context self._api = _api.Api(context) self._bucket = bucket self._key = key self._info = info @staticmethod def from_url(url): from . import _bucket bucket, item = _bucket.parse_name(url) return Item(bucket, item) @property def key(self): """Returns the key of the item.""" return self._key @property def uri(self): """Returns the gs:// URI for the item. """ return 'gs://%s/%s' % (self._bucket, self._key) def __repr__(self): """Returns a representation for the table for showing in the notebook. """ return 'Item %s' % self.uri def copy_to(self, new_key, bucket=None): """Copies this item to the specified new key. Args: new_key: the new key to copy this item to. bucket: the bucket of the new item; if None (the default) use the same bucket. Returns: An Item corresponding to new key. Raises: Exception if there was an error copying the item. """ if bucket is None: bucket = self._bucket try: new_info = self._api.objects_copy(self._bucket, self._key, bucket, new_key) except Exception as e: raise e return Item(bucket, new_key, new_info, context=self._context) def exists(self): """ Checks if the item exists. """ try: return self.metadata is not None except datalab.utils.RequestException: return False except Exception as e: raise e def delete(self): """Deletes this item from its bucket. Raises: Exception if there was an error deleting the item. """ if self.exists(): try: self._api.objects_delete(self._bucket, self._key) except Exception as e: raise e @property def metadata(self): """Retrieves metadata about the bucket. Returns: A BucketMetadata instance with information about this bucket. Raises: Exception if there was an error requesting the bucket's metadata. """ if self._info is None: try: self._info = self._api.objects_get(self._bucket, self._key) except Exception as e: raise e return ItemMetadata(self._info) if self._info else None def read_from(self, start_offset=0, byte_count=None): """Reads the content of this item as text. Args: start_offset: the start offset of bytes to read. byte_count: the number of bytes to read. If None, it reads to the end. Returns: The text content within the item. Raises: Exception if there was an error requesting the item's content. """ try: return self._api.object_download(self._bucket, self._key, start_offset=start_offset, byte_count=byte_count) except Exception as e: raise e def read_lines(self, max_lines=None): """Reads the content of this item as text, and return a list of lines up to some max. Args: max_lines: max number of lines to return. If None, return all lines. Returns: The text content of the item as a list of lines. Raises: Exception if there was an error requesting the item's content. """ if max_lines is None: return self.read_from().split('\n') max_to_read = self.metadata.size bytes_to_read = min(100 * max_lines, self.metadata.size) while True: content = self.read_from(byte_count=bytes_to_read) lines = content.split('\n') if len(lines) > max_lines or bytes_to_read >= max_to_read: break # try 10 times more bytes or max bytes_to_read = min(bytes_to_read * 10, max_to_read) # remove the partial line at last del lines[-1] return lines[0:max_lines] def write_to(self, content, content_type): """Writes text content to this item. Args: content: the text content to be written. content_type: the type of text content. Raises: Exception if there was an error requesting the item's content. """ try: self._api.object_upload(self._bucket, self._key, content, content_type) except Exception as e: raise e class Items(object): """Represents a list of Cloud Storage objects within a bucket.""" def __init__(self, bucket, prefix, delimiter, context=None): """Initializes an instance of an ItemList. Args: bucket: the name of the bucket containing the items. prefix: an optional prefix to match items. delimiter: an optional string to simulate directory-like semantics. The returned items will be those whose names do not contain the delimiter after the prefix. For the remaining items, the names will be returned truncated after the delimiter with duplicates removed (i.e. as pseudo-directories). context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. """ if context is None: context = datalab.context.Context.default() self._context = context self._api = _api.Api(context) self._bucket = bucket self._prefix = prefix self._delimiter = delimiter def contains(self, key): """Checks if the specified item exists. Args: key: the key of the item to lookup. Returns: True if the item exists; False otherwise. Raises: Exception if there was an error requesting information about the item. """ try: self._api.objects_get(self._bucket, key) except datalab.utils.RequestException as e: if e.status == 404: return False raise e except Exception as e: raise e return True def _retrieve_items(self, page_token, _): try: list_info = self._api.objects_list(self._bucket, prefix=self._prefix, delimiter=self._delimiter, page_token=page_token) except Exception as e: raise e items = list_info.get('items', []) if len(items): try: items = [Item(self._bucket, info['name'], info, context=self._context) for info in items] except KeyError: raise Exception('Unexpected response from server') page_token = list_info.get('nextPageToken', None) return items, page_token def __iter__(self): return iter(datalab.utils.Iterator(self._retrieve_items)) ================================================ FILE: datalab/storage/commands/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from . import _storage __all__ = ['_storage'] ================================================ FILE: datalab/storage/commands/_storage.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - BigQuery IPython Functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from past.builtins import basestring try: import IPython import IPython.core.display import IPython.core.magic except ImportError: raise Exception('This module can only be loaded in ipython.') import fnmatch import json import re import datalab.storage import datalab.utils.commands def _extract_storage_api_response_error(message): """ A helper function to extract user-friendly error messages from service exceptions. Args: message: An error message from an exception. If this is from our HTTP client code, it will actually be a tuple. Returns: A modified version of the message that is less cryptic. """ try: if len(message) == 3: # Try treat the last part as JSON data = json.loads(message[2]) return data['error']['errors'][0]['message'] except Exception: pass return message @IPython.core.magic.register_line_cell_magic def storage(line, cell=None): """Implements the storage cell magic for ipython notebooks. Args: line: the contents of the storage line. Returns: The results of executing the cell. """ parser = datalab.utils.commands.CommandParser(prog='storage', description=""" Execute various storage-related operations. Use "%storage -h" for help on a specific command. """) # TODO(gram): consider adding a move command too. I did try this already using the # objects.patch API to change the object name but that fails with an error: # # Value 'newname' in content does not agree with value 'oldname'. This can happen when a value # set through a parameter is inconsistent with a value set in the request. # # This is despite 'name' being identified as writable in the storage API docs. # The alternative would be to use a copy/delete. copy_parser = parser.subcommand('copy', 'Copy one or more GCS objects to a different location.') copy_parser.add_argument('-s', '--source', help='The name of the object(s) to copy', nargs='+') copy_parser.add_argument('-d', '--destination', required=True, help='The copy destination. For multiple source items this must be a ' 'bucket.') copy_parser.set_defaults(func=_storage_copy) create_parser = parser.subcommand('create', 'Create one or more GCS buckets.') create_parser.add_argument('-p', '--project', help='The project associated with the objects') create_parser.add_argument('-b', '--bucket', help='The name of the bucket(s) to create', nargs='+') create_parser.set_defaults(func=_storage_create) delete_parser = parser.subcommand('delete', 'Delete one or more GCS buckets or objects.') delete_parser.add_argument('-b', '--bucket', nargs='*', help='The name of the bucket(s) to remove') delete_parser.add_argument('-o', '--object', nargs='*', help='The name of the object(s) to remove') delete_parser.set_defaults(func=_storage_delete) list_parser = parser.subcommand('list', 'List buckets in a project, or contents of a bucket.') list_parser.add_argument('-p', '--project', help='The project associated with the objects') group = list_parser.add_mutually_exclusive_group() group.add_argument('-o', '--object', help='The name of the objects(s) to list; can include wildchars', nargs='?') group.add_argument('-b', '--bucket', help='The name of the buckets(s) to list; can include wildchars', nargs='?') list_parser.set_defaults(func=_storage_list) read_parser = parser.subcommand('read', 'Read the contents of a storage object into a Python variable.') read_parser.add_argument('-o', '--object', help='The name of the object to read', required=True) read_parser.add_argument('-v', '--variable', required=True, help='The name of the Python variable to set') read_parser.set_defaults(func=_storage_read) view_parser = parser.subcommand('view', 'View the contents of a storage object.') view_parser.add_argument('-n', '--head', type=int, default=20, help='The number of initial lines to view') view_parser.add_argument('-t', '--tail', type=int, default=20, help='The number of lines from end to view') view_parser.add_argument('-o', '--object', help='The name of the object to view', required=True) view_parser.set_defaults(func=_storage_view) write_parser = parser.subcommand('write', 'Write the value of a Python variable to a storage object.') write_parser.add_argument('-v', '--variable', help='The name of the source Python variable', required=True) write_parser.add_argument('-o', '--object', required=True, help='The name of the destination GCS object to write') write_parser.add_argument('-c', '--content_type', help='MIME type', default='text/plain') write_parser.set_defaults(func=_storage_write) return datalab.utils.commands.handle_magic_line(line, cell, parser) def _parser_exit(status=0, message=None): """ Replacement exit method for argument parser. We want to stop processing args but not call sys.exit(), so we raise an exception here and catch it in the call to parse_args. """ raise Exception() def _expand_list(names): """ Do a wildchar name expansion of object names in a list and return expanded list. The items are expected to exist as this is used for copy sources or delete targets. Currently we support wildchars in the key name only. """ if names is None: names = [] elif isinstance(names, basestring): names = [names] results = [] # The expanded list. items = {} # Cached contents of buckets; used for matching. for name in names: bucket, key = datalab.storage._bucket.parse_name(name) results_len = len(results) # If we fail to add any we add name and let caller deal with it. if bucket: if not key: # Just a bucket; add it. results.append('gs://%s' % bucket) elif datalab.storage.Item(bucket, key).exists(): results.append('gs://%s/%s' % (bucket, key)) else: # Expand possible key values. if bucket not in items and key[:1] == '*': # We need the full list; cache a copy for efficiency. items[bucket] = [item.metadata.name for item in list(datalab.storage.Bucket(bucket).items())] # If we have a cached copy use it if bucket in items: candidates = items[bucket] # else we have no cached copy but can use prefix matching which is more efficient than # getting the full contents. else: # Get the non-wildchar prefix. match = re.search('\?|\*|\[', key) prefix = key if match: prefix = key[0:match.start()] candidates = [item.metadata.name for item in datalab.storage.Bucket(bucket).items(prefix=prefix)] for item in candidates: if fnmatch.fnmatch(item, key): results.append('gs://%s/%s' % (bucket, item)) # If we added no matches, add the original name and let caller deal with it. if len(results) == results_len: results.append(name) return results def _storage_copy(args, _): target = args['destination'] target_bucket, target_key = datalab.storage._bucket.parse_name(target) if target_bucket is None and target_key is None: raise Exception('Invalid copy target name %s' % target) sources = _expand_list(args['source']) if len(sources) > 1: # Multiple sources; target must be a bucket if target_bucket is None or target_key is not None: raise Exception('More than one source but target %s is not a bucket' % target) errs = [] for source in sources: source_bucket, source_key = datalab.storage._bucket.parse_name(source) if source_bucket is None or source_key is None: raise Exception('Invalid source object name %s' % source) destination_bucket = target_bucket if target_bucket else source_bucket destination_key = target_key if target_key else source_key try: datalab.storage.Item(source_bucket, source_key).copy_to(destination_key, bucket=destination_bucket) except Exception as e: errs.append("Couldn't copy %s to %s: %s" % (source, target, _extract_storage_api_response_error(str(e)))) if errs: raise Exception('\n'.join(errs)) def _storage_create(args, _): """ Create one or more buckets. """ buckets = datalab.storage.Buckets(project_id=args['project']) errs = [] for name in args['bucket']: try: bucket, key = datalab.storage._bucket.parse_name(name) if bucket and not key: buckets.create(bucket) else: raise Exception("Invalid bucket name %s" % name) except Exception as e: errs.append("Couldn't create %s: %s" % (name, _extract_storage_api_response_error(str(e)))) if errs: raise Exception('\n'.join(errs)) def _storage_delete(args, _): """ Delete one or more buckets or objects. """ items = _expand_list(args['bucket']) items.extend(_expand_list(args['object'])) errs = [] for item in items: try: bucket, key = datalab.storage._bucket.parse_name(item) if bucket and key: gcs_item = datalab.storage.Item(bucket, key) if gcs_item.exists(): datalab.storage.Item(bucket, key).delete() else: errs.append("%s does not exist" % item) elif bucket: gcs_bucket = datalab.storage.Bucket(bucket) if gcs_bucket.exists(): gcs_bucket.delete() else: errs.append("%s does not exist" % item) else: raise Exception("Can't delete item with invalid name %s" % item) except Exception as e: errs.append("Couldn't delete %s: %s" % (item, _extract_storage_api_response_error(str(e)))) if errs: raise Exception('\n'.join(errs)) def _storage_list_buckets(project, pattern): """ List all storage buckets that match a pattern. """ data = [{'Bucket': 'gs://' + bucket.name, 'Created': bucket.metadata.created_on} for bucket in datalab.storage.Buckets(project_id=project) if fnmatch.fnmatch(bucket.name, pattern)] return datalab.utils.commands.render_dictionary(data, ['Bucket', 'Created']) def _storage_get_keys(bucket, pattern): """ Get names of all storage keys in a specified bucket that match a pattern. """ return [item for item in list(bucket.items()) if fnmatch.fnmatch(item.metadata.name, pattern)] def _storage_get_key_names(bucket, pattern): """ Get names of all storage keys in a specified bucket that match a pattern. """ return [item.metadata.name for item in _storage_get_keys(bucket, pattern)] def _storage_list_keys(bucket, pattern): """ List all storage keys in a specified bucket that match a pattern. """ data = [{'Name': item.metadata.name, 'Type': item.metadata.content_type, 'Size': item.metadata.size, 'Updated': item.metadata.updated_on} for item in _storage_get_keys(bucket, pattern)] return datalab.utils.commands.render_dictionary(data, ['Name', 'Type', 'Size', 'Updated']) def _storage_list(args, _): """ List the buckets or the contents of a bucket. This command is a bit different in that we allow wildchars in the bucket name and will list the buckets that match. """ target = args['object'] if args['object'] else args['bucket'] project = args['project'] if target is None: return _storage_list_buckets(project, '*') # List all buckets. bucket_name, key = datalab.storage._bucket.parse_name(target) if bucket_name is None: raise Exception('Cannot list %s; not a valid bucket name' % target) if key or not re.search('\?|\*|\[', target): # List the contents of the bucket if not key: key = '*' if project: # Only list if the bucket is in the project for bucket in datalab.storage.Buckets(project_id=project): if bucket.name == bucket_name: break else: raise Exception('%s does not exist in project %s' % (target, project)) else: bucket = datalab.storage.Bucket(bucket_name) if bucket.exists(): return _storage_list_keys(bucket, key) else: raise Exception('Bucket %s does not exist' % target) else: # Treat the bucket name as a pattern and show matches. We don't use bucket_name as that # can strip off wildchars and so we need to strip off gs:// here. return _storage_list_buckets(project, target[5:]) def _get_item_contents(source_name): source_bucket, source_key = datalab.storage._bucket.parse_name(source_name) if source_bucket is None: raise Exception('Invalid source object name %s; no bucket specified.' % source_name) if source_key is None: raise Exception('Invalid source object name %si; source cannot be a bucket.' % source_name) source = datalab.storage.Item(source_bucket, source_key) if not source.exists(): raise Exception('Source object %s does not exist' % source_name) return source.read_from() def _storage_read(args, _): contents = _get_item_contents(args['object']) ipy = IPython.get_ipython() ipy.push({args['variable']: contents}) def _storage_view(args, _): contents = _get_item_contents(args['object']) if not isinstance(contents, basestring): contents = str(contents) lines = contents.split('\n') head_count = args['head'] tail_count = args['tail'] if len(lines) > head_count + tail_count: head = '\n'.join(lines[:head_count]) tail = '\n'.join(lines[-tail_count:]) return head + '\n...\n' + tail else: return contents def _storage_write(args, _): target_name = args['object'] target_bucket, target_key = datalab.storage._bucket.parse_name(target_name) if target_bucket is None or target_key is None: raise Exception('Invalid target object name %s' % target_name) target = datalab.storage.Item(target_bucket, target_key) ipy = IPython.get_ipython() contents = ipy.user_ns[args['variable']] # TODO(gram): would we want to to do any special handling here; e.g. for DataFrames? target.write_to(str(contents), args['content_type']) ================================================ FILE: datalab/utils/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - Internal Helpers.""" from ._async import async_, async_function, async_method from ._gcp_job import GCPJob from ._http import Http, RequestException from ._iterator import Iterator from ._job import Job, JobError from ._json_encoder import JSONEncoder from ._lru_cache import LRUCache from ._lambda_job import LambdaJob from ._dataflow_job import DataflowJob from ._utils import print_exception_with_last_stack, get_item, compare_datetimes, \ pick_unused_port, is_http_running_on, gcs_copy_file __all__ = ['async_', 'async_function', 'async_method', 'GCPJob', 'Http', 'RequestException', 'Iterator', 'Job', 'JobError', 'JSONEncoder', 'LRUCache', 'LambdaJob', 'DataflowJob', 'print_exception_with_last_stack', 'get_item', 'compare_datetimes', 'pick_unused_port', 'is_http_running_on', 'gcs_copy_file'] ================================================ FILE: datalab/utils/_async.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Decorators for async methods and functions to dispatch on threads and support chained calls.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object import abc import concurrent.futures import functools from . import _job from future.utils import with_metaclass class async_(with_metaclass(abc.ABCMeta, object)): """ Base class for async_function/async_method. Creates a wrapped function/method that will run the original function/method on a thread pool worker thread and return a Job instance for monitoring the status of the thread. """ executor = concurrent.futures.ThreadPoolExecutor(max_workers=50) # Pool for doing the work. def __init__(self, function): self._function = function # Make the wrapper get attributes like docstring from wrapped method. functools.update_wrapper(self, function) @staticmethod def _preprocess_args(*args): # Pre-process arguments - if any are themselves Futures block until they can be resolved. return [arg.result() if isinstance(arg, concurrent.futures.Future) else arg for arg in args] @staticmethod def _preprocess_kwargs(**kwargs): # Pre-process keyword arguments - if any are Futures block until they can be resolved. return {kw: (arg.result() if isinstance(arg, concurrent.futures.Future) else arg) for kw, arg in list(kwargs.items())} @abc.abstractmethod def _call(self, *args, **kwargs): return def __call__(self, *args, **kwargs): # Queue the call up in the thread pool. return _job.Job(future=self.executor.submit(self._call, *args, **kwargs)) class async_function(async_): """ This decorator can be applied to any static function that makes blocking calls to create a modified version that creates a Job and returns immediately; the original method will be called on a thread pool worker thread. """ def _call(self, *args, **kwargs): # Call the wrapped method. return self._function(*async_._preprocess_args(*args), **async_._preprocess_kwargs(**kwargs)) class async_method(async_): """ This decorator can be applied to any class instance method that makes blocking calls to create a modified version that creates a Job and returns immediately; the original method will be called on a thread pool worker thread. """ def _call(self, *args, **kwargs): # Call the wrapped method. return self._function(self.obj, *async_._preprocess_args(*args), **async_._preprocess_kwargs(**kwargs)) def __get__(self, instance, owner): # This is important for attribute inheritance and setting self.obj so it can be # passed as first argument to wrapped method. self.cls = owner self.obj = instance return self ================================================ FILE: datalab/utils/_dataflow_job.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements DataFlow Job functionality.""" from . import _job class DataflowJob(_job.Job): """Represents a DataFlow Job. """ def __init__(self, runner_results): """Initializes an instance of a DataFlow Job. Args: runner_results: a DataflowPipelineResult returned from Pipeline.run(). """ super(DataflowJob, self).__init__(runner_results._job.name) self._runner_results = runner_results def _refresh_state(self): """ Refresh the job info. """ # DataFlow's DataflowPipelineResult does not refresh state, so we have to do it ourselves # as a workaround. self._runner_results._job = ( self._runner_results._runner.dataflow_client.get_job(self._runner_results.job_id())) self._is_complete = self._runner_results.state in ['STOPPED', 'DONE', 'FAILED', 'CANCELLED'] self._fator_error = getattr(self._runner_results._runner, 'last_error_msg', None) ================================================ FILE: datalab/utils/_gcp_job.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements GCP Job functionality.""" from __future__ import absolute_import from __future__ import unicode_literals import datalab.context from . import _job class GCPJob(_job.Job): """Represents a BigQuery Job. """ def __init__(self, job_id, context): """Initializes an instance of a Job. Args: job_id: the BigQuery job ID corresponding to this job. context: a Context object providing project_id and credentials. """ super(GCPJob, self).__init__(job_id) if context is None: context = datalab.context.Context.default() self._context = context self._api = self._create_api(context) def _create_api(self, context): raise Exception('_create_api must be defined in a derived class') def __repr__(self): """Returns a representation for the job for showing in the notebook. """ return 'Job %s/%s %s' % (self._context.project_id, self._job_id, self.state) ================================================ FILE: datalab/utils/_http.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements HTTP client helper functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from future import standard_library standard_library.install_aliases() # noqa from builtins import str from past.builtins import basestring from builtins import object import copy import datetime import json import urllib.request import urllib.parse import urllib.error import httplib2 import google_auth_httplib2 import logging log = logging.getLogger(__name__) # TODO(nikhilko): Start using the requests library instead. class RequestException(Exception): def __init__(self, status, content): self.status = status self.content = content self.message = 'HTTP request failed' # Try extract a message from the body; swallow possible resulting ValueErrors and KeyErrors. try: error = json.loads(content.decode('utf-8'))['error'] if 'errors' in error: error = error['errors'][0] self.message += ': ' + error['message'] except Exception: lines = content.splitlines() if isinstance(content, basestring) else [] if lines: self.message += ': ' + lines[0] def __str__(self): return self.message class Http(object): """A helper class for making HTTP requests. """ # Reuse one Http object across requests to take advantage of Keep-Alive, e.g. # for BigQuery queries that requires at least ~5 sequential http requests. # # TODO(nikhilko): # SSL cert validation seemingly fails, and workarounds are not amenable # to implementing in library code. So configure the Http object to skip # doing so, in the interim. http = httplib2.Http() http.disable_ssl_certificate_validation = True def __init__(self): pass @staticmethod def request(url, args=None, data=None, headers=None, method=None, credentials=None, raw_response=False, stats=None): """Issues HTTP requests. Args: url: the URL to request. args: optional query string arguments. data: optional data to be sent within the request. headers: optional headers to include in the request. method: optional HTTP method to use. If unspecified this is inferred (GET or POST) based on the existence of request data. credentials: optional set of credentials to authorize the request. raw_response: whether the raw response content should be returned as-is. stats: an optional dictionary that, if provided, will be populated with some useful info about the request, like 'duration' in seconds and 'data_size' in bytes. These may be useful optimizing the access to rate-limited APIs. Returns: The parsed response object. Raises: Exception when the HTTP request fails or the response cannot be processed. """ if headers is None: headers = {} headers['user-agent'] = 'GoogleCloudDataLab/1.0' # Add querystring to the URL if there are any arguments. if args is not None: qs = urllib.parse.urlencode(args) url = url + '?' + qs # Setup method to POST if unspecified, and appropriate request headers # if there is data to be sent within the request. if data is not None: if method is None: method = 'POST' if data != '': # If there is a content type specified, use it (and the data) as-is. # Otherwise, assume JSON, and serialize the data object. if 'Content-Type' not in headers: data = json.dumps(data) headers['Content-Type'] = 'application/json' headers['Content-Length'] = str(len(data)) else: if method == 'POST': headers['Content-Length'] = '0' # If the method is still unset, i.e. it was unspecified, and there # was no data to be POSTed, then default to GET request. if method is None: method = 'GET' http = Http.http # Authorize with credentials if given. if credentials is not None: # Make a copy of the shared http instance before we modify it. http = copy.copy(http) http = google_auth_httplib2.AuthorizedHttp(credentials) if stats is not None: stats['duration'] = datetime.datetime.utcnow() response = None try: log.debug('request: method[%(method)s], url[%(url)s], body[%(data)s]' % locals()) response, content = http.request(url, method=method, body=data, headers=headers) if 200 <= response.status < 300: if raw_response: return content if type(content) == str: return json.loads(content) else: return json.loads(str(content, encoding='UTF-8')) else: raise RequestException(response.status, content) except ValueError: raise Exception('Failed to process HTTP response.') except httplib2.HttpLib2Error: raise Exception('Failed to send HTTP request.') finally: if stats is not None: stats['data_size'] = len(data) stats['status'] = response.status stats['duration'] = (datetime.datetime.utcnow() - stats['duration']).total_seconds() ================================================ FILE: datalab/utils/_iterator.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Iterator class for iterable cloud lists.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object class Iterator(object): """An iterator implementation that handles paging over a cloud list.""" def __init__(self, retriever): """Initializes an instance of an Iterator. Args: retriever: a function that can retrieve the next page of items. """ self._page_token = None self._first_page = True self._retriever = retriever self._count = 0 def __iter__(self): """Provides iterator functionality.""" while self._first_page or (self._page_token is not None): items, next_page_token = self._retriever(self._page_token, self._count) self._page_token = next_page_token self._first_page = False if self._count == 0: self._count = len(items) for item in items: yield item def reset(self): """Resets the current iteration.""" self._page_token = None self._first_page = True self._count = 0 ================================================ FILE: datalab/utils/_job.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Job functionality for async tasks.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from builtins import object import concurrent.futures import datetime import time import traceback import uuid class JobError(Exception): """ A helper class to capture multiple components of Job errors. """ def __init__(self, location, message, reason): self.location = location self.message = message self.reason = reason def __str__(self): return self.message class Job(object): """A manager object for async operations. A Job can have a Future in which case it will be able to monitor its own completion state and result, or it may have no Future in which case it must be a derived class that manages this some other way. We do this instead of having an abstract base class in order to make wait_one/wait_all more efficient; instead of just sleeping and polling we can use more reactive ways of monitoring groups of Jobs. """ _POLL_INTERVAL_SECONDS = 5 def __init__(self, job_id=None, future=None): """Initializes an instance of a Job. Args: job_id: a unique ID for the job. If None, a UUID will be generated. future: the Future associated with the Job, if any. """ self._job_id = str(uuid.uuid4()) if job_id is None else job_id self._future = future self._is_complete = False self._errors = None self._fatal_error = None self._result = None self._start_time = datetime.datetime.utcnow() self._end_time = None def __str__(self): return self._job_id @property def id(self): """ Get the Job ID. Returns: The ID of the job. """ return self._job_id @property def is_complete(self): """ Get the completion state of the job. Returns: True if the job is complete; False if it is still running. """ self._refresh_state() return self._is_complete @property def failed(self): """ Get the success state of the job. Returns: True if the job failed; False if it is still running or succeeded (possibly with partial failure). """ self._refresh_state() return self._is_complete and self._fatal_error is not None @property def fatal_error(self): """ Get the job error. Returns: None if the job succeeded or is still running, else the error tuple for the failure. """ self._refresh_state() return self._fatal_error @property def errors(self): """ Get the non-fatal errors in the job. Returns: None if the job is still running, else the list of errors that occurred. """ self._refresh_state() return self._errors def result(self): """ Get the result for a job. This will block if the job is incomplete. Returns: The result for the Job. Raises: An exception if the Job resulted in an exception. """ self.wait() if self._fatal_error: raise self._fatal_error return self._result @property def start_time_utc(self): """ The UTC start time of the job as a Python datetime. """ return self._start_time @property def end_time_utc(self): """ The UTC end time of the job (or None if incomplete) as a Python datetime. """ return self._end_time @property def total_time(self): """ The total time in fractional seconds that the job took, or None if not complete. """ if self._end_time is None: return None return (self._end_time - self._start_time).total_seconds() def _refresh_state(self): """ Get the state of a job. Must be overridden by derived Job classes for Jobs that don't use a Future. """ if self._is_complete: return if not self._future: raise Exception('Please implement this in the derived class') if self._future.done(): self._is_complete = True self._end_time = datetime.datetime.utcnow() try: self._result = self._future.result() except Exception as e: message = str(e) self._fatal_error = JobError(location=traceback.format_exc(), message=message, reason=str(type(e))) def _timeout(self): """ Helper for raising timeout errors. """ raise concurrent.futures.TimeoutError('Timed out waiting for Job %s to complete' % self._job_id) def wait(self, timeout=None): """ Wait for the job to complete, or a timeout to happen. Args: timeout: how long to wait before giving up (in seconds); default None which means no timeout. Returns: The Job """ if self._future: try: # Future.exception() will return rather than raise any exception so we use it. self._future.exception(timeout) except concurrent.futures.TimeoutError: self._timeout() self._refresh_state() else: # fall back to polling while not self.is_complete: if timeout is not None: if timeout <= 0: self._timeout() timeout -= Job._POLL_INTERVAL_SECONDS time.sleep(Job._POLL_INTERVAL_SECONDS) return self @property def state(self): """ Describe the state of a Job. Returns: A string describing the job's state. """ state = 'in progress' if self.is_complete: if self.failed: state = 'failed with error: %s' % str(self._fatal_error) elif self._errors: state = 'completed with some non-fatal errors' else: state = 'completed' return state def __repr__(self): """ Get the notebook representation for the job. """ return 'Job %s %s' % (self._job_id, self.state) @staticmethod def _wait(jobs, timeout, return_when): # If a single job is passed in, make it an array for consistency if isinstance(jobs, Job): jobs = [jobs] elif len(jobs) == 0: return jobs wait_on_one = return_when == concurrent.futures.FIRST_COMPLETED completed = [] while True: if timeout is not None: timeout -= Job._POLL_INTERVAL_SECONDS done = [job for job in jobs if job.is_complete] if len(done): completed.extend(done) for job in done: jobs.remove(job) if wait_on_one or len(jobs) == 0: return completed if timeout is not None and timeout < 0: return completed # Need to block for some time. Favor using concurrent.futures.wait if possible # as it can return early if a (thread) job is ready; else fall back to time.sleep. futures = [job._future for job in jobs if job._future] if len(futures) == 0: time.sleep(Job._POLL_INTERVAL_SECONDS) else: concurrent.futures.wait(futures, timeout=Job._POLL_INTERVAL_SECONDS, return_when=return_when) @staticmethod def wait_any(jobs, timeout=None): """ Return when at least one of the specified jobs has completed or timeout expires. Args: jobs: a Job or list of Jobs to wait on. timeout: a timeout in seconds to wait for. None (the default) means no timeout. Returns: A list of the jobs that have now completed or None if there were no jobs. """ return Job._wait(jobs, timeout, concurrent.futures.FIRST_COMPLETED) @staticmethod def wait_all(jobs, timeout=None): """ Return when at all of the specified jobs have completed or timeout expires. Args: jobs: a Job or list of Jobs to wait on. timeout: a timeout in seconds to wait for. None (the default) means no timeout. Returns: A list of the jobs that have now completed or None if there were no jobs. """ return Job._wait(jobs, timeout, concurrent.futures.ALL_COMPLETED) ================================================ FILE: datalab/utils/_json_encoder.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """ JSON encoder that can handle Python datetime objects. """ from __future__ import absolute_import from __future__ import unicode_literals import datetime import json class JSONEncoder(json.JSONEncoder): """ A JSON encoder that can handle Python datetime objects. """ def default(self, obj): if isinstance(obj, datetime.date) or isinstance(obj, datetime.datetime): return obj.isoformat() elif isinstance(obj, datetime.timedelta): return (datetime.datetime.min + obj).time().isoformat() else: return super(JSONEncoder, self).default(obj) ================================================ FILE: datalab/utils/_lambda_job.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements OS shell Job functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from . import _async from . import _job class LambdaJob(_job.Job): """Represents an lambda function as a Job. """ def __init__(self, fn, job_id, *args, **kwargs): """Initializes an instance of a Job. Args: fn: the lambda function to execute asyncronously job_id: an optional ID for the job. If None, a UUID will be generated. """ super(LambdaJob, self).__init__(job_id) self._future = _async.async_.executor.submit(fn, *args, **kwargs) def __repr__(self): """Returns a representation for the job for showing in the notebook. """ return 'Job %s %s' % (self._job_id, self.state) # TODO: ShellJob, once we need it, should inherit on LambdaJob: # import subprocess # LambdaJob(subprocess.check_output, id, command_line, shell=True) ================================================ FILE: datalab/utils/_lru_cache.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """A simple LRU cache.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from past.builtins import basestring from builtins import object import datetime class LRUCache(object): """A simple LRU cache.""" def __init__(self, cache_size): """ Initialize the cache with the given size. Args: cache_size: the maximum number of items the cache can hold. Attempts to add more items than this will result in the least recently used items being displaced to make room. """ self._cache = {} self._cache_size = cache_size def __getitem__(self, key): """ Get an item from the cache. Args: key: a string used as the lookup key. Returns: The cached item, if any. Raises: Exception if the key is not a string. KeyError if the key is not found. """ if not isinstance(key, basestring): raise Exception("LRU cache can only be indexed by strings (%s has type %s)" % (str(key), str(type(key)))) if key in self._cache: entry = self._cache[key] entry['last_used'] = datetime.datetime.now() return entry['value'] else: raise KeyError(key) def __delitem__(self, key): """ Remove an item from the cache. Args: key: a string key for retrieving the item. """ if not isinstance(key, basestring): raise Exception("LRU cache can only be indexed by strings") del self._cache[key] def __setitem__(self, key, value): """ Put an item in the cache. Args: key: a string key for retrieving the item. value: the item to cache. Raises: Exception if the key is not a string. """ if not isinstance(key, basestring): raise Exception("LRU cache can only be indexed by strings") if key in self._cache: entry = self._cache[key] elif len(self._cache) < self._cache_size: # Cache is not full; append an new entry self._cache[key] = entry = {} else: # Cache is full; displace an entry entry = min(list(self._cache.values()), key=lambda x: x['last_used']) self._cache.pop(entry['key']) self._cache[key] = entry entry['value'] = value entry['key'] = key entry['last_used'] = datetime.datetime.now() def __contains__(self, key): return key in self._cache def get(self, key, value): if key in self._cache: return self._cache[key]['value'] return value ================================================ FILE: datalab/utils/_utils.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Miscellaneous simple utility functions.""" from __future__ import absolute_import from __future__ import print_function from __future__ import unicode_literals from builtins import str try: import http.client as httplib except ImportError: import httplib import pytz import subprocess import socket import traceback import types def print_exception_with_last_stack(e): """ Print the call stack of the last exception plu sprint the passed exception. Args: e: the exception to print. """ traceback.print_exc() print(str(e)) def get_item(env, name, default=None): """ Get an item from a dictionary, handling nested lookups with dotted notation. Args: env: the environment (dictionary) to use to look up the name. name: the name to look up, in dotted notation. default: the value to return if the name if not found. Returns: The result of looking up the name, if found; else the default. """ # TODO: handle attributes for key in name.split('.'): if isinstance(env, dict) and key in env: env = env[key] elif isinstance(env, types.ModuleType) and key in env.__dict__: env = env.__dict__[key] else: return default return env def compare_datetimes(d1, d2): """ Compares two datetimes safely, whether they are timezone-naive or timezone-aware. If either datetime is naive it is converted to an aware datetime assuming UTC. Args: d1: first datetime. d2: second datetime. Returns: -1 if d1 < d2, 0 if they are the same, or +1 is d1 > d2. """ if d1.tzinfo is None or d1.tzinfo.utcoffset(d1) is None: d1 = d1.replace(tzinfo=pytz.UTC) if d2.tzinfo is None or d2.tzinfo.utcoffset(d2) is None: d2 = d2.replace(tzinfo=pytz.UTC) if d1 < d2: return -1 elif d1 > d2: return 1 return 0 def pick_unused_port(): """ get an unused port on the VM. Returns: An unused port. """ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(('localhost', 0)) addr, port = s.getsockname() s.close() return port def is_http_running_on(port): """ Check if an http server runs on a given port. Args: The port to check. Returns: True if it is used by an http server. False otherwise. """ try: conn = httplib.HTTPConnection('127.0.0.1:' + str(port)) conn.connect() conn.close() return True except Exception: return False def gcs_copy_file(source, dest): """ Copy file from source to destination. The paths can be GCS or local. Args: source: the source file path. dest: the destination file path. """ subprocess.check_call(['gsutil', '-q', 'cp', source, dest]) ================================================ FILE: datalab/utils/commands/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. # flake8: noqa from __future__ import absolute_import from __future__ import unicode_literals # Support functions for magics and display help. from ._commands import CommandParser from ._html import Html, HtmlBuilder from ._utils import * # Magics from . import _chart from . import _chart_data from . import _csv from . import _extension from . import _job from . import _modules ================================================ FILE: datalab/utils/commands/_chart.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - Chart cell magic.""" from __future__ import absolute_import from __future__ import unicode_literals try: import IPython import IPython.core.display import IPython.core.magic except ImportError: raise Exception('This module can only be loaded in ipython.') from . import _commands from . import _utils @IPython.core.magic.register_line_cell_magic def chart(line, cell=None): """ Generate charts with Google Charts. Use %chart --help for more details. """ parser = _commands.CommandParser(prog='%%chart', description=""" Generate an inline chart using Google Charts using the data in a Table, Query, dataframe, or list. Numerous types of charts are supported. Options for the charts can be specified in the cell body using YAML or JSON. """) for chart_type in ['annotation', 'area', 'bars', 'bubbles', 'calendar', 'candlestick', 'columns', 'combo', 'gauge', 'geo', 'heatmap', 'histogram', 'line', 'map', 'org', 'paged_table', 'pie', 'sankey', 'scatter', 'stepped_area', 'table', 'timeline', 'treemap']: subparser = parser.subcommand(chart_type, 'Generate a %s chart.' % chart_type) subparser.add_argument('-f', '--fields', help='The field(s) to include in the chart') subparser.add_argument('-d', '--data', help='The name of the variable referencing the Table or Query to chart', required=True) subparser.set_defaults(chart=chart_type) parser.set_defaults(func=_chart_cell) return _utils.handle_magic_line(line, cell, parser) def _chart_cell(args, cell): source = args['data'] ipy = IPython.get_ipython() chart_options = _utils.parse_config(cell, ipy.user_ns) if chart_options is None: chart_options = {} elif not isinstance(chart_options, dict): raise Exception("Could not parse chart options") chart_type = args['chart'] fields = args['fields'] if args['fields'] else '*' return IPython.core.display.HTML(_utils.chart_html('gcharts', chart_type, source=source, chart_options=chart_options, fields=fields)) ================================================ FILE: datalab/utils/commands/_chart_data.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - chart_data cell magic.""" from __future__ import absolute_import from __future__ import print_function from __future__ import unicode_literals try: import IPython import IPython.core.display import IPython.core.magic except ImportError: raise Exception('This module can only be loaded in ipython.') import json import datalab.data import datalab.utils from . import _utils # Disable the magic here because another one with same name is available under # google.datalab namespace. # @IPython.core.magic.register_cell_magic def _get_chart_data(line, cell_body=''): refresh = 0 options = {} try: metadata = json.loads(cell_body) if cell_body else {} source_index = metadata.get('source_index', None) fields = metadata.get('fields', '*') first_row = int(metadata.get('first', 0)) count = int(metadata.get('count', -1)) source_index = int(source_index) if source_index >= len(_utils._data_sources): # Can happen after e.g. kernel restart # TODO(gram): get kernel restart events in charting.js and disable any refresh timers. print('No source %d' % source_index) return IPython.core.display.JSON({'data': {}}) source = _utils._data_sources[source_index] schema = None controls = metadata['controls'] if 'controls' in metadata else {} data, _ = _utils.get_data(source, fields, controls, first_row, count, schema) except Exception as e: datalab.utils.print_exception_with_last_stack(e) print('Failed with exception %s' % e) data = {} # TODO(gram): The old way - commented out below - has the advantage that it worked # for datetimes, but it is strictly wrong. The correct way below may have issues if the # chart has datetimes though so test this. return IPython.core.display.JSON({'data': data, 'refresh_interval': refresh, 'options': options}) # return IPython.core.display.JSON(json.dumps({'data': data}, cls=datalab.utils.JSONEncoder)) ================================================ FILE: datalab/utils/commands/_commands.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implementation of command parsing and handling within magics.""" from __future__ import absolute_import from __future__ import print_function from __future__ import unicode_literals try: import IPython except ImportError: raise Exception('This module can only be loaded in ipython.') import argparse import shlex class CommandParser(argparse.ArgumentParser): """An argument parser to parse commands in line/cell magic declarations. """ def __init__(self, *args, **kwargs): """Initializes an instance of a CommandParser. """ super(CommandParser, self).__init__(*args, **kwargs) self._subcommands = None @staticmethod def create(name): """Creates a CommandParser for a specific magic. """ return CommandParser(prog=name) def exit(self, status=0, message=None): """Overridden exit method to stop parsing without calling sys.exit(). """ raise Exception(message) def format_usage(self): """Overridden usage generator to use the full help message. """ return self.format_help() @staticmethod def create_args(line, namespace): """ Expand any meta-variable references in the argument list. """ args = [] # Using shlex.split handles quotes args and escape characters. for arg in shlex.split(line): if not arg: continue if arg[0] == '$': var_name = arg[1:] if var_name in namespace: args.append((namespace[var_name])) else: raise Exception('Undefined variable referenced in command line: %s' % arg) else: args.append(arg) return args def parse(self, line, namespace=None): """Parses a line into a dictionary of arguments, expanding meta-variables from a namespace. """ try: if namespace is None: ipy = IPython.get_ipython() namespace = ipy.user_ns args = CommandParser.create_args(line, namespace) return self.parse_args(args) except Exception as e: print(str(e)) return None def subcommand(self, name, help): """Creates a parser for a sub-command. """ if self._subcommands is None: self._subcommands = self.add_subparsers(help='commands') return self._subcommands.add_parser(name, description=help, help=help) ================================================ FILE: datalab/utils/commands/_csv.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements CSV file exploration""" from __future__ import absolute_import from __future__ import unicode_literals try: import IPython import IPython.core.magic import IPython.core.display except ImportError: raise Exception('This module can only be loaded in ipython.') import pandas as pd import datalab.data from . import _commands from . import _utils @IPython.core.magic.register_line_cell_magic def csv(line, cell=None): parser = _commands.CommandParser.create('csv') view_parser = parser.subcommand('view', 'Browse CSV files without providing a schema. ' + 'Each value is considered string type.') view_parser.add_argument('-i', '--input', help='Path of the input CSV data', required=True) view_parser.add_argument('-n', '--count', help='The number of lines to browse from head, default to 5.') view_parser.add_argument('-P', '--profile', action='store_true', default=False, help='Generate an interactive profile of the data') view_parser.set_defaults(func=_view) return _utils.handle_magic_line(line, cell, parser) def _view(args, cell): csv = datalab.data.Csv(args['input']) num_lines = int(args['count'] or 5) headers = None if cell: ipy = IPython.get_ipython() config = _utils.parse_config(cell, ipy.user_ns) if 'columns' in config: headers = [e.strip() for e in config['columns'].split(',')] df = pd.DataFrame(csv.browse(num_lines, headers)) if args['profile']: # TODO(gram): We need to generate a schema and type-convert the columns before this # will be useful for CSV return _utils.profile_df(df) else: return IPython.core.display.HTML(df.to_html(index=False)) ================================================ FILE: datalab/utils/commands/_extension.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - Extension cell magic.""" from __future__ import absolute_import from __future__ import unicode_literals try: import IPython import IPython.core.display import IPython.core.magic except ImportError: raise Exception('This module can only be loaded in ipython.') from . import _commands from . import _utils @IPython.core.magic.register_line_cell_magic def extension(line, cell=None): """ Load an extension. Use %extension --help for more details. """ parser = _commands.CommandParser(prog='%extension', description=""" Load an extension into Datalab. Currently only mathjax is supported. """) subparser = parser.subcommand('mathjax', 'Enabled MathJaX support in Datalab.') subparser.set_defaults(ext='mathjax') parser.set_defaults(func=_extension) return _utils.handle_magic_line(line, cell, parser) def _extension(args, cell): ext = args['ext'] if ext == 'mathjax': # TODO: remove this with the next version update # MathJax is now loaded by default for all notebooks return raise Exception('Unsupported extension %s' % ext) ================================================ FILE: datalab/utils/commands/_html.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - IPython HTML display Functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from builtins import range from past.builtins import basestring from builtins import object import time class Html(object): """A helper to enable generating an HTML representation as display data in a notebook. This object supports the combination of HTML markup and/or associated JavaScript. """ _div_id_counter = 0 @staticmethod def next_id(): """ Return an ID containing a reproducible part (counter) and unique part (timestamp). """ Html._div_id_counter += 1 return '%d_%d' % (Html._div_id_counter, int(round(time.time() * 100))) def __init__(self, markup=None): """Initializes an instance of Html. """ self._id = Html.next_id() Html._div_id_counter += 1 self._class_name = '' self._markup = markup self._dependencies = [('element!hh_%d' % self._id, 'dom')] self._script = '' self._class = None def add_class(self, class_name): """Adds a CSS class to be generated on the output HTML. """ self._class = class_name def add_dependency(self, path, name): """Adds a script dependency to be loaded before any script is executed. """ self._dependencies.append((path, name)) def add_script(self, script): """Adds JavaScript that should be included along-side the HTML. """ self._script = script def _repr_html_(self): """Generates the HTML representation. """ parts = [] if self._class: parts.append('
%s
' % (self._id, self._class, self._markup)) else: parts.append('
%s
' % (self._id, self._markup)) if len(self._script) != 0: parts.append('') return ''.join(parts) class HtmlBuilder(object): """A set of helpers to build HTML representations of objects. """ def __init__(self): """Initializes an instance of an HtmlBuilder. """ self._segments = [] def _render_objects(self, items, attributes=None, datatype='object'): """Renders an HTML table with the specified list of objects. Args: items: the iterable collection of objects to render. attributes: the optional list of properties or keys to render. datatype: the type of data; one of 'object' for Python objects, 'dict' for a list of dictionaries, or 'chartdata' for Google chart data. """ if not items: return if datatype == 'chartdata': if not attributes: attributes = [items['cols'][i]['label'] for i in range(0, len(items['cols']))] items = items['rows'] indices = {attributes[i]: i for i in range(0, len(attributes))} num_segments = len(self._segments) self._segments.append('
') first = True for o in items: if first: first = False if datatype == 'dict' and not attributes: attributes = list(o.keys()) if attributes is not None: self._segments.append('') for attr in attributes: self._segments.append('' % attr) self._segments.append('') self._segments.append('') if attributes is None: self._segments.append('' % HtmlBuilder._format(o)) else: for attr in attributes: if datatype == 'dict': self._segments.append('' % HtmlBuilder._format(o.get(attr, None), nbsp=True)) elif datatype == 'chartdata': self._segments.append('' % HtmlBuilder._format(o['c'][indices[attr]]['v'], nbsp=True)) else: self._segments.append('' % HtmlBuilder._format(o.__getattribute__(attr), nbsp=True)) self._segments.append('') self._segments.append('
%s
%s%s%s%s
') if first: # The table was empty; drop it from the segments. self._segments = self._segments[:num_segments] def _render_text(self, text, preformatted=False): """Renders an HTML formatted text block with the specified text. Args: text: the text to render preformatted: whether the text should be rendered as preformatted """ tag = 'pre' if preformatted else 'div' self._segments.append('<%s>%s' % (tag, HtmlBuilder._format(text), tag)) def _render_list(self, items, empty='
<empty>
'): """Renders an HTML list with the specified list of strings. Args: items: the iterable collection of objects to render. empty: what to render if the list is None or empty. """ if not items or len(items) == 0: self._segments.append(empty) return self._segments.append('
    ') for o in items: self._segments.append('
  • ') self._segments.append(str(o)) self._segments.append('
  • ') self._segments.append('
') def _to_html(self): """Returns the HTML that has been rendered. Returns: The HTML string that has been built. """ return ''.join(self._segments) @staticmethod def _format(value, nbsp=False): if value is None: return ' ' if nbsp else '' elif isinstance(value, basestring): return value.replace('&', '&').replace('<', '<').replace('>', '>') else: return str(value) @staticmethod def render_text(text, preformatted=False): """Renders an HTML formatted text block with the specified text. Args: text: the text to render preformatted: whether the text should be rendered as preformatted Returns: The formatted HTML. """ builder = HtmlBuilder() builder._render_text(text, preformatted=preformatted) return builder._to_html() @staticmethod def render_table(data, headers=None): """ Return a dictionary list formatted as a HTML table. Args: data: a list of dictionaries, one per row. headers: the keys in the dictionary to use as table columns, in order. """ builder = HtmlBuilder() builder._render_objects(data, headers, datatype='dict') return builder._to_html() @staticmethod def render_chart_data(data): """ Return a dictionary list formatted as a HTML table. Args: data: data in the form consumed by Google Charts. """ builder = HtmlBuilder() builder._render_objects(data, datatype='chartdata') return builder._to_html() @staticmethod def render_list(data): """ Return a list formatted as a HTML list. Args: data: a list of strings. """ builder = HtmlBuilder() builder._render_list(data) return builder._to_html() ================================================ FILE: datalab/utils/commands/_job.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements job view""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str try: import IPython import IPython.core.magic import IPython.core.display except ImportError: raise Exception('This module can only be loaded in ipython.') import datalab.utils from . import _html _local_jobs = {} def html_job_status(job_name, job_type, refresh_interval, html_on_running, html_on_success): """create html representation of status of a job (long running operation). Args: job_name: the full name of the job. job_type: type of job. Can be 'local' or 'cloud'. refresh_interval: how often should the client refresh status. html_on_running: additional html that the job view needs to include on job running. html_on_success: additional html that the job view needs to include on job success. """ _HTML_TEMPLATE = """
""" div_id = _html.Html.next_id() return IPython.core.display.HTML(_HTML_TEMPLATE % (div_id, div_id, job_name, job_type, refresh_interval, html_on_running, html_on_success)) @IPython.core.magic.register_line_magic def _get_job_status(line): """magic used as an endpoint for client to get job status. %_get_job_status Returns: A JSON object of the job status. """ try: args = line.strip().split() job_name = args[0] job = None if job_name in _local_jobs: job = _local_jobs[job_name] else: raise Exception('invalid job %s' % job_name) if job is not None: error = '' if job.fatal_error is None else str(job.fatal_error) data = {'exists': True, 'done': job.is_complete, 'error': error} else: data = {'exists': False} except Exception as e: datalab.utils.print_exception_with_last_stack(e) data = {'done': True, 'error': str(e)} return IPython.core.display.JSON(data) ================================================ FILE: datalab/utils/commands/_modules.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implementation of various module magics""" from __future__ import absolute_import from __future__ import unicode_literals try: import IPython import IPython.core.magic except ImportError: raise Exception('This module can only be loaded in ipython.') import sys import types from . import _commands from . import _utils @IPython.core.magic.register_line_cell_magic def pymodule(line, cell=None): """Creates and subsequently auto-imports a python module. """ parser = _commands.CommandParser.create('pymodule') parser.add_argument('-n', '--name', help='the name of the python module to create and import') parser.set_defaults(func=_pymodule_cell) return _utils.handle_magic_line(line, cell, parser) def _pymodule_cell(args, cell): if cell is None: raise Exception('The code for the module must be included') name = args['name'] module = _create_python_module(name, cell) # Automatically import the newly created module by assigning it to a variable # named the same name as the module name. ipy = IPython.get_ipython() ipy.push({name: module}) def _create_python_module(name, code): # By convention the module is associated with a file name matching the module name module = types.ModuleType(str(name)) module.__file__ = name module.__name__ = name exec(code, module.__dict__) # Hold on to the module if the code executed successfully sys.modules[name] = module return module ================================================ FILE: datalab/utils/commands/_utils.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Utility functions.""" from __future__ import absolute_import from __future__ import division from __future__ import unicode_literals from builtins import str from past.builtins import basestring try: import IPython import IPython.core.display except ImportError: raise Exception('This module can only be loaded in ipython.') import json import pandas try: # Pandas profiling is not needed for build/test but will be in the container. import pandas_profiling except ImportError: pass import sys import types import yaml import datalab.data import datalab.bigquery import datalab.storage import datalab.utils import google.datalab.bigquery import google.datalab.utils from . import _html def notebook_environment(): """ Get the IPython user namespace. """ ipy = IPython.get_ipython() return ipy.user_ns def get_notebook_item(name): """ Get an item from the IPython environment. """ env = notebook_environment() return datalab.utils.get_item(env, name) def render_list(data): return IPython.core.display.HTML(_html.HtmlBuilder.render_list(data)) def render_dictionary(data, headers=None): """ Return a dictionary list formatted as a HTML table. Args: data: the dictionary list headers: the keys in the dictionary to use as table columns, in order. """ return IPython.core.display.HTML(_html.HtmlBuilder.render_table(data, headers)) def render_text(text, preformatted=False): """ Return text formatted as a HTML Args: text: the text to render preformatted: whether the text should be rendered as preformatted """ return IPython.core.display.HTML(_html.HtmlBuilder.render_text(text, preformatted)) def get_field_list(fields, schema): """ Convert a field list spec into a real list of field names. For tables, we return only the top-level non-RECORD fields as Google charts can't handle nested data. """ # If the fields weren't supplied get them from the schema. if isinstance(fields, list): return fields if isinstance(fields, basestring) and fields != '*': return fields.split(',') if not schema: return [] return [f['name'] for f in schema._bq_schema if f['type'] != 'RECORD'] def _get_cols(fields, schema): """ Get column metadata for Google Charts based on field list and schema. """ typemap = { 'STRING': 'string', 'INT64': 'number', 'INTEGER': 'number', 'FLOAT': 'number', 'FLOAT64': 'number', 'BOOL': 'boolean', 'BOOLEAN': 'boolean', 'DATE': 'date', 'TIME': 'timeofday', 'DATETIME': 'datetime', 'TIMESTAMP': 'datetime' } cols = [] for col in fields: if schema: f = schema[col] t = 'string' if f.mode == 'REPEATED' else typemap.get(f.data_type, 'string') cols.append({'id': f.name, 'label': f.name, 'type': t}) else: # This will only happen if we had no rows to infer a schema from, so the type # is not really important, except that GCharts will choke if we pass such a schema # to a chart if it is string x string so we default to number. cols.append({'id': col, 'label': col, 'type': 'number'}) return cols def _get_data_from_empty_list(source, fields='*', first_row=0, count=-1, schema=None): """ Helper function for _get_data that handles empty lists. """ fields = get_field_list(fields, schema) return {'cols': _get_cols(fields, schema), 'rows': []}, 0 def _get_data_from_list_of_dicts(source, fields='*', first_row=0, count=-1, schema=None): """ Helper function for _get_data that handles lists of dicts. """ if schema is None: schema = datalab.bigquery.Schema.from_data(source) fields = get_field_list(fields, schema) gen = source[first_row:first_row + count] if count >= 0 else source rows = [{'c': [{'v': row[c]} if c in row else {} for c in fields]} for row in gen] return {'cols': _get_cols(fields, schema), 'rows': rows}, len(source) def _get_data_from_list_of_lists(source, fields='*', first_row=0, count=-1, schema=None): """ Helper function for _get_data that handles lists of lists. """ if schema is None: schema = datalab.bigquery.Schema.from_data(source) fields = get_field_list(fields, schema) gen = source[first_row:first_row + count] if count >= 0 else source cols = [schema.find(name) for name in fields] rows = [{'c': [{'v': row[i]} for i in cols]} for row in gen] return {'cols': _get_cols(fields, schema), 'rows': rows}, len(source) def _get_data_from_dataframe(source, fields='*', first_row=0, count=-1, schema=None): """ Helper function for _get_data that handles Pandas DataFrames. """ if schema is None: schema = datalab.bigquery.Schema.from_data(source) fields = get_field_list(fields, schema) rows = [] if count < 0: count = len(source.index) df_slice = source.reset_index(drop=True)[first_row:first_row + count] for index, data_frame_row in df_slice.iterrows(): row = data_frame_row.to_dict() for key in list(row.keys()): val = row[key] if isinstance(val, pandas.Timestamp): row[key] = val.to_pydatetime() rows.append({'c': [{'v': row[c]} if c in row else {} for c in fields]}) cols = _get_cols(fields, schema) return {'cols': cols, 'rows': rows}, len(source) def _get_data_from_table(source, fields='*', first_row=0, count=-1, schema=None): """ Helper function for _get_data that handles BQ Tables. """ if not source.exists(): return _get_data_from_empty_list(source, fields, first_row, count) if schema is None: schema = source.schema fields = get_field_list(fields, schema) gen = source.range(first_row, count) if count >= 0 else source rows = [{'c': [{'v': row[c]} if c in row else {} for c in fields]} for row in gen] return {'cols': _get_cols(fields, schema), 'rows': rows}, source.length def get_data(source, fields='*', env=None, first_row=0, count=-1, schema=None): """ A utility function to get a subset of data from a Table, Query, Pandas dataframe or List. Args: source: the source of the data. Can be a Table, Pandas DataFrame, List of dictionaries or lists, or a string, in which case it is expected to be the name of a table in BQ. fields: a list of fields that we want to return as a list of strings, comma-separated string, or '*' for all. env: if the data source is a Query module, this is the set of variable overrides for parameterizing the Query. first_row: the index of the first row to return; default 0. Onl;y used if count is non-negative. count: the number or rows to return. If negative (the default), return all rows. schema: the schema of the data. Optional; if supplied this can be used to help do type-coercion. Returns: A tuple consisting of a dictionary and a count; the dictionary has two entries: 'cols' which is a list of column metadata entries for Google Charts, and 'rows' which is a list of lists of values. The count is the total number of rows in the source (independent of the first_row/count parameters). Raises: Exception if the request could not be fulfilled. """ ipy = IPython.get_ipython() if env is None: env = {} env.update(ipy.user_ns) if isinstance(source, basestring): source = datalab.utils.get_item(ipy.user_ns, source, source) if isinstance(source, basestring): source = datalab.bigquery.Table(source) if isinstance(source, types.ModuleType) or isinstance(source, datalab.data.SqlStatement): source = datalab.bigquery.Query(source, values=env) if isinstance(source, list): if len(source) == 0: return _get_data_from_empty_list(source, fields, first_row, count, schema) elif isinstance(source[0], dict): return _get_data_from_list_of_dicts(source, fields, first_row, count, schema) elif isinstance(source[0], list): return _get_data_from_list_of_lists(source, fields, first_row, count, schema) else: raise Exception("To get tabular data from a list it must contain dictionaries or lists.") elif isinstance(source, pandas.DataFrame): return _get_data_from_dataframe(source, fields, first_row, count, schema) elif (isinstance(source, google.datalab.bigquery.Query) or isinstance(source, google.datalab.bigquery.Table)): return google.datalab.utils.commands._utils.get_data( source, fields, env, first_row, count, schema) elif isinstance(source, datalab.bigquery.Query): return _get_data_from_table(source.results(), fields, first_row, count, schema) elif isinstance(source, datalab.bigquery.Table): return _get_data_from_table(source, fields, first_row, count, schema) else: raise Exception("Cannot chart %s; unsupported object type" % source) def handle_magic_line(line, cell, parser, namespace=None): """ Helper function for handling magic command lines given a parser with handlers set. """ args = parser.parse(line, namespace) if args: try: return args.func(vars(args), cell) except Exception as e: sys.stderr.write(str(e)) sys.stderr.write('\n') sys.stderr.flush() return None def expand_var(v, env): """ If v is a variable reference (for example: '$myvar'), replace it using the supplied env dictionary. Args: v: the variable to replace if needed. env: user supplied dictionary. Raises: Exception if v is a variable reference but it is not found in env. """ if len(v) == 0: return v # Using len() and v[0] instead of startswith makes this Unicode-safe. if v[0] == '$': v = v[1:] if len(v) and v[0] != '$': if v in env: v = env[v] else: raise Exception('Cannot expand variable $%s' % v) return v def replace_vars(config, env): """ Replace variable references in config using the supplied env dictionary. Args: config: the config to parse. Can be a tuple, list or dict. env: user supplied dictionary. Raises: Exception if any variable references are not found in env. """ if isinstance(config, dict): for k, v in list(config.items()): if isinstance(v, dict) or isinstance(v, list) or isinstance(v, tuple): replace_vars(v, env) elif isinstance(v, basestring): config[k] = expand_var(v, env) elif isinstance(config, list): for i, v in enumerate(config): if isinstance(v, dict) or isinstance(v, list) or isinstance(v, tuple): replace_vars(v, env) elif isinstance(v, basestring): config[i] = expand_var(v, env) elif isinstance(config, tuple): # TODO(gram): figure out how to handle these if the tuple elements are scalar for v in config: if isinstance(v, dict) or isinstance(v, list) or isinstance(v, tuple): replace_vars(v, env) def parse_config(config, env, as_dict=True): """ Parse a config from a magic cell body. This could be JSON or YAML. We turn it into a Python dictionary then recursively replace any variable references using the supplied env dictionary. """ if config is None: return None stripped = config.strip() if len(stripped) == 0: config = {} elif stripped[0] == '{': config = json.loads(config) else: config = yaml.load(config) if as_dict: config = dict(config) # Now we need to walk the config dictionary recursively replacing any '$name' vars. replace_vars(config, env) return config def validate_config(config, required_keys, optional_keys=None): """ Validate a config dictionary to make sure it includes all required keys and does not include any unexpected keys. Args: config: the config to validate. required_keys: the names of the keys that the config must have. optional_keys: the names of the keys that the config can have. Raises: Exception if the config is not a dict or invalid. """ if optional_keys is None: optional_keys = [] if not isinstance(config, dict): raise Exception('config is not dict type') invalid_keys = set(config) - set(required_keys + optional_keys) if len(invalid_keys) > 0: raise Exception('Invalid config with unexpected keys "%s"' % ', '.join(e for e in invalid_keys)) missing_keys = set(required_keys) - set(config) if len(missing_keys) > 0: raise Exception('Invalid config with missing keys "%s"' % ', '.join(missing_keys)) def validate_config_must_have(config, required_keys): """ Validate a config dictionary to make sure it has all of the specified keys Args: config: the config to validate. required_keys: the list of possible keys that config must include. Raises: Exception if the config does not have any of them. """ missing_keys = set(required_keys) - set(config) if len(missing_keys) > 0: raise Exception('Invalid config with missing keys "%s"' % ', '.join(missing_keys)) def validate_config_has_one_of(config, one_of_keys): """ Validate a config dictionary to make sure it has one and only one key in one_of_keys. Args: config: the config to validate. one_of_keys: the list of possible keys that config can have one and only one. Raises: Exception if the config does not have any of them, or multiple of them. """ intersection = set(config).intersection(one_of_keys) if len(intersection) > 1: raise Exception('Only one of the values in "%s" is needed' % ', '.join(intersection)) if len(intersection) == 0: raise Exception('One of the values in "%s" is needed' % ', '.join(one_of_keys)) def validate_config_value(value, possible_values): """ Validate a config value to make sure it is one of the possible values. Args: value: the config value to validate. possible_values: the possible values the value can be Raises: Exception if the value is not one of possible values. """ if value not in possible_values: raise Exception('Invalid config value "%s". Possible values are ' '%s' % (value, ', '.join(e for e in possible_values))) # For chart and table HTML viewers, we use a list of table names and reference # instead the indices in the HTML, so as not to include things like projectID, etc, # in the HTML. _data_sources = [] def get_data_source_index(name): if name not in _data_sources: _data_sources.append(name) return _data_sources.index(name) def validate_gcs_path(path, require_object): """ Check whether a given path is a valid GCS path. Args: path: the config to check. require_object: if True, the path has to be an object path but not bucket path. Raises: Exception if the path is invalid """ bucket, key = datalab.storage._bucket.parse_name(path) if bucket is None: raise Exception('Invalid GCS path "%s"' % path) if require_object and key is None: raise Exception('It appears the GCS path "%s" is a bucket path but not an object path' % path) def parse_control_options(controls, variable_defaults=None): """ Parse a set of control options. Args: controls: The dictionary of control options. variable_defaults: If the controls are for a Query with variables, then this is the default variable values defined in the Query module. The options in the controls parameter can override these but if a variable has no 'value' property then we fall back to these. Returns: - the HTML for the controls. - the default values for the controls as a dict. - the list of DIV IDs of the controls. """ controls_html = '' control_defaults = {} control_ids = [] div_id = _html.Html.next_id() if variable_defaults is None: variable_defaults = {} for varname, control in list(controls.items()): label = control.get('label', varname) control_id = div_id + '__' + varname control_ids.append(control_id) value = control.get('value', variable_defaults.get(varname, None)) # The user should usually specify the type but we will default to 'textbox' for strings # and 'set' for lists. if isinstance(value, basestring): type = 'textbox' elif isinstance(value, list): type = 'set' else: type = None type = control.get('type', type) if type == 'picker': choices = control.get('choices', value) if not isinstance(choices, list) or len(choices) == 0: raise Exception('picker control must specify a nonempty set of choices') if value is None: value = choices[0] choices_html = '' for i, choice in enumerate(choices): choices_html += "" % \ (choice, ("selected=\"selected\"" if choice == value else ''), choice) control_html = "{label}" \ .format(label=label, id=control_id, choices=choices_html) elif type == 'set': # Multi-picker; implemented as checkboxes. # TODO(gram): consider using "name" property of the control to group checkboxes. That # way we can save the code of constructing and parsing control Ids with sequential # numbers in it. Multiple checkboxes can share the same name. choices = control.get('choices', value) if not isinstance(choices, list) or len(choices) == 0: raise Exception('set control must specify a nonempty set of choices') if value is None: value = choices choices_html = '' control_ids[-1] = '%s:%d' % (control_id, len(choices)) # replace ID to include count. for i, choice in enumerate(choices): checked = choice in value choice_id = '%s:%d' % (control_id, i) # TODO(gram): we may want a 'Submit/Refresh button as we may not want to rerun # query on each checkbox change. choices_html += """
""".format(id=choice_id, choice=choice, checked="checked" if checked else '') control_html = "{label}
{choices}
".format(label=label, choices=choices_html) elif type == 'checkbox': control_html = """ """.format(label=label, id=control_id, checked="checked" if value else '') elif type == 'slider': min_ = control.get('min', None) max_ = control.get('max', None) if min_ is None or max_ is None: raise Exception('slider control must specify a min and max value') if max_ <= min_: raise Exception('slider control must specify a min value less than max value') step = control.get('step', 1 if isinstance(min_, int) and isinstance(max_, int) else (float(max_ - min_) / 10.0)) if value is None: value = min_ control_html = """ {label} """.format(label=label, id=control_id, value=value, min=min_, max=max_, step=step) elif type == 'textbox': if value is None: value = '' control_html = "{label}" \ .format(label=label, value=value, id=control_id) else: raise Exception( 'Unknown control type %s (expected picker, slider, checkbox, textbox or set)' % type) control_defaults[varname] = value controls_html += "
{control}
\n" \ .format(control=control_html) controls_html = "
{controls}
".format(controls=controls_html) return controls_html, control_defaults, control_ids def chart_html(driver_name, chart_type, source, chart_options=None, fields='*', refresh_interval=0, refresh_data=None, control_defaults=None, control_ids=None, schema=None): """ Return HTML for a chart. Args: driver_name: the name of the chart driver. Currently we support 'plotly' or 'gcharts'. chart_type: string specifying type of chart. source: the data source for the chart. Can be actual data (e.g. list) or the name of a data source (e.g. the name of a query module). chart_options: a dictionary of options for the chart. Can contain a 'controls' entry specifying controls. Other entries are passed as JSON to Google Charts. fields: the fields to chart. Can be '*' for all fields (only sensible if the columns are ordered; e.g. a Query or list of lists, but not a list of dictionaries); otherwise a string containing a comma-separated list of field names. refresh_interval: a time in seconds after which the chart data will be refreshed. 0 if the chart should not be refreshed (i.e. the data is static). refresh_data: if the source is a list or other raw data, this is a YAML string containing metadata needed to support calls to refresh (get_chart_data). control_defaults: the default variable values for controls that are shared across charts including this one. control_ids: the DIV IDs for controls that are shared across charts including this one. schema: an optional schema for the data; if not supplied one will be inferred. Returns: A string containing the HTML for the chart. """ div_id = _html.Html.next_id() controls_html = '' if control_defaults is None: control_defaults = {} if control_ids is None: control_ids = [] if chart_options is not None and 'variables' in chart_options: controls = chart_options['variables'] del chart_options['variables'] # Just to make sure GCharts doesn't see them. try: item = get_notebook_item(source) _, variable_defaults = datalab.data.SqlModule.get_sql_statement_with_environment(item, '') except Exception: variable_defaults = {} controls_html, defaults, ids = parse_control_options(controls, variable_defaults) # We augment what we are passed so that in principle we can have controls that are # shared by charts as well as controls that are specific to a chart. control_defaults.update(defaults) control_ids.extend(ids), _HTML_TEMPLATE = """
{controls}
""" count = 25 if chart_type == 'paged_table' else -1 data, total_count = get_data(source, fields, control_defaults, 0, count, schema) if refresh_data is None: if isinstance(source, basestring): source_index = get_data_source_index(source) refresh_data = {'source_index': source_index, 'name': source_index} else: refresh_data = {'name': 'raw data'} refresh_data['fields'] = fields # TODO(gram): check if we need to augment env with user_ns return _HTML_TEMPLATE \ .format(driver=driver_name, controls=controls_html, id=div_id, chart_type=chart_type, extra_class=" bqgc-controlled" if len(controls_html) else '', data=json.dumps(data, cls=datalab.utils.JSONEncoder), options=json.dumps(chart_options, cls=datalab.utils.JSONEncoder), refresh_data=json.dumps(refresh_data, cls=datalab.utils.JSONEncoder), refresh_interval=refresh_interval, control_ids=str(control_ids), total_rows=total_count) def profile_df(df): """ Generate a profile of data in a dataframe. Args: df: the Pandas dataframe. """ # The bootstrap CSS messes up the Datalab display so we tweak it to not have an effect. # TODO(gram): strip it out rather than this kludge. return IPython.core.display.HTML( pandas_profiling.ProfileReport(df).html.replace('bootstrap', 'nonexistent')) ================================================ FILE: docs/.nojekyll ================================================ ================================================ FILE: docs/Makefile ================================================ # Makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build PAPER = # The buildir is out of the main repo. It is the location of the gh-pages # branch of Datalab that contains none of the source but just the HTML # output from Sphinx. This is to support GitHub Pages documentation. BUILDDIR = ../../datalab-docs # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) endif # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . # the i18n builder cannot share the environment and doctrees with the others I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest coverage gettext help: @echo "Please use \`make ' where is one of" @echo " html to make standalone HTML files" @echo " dirhtml to make HTML files named index.html in directories" @echo " singlehtml to make a single large HTML file" @echo " pickle to make pickle files" @echo " json to make JSON files" @echo " htmlhelp to make HTML files and a HTML help project" @echo " qthelp to make HTML files and a qthelp project" @echo " applehelp to make an Apple Help Book" @echo " devhelp to make HTML files and a Devhelp project" @echo " epub to make an epub" @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" @echo " latexpdf to make LaTeX files and run them through pdflatex" @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" @echo " text to make text files" @echo " man to make manual pages" @echo " texinfo to make Texinfo files" @echo " info to make Texinfo files and run them through makeinfo" @echo " gettext to make PO message catalogs" @echo " changes to make an overview of all changed/added/deprecated items" @echo " xml to make Docutils-native XML files" @echo " pseudoxml to make pseudoxml-XML files for display purposes" @echo " linkcheck to check all external links for integrity" @echo " doctest to run all doctests embedded in the documentation (if enabled)" @echo " coverage to run coverage check of the documentation (if enabled)" clean: rm -rf $(BUILDDIR)/* pre-build: @echo "Generate reST for magic commands:" ipython gen-magic-rst.ipy html: pre-build $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." dirhtml: pre-build $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." singlehtml: pre-build $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml @echo @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." pickle: pre-build $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle @echo @echo "Build finished; now you can process the pickle files." json: pre-build $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json @echo @echo "Build finished; now you can process the JSON files." htmlhelp: pre-build $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp @echo @echo "Build finished; now you can run HTML Help Workshop with the" \ ".hhp project file in $(BUILDDIR)/htmlhelp." qthelp: pre-build $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp @echo @echo "Build finished; now you can run "qcollectiongenerator" with the" \ ".qhcp project file in $(BUILDDIR)/qthelp, like this:" @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/api.qhcp" @echo "To view the help file:" @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/api.qhc" applehelp: pre-build $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp @echo @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." @echo "N.B. You won't be able to view it unless you put it in" \ "~/Library/Documentation/Help or install it in your application" \ "bundle." devhelp: pre-build $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp @echo @echo "Build finished." @echo "To view the help file:" @echo "# mkdir -p $$HOME/.local/share/devhelp/api" @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/api" @echo "# devhelp" epub: pre-build $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub @echo @echo "Build finished. The epub file is in $(BUILDDIR)/epub." latex: pre-build $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." @echo "Run \`make' in that directory to run these through (pdf)latex" \ "(use \`make latexpdf' here to do that automatically)." latexpdf: pre-build $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through pdflatex..." $(MAKE) -C $(BUILDDIR)/latex all-pdf @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." latexpdfja: pre-build $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through platex and dvipdfmx..." $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." text: pre-build $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text @echo @echo "Build finished. The text files are in $(BUILDDIR)/text." man: pre-build $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man @echo @echo "Build finished. The manual pages are in $(BUILDDIR)/man." texinfo: pre-build $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." @echo "Run \`make' in that directory to run these through makeinfo" \ "(use \`make info' here to do that automatically)." info: pre-build $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo "Running Texinfo files through makeinfo..." make -C $(BUILDDIR)/texinfo info @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." gettext: pre-build $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale @echo @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." changes: pre-build $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes @echo @echo "The overview file is in $(BUILDDIR)/changes." linkcheck: pre-build $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck @echo @echo "Link check complete; look for any errors in the above output " \ "or in $(BUILDDIR)/linkcheck/output.txt." doctest: pre-build $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest @echo "Testing of doctests in the sources finished, look at the " \ "results in $(BUILDDIR)/doctest/output.txt." coverage: pre-build $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage @echo "Testing of coverage in the sources finished, look at the " \ "results in $(BUILDDIR)/coverage/python.txt." xml: pre-build $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml @echo @echo "Build finished. The XML files are in $(BUILDDIR)/xml." pseudoxml: pre-build $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml @echo @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." # Publishing requires a cloned repo in the ../../datalab/docs directory. # You can use the prepublish target for this. prepublish: mkdir -p ../../datalab-docs/html cd ../../datalab-docs && git clone https://github.com/GoogleCloudPlatform/datalab.git html && \ git checkout gh-pages publish: pre-build $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html cd ../../datalab-docs/html && git add . && git commit -m "Updated" && git push --force origin gh-pages ================================================ FILE: docs/README ================================================ To use, install the prerequisites and the pydatalab module: pip install sphinx sphinx_rtd_theme sphinxcontrib-napoleon pip install .. # from docs directory then in the docs directory, do 'make html' (or epub, or text, etc). Output will be in $BUILDDIR, defaulting to ../../datalab-docs. ================================================ FILE: docs/conf.py ================================================ # -*- coding: utf-8 -*- # # api documentation build configuration file, created by # sphinx-quickstart on Tue Nov 3 12:10:12 2015. # # This file is execfile()d with the current directory set to its # containing dir. # # Note that not all possible configuration values are present in this # autogenerated file. # # All configuration values have a default; values that are commented out # serve to show the default. import sys import os import sphinx_rtd_theme # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. sys.path.append(os.path.abspath('../')) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. #needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.todo', 'sphinx.ext.viewcode', 'sphinxcontrib.napoleon', ] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # source_suffix = ['.rst', '.md'] source_suffix = '.rst' # The encoding of source files. #source_encoding = 'utf-8-sig' # The master toctree document. master_doc = 'index' # General information about the project. project = u'Google Cloud Datalab' copyright = u'2015, Google, Inc.' author = u'Google, Inc.' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. version = '' # The full version, including alpha/beta/rc tags. release = '' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. language = 'en' # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: #today = '' # Else, today_fmt is used as the format for a strftime call. #today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. exclude_patterns = ['_build'] # The reST default role (used for this markup: `text`) to use for all # documents. #default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. #add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). #add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. #show_authors = False # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' # A list of ignored prefixes for module index sorting. #modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. #keep_warnings = False # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True # Combine class pydoc with __init__ pydoc autoclass_content = 'both' # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. html_theme = 'sphinx_rtd_theme' html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. #html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. #html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". #html_title = None # A shorter title for the navigation bar. Default is the same as html_title. #html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. #html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. #html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". #html_static_path = [] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. #html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. #html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. #html_use_smartypants = True # Custom sidebar templates, maps document names to template names. #html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. #html_additional_pages = {} # If false, no module index is generated. #html_domain_indices = True # If false, no index is generated. #html_use_index = True # If true, the index is split into individual pages for each letter. #html_split_index = False # If true, links to the reST sources are added to the pages. #html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. #html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. #html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. #html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). #html_file_suffix = None # Language to be used for generating the HTML full-text search index. # Sphinx supports the following languages: # 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' # 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' #html_search_language = 'en' # A dictionary with options for the search language support, empty by default. # Now only 'ja' uses this config value #html_search_options = {'type': 'default'} # The name of a javascript file (relative to the configuration directory) that # implements a search results scorer. If empty, the default will be used. #html_search_scorer = 'scorer.js' # Output file base name for HTML help builder. htmlhelp_basename = 'apidoc' # -- Options for LaTeX output --------------------------------------------- latex_elements = { # The paper size ('letterpaper' or 'a4paper'). #'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). #'pointsize': '10pt', # Additional stuff for the LaTeX preamble. #'preamble': '', # Latex figure (float) alignment #'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ (master_doc, 'api.tex', u'api Documentation', u'Google', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of # the title page. #latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. #latex_use_parts = False # If true, show page references after internal links. #latex_show_pagerefs = False # If true, show URL addresses after external links. #latex_show_urls = False # Documents to append as an appendix to all manuals. #latex_appendices = [] # If false, no module index is generated. #latex_domain_indices = True # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ (master_doc, 'api', u'api Documentation', [author], 1) ] # If true, show URL addresses after external links. #man_show_urls = False # -- Options for Texinfo output ------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ (master_doc, 'api', u'api Documentation', author, 'api', 'One line description of project.', 'Miscellaneous'), ] # Documents to append as an appendix to all manuals. #texinfo_appendices = [] # If false, no module index is generated. #texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. #texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. #texinfo_no_detailmenu = False # -- Options for Epub output ---------------------------------------------- # Bibliographic Dublin Core info. epub_title = project epub_author = author epub_publisher = author epub_copyright = copyright # The basename for the epub file. It defaults to the project name. #epub_basename = project # The HTML theme for the epub output. Since the default themes are not optimized # for small screen space, using the same theme for HTML and epub output is # usually not wise. This defaults to 'epub', a theme designed to save visual # space. #epub_theme = 'epub' # The language of the text. It defaults to the language option # or 'en' if the language is not set. #epub_language = '' # The scheme of the identifier. Typical schemes are ISBN or URL. #epub_scheme = '' # The unique identifier of the text. This can be a ISBN number # or the project homepage. #epub_identifier = '' # A unique identification for the text. #epub_uid = '' # A tuple containing the cover image and cover page html template filenames. #epub_cover = () # A sequence of (type, uri, title) tuples for the guide element of content.opf. #epub_guide = () # HTML files that should be inserted before the pages created by sphinx. # The format is a list of tuples containing the path and title. #epub_pre_files = [] # HTML files shat should be inserted after the pages created by sphinx. # The format is a list of tuples containing the path and title. #epub_post_files = [] # A list of files that should not be packed into the epub file. epub_exclude_files = ['search.html'] # The depth of the table of contents in toc.ncx. #epub_tocdepth = 3 # Allow duplicate toc entries. #epub_tocdup = True # Choose between 'default' and 'includehidden'. #epub_tocscope = 'default' # Fix unsupported image types using the Pillow. #epub_fix_images = False # Scale large images. #epub_max_image_width = 0 # How to display URL addresses: 'footnote', 'no', or 'inline'. #epub_show_urls = 'inline' # If false, no index is generated. #epub_use_index = True ================================================ FILE: docs/datalab Commands.rst ================================================ datalab Commands ======================= .. attribute:: %bigquery .. parsed-literal:: usage: bigquery [-h] {sample,create,delete,dryrun,udf,execute,pipeline,table,schema,datasets,tables,extract,load} ... Execute various BigQuery-related operations. Use "%bigquery -h" for help on a specific command. positional arguments: {sample,create,delete,dryrun,udf,execute,pipeline,table,schema,datasets,tables,extract,load} commands sample Display a sample of the results of a BigQuery SQL query. The cell can optionally contain arguments for expanding variables in the query, if -q/--query was used, or it can contain SQL for a query. create Create a dataset or table. delete Delete a dataset or table. dryrun Execute a dry run of a BigQuery query and display approximate usage statistics udf Create a named Javascript BigQuery UDF execute Execute a BigQuery SQL query and optionally send the results to a named table. The cell can optionally contain arguments for expanding variables in the query. pipeline Define a deployable pipeline based on a BigQuery query. The cell can optionally contain arguments for expanding variables in the query. table View a BigQuery table. schema View a BigQuery table or view schema. datasets List the datasets in a BigQuery project. tables List the tables in a BigQuery project or dataset. extract Extract BigQuery query results or table to GCS. load Load data from GCS into a BigQuery table. optional arguments: -h, --help show this help message and exit None .. attribute:: %extension .. parsed-literal:: usage: %extension [-h] {mathjax} ... Load an extension into Datalab. Currently only mathjax is supported. positional arguments: {mathjax} commands mathjax Enabled MathJaX support in Datalab. optional arguments: -h, --help show this help message and exit None .. attribute:: %monitoring .. parsed-literal:: usage: monitoring [-h] {list} ... Execute various Monitoring-related operations. Use "%monitoring -h" for help on a specific command. positional arguments: {list} commands list List the metrics or resource types in a monitored project. optional arguments: -h, --help show this help message and exit None .. attribute:: %projects .. parsed-literal:: usage: projects [-h] {list,set} ... positional arguments: {list,set} commands list List available projects. set Set the default project. optional arguments: -h, --help show this help message and exit None .. attribute:: %pymodule .. parsed-literal:: usage: pymodule [-h] [-n NAME] optional arguments: -h, --help show this help message and exit -n NAME, --name NAME the name of the python module to create and import None .. attribute:: %sql .. parsed-literal:: usage: %%sql [-h] [-m MODULE] [-d {legacy,standard}] [-b BILLING] Create a named SQL module with one or more queries. The cell body should contain an optional initial part defining the default values for the variables, if any, using Python code, followed by one or more queries. Queries should start with 'DEFINE QUERY ' in order to bind them to . in the notebook (as datalab.data.SqlStament instances). The final query can optionally omit 'DEFINE QUERY ', as using the module name in places where a SqlStatement is expected will resolve to the final query in the module. Queries can refer to variables with '$', as well as refer to other queries within the same module, making it easy to compose nested queries and test their parts. The Python code defining the variable default values can assign scalar or list/tuple values to variables, or one of the special functions 'datestring' and 'source'. When a variable with a 'datestring' default is expanded it will expand to a formatted string based on the current date, while a 'source' default will expand to a table whose name is based on the current date. datestring() takes two named arguments, 'format' and 'offset'. The former is a format string that is the same as for Python's time.strftime function. The latter is a string containing a comma-separated list of expressions such as -1y, +2m, etc; these are offsets from the time of expansion that are applied in order. The suffix (y, m, d, h, M) correspond to units of years, months, days, hours and minutes, while the +n or -n prefix is the number of units to add or subtract from the time of expansion. Three special values 'now', 'today' and 'yesterday' are also supported; 'today' and 'yesterday' will be midnight UTC on the current date or previous days date. source() can take a 'name' argument for a fixed table name, or 'format' and 'offset' arguments similar to datestring(), but unlike datestring() will resolve to a Table with the specified name. optional arguments: -h, --help show this help message and exit -m MODULE, --module MODULE The name for this SQL module -d {legacy,standard}, --dialect {legacy,standard} BigQuery SQL dialect -b BILLING, --billing BILLING BigQuery billing tier .. attribute:: %storage .. parsed-literal:: usage: storage [-h] {copy,create,delete,list,read,view,write} ... Execute various storage-related operations. Use "%storage -h" for help on a specific command. positional arguments: {copy,create,delete,list,read,view,write} commands copy Copy one or more GCS objects to a different location. create Create one or more GCS buckets. delete Delete one or more GCS buckets or objects. list List buckets in a project, or contents of a bucket. read Read the contents of a storage object into a Python variable. view View the contents of a storage object. write Write the value of a Python variable to a storage object. optional arguments: -h, --help show this help message and exit None ================================================ FILE: docs/datalab.bigquery.rst ================================================ datalab.bigquery Module ======================= .. automodule:: datalab.bigquery :members: :undoc-members: :show-inheritance: .. autoclass:: datalab.bigquery.CSVOptions :members: .. autoclass:: datalab.bigquery.Dataset :members: .. autoclass:: datalab.bigquery.DatasetName :members: .. autoclass:: datalab.bigquery.Datasets :members: .. autoclass:: datalab.bigquery.FederatedTable :members: .. autoclass:: datalab.bigquery.Job :members: .. autoclass:: datalab.bigquery.Query :members: .. autoclass:: datalab.bigquery.QueryJob :members: .. autoclass:: datalab.bigquery.QueryResultsTable :members: .. autoclass:: datalab.bigquery.QueryStats :members: .. autoclass:: datalab.bigquery.Sampling :members: .. autoclass:: datalab.bigquery.Schema :members: .. autoclass:: datalab.bigquery.Table :members: .. autoclass:: datalab.bigquery.TableMetadata :members: .. autoclass:: datalab.bigquery.TableName :members: .. autoclass:: datalab.bigquery.UDF :members: .. autoclass:: datalab.bigquery.View :members: ================================================ FILE: docs/datalab.context.rst ================================================ datalab.context Module ====================== .. autoclass:: datalab.context.Context :members: .. autoclass:: datalab.context.Project :members: .. autoclass:: datalab.context.Projects :members: ================================================ FILE: docs/datalab.data.rst ================================================ datalab.data Module =================== .. automodule:: datalab.data :members: :undoc-members: :show-inheritance: .. autoclass:: datalab.data.Csv :members: .. autoclass:: datalab.data.SqlModule :members: .. autoclass:: datalab.data.SqlStatement :members: ================================================ FILE: docs/datalab.stackdriver.monitoring.rst ================================================ datalab.stackdriver.monitoring Module ===================================== .. autoclass:: datalab.stackdriver.monitoring.Groups :members: .. autoclass:: datalab.stackdriver.monitoring.MetricDescriptors :members: .. autoclass:: datalab.stackdriver.monitoring.ResourceDescriptors :members: .. autoclass:: datalab.stackdriver.monitoring.Query :members: .. autoclass:: datalab.stackdriver.monitoring.QueryMetadata :members: ================================================ FILE: docs/datalab.storage.rst ================================================ datalab.storage Module ====================== .. automodule:: datalab.storage :members: :undoc-members: :show-inheritance: .. autoclass:: datalab.storage.Bucket :members: .. autoclass:: datalab.storage.Buckets :members: .. autoclass:: datalab.storage.Item :members: .. autoclass:: datalab.storage.Items :members: ================================================ FILE: docs/gen-magic-rst.ipy ================================================ import subprocess, pkgutil, importlib, sys from cStringIO import StringIO # import submodules datalab_submodules = ['datalab.' + s + '.commands' for _,s,_ in pkgutil.iter_modules(['../datalab'])] google_submodules = ['google.datalab.' + s + '.commands' for _,s,_ in pkgutil.iter_modules(['../google/datalab'])] def generate_magic_docs(submodules, header, dir, ignored_magics=None): if not ignored_magics: ignored_magics = [] for m in submodules: try: importlib.import_module(m) except: sys.stderr.write('WARNING, could not find module ' + m + '. Ignoring..\n') magic_regex = "find " + dir + " -name '*.py' -exec perl -e '$f=join(\"\",<>); print \"$1\n\" if $f=~/register_line_cell_magic\ndef ([^\(]+)/m' {} \;" magics = subprocess.check_output(magic_regex, shell=True) reSTfile = open(header + '.rst', 'w') indent = '\n ' reSTfile.write(header + '\n') reSTfile.write('=======================\n\n') for m in sorted(magics.split()): if m in ignored_magics: sys.stderr.write('Ignoring magic ' + m + '\n') else: print('working on magic: '+ m) reSTfile.write('.. attribute:: %' + m + '\n') reSTfile.write('.. parsed-literal::\n') # hijack stdout since the ipython kernel call writes to stdout/err directly # and does not return its output tmpStdout, sys.stdout = sys.stdout, StringIO() get_ipython().magic(m + ' -h') resultout = sys.stdout.getvalue().splitlines() sys.stdout = tmpStdout reSTfile.writelines(indent + indent.join(resultout) + '\n\n') generate_magic_docs(datalab_submodules, 'datalab Commands', '../datalab', ignored_magics=['chart', 'csv']); generate_magic_docs(google_submodules, 'google.datalab Commands', '../google'); ================================================ FILE: docs/google.datalab Commands.rst ================================================ google.datalab Commands ======================= .. attribute:: %bq .. parsed-literal:: usage: %bq [-h] {datasets,tables,query,execute,extract,sample,dryrun,udf,datasource,load} ... Execute various BigQuery-related operations. Use "%bq -h" for help on a specific command. positional arguments: {datasets,tables,query,execute,extract,sample,dryrun,udf,datasource,load} commands datasets Operations on BigQuery datasets tables Operations on BigQuery tables query Create or execute a BigQuery SQL query object, optionally using other SQL objects, UDFs, or external datasources. If a query name is not specified, the query is executed. execute Execute a BigQuery SQL query and optionally send the results to a named table. The cell can optionally contain arguments for expanding variables in the query. extract Extract a query or table into file (local or GCS) sample Display a sample of the results of a BigQuery SQL query. The cell can optionally contain arguments for expanding variables in the query, if -q/--query was used, or it can contain SQL for a query. dryrun Execute a dry run of a BigQuery query and display approximate usage statistics udf Create a named Javascript BigQuery UDF datasource Create a named Javascript BigQuery external data source load Load data from GCS into a BigQuery table. If creating a new table, a schema should be specified in YAML or JSON in the cell body, otherwise the schema is inferred from existing table. optional arguments: -h, --help show this help message and exit None .. attribute:: %chart .. parsed-literal:: usage: %chart [-h] {annotation,area,bars,bubbles,calendar,candlestick,columns,combo,gauge,geo,heatmap,histogram,line,map,org,paged_table,pie,sankey,scatter,stepped_area,table,timeline,treemap} ... Generate an inline chart using Google Charts using the data in a Table, Query, dataframe, or list. Numerous types of charts are supported. Options for the charts can be specified in the cell body using YAML or JSON. positional arguments: {annotation,area,bars,bubbles,calendar,candlestick,columns,combo,gauge,geo,heatmap,histogram,line,map,org,paged_table,pie,sankey,scatter,stepped_area,table,timeline,treemap} commands annotation Generate a annotation chart. area Generate a area chart. bars Generate a bars chart. bubbles Generate a bubbles chart. calendar Generate a calendar chart. candlestick Generate a candlestick chart. columns Generate a columns chart. combo Generate a combo chart. gauge Generate a gauge chart. geo Generate a geo chart. heatmap Generate a heatmap chart. histogram Generate a histogram chart. line Generate a line chart. map Generate a map chart. org Generate a org chart. paged_table Generate a paged_table chart. pie Generate a pie chart. sankey Generate a sankey chart. scatter Generate a scatter chart. stepped_area Generate a stepped_area chart. table Generate a table chart. timeline Generate a timeline chart. treemap Generate a treemap chart. optional arguments: -h, --help show this help message and exit None .. attribute:: %csv .. parsed-literal:: usage: csv [-h] {view} ... positional arguments: {view} commands view Browse CSV files without providing a schema. Each value is considered string type. optional arguments: -h, --help show this help message and exit None .. attribute:: %datalab .. parsed-literal:: usage: %datalab [-h] {config,project} ... Execute operations that apply to multiple Datalab APIs. Use "%datalab -h" for help on a specific command. positional arguments: {config,project} commands config List or set API-specific configurations. project Get or set the default project ID optional arguments: -h, --help show this help message and exit None .. attribute:: %gcs .. parsed-literal:: usage: %gcs [-h] {copy,create,delete,list,read,view,write} ... Execute various Google Cloud Storage related operations. Use "%gcs -h" for help on a specific command. positional arguments: {copy,create,delete,list,read,view,write} commands copy Copy one or more Google Cloud Storage objects to a different location. create Create one or more Google Cloud Storage buckets. delete Delete one or more Google Cloud Storage buckets or objects. list List buckets in a project, or contents of a bucket. read Read the contents of a Google Cloud Storage object into a Python variable. view View the contents of a Google Cloud Storage object. write Write the value of a Python variable to a Google Cloud Storage object. optional arguments: -h, --help show this help message and exit None .. attribute:: %sd .. parsed-literal:: usage: %sd [-h] {monitoring} ... Execute various Stackdriver related operations. Use "%sd -h" for help on a specific Stackdriver product. positional arguments: {monitoring} commands monitoring Execute Stackdriver monitoring related operations. Use "sd monitoring -h" for help on a specific command optional arguments: -h, --help show this help message and exit None ================================================ FILE: docs/google.datalab.bigquery.rst ================================================ google.datalab.bigquery Module ============================== .. automodule:: google.datalab.bigquery :members: :undoc-members: :show-inheritance: .. autoclass:: google.datalab.bigquery.CSVOptions :members: .. autoclass:: google.datalab.bigquery.Dataset :members: .. autoclass:: google.datalab.bigquery.DatasetName :members: .. autoclass:: google.datalab.bigquery.Datasets :members: .. autoclass:: google.datalab.bigquery.ExternalDataSource :members: .. autoclass:: google.datalab.bigquery.Query :members: .. autoclass:: google.datalab.bigquery.QueryOutput :members: .. autoclass:: google.datalab.bigquery.QueryResultsTable :members: .. autoclass:: google.datalab.bigquery.QueryStats :members: .. autoclass:: google.datalab.bigquery.Sampling :members: .. autoclass:: google.datalab.bigquery.Schema :members: .. autoclass:: google.datalab.bigquery.SchemaField :members: .. autoclass:: google.datalab.bigquery.Table :members: .. autoclass:: google.datalab.bigquery.TableMetadata :members: .. autoclass:: google.datalab.bigquery.TableName :members: .. autoclass:: google.datalab.bigquery.UDF :members: .. autoclass:: google.datalab.bigquery.View :members: ================================================ FILE: docs/google.datalab.data.rst ================================================ google.datalab.data Module ========================== .. automodule:: google.datalab.data :members: :undoc-members: :show-inheritance: .. autoclass:: google.datalab.data.CsvFile :members: ================================================ FILE: docs/google.datalab.ml.rst ================================================ google.datalab.ml Module ======================== .. automodule:: google.datalab.ml :members: :undoc-members: :show-inheritance: .. autoclass:: google.datalab.ml.Job :members: .. autoclass:: google.datalab.ml.Jobs :members: .. autoclass:: google.datalab.ml.Summary :members: .. autoclass:: google.datalab.ml.TensorBoard :members: .. autoclass:: google.datalab.ml.CsvDataSet :members: .. autoclass:: google.datalab.ml.BigQueryDataSet :members: .. autoclass:: google.datalab.ml.Models :members: .. autoclass:: google.datalab.ml.ModelVersions :members: .. autoclass:: google.datalab.ml.ConfusionMatrix :members: .. autoclass:: google.datalab.ml.FeatureSliceView :members: .. autoclass:: google.datalab.ml.CloudTrainingConfig :members: ================================================ FILE: docs/google.datalab.rst ================================================ google.datalab Module ===================== .. automodule:: google.datalab :members: :undoc-members: :show-inheritance: .. autoclass:: google.datalab.Context :members: .. autoclass:: google.datalab.Job :members: ================================================ FILE: docs/google.datalab.stackdriver.monitoring.rst ================================================ google.datalab.stackdriver.monitoring Module ============================================ .. autoclass:: google.datalab.stackdriver.monitoring.Groups :members: .. autoclass:: google.datalab.stackdriver.monitoring.MetricDescriptors :members: .. autoclass:: google.datalab.stackdriver.monitoring.ResourceDescriptors :members: .. autoclass:: google.datalab.stackdriver.monitoring.Query :members: .. autoclass:: google.datalab.stackdriver.monitoring.QueryMetadata :members: ================================================ FILE: docs/google.datalab.storage.rst ================================================ google.datalab.storage Module ============================= .. automodule:: google.datalab.storage :members: :undoc-members: :show-inheritance: .. autoclass:: google.datalab.storage.Bucket :members: .. autoclass:: google.datalab.storage.Buckets :members: .. autoclass:: google.datalab.storage.Object :members: .. autoclass:: google.datalab.storage.Objects :members: ================================================ FILE: docs/index.rst ================================================ Welcome to Cloud Datalab's documentation ======================================== google.datalab namespace ######################## Contents: .. toctree:: google.datalab google.datalab.bigquery google.datalab.data google.datalab.ml google.datalab.stackdriver.monitoring google.datalab.storage google.datalab Commands ML Toolbox: .. toctree:: mltoolbox.classification.dnn mltoolbox.classification.linear mltoolbox.regression.dnn mltoolbox.regression.linear mltoolbox.image.classification datalab namespace ################# Please note, this namespace is planned to be phased out. You are strongly encouraged to move to the new google.datalab namespace above. Contents: .. toctree:: datalab.bigquery datalab.context datalab.data datalab.stackdriver.monitoring datalab.storage datalab Commands Indices and tables ================== * :ref:`genindex` * :ref:`modindex` * :ref:`search` ================================================ FILE: docs/make.bat ================================================ @ECHO OFF REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set BUILDDIR=_build set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . set I18NSPHINXOPTS=%SPHINXOPTS% . if NOT "%PAPER%" == "" ( set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% ) if "%1" == "" goto help if "%1" == "help" ( :help echo.Please use `make ^` where ^ is one of echo. html to make standalone HTML files echo. dirhtml to make HTML files named index.html in directories echo. singlehtml to make a single large HTML file echo. pickle to make pickle files echo. json to make JSON files echo. htmlhelp to make HTML files and a HTML help project echo. qthelp to make HTML files and a qthelp project echo. devhelp to make HTML files and a Devhelp project echo. epub to make an epub echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter echo. text to make text files echo. man to make manual pages echo. texinfo to make Texinfo files echo. gettext to make PO message catalogs echo. changes to make an overview over all changed/added/deprecated items echo. xml to make Docutils-native XML files echo. pseudoxml to make pseudoxml-XML files for display purposes echo. linkcheck to check all external links for integrity echo. doctest to run all doctests embedded in the documentation if enabled echo. coverage to run coverage check of the documentation if enabled goto end ) if "%1" == "clean" ( for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i del /q /s %BUILDDIR%\* goto end ) REM Check if sphinx-build is available and fallback to Python version if any %SPHINXBUILD% 2> nul if errorlevel 9009 goto sphinx_python goto sphinx_ok :sphinx_python set SPHINXBUILD=python -m sphinx.__init__ %SPHINXBUILD% 2> nul if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.http://sphinx-doc.org/ exit /b 1 ) :sphinx_ok if "%1" == "html" ( %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html if errorlevel 1 exit /b 1 echo. echo.Build finished. The HTML pages are in %BUILDDIR%/html. goto end ) if "%1" == "dirhtml" ( %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml if errorlevel 1 exit /b 1 echo. echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. goto end ) if "%1" == "singlehtml" ( %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml if errorlevel 1 exit /b 1 echo. echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. goto end ) if "%1" == "pickle" ( %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can process the pickle files. goto end ) if "%1" == "json" ( %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can process the JSON files. goto end ) if "%1" == "htmlhelp" ( %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can run HTML Help Workshop with the ^ .hhp project file in %BUILDDIR%/htmlhelp. goto end ) if "%1" == "qthelp" ( %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can run "qcollectiongenerator" with the ^ .qhcp project file in %BUILDDIR%/qthelp, like this: echo.^> qcollectiongenerator %BUILDDIR%\qthelp\api.qhcp echo.To view the help file: echo.^> assistant -collectionFile %BUILDDIR%\qthelp\api.ghc goto end ) if "%1" == "devhelp" ( %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp if errorlevel 1 exit /b 1 echo. echo.Build finished. goto end ) if "%1" == "epub" ( %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub if errorlevel 1 exit /b 1 echo. echo.Build finished. The epub file is in %BUILDDIR%/epub. goto end ) if "%1" == "latex" ( %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex if errorlevel 1 exit /b 1 echo. echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. goto end ) if "%1" == "latexpdf" ( %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex cd %BUILDDIR%/latex make all-pdf cd %~dp0 echo. echo.Build finished; the PDF files are in %BUILDDIR%/latex. goto end ) if "%1" == "latexpdfja" ( %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex cd %BUILDDIR%/latex make all-pdf-ja cd %~dp0 echo. echo.Build finished; the PDF files are in %BUILDDIR%/latex. goto end ) if "%1" == "text" ( %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text if errorlevel 1 exit /b 1 echo. echo.Build finished. The text files are in %BUILDDIR%/text. goto end ) if "%1" == "man" ( %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man if errorlevel 1 exit /b 1 echo. echo.Build finished. The manual pages are in %BUILDDIR%/man. goto end ) if "%1" == "texinfo" ( %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo if errorlevel 1 exit /b 1 echo. echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. goto end ) if "%1" == "gettext" ( %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale if errorlevel 1 exit /b 1 echo. echo.Build finished. The message catalogs are in %BUILDDIR%/locale. goto end ) if "%1" == "changes" ( %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes if errorlevel 1 exit /b 1 echo. echo.The overview file is in %BUILDDIR%/changes. goto end ) if "%1" == "linkcheck" ( %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck if errorlevel 1 exit /b 1 echo. echo.Link check complete; look for any errors in the above output ^ or in %BUILDDIR%/linkcheck/output.txt. goto end ) if "%1" == "doctest" ( %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest if errorlevel 1 exit /b 1 echo. echo.Testing of doctests in the sources finished, look at the ^ results in %BUILDDIR%/doctest/output.txt. goto end ) if "%1" == "coverage" ( %SPHINXBUILD% -b coverage %ALLSPHINXOPTS% %BUILDDIR%/coverage if errorlevel 1 exit /b 1 echo. echo.Testing of coverage in the sources finished, look at the ^ results in %BUILDDIR%/coverage/python.txt. goto end ) if "%1" == "xml" ( %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml if errorlevel 1 exit /b 1 echo. echo.Build finished. The XML files are in %BUILDDIR%/xml. goto end ) if "%1" == "pseudoxml" ( %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml if errorlevel 1 exit /b 1 echo. echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. goto end ) :end ================================================ FILE: docs/mltoolbox.classification.dnn.rst ================================================ mltoolbox.classification.dnn ============================ .. automodule:: mltoolbox.classification.dnn :members: :undoc-members: :show-inheritance: .. autofunction:: mltoolbox.classification.dnn.analyze .. autofunction:: mltoolbox.classification.dnn.analyze_async .. autofunction:: mltoolbox.classification.dnn.batch_predict .. autofunction:: mltoolbox.classification.dnn.batch_predict_async .. autofunction:: mltoolbox.classification.dnn.predict .. autofunction:: mltoolbox.classification.dnn.train .. autofunction:: mltoolbox.classification.dnn.train_async ================================================ FILE: docs/mltoolbox.classification.linear.rst ================================================ mltoolbox.classification.linear =============================== .. automodule:: mltoolbox.classification.linear :members: :undoc-members: :show-inheritance: .. autofunction:: mltoolbox.classification.linear.analyze .. autofunction:: mltoolbox.classification.linear.analyze_async .. autofunction:: mltoolbox.classification.linear.batch_predict .. autofunction:: mltoolbox.classification.linear.batch_predict_async .. autofunction:: mltoolbox.classification.linear.predict .. autofunction:: mltoolbox.classification.linear.train .. autofunction:: mltoolbox.classification.linear.train_async ================================================ FILE: docs/mltoolbox.image.classification.rst ================================================ mltoolbox.image.classification ============================== .. automodule:: mltoolbox.image.classification :members: :undoc-members: :show-inheritance: .. autofunction:: mltoolbox.image.classification.preprocess .. autofunction:: mltoolbox.image.classification.preprocess_async .. autofunction:: mltoolbox.image.classification.train .. autofunction:: mltoolbox.image.classification.train_async .. autofunction:: mltoolbox.image.classification.predict .. autofunction:: mltoolbox.image.classification.batch_predict .. autofunction:: mltoolbox.image.classification.batch_predict_async ================================================ FILE: docs/mltoolbox.regression.dnn.rst ================================================ mltoolbox.regression.dnn ======================== .. automodule:: mltoolbox.regression.dnn :members: :undoc-members: :show-inheritance: .. autofunction:: mltoolbox.regression.dnn.analyze .. autofunction:: mltoolbox.regression.dnn.analyze_async .. autofunction:: mltoolbox.regression.dnn.batch_predict .. autofunction:: mltoolbox.regression.dnn.batch_predict_async .. autofunction:: mltoolbox.regression.dnn.predict .. autofunction:: mltoolbox.regression.dnn.train .. autofunction:: mltoolbox.regression.dnn.train_async ================================================ FILE: docs/mltoolbox.regression.linear.rst ================================================ mltoolbox.regression.linear =========================== .. automodule:: mltoolbox.regression.linear :members: :undoc-members: :show-inheritance: .. autofunction:: mltoolbox.regression.linear.analyze .. autofunction:: mltoolbox.regression.linear.analyze_async .. autofunction:: mltoolbox.regression.linear.batch_predict .. autofunction:: mltoolbox.regression.linear.batch_predict_async .. autofunction:: mltoolbox.regression.linear.predict .. autofunction:: mltoolbox.regression.linear.train .. autofunction:: mltoolbox.regression.linear.train_async ================================================ FILE: externs/ts/require/require.d.ts ================================================ // Type definitions for RequireJS 2.1.20 // Project: http://requirejs.org/ // Definitions by: Josh Baldwin // Definitions: https://github.com/DefinitelyTyped/DefinitelyTyped /* require-2.1.8.d.ts may be freely distributed under the MIT license. Copyright (c) 2013 Josh Baldwin https://github.com/jbaldwin/require.d.ts Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ declare module 'module' { var mod: { config: () => any; id: string; uri: string; } export = mod; } interface RequireError extends Error { /** * The error ID that maps to an ID on a web page. **/ requireType: string; /** * Required modules. **/ requireModules: string[]; /** * The original error, if there is one (might be null). **/ originalError: Error; } interface RequireShim { /** * List of dependencies. **/ deps?: string[]; /** * Name the module will be exported as. **/ exports?: string; /** * Initialize function with all dependcies passed in, * if the function returns a value then that value is used * as the module export value instead of the object * found via the 'exports' string. * @param dependencies * @return **/ init?: (...dependencies: any[]) => any; } interface RequireConfig { // The root path to use for all module lookups. baseUrl?: string; // Path mappings for module names not found directly under // baseUrl. paths?: { [key: string]: any; }; // Dictionary of Shim's. // does not cover case of key->string[] shim?: { [key: string]: RequireShim; }; /** * For the given module prefix, instead of loading the * module with the given ID, substitude a different * module ID. * * @example * requirejs.config({ * map: { * 'some/newmodule': { * 'foo': 'foo1.2' * }, * 'some/oldmodule': { * 'foo': 'foo1.0' * } * } * }); **/ map?: { [id: string]: { [id: string]: string; }; }; /** * Allows pointing multiple module IDs to a module ID that contains a bundle of modules. * * @example * requirejs.config({ * bundles: { * 'primary': ['main', 'util', 'text', 'text!template.html'], * 'secondary': ['text!secondary.html'] * } * }); **/ bundles?: { [key: string]: string[]; }; /** * AMD configurations, use module.config() to access in * define() functions **/ config?: { [id: string]: {}; }; /** * Configures loading modules from CommonJS packages. **/ packages?: {}; /** * The number of seconds to wait before giving up on loading * a script. The default is 7 seconds. **/ waitSeconds?: number; /** * A name to give to a loading context. This allows require.js * to load multiple versions of modules in a page, as long as * each top-level require call specifies a unique context string. **/ context?: string; /** * An array of dependencies to load. **/ deps?: string[]; /** * A function to pass to require that should be require after * deps have been loaded. * @param modules **/ callback?: (...modules: any[]) => void; /** * If set to true, an error will be thrown if a script loads * that does not call define() or have shim exports string * value that can be checked. **/ enforceDefine?: boolean; /** * If set to true, document.createElementNS() will be used * to create script elements. **/ xhtml?: boolean; /** * Extra query string arguments appended to URLs that RequireJS * uses to fetch resources. Most useful to cache bust when * the browser or server is not configured correctly. * * @example * urlArgs: "bust= + (new Date()).getTime() **/ urlArgs?: string; /** * Specify the value for the type="" attribute used for script * tags inserted into the document by RequireJS. Default is * "text/javascript". To use Firefox's JavasScript 1.8 * features, use "text/javascript;version=1.8". **/ scriptType?: string; /** * If set to true, skips the data-main attribute scanning done * to start module loading. Useful if RequireJS is embedded in * a utility library that may interact with other RequireJS * library on the page, and the embedded version should not do * data-main loading. **/ skipDataMain?: boolean; /** * Allow extending requirejs to support Subresource Integrity * (SRI). **/ onNodeCreated?: (node: HTMLScriptElement, config: RequireConfig, moduleName: string, url: string) => void; } // todo: not sure what to do with this guy interface RequireModule { /** * **/ config(): {}; } /** * **/ interface RequireMap { /** * **/ prefix: string; /** * **/ name: string; /** * **/ parentMap: RequireMap; /** * **/ url: string; /** * **/ originalName: string; /** * **/ fullName: string; } interface Require { /** * Configure require.js **/ config(config: RequireConfig): Require; /** * CommonJS require call * @param module Module to load * @return The loaded module */ (module: string): any; /** * Start the main app logic. * Callback is optional. * Can alternatively use deps and callback. * @param modules Required modules to load. **/ (modules: string[]): void; /** * @see Require() * @param ready Called when required modules are ready. **/ (modules: string[], ready: Function): void; /** * @see http://requirejs.org/docs/api.html#errbacks * @param ready Called when required modules are ready. **/ (modules: string[], ready: Function, errback: Function): void; /** * Generate URLs from require module * @param module Module to URL * @return URL string **/ toUrl(module: string): string; /** * Returns true if the module has already been loaded and defined. * @param module Module to check **/ defined(module: string): boolean; /** * Returns true if the module has already been requested or is in the process of loading and should be available at some point. * @param module Module to check **/ specified(module: string): boolean; /** * On Error override * @param err **/ onError(err: RequireError, errback?: (err: RequireError) => void): void; /** * Undefine a module * @param module Module to undefine. **/ undef(module: string): void; /** * Semi-private function, overload in special instance of undef() **/ onResourceLoad(context: Object, map: RequireMap, depArray: RequireMap[]): void; } interface RequireDefine { /** * Define Simple Name/Value Pairs * @param config Dictionary of Named/Value pairs for the config. **/ (config: { [key: string]: any; }): void; /** * Define function. * @param func: The function module. **/ (func: () => any): void; /** * Define function with dependencies. * @param deps List of dependencies module IDs. * @param ready Callback function when the dependencies are loaded. * callback param deps module dependencies * callback return module definition **/ (deps: string[], ready: Function): void; /** * Define module with simplified CommonJS wrapper. * @param ready * callback require requirejs instance * callback exports exports object * callback module module * callback return module definition **/ (ready: (require: Require, exports: { [key: string]: any; }, module: RequireModule) => any): void; /** * Define a module with a name and dependencies. * @param name The name of the module. * @param deps List of dependencies module IDs. * @param ready Callback function when the dependencies are loaded. * callback deps module dependencies * callback return module definition **/ (name: string, deps: string[], ready: Function): void; /** * Define a module with a name. * @param name The name of the module. * @param ready Callback function when the dependencies are loaded. * callback return module definition **/ (name: string, ready: Function): void; /** * Used to allow a clear indicator that a global define function (as needed for script src browser loading) conforms * to the AMD API, any global define function SHOULD have a property called "amd" whose value is an object. * This helps avoid conflict with any other existing JavaScript code that could have defined a define() function * that does not conform to the AMD API. * define.amd.jQuery is specific to jQuery and indicates that the loader is able to account for multiple version * of jQuery being loaded simultaneously. */ amd: Object; } // Ambient declarations for 'require' and 'define' declare var requirejs: Require; declare var require: Require; declare var define: RequireDefine; ================================================ FILE: google/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: google/datalab/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from google.datalab._context import Context from google.datalab._job import Job, JobError import warnings __all__ = ['Context', 'Job', 'JobError'] warnings.warn("Datalab is deprecated. For more information, see https://cloud.google.com/datalab/docs/resources/deprecation.", DeprecationWarning) ================================================ FILE: google/datalab/_context.py ================================================ # Copyright 2014 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Implements Context functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object from google.datalab.utils import _utils as du class Context(object): """Maintains contextual state for connecting to Cloud APIs. """ _global_context = None def __init__(self, project_id, credentials, config=None): """Initializes an instance of a Context object. Args: project_id: the current cloud project. credentials: the credentials to use to authorize requests. config: key/value configurations for cloud operations """ self._project_id = project_id self._credentials = credentials self._config = config if config is not None else Context._get_default_config() @property def credentials(self): """Retrieves the value of the credentials property. Returns: The current credentials used in authorizing API requests. """ return self._credentials def set_credentials(self, credentials): """ Set the credentials for the context. """ self._credentials = credentials @property def project_id(self): """Retrieves the value of the project_id property. Returns: The current project id to associate with API requests. """ if not self._project_id: raise Exception('No project ID found. Perhaps you should set one by running' '"%datalab project set -p " in a code cell.') return self._project_id def set_project_id(self, project_id): """ Set the project_id for the context. """ self._project_id = project_id if self == Context._global_context: du.save_project_id(self._project_id) @property def config(self): """ Retrieves the value of the config property. Returns: The current config object used in cloud operations """ return self._config def set_config(self, config): """ Set the config property for the context. """ self._config = config @staticmethod def _is_signed_in(): """ If the user has signed in or it is on GCE VM with default credential.""" try: du.get_credentials() return True except Exception: return False @staticmethod def _get_default_config(): """Return a default config object""" return { 'bigquery_billing_tier': None } @staticmethod def default(): """Retrieves a default Context object, creating it if necessary. The default Context is a global shared instance used every time the default context is retrieved. Attempting to use a Context with no project_id will raise an exception, so on first use set_project_id must be called. Returns: An initialized and shared instance of a Context object. """ credentials = du.get_credentials() project = du.get_default_project_id() if Context._global_context is None: config = Context._get_default_config() Context._global_context = Context(project, credentials, config) else: # Always update everything in case the access token is revoked or expired, config changed, # or project changed. Context._global_context.set_credentials(credentials) Context._global_context.set_project_id(project) return Context._global_context ================================================ FILE: google/datalab/_job.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Job functionality for async tasks.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from builtins import object import concurrent.futures import datetime import time import traceback import uuid class JobError(Exception): """ A helper class to capture multiple components of Job errors. """ def __init__(self, location, message, reason): self.location = location self.message = message self.reason = reason def __str__(self): return '%s %s %s' % (self.message, self.reason, self.location) class Job(object): """A manager object for async operations. A Job can have a Future in which case it will be able to monitor its own completion state and result, or it may have no Future in which case it must be a derived class that manages this some other way. We do this instead of having an abstract base class in order to make wait_one/wait_all more efficient; instead of just sleeping and polling we can use more reactive ways of monitoring groups of Jobs. """ _POLL_INTERVAL_SECONDS = 5 def __init__(self, job_id=None, future=None): """Initializes an instance of a Job. Args: job_id: a unique ID for the job. If None, a UUID will be generated. future: the Future associated with the Job, if any. """ self._job_id = str(uuid.uuid4()) if job_id is None else job_id self._future = future self._is_complete = False self._errors = None self._fatal_error = None self._result = None self._start_time = datetime.datetime.utcnow() self._end_time = None def __str__(self): return self._job_id @property def id(self): """ Get the Job ID. Returns: The ID of the job. """ return self._job_id @property def is_complete(self): """ Get the completion state of the job. Returns: True if the job is complete; False if it is still running. """ self._refresh_state() return self._is_complete @property def failed(self): """ Get the success state of the job. Returns: True if the job failed; False if it is still running or succeeded (possibly with partial failure). """ self._refresh_state() return self._is_complete and self._fatal_error is not None @property def fatal_error(self): """ Get the job error. Returns: None if the job succeeded or is still running, else the error tuple for the failure. """ self._refresh_state() return self._fatal_error @property def errors(self): """ Get the non-fatal errors in the job. Returns: None if the job is still running, else the list of errors that occurred. """ self._refresh_state() return self._errors def result(self): """ Get the result for a job. This will block if the job is incomplete. Returns: The result for the Job. Raises: An exception if the Job resulted in an exception. """ self.wait() if self._fatal_error: raise self._fatal_error return self._result @property def start_time_utc(self): """ The UTC start time of the job as a Python datetime. """ return self._start_time @property def end_time_utc(self): """ The UTC end time of the job (or None if incomplete) as a Python datetime. """ return self._end_time @property def total_time(self): """ The total time in fractional seconds that the job took, or None if not complete. """ if self._end_time is None: return None return (self._end_time - self._start_time).total_seconds() def _refresh_state(self): """ Get the state of a job. Must be overridden by derived Job classes for Jobs that don't use a Future. """ if self._is_complete: return if not self._future: raise Exception('Please implement this in the derived class') if self._future.done(): self._is_complete = True self._end_time = datetime.datetime.utcnow() try: self._result = self._future.result() except Exception as e: message = str(e) self._fatal_error = JobError(location=traceback.format_exc(), message=message, reason=str(type(e))) def _timeout(self): """ Helper for raising timeout errors. """ raise concurrent.futures.TimeoutError('Timed out waiting for Job %s to complete' % self._job_id) def wait(self, timeout=None): """ Wait for the job to complete, or a timeout to happen. Args: timeout: how long to wait before giving up (in seconds); default None which means no timeout. Returns: The Job """ if self._future: try: # Future.exception() will return rather than raise any exception so we use it. self._future.exception(timeout) except concurrent.futures.TimeoutError: self._timeout() self._refresh_state() else: # fall back to polling while not self.is_complete: if timeout is not None: if timeout <= 0: self._timeout() timeout -= Job._POLL_INTERVAL_SECONDS time.sleep(Job._POLL_INTERVAL_SECONDS) return self @property def state(self): """ Describe the state of a Job. Returns: A string describing the job's state. """ state = 'in progress' if self.is_complete: if self.failed: state = 'failed with error: %s' % str(self._fatal_error) elif self._errors: state = 'completed with some non-fatal errors' else: state = 'completed' return state def __repr__(self): """ Get the notebook representation for the job. """ return 'Job %s %s' % (self._job_id, self.state) @staticmethod def _wait(jobs, timeout, return_when): # If a single job is passed in, make it an array for consistency if isinstance(jobs, Job): jobs = [jobs] elif len(jobs) == 0: return jobs wait_on_one = return_when == concurrent.futures.FIRST_COMPLETED completed = [] while True: if timeout is not None: timeout -= Job._POLL_INTERVAL_SECONDS done = [job for job in jobs if job.is_complete] if len(done): completed.extend(done) for job in done: jobs.remove(job) if wait_on_one or len(jobs) == 0: return completed if timeout is not None and timeout < 0: return completed # Need to block for some time. Favor using concurrent.futures.wait if possible # as it can return early if a (thread) job is ready; else fall back to time.sleep. futures = [job._future for job in jobs if job._future] if len(futures) == 0: time.sleep(Job._POLL_INTERVAL_SECONDS) else: concurrent.futures.wait(futures, timeout=Job._POLL_INTERVAL_SECONDS, return_when=return_when) ================================================ FILE: google/datalab/bigquery/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - BigQuery Functionality.""" from __future__ import absolute_import from ._csv_options import CSVOptions from ._dataset import Dataset, Datasets from ._external_data_source import ExternalDataSource from ._query import Query from ._query_output import QueryOutput from ._query_results_table import QueryResultsTable from ._query_stats import QueryStats from ._sampling import Sampling from ._schema import Schema, SchemaField from ._table import Table, TableMetadata from ._udf import UDF from ._utils import TableName, DatasetName from ._view import View __all__ = ['CSVOptions', 'Dataset', 'Datasets', 'ExternalDataSource', 'Query', 'QueryOutput', 'QueryResultsTable', 'QueryStats', 'Sampling', 'Schema', 'SchemaField', 'Table', 'TableMetadata', 'UDF', 'TableName', 'DatasetName', 'View'] ================================================ FILE: google/datalab/bigquery/_api.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements BigQuery HTTP API wrapper.""" from __future__ import absolute_import from __future__ import unicode_literals from past.builtins import basestring from builtins import object import google.datalab.utils import google.datalab.bigquery class Api(object): """A helper class to issue BigQuery HTTP requests.""" # TODO(nikhilko): Use named placeholders in these string templates. _ENDPOINT = 'https://www.googleapis.com/bigquery/v2' _JOBS_PATH = '/projects/%s/jobs/%s' _QUERIES_PATH = '/projects/%s/queries/%s' _DATASETS_PATH = '/projects/%s/datasets/%s' _TABLES_PATH = '/projects/%s/datasets/%s/tables/%s%s' _TABLEDATA_PATH = '/projects/%s/datasets/%s/tables/%s%s/data' _DEFAULT_TIMEOUT = 60000 def __init__(self, context): """Initializes the BigQuery helper with context information. Args: context: a Context object providing project_id and credentials. """ self._context = context @property def project_id(self): """The project_id associated with this API client.""" return self._context.project_id @property def credentials(self): """The credentials associated with this API client.""" return self._context.credentials @property def bigquery_billing_tier(self): """The BigQuery billing tier associated with this API client.""" return self._context.config.get('bigquery_billing_tier', None) def jobs_insert_load(self, source, table_name, append=False, overwrite=False, create=False, source_format='CSV', field_delimiter=',', allow_jagged_rows=False, allow_quoted_newlines=False, encoding='UTF-8', ignore_unknown_values=False, max_bad_records=0, quote='"', skip_leading_rows=0): """ Issues a request to load data from GCS to a BQ table Args: source: the URL of the source bucket(s). Can include wildcards, and can be a single string argument or a list. table_name: a tuple representing the full name of the destination table. append: if True append onto existing table contents. overwrite: if True overwrite existing table contents. create: if True, create the table if it doesn't exist source_format: the format of the data; default 'CSV'. Other options are DATASTORE_BACKUP or NEWLINE_DELIMITED_JSON. field_delimiter: The separator for fields in a CSV file. BigQuery converts the string to ISO-8859-1 encoding, and then uses the first byte of the encoded string to split the data as raw binary (default ','). allow_jagged_rows: If True, accept rows in CSV files that are missing trailing optional columns; the missing values are treated as nulls (default False). allow_quoted_newlines: If True, allow quoted data sections in CSV files that contain newline characters (default False). encoding: The character encoding of the data, either 'UTF-8' (the default) or 'ISO-8859-1'. ignore_unknown_values: If True, accept rows that contain values that do not match the schema; the unknown values are ignored (default False). max_bad_records: The maximum number of bad records that are allowed (and ignored) before returning an 'invalid' error in the Job result (default 0). quote: The value used to quote data sections in a CSV file; default '"'. If your data does not contain quoted sections, set the property value to an empty string. If your data contains quoted newline characters, you must also enable allow_quoted_newlines. skip_leading_rows: A number of rows at the top of a CSV file to skip (default 0). Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._JOBS_PATH % (table_name.project_id, '')) if isinstance(source, basestring): source = [source] write_disposition = 'WRITE_EMPTY' if overwrite: write_disposition = 'WRITE_TRUNCATE' if append: write_disposition = 'WRITE_APPEND' data = { 'kind': 'bigquery#job', 'configuration': { 'load': { 'sourceUris': source, 'destinationTable': { 'projectId': table_name.project_id, 'datasetId': table_name.dataset_id, 'tableId': table_name.table_id }, 'createDisposition': 'CREATE_IF_NEEDED' if create else 'CREATE_NEVER', 'writeDisposition': write_disposition, 'sourceFormat': source_format, 'ignoreUnknownValues': ignore_unknown_values, 'maxBadRecords': max_bad_records, } } } if source_format == 'CSV': load_config = data['configuration']['load'] load_config.update({ 'fieldDelimiter': field_delimiter, 'allowJaggedRows': allow_jagged_rows, 'allowQuotedNewlines': allow_quoted_newlines, 'quote': quote, 'encoding': encoding, 'skipLeadingRows': skip_leading_rows }) return google.datalab.utils.Http.request(url, data=data, credentials=self.credentials) def jobs_insert_query(self, sql, table_name=None, append=False, overwrite=False, dry_run=False, use_cache=True, batch=True, allow_large_results=False, table_definitions=None, query_params=None): """Issues a request to insert a query job. Args: sql: the SQL string representing the query to execute. table_name: None for an anonymous table, or a name parts tuple for a long-lived table. append: if True, append to the table if it is non-empty; else the request will fail if table is non-empty unless overwrite is True. overwrite: if the table already exists, truncate it instead of appending or raising an Exception. dry_run: whether to actually execute the query or just dry run it. use_cache: whether to use past query results or ignore cache. Has no effect if destination is specified. batch: whether to run this as a batch job (lower priority) or as an interactive job (high priority, more expensive). allow_large_results: whether to allow large results (slower with some restrictions but can handle big jobs). table_definitions: a dictionary of ExternalDataSource names and objects for any external tables referenced in the query. query_params: a dictionary containing query parameter types and values, passed to BigQuery. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._JOBS_PATH % (self.project_id, '')) data = { 'kind': 'bigquery#job', 'configuration': { 'query': { 'query': sql, 'useQueryCache': use_cache, 'allowLargeResults': allow_large_results, 'useLegacySql': False }, 'dryRun': dry_run, 'priority': 'BATCH' if batch else 'INTERACTIVE', }, } query_config = data['configuration']['query'] if table_definitions: expanded_definitions = {} for td in table_definitions: expanded_definitions[td] = table_definitions[td]._to_query_json() query_config['tableDefinitions'] = expanded_definitions if table_name: query_config['destinationTable'] = { 'projectId': table_name.project_id, 'datasetId': table_name.dataset_id, 'tableId': table_name.table_id } if append: query_config['writeDisposition'] = "WRITE_APPEND" elif overwrite: query_config['writeDisposition'] = "WRITE_TRUNCATE" if self.bigquery_billing_tier: query_config['maximumBillingTier'] = self.bigquery_billing_tier if query_params: query_config['queryParameters'] = query_params return google.datalab.utils.Http.request(url, data=data, credentials=self.credentials) def jobs_query_results(self, job_id, project_id, page_size, timeout, start_index=0): """Issues a request to the jobs/getQueryResults method. Args: job_id: the id of job from a previously executed query. project_id: the project id to use to fetch the results; use None for the default project. page_size: limit to the number of rows to fetch. timeout: duration (in milliseconds) to wait for the query to complete. start_index: the index of the row (0-based) at which to start retrieving the page of result rows. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ if timeout is None: timeout = Api._DEFAULT_TIMEOUT if project_id is None: project_id = self.project_id args = { 'maxResults': page_size, 'timeoutMs': timeout, 'startIndex': start_index } url = Api._ENDPOINT + (Api._QUERIES_PATH % (project_id, job_id)) return google.datalab.utils.Http.request(url, args=args, credentials=self.credentials) def jobs_get(self, job_id, project_id=None): """Issues a request to retrieve information about a job. Args: job_id: the id of the job project_id: the project id to use to fetch the results; use None for the default project. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ if project_id is None: project_id = self.project_id url = Api._ENDPOINT + (Api._JOBS_PATH % (project_id, job_id)) return google.datalab.utils.Http.request(url, credentials=self.credentials) def datasets_insert(self, dataset_name, friendly_name=None, description=None): """Issues a request to create a dataset. Args: dataset_name: the name of the dataset to create. friendly_name: (optional) the friendly name for the dataset description: (optional) a description for the dataset Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._DATASETS_PATH % (dataset_name.project_id, '')) data = { 'kind': 'bigquery#dataset', 'datasetReference': { 'projectId': dataset_name.project_id, 'datasetId': dataset_name.dataset_id }, } if friendly_name: data['friendlyName'] = friendly_name if description: data['description'] = description return google.datalab.utils.Http.request(url, data=data, credentials=self.credentials) def datasets_delete(self, dataset_name, delete_contents): """Issues a request to delete a dataset. Args: dataset_name: the name of the dataset to delete. delete_contents: if True, any tables in the dataset will be deleted. If False and the dataset is non-empty an exception will be raised. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._DATASETS_PATH % dataset_name) args = {} if delete_contents: args['deleteContents'] = True return google.datalab.utils.Http.request(url, method='DELETE', args=args, credentials=self.credentials, raw_response=True) def datasets_update(self, dataset_name, dataset_info): """Updates the Dataset info. Args: dataset_name: the name of the dataset to update as a tuple of components. dataset_info: the Dataset resource with updated fields. """ url = Api._ENDPOINT + (Api._DATASETS_PATH % dataset_name) return google.datalab.utils.Http.request(url, method='PUT', data=dataset_info, credentials=self.credentials) def datasets_get(self, dataset_name): """Issues a request to retrieve information about a dataset. Args: dataset_name: the name of the dataset Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._DATASETS_PATH % dataset_name) return google.datalab.utils.Http.request(url, credentials=self.credentials) def datasets_list(self, project_id=None, max_results=0, page_token=None): """Issues a request to list the datasets in the project. Args: project_id: the project id to use to fetch the results; use None for the default project. max_results: an optional maximum number of tables to retrieve. page_token: an optional token to continue the retrieval. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ if project_id is None: project_id = self.project_id url = Api._ENDPOINT + (Api._DATASETS_PATH % (project_id, '')) args = {} if max_results != 0: args['maxResults'] = max_results if page_token is not None: args['pageToken'] = page_token return google.datalab.utils.Http.request(url, args=args, credentials=self.credentials) def tables_get(self, table_name): """Issues a request to retrieve information about a table. Args: table_name: a tuple representing the full name of the table. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._TABLES_PATH % table_name) return google.datalab.utils.Http.request(url, credentials=self.credentials) def tables_list(self, dataset_name, max_results=0, page_token=None): """Issues a request to retrieve a list of tables. Args: dataset_name: the name of the dataset to enumerate. max_results: an optional maximum number of tables to retrieve. page_token: an optional token to continue the retrieval. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT +\ (Api._TABLES_PATH % (dataset_name.project_id, dataset_name.dataset_id, '', '')) args = {} if max_results != 0: args['maxResults'] = max_results if page_token is not None: args['pageToken'] = page_token return google.datalab.utils.Http.request(url, args=args, credentials=self.credentials) def tables_insert(self, table_name, schema=None, query=None, friendly_name=None, description=None): """Issues a request to create a table or view in the specified dataset with the specified id. A schema must be provided to create a Table, or a query must be provided to create a View. Args: table_name: the name of the table as a tuple of components. schema: the schema, if this is a Table creation. query: the query, if this is a View creation. friendly_name: an optional friendly name. description: an optional description. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + \ (Api._TABLES_PATH % (table_name.project_id, table_name.dataset_id, '', '')) data = { 'kind': 'bigquery#table', 'tableReference': { 'projectId': table_name.project_id, 'datasetId': table_name.dataset_id, 'tableId': table_name.table_id } } if schema: data['schema'] = {'fields': schema} if query: data['view'] = {'query': query} if friendly_name: data['friendlyName'] = friendly_name if description: data['description'] = description return google.datalab.utils.Http.request(url, data=data, credentials=self.credentials) def tabledata_insert_all(self, table_name, rows): """Issues a request to insert data into a table. Args: table_name: the name of the table as a tuple of components. rows: the data to populate the table, as a list of dictionaries. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._TABLES_PATH % table_name) + "/insertAll" data = { 'kind': 'bigquery#tableDataInsertAllRequest', 'rows': rows } return google.datalab.utils.Http.request(url, data=data, credentials=self.credentials) def tabledata_list(self, table_name, start_index=None, max_results=None, page_token=None): """ Retrieves the contents of a table. Args: table_name: the name of the table as a tuple of components. start_index: the index of the row at which to start retrieval. max_results: an optional maximum number of rows to retrieve. page_token: an optional token to continue the retrieval. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._TABLEDATA_PATH % table_name) args = {} if start_index: args['startIndex'] = start_index if max_results: args['maxResults'] = max_results if page_token is not None: args['pageToken'] = page_token return google.datalab.utils.Http.request(url, args=args, credentials=self.credentials) def table_delete(self, table_name): """Issues a request to delete a table. Args: table_name: the name of the table as a tuple of components. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._TABLES_PATH % table_name) return google.datalab.utils.Http.request(url, method='DELETE', credentials=self.credentials, raw_response=True) def table_extract(self, table_name, destination, format='CSV', compress=True, field_delimiter=',', print_header=True): """Exports the table to GCS. Args: table_name: the name of the table as a tuple of components. destination: the destination URI(s). Can be a single URI or a list. format: the format to use for the exported data; one of CSV, NEWLINE_DELIMITED_JSON or AVRO. Defaults to CSV. compress: whether to compress the data on export. Compression is not supported for AVRO format. Defaults to False. field_delimiter: for CSV exports, the field delimiter to use. Defaults to ',' print_header: for CSV exports, whether to include an initial header line. Default true. Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._JOBS_PATH % (table_name.project_id, '')) if isinstance(destination, basestring): destination = [destination] data = { # 'projectId': table_name.project_id, # Code sample shows this but it is not in job # reference spec. Filed as b/19235843 'kind': 'bigquery#job', 'configuration': { 'extract': { 'sourceTable': { 'projectId': table_name.project_id, 'datasetId': table_name.dataset_id, 'tableId': table_name.table_id, }, 'compression': 'GZIP' if compress else 'NONE', 'fieldDelimiter': field_delimiter, 'printHeader': print_header, 'destinationUris': destination, 'destinationFormat': format, } } } return google.datalab.utils.Http.request(url, data=data, credentials=self.credentials) def table_update(self, table_name, table_info): """Updates the Table info. Args: table_name: the name of the table to update as a tuple of components. table_info: the Table resource with updated fields. """ url = Api._ENDPOINT + (Api._TABLES_PATH % table_name) return google.datalab.utils.Http.request(url, method='PUT', data=table_info, credentials=self.credentials) ================================================ FILE: google/datalab/bigquery/_csv_options.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements CSV options for External Tables and Table loads from GCS.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object class CSVOptions(object): def __init__(self, delimiter=',', skip_leading_rows=0, encoding='utf-8', quote='"', allow_quoted_newlines=False, allow_jagged_rows=False): """ Initialize an instance of CSV options. Args: delimiter: The separator for fields in a CSV file. BigQuery converts the string to ISO-8859-1 encoding, and then uses the first byte of the encoded string to split the data as raw binary (default ','). skip_leading_rows: A number of rows at the top of a CSV file to skip (default 0). encoding: The character encoding of the data, either 'utf-8' (the default) or 'iso-8859-1'. quote: The value used to quote data sections in a CSV file; default '"'. If your data does not contain quoted sections, set the property value to an empty string. If your data contains quoted newline characters, you must also enable allow_quoted_newlines. allow_quoted_newlines: If True, allow quoted data sections in CSV files that contain newline characters (default False). allow_jagged_rows: If True, accept rows in CSV files that are missing trailing optional columns; the missing values are treated as nulls (default False). """ encoding_upper = encoding.upper() if encoding_upper != 'UTF-8' and encoding_upper != 'ISO-8859-1': raise Exception("Invalid source encoding %s" % encoding) self._delimiter = delimiter self._skip_leading_rows = skip_leading_rows self._encoding = encoding self._quote = quote self._allow_quoted_newlines = allow_quoted_newlines self._allow_jagged_rows = allow_jagged_rows @property def delimiter(self): return self._delimiter @property def skip_leading_rows(self): return self._skip_leading_rows @property def encoding(self): return self._encoding @property def quote(self): return self._quote @property def allow_quoted_newlines(self): return self._allow_quoted_newlines @property def allow_jagged_rows(self): return self._allow_jagged_rows def _to_query_json(self): """ Return the options as a dictionary to be used as JSON in a query job. """ return { 'quote': self._quote, 'fieldDelimiter': self._delimiter, 'encoding': self._encoding.upper(), 'skipLeadingRows': self._skip_leading_rows, 'allowQuotedNewlines': self._allow_quoted_newlines, 'allowJaggedRows': self._allow_jagged_rows } ================================================ FILE: google/datalab/bigquery/_dataset.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Dataset, and related Dataset BigQuery APIs.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object import google.datalab import google.datalab.utils from . import _api from . import _table from . import _utils from . import _view class Dataset(object): """Represents a list of BigQuery tables in a dataset.""" def __init__(self, name, context=None): """Initializes an instance of a Dataset. Args: name: the name of the dataset, as a string or (project_id, dataset_id) tuple. context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. Raises: Exception if the name is invalid. """ if context is None: context = google.datalab.Context.default() self._context = context self._api = _api.Api(context) self._name_parts = _utils.parse_dataset_name(name, self._api.project_id) self._full_name = '%s.%s' % self._name_parts self._info = None try: self._info = self._get_info() except google.datalab.utils.RequestException: pass @property def name(self): """The DatasetName named tuple (project_id, dataset_id) for the dataset.""" return self._name_parts @property def description(self): """The description of the dataset, if any. Raises: Exception if the dataset exists but the metadata for the dataset could not be retrieved. """ self._get_info() return self._info['description'] if self._info else None @property def friendly_name(self): """The friendly name of the dataset, if any. Raises: Exception if the dataset exists but the metadata for the dataset could not be retrieved. """ self._get_info() return self._info['friendlyName'] if self._info else None def _get_info(self): try: if self._info is None: self._info = self._api.datasets_get(self._name_parts) return self._info except google.datalab.utils.RequestException as e: if e.status == 404: return None raise e except Exception as e: raise e def exists(self): """ Checks if the dataset exists. Returns: True if the dataset exists; False otherwise. Raises: Exception if the dataset exists but the metadata for the dataset could not be retrieved. """ self._get_info() return self._info is not None def delete(self, delete_contents=False): """Issues a request to delete the dataset. Args: delete_contents: if True, any tables and views in the dataset will be deleted. If False and the dataset is non-empty an exception will be raised. Returns: None on success. Raises: Exception if the delete fails (including if table was nonexistent). """ if not self.exists(): raise Exception('Cannot delete non-existent dataset %s' % self._full_name) try: self._api.datasets_delete(self._name_parts, delete_contents=delete_contents) except Exception as e: raise e self._info = None return None def create(self, friendly_name=None, description=None): """Creates the Dataset with the specified friendly name and description. Args: friendly_name: (optional) the friendly name for the dataset if it is being created. description: (optional) a description for the dataset if it is being created. Returns: The Dataset. Raises: Exception if the Dataset could not be created. """ if not self.exists(): try: response = self._api.datasets_insert(self._name_parts, friendly_name=friendly_name, description=description) except Exception as e: raise e if 'selfLink' not in response: raise Exception("Could not create dataset %s" % self._full_name) return self def update(self, friendly_name=None, description=None): """ Selectively updates Dataset information. Args: friendly_name: if not None, the new friendly name. description: if not None, the new description. Returns: """ self._get_info() if self._info: if friendly_name: self._info['friendlyName'] = friendly_name if description: self._info['description'] = description try: self._api.datasets_update(self._name_parts, self._info) except Exception as e: raise e finally: self._info = None # need a refresh def _retrieve_items(self, page_token, item_type): try: list_info = self._api.tables_list(self._name_parts, page_token=page_token) except Exception as e: raise e tables = list_info.get('tables', []) contents = [] if len(tables): try: for info in tables: if info['type'] != item_type: continue if info['type'] == 'TABLE': item = _table.Table((info['tableReference']['projectId'], info['tableReference']['datasetId'], info['tableReference']['tableId']), self._context) else: item = _view.View((info['tableReference']['projectId'], info['tableReference']['datasetId'], info['tableReference']['tableId']), self._context) contents.append(item) except KeyError: raise Exception('Unexpected item list response') page_token = list_info.get('nextPageToken', None) return contents, page_token def _retrieve_tables(self, page_token, _): return self._retrieve_items(page_token=page_token, item_type='TABLE') def _retrieve_views(self, page_token, _): return self._retrieve_items(page_token=page_token, item_type='VIEW') def tables(self): """ Returns an iterator for iterating through the Tables in the dataset. """ return iter(google.datalab.utils.Iterator(self._retrieve_tables)) def views(self): """ Returns an iterator for iterating through the Views in the dataset. """ return iter(google.datalab.utils.Iterator(self._retrieve_views)) def __iter__(self): """ Returns an iterator for iterating through the Tables in the dataset. """ return self.tables() def __str__(self): """Returns a string representation of the dataset using its specified name. Returns: The string representation of this object. """ return self._full_name def __repr__(self): """Returns a representation for the dataset for showing in the notebook. """ return 'Dataset %s' % self._full_name class Datasets(object): """ Iterator class for enumerating the datasets in a project. """ def __init__(self, context=None): """ Initialize the Datasets object. Args: context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. """ if context is None: context = google.datalab.Context.default() self._context = context self._api = _api.Api(context) self._project_id = context.project_id if context else self._api.project_id self._page_size = 0 def _retrieve_datasets(self, page_token, _): try: list_info = self._api.datasets_list(self._project_id, max_results=self._page_size, page_token=page_token) except Exception as e: raise e datasets = list_info.get('datasets', []) if len(datasets): self._page_size = self._page_size or len(datasets) try: datasets = [Dataset((info['datasetReference']['projectId'], info['datasetReference']['datasetId']), self._context) for info in datasets] except KeyError: raise Exception('Unexpected response from server.') page_token = list_info.get('nextPageToken', None) return datasets, page_token def __iter__(self): """ Returns an iterator for iterating through the Datasets in the project. """ return iter(google.datalab.utils.Iterator(self._retrieve_datasets)) ================================================ FILE: google/datalab/bigquery/_external_data_source.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements External Table functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object from . import _csv_options class ExternalDataSource(object): def __init__(self, source, source_format='csv', csv_options=None, ignore_unknown_values=False, max_bad_records=0, compressed=False, schema=None): """ Create an external table for a GCS object. Args: source: the URL of the source objects(s). Can include a wildcard '*' at the end of the item name. Can be a single source or a list. source_format: the format of the data, 'csv' or 'json'; default 'csv'. csv_options: For CSV files, the options such as quote character and delimiter. ignore_unknown_values: If True, accept rows that contain values that do not match the schema; the unknown values are ignored (default False). max_bad_records: The maximum number of bad records that are allowed (and ignored) before returning an 'invalid' error in the Job result (default 0). compressed: whether the data is GZ compressed or not (default False). Note that compressed data can be used as an external data source but cannot be loaded into a BQ Table. schema: the schema of the data. This is required for this table to be used as an external data source or to be loaded using a Table object that itself has no schema (default None). """ # Do some sanity checking and concert some params from friendly form to form used by BQ. if source_format == 'csv': self._bq_source_format = 'CSV' if csv_options is None: csv_options = _csv_options.CSVOptions() # use defaults elif source_format == 'json': if csv_options: raise Exception('CSV options are not support for JSON tables') self._bq_source_format = 'NEWLINE_DELIMITED_JSON' else: raise Exception("Invalid source format %s" % source_format) self._source = source if isinstance(source, list) else [source] self._source_format = source_format self._csv_options = csv_options self._ignore_unknown_values = ignore_unknown_values self._max_bad_records = max_bad_records self._compressed = compressed self._schema = schema @property def schema(self): return self._schema def __repr__(self): return 'BigQuery External Datasource - paths: %s' % (','.join(self._source)) def _to_query_json(self): """ Return the table as a dictionary to be used as JSON in a query job. """ json = { 'compression': 'GZIP' if self._compressed else 'NONE', 'ignoreUnknownValues': self._ignore_unknown_values, 'maxBadRecords': self._max_bad_records, 'sourceFormat': self._bq_source_format, 'sourceUris': self._source, } if self._source_format == 'csv' and self._csv_options: json['csvOptions'] = {} json['csvOptions'].update(self._csv_options._to_query_json()) if self._schema: json['schema'] = {'fields': self._schema._bq_schema} return json ================================================ FILE: google/datalab/bigquery/_job.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements BigQuery Job functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from __future__ import division import datetime import google.datalab from google.datalab.utils._gcp_job import GCPJob from . import _api class Job(GCPJob): """Represents a BigQuery Job. """ def __init__(self, job_id, context): """Initializes an instance of a Job. Args: job_id: the BigQuery job ID corresponding to this job. context: a Context object providing project_id and credentials. """ super(Job, self).__init__(job_id, context) def _create_api(self, context): return _api.Api(context) def _refresh_state(self): """ Get the state of a job. If the job is complete this does nothing; otherwise it gets a refreshed copy of the job resource. """ # TODO(gram): should we put a choke on refreshes? E.g. if the last call was less than # a second ago should we return the cached value? if self._is_complete: return try: response = self._api.jobs_get(self._job_id) except Exception as e: raise e if 'status' in response: status = response['status'] if 'state' in status and status['state'] == 'DONE': self._end_time = datetime.datetime.utcnow() self._is_complete = True self._process_job_status(status) if 'statistics' in response: statistics = response['statistics'] start_time = statistics.get('creationTime', None) end_time = statistics.get('endTime', None) if start_time and end_time and end_time >= start_time: self._start_time = datetime.datetime.fromtimestamp(float(start_time) / 1000.0) self._end_time = datetime.datetime.fromtimestamp(float(end_time) / 1000.0) def _process_job_status(self, status): if 'errorResult' in status: error_result = status['errorResult'] location = error_result.get('location', None) message = error_result.get('message', None) reason = error_result.get('reason', None) self._fatal_error = google.datalab.JobError(location, message, reason) if 'errors' in status: self._errors = [] for error in status['errors']: location = error.get('location', None) message = error.get('message', None) reason = error.get('reason', None) self._errors.append(google.datalab.JobError(location, message, reason)) ================================================ FILE: google/datalab/bigquery/_parser.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements BigQuery related data parsing helpers.""" from __future__ import absolute_import from __future__ import division from __future__ import unicode_literals from builtins import zip from builtins import str from builtins import object import datetime class Parser(object): """A set of helper functions to parse data in BigQuery responses.""" def __init__(self): pass @staticmethod def parse_row(schema, data): """Parses a row from query results into an equivalent object. Args: schema: the array of fields defining the schema of the data. data: the JSON row from a query result. Returns: The parsed row object. """ def parse_value(data_type, value): """Parses a value returned from a BigQuery response. Args: data_type: the type of the value as specified by the schema. value: the raw value to return (before casting to data_type). Returns: The value cast to the data_type. """ if value is not None: if value == 'null': value = None elif data_type == 'INTEGER': value = int(value) elif data_type == 'FLOAT': value = float(value) elif data_type == 'TIMESTAMP': value = datetime.datetime.utcfromtimestamp(float(value)) elif data_type == 'BOOLEAN': value = value == 'true' elif (type(value) != str): # TODO(gram): Handle nested JSON records value = str(value) return value row = {} if data is None: return row for i, (field, schema_field) in enumerate(zip(data['f'], schema)): val = field['v'] name = schema_field['name'] data_type = schema_field['type'] repeated = True if 'mode' in schema_field and schema_field['mode'] == 'REPEATED' else False if repeated and val is None: row[name] = [] elif data_type == 'RECORD': sub_schema = schema_field['fields'] if repeated: row[name] = [Parser.parse_row(sub_schema, v['v']) for v in val] else: row[name] = Parser.parse_row(sub_schema, val) elif repeated: row[name] = [parse_value(data_type, v['v']) for v in val] else: row[name] = parse_value(data_type, val) return row @staticmethod def parse_timestamp(value): """Parses a timestamp. Args: value: the number of milliseconds since epoch. """ return datetime.datetime.utcfromtimestamp(float(value) / 1000.0) ================================================ FILE: google/datalab/bigquery/_query.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Query BigQuery API.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object import datetime import google.datalab import google.datalab.data import google.datalab.utils import six from ._query_output import QueryOutput from . import _api from . import _query_job from . import _udf from . import _utils from . import _external_data_source class Query(object): """Represents a Query object that encapsulates a BigQuery SQL query. This object can be used to execute SQL queries and retrieve results. """ def __init__(self, sql, env=None, udfs=None, data_sources=None, subqueries=None): """Initializes an instance of a Query object. Args: sql: the BigQuery SQL query string to execute env: a dictionary containing objects from the query execution context, used to get references to UDFs, subqueries, and external data sources referenced by the query udfs: list of UDFs names referenced in the SQL, or dictionary of names and UDF objects data_sources: list of external data sources names referenced in the SQL, or dictionary of names and data source objects subqueries: list of subqueries names referenced in the SQL, or dictionary of names and Query objects Raises: Exception if expansion of any variables failed. """ self._sql = sql self._udfs = [] self._subqueries = [] self._data_sources = [] self._env = env or {} # Validate given list or dictionary of objects that they are of correct type # and add them to the target dictionary def _expand_objects(obj_container, obj_type, target_list): for item in obj_container: # for a list of objects, we should find these objects in the given environment if isinstance(obj_container, list): value = self._env.get(item) if value is None: raise Exception('Cannot find object %s' % item) # for a dictionary of objects, each pair must be a string and object of the expected type elif isinstance(obj_container, dict): value = obj_container[item] if not isinstance(value, obj_type): raise Exception('Expected type: %s, found: %s.' % (obj_type, type(value))) else: raise Exception('Unexpected container for type %s. Expected a list or dictionary' % obj_type) target_list.append((item, value)) if subqueries: _expand_objects(subqueries, Query, self._subqueries) if udfs: _expand_objects(udfs, _udf.UDF, self._udfs) if data_sources: _expand_objects(data_sources, _external_data_source.ExternalDataSource, self._data_sources) if len(self._data_sources) > 1: raise Exception('Only one temporary external datasource is supported in queries.') @staticmethod def from_view(view): """ Return a Query for the given View object Args: view: the View object to construct a Query out of Returns: A Query object with the same sql as the given View object """ return Query('SELECT * FROM %s' % view._repr_sql_()) @staticmethod def from_table(table, fields=None): """ Return a Query for the given Table object Args: table: the Table object to construct a Query out of fields: the fields to return. If None, all fields will be returned. This can be a string which will be injected into the Query after SELECT, or a list of field names. Returns: A Query object that will return the specified fields from the records in the Table. """ if fields is None: fields = '*' elif isinstance(fields, list): fields = ','.join(fields) return Query('SELECT %s FROM %s' % (fields, table._repr_sql_())) def _expanded_sql(self, sampling=None): """Get the expanded SQL of this object, including all subqueries, UDFs, and external datasources Returns: The expanded SQL string of this object """ # use lists to preserve the order of subqueries, bigquery will not like listing subqueries # out of order if they depend on each other. for example. the following will be rejected: # WITH q2 as (SELECT * FROM q1), # q1 as (SELECT * FROM mytable), # SELECT * FROM q2 # so when we're getting the dependencies, use recursion into a list to maintain the order udfs = [] subqueries = [] expanded_sql = '' def _recurse_subqueries(query): """Recursively scan subqueries and add their pieces to global scope udfs and subqueries """ if query._subqueries: for subquery in query._subqueries: _recurse_subqueries(subquery[1]) subqueries.extend([s for s in query._subqueries if s not in subqueries]) if query._udfs: # query._udfs is a list of (name, UDF) tuples; we just want the UDF. udfs.extend([u[1] for u in query._udfs if u[1] not in udfs]) _recurse_subqueries(self) if udfs: expanded_sql += '\n'.join([udf._expanded_sql() for udf in udfs]) expanded_sql += '\n' def _indent_query(subquery): return ' ' + subquery._sql.replace('\n', '\n ') if subqueries: expanded_sql += 'WITH ' + \ '\n),\n'.join(['%s AS (\n%s' % (sq[0], _indent_query(sq[1])) for sq in subqueries]) expanded_sql += '\n)\n\n' expanded_sql += sampling(self._sql) if sampling else self._sql return expanded_sql def _repr_sql_(self): """Creates a SQL representation of this object. Returns: The SQL representation to use when embedding this object into other SQL. """ return '(%s)' % self.sql def __repr__(self): """Creates a friendly representation of this object. Returns: The friendly representation of this object (the unmodified SQL). """ return 'BigQuery Query - %s' % self._sql @property def sql(self): """ Get the SQL for the query. """ return self._expanded_sql() @property def udfs(self): """ Get a dictionary of UDFs referenced by the query.""" return dict(self._udfs) @property def subqueries(self): """ Get a dictionary of subqueries referenced by the query.""" return dict(self._subqueries) @property def data_sources(self): """ Get a dictionary of external data sources referenced by the query.""" return dict(self._data_sources) def dry_run(self, context=None, query_params=None): """Dry run a query, to check the validity of the query and return some useful statistics. Args: context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. query_params: a dictionary containing query parameter types and values, passed to BigQuery. Returns: A dict with 'cacheHit' and 'totalBytesProcessed' fields. Raises: An exception if the query was malformed. """ context = context or google.datalab.Context.default() api = _api.Api(context) try: query_result = api.jobs_insert_query(self.sql, dry_run=True, table_definitions=self.data_sources, query_params=query_params) except Exception as e: raise e return query_result['statistics']['query'] def execute_async(self, output_options=None, sampling=None, context=None, query_params=None): """ Initiate the query and return a QueryJob. Args: output_options: a QueryOutput object describing how to execute the query sampling: sampling function to use. No sampling is done if None. See bigquery.Sampling context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. query_params: a dictionary containing query parameter types and values, passed to BigQuery. Returns: A Job object that can wait on creating a table or exporting to a file If the output is a table, the Job object additionally has run statistics and query results Raises: Exception if query could not be executed. """ # Default behavior is to execute to a table if output_options is None: output_options = QueryOutput.table() # First, execute the query into a table, using a temporary one if no name is specified batch = output_options.priority == 'low' append = output_options.table_mode == 'append' overwrite = output_options.table_mode == 'overwrite' table_name = output_options.table_name context = context or google.datalab.Context.default() api = _api.Api(context) if table_name is not None: table_name = _utils.parse_table_name(table_name, api.project_id) sql = self._expanded_sql(sampling) try: query_result = api.jobs_insert_query(sql, table_name=table_name, append=append, overwrite=overwrite, batch=batch, use_cache=output_options.use_cache, allow_large_results=output_options.allow_large_results, table_definitions=self.data_sources, query_params=query_params) except Exception as e: raise e if 'jobReference' not in query_result: raise Exception('Unexpected response from server') job_id = query_result['jobReference']['jobId'] if not table_name: try: destination = query_result['configuration']['query']['destinationTable'] table_name = (destination['projectId'], destination['datasetId'], destination['tableId']) except KeyError: # The query was in error raise Exception(_utils.format_query_errors(query_result['status']['errors'])) execute_job = _query_job.QueryJob(job_id, table_name, sql, context=context) # If all we need is to execute the query to a table, we're done if output_options.type == 'table': return execute_job # Otherwise, build an async Job that waits on the query execution then carries out # the specific export operation else: export_args = export_kwargs = None if output_options.type == 'file': if output_options.file_path.startswith('gs://'): export_func = execute_job.result().extract export_args = [output_options.file_path] export_kwargs = { 'format': output_options.file_format, 'csv_delimiter': output_options.csv_delimiter, 'csv_header': output_options.csv_header, 'compress': output_options.compress_file } else: export_func = execute_job.result().to_file export_args = [output_options.file_path] export_kwargs = { 'format': output_options.file_format, 'csv_delimiter': output_options.csv_delimiter, 'csv_header': output_options.csv_header } elif output_options.type == 'dataframe': export_func = execute_job.result().to_dataframe export_args = [] export_kwargs = { 'start_row': output_options.dataframe_start_row, 'max_rows': output_options.dataframe_max_rows } # Perform the export operation with the specified parameters export_func = google.datalab.utils.async_function(export_func) return export_func(*export_args, **export_kwargs) def execute(self, output_options=None, sampling=None, context=None, query_params=None): """ Initiate the query and return a QueryJob. Args: output_options: a QueryOutput object describing how to execute the query sampling: sampling function to use. No sampling is done if None. See bigquery.Sampling context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. Returns: A Job object that can be used to get the query results, or export to a file or dataframe Raises: Exception if query could not be executed. """ return self.execute_async(output_options, sampling=sampling, context=context, query_params=query_params).wait() @staticmethod def get_query_parameters(config_parameters, date_time=datetime.datetime.now()): """ Merge the given parameters with the airflow macros. Enables macros (like '@_ds') in sql. Args: config_parameters: The user-specified list of parameters in the cell-body. date_time: The timestamp at which the parameters need to be evaluated. E.g. when the table is ..logs_%(_ds)s, the '_ds' evaluates to the current date-time. Returns: A list of query parameters that are in the format for the BQ service """ merged_parameters = Query.merge_parameters(config_parameters, date_time=date_time, macros=False, types_and_values=True) # We're exposing a simpler schema format than the one actually required by BigQuery to make # magics easier. We need to convert between the two formats parsed_params = [] for key, value in merged_parameters.items(): parsed_params.append({ 'name': key, 'parameterType': { 'type': value['type'] }, 'parameterValue': { 'value': value['value'] } }) return parsed_params @staticmethod def resolve_parameters(value, parameters, date_time=datetime.datetime.now(), macros=False): """ Resolve a format modifier with the corresponding value. Args: value: The string (path, table, or any other artifact in a cell_body) which may have format modifiers. E.g. a table name could be ..logs_%(_ds)s parameters: The user-specified list of parameters in the cell-body. date_time: The timestamp at which the parameters need to be evaluated. E.g. when the table is ..logs_%(_ds)s, the '_ds' evaluates to the current date-time. macros: When true, the format modifers in the value are replaced with the corresponding airflow macro equivalents (like '{{ ds }}'. When false, the actual values are used (like '2015-12-12'. Returns: The resolved value, i.e. the value with the format modifiers replaced with the corresponding parameter-values. E.g. if value is ..logs_%(_ds)s, the returned value is something like ..logs_2017-12-21 """ merged_parameters = Query.merge_parameters(parameters, date_time=date_time, macros=macros, types_and_values=False) return Query._resolve_parameters(value, merged_parameters) @staticmethod def _resolve_parameters(operator_param_value, merged_parameters): """ Resolve a format modifier with the corresponding value. Args: operator_param_value: The object with the format-modifiers that need to be evaluated. This could either be a string, or a more complex type like a list or a dict. This function will recursively replace the format-modifiers from all the string values that it can find. merged_parameters: The full set of parameters that include the user-specified list of parameters from the cell-body, and the built-in airflow macros (either the macros or the evaluated-values). Returns: The resolved value, i.e. the value with the format modifiers replaced with the corresponding parameter-values. E.g. if value is ..logs_%(_ds)s, the returned value could be ..logs_2017-12-21. """ if isinstance(operator_param_value, list): return [Query._resolve_parameters(item, merged_parameters) for item in operator_param_value] if isinstance(operator_param_value, dict): return {Query._resolve_parameters(k, merged_parameters): Query._resolve_parameters( v, merged_parameters) for k, v in operator_param_value.items()} if isinstance(operator_param_value, six.string_types) and merged_parameters: return operator_param_value % merged_parameters return operator_param_value @staticmethod def _airflow_macro_formats(date_time, macros, types_and_values): """ Return a mapping from airflow macro names (prefixed with '_') to values Args: date_time: The timestamp at which the macro values need to be evaluated. This is only applicable when types_and_values = True macros: If true, the items in the returned dict are the macro strings (like '_ds': '{{ ds }}') types_and_values: If true, the values in the returned dict are dictionaries of the types and the values of the parameters (i.e like '_ds': {'type': STRING, 'value': 2017-12-21}) Returns: The resolved value, i.e. the value with the format modifiers replaced with the corresponding parameter-values. E.g. if value is ..logs_%(_ds)s, the returned value could be ..logs_2017-12-21. """ day = date_time.date() airflow_macros = { # the datetime formatted as YYYY-MM-DD '_ds': {'type': 'STRING', 'value': day.isoformat(), 'macro': '{{ ds }}'}, # the full ISO-formatted timestamp YYYY-MM-DDTHH:MM:SS.mmmmmm '_ts': {'type': 'STRING', 'value': date_time.isoformat(), 'macro': '{{ ts }}'}, # the datetime formatted as YYYYMMDD (i.e. YYYY-MM-DD with 'no dashes') '_ds_nodash': {'type': 'STRING', 'value': day.strftime('%Y%m%d'), 'macro': '{{ ds_nodash }}'}, # the timestamp formatted as YYYYMMDDTHHMMSSmmmmmm (i.e full ISO-formatted timestamp # YYYY-MM-DDTHH:MM:SS.mmmmmm with no dashes or colons). '_ts_nodash': {'type': 'STRING', 'value': date_time.strftime('%Y%m%d%H%M%S%f'), 'macro': '{{ ts_nodash }}'}, '_ts_year': {'type': 'STRING', 'value': day.strftime('%Y'), 'macro': """{{ '{:04d}'.format(execution_date.year) }}"""}, '_ts_month': {'type': 'STRING', 'value': day.strftime('%m'), 'macro': """{{ '{:02d}'.format(execution_date.month) }}"""}, '_ts_day': {'type': 'STRING', 'value': day.strftime('%d'), 'macro': """{{ '{:02d}'.format(execution_date.day) }}"""}, '_ts_hour': {'type': 'STRING', 'value': date_time.strftime('%H'), 'macro': """{{ '{:02d}'.format(execution_date.hour) }}"""}, '_ts_minute': {'type': 'STRING', 'value': date_time.strftime('%M'), 'macro': """{{ '{:02d}'.format(execution_date.minute) }}"""}, '_ts_second': {'type': 'STRING', 'value': date_time.strftime('%S'), 'macro': """{{ '{:02d}'.format(execution_date.second) }}"""}, } if macros: return {key: value['macro'] for key, value in airflow_macros.items()} if types_and_values: return { key: { 'type': item['type'], 'value': item['value'] } for key, item in airflow_macros.items() } # By default only return values return {key: value['value'] for key, value in airflow_macros.items()} @staticmethod def merge_parameters(parameters, date_time, macros, types_and_values): """ Merge Return a mapping from airflow macro names (prefixed with '_') to values Args: date_time: The timestamp at which the macro values need to be evaluated. This is only applicable when types_and_values = True macros: If true, the values in the returned dict are the macro strings (like '{{ ds }}') Returns: The resolved value, i.e. the value with the format modifiers replaced with the corresponding parameter-values. E.g. if value is ..logs_%(_ds)s, the returned value could be ..logs_2017-12-21. """ merged_parameters = Query._airflow_macro_formats(date_time=date_time, macros=macros, types_and_values=types_and_values) if parameters: if types_and_values: parameters = { item['name']: {'value': item['value'], 'type': item['type']} for item in parameters } else: # macros = True, or the default (i.e. just values) parameters = {item['name']: item['value'] for item in parameters} merged_parameters.update(parameters) return merged_parameters ================================================ FILE: google/datalab/bigquery/_query_job.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements BigQuery query job functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from . import _job from . import _query_results_table class QueryJob(_job.Job): """ Represents a BigQuery Query Job. """ def __init__(self, job_id, table_name, sql, context): """ Initializes a QueryJob object. Args: job_id: the ID of the query job. table_name: the name of the table where the query results will be stored. sql: the SQL statement that was executed for the query. context: the Context object providing project_id and credentials that was used when executing the query. """ super(QueryJob, self).__init__(job_id, context) self._sql = sql self._table = _query_results_table.QueryResultsTable(table_name, context, self, is_temporary=True) self._bytes_processed = None self._cache_hit = None self._total_rows = None @property def bytes_processed(self): """ The number of bytes processed, or None if the job is not complete. """ return self._bytes_processed @property def total_rows(self): """ The total number of rows in the result, or None if not complete. """ return self._total_rows @property def cache_hit(self): """ Whether the query results were obtained from the cache or not, or None if not complete. """ return self._cache_hit @property def sql(self): """ The SQL statement that was executed for the query. """ return self._sql def wait(self, timeout=None): """ Wait for the job to complete, or a timeout to happen. This is more efficient than the version in the base Job class, in that we can use a call that blocks for the poll duration rather than a sleep. That means we shouldn't block unnecessarily long and can also poll less. Args: timeout: how long to wait (in seconds) before giving up; default None which means no timeout. Returns: The QueryJob """ poll = 30 while not self._is_complete: try: query_result = self._api.jobs_query_results(self._job_id, project_id=self._context.project_id, page_size=0, timeout=poll * 1000) except Exception as e: raise e if query_result['jobComplete']: if 'totalBytesProcessed' in query_result: self._bytes_processed = int(query_result['totalBytesProcessed']) self._cache_hit = query_result.get('cacheHit', None) if 'totalRows' in query_result: self._total_rows = int(query_result['totalRows']) break if timeout is not None: timeout -= poll if timeout <= 0: break self._refresh_state() return self def result(self): """ Get the table used for the results of the query. If the query is incomplete, this blocks. Raises: Exception if we timed out waiting for results or the query failed. """ self.wait() if self.failed: raise Exception('Query failed: %s' % str(self.errors)) return self._table ================================================ FILE: google/datalab/bigquery/_query_output.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements BigQuery output type functionality.""" class QueryOutput(object): @staticmethod def table(name=None, mode='create', use_cache=True, priority='interactive', allow_large_results=False): """ Construct a query output object where the result is a table Args: name: the result table name as a string or TableName; if None (the default), then a temporary table will be used. table_mode: one of 'create', 'overwrite' or 'append'. If 'create' (the default), the request will fail if the table exists. use_cache: whether to use past query results or ignore cache. Has no effect if destination is specified (default True). priority:one of 'batch' or 'interactive' (default). 'interactive' jobs should be scheduled to run quickly but are subject to rate limits; 'batch' jobs could be delayed by as much as three hours but are not rate-limited. allow_large_results: whether to allow large results; i.e. compressed data over 100MB. This is slower and requires a name to be specified) (default False). """ output = QueryOutput() output._output_type = 'table' output._table_name = name output._table_mode = mode output._use_cache = use_cache output._priority = priority output._allow_large_results = allow_large_results return output @staticmethod def file(path, format='csv', csv_delimiter=',', csv_header=True, compress=False, use_cache=True): """ Construct a query output object where the result is either a local file or a GCS path Note that there are two jobs that may need to be run sequentially, one to run the query, and the second to extract the resulting table. These are wrapped by a single outer Job. If the query has already been executed and you would prefer to get a Job just for the extract, you can can call extract[_async] on the QueryResultsTable returned by the query Args: path: the destination path. Can either be a local or GCS URI (starting with gs://) format: the format to use for the exported data; one of 'csv', 'json', or 'avro' (default 'csv'). csv_delimiter: for CSV exports, the field delimiter to use (default ','). csv_header: for CSV exports, whether to include an initial header line (default True). compress: whether to compress the data on export. Compression is not supported for AVRO format (default False). Applies only to GCS URIs. use_cache: whether to use cached results or not (default True). """ output = QueryOutput() output._output_type = 'file' output._file_path = path output._file_format = format output._csv_delimiter = csv_delimiter output._csv_header = csv_header output._compress_file = compress return output @staticmethod def dataframe(start_row=0, max_rows=None, use_cache=True): """ Construct a query output object where the result is a dataframe Args: start_row: the row of the table at which to start the export (default 0). max_rows: an upper limit on the number of rows to export (default None). use_cache: whether to use cached results or not (default True). """ output = QueryOutput() output._output_type = 'dataframe' output._dataframe_start_row = start_row output._dataframe_max_rows = max_rows output._use_cache = use_cache return output def __init__(self): """ Create a BigQuery output type object. Do not call this directly; use factory methods. """ self._output_type = None self._table_name = None self._table_mode = None self._use_cache = None self._priority = None self._allow_large_results = None self._file_path = None self._file_format = None self._csv_delimiter = None self._csv_header = None self._compress_file = None self._dataframe_start_row = None self._dataframe_max_rows = None @property def type(self): return self._output_type @property def table_name(self): return self._table_name @property def table_mode(self): return self._table_mode @property def use_cache(self): return self._use_cache @property def priority(self): return self._priority @property def allow_large_results(self): return self._allow_large_results @property def file_path(self): return self._file_path @property def file_format(self): return self._file_format @property def csv_delimiter(self): return self._csv_delimiter @property def csv_header(self): return self._csv_header @property def compress_file(self): return self._compress_file @property def dataframe_start_row(self): return self._dataframe_start_row @property def dataframe_max_rows(self): return self._dataframe_max_rows ================================================ FILE: google/datalab/bigquery/_query_results_table.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements BigQuery query job results table functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from . import _table class QueryResultsTable(_table.Table): """ A subclass of Table specifically for Query results. The primary differences are the additional properties job_id and sql. """ def __init__(self, name, context, job, is_temporary=False): """Initializes an instance of a Table object. Args: name: the name of the table either as a string or a 3-part tuple (projectid, datasetid, name). context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. job: the QueryJob associated with these results. is_temporary: if True, this is a short-lived table for intermediate results (default False). """ super(QueryResultsTable, self).__init__(name, context) self._job = job self._is_temporary = is_temporary def __repr__(self): """Returns a representation for the dataset for showing in the notebook. """ if self._is_temporary: return 'QueryResultsTable %s' % self.job_id else: return super(QueryResultsTable, self).__repr__() def insert(self, *args, **kwargs): raise Exception('QueryResultsTable object is immutable') @property def job(self): """ The QueryJob object that caused the table to be populated. """ return self._job @property def job_id(self): """ The ID of the query job that caused the table to be populated. """ return self._job.id @property def sql(self): """ The SQL statement for the query that populated the table. """ return self._job.sql @property def is_temporary(self): """ Whether this is a short-lived table or not. """ return self._is_temporary ================================================ FILE: google/datalab/bigquery/_query_stats.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements representation of BigQuery query job dry run results.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object class QueryStats(object): """A wrapper for statistics returned by a dry run query. Useful so we can get an HTML representation in a notebook. """ def __init__(self, total_bytes, is_cached): self.total_bytes = float(total_bytes) self.is_cached = is_cached def _repr_html_(self): self.total_bytes = QueryStats._size_formatter(self.total_bytes) return """

Dry run information: %s to process, results %s

""" % (self.total_bytes, "cached" if self.is_cached else "not cached") @staticmethod def _size_formatter(byte_num, suf='B'): for mag in ['', 'K', 'M', 'G', 'T']: if byte_num < 1000.0: if suf == 'B': # Don't do fractional bytes return "%5d%s%s" % (int(byte_num), mag, suf) return "%3.1f%s%s" % (byte_num, mag, suf) byte_num /= 1000.0 return "%.1f%s%s".format(byte_num, 'P', suf) ================================================ FILE: google/datalab/bigquery/_sampling.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Sampling for BigQuery.""" from __future__ import absolute_import from __future__ import division from __future__ import unicode_literals from builtins import object class Sampling(object): """Provides common sampling strategies. Sampling strategies can be used for sampling tables or queries. They are implemented as functions that take in a SQL statement representing the table or query that should be sampled, and return a new SQL statement that limits the result set in some manner. """ def __init__(self): pass @staticmethod def _create_projection(fields): """Creates a projection for use in a SELECT statement. Args: fields: the list of fields to be specified. """ if (fields is None) or (len(fields) == 0): return '*' return ','.join(fields) @staticmethod def default(fields=None, count=5): """Provides a simple default sampling strategy which limits the result set by a count. Args: fields: an optional list of field names to retrieve. count: optional number of rows to limit the sampled results to. Returns: A sampling function that can be applied to get a random sampling. """ projection = Sampling._create_projection(fields) return lambda sql: 'SELECT %s FROM (%s) LIMIT %d' % (projection, sql, count) @staticmethod def sorted(field_name, ascending=True, fields=None, count=5): """Provides a sampling strategy that picks from an ordered set of rows. Args: field_name: the name of the field to sort the rows by. ascending: whether to sort in ascending direction or not. fields: an optional list of field names to retrieve. count: optional number of rows to limit the sampled results to. Returns: A sampling function that can be applied to get the initial few rows. """ if field_name is None: raise Exception('Sort field must be specified') direction = '' if ascending else ' DESC' projection = Sampling._create_projection(fields) return lambda sql: 'SELECT %s FROM (%s) ORDER BY %s%s LIMIT %d' % (projection, sql, field_name, direction, count) @staticmethod def hashed(field_name, percent, fields=None, count=0): """Provides a sampling strategy based on hashing and selecting a percentage of data. Args: field_name: the name of the field to hash. percent: the percentage of the resulting hashes to select. fields: an optional list of field names to retrieve. count: optional maximum count of rows to pick. Returns: A sampling function that can be applied to get a hash-based sampling. """ if field_name is None: raise Exception('Hash field must be specified') def _hashed_sampling(sql): projection = Sampling._create_projection(fields) sql = 'SELECT %s FROM (%s) WHERE MOD(ABS(FARM_FINGERPRINT(CAST(%s AS STRING))), 100) < %d' % \ (projection, sql, field_name, percent) if count != 0: sql = '%s LIMIT %d' % (sql, count) return sql return _hashed_sampling @staticmethod def random(percent, fields=None, count=0): """Provides a sampling strategy that picks a semi-random set of rows. Args: percent: the percentage of the resulting hashes to select. fields: an optional list of field names to retrieve. count: maximum number of rows to limit the sampled results to (default 5). Returns: A sampling function that can be applied to get some random rows. In order for this to provide a good random sample percent should be chosen to be ~count/#rows where #rows is the number of rows in the object (query, view or table) being sampled. The rows will be returned in order; i.e. the order itself is not randomized. """ def _random_sampling(sql): projection = Sampling._create_projection(fields) sql = 'SELECT %s FROM (%s) WHERE rand() < %f' % (projection, sql, (float(percent) / 100.0)) if count != 0: sql = '%s LIMIT %d' % (sql, count) return sql return _random_sampling @staticmethod def _auto(method, fields, count, percent, key_field, ascending): """Construct a sampling function according to the provided sampling technique, provided all its needed fields are passed as arguments Args: method: one of the supported sampling methods: {limit,random,hashed,sorted} fields: an optional list of field names to retrieve. count: maximum number of rows to limit the sampled results to. percent: the percentage of the resulting hashes to select if using hashed sampling key_field: the name of the field to sort the rows by or use for hashing ascending: whether to sort in ascending direction or not. Returns: A sampling function using the provided arguments Raises: Exception if an unsupported mathod name is passed """ if method == 'limit': return Sampling.default(fields=fields, count=count) elif method == 'random': return Sampling.random(fields=fields, percent=percent, count=count) elif method == 'hashed': return Sampling.hashed(fields=fields, field_name=key_field, percent=percent, count=count) elif method == 'sorted': return Sampling.sorted(fields=fields, field_name=key_field, ascending=ascending, count=count) else: raise Exception('Unsupported sampling method: %s' % method) ================================================ FILE: google/datalab/bigquery/_schema.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Table and View Schema APIs.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from builtins import range from past.builtins import basestring from builtins import object import datetime import pandas import pprint class SchemaField(object): """ Represents a single field in a Table schema. This has the properties: - name: the flattened, full-qualified name of the field. - type: the type of the field as a string ('INTEGER', 'BOOLEAN', 'FLOAT', 'STRING' or 'TIMESTAMP'). - mode: the mode of the field; 'NULLABLE' by default. - description: a description of the field, if known; empty string by default. """ def __init__(self, name, type, mode='NULLABLE', description=''): self.name = name self.type = type self.mode = mode self.description = description def _repr_sql_(self): """Returns a representation of the field for embedding into a SQL statement. Returns: A formatted field name for use within SQL statements. """ return self.name def __eq__(self, other): """ Compare two schema field objects for equality (ignoring description). """ return self.name == other.name and self.type == other.type\ and self.mode == other.mode def __repr__(self): """ Returns the schema field as a string form of a dictionary. """ return 'BigQuery Schema Field:\n%s' % pprint.pformat(vars(self), width=1) def __getitem__(self, item): # TODO(gram): Currently we need this for a Schema object to work with the Parser object. # Eventually if we change Parser to only work with Schema (and not also with the # schema dictionaries in query results) we can remove this. if item == 'name': return self.name if item == 'type': return self.type if item == 'mode': return self.mode if item == 'description': return self.description class Schema(list): """Represents the schema of a BigQuery table as a flattened list of objects representing fields. Each field object has name, type, mode and description properties. Nested fields get flattened with their full-qualified names. So a Schema that has an object A with nested field B will be represented as [(name: 'A', ...), (name: 'A.b', ...)]. """ @staticmethod def _from_dataframe(dataframe, default_type='STRING'): """ Infer a BigQuery table schema from a Pandas dataframe. Note that if you don't explicitly set the types of the columns in the dataframe, they may be of a type that forces coercion to STRING, so even though the fields in the dataframe themselves may be numeric, the type in the derived schema may not be. Hence it is prudent to make sure the Pandas dataframe is typed correctly. Args: dataframe: The DataFrame. default_type : The default big query type in case the type of the column does not exist in the schema. Defaults to 'STRING'. Returns: A list of dictionaries containing field 'name' and 'type' entries, suitable for use in a BigQuery Tables resource schema. """ type_mapping = { 'i': 'INTEGER', 'b': 'BOOLEAN', 'f': 'FLOAT', 'O': 'STRING', 'S': 'STRING', 'U': 'STRING', 'M': 'TIMESTAMP' } fields = [] for column_name, dtype in dataframe.dtypes.iteritems(): fields.append({'name': column_name, 'type': type_mapping.get(dtype.kind, default_type)}) return fields @staticmethod def _get_field_entry(name, value): entry = {'name': name} if isinstance(value, datetime.datetime): _type = 'TIMESTAMP' elif isinstance(value, datetime.date): _type = 'DATE' elif isinstance(value, datetime.time): _type = 'TIME' elif isinstance(value, bool): _type = 'BOOLEAN' elif isinstance(value, float): _type = 'FLOAT' elif isinstance(value, int): _type = 'INTEGER' elif isinstance(value, dict) or isinstance(value, list): _type = 'RECORD' entry['fields'] = Schema._from_record(value) else: _type = 'STRING' entry['type'] = _type return entry @staticmethod def _from_dict_record(data): """ Infer a BigQuery table schema from a dictionary. If the dictionary has entries that are in turn OrderedDicts these will be turned into RECORD types. Ideally this will be an OrderedDict but it is not required. Args: data: The dict to infer a schema from. Returns: A list of dictionaries containing field 'name' and 'type' entries, suitable for use in a BigQuery Tables resource schema. """ return [Schema._get_field_entry(name, value) for name, value in list(data.items())] @staticmethod def _from_list_record(data): """ Infer a BigQuery table schema from a list of values. Args: data: The list of values. Returns: A list of dictionaries containing field 'name' and 'type' entries, suitable for use in a BigQuery Tables resource schema. """ return [Schema._get_field_entry('Column%d' % (i + 1), value) for i, value in enumerate(data)] @staticmethod def _from_record(data): """ Infer a BigQuery table schema from a list of fields or a dictionary. The typeof the elements is used. For a list, the field names are simply 'Column1', 'Column2', etc. Args: data: The list of fields or dictionary. Returns: A list of dictionaries containing field 'name' and 'type' entries, suitable for use in a BigQuery Tables resource schema. """ if isinstance(data, dict): return Schema._from_dict_record(data) elif isinstance(data, list): return Schema._from_list_record(data) else: raise Exception('Cannot create a schema from record %s' % str(data)) @staticmethod def from_record(source): """ Infers a table/view schema from a single record that can contain a list of fields or a dictionary of fields. The type of the elements is used for the types in the schema. For a dict, key names are used for column names while for a list, the field names are simply named 'Column1', 'Column2', etc. Note that if using a dict you may want to use an OrderedDict to ensure column ordering is deterministic. Args: source: The list of field values or dictionary of key/values. Returns: A Schema for the data. """ # TODO(gram): may want to allow an optional second argument which is a list of field # names; could be useful for the record-containing-list case. return Schema(Schema._from_record(source)) @staticmethod def from_data(source): """Infers a table/view schema from its JSON representation, a list of records, or a Pandas dataframe. Args: source: the Pandas Dataframe, a dictionary representing a record, a list of heterogeneous data (record) or homogeneous data (list of records) from which to infer the schema, or a definition of the schema as a list of dictionaries with 'name' and 'type' entries and possibly 'mode' and 'description' entries. Only used if no data argument was provided. 'mode' can be 'NULLABLE', 'REQUIRED' or 'REPEATED'. For the allowed types, see: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types Note that there is potential ambiguity when passing a list of lists or a list of dicts between whether that should be treated as a list of records or a single record that is a list. The heuristic used is to check the length of the entries in the list; if they are equal then a list of records is assumed. To avoid this ambiguity you can instead use the Schema.from_record method which assumes a single record, in either list of values or dictionary of key-values form. Returns: A Schema for the data. """ if isinstance(source, pandas.DataFrame): bq_schema = Schema._from_dataframe(source) elif isinstance(source, list): if len(source) == 0: bq_schema = source elif all(isinstance(d, dict) for d in source): if all('name' in d and 'type' in d for d in source): # It looks like a bq_schema; use it as-is. bq_schema = source elif all(len(d) == len(source[0]) for d in source): bq_schema = Schema._from_dict_record(source[0]) else: raise Exception(('Cannot create a schema from heterogeneous list %s; perhaps you meant ' + 'to use Schema.from_record?') % str(source)) elif isinstance(source[0], list) and \ all([isinstance(l, list) and len(l) == len(source[0]) for l in source]): # A list of lists all of the same length; treat first entry as a list record. bq_schema = Schema._from_record(source[0]) else: # A heterogeneous list; treat as a record. raise Exception(('Cannot create a schema from heterogeneous list %s; perhaps you meant ' + 'to use Schema.from_record?') % str(source)) elif isinstance(source, dict): bq_schema = Schema._from_record(source) else: raise Exception('Cannot create a schema from %s' % str(source)) return Schema(bq_schema) def __init__(self, definition=None): """Initializes a Schema from its raw JSON representation, a Pandas Dataframe, or a list. Args: definition: a definition of the schema as a list of dictionaries with 'name' and 'type' entries and possibly 'mode' and 'description' entries. Only used if no data argument was provided. 'mode' can be 'NULLABLE', 'REQUIRED' or 'REPEATED'. For the allowed types, see: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types """ super(Schema, self).__init__() self._map = {} self._bq_schema = definition self._populate_fields(definition) def __getitem__(self, key): """Provides ability to lookup a schema field by position or by name. """ if isinstance(key, basestring): return self._map.get(key, None) # noinspection PyCallByClass return list.__getitem__(self, key) def _add_field(self, name, type, mode='NULLABLE', description=''): field = SchemaField(name, type, mode, description) self.append(field) self._map[name] = field def find(self, name): """ Get the index of a field in the flattened list given its (fully-qualified) name. Args: name: the fully-qualified name of the field. Returns: The index of the field, if found; else -1. """ for i in range(0, len(self)): if self[i].name == name: return i return -1 def _populate_fields(self, data, prefix=''): for field_data in data: name = prefix + field_data['name'] type = field_data['type'] self._add_field(name, type, field_data.get('mode', None), field_data.get('description', None)) if type == 'RECORD': # Recurse into the nested fields, using this field's name as a prefix. self._populate_fields(field_data.get('fields'), name + '.') def __repr__(self): """ Returns a string representation of the schema for notebooks.""" return 'BigQuery Schema - Fields:\n%s' % pprint.pformat(self._bq_schema, width=1) def __eq__(self, other): """ Compares two schema for equality. """ other_map = other._map if len(self._map) != len(other_map): return False for name in self._map.keys(): if name not in other_map: return False if not self._map[name] == other_map[name]: return False return True def __ne__(self, other): """ Compares two schema for inequality. """ return not(self.__eq__(other)) ================================================ FILE: google/datalab/bigquery/_table.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Table, and related Table BigQuery APIs.""" from __future__ import absolute_import from __future__ import division from __future__ import unicode_literals from builtins import str from past.utils import old_div from builtins import object import calendar import codecs import csv import datetime import pandas import time import traceback import uuid import sys import google.datalab import google.datalab.utils from . import _api from . import _csv_options from . import _job from . import _parser from . import _schema from . import _utils class TableMetadata(object): """Represents metadata about a BigQuery table.""" def __init__(self, table, info): """Initializes a TableMetadata instance. Args: table: the Table object this belongs to. info: The BigQuery information about this table as a Python dictionary. """ self._table = table self._info = info @property def created_on(self): """The creation timestamp.""" timestamp = self._info.get('creationTime') return _parser.Parser.parse_timestamp(timestamp) @property def description(self): """The description of the table if it exists.""" return self._info.get('description', '') @property def expires_on(self): """The timestamp for when the table will expire, or None if unknown.""" timestamp = self._info.get('expirationTime', None) if timestamp is None: return None return _parser.Parser.parse_timestamp(timestamp) @property def friendly_name(self): """The friendly name of the table if it exists.""" return self._info.get('friendlyName', '') @property def modified_on(self): """The timestamp for when the table was last modified.""" timestamp = self._info.get('lastModifiedTime') return _parser.Parser.parse_timestamp(timestamp) @property def rows(self): """The number of rows within the table, or -1 if unknown. """ return int(self._info['numRows']) if 'numRows' in self._info else -1 @property def size(self): """The size of the table in bytes, or -1 if unknown. """ return int(self._info['numBytes']) if 'numBytes' in self._info else -1 def refresh(self): """ Refresh the metadata. """ self._info = self._table._load_info() class Table(object): """Represents a Table object referencing a BigQuery table. """ # Allowed characters in a BigQuery table column name _VALID_COLUMN_NAME_CHARACTERS = '_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' # When fetching table contents for a range or iteration, use a small page size per request _DEFAULT_PAGE_SIZE = 1024 # When fetching the entire table, use the maximum number of rows. The BigQuery service # will always return fewer rows than this if their encoded JSON size is larger than 10MB _MAX_PAGE_SIZE = 100000 # Milliseconds per week _MSEC_PER_WEEK = 7 * 24 * 3600 * 1000 def __init__(self, name, context=None): """Initializes an instance of a Table object. The Table need not exist yet. Args: name: the name of the table either as a string or a 3-part tuple (projectid, datasetid, name). If a string, it must have the form '..' or '.
'. context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. Raises: Exception if the name is invalid. """ if context is None: context = google.datalab.Context.default() self._context = context self._api = _api.Api(context) self._name_parts = _utils.parse_table_name(name, self._api.project_id) self._full_name = '%s.%s.%s%s' % self._name_parts self._info = None self._cached_page = None self._cached_page_index = 0 self._schema = None @property def name(self): """The TableName named tuple (project_id, dataset_id, table_id, decorator) for the table.""" return self._name_parts @property def full_name(self): """The full name of the table in the form of project.dataset.table.""" return self._full_name @property def job(self): """ For tables resulting from executing queries, the job that created the table. Default is None for a Table object; this is overridden by QueryResultsTable. """ return None @property def is_temporary(self): """ Whether this is a short-lived table or not. Always False for non-QueryResultsTables. """ return False def _load_info(self): """Loads metadata about this table.""" if self._info is None: try: self._info = self._api.tables_get(self._name_parts) except Exception as e: raise e @property def metadata(self): """Retrieves metadata about the table. Returns: A TableMetadata object. Raises Exception if the request could not be executed or the response was malformed. """ self._load_info() return TableMetadata(self, self._info) def exists(self): """Checks if the table exists. Returns: True if the table exists; False otherwise. Raises: Exception if there was an error requesting information about the table. """ try: info = self._api.tables_get(self._name_parts) except google.datalab.utils.RequestException as e: if e.status == 404: return False raise e except Exception as e: raise e self._info = info return True def is_listable(self): """ Determine if the table can be listed. Returns: True is the Table can be listed; False otherwise. """ self._load_info() return 'type' not in self._info or 'MODEL' != self._info['type'] def delete(self): """ Delete the table. Returns: True if the Table no longer exists; False otherwise. """ try: self._api.table_delete(self._name_parts) except google.datalab.utils.RequestException: # TODO(gram): May want to check the error reasons here and if it is not # because the file didn't exist, return an error. pass except Exception as e: raise e return not self.exists() def create(self, schema, overwrite=False): """ Create the table with the specified schema. Args: schema: the schema to use to create the table. Should be a list of dictionaries, each containing at least a pair of entries, 'name' and 'type'. See https://cloud.google.com/bigquery/docs/reference/v2/tables#resource overwrite: if True, delete the table first if it exists. If False and the table exists, creation will fail and raise an Exception. Returns: The Table instance. Raises: Exception if the table couldn't be created or already exists and truncate was False. """ if overwrite and self.exists(): self.delete() if not isinstance(schema, _schema.Schema): # Convert to a Schema object schema = _schema.Schema(schema) try: response = self._api.tables_insert(self._name_parts, schema=schema._bq_schema) except Exception as e: raise e if 'selfLink' in response: self._schema = schema return self raise Exception("Table %s could not be created as it already exists" % self._full_name) @staticmethod def _encode_dict_as_row(record, column_name_map): """ Encode a dictionary representing a table row in a form suitable for streaming to BQ. This includes encoding timestamps as ISO-compatible strings and removing invalid characters from column names. Args: record: a Python dictionary representing the table row. column_name_map: a dictionary mapping dictionary keys to column names. This is initially empty and built up by this method when it first encounters each column, then used as a cache subsequently. Returns: The sanitized dictionary. """ for k in list(record.keys()): v = record[k] # If the column is a date, convert to ISO string. if isinstance(v, (pandas.Timestamp, datetime.datetime, datetime.date, datetime.time)): v = record[k] = record[k].isoformat() # If k has invalid characters clean it up if k not in column_name_map: column_name_map[k] = ''.join(c for c in k if c in Table._VALID_COLUMN_NAME_CHARACTERS) new_k = column_name_map[k] if k != new_k: record[new_k] = v del record[k] return record def insert(self, data, include_index=False, index_name=None): """ Insert the contents of a Pandas DataFrame or a list of dictionaries into the table. The insertion will be performed using at most 500 rows per POST, and at most 10 POSTs per second, as BigQuery has some limits on streaming rates. Args: data: the DataFrame or list to insert. include_index: whether to include the DataFrame or list index as a column in the BQ table. index_name: for a list, if include_index is True, this should be the name for the index. If not specified, 'Index' will be used. Returns: The table. Raises: Exception if the table doesn't exist, the table's schema differs from the data's schema, or the insert failed. """ # TODO(gram): we could create the Table here is it doesn't exist using a schema derived # from the data. IIRC we decided not to but doing so seems less unwieldy that having to # create it first and then validate the schema against it itself. # There are BigQuery limits on the streaming API: # # max_rows_per_post = 500 # max_bytes_per_row = 20000 # max_rows_per_second = 10000 # max_bytes_per_post = 1000000 # max_bytes_per_second = 10000000 # # It is non-trivial to enforce these here, and the max bytes per row is not something we # can really control. As an approximation we enforce the 500 row limit # with a 0.05 sec POST interval (to enforce the 10,000 rows per sec limit). max_rows_per_post = 500 post_interval = 0.05 # TODO(gram): add different exception types for each failure case. if not self.exists(): raise Exception('Table %s does not exist.' % self._full_name) data_schema = _schema.Schema.from_data(data) if isinstance(data, list): if include_index: if not index_name: index_name = 'Index' data_schema._add_field(index_name, 'INTEGER') table_schema = self.schema # Do some validation of the two schema to make sure they are compatible. for data_field in data_schema: name = data_field.name table_field = table_schema[name] if table_field is None: raise Exception('Table does not contain field %s' % name) data_type = data_field.type table_type = table_field.type if table_type != data_type: raise Exception('Field %s in data has type %s but in table has type %s' % (name, data_type, table_type)) total_rows = len(data) total_pushed = 0 job_id = uuid.uuid4().hex rows = [] column_name_map = {} is_dataframe = isinstance(data, pandas.DataFrame) if is_dataframe: # reset_index creates a new dataframe so we don't affect the original. reset_index(drop=True) # drops the original index and uses an integer range. gen = data.reset_index(drop=not include_index).iterrows() else: gen = enumerate(data) for index, row in gen: if is_dataframe: row = row.to_dict() elif include_index: row[index_name] = index rows.append({ 'json': self._encode_dict_as_row(row, column_name_map), 'insertId': job_id + str(index) }) total_pushed += 1 if (total_pushed == total_rows) or (len(rows) == max_rows_per_post): try: response = self._api.tabledata_insert_all(self._name_parts, rows) except Exception as e: raise e if 'insertErrors' in response: raise Exception('insertAll failed: %s' % response['insertErrors']) time.sleep(post_interval) # Streaming API is rate-limited rows = [] # Block until data is ready while True: self._info = self._api.tables_get(self._name_parts) if 'streamingBuffer' not in self._info or \ 'estimatedRows' not in self._info['streamingBuffer'] or \ int(self._info['streamingBuffer']['estimatedRows']) > 0: break time.sleep(2) return self def _init_job_from_response(self, response): """ Helper function to create a Job instance from a response. """ job = None if response and 'jobReference' in response: job = _job.Job(job_id=response['jobReference']['jobId'], context=self._context) return job def extract_async(self, destination, format='csv', csv_delimiter=None, csv_header=True, compress=False): """Starts a job to export the table to GCS. Args: destination: the destination URI(s). Can be a single URI or a list. format: the format to use for the exported data; one of 'csv', 'json', or 'avro' (default 'csv'). csv_delimiter: for CSV exports, the field delimiter to use. Defaults to ',' csv_header: for CSV exports, whether to include an initial header line. Default true. compress: whether to compress the data on export. Compression is not supported for AVRO format. Defaults to False. Returns: A Job object for the export Job if it was started successfully; else None. """ format = format.upper() if format == 'JSON': format = 'NEWLINE_DELIMITED_JSON' if format == 'CSV' and csv_delimiter is None: csv_delimiter = ',' try: response = self._api.table_extract(self._name_parts, destination, format, compress, csv_delimiter, csv_header) return self._init_job_from_response(response) except Exception as e: raise google.datalab.JobError(location=traceback.format_exc(), message=str(e), reason=str(type(e))) def extract(self, destination, format='csv', csv_delimiter=None, csv_header=True, compress=False): """Exports the table to GCS; blocks until complete. Args: destination: the destination URI(s). Can be a single URI or a list. format: the format to use for the exported data; one of 'csv', 'json', or 'avro' (default 'csv'). csv_delimiter: for CSV exports, the field delimiter to use. Defaults to ',' csv_header: for CSV exports, whether to include an initial header line. Default true. compress: whether to compress the data on export. Compression is not supported for AVRO format. Defaults to False. Returns: A Job object for the completed export Job if it was started successfully; else None. """ job = self.extract_async(destination, format=format, csv_delimiter=csv_delimiter, csv_header=csv_header, compress=compress) if job is not None: job.wait() return job def load_async(self, source, mode='create', source_format='csv', csv_options=None, ignore_unknown_values=False, max_bad_records=0): """ Starts importing a table from GCS and return a Future. Args: source: the URL of the source objects(s). Can include a wildcard '*' at the end of the item name. Can be a single source or a list. mode: one of 'create', 'append', or 'overwrite'. 'append' or 'overwrite' will fail if the table does not already exist, while 'create' will fail if it does. The default is 'create'. If 'create' the schema will be inferred if necessary. source_format: the format of the data, 'csv' or 'json'; default 'csv'. csv_options: if source format is 'csv', additional options as a CSVOptions object. ignore_unknown_values: If True, accept rows that contain values that do not match the schema; the unknown values are ignored (default False). max_bad_records: the maximum number of bad records that are allowed (and ignored) before returning an 'invalid' error in the Job result (default 0). Returns: A Job object for the import if it was started successfully or None if not. Raises: Exception if the load job failed to be started or invalid arguments were supplied. """ if source_format == 'csv': source_format = 'CSV' elif source_format == 'json': source_format = 'NEWLINE_DELIMITED_JSON' else: raise Exception("Invalid source format %s" % source_format) if not(mode == 'create' or mode == 'append' or mode == 'overwrite'): raise Exception("Invalid mode %s" % mode) if csv_options is None: csv_options = _csv_options.CSVOptions() try: response = self._api.jobs_insert_load(source, self._name_parts, append=(mode == 'append'), overwrite=(mode == 'overwrite'), create=(mode == 'create'), source_format=source_format, field_delimiter=csv_options.delimiter, allow_jagged_rows=csv_options.allow_jagged_rows, allow_quoted_newlines=csv_options.allow_quoted_newlines, encoding=csv_options.encoding.upper(), ignore_unknown_values=ignore_unknown_values, max_bad_records=max_bad_records, quote=csv_options.quote, skip_leading_rows=csv_options.skip_leading_rows) except Exception as e: raise e return self._init_job_from_response(response) def load(self, source, mode='create', source_format='csv', csv_options=None, ignore_unknown_values=False, max_bad_records=0): """ Load the table from GCS. Args: source: the URL of the source objects(s). Can include a wildcard '*' at the end of the item name. Can be a single source or a list. mode: one of 'create', 'append', or 'overwrite'. 'append' or 'overwrite' will fail if the table does not already exist, while 'create' will fail if it does. The default is 'create'. If 'create' the schema will be inferred if necessary. source_format: the format of the data, 'csv' or 'json'; default 'csv'. csv_options: if source format is 'csv', additional options as a CSVOptions object. ignore_unknown_values: if True, accept rows that contain values that do not match the schema; the unknown values are ignored (default False). max_bad_records: the maximum number of bad records that are allowed (and ignored) before returning an 'invalid' error in the Job result (default 0). Returns: A Job object for the completed load Job if it was started successfully; else None. """ job = self.load_async(source, mode=mode, source_format=source_format, csv_options=csv_options, ignore_unknown_values=ignore_unknown_values, max_bad_records=max_bad_records) if job is not None: job.wait() return job def _get_row_fetcher(self, start_row=0, max_rows=None, page_size=_DEFAULT_PAGE_SIZE): """ Get a function that can retrieve a page of rows. The function returned is a closure so that it can have a signature suitable for use by Iterator. Args: start_row: the row to start fetching from; default 0. max_rows: the maximum number of rows to fetch (across all calls, not per-call). Default is None which means no limit. page_size: the maximum number of results to fetch per page; default 1024. Returns: A function that can be called repeatedly with a page token and running count, and that will return an array of rows and a next page token; when the returned page token is None the fetch is complete. """ if not start_row: start_row = 0 elif start_row < 0: # We are measuring from the table end if self.length >= 0: start_row += self.length else: raise Exception('Cannot use negative indices for table of unknown length') schema = self.schema._bq_schema name_parts = self._name_parts def _retrieve_rows(page_token, count): page_rows = [] if max_rows and count >= max_rows: page_token = None else: if max_rows and page_size > (max_rows - count): max_results = max_rows - count else: max_results = page_size try: if page_token: response = self._api.tabledata_list(name_parts, page_token=page_token, max_results=max_results) else: response = self._api.tabledata_list(name_parts, start_index=start_row, max_results=max_results) except Exception as e: raise e page_token = response['pageToken'] if 'pageToken' in response else None if 'rows' in response: page_rows = response['rows'] rows = [] for row_dict in page_rows: rows.append(_parser.Parser.parse_row(schema, row_dict)) return rows, page_token return _retrieve_rows def range(self, start_row=0, max_rows=None): """ Get an iterator to iterate through a set of table rows. Args: start_row: the row of the table at which to start the iteration (default 0) max_rows: an upper limit on the number of rows to iterate through (default None) Returns: A row iterator. """ fetcher = self._get_row_fetcher(start_row=start_row, max_rows=max_rows) return iter(google.datalab.utils.Iterator(fetcher)) def to_dataframe(self, start_row=0, max_rows=None): """ Exports the table to a Pandas dataframe. Args: start_row: the row of the table at which to start the export (default 0) max_rows: an upper limit on the number of rows to export (default None) Returns: A Pandas dataframe containing the table data. """ fetcher = self._get_row_fetcher(start_row=start_row, max_rows=max_rows, page_size=self._MAX_PAGE_SIZE) count = 0 page_token = None # Collect results of page fetcher in separate dataframe objects, then # concatenate them to reduce the amount of copying df_list = [] df = None while True: page_rows, page_token = fetcher(page_token, count) if len(page_rows): count += len(page_rows) df_list.append(pandas.DataFrame.from_records(page_rows)) if not page_token: break if df_list: df = pandas.concat(df_list, ignore_index=True, copy=False) # Need to reorder the dataframe to preserve column ordering ordered_fields = [field.name for field in self.schema] return df[ordered_fields] if df is not None else pandas.DataFrame() def to_file(self, destination, format='csv', csv_delimiter=',', csv_header=True): """Save the results to a local file in CSV format. Args: destination: path on the local filesystem for the saved results. format: the format to use for the exported data; currently only 'csv' is supported. csv_delimiter: for CSV exports, the field delimiter to use. Defaults to ',' csv_header: for CSV exports, whether to include an initial header line. Default true. Raises: An Exception if the operation failed. """ f = codecs.open(destination, 'w', 'utf-8') fieldnames = [] for column in self.schema: fieldnames.append(column.name) if sys.version_info[0] == 2: csv_delimiter = csv_delimiter.encode('unicode_escape') writer = csv.DictWriter(f, fieldnames=fieldnames, delimiter=csv_delimiter) if csv_header: writer.writeheader() for row in self: writer.writerow(row) f.close() @property def schema(self): """Retrieves the schema of the table. Returns: A Schema object containing a list of schema fields and associated metadata. Raises Exception if the request could not be executed or the response was malformed. """ if not self._schema: try: self._load_info() self._schema = _schema.Schema(self._info['schema']['fields']) except KeyError: raise Exception('Unexpected table response: missing schema') return self._schema def update(self, friendly_name=None, description=None, expiry=None, schema=None): """ Selectively updates Table information. Any parameters that are omitted or None are not updated. Args: friendly_name: if not None, the new friendly name. description: if not None, the new description. expiry: if not None, the new expiry time, either as a DateTime or milliseconds since epoch. schema: if not None, the new schema: either a list of dictionaries or a Schema. """ self._load_info() if friendly_name is not None: self._info['friendlyName'] = friendly_name if description is not None: self._info['description'] = description if expiry is not None: if isinstance(expiry, datetime.datetime): expiry = calendar.timegm(expiry.utctimetuple()) * 1000 self._info['expirationTime'] = expiry if schema is not None: if isinstance(schema, _schema.Schema): schema = schema._bq_schema self._info['schema'] = {'fields': schema} try: self._api.table_update(self._name_parts, self._info) except google.datalab.utils.RequestException: # The cached metadata is out of sync now; abandon it. self._info = None except Exception as e: raise e def _repr_sql_(self): """Returns a representation of the table for embedding into a SQL statement. Returns: A formatted table name for use within SQL statements. """ return '`' + self._full_name + '`' def __repr__(self): """Returns a representation for the table for showing in the notebook. """ return 'BigQuery Table - name: %s' % self._full_name @property def length(self): """ Get the length of the table (number of rows). We don't use __len__ as this may return -1 for 'unknown'. """ return self.metadata.rows def __iter__(self): """ Get an iterator for the table. """ return self.range(start_row=0) def __getitem__(self, item): """ Get an item or a slice of items from the table. This uses a small cache to reduce the number of calls to tabledata.list. Note: this is a useful function to have, and supports some current usage like query.execute().result()[0], but should be used with care. """ if isinstance(item, slice): # Just treat this as a set of calls to __getitem__(int) result = [] i = item.start step = item.step if item.step else 1 while i < item.stop: result.append(self[i]) i += step return result # Handle the integer index case. if item < 0: if self.length >= 0: item += self.length else: raise Exception('Cannot use negative indices for table of unknown length') if not self._cached_page \ or self._cached_page_index > item \ or self._cached_page_index + len(self._cached_page) <= item: # cache a new page. To get the start row we round to the nearest multiple of the page # size. first = old_div(item, self._DEFAULT_PAGE_SIZE) * self._DEFAULT_PAGE_SIZE count = self._DEFAULT_PAGE_SIZE if self.length >= 0: remaining = self.length - first if count > remaining: count = remaining fetcher = self._get_row_fetcher(start_row=first, max_rows=count, page_size=count) self._cached_page_index = first self._cached_page, _ = fetcher(None, 0) return self._cached_page[item - self._cached_page_index] @staticmethod def _convert_decorator_time(when): if isinstance(when, datetime.datetime): value = 1000 * (when - datetime.datetime.utcfromtimestamp(0)).total_seconds() elif isinstance(when, datetime.timedelta): value = when.total_seconds() * 1000 if value > 0: raise Exception("Invalid snapshot relative when argument: %s" % str(when)) else: raise Exception("Invalid snapshot when argument type: %s" % str(when)) if value < -Table._MSEC_PER_WEEK: raise Exception("Invalid snapshot relative when argument: must be within 7 days: %s" % str(when)) if value > 0: now = 1000 * (datetime.datetime.utcnow() - datetime.datetime.utcfromtimestamp(0)).total_seconds() # Check that an abs value is not more than 7 days in the past and is # not in the future if not ((now - Table._MSEC_PER_WEEK) < value < now): raise Exception("Invalid snapshot absolute when argument: %s" % str(when)) return int(value) def snapshot(self, at): """ Return a new Table which is a snapshot of this table at the specified time. Args: at: the time of the snapshot. This can be a Python datetime (absolute) or timedelta (relative to current time). The result must be after the table was created and no more than seven days in the past. Passing None will get a reference the oldest snapshot. Note that using a datetime will get a snapshot at an absolute point in time, while a timedelta will provide a varying snapshot; any queries issued against such a Table will be done against a snapshot that has an age relative to the execution time of the query. Returns: A new Table object referencing the snapshot. Raises: An exception if this Table is already decorated, or if the time specified is invalid. """ if self._name_parts.decorator != '': raise Exception("Cannot use snapshot() on an already decorated table") value = Table._convert_decorator_time(at) return Table("%s@%s" % (self._full_name, str(value)), context=self._context) def window(self, begin, end=None): """ Return a new Table limited to the rows added to this Table during the specified time range. Args: begin: the start time of the window. This can be a Python datetime (absolute) or timedelta (relative to current time). The result must be after the table was created and no more than seven days in the past. Note that using a relative value will provide a varying snapshot, not a fixed snapshot; any queries issued against such a Table will be done against a snapshot that has an age relative to the execution time of the query. end: the end time of the snapshot; if None, then the current time is used. The types and interpretation of values is as for start. Returns: A new Table object referencing the window. Raises: An exception if this Table is already decorated, or if the time specified is invalid. """ if self._name_parts.decorator != '': raise Exception("Cannot use window() on an already decorated table") start = Table._convert_decorator_time(begin) if end is None: if isinstance(begin, datetime.timedelta): end = datetime.timedelta(0) else: end = datetime.datetime.utcnow() stop = Table._convert_decorator_time(end) # Both values must have the same sign if (start > 0 >= stop) or (stop > 0 >= start): raise Exception("window: Between arguments must both be absolute or relative: %s, %s" % (str(begin), str(end))) # start must be less than stop if start > stop: raise Exception("window: Between arguments: begin must be before end: %s, %s" % (str(begin), str(end))) return Table("%s@%s-%s" % (self._full_name, str(start), str(stop)), context=self._context) ================================================ FILE: google/datalab/bigquery/_udf.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - BigQuery UDF Functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from past.builtins import basestring from builtins import object class UDF(object): """Represents a BigQuery UDF declaration. """ @property def name(self): return self._name @property def imports(self): return self._imports @property def code(self): return self._code def __init__(self, name, code, return_type, params=None, language='js', imports=None): """Initializes a UDF object from its pieces. Args: name: the name of the javascript function code: function body implementing the logic. return_type: BigQuery data type of the function return. See supported data types in the BigQuery docs params: list of parameter tuples: (name, type) language: see list of supported languages in the BigQuery docs imports: a list of GCS paths containing further support code. """ if not isinstance(return_type, basestring): raise TypeError('Argument return_type should be a string. Instead got: ', type(return_type)) if params and not isinstance(params, list): raise TypeError('Argument params should be a list of parameter names and types') if imports and not isinstance(imports, list): raise TypeError('Argument imports should be a list of GCS string paths') if imports and language != 'js': raise Exception('Imports are available for Javascript UDFs only') self._name = name self._code = code self._return_type = return_type self._params = params or [] self._language = language self._imports = imports or [] self._sql = None def _expanded_sql(self): """Get the expanded BigQuery SQL string of this UDF Returns The expanded SQL string of this UDF """ if not self._sql: self._sql = UDF._build_udf(self._name, self._code, self._return_type, self._params, self._language, self._imports) return self._sql def _repr_sql_(self): return self._expanded_sql() def __repr__(self): return 'BigQuery UDF - code:\n%s' % self._code @staticmethod def _build_udf(name, code, return_type, params, language, imports): """Creates the UDF part of a BigQuery query using its pieces Args: name: the name of the javascript function code: function body implementing the logic. return_type: BigQuery data type of the function return. See supported data types in the BigQuery docs params: dictionary of parameter names and types language: see list of supported languages in the BigQuery docs imports: a list of GCS paths containing further support code. """ params = ','.join(['%s %s' % named_param for named_param in params]) imports = ','.join(['library="%s"' % i for i in imports]) if language.lower() == 'sql': udf = 'CREATE TEMPORARY FUNCTION {name} ({params})\n' + \ 'RETURNS {return_type}\n' + \ 'AS (\n' + \ '{code}\n' + \ ');' else: udf = 'CREATE TEMPORARY FUNCTION {name} ({params})\n' +\ 'RETURNS {return_type}\n' + \ 'LANGUAGE {language}\n' + \ 'AS """\n' +\ '{code}\n' +\ '"""\n' +\ 'OPTIONS (\n' +\ '{imports}\n' +\ ');' return udf.format(name=name, params=params, return_type=return_type, language=language, code=code, imports=imports) ================================================ FILE: google/datalab/bigquery/_utils.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Useful common utility functions.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from past.builtins import basestring import collections import re DatasetName = collections.namedtuple('DatasetName', ['project_id', 'dataset_id']) """ A namedtuple for Dataset names. Args: project_id: the project id for the dataset. dataset_id: the dataset id for the dataset. """ TableName = collections.namedtuple('TableName', ['project_id', 'dataset_id', 'table_id', 'decorator']) """ A namedtuple for Table names. Args: project_id: the project id for the table. dataset_id: the dataset id for the table. table_id: the table id for the table. decorator: the optional decorator for the table (for windowing/snapshot-ing). """ # Absolute project-qualified name pattern: . _ABS_DATASET_NAME_PATTERN = r'^([a-z\d\-_\.:]+)\.(\w+)$' # Relative name pattern: _REL_DATASET_NAME_PATTERN = r'^(\w+)$' # Absolute project-qualified name pattern: ..
_ABS_TABLE_NAME_PATTERN = r'^([a-z\d\-_\.:]+)\.(\w+)\.(\w+)(@[\d\-]+)?$' # Relative name pattern: .
_REL_TABLE_NAME_PATTERN = r'^(\w+)\.(\w+)(@[\d\-]+)?$' # Table-only name pattern:
. Includes an optional decorator. _TABLE_NAME_PATTERN = r'^(\w+)(@[\d\-]+)$' def parse_dataset_name(name, project_id=None): """Parses a dataset name into its individual parts. Args: name: the name to parse, or a tuple, dictionary or array containing the parts. project_id: the expected project ID. If the name does not contain a project ID, this will be used; if the name does contain a project ID and it does not match this, an exception will be thrown. Returns: A DatasetName named tuple for the dataset. Raises: Exception: raised if the name doesn't match the expected formats or a project_id was specified that does not match that in the name. """ _project_id = _dataset_id = None if isinstance(name, basestring): # Try to parse as absolute name first. m = re.match(_ABS_DATASET_NAME_PATTERN, name, re.IGNORECASE) if m is not None: _project_id, _dataset_id = m.groups() else: # Next try to match as a relative name implicitly scoped within current project. m = re.match(_REL_DATASET_NAME_PATTERN, name) if m is not None: groups = m.groups() _dataset_id = groups[0] elif isinstance(name, dict): try: _dataset_id = name['dataset_id'] _project_id = name['project_id'] except KeyError: pass else: # Try treat as an array or tuple if len(name) == 2: # Treat as a tuple or array. _project_id, _dataset_id = name elif len(name) == 1: _dataset_id = name[0] if not _dataset_id: raise Exception('Invalid dataset name: ' + str(name)) if not _project_id: _project_id = project_id return DatasetName(_project_id, _dataset_id) def parse_table_name(name, project_id=None, dataset_id=None): """Parses a table name into its individual parts. Args: name: the name to parse, or a tuple, dictionary or array containing the parts. project_id: the expected project ID. If the name does not contain a project ID, this will be used; if the name does contain a project ID and it does not match this, an exception will be thrown. dataset_id: the expected dataset ID. If the name does not contain a dataset ID, this will be used; if the name does contain a dataset ID and it does not match this, an exception will be thrown. Returns: A TableName named tuple consisting of the full name and individual name parts. Raises: Exception: raised if the name doesn't match the expected formats, or a project_id and/or dataset_id was provided that does not match that in the name. """ _project_id = _dataset_id = _table_id = _decorator = None if isinstance(name, basestring): # Try to parse as absolute name first. m = re.match(_ABS_TABLE_NAME_PATTERN, name, re.IGNORECASE) if m is not None: _project_id, _dataset_id, _table_id, _decorator = m.groups() else: # Next try to match as a relative name implicitly scoped within current project. m = re.match(_REL_TABLE_NAME_PATTERN, name) if m is not None: groups = m.groups() _project_id, _dataset_id, _table_id, _decorator =\ project_id, groups[0], groups[1], groups[2] else: # Finally try to match as a table name only. m = re.match(_TABLE_NAME_PATTERN, name) if m is not None: groups = m.groups() _project_id, _dataset_id, _table_id, _decorator =\ project_id, dataset_id, groups[0], groups[1] elif isinstance(name, dict): try: _table_id = name['table_id'] _dataset_id = name['dataset_id'] _project_id = name['project_id'] except KeyError: pass else: # Try treat as an array or tuple if len(name) == 4: _project_id, _dataset_id, _table_id, _decorator = name elif len(name) == 3: _project_id, _dataset_id, _table_id = name elif len(name) == 2: _dataset_id, _table_id = name if not _table_id: raise Exception('Invalid table name: ' + str(name)) if not _project_id: _project_id = project_id if not _dataset_id: _dataset_id = dataset_id if not _decorator: _decorator = '' return TableName(_project_id, _dataset_id, _table_id, _decorator) def format_query_errors(errors): return '\n'.join(['%s: %s' % (error['reason'], error['message']) for error in errors]) ================================================ FILE: google/datalab/bigquery/_view.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements BigQuery Views.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from builtins import object import google.datalab from . import _query from . import _table # Query import is at end to avoid issues with circular dependencies. class View(object): """ An implementation of a BigQuery View. """ # Views in BigQuery are virtual tables, but it is useful to have a mixture of both Table and # Query semantics; our version thus internally has a BaseTable and a Query (for materialization; # not the same as the view query), and exposes a number of the same APIs as Table and Query # through wrapper functions around these. def __init__(self, name, context=None): """Initializes an instance of a View object. Args: name: the name of the view either as a string or a 3-part tuple (projectid, datasetid, name). If a string, it must have the form '..' or '.'. context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. Raises: Exception if the name is invalid. """ if context is None: context = google.datalab.Context.default() self._context = context self._table = _table.Table(name, context=context) self._materialization = _query.Query('SELECT * FROM %s' % self._repr_sql_()) @property def name(self): """The name for the view as a named tuple.""" return self._table.name @property def description(self): """The description of the view if it exists.""" return self._table.metadata.description @property def friendly_name(self): """The friendly name of the view if it exists.""" return self._table.metadata.friendly_name @property def query(self): """The Query that defines the view.""" if not self.exists(): return None self._table._load_info() if 'view' in self._table._info and 'query' in self._table._info['view']: return _query.Query(self._table._info['view']['query']) return None def exists(self): """Whether the view's Query has been executed and the view is available or not.""" return self._table.exists() def delete(self): """Removes the view if it exists.""" self._table.delete() def create(self, query): """ Creates the view with the specified query. Args: query: the query to use to for the View; either a string containing a SQL query or a Query object. Returns: The View instance. Raises: Exception if the view couldn't be created or already exists and overwrite was False. """ if isinstance(query, _query.Query): query = query.sql try: response = self._table._api.tables_insert(self._table.name, query=query) except Exception as e: raise e if 'selfLink' in response: return self raise Exception("View %s could not be created as it already exists" % str(self)) @property def schema(self): """Retrieves the schema of the table. Returns: A Schema object containing a list of schema fields and associated metadata. Raises Exception if the request could not be executed or the response was malformed. """ return self._table.schema def update(self, friendly_name=None, description=None, query=None): """ Selectively updates View information. Any parameters that are None (the default) are not applied in the update. Args: friendly_name: if not None, the new friendly name. description: if not None, the new description. query: if not None, a new query string for the View. """ self._table._load_info() if query is not None: if isinstance(query, _query.Query): query = query.sql self._table._info['view'] = {'query': query} self._table.update(friendly_name=friendly_name, description=description) def _repr_sql_(self): """Returns a representation of the view for embedding into a SQL statement. Returns: A formatted table name for use within SQL statements. """ return self._table._repr_sql_() def __repr__(self): """Returns a representation for the view for showing in the notebook. """ return 'BigQuery View - table: %s, sql: %s' % (self._table, self.query) ================================================ FILE: google/datalab/bigquery/commands/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from . import _bigquery __all__ = ['_bigquery'] ================================================ FILE: google/datalab/bigquery/commands/_bigquery.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - BigQuery IPython Functionality.""" from __future__ import absolute_import from __future__ import print_function from __future__ import unicode_literals from builtins import str from past.builtins import basestring try: import IPython import IPython.core.display import IPython.core.magic except ImportError: raise Exception('This module can only be loaded in ipython.') import datetime import jsonschema import fnmatch import json import re import google.datalab.bigquery as bigquery import google.datalab.data import google.datalab.utils import google.datalab.utils.commands from google.datalab.bigquery._query_output import QueryOutput from google.datalab.bigquery._sampling import Sampling class BigQuerySchema(object): """A container class for commonly used BQ-related constants.""" DATATYPES = ['STRING', 'BYTES', 'INTEGER', 'INT64', 'FLOAT', 'FLOAT64', 'BOOLEAN', 'BOOL', 'TIMESTAMP', 'DATE', 'TIME', 'DATETIME', 'RECORD'] DATATYPES_LOWER = [t.lower() for t in DATATYPES] MODES = ['NULLABLE', 'REQUIRED', 'REPEATED'] MODES_LOWER = [m.lower() for m in MODES] TABLE_SCHEMA_SCHEMA = { 'definitions': { 'field': { 'title': 'field', 'type': 'object', 'properties': { 'name': {'type': 'string'}, 'type': {'type': 'string', 'enum': DATATYPES + DATATYPES_LOWER}, 'mode': {'type': 'string', 'enum': MODES + MODES_LOWER}, 'description': {'type': 'string'}, 'fields': { 'type': 'array', 'items': { 'allOf': [{'$ref': '#/definitions/field'}] } } }, 'required': ['name', 'type'], 'additionalProperties': False } }, 'type': 'object', 'properties': { 'schema': { 'type': 'array', 'items': { 'allOf': [{'$ref': '#/definitions/field'}] } } }, 'required': ['schema'], 'additionalProperties': False } QUERY_PARAMS_SCHEMA = { 'type': 'object', 'properties': { 'parameters': { 'type': 'array', 'items': [ { 'type': 'object', 'properties': { 'name': {'type': 'string'}, 'type': {'type': 'string', 'enum': DATATYPES + DATATYPES_LOWER}, 'value': {'type': ['string', 'integer', 'number']} }, 'required': ['name', 'type', 'value'], 'additionalProperties': False } ] } }, 'required': ['parameters'], 'additionalProperties': False } def _create_dataset_subparser(parser): dataset_parser = parser.subcommand('datasets', 'Operations on BigQuery datasets') sub_commands = dataset_parser.add_subparsers(dest='command') # %%bq datasets list list_parser = sub_commands.add_parser('list', help='List datasets') list_parser.add_argument('-p', '--project', help='The project whose datasets should be listed') list_parser.add_argument('-f', '--filter', help='Optional wildcard filter string used to limit the results') # %%bq datasets create create_parser = sub_commands.add_parser('create', help='Create a dataset.') create_parser.add_argument('-n', '--name', help='The name of the dataset to create.', required=True) create_parser.add_argument('-f', '--friendly', help='The friendly name of the dataset.') # %%bq datasets delete delete_dataset_parser = sub_commands.add_parser('delete', help='Delete a dataset.') delete_dataset_parser.add_argument('-n', '--name', help='The name of the dataset to delete.', required=True) return dataset_parser def _create_table_subparser(parser): table_parser = parser.subcommand('tables', 'Operations on BigQuery tables') sub_commands = table_parser.add_subparsers(dest='command') # %%bq tables list list_parser = sub_commands.add_parser('list', help='List the tables in a BigQuery project or dataset.') list_parser.add_argument('-p', '--project', help='The project whose tables should be listed') list_parser.add_argument('-d', '--dataset', help='The dataset to restrict to') list_parser.add_argument('-f', '--filter', help='Optional wildcard filter string used to limit the results') # %%bq tables create create_parser = sub_commands.add_parser('create', help='Create a table.') create_parser.add_argument('-n', '--name', help='The name of the table to create.', required=True) create_parser.add_argument('-o', '--overwrite', help='Overwrite table if it exists.', action='store_true') # %%bq tables describe describe_parser = sub_commands.add_parser('describe', help='View a table\'s schema') describe_parser.add_argument('-n', '--name', help='Name of table to show', required=True) # %%bq tables delete delete_parser = sub_commands.add_parser('delete', help='Delete a table.') delete_parser.add_argument('-n', '--name', help='The name of the table to delete.', required=True) # %%bq tables view delete_parser = sub_commands.add_parser('view', help='View a table.') delete_parser.add_argument('-n', '--name', help='The name of the table to view.', required=True) return table_parser def _create_sample_subparser(parser): sample_parser = parser.subcommand('sample', help='Display a sample of the results of a BigQuery SQL query. ' 'The cell can optionally contain arguments for expanding ' 'variables in the query, if -q/--query was used, or it ' 'can contain SQL for a query.') group = sample_parser.add_mutually_exclusive_group() group.add_argument('-q', '--query', help='the name of the query object to sample') group.add_argument('-t', '--table', help='the name of the table object to sample') group.add_argument('-v', '--view', help='the name of the view object to sample') sample_parser.add_argument('-nc', '--nocache', help='Don\'t use previously cached results', action='store_true') sample_parser.add_argument('-b', '--billing', type=int, help='BigQuery billing tier') sample_parser.add_argument('-m', '--method', help='The type of sampling to use', choices=['limit', 'random', 'hashed', 'sorted'], default='limit') sample_parser.add_argument('--fields', help='Comma separated field names for projection') sample_parser.add_argument('-c', '--count', type=int, default=10, help='The number of rows to limit to, if sampling') sample_parser.add_argument('-p', '--percent', type=int, default=1, help='For random or hashed sampling, what percentage to sample from') sample_parser.add_argument('--key-field', help='The field to use for sorted or hashed sampling') sample_parser.add_argument('-o', '--order', choices=['ascending', 'descending'], default='ascending', help='The sort order to use for sorted sampling') sample_parser.add_argument('-P', '--profile', action='store_true', default=False, help='Generate an interactive profile of the data') sample_parser.add_argument('--verbose', help='Show the expanded SQL that is being executed', action='store_true') return sample_parser def _create_udf_subparser(parser): udf_parser = parser.subcommand('udf', 'Create a named Javascript BigQuery UDF') udf_parser.add_argument('-n', '--name', help='The name for this UDF', required=True) udf_parser.add_argument('-l', '--language', help='The language of the function', required=True, choices=['sql', 'js']) return udf_parser def _create_datasource_subparser(parser): datasource_parser = parser.subcommand('datasource', 'Create a named Javascript BigQuery external data source') datasource_parser.add_argument('-n', '--name', help='The name for this data source', required=True) datasource_parser.add_argument('-p', '--paths', help='URL(s) of the data objects, can include a wildcard "*" at ' 'the end', required=True, nargs='+') datasource_parser.add_argument('-f', '--format', help='The format of the table\'s data. CSV or JSON, default CSV', default='CSV') datasource_parser.add_argument('-c', '--compressed', help='Whether the data is compressed', action='store_true') return datasource_parser def _create_dryrun_subparser(parser): dryrun_parser = parser.subcommand('dryrun', 'Execute a dry run of a BigQuery query and display ' 'approximate usage statistics') dryrun_parser.add_argument('-q', '--query', help='The name of the query to be dry run') dryrun_parser.add_argument('-b', '--billing', type=int, help='BigQuery billing tier') dryrun_parser.add_argument('-v', '--verbose', help='Show the expanded SQL that is being executed', action='store_true') return dryrun_parser def _create_query_subparser(parser): query_parser = parser.subcommand('query', 'Create or execute a BigQuery SQL query object, ' 'optionally using other SQL objects, UDFs, or external ' 'datasources. If a query name is not specified, the ' 'query is executed.') query_parser.add_argument('-n', '--name', help='The name of this SQL query object') query_parser.add_argument('--udfs', help='List of UDFs to reference in the query body', nargs='+') query_parser.add_argument('--datasources', help='List of external datasources to reference in the query body', nargs='+') query_parser.add_argument('--subqueries', help='List of subqueries to reference in the query body', nargs='+') query_parser.add_argument('-v', '--verbose', help='Show the expanded SQL that is being executed', action='store_true') return query_parser def _create_execute_subparser(parser): execute_parser = parser.subcommand('execute', 'Execute a BigQuery SQL query and optionally send ' 'the results to a named table.\nThe cell can ' 'optionally contain arguments for expanding ' 'variables in the query.') execute_parser.add_argument('-nc', '--nocache', help='Don\'t use previously cached results', action='store_true') execute_parser.add_argument('-b', '--billing', type=int, help='BigQuery billing tier') execute_parser.add_argument('-m', '--mode', help='The table creation mode', default='create', choices=['create', 'append', 'overwrite']) execute_parser.add_argument('-l', '--large', help='Whether to allow large results', action='store_true') execute_parser.add_argument('-q', '--query', help='The name of query to run', required=True) execute_parser.add_argument('-t', '--table', help='Target table name') execute_parser.add_argument('--to-dataframe', help='Convert the result into a dataframe', action='store_true') execute_parser.add_argument('--dataframe-start-row', help='Row of the table to start the ' + 'dataframe export') execute_parser.add_argument('--dataframe-max-rows', help='Upper limit on number of rows ' + 'to export to the dataframe', default=None) execute_parser.add_argument('-v', '--verbose', help='Show the expanded SQL that is being executed', action='store_true') return execute_parser def _create_extract_subparser(parser): extract_parser = parser.subcommand('extract', 'Extract a query or table into file (local or GCS)') extract_parser.add_argument('-nc', '--nocache', help='Don\'t use previously cached results', action='store_true') extract_parser.add_argument('-f', '--format', choices=['csv', 'json'], default='csv', help='The format to use for the export') extract_parser.add_argument('-b', '--billing', type=int, help='BigQuery billing tier') extract_parser.add_argument('-c', '--compress', action='store_true', help='Whether to compress the data') extract_parser.add_argument('-H', '--header', action='store_true', help='Whether to include a header line (CSV only)') extract_parser.add_argument('-D', '--delimiter', default=',', help='The field delimiter to use (CSV only)') group = extract_parser.add_mutually_exclusive_group() group.add_argument('-q', '--query', help='The name of query to extract') group.add_argument('-t', '--table', help='The name of the table to extract') group.add_argument('-v', '--view', help='The name of the view to extract') extract_parser.add_argument('-p', '--path', help='The path of the destination') extract_parser.add_argument('--verbose', help='Show the expanded SQL that is being executed', action='store_true') return extract_parser def _create_load_subparser(parser): load_parser = parser.subcommand('load', 'Load data from GCS into a BigQuery table. If creating a ' 'new table, a schema should be specified in YAML or JSON ' 'in the cell body, otherwise the schema is inferred from ' 'existing table.') load_parser.add_argument('-m', '--mode', help='One of create (default), append or overwrite', choices=['create', 'append', 'overwrite'], default='create') load_parser.add_argument('-f', '--format', help='The source format', choices=['json', 'csv'], default='csv') load_parser.add_argument('--skip', help='The number of initial lines to skip; useful for CSV headers', type=int, default=0) load_parser.add_argument('-s', '--strict', help='Whether to reject bad values and jagged lines', action='store_true') load_parser.add_argument('-d', '--delimiter', default=',', help='The inter-field delimiter for CVS (default ,)') load_parser.add_argument('-q', '--quote', default='"', help='The quoted field delimiter for CVS (default ")') load_parser.add_argument('-p', '--path', help='The path URL of the GCS source(s)') load_parser.add_argument('-t', '--table', help='The destination table name') return load_parser def _get_query_argument(args, cell, env): """ Get a query argument to a cell magic. The query is specified with args['query']. We look that up and if it is a BQ query object, just return it. If it is a string, build a query object out of it and return that Args: args: the dictionary of magic arguments. cell: the cell contents which can be variable value overrides (if args has a 'query' value) or inline SQL otherwise. env: a dictionary that is used for looking up variable values. Returns: A Query object. """ sql_arg = args.get('query', None) if sql_arg is None: # Assume we have inline SQL in the cell if not isinstance(cell, basestring): raise Exception('Expected a --query argument or inline SQL') return bigquery.Query(cell, env=env) item = google.datalab.utils.commands.get_notebook_item(sql_arg) if isinstance(item, bigquery.Query): return item else: raise Exception('Expected a query object, got %s.' % type(item)) def get_query_parameters(args, cell_body, date_time=datetime.datetime.now()): """Extract query parameters from cell body if provided Also validates the cell body schema using jsonschema to catch errors before sending the http request. This validation isn't complete, however; it does not validate recursive schemas, but it acts as a good filter against most simple schemas Args: args: arguments passed to the magic cell cell_body: body of the magic cell date_time: The timestamp at which the date-time related parameters need to be resolved. Returns: Validated object containing query parameters """ env = google.datalab.utils.commands.notebook_environment() config = google.datalab.utils.commands.parse_config(cell_body, env=env, as_dict=False) sql = args['query'] if sql is None: raise Exception('Cannot extract query parameters in non-query cell') # Validate query_params if config: jsonschema.validate(config, BigQuerySchema.QUERY_PARAMS_SCHEMA) config = config or {} config_parameters = config.get('parameters', []) return bigquery.Query.get_query_parameters(config_parameters, date_time=date_time) def _sample_cell(args, cell_body): """Implements the BigQuery sample magic for sampling queries The supported sytanx is: %%bq sample [] Args: args: the optional arguments following '%%bq sample'. cell_body: optional contents of the cell Returns: The results of executing the sampling query, or a profile of the sample data. """ env = google.datalab.utils.commands.notebook_environment() config = google.datalab.utils.commands.parse_config(cell_body, env, False) or {} parameters = config.get('parameters') or [] if parameters: jsonschema.validate({'parameters': parameters}, BigQuerySchema.QUERY_PARAMS_SCHEMA) query = None table = None view = None query_params = None if args['query']: query = google.datalab.utils.commands.get_notebook_item(args['query']) if query is None: raise Exception('Cannot find query %s.' % args['query']) query_params = get_query_parameters(args, cell_body) elif args['table']: table_name = google.datalab.bigquery.Query.resolve_parameters(args['table'], parameters) table = _get_table(table_name) if not table: raise Exception('Could not find table %s' % args['table']) elif args['view']: view = google.datalab.utils.commands.get_notebook_item(args['view']) if not isinstance(view, bigquery.View): raise Exception('Could not find view %s' % args['view']) else: raise Exception('A query, table, or view is neede to sample') # parse comma-separated list of fields fields = args['fields'].split(',') if args['fields'] else None count = int(args['count']) if args['count'] else None percent = int(args['percent']) if args['percent'] else None sampling = Sampling._auto(method=args['method'], fields=fields, count=count, percent=percent, key_field=args['key_field'], ascending=(args['order'] == 'ascending')) context = google.datalab.utils._utils._construct_context_for_args(args) if view: query = bigquery.Query.from_view(view) elif table: query = bigquery.Query.from_table(table) if args['profile']: results = query.execute(QueryOutput.dataframe(), sampling=sampling, context=context, query_params=query_params).result() else: results = query.execute(QueryOutput.table(), sampling=sampling, context=context, query_params=query_params).result() if args['verbose']: print(query.sql) if args['profile']: return google.datalab.utils.commands.profile_df(results) else: return results def _dryrun_cell(args, cell_body): """Implements the BigQuery cell magic used to dry run BQ queries. The supported syntax is: %%bq dryrun [-q|--sql ] [] Args: args: the argument following '%bq dryrun'. cell_body: optional contents of the cell interpreted as YAML or JSON. Returns: The response wrapped in a DryRunStats object """ query = _get_query_argument(args, cell_body, google.datalab.utils.commands.notebook_environment()) if args['verbose']: print(query.sql) context = google.datalab.utils._utils._construct_context_for_args(args) result = query.dry_run(context=context) return bigquery._query_stats.QueryStats( total_bytes=result['totalBytesProcessed'], is_cached=result['cacheHit']) def _udf_cell(args, cell_body): """Implements the Bigquery udf cell magic for ipython notebooks. The supported syntax is: %%bq udf --name --language // @param // @returns // @import Args: args: the optional arguments following '%%bq udf'. cell_body: the UDF declaration (inputs and outputs) and implementation in javascript. """ udf_name = args['name'] if not udf_name: raise Exception('Declaration must be of the form %%bq udf --name ') # Parse out parameters, return type, and imports param_pattern = r'^\s*\/\/\s*@param\s+([<>\w]+)\s+([<>\w,\s]+)\s*$' returns_pattern = r'^\s*\/\/\s*@returns\s+([<>\w,\s]+)\s*$' import_pattern = r'^\s*\/\/\s*@import\s+(\S+)\s*$' params = re.findall(param_pattern, cell_body, re.MULTILINE) return_type = re.findall(returns_pattern, cell_body, re.MULTILINE) imports = re.findall(import_pattern, cell_body, re.MULTILINE) if len(return_type) < 1: raise Exception('UDF return type must be defined using // @returns ') if len(return_type) > 1: raise Exception('Found more than one return type definition') return_type = return_type[0] # Finally build the UDF object udf = bigquery.UDF(udf_name, cell_body, return_type, params, args['language'], imports) google.datalab.utils.commands.notebook_environment()[udf_name] = udf def _datasource_cell(args, cell_body): """Implements the BigQuery datasource cell magic for ipython notebooks. The supported syntax is %%bq datasource --name --paths [--format ] Args: args: the optional arguments following '%%bq datasource' cell_body: the datasource's schema in json/yaml """ name = args['name'] paths = args['paths'] data_format = (args['format'] or 'CSV').lower() compressed = args['compressed'] or False # Get the source schema from the cell body record = google.datalab.utils.commands.parse_config( cell_body, google.datalab.utils.commands.notebook_environment(), as_dict=False) jsonschema.validate(record, BigQuerySchema.TABLE_SCHEMA_SCHEMA) schema = bigquery.Schema(record['schema']) # Finally build the datasource object datasource = bigquery.ExternalDataSource(source=paths, source_format=data_format, compressed=compressed, schema=schema) google.datalab.utils.commands.notebook_environment()[name] = datasource def _query_cell(args, cell_body): """Implements the BigQuery cell magic for used to build SQL objects. The supported syntax is: %%bq query [] Args: args: the optional arguments following '%%bql query'. cell_body: the contents of the cell """ name = args['name'] udfs = args['udfs'] datasources = args['datasources'] subqueries = args['subqueries'] # Finally build the query object query = bigquery.Query(cell_body, env=IPython.get_ipython().user_ns, udfs=udfs, data_sources=datasources, subqueries=subqueries) # if no name is specified, execute this query instead of defining it if name is None: return query.execute().result() else: google.datalab.utils.commands.notebook_environment()[name] = query def _execute_cell(args, cell_body): """Implements the BigQuery cell magic used to execute BQ queries. The supported syntax is: %%bq execute [] Args: args: the optional arguments following '%%bq execute'. cell_body: optional contents of the cell Returns: QueryResultsTable containing query result """ env = google.datalab.utils.commands.notebook_environment() config = google.datalab.utils.commands.parse_config(cell_body, env, False) or {} parameters = config.get('parameters') or [] if parameters: jsonschema.validate({'parameters': parameters}, BigQuerySchema.QUERY_PARAMS_SCHEMA) table_name = google.datalab.bigquery.Query.resolve_parameters(args['table'], parameters) query = google.datalab.utils.commands.get_notebook_item(args['query']) if args['verbose']: print(query.sql) query_params = get_query_parameters(args, cell_body) if args['to_dataframe']: # re-parse the int arguments because they're passed as strings start_row = int(args['dataframe_start_row']) if args['dataframe_start_row'] else None max_rows = int(args['dataframe_max_rows']) if args['dataframe_max_rows'] else None output_options = QueryOutput.dataframe(start_row=start_row, max_rows=max_rows, use_cache=not args['nocache']) else: output_options = QueryOutput.table( name=table_name, mode=args['mode'], use_cache=not args['nocache'], allow_large_results=args['large']) context = google.datalab.utils._utils._construct_context_for_args(args) r = query.execute(output_options, context=context, query_params=query_params) return r.result() # An LRU cache for Tables. This is mostly useful so that when we cross page boundaries # when paging through a table we don't have to re-fetch the schema. _existing_table_cache = google.datalab.utils.LRUCache(10) def _get_table(name): """ Given a variable or table name, get a Table if it exists. Args: name: the name of the Table or a variable referencing the Table. Returns: The Table, if found. """ # If name is a variable referencing a table, use that. item = google.datalab.utils.commands.get_notebook_item(name) if isinstance(item, bigquery.Table): return item # Else treat this as a BQ table name and return the (cached) table if it exists. try: return _existing_table_cache[name] except KeyError: table = bigquery.Table(name) if table.exists(): _existing_table_cache[name] = table return table return None def _render_list(data): """ Helper to render a list of objects as an HTML list object. """ return IPython.core.display.HTML(google.datalab.utils.commands.HtmlBuilder.render_list(data)) def _dataset_line(args): """Implements the BigQuery dataset magic subcommand used to operate on datasets The supported syntax is: %bq datasets Commands: {list, create, delete} Args: args: the optional arguments following '%bq datasets command'. """ if args['command'] == 'list': filter_ = args['filter'] if args['filter'] else '*' context = google.datalab.Context.default() if args['project']: context = google.datalab.Context(args['project'], context.credentials) return _render_list([str(dataset) for dataset in bigquery.Datasets(context) if fnmatch.fnmatch(str(dataset), filter_)]) elif args['command'] == 'create': try: bigquery.Dataset(args['name']).create(friendly_name=args['friendly']) except Exception as e: print('Failed to create dataset %s: %s' % (args['name'], e)) elif args['command'] == 'delete': try: bigquery.Dataset(args['name']).delete() except Exception as e: print('Failed to delete dataset %s: %s' % (args['name'], e)) def _table_cell(args, cell_body): """Implements the BigQuery table magic subcommand used to operate on tables The supported syntax is: %%bq tables Commands: {list, create, delete, describe, view} Args: args: the optional arguments following '%%bq tables command'. cell_body: optional contents of the cell interpreted as SQL, YAML or JSON. Returns: The HTML rendering for the table of datasets. """ if args['command'] == 'list': filter_ = args['filter'] if args['filter'] else '*' if args['dataset']: if args['project'] is None: datasets = [bigquery.Dataset(args['dataset'])] else: context = google.datalab.Context(args['project'], google.datalab.Context.default().credentials) datasets = [bigquery.Dataset(args['dataset'], context)] else: default_context = google.datalab.Context.default() context = google.datalab.Context(default_context.project_id, default_context.credentials) if args['project']: context.set_project_id(args['project']) datasets = bigquery.Datasets(context) tables = [] for dataset in datasets: tables.extend([table.full_name for table in dataset if fnmatch.fnmatch(table.full_name, filter_)]) return _render_list(tables) elif args['command'] == 'create': if cell_body is None: print('Failed to create %s: no schema specified' % args['name']) else: try: record = google.datalab.utils.commands.parse_config( cell_body, google.datalab.utils.commands.notebook_environment(), as_dict=False) jsonschema.validate(record, BigQuerySchema.TABLE_SCHEMA_SCHEMA) schema = bigquery.Schema(record['schema']) bigquery.Table(args['name']).create(schema=schema, overwrite=args['overwrite']) except Exception as e: print('Failed to create table %s: %s' % (args['name'], e)) elif args['command'] == 'describe': name = args['name'] table = _get_table(name) if not table: raise Exception('Could not find table %s' % name) html = _repr_html_table_schema(table.schema) return IPython.core.display.HTML(html) elif args['command'] == 'delete': try: bigquery.Table(args['name']).delete() except Exception as e: print('Failed to delete table %s: %s' % (args['name'], e)) elif args['command'] == 'view': name = args['name'] table = _get_table(name) if not table: raise Exception('Could not find table %s' % name) return table def _extract_cell(args, cell_body): """Implements the BigQuery extract magic used to extract query or table data to GCS. The supported syntax is: %bq extract Args: args: the arguments following '%bigquery extract'. """ env = google.datalab.utils.commands.notebook_environment() config = google.datalab.utils.commands.parse_config(cell_body, env, False) or {} parameters = config.get('parameters') if args['table']: table = google.datalab.bigquery.Query.resolve_parameters(args['table'], parameters) source = _get_table(table) if not source: raise Exception('Could not find table %s' % table) csv_delimiter = args['delimiter'] if args['format'] == 'csv' else None path = google.datalab.bigquery.Query.resolve_parameters(args['path'], parameters) job = source.extract(path, format=args['format'], csv_delimiter=csv_delimiter, csv_header=args['header'], compress=args['compress']) elif args['query'] or args['view']: source_name = args['view'] or args['query'] source = google.datalab.utils.commands.get_notebook_item(source_name) if not source: raise Exception('Could not find ' + ('view ' + args['view'] if args['view'] else 'query ' + args['query'])) query = source if args['query'] else bigquery.Query.from_view(source) query_params = get_query_parameters(args, cell_body) if args['query'] else None output_options = QueryOutput.file(path=args['path'], format=args['format'], csv_delimiter=args['delimiter'], csv_header=args['header'], compress=args['compress'], use_cache=not args['nocache']) context = google.datalab.utils._utils._construct_context_for_args(args) job = query.execute(output_options, context=context, query_params=query_params) else: raise Exception('A query, table, or view is needed to extract') if job.failed: raise Exception('Extract failed: %s' % str(job.fatal_error)) elif job.errors: raise Exception('Extract completed with errors: %s' % str(job.errors)) return job.result() def _load_cell(args, cell_body): """Implements the BigQuery load magic used to load data from GCS to a table. The supported syntax is: %bq load Args: args: the arguments following '%bq load'. cell_body: optional contents of the cell interpreted as YAML or JSON. Returns: A message about whether the load succeeded or failed. """ env = google.datalab.utils.commands.notebook_environment() config = google.datalab.utils.commands.parse_config(cell_body, env, False) or {} parameters = config.get('parameters') or [] if parameters: jsonschema.validate({'parameters': parameters}, BigQuerySchema.QUERY_PARAMS_SCHEMA) name = google.datalab.bigquery.Query.resolve_parameters(args['table'], parameters) table = _get_table(name) if not table: table = bigquery.Table(name) if args['mode'] == 'create': if table.exists(): raise Exception('table %s already exists; use "append" or "overwrite" as mode.' % name) if not cell_body or 'schema' not in cell_body: raise Exception('Table does not exist, and no schema specified in cell; cannot load.') schema = config['schema'] # schema can be an instance of bigquery.Schema. # For example, user can run "my_schema = bigquery.Schema.from_data(df)" in a previous cell and # specify "schema: $my_schema" in cell input. if not isinstance(schema, bigquery.Schema): jsonschema.validate({'schema': schema}, BigQuerySchema.TABLE_SCHEMA_SCHEMA) schema = bigquery.Schema(schema) table.create(schema=schema) elif not table.exists(): raise Exception('table %s does not exist; use "create" as mode.' % name) csv_options = bigquery.CSVOptions(delimiter=args['delimiter'], skip_leading_rows=args['skip'], allow_jagged_rows=not args['strict'], quote=args['quote']) path = google.datalab.bigquery.Query.resolve_parameters(args['path'], parameters) job = table.load(path, mode=args['mode'], source_format=args['format'], csv_options=csv_options, ignore_unknown_values=not args['strict']) if job.failed: raise Exception('Load failed: %s' % str(job.fatal_error)) elif job.errors: raise Exception('Load completed with errors: %s' % str(job.errors)) def _create_pipeline_subparser(parser): import argparse pipeline_parser = parser.subcommand( 'pipeline', formatter_class=argparse.RawTextHelpFormatter, help=""" Creates a GCS/BigQuery ETL pipeline. The cell-body is specified as follows: input: table | path: table load is also required> schema: format: {csv (default) | json} csv: delimiter: skip: strict: <{True | False (default)}; whether to accept rows with missing trailing (or optional) columns> quote: mode: <{append (default) | overwrite}; applicable if path->table load> transformation: query: output: table | path: path extract is required> mode: <{append | overwrite | create (default)}; applicable only when table is specified. format: <{csv (default) | json}> csv: delimiter: header: <{True (default) | False}; Whether to include an initial header line> compress: <{True | False (default) }; Whether to compress the data on export> schedule: start: end: interval: <{@once (default) | @hourly | @daily | @weekly | @ monthly | @yearly | }> catchup: <{True | False (default)}; when True, backfill is performed for start and end times. retries: Number of attempts to run the pipeline; default is 0 retry_delay_seconds: Number of seconds to wait before retrying the task emails: parameters: """) # noqa pipeline_parser.add_argument('-n', '--name', type=str, help='BigQuery pipeline name', required=True) pipeline_parser.add_argument('-d', '--gcs_dag_bucket', type=str, help='The Google Cloud Storage bucket for the Airflow dags.') pipeline_parser.add_argument('-f', '--gcs_dag_file_path', type=str, help='The file path suffix for the Airflow dags.') pipeline_parser.add_argument('-e', '--environment', type=str, help='The name of the Google Cloud Composer environment.') pipeline_parser.add_argument('-l', '--location', type=str, help='The location of the Google Cloud Composer environment. ' 'Refer https://cloud.google.com/about/locations/ for further ' 'details.') pipeline_parser.add_argument('-g', '--debug', type=str, help='Debug output with the airflow spec.') return pipeline_parser def _pipeline_cell(args, cell_body): """Implements the pipeline subcommand in the %%bq magic. Args: args: the arguments following '%%bq pipeline'. cell_body: Cell contents. """ name = args.get('name') if name is None: raise Exception('Pipeline name was not specified.') import google.datalab.utils as utils bq_pipeline_config = utils.commands.parse_config( cell_body, utils.commands.notebook_environment()) try: airflow_spec = \ google.datalab.contrib.bigquery.commands.get_airflow_spec_from_config(name, bq_pipeline_config) except AttributeError: return "Perhaps you're missing: import google.datalab.contrib.bigquery.commands" # If a gcs_dag_bucket is specified, we deploy to it so that the Airflow VM rsyncs it. error_message = '' gcs_dag_bucket = args.get('gcs_dag_bucket') gcs_dag_file_path = args.get('gcs_dag_file_path') if gcs_dag_bucket: try: airflow = google.datalab.contrib.pipeline.airflow.Airflow(gcs_dag_bucket, gcs_dag_file_path) airflow.deploy(name, airflow_spec) error_message += ("Airflow pipeline successfully deployed! View dashboard for more " "details.\n") except AttributeError: return "Perhaps you're missing: import google.datalab.contrib.pipeline.airflow" location = args.get('location') environment = args.get('environment') if location and environment: try: composer = google.datalab.contrib.pipeline.composer.Composer(location, environment) composer.deploy(name, airflow_spec) error_message += ("Composer pipeline successfully deployed! View dashboard for more " "details.\n") except AttributeError: return "Perhaps you're missing: import google.datalab.contrib.pipeline.composer" if args.get('debug'): error_message += '\n\n' + airflow_spec return error_message def _add_command(parser, subparser_fn, handler, cell_required=False, cell_prohibited=False): """ Create and initialize a bigquery subcommand handler. """ sub_parser = subparser_fn(parser) sub_parser.set_defaults(func=lambda args, cell: _dispatch_handler(args, cell, sub_parser, handler, cell_required=cell_required, cell_prohibited=cell_prohibited)) def _create_bigquery_parser(): """ Create the parser for the %bq magics. Note that because we use the func default handler dispatch mechanism of argparse, our handlers can take only one argument which is the parsed args. So we must create closures for the handlers that bind the cell contents and thus must recreate this parser for each cell upon execution. """ parser = google.datalab.utils.commands.CommandParser(prog='%bq', description=""" Execute various BigQuery-related operations. Use "%bq -h" for help on a specific command. """) # This is a bit kludgy because we want to handle some line magics and some cell magics # with the bq command. # %bq datasets _add_command(parser, _create_dataset_subparser, _dataset_line, cell_prohibited=True) # %bq tables _add_command(parser, _create_table_subparser, _table_cell) # %%bq query _add_command(parser, _create_query_subparser, _query_cell) # %%bq execute _add_command(parser, _create_execute_subparser, _execute_cell) # %bq extract _add_command(parser, _create_extract_subparser, _extract_cell) # %%bq sample _add_command(parser, _create_sample_subparser, _sample_cell) # %%bq dryrun _add_command(parser, _create_dryrun_subparser, _dryrun_cell) # %%bq udf _add_command(parser, _create_udf_subparser, _udf_cell, cell_required=True) # %%bq datasource _add_command(parser, _create_datasource_subparser, _datasource_cell, cell_required=True) # %bq load _add_command(parser, _create_load_subparser, _load_cell) # %bq pipeline _add_command(parser, _create_pipeline_subparser, _pipeline_cell) return parser _bigquery_parser = _create_bigquery_parser() @IPython.core.magic.register_line_cell_magic def bq(line, cell=None): """Implements the bq cell magic for ipython notebooks. The supported syntax is: %%bq [] or: %bq [] Use %bq --help for a list of commands, or %bq --help for help on a specific command. """ return google.datalab.utils.commands.handle_magic_line(line, cell, _bigquery_parser) def _dispatch_handler(args, cell, parser, handler, cell_required=False, cell_prohibited=False): """ Makes sure cell magics include cell and line magics don't, before dispatching to handler. Args: args: the parsed arguments from the magic line. cell: the contents of the cell, if any. parser: the argument parser for ; used for error message. handler: the handler to call if the cell present/absent check passes. cell_required: True for cell magics, False for line magics that can't be cell magics. cell_prohibited: True for line magics, False for cell magics that can't be line magics. Returns: The result of calling the handler. Raises: Exception if the invocation is not valid. """ if cell_prohibited: if cell and len(cell.strip()): parser.print_help() raise Exception('Additional data is not supported with the %s command.' % parser.prog) return handler(args) if cell_required and not cell: parser.print_help() raise Exception('The %s command requires additional data' % parser.prog) return handler(args, cell) def _table_viewer(table, rows_per_page=25, fields=None): """ Return a table viewer. This includes a static rendering of the first page of the table, that gets replaced by the charting code in environments where Javascript is executable and BQ is available. Args: table: the table to view. rows_per_page: how many rows to display at one time. fields: an array of field names to display; default is None which uses the full schema. Returns: A string containing the HTML for the table viewer. """ # TODO(gram): rework this to use google.datalab.utils.commands.chart_html if not table.exists(): raise Exception('Table %s does not exist' % table.full_name) if not table.is_listable(): return "Done" _HTML_TEMPLATE = u"""
{static_table}

{meta_data}
""" if fields is None: fields = google.datalab.utils.commands.get_field_list(fields, table.schema) div_id = google.datalab.utils.commands.Html.next_id() meta_count = ('rows: %d' % table.length) if table.length >= 0 else '' meta_name = table.full_name if table.job is None else ('job: %s' % table.job.id) if table.job: if table.job.cache_hit: meta_cost = 'cached' else: bytes = bigquery._query_stats.QueryStats._size_formatter(table.job.bytes_processed) meta_cost = '%s processed' % bytes meta_time = 'time: %.1fs' % table.job.total_time else: meta_cost = '' meta_time = '' data, total_count = google.datalab.utils.commands.get_data(table, fields, first_row=0, count=rows_per_page) if total_count < 0: # The table doesn't have a length metadata property but may still be small if we fetched less # rows than we asked for. fetched_count = len(data['rows']) if fetched_count < rows_per_page: total_count = fetched_count chart = 'table' if 0 <= total_count <= rows_per_page else 'paged_table' meta_entries = [meta_count, meta_time, meta_cost, meta_name] meta_data = '(%s)' % (', '.join([entry for entry in meta_entries if len(entry)])) return _HTML_TEMPLATE.format(div_id=div_id, static_table=google.datalab.utils.commands.HtmlBuilder .render_chart_data(data), meta_data=meta_data, chart_style=chart, source_index=google.datalab.utils.commands .get_data_source_index(table.full_name), fields=','.join(fields), total_rows=total_count, rows_per_page=rows_per_page, data=json.dumps(data, cls=google.datalab.utils.JSONEncoder)) def _repr_html_query(query): # TODO(nikhilko): Pretty print the SQL return google.datalab.utils.commands.HtmlBuilder.render_text(query.sql, preformatted=True) def _repr_html_query_results_table(results): return _table_viewer(results) def _repr_html_table(results): return _table_viewer(results) def _repr_html_table_schema(schema): _HTML_TEMPLATE = """
""" id = google.datalab.utils.commands.Html.next_id() return _HTML_TEMPLATE % (id, id, json.dumps(schema._bq_schema)) def _register_html_formatters(): try: # The full module paths need to be specified in the type name lookup ipy = IPython.get_ipython() html_formatter = ipy.display_formatter.formatters['text/html'] html_formatter.for_type_by_name('google.datalab.bigquery._query', 'Query', _repr_html_query) html_formatter.for_type_by_name('google.datalab.bigquery._query_results_table', 'QueryResultsTable', _repr_html_query_results_table) html_formatter.for_type_by_name('google.datalab.bigquery._table', 'Table', _repr_html_table) html_formatter.for_type_by_name('google.datalab.bigquery._schema', 'Schema', _repr_html_table_schema) except TypeError: # For when running unit tests pass _register_html_formatters() ================================================ FILE: google/datalab/commands/__init__.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """This module defines the `datalab` magics""" from __future__ import absolute_import from . import _datalab __all__ = ['_datalab'] ================================================ FILE: google/datalab/commands/_datalab.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Google Cloud Platform library - datalab cell magic.""" from __future__ import absolute_import from __future__ import unicode_literals try: import IPython import IPython.core.display import IPython.core.magic except ImportError: raise Exception('This module can only be loaded in ipython.') import google.datalab.utils.commands @IPython.core.magic.register_line_cell_magic def datalab(line, cell=None): """Implements the datalab cell magic for ipython notebooks. Args: line: the contents of the datalab line. Returns: The results of executing the cell. """ parser = google.datalab.utils.commands.CommandParser( prog='%datalab', description=""" Execute operations that apply to multiple Datalab APIs. Use "%datalab -h" for help on a specific command. """) config_parser = parser.subcommand( 'config', help='List or set API-specific configurations.') config_sub_commands = config_parser.add_subparsers(dest='command') # %%datalab config list config_list_parser = config_sub_commands.add_parser( 'list', help='List configurations') config_list_parser.set_defaults(func=_config_list_fn) # %%datalab config set -n -v config_set_parser = config_sub_commands.add_parser( 'set', help='Set configurations') config_set_parser.add_argument( '-n', '--name', help='The name of the configuration value', required=True) config_set_parser.add_argument( '-v', '--value', help='The value to set', required=True) config_set_parser.set_defaults(func=_config_set_fn) project_parser = parser.subcommand( 'project', help='Get or set the default project ID') project_sub_commands = project_parser.add_subparsers(dest='command') # %%datalab project get project_get_parser = project_sub_commands.add_parser( 'get', help='Get the default project ID') project_get_parser.set_defaults(func=_project_get_fn) # %%datalab project set -p project_set_parser = project_sub_commands.add_parser( 'set', help='Set the default project ID') project_set_parser.add_argument( '-p', '--project', help='The default project ID', required=True) project_set_parser.set_defaults(func=_project_set_fn) return google.datalab.utils.commands.handle_magic_line(line, cell, parser) def _config_list_fn(args, cell): ctx = google.datalab.Context.default() return google.datalab.utils.commands.render_dictionary([ctx.config]) def _config_set_fn(args, cell): name = args['name'] value = args['value'] ctx = google.datalab.Context.default() ctx.config[name] = value return google.datalab.utils.commands.render_dictionary([ctx.config]) def _project_get_fn(args, cell): ctx = google.datalab.Context.default() return google.datalab.utils.commands.render_text(ctx.project_id) def _project_set_fn(args, cell): project = args['project'] ctx = google.datalab.Context.default() ctx.set_project_id(project) return ================================================ FILE: google/datalab/contrib/__init__.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: google/datalab/contrib/bigquery/__init__.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: google/datalab/contrib/bigquery/commands/__init__.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from ._bigquery import get_airflow_spec_from_config # noqa ================================================ FILE: google/datalab/contrib/bigquery/commands/_bigquery.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - BigQuery IPython Functionality.""" import google import google.datalab.utils as utils import google.datalab.contrib.pipeline._pipeline import jsonschema def get_airflow_spec_from_config(name, bq_pipeline_config): pipeline_spec = google.datalab.contrib.bigquery.commands._bigquery._get_pipeline_spec_from_config( bq_pipeline_config) return google.datalab.contrib.pipeline._pipeline.PipelineGenerator.generate_airflow_spec( name, pipeline_spec) def _get_pipeline_spec_from_config(bq_pipeline_config): pipeline_spec = {} schedule_config = bq_pipeline_config.get('schedule') if schedule_config: pipeline_spec['schedule'] = schedule_config email_config = bq_pipeline_config.get('emails') if email_config: pipeline_spec['emails'] = email_config input_config = bq_pipeline_config.get('input') or bq_pipeline_config.get('load') transformation_config = bq_pipeline_config.get('transformation') output_config = bq_pipeline_config.get('output') or bq_pipeline_config.get('extract') parameters_config = bq_pipeline_config.get('parameters') if parameters_config: jsonschema.validate( {'parameters': parameters_config}, google.datalab.bigquery.commands._bigquery.BigQuerySchema.QUERY_PARAMS_SCHEMA) pipeline_spec['parameters'] = parameters_config pipeline_spec['tasks'] = {} load_task_id = None load_task_config = _get_load_parameters(input_config, transformation_config, output_config) if load_task_config: load_task_id = 'bq_pipeline_load_task' pipeline_spec['tasks'][load_task_id] = load_task_config execute_task_config = _get_execute_parameters(load_task_id, input_config, transformation_config, output_config, parameters_config) execute_task_id = None if execute_task_config: execute_task_id = 'bq_pipeline_execute_task' pipeline_spec['tasks'][execute_task_id] = execute_task_config extract_task_config = _get_extract_parameters(execute_task_id, input_config, transformation_config, output_config) if extract_task_config: pipeline_spec['tasks']['bq_pipeline_extract_task'] = extract_task_config if not load_task_config and not execute_task_config and not extract_task_config: raise Exception('Pipeline has no tasks to execute.') return pipeline_spec def _get_load_parameters(bq_pipeline_input_config, bq_pipeline_transformation_config, bq_pipeline_output_config): if bq_pipeline_input_config is None: return None load_task_config = {'type': 'pydatalab.bq.load'} # The path URL of the GCS load file(s). if 'path' not in bq_pipeline_input_config: return None # The path URL of the GCS load file(s), and associated parameters load_task_config['path'] = bq_pipeline_input_config.get('path') if 'format' in bq_pipeline_input_config: load_task_config['format'] = bq_pipeline_input_config['format'] if 'csv' in bq_pipeline_input_config: load_task_config['csv_options'] = bq_pipeline_input_config['csv'] # The destination BQ table name for loading source_of_table = bq_pipeline_input_config if ('table' not in bq_pipeline_input_config and not bq_pipeline_transformation_config and bq_pipeline_output_config and 'table' in bq_pipeline_output_config and 'path' not in bq_pipeline_output_config): # If we're here it means that there was no transformation config, but there was an output # config with only a table (and no path). We assume that the user was just trying to do a # gcs->table (or load) step, so we take that as the input table (and emit a load # operator). source_of_table = bq_pipeline_output_config # If a table or path are absent, there is no load to be done so we return None if 'table' not in source_of_table: return None load_task_config['table'] = source_of_table.get('table') if 'schema' in source_of_table: load_task_config['schema'] = source_of_table['schema'] if 'mode' in source_of_table: load_task_config['mode'] = source_of_table['mode'] return load_task_config def _get_execute_parameters(load_task_id, bq_pipeline_input_config, bq_pipeline_transformation_config, bq_pipeline_output_config, bq_pipeline_parameters_config): if bq_pipeline_transformation_config is None: return None # The name of query for execution; if absent, we return None as we assume that there is # no query to execute if 'query' not in bq_pipeline_transformation_config: return None execute_task_config = { 'type': 'pydatalab.bq.execute', } if load_task_id: execute_task_config['up_stream'] = [load_task_id] # If the input config has a path but no table, we assume that the user has specified an # external data_source either explicitly (i.e. via specifying a "data_source" key in the input # config, or implicitly (i.e. by letting us assume that this is called "input") if (bq_pipeline_input_config and 'path' in bq_pipeline_input_config and 'table' not in bq_pipeline_input_config): execute_task_config['data_source'] = bq_pipeline_input_config.get('data_source', 'input') if 'path' in bq_pipeline_input_config: # We format the path since this could contain format modifiers execute_task_config['path'] = bq_pipeline_input_config['path'] if 'schema' in bq_pipeline_input_config: execute_task_config['schema'] = bq_pipeline_input_config['schema'] if 'max_bad_records' in bq_pipeline_input_config: execute_task_config['max_bad_records'] = bq_pipeline_input_config['max_bad_records'] if 'format' in bq_pipeline_input_config: execute_task_config['source_format'] = bq_pipeline_input_config.get('format') if 'csv' in bq_pipeline_input_config: execute_task_config['csv_options'] = bq_pipeline_input_config.get('csv') query = utils.commands.get_notebook_item(bq_pipeline_transformation_config['query']) # If there is a table in the input config, we allow the user to reference table with the name # 'input' in their sql, i.e. via something like 'SELECT col1 FROM input WHERE ...'. To enable # this, we include the input table as a subquery with the query object. If the user's sql does # not reference an 'input' table, BigQuery will just ignore it. Things get interesting if the # user's sql specifies a subquery named 'input' - that should override the subquery that we use. # TODO(rajivpb): Verify this. if (bq_pipeline_input_config and 'table' in bq_pipeline_input_config): table_name = google.datalab.bigquery.Query.resolve_parameters( bq_pipeline_input_config.get('table'), bq_pipeline_parameters_config, macros=True) input_subquery_sql = 'SELECT * FROM `{0}`'.format(table_name) input_subquery = google.datalab.bigquery.Query(input_subquery_sql) # We artificially create an env with just the 'input' key, and the new input_query value to # fool the Query object into using the subquery correctly. query = google.datalab.bigquery.Query(query.sql, env={'input': input_subquery}, subqueries=['input']) execute_task_config['sql'] = query.sql execute_task_config['parameters'] = bq_pipeline_parameters_config if bq_pipeline_output_config: if 'table' in bq_pipeline_output_config: execute_task_config['table'] = bq_pipeline_output_config['table'] if 'mode' in bq_pipeline_output_config: execute_task_config['mode'] = bq_pipeline_output_config['mode'] return execute_task_config def _get_extract_parameters(execute_task_id, bq_pipeline_input_config, bq_pipeline_transformation_config, bq_pipeline_output_config): if bq_pipeline_output_config is None: return None extract_task_config = { 'type': 'pydatalab.bq.extract', } if execute_task_id: extract_task_config['up_stream'] = [execute_task_id] extract_task_config['table'] = """{{{{ ti.xcom_pull(task_ids='{0}_id').get('table') }}}}"""\ .format(execute_task_id) # If a path is not specified, there is no extract to be done, so we return None if 'path' not in bq_pipeline_output_config: return None extract_task_config['path'] = bq_pipeline_output_config.get('path') if 'format' in bq_pipeline_output_config: extract_task_config['format'] = bq_pipeline_output_config.get('format') if 'csv' in bq_pipeline_output_config: extract_task_config['csv_options'] = bq_pipeline_output_config.get('csv') # If a temporary table from the bigquery results is being used, this will not be present in the # output section. source_of_table = None if 'table' in bq_pipeline_output_config: source_of_table = bq_pipeline_output_config elif (bq_pipeline_input_config and not bq_pipeline_transformation_config and 'table' in bq_pipeline_input_config and 'path' not in bq_pipeline_input_config): # If we're here it means that there was no transformation config, but there was an input # config with only a table and no path. We assume that the user was just trying to do a # table->gcs (or extract) step, so we take that as the input table (and emit an extract # operator). source_of_table = bq_pipeline_input_config if source_of_table: extract_task_config['table'] = source_of_table['table'] return extract_task_config ================================================ FILE: google/datalab/contrib/bigquery/operators/__init__.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: google/datalab/contrib/bigquery/operators/_bq_execute_operator.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. import google.datalab.bigquery as bq from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults class ExecuteOperator(BaseOperator): template_fields = ('table', 'parameters', 'path', 'sql') @apply_defaults def __init__(self, sql, parameters=None, table=None, mode=None, data_source=None, path=None, format=None, csv_options=None, schema=None, max_bad_records=None, *args, **kwargs): super(ExecuteOperator, self).__init__(*args, **kwargs) self.sql = sql self.table = table self.mode = mode self.parameters = parameters self.data_source = data_source self.path = path self.format = format self.csv_options = csv_options self.schema = schema self.max_bad_records = max_bad_records def execute(self, context): if self.data_source: kwargs = {} if self.csv_options: csv_kwargs = {} if 'delimiter' in self.csv_options: csv_kwargs['delimiter'] = self.csv_options['delimiter'] if 'skip' in self.csv_options: csv_kwargs['skip_leading_rows'] = self.csv_options['skip'] if 'strict' in self.csv_options: csv_kwargs['allow_jagged_rows'] = self.csv_options['strict'] if 'quote' in self.csv_options: csv_kwargs['quote'] = self.csv_options['quote'] kwargs['csv_options'] = bq.CSVOptions(**csv_kwargs) if self.format: kwargs['source_format'] = self.format if self.max_bad_records: kwargs['max_bad_records'] = self.max_bad_records external_data_source = bq.ExternalDataSource( source=self.path, schema=bq.Schema(self.schema), **kwargs) query = bq.Query(sql=self.sql, data_sources={self.data_source: external_data_source}) else: query = bq.Query(sql=self.sql) # use_cache is False since this is most likely the case in pipeline scenarios # allow_large_results can be True only if table is specified (i.e. when it's not None) kwargs = {} if self.mode is not None: kwargs['mode'] = self.mode output_options = bq.QueryOutput.table(name=self.table, use_cache=False, allow_large_results=self.table is not None, **kwargs) query_params = bq.Query.get_query_parameters(self.parameters) job = query.execute(output_options=output_options, query_params=query_params) # Returning the table-name here makes it available for downstream task instances. return { 'table': job.result().full_name } ================================================ FILE: google/datalab/contrib/bigquery/operators/_bq_extract_operator.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. import google from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults class ExtractOperator(BaseOperator): template_fields = ('table', 'path') @apply_defaults def __init__(self, path, table, format='csv', csv_options=None, *args, **kwargs): super(ExtractOperator, self).__init__(*args, **kwargs) self.table = table self.path = path self.format = format self.csv_options = csv_options or {} def execute(self, context): source_table = google.datalab.bigquery.Table(self.table, context=None) csv_kwargs = {} if 'delimiter' in self.csv_options: csv_kwargs['csv_delimiter'] = self.csv_options['delimiter'] if 'header' in self.csv_options: csv_kwargs['csv_header'] = self.csv_options['header'] if 'compress' in self.csv_options: csv_kwargs['compress'] = self.csv_options['compress'] job = source_table.extract( self.path, format='CSV' if self.format == 'csv' else 'NEWLINE_DELIMITED_JSON', **csv_kwargs) if job.failed: raise Exception('Extract failed: %s' % str(job.fatal_error)) elif job.errors: raise Exception('Extract completed with errors: %s' % str(job.errors)) return { 'result': job.result() } ================================================ FILE: google/datalab/contrib/bigquery/operators/_bq_load_operator.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. import google.datalab.bigquery as bq from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults class LoadOperator(BaseOperator): """Implements the BigQuery load magic used to load data from GCS to a table. The supported syntax is: %bq load Args: args: the arguments following '%bq load'. cell_body: optional contents of the cell interpreted as YAML or JSON. Returns: A message about whether the load succeeded or failed. """ template_fields = ('table', 'path') @apply_defaults def __init__(self, table, path, mode='append', format='csv', schema=None, csv_options=None, *args, **kwargs): super(LoadOperator, self).__init__(*args, **kwargs) self.table = table self.path = path self.mode = mode self.format = format self.csv_options = csv_options or {} self.schema = schema # TODO(rajipb): In schema validation, make sure that mode is either 'append' or 'create' def execute(self, context): table = bq.Table(self.table, context=None) if not table.exists(): table.create(schema=self.schema) kwargs = {} if 'delimiter' in self.csv_options: kwargs['delimiter'] = self.csv_options['delimiter'] if 'skip' in self.csv_options: kwargs['skip_leading_rows'] = self.csv_options['skip'] if 'strict' in self.csv_options: kwargs['allow_jagged_rows'] = self.csv_options['strict'] if 'quote' in self.csv_options: kwargs['quote'] = self.csv_options['quote'] csv_options = bq.CSVOptions(**kwargs) job = table.load(self.path, mode=self.mode, source_format=('csv' if self.format == 'csv' else 'NEWLINE_DELIMITED_JSON'), csv_options=csv_options, ignore_unknown_values=not self.csv_options.get('strict')) if job.failed: raise Exception('Load failed: %s' % str(job.fatal_error)) elif job.errors: raise Exception('Load completed with errors: %s' % str(job.errors)) return { 'result': job.result() } ================================================ FILE: google/datalab/contrib/mlworkbench/__init__.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from ._local_predict import get_prediction_results, get_probs_for_labels, local_batch_predict from ._prediction_explainer import PredictionExplainer __all__ = ['get_prediction_results', 'get_probs_for_labels', 'local_batch_predict', 'PredictionExplainer'] ================================================ FILE: google/datalab/contrib/mlworkbench/_archive.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Google Cloud Platform library - ml cell magic.""" from __future__ import absolute_import from __future__ import unicode_literals import os import shutil import tempfile import tensorflow as tf import google.datalab.contrib.mlworkbench._shell_process as _shell_process def extract_archive(archive_path, dest): """Extract a local or GCS archive file to a folder. Args: archive_path: local or gcs path to a *.tar.gz or *.tar file dest: local folder the archive will be extracted to """ # Make the dest folder if it does not exist if not os.path.isdir(dest): os.makedirs(dest) try: tmpfolder = None if (not tf.gfile.Exists(archive_path)) or tf.gfile.IsDirectory(archive_path): raise ValueError('archive path %s is not a file' % archive_path) if archive_path.startswith('gs://'): # Copy the file to a local temp folder tmpfolder = tempfile.mkdtemp() cmd_args = ['gsutil', 'cp', archive_path, tmpfolder] _shell_process.run_and_monitor(cmd_args, os.getpid()) archive_path = os.path.join(tmpfolder, os.path.name(archive_path)) if archive_path.lower().endswith('.tar.gz'): flags = '-xzf' elif archive_path.lower().endswith('.tar'): flags = '-xf' else: raise ValueError('Only tar.gz or tar.Z files are supported.') cmd_args = ['tar', flags, archive_path, '-C', dest] _shell_process.run_and_monitor(cmd_args, os.getpid()) finally: if tmpfolder: shutil.rmtree(tmpfolder) ================================================ FILE: google/datalab/contrib/mlworkbench/_local_predict.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Google Cloud Platform library - ml cell magic.""" from __future__ import absolute_import from __future__ import unicode_literals import base64 import collections import copy import csv from io import BytesIO import json import logging import numpy as np import os import pandas as pd from PIL import Image import six import tensorflow as tf from tensorflow.python.lib.io import file_io from tensorflow.python.saved_model import signature_constants import google.datalab.ml as ml def _tf_load_model(sess, model_dir): """Load a tf model from model_dir, and return input/output alias maps.""" meta_graph_pb = tf.saved_model.loader.load( sess=sess, tags=[tf.saved_model.tag_constants.SERVING], export_dir=model_dir) signature = meta_graph_pb.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] input_alias_map = {friendly_name: tensor_info_proto.name for (friendly_name, tensor_info_proto) in signature.inputs.items()} output_alias_map = {friendly_name: tensor_info_proto.name for (friendly_name, tensor_info_proto) in signature.outputs.items()} return input_alias_map, output_alias_map def _tf_predict(model_dir, input_csvlines): """Prediction with a tf savedmodel. Args: model_dir: directory that contains a saved model input_csvlines: list of csv strings Returns: Dict in the form tensor_name:prediction_list. Note that the value is always a list, even if there was only 1 row in input_csvlines. """ with tf.Graph().as_default(), tf.Session() as sess: input_alias_map, output_alias_map = _tf_load_model(sess, model_dir) csv_tensor_name = list(input_alias_map.values())[0] results = sess.run(fetches=output_alias_map, feed_dict={csv_tensor_name: input_csvlines}) # convert any scalar values to a list. This may happen when there is one # example in input_csvlines and the model uses tf.squeeze on the output # tensor. if len(input_csvlines) == 1: for k, v in six.iteritems(results): if not isinstance(v, (list, np.ndarray)): results[k] = [v] # Convert bytes to string. In python3 the results may be bytes. for k, v in six.iteritems(results): if any(isinstance(x, bytes) for x in v): results[k] = [x.decode('utf-8') for x in v] return results def _download_images(data, img_cols): """Download images given image columns.""" images = collections.defaultdict(list) for d in data: for img_col in img_cols: if d.get(img_col, None): if isinstance(d[img_col], Image.Image): # If it is already an Image, just copy and continue. images[img_col].append(d[img_col]) else: # Otherwise it is image url. Load the image. with file_io.FileIO(d[img_col], 'rb') as fi: im = Image.open(fi) images[img_col].append(im) else: images[img_col].append('') return images def _get_predicton_csv_lines(data, headers, images): """Create CSV lines from list-of-dict data.""" if images: data = copy.deepcopy(data) for img_col in images: for d, im in zip(data, images[img_col]): if im == '': continue im = im.copy() im.thumbnail((299, 299), Image.ANTIALIAS) buf = BytesIO() im.save(buf, "JPEG") content = base64.urlsafe_b64encode(buf.getvalue()).decode('ascii') d[img_col] = content csv_lines = [] for d in data: buf = six.StringIO() writer = csv.DictWriter(buf, fieldnames=headers, lineterminator='') writer.writerow(d) csv_lines.append(buf.getvalue()) return csv_lines def _get_display_data_with_images(data, images): """Create display data by converting image urls to base64 strings.""" if not images: return data display_data = copy.deepcopy(data) for img_col in images: for d, im in zip(display_data, images[img_col]): if im == '': d[img_col + '_image'] = '' else: im = im.copy() im.thumbnail((128, 128), Image.ANTIALIAS) buf = BytesIO() im.save(buf, "PNG") content = base64.b64encode(buf.getvalue()).decode('ascii') d[img_col + '_image'] = content return display_data def get_model_schema_and_features(model_dir): """Get a local model's schema and features config. Args: model_dir: local or GCS path of a model. Returns: A tuple of schema (list) and features config (dict). """ schema_file = os.path.join(model_dir, 'assets.extra', 'schema.json') schema = json.loads(file_io.read_file_to_string(schema_file)) features_file = os.path.join(model_dir, 'assets.extra', 'features.json') features_config = json.loads(file_io.read_file_to_string(features_file)) return schema, features_config def get_prediction_results(model_dir_or_id, data, headers, img_cols=None, cloud=False, with_source=True, show_image=True): """ Predict with a specified model. It predicts with the model, join source data with prediction results, and formats the results so they can be displayed nicely in Datalab. Args: model_dir_or_id: The model directory if cloud is False, or model.version if cloud is True. data: Can be a list of dictionaries, a list of csv lines, or a Pandas DataFrame. If it is not a list of csv lines, data will be converted to csv lines first, using the orders specified by headers and then send to model. For images, it can be image gs urls or in-memory PIL images. Images will be converted to base64 encoded strings before prediction. headers: the column names of data. It specifies the order of the columns when serializing to csv lines for prediction. img_cols: The image url columns. If specified, the img_urls will be converted to base64 encoded image bytes. with_source: Whether return a joined prediction source and prediction results, or prediction results only. show_image: When displaying prediction source, whether to add a column of image bytes for each image url column. Returns: A dataframe of joined prediction source and prediction results, or prediction results only. """ if img_cols is None: img_cols = [] if isinstance(data, pd.DataFrame): data = list(data.T.to_dict().values()) elif isinstance(data[0], six.string_types): data = list(csv.DictReader(data, fieldnames=headers)) images = _download_images(data, img_cols) predict_data = _get_predicton_csv_lines(data, headers, images) if cloud: parts = model_dir_or_id.split('.') if len(parts) != 2: raise ValueError('Invalid model name for cloud prediction. Use "model.version".') predict_results = ml.ModelVersions(parts[0]).predict(parts[1], predict_data) else: tf_logging_level = logging.getLogger("tensorflow").level logging.getLogger("tensorflow").setLevel(logging.WARNING) try: predict_results = _tf_predict(model_dir_or_id, predict_data) finally: logging.getLogger("tensorflow").setLevel(tf_logging_level) df_r = pd.DataFrame(predict_results) if not with_source: return df_r display_data = data if show_image: display_data = _get_display_data_with_images(data, images) df_s = pd.DataFrame(display_data) df = pd.concat([df_r, df_s], axis=1) # Remove duplicate columns. All 'key' columns are duplicate here. df = df.loc[:, ~df.columns.duplicated()] return df def get_probs_for_labels(labels, prediction_results): """ Given ML Workbench prediction results, get probs of each label for each instance. The prediction results are like: [ {'predicted': 'daisy', 'probability': 0.8, 'predicted_2': 'rose', 'probability_2': 0.1}, {'predicted': 'sunflower', 'probability': 0.9, 'predicted_2': 'daisy', 'probability_2': 0.01}, ... ] Each instance is ordered by prob. But in some cases probs are needed for fixed order of labels. For example, given labels = ['daisy', 'rose', 'sunflower'], the results of above is expected to be: [ [0.8, 0.1, 0.0], [0.01, 0.0, 0.9], ... ] Note that the sum of each instance may not be always 1. If model's top_n is set to none-zero, and is less than number of labels, then prediction results may not contain probs for all labels. Args: labels: a list of labels specifying the order of the labels. prediction_results: a pandas DataFrame containing prediction results, usually returned by get_prediction_results() call. Returns: A list of list of probs for each class. """ probs = [] if 'probability' in prediction_results: # 'probability' exists so top-n is set to none zero, and results are like # "predicted, predicted_2,...,probability,probability_2,... for i, r in prediction_results.iterrows(): probs_one = [0.0] * len(labels) for k, v in six.iteritems(r): if v in labels and k.startswith('predicted'): if k == 'predict': prob_name = 'probability' else: prob_name = 'probability' + k[9:] probs_one[labels.index(v)] = r[prob_name] probs.append(probs_one) return probs else: # 'probability' does not exist, so top-n is set to zero. Results are like # "predicted, class_name1, class_name2,... for i, r in prediction_results.iterrows(): probs_one = [0.0] * len(labels) for k, v in six.iteritems(r): if k in labels: probs_one[labels.index(k)] = v probs.append(probs_one) return probs def _batch_csv_reader(csv_file, n): with file_io.FileIO(csv_file, 'r') as f: args = [f] * n return six.moves.zip_longest(*args) def _get_output_schema(session, output_alias_map): schema = [] for name in sorted(six.iterkeys(output_alias_map)): tensor_name = output_alias_map[name] dtype = session.graph.get_tensor_by_name(tensor_name).dtype if dtype == tf.int32 or dtype == tf.int64: schema.append({'name': name, 'type': 'INTEGER'}) elif dtype == tf.float32 or dtype == tf.float64: schema.append({'name': name, 'type': 'FLOAT'}) else: schema.append({'name': name, 'type': 'STRING'}) return schema def _format_results(output_format, output_schema, batched_results): # Convert a dict of list to a list of dict. # Note that results from session.run may contain scaler value instead of lists # if batch size is 1. if (isinstance(next(iter(batched_results.values())), list) or isinstance(next(iter(batched_results.values())), np.ndarray)): batched_results = [dict(zip(batched_results, t)) for t in zip(*batched_results.values())] else: batched_results = [batched_results] if output_format == 'csv': results = [] for r in batched_results: values = [str(r[schema['name']]) for schema in output_schema] results.append(','.join(values)) elif output_format == 'json': # Default json encoder cannot handle numpy types. class _JSONEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.integer): return int(obj) elif isinstance(obj, np.floating): return float(obj) elif isinstance(obj, bytes): return obj.decode('utf-8') elif isinstance(obj, np.ndarray): return obj.tolist() else: return super(_JSONEncoder, self).default(obj) results = [json.dumps(x, cls=_JSONEncoder) for x in batched_results] else: raise ValueError('Unknown output_format %s' % output_format) return results def local_batch_predict(model_dir, csv_file_pattern, output_dir, output_format, batch_size=100): """ Batch Predict with a specified model. It does batch prediction, saves results to output files and also creates an output schema file. The output file names are input file names prepended by 'predict_results_'. Args: model_dir: The model directory containing a SavedModel (usually saved_model.pb). csv_file_pattern: a pattern of csv files as batch prediction source. output_dir: the path of the output directory. output_format: csv or json. batch_size: Larger batch_size improves performance but may cause more memory usage. """ file_io.recursive_create_dir(output_dir) csv_files = file_io.get_matching_files(csv_file_pattern) if len(csv_files) == 0: raise ValueError('No files found given ' + csv_file_pattern) with tf.Graph().as_default(), tf.Session() as sess: input_alias_map, output_alias_map = _tf_load_model(sess, model_dir) csv_tensor_name = list(input_alias_map.values())[0] output_schema = _get_output_schema(sess, output_alias_map) for csv_file in csv_files: output_file = os.path.join( output_dir, 'predict_results_' + os.path.splitext(os.path.basename(csv_file))[0] + '.' + output_format) with file_io.FileIO(output_file, 'w') as f: prediction_source = _batch_csv_reader(csv_file, batch_size) for batch in prediction_source: batch = [l.rstrip() for l in batch if l] predict_results = sess.run(fetches=output_alias_map, feed_dict={csv_tensor_name: batch}) formatted_results = _format_results(output_format, output_schema, predict_results) f.write('\n'.join(formatted_results) + '\n') file_io.write_string_to_file(os.path.join(output_dir, 'predict_results_schema.json'), json.dumps(output_schema, indent=2)) ================================================ FILE: google/datalab/contrib/mlworkbench/_prediction_explainer.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Google Cloud Platform library - ML Workbench Model Prediction Explainer.""" from __future__ import absolute_import from __future__ import unicode_literals import base64 import csv import io import numpy as np import pandas as pd from PIL import Image import six import tensorflow as tf from tensorflow.python.lib.io import file_io from . import _local_predict class PredictionExplainer(object): """An explainer that explains text and image predictions based on LIME.""" def __init__(self, model_dir): """ Args: model_dir: the directory of the model to use for prediction. """ self._model_dir = model_dir schema, features = _local_predict.get_model_schema_and_features(model_dir) self._headers = [x['name'] for x in schema] self._text_columns, self._image_columns = [], [] self._categorical_columns, self._numeric_columns = [], [] for k, v in six.iteritems(features): if v['transform'] in ['image_to_vec']: self._image_columns.append(v['source_column']) elif v['transform'] in ['bag_of_words', 'tfidf']: self._text_columns.append(v['source_column']) elif v['transform'] in ['one_hot', 'embedding']: self._categorical_columns.append(v['source_column']) elif v['transform'] in ['identity', 'scale']: self._numeric_columns.append(v['source_column']) def _make_text_predict_fn(self, labels, instance, column_to_explain): """Create a predict_fn that can be used by LIME text explainer. """ def _predict_fn(perturbed_text): predict_input = [] for x in perturbed_text: instance_copy = dict(instance) instance_copy[column_to_explain] = x predict_input.append(instance_copy) df = _local_predict.get_prediction_results(self._model_dir, predict_input, self._headers, with_source=False) probs = _local_predict.get_probs_for_labels(labels, df) return np.asarray(probs) return _predict_fn def _make_image_predict_fn(self, labels, instance, column_to_explain): """Create a predict_fn that can be used by LIME image explainer. """ def _predict_fn(perturbed_image): predict_input = [] for x in perturbed_image: instance_copy = dict(instance) instance_copy[column_to_explain] = Image.fromarray(x) predict_input.append(instance_copy) df = _local_predict.get_prediction_results( self._model_dir, predict_input, self._headers, img_cols=self._image_columns, with_source=False) probs = _local_predict.get_probs_for_labels(labels, df) return np.asarray(probs) return _predict_fn def _get_unique_categories(self, df): """Get all categories for each categorical columns from training data.""" categories = [] for col in self._categorical_columns: categocial = pd.Categorical(df[col]) col_categories = list(map(str, categocial.categories)) col_categories.append('_UNKNOWN') categories.append(col_categories) return categories def _preprocess_data_for_tabular_explain(self, df, categories): """Get preprocessed training set in numpy array, and categorical names from raw training data. LIME tabular explainer requires a training set to know the distribution of numeric and categorical values. The training set has to be numpy arrays, with all categorical values converted to indices. It also requires list of names for each category. """ df = df.copy() # Remove non tabular columns (text, image). for col in list(df.columns): if col not in (self._categorical_columns + self._numeric_columns): del df[col] # Convert categorical values into indices. for col_name, col_categories in zip(self._categorical_columns, categories): df[col_name] = df[col_name].apply( lambda x: col_categories.index(str(x)) if str(x) in col_categories else len(col_categories) - 1) # Make sure numeric values are really numeric for numeric_col in self._numeric_columns: df[numeric_col] = df[numeric_col].apply(lambda x: float(x)) return df.as_matrix(self._categorical_columns + self._numeric_columns) def _make_tabular_predict_fn(self, labels, instance, categories): """Create a predict_fn that can be used by LIME tabular explainer. """ def _predict_fn(np_instance): df = pd.DataFrame( np_instance, columns=(self._categorical_columns + self._numeric_columns)) # Convert categorical indices back to categories. for col_name, col_categories in zip(self._categorical_columns, categories): df[col_name] = df[col_name].apply(lambda x: col_categories[int(x)]) # Add columns that do not exist in the perturbed data, # such as key, text, and image data. for col_name in self._headers: if col_name not in (self._categorical_columns + self._numeric_columns): df[col_name] = instance[col_name] r = _local_predict.get_prediction_results( self._model_dir, df, self._headers, with_source=False) probs = _local_predict.get_probs_for_labels(labels, r) probs = np.asarray(probs) return probs return _predict_fn def explain_tabular(self, trainset, labels, instance, num_features=5, kernel_width=3): """Explain categorical and numeric features for a prediction. It analyze the prediction by LIME, and returns a report of the most impactful tabular features contributing to certain labels. Args: trainset: a DataFrame representing the training features that LIME can use to decide value distributions. labels: a list of labels to explain. instance: the prediction instance. It needs to conform to model's input. Can be a csv line string, or a dict. num_features: maximum number of features to show. kernel_width: Passed to LIME LimeTabularExplainer directly. Returns: A LIME's lime.explanation.Explanation. """ from lime.lime_tabular import LimeTabularExplainer if isinstance(instance, six.string_types): instance = next(csv.DictReader([instance], fieldnames=self._headers)) categories = self._get_unique_categories(trainset) np_trainset = self._preprocess_data_for_tabular_explain(trainset, categories) predict_fn = self._make_tabular_predict_fn(labels, instance, categories) prediction_df = pd.DataFrame([instance]) prediction_instance = self._preprocess_data_for_tabular_explain(prediction_df, categories) explainer = LimeTabularExplainer( np_trainset, feature_names=(self._categorical_columns + self._numeric_columns), class_names=labels, categorical_features=range(len(categories)), categorical_names={i: v for i, v in enumerate(categories)}, kernel_width=kernel_width) exp = explainer.explain_instance( prediction_instance[0], predict_fn, num_features=num_features, labels=range(len(labels))) return exp def explain_text(self, labels, instance, column_name=None, num_features=10, num_samples=5000): """Explain a text field of a prediction. It analyze the prediction by LIME, and returns a report of which words are most impactful in contributing to certain labels. Args: labels: a list of labels to explain. instance: the prediction instance. It needs to conform to model's input. Can be a csv line string, or a dict. column_name: which text column to explain. Can be None if there is only one text column in the model input. num_features: maximum number of words (features) to analyze. Passed to LIME LimeTextExplainer directly. num_samples: size of the neighborhood to learn the linear model. Passed to LIME LimeTextExplainer directly. Returns: A LIME's lime.explanation.Explanation. Throws: ValueError if the given text column is not found in model input or column_name is None but there are multiple text columns in model input. """ from lime.lime_text import LimeTextExplainer if len(self._text_columns) > 1 and not column_name: raise ValueError('There are multiple text columns in the input of the model. ' + 'Please specify "column_name".') elif column_name and column_name not in self._text_columns: raise ValueError('Specified column_name "%s" not found in the model input.' % column_name) text_column_name = column_name if column_name else self._text_columns[0] if isinstance(instance, six.string_types): instance = next(csv.DictReader([instance], fieldnames=self._headers)) predict_fn = self._make_text_predict_fn(labels, instance, text_column_name) explainer = LimeTextExplainer(class_names=labels) exp = explainer.explain_instance( instance[text_column_name], predict_fn, labels=range(len(labels)), num_features=num_features, num_samples=num_samples) return exp def explain_image(self, labels, instance, column_name=None, num_features=100000, num_samples=300, batch_size=200, hide_color=0): """Explain an image of a prediction. It analyze the prediction by LIME, and returns a report of which words are most impactful in contributing to certain labels. Args: labels: a list of labels to explain. instance: the prediction instance. It needs to conform to model's input. Can be a csv line string, or a dict. column_name: which image column to explain. Can be None if there is only one image column in the model input. num_features: maximum number of areas (features) to analyze. Passed to LIME LimeImageExplainer directly. num_samples: size of the neighborhood to learn the linear model. Passed to LIME LimeImageExplainer directly. batch_size: size of batches passed to predict_fn. Passed to LIME LimeImageExplainer directly. hide_color: the color used to perturb images. Passed to LIME LimeImageExplainer directly. Returns: A LIME's lime.explanation.Explanation. Throws: ValueError if the given image column is not found in model input or column_name is None but there are multiple image columns in model input. """ from lime.lime_image import LimeImageExplainer if len(self._image_columns) > 1 and not column_name: raise ValueError('There are multiple image columns in the input of the model. ' + 'Please specify "column_name".') elif column_name and column_name not in self._image_columns: raise ValueError('Specified column_name "%s" not found in the model input.' % column_name) image_column_name = column_name if column_name else self._image_columns[0] if isinstance(instance, six.string_types): instance = next(csv.DictReader([instance], fieldnames=self._headers)) predict_fn = self._make_image_predict_fn(labels, instance, image_column_name) explainer = LimeImageExplainer() with file_io.FileIO(instance[image_column_name], 'rb') as fi: im = Image.open(fi) im.thumbnail((299, 299), Image.ANTIALIAS) rgb_im = np.asarray(im.convert('RGB')) exp = explainer.explain_instance( rgb_im, predict_fn, labels=range(len(labels)), top_labels=None, hide_color=hide_color, num_features=num_features, num_samples=num_samples, batch_size=batch_size) return exp def _image_gradients(self, input_csvlines, label, image_column_name): """Compute gradients from prob of label to image. Used by integrated gradients (probe).""" with tf.Graph().as_default() as g, tf.Session() as sess: logging_level = tf.logging.get_verbosity() try: tf.logging.set_verbosity(tf.logging.ERROR) meta_graph_pb = tf.saved_model.loader.load( sess=sess, tags=[tf.saved_model.tag_constants.SERVING], export_dir=self._model_dir) finally: tf.logging.set_verbosity(logging_level) signature = meta_graph_pb.signature_def['serving_default'] input_alias_map = {name: tensor_info_proto.name for (name, tensor_info_proto) in signature.inputs.items()} output_alias_map = {name: tensor_info_proto.name for (name, tensor_info_proto) in signature.outputs.items()} csv_tensor_name = list(input_alias_map.values())[0] # The image tensor is already built into ML Workbench graph. float_image = g.get_tensor_by_name("import/gradients_%s:0" % image_column_name) if label not in output_alias_map: raise ValueError('The label "%s" does not exist in output map.' % label) prob = g.get_tensor_by_name(output_alias_map[label]) grads = tf.gradients(prob, float_image)[0] grads_values = sess.run(fetches=grads, feed_dict={csv_tensor_name: input_csvlines}) return grads_values def probe_image(self, labels, instance, column_name=None, num_scaled_images=50, top_percent=10): """ Get pixel importance of the image. It performs pixel sensitivity analysis by showing only the most important pixels to a certain label in the image. It uses integrated gradients to measure the importance of each pixel. Args: labels: labels to compute gradients from. instance: the prediction instance. It needs to conform to model's input. Can be a csv line string, or a dict. img_column_name: the name of the image column to probe. If there is only one image column it can be None. num_scaled_images: Number of scaled images to get grads from. For example, if 10, the image will be scaled by 0.1, 0.2, ..., 0,9, 1.0 and it will produce 10 images for grads computation. top_percent: The percentile of pixels to show only. for example, if 10, only top 10% impactful pixels will be shown and rest of the pixels will be black. Returns: A tuple. First is the resized original image (299x299x3). Second is a list of the visualization with same size that highlights the most important pixels, one per each label. """ if len(self._image_columns) > 1 and not column_name: raise ValueError('There are multiple image columns in the input of the model. ' + 'Please specify "column_name".') elif column_name and column_name not in self._image_columns: raise ValueError('Specified column_name "%s" not found in the model input.' % column_name) image_column_name = column_name if column_name else self._image_columns[0] if isinstance(instance, six.string_types): instance = next(csv.DictReader([instance], fieldnames=self._headers)) image_path = instance[image_column_name] with file_io.FileIO(image_path, 'rb') as fi: im = Image.open(fi) resized_image = im.resize((299, 299)) # Produce a list of scaled images, create instances (csv lines) from these images. step = 1. / num_scaled_images scales = np.arange(0.0, 1.0, step) + step csv_lines = [] for s in scales: pixels = (np.asarray(resized_image) * s).astype('uint8') scaled_image = Image.fromarray(pixels) buf = io.BytesIO() scaled_image.save(buf, "JPEG") encoded_image = base64.urlsafe_b64encode(buf.getvalue()).decode('ascii') instance_copy = dict(instance) instance_copy[image_column_name] = encoded_image buf = six.StringIO() writer = csv.DictWriter(buf, fieldnames=self._headers, lineterminator='') writer.writerow(instance_copy) csv_lines.append(buf.getvalue()) integrated_gradients_images = [] for label in labels: # Send to tf model to get gradients. grads = self._image_gradients(csv_lines, label, image_column_name) integrated_grads = resized_image * np.average(grads, axis=0) # Gray scale the grads by removing color dimension. # abs() is for getting the most impactful pixels regardless positive or negative. grayed = np.average(abs(integrated_grads), axis=2) grayed = np.transpose([grayed, grayed, grayed], axes=[1, 2, 0]) # Only show the most impactful pixels. p = np.percentile(grayed, 100 - top_percent) viz_window = np.where(grayed > p, 1, 0) vis = resized_image * viz_window im_vis = Image.fromarray(np.uint8(vis)) integrated_gradients_images.append(im_vis) return resized_image, integrated_gradients_images ================================================ FILE: google/datalab/contrib/mlworkbench/_shell_process.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Google Cloud Platform library - ml cell magic.""" from __future__ import absolute_import from __future__ import unicode_literals import os import psutil import six import subprocess import sys def _wait_and_kill(pid_to_wait, pids_to_kill): """ Wait for a process to finish if it exists, and then kill a list of processes. Args: pid_to_wait: the process to wait for. pids_to_kill: a list of processes to kill after the process of pid_to_wait finishes. """ if psutil.pid_exists(pid_to_wait): psutil.Process(pid=pid_to_wait).wait() for pid_to_kill in pids_to_kill: if psutil.pid_exists(pid_to_kill): p = psutil.Process(pid=pid_to_kill) p.kill() def run_and_monitor(args, pid_to_wait, std_out_filter_fn=None, cwd=None): """ Start a process, and have it depend on another specified process. Args: args: the args of the process to start and monitor. pid_to_wait: the process to wait on. If the process ends, also kill the started process. std_out_filter_fn: a filter function which takes a string content from the stdout of the started process, and returns True if the string should be redirected to console stdout. cwd: the current working directory for the process to start. """ monitor_process = None try: p = subprocess.Popen(args, cwd=cwd, env=os.environ, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) pids_to_kill = [p.pid] script = ('import %s;%s._wait_and_kill(%s, %s)' % (__name__, __name__, str(pid_to_wait), str(pids_to_kill))) monitor_process = subprocess.Popen(['python', '-c', script], env=os.environ) while p.poll() is None: line = p.stdout.readline() if not six.PY2: line = line.decode() if std_out_filter_fn is None or std_out_filter_fn(line): sys.stdout.write(line) # Cannot do sys.stdout.flush(). It appears that too many flush() calls will hang browser. finally: if monitor_process: monitor_process.kill() ================================================ FILE: google/datalab/contrib/mlworkbench/commands/__init__.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from . import _ml __all__ = ['_ml'] ================================================ FILE: google/datalab/contrib/mlworkbench/commands/_ml.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Google Cloud Platform library - ml cell magic.""" from __future__ import absolute_import from __future__ import unicode_literals try: import IPython import IPython.core.display import IPython.core.magic except ImportError: raise Exception('This module can only be loaded in ipython.') import argparse import collections import json import os import pandas as pd import matplotlib.pyplot as plt import numpy as np import shutil import six from skimage.segmentation import mark_boundaries import subprocess import tempfile import textwrap import tensorflow as tf from tensorflow.python.lib.io import file_io import urllib import google.datalab from google.datalab import Context import google.datalab.ml as datalab_ml import google.datalab.utils.commands import google.datalab.contrib.mlworkbench._local_predict as _local_predict import google.datalab.contrib.mlworkbench._shell_process as _shell_process import google.datalab.contrib.mlworkbench._archive as _archive import google.datalab.contrib.mlworkbench._prediction_explainer as _prediction_explainer DEFAULT_PACKAGE_PATH = '/datalab/lib/pydatalab/solutionbox/ml_workbench/tensorflow/' @IPython.core.magic.register_line_cell_magic def ml(line, cell=None): """Implements the datalab cell magic for MLWorkbench operations. Args: line: the contents of the ml command line. Returns: The results of executing the cell. """ parser = google.datalab.utils.commands.CommandParser( prog='%ml', description=textwrap.dedent("""\ Execute MLWorkbench operations Use "%ml -h" for help on a specific command. """)) dataset_parser = parser.subcommand( 'dataset', formatter_class=argparse.RawTextHelpFormatter, help='Create or explore datasets.') dataset_sub_commands = dataset_parser.add_subparsers(dest='command') dataset_create_parser = dataset_sub_commands.add_parser( 'create', help='Create datasets', formatter_class=argparse.RawTextHelpFormatter, epilog=textwrap.dedent("""\ Example usage: %%ml dataset name: mydata format: csv train: path/to/train.csv eval: path/to/eval.csv schema: - name: news_label type: STRING - name: text type: STRING""")) dataset_create_parser.add_argument('--name', required=True, help='the name of the dataset to define. ') dataset_create_parser.add_argument('--format', required=True, choices=['csv', 'bigquery', 'transformed'], help='The format of the data.') dataset_create_parser.add_argument('--train', required=True, help='The path of the training file pattern if format ' + 'is csv or transformed, or table name if format ' + 'is bigquery.') dataset_create_parser.add_argument('--eval', required=True, help='The path of the eval file pattern if format ' + 'is csv or transformed, or table name if format ' + 'is bigquery.') dataset_create_parser.add_cell_argument('schema', help='yaml representation of CSV schema, or path to ' + 'schema file. Only needed if format is csv.') dataset_create_parser.set_defaults(func=_dataset_create) dataset_explore_parser = dataset_sub_commands.add_parser( 'explore', help='Explore training data.') dataset_explore_parser.add_argument('--name', required=True, help='The name of the dataset to explore.') dataset_explore_parser.add_argument('--overview', action='store_true', default=False, help='Plot overview of sampled data. Set "sample_size" ' + 'to change the default sample size.') dataset_explore_parser.add_argument('--facets', action='store_true', default=False, help='Plot facets view of sampled data. Set ' + '"sample_size" to change the default sample size.') dataset_explore_parser.add_argument('--sample_size', type=int, default=1000, help='sample size for overview or facets view. Only ' + 'used if either --overview or --facets is set.') dataset_explore_parser.set_defaults(func=_dataset_explore) analyze_parser = parser.subcommand( 'analyze', formatter_class=argparse.RawTextHelpFormatter, help='Analyze training data and generate stats, such as min/max/mean ' 'for numeric values, vocabulary for text columns.', epilog=textwrap.dedent("""\ Example usage: %%ml analyze [--cloud] output: path/to/dir data: $mydataset features: serialId: transform: key num1: transform: scale value: 1 num2: transform: identity text1: transform: bag_of_words Also supports in-notebook variables, such as: %%ml analyze --output path/to/dir training_data: $my_csv_dataset features: $features_def""")) analyze_parser.add_argument('--output', required=True, help='path of output directory.') analyze_parser.add_argument('--cloud', action='store_true', default=False, help='whether to run analysis in cloud or local.') analyze_parser.add_argument('--package', required=False, help='A local or GCS tarball path to use as the source. ' 'If not set, the default source package will be used.') analyze_parser.add_cell_argument( 'data', required=True, help="""Training data. A dataset defined by "%%ml dataset".""") analyze_parser.add_cell_argument( 'features', required=True, help=textwrap.dedent("""\ features config indicating how to transform data into features. The list of supported transforms: "transform: identity" does nothing (for numerical columns). "transform: scale value: x" scale a numerical column to [-a, a]. If value is missing, x defaults to 1. "transform: one_hot" treats the string column as categorical and makes one-hot encoding of it. "transform: embedding embedding_dim: d" treats the string column as categorical and makes embeddings of it with specified dimension size. "transform: bag_of_words" treats the string column as text and make bag of words transform of it. "transform: tfidf" treats the string column as text and make TFIDF transform of it. "transform: image_to_vec checkpoint: gs://b/o" from image gs url to embeddings. "checkpoint" is a inception v3 checkpoint. If absent, a default checkpoint is used. "transform: target" denotes the column is the target. If the schema type of this column is string, a one_hot encoding is automatically applied. If numerical, an identity transform is automatically applied. "transform: key" column contains metadata-like information and will be output as-is in prediction.""")) analyze_parser.set_defaults(func=_analyze) transform_parser = parser.subcommand( 'transform', formatter_class=argparse.RawTextHelpFormatter, help='Transform the data into tf.example which is more efficient in training.', epilog=textwrap.dedent("""\ Example usage: %%ml transform [--cloud] [--shuffle] analysis: path/to/analysis_output_folder output: path/to/dir batch_size: 100 data: $mydataset cloud: num_workers: 3 worker_machine_type: n1-standard-1 project_id: my_project_id""")) transform_parser.add_argument('--analysis', required=True, help='path of analysis output directory.') transform_parser.add_argument('--output', required=True, help='path of output directory.') transform_parser.add_argument('--cloud', action='store_true', default=False, help='whether to run transform in cloud or local.') transform_parser.add_argument('--shuffle', action='store_true', default=False, help='whether to shuffle the training data in output.') transform_parser.add_argument('--batch_size', type=int, default=100, help='number of instances in a batch to process once. ' 'Larger batch is more efficient but may consume more memory.') transform_parser.add_argument('--package', required=False, help='A local or GCS tarball path to use as the source. ' 'If not set, the default source package will be used.') transform_parser.add_cell_argument( 'data', required=True, help="""Training data. A dataset defined by "%%ml dataset".""") transform_parser.add_cell_argument( 'cloud_config', help=textwrap.dedent("""\ A dictionary of cloud config. All of them are optional. num_workers: Dataflow number of workers. If not set, DataFlow service will determine the number. worker_machine_type: a machine name from https://cloud.google.com/compute/docs/machine-types If not given, the service uses the default machine type. project_id: id of the project to use for DataFlow service. If not set, Datalab's default project (set by %%datalab project set) is used. job_name: Unique name for a Dataflow job to use. If not set, a random name will be used.""")) transform_parser.set_defaults(func=_transform) train_parser = parser.subcommand( 'train', formatter_class=argparse.RawTextHelpFormatter, help='Train a model.', epilog=textwrap.dedent("""\ Example usage: %%ml train [--cloud] analysis: path/to/analysis_output output: path/to/dir data: $mydataset model_args: model: linear_regression cloud_config: region: us-central1""")) train_parser.add_argument('--analysis', required=True, help='path of analysis output directory.') train_parser.add_argument('--output', required=True, help='path of trained model directory.') train_parser.add_argument('--cloud', action='store_true', default=False, help='whether to run training in cloud or local.') train_parser.add_argument('--notb', action='store_true', default=False, help='If set, tensorboard is not automatically started.') train_parser.add_argument('--package', required=False, help='A local or GCS tarball path to use as the source. ' 'If not set, the default source package will be used.') train_parser.add_cell_argument( 'data', required=True, help="""Training data. A dataset defined by "%%ml dataset".""") package_model_help = subprocess.Popen( ['python', '-m', 'trainer.task', '--datalab-help'], cwd=DEFAULT_PACKAGE_PATH, stdout=subprocess.PIPE).communicate()[0] package_model_help = ('model_args: a dictionary of model specific args, including:\n\n' + package_model_help.decode()) train_parser.add_cell_argument('model_args', help=package_model_help) train_parser.add_cell_argument( 'cloud_config', help=textwrap.dedent("""\ A dictionary of cloud training config, including: job_id: the name of the job. If not provided, a default job name is created. region: see {url} runtime_version: see "region". Must be a string like '1.2'. scale_tier: see "region".""".format( url='https://cloud.google.com/sdk/gcloud/reference/ml-engine/jobs/submit/training'))) train_parser.set_defaults(func=_train) predict_parser = parser.subcommand( 'predict', formatter_class=argparse.RawTextHelpFormatter, help='Predict with local or deployed models. (Good for small datasets).', epilog=textwrap.dedent("""\ Example usage: %%ml predict headers: key,num model: path/to/model data: - key1,value1 - key2,value2 Or, in another cell, define a list of dict: my_data = [{'key': 1, 'num': 1.2}, {'key': 2, 'num': 2.8}] Then: %%ml predict headers: key,num model: path/to/model data: $my_data""")) predict_parser.add_argument('--model', required=True, help='The model path.') predict_parser.add_argument('--no_show_image', action='store_true', default=False, help='If not set, add a column of images in output.') predict_parser.add_cell_argument( 'data', required=True, help=textwrap.dedent("""\ Prediction data can be 1) CSV lines in the input cell in yaml format or 2) a local variable which is one of a) list of dict b) list of strings of csv lines c) a Pandas DataFrame""")) predict_parser.set_defaults(func=_predict) batch_predict_parser = parser.subcommand( 'batch_predict', formatter_class=argparse.RawTextHelpFormatter, help='Batch prediction with local or deployed models. (Good for large datasets)', epilog=textwrap.dedent("""\ Example usage: %%ml batch_predict [--cloud] model: path/to/model output: path/to/output format: csv data: csv: path/to/file_pattern""")) batch_predict_parser.add_argument('--model', required=True, help='The model path if not --cloud, or the id in ' 'the form of model.version if --cloud.') batch_predict_parser.add_argument('--output', required=True, help='The path of output directory with prediction results. ' 'If --cloud, it has to be GCS path.') batch_predict_parser.add_argument('--format', help='csv or json. For cloud run, ' 'the only supported format is json.') batch_predict_parser.add_argument('--batch_size', type=int, default=100, help='number of instances in a batch to process once. ' 'Larger batch is more efficient but may consume ' 'more memory. Only used in local run.') batch_predict_parser.add_argument('--cloud', action='store_true', default=False, help='whether to run prediction in cloud or local.') batch_predict_parser.add_cell_argument( 'data', required=True, help='Data to predict with. Only csv is supported.') batch_predict_parser.add_cell_argument( 'cloud_config', help=textwrap.dedent("""\ A dictionary of cloud batch prediction config. job_id: the name of the job. If not provided, a default job name is created. region: see {url} max_worker_count: see reference in "region".""".format( url='https://cloud.google.com/sdk/gcloud/reference/ml-engine/jobs/submit/prediction'))) # noqa batch_predict_parser.set_defaults(func=_batch_predict) explain_parser = parser.subcommand( 'explain', formatter_class=argparse.RawTextHelpFormatter, help='Explain a prediction with LIME tool.') explain_parser.add_argument('--type', default='all', choices=['text', 'image', 'tabular', 'all'], help='the type of column to explain.') explain_parser.add_argument('--algorithm', choices=['lime', 'ig'], default='lime', help='"lime" is the open sourced project for prediction explainer.' + '"ig" means integrated gradients and currently only applies ' + 'to image.') explain_parser.add_argument('--model', required=True, help='path of the model directory used for prediction.') explain_parser.add_argument('--labels', required=True, help='comma separated labels to explain.') explain_parser.add_argument('--column_name', help='the name of the column to explain. Optional if text type ' + 'and there is only one text column, or image type and ' + 'there is only one image column.') explain_parser.add_cell_argument('data', required=True, help='Prediction Data. Can be a csv line, or a dict.') explain_parser.add_cell_argument('training_data', help='A csv or bigquery dataset defined by %%ml dataset. ' + 'Used by tabular explainer only to determine the ' + 'distribution of numeric and categorical values. ' + 'Suggest using original training dataset.') # options specific for lime explain_parser.add_argument('--num_features', type=int, help='number of features to analyze. In text, it is number of ' + 'words. In image, it is number of areas. For lime only.') explain_parser.add_argument('--num_samples', type=int, help='size of the neighborhood to learn the linear model. ' + 'For lime only.') explain_parser.add_argument('--hide_color', type=int, default=0, help='the color to use for perturbed area. If -1, average of ' + 'each channel is used for each channel. For image only.') explain_parser.add_argument('--include_negative', action='store_true', default=False, help='whether to show only positive areas. For lime image only.') explain_parser.add_argument('--overview', action='store_true', default=False, help='whether to show overview instead of details view.' + 'For lime text and tabular only.') explain_parser.add_argument('--batch_size', type=int, default=100, help='size of batches passed to prediction. For lime only.') # options specific for integrated gradients explain_parser.add_argument('--num_gradients', type=int, default=50, help='the number of scaled images to get gradients from. Larger ' + 'number usually produces better results but slower.') explain_parser.add_argument('--percent_show', type=int, default=10, help='the percentage of top impactful pixels to show.') explain_parser.set_defaults(func=_explain) tensorboard_parser = parser.subcommand( 'tensorboard', formatter_class=argparse.RawTextHelpFormatter, help='Start/stop/list TensorBoard instances.') tensorboard_sub_commands = tensorboard_parser.add_subparsers(dest='command') tensorboard_start_parser = tensorboard_sub_commands.add_parser( 'start', help='Start a tensorboard instance.') tensorboard_start_parser.add_argument('--logdir', required=True, help='The local or GCS logdir path.') tensorboard_start_parser.set_defaults(func=_tensorboard_start) tensorboard_stop_parser = tensorboard_sub_commands.add_parser( 'stop', help='Stop a tensorboard instance.') tensorboard_stop_parser.add_argument('--pid', required=True, type=int, help='The pid of the tensorboard instance.') tensorboard_stop_parser.set_defaults(func=_tensorboard_stop) tensorboard_list_parser = tensorboard_sub_commands.add_parser( 'list', help='List tensorboard instances.') tensorboard_list_parser.set_defaults(func=_tensorboard_list) evaluate_parser = parser.subcommand( 'evaluate', formatter_class=argparse.RawTextHelpFormatter, help='Analyze model evaluation results, such as confusion matrix, ROC, RMSE.') evaluate_sub_commands = evaluate_parser.add_subparsers(dest='command') def _add_data_params_for_evaluate(parser): parser.add_argument('--csv', help='csv file path patterns.') parser.add_argument('--headers', help='csv file headers. Required if csv is specified and ' + 'predict_results_schema.json does not exist in the same directory.') parser.add_argument('--bigquery', help='can be bigquery table, query as a string, or ' + 'a pre-defined query (%%bq query --name).') evaluate_cm_parser = evaluate_sub_commands.add_parser( 'confusion_matrix', help='Get confusion matrix from evaluation results.') _add_data_params_for_evaluate(evaluate_cm_parser) evaluate_cm_parser.add_argument('--plot', action='store_true', default=False, help='Whether to plot confusion matrix as graph.') evaluate_cm_parser.add_argument('--size', type=int, default=10, help='The size of the confusion matrix.') evaluate_cm_parser.set_defaults(func=_evaluate_cm) evaluate_accuracy_parser = evaluate_sub_commands.add_parser( 'accuracy', help='Get accuracy results from classification evaluation results.') _add_data_params_for_evaluate(evaluate_accuracy_parser) evaluate_accuracy_parser.set_defaults(func=_evaluate_accuracy) evaluate_pr_parser = evaluate_sub_commands.add_parser( 'precision_recall', help='Get precision recall metrics from evaluation results.') _add_data_params_for_evaluate(evaluate_pr_parser) evaluate_pr_parser.add_argument('--plot', action='store_true', default=False, help='Whether to plot precision recall as graph.') evaluate_pr_parser.add_argument('--num_thresholds', type=int, default=20, help='Number of thresholds which determines how many ' + 'points in the graph.') evaluate_pr_parser.add_argument('--target_class', required=True, help='The target class to determine correctness of ' + 'a prediction.') evaluate_pr_parser.add_argument('--probability_column', help='The name of the column holding the probability ' + 'value of the target class. If absent, the value ' + 'of target class is used.') evaluate_pr_parser.set_defaults(func=_evaluate_pr) evaluate_roc_parser = evaluate_sub_commands.add_parser( 'roc', help='Get ROC metrics from evaluation results.') _add_data_params_for_evaluate(evaluate_roc_parser) evaluate_roc_parser.add_argument('--plot', action='store_true', default=False, help='Whether to plot ROC as graph.') evaluate_roc_parser.add_argument('--num_thresholds', type=int, default=20, help='Number of thresholds which determines how many ' + 'points in the graph.') evaluate_roc_parser.add_argument('--target_class', required=True, help='The target class to determine correctness of ' + 'a prediction.') evaluate_roc_parser.add_argument('--probability_column', help='The name of the column holding the probability ' + 'value of the target class. If absent, the value ' + 'of target class is used.') evaluate_roc_parser.set_defaults(func=_evaluate_roc) evaluate_regression_parser = evaluate_sub_commands.add_parser( 'regression', help='Get regression metrics from evaluation results.') _add_data_params_for_evaluate(evaluate_regression_parser) evaluate_regression_parser.set_defaults(func=_evaluate_regression) model_parser = parser.subcommand( 'model', help='Models and versions management such as deployment, deletion, listing.') model_sub_commands = model_parser.add_subparsers(dest='command') model_list_parser = model_sub_commands.add_parser( 'list', help='List models and versions.') model_list_parser.add_argument('--name', help='If absent, list all models of specified or current ' + 'project. If provided, list all versions of the ' + 'model.') model_list_parser.add_argument('--project', help='The project to list model(s) or version(s). If absent, ' + 'use Datalab\'s default project.') model_list_parser.set_defaults(func=_model_list) model_delete_parser = model_sub_commands.add_parser( 'delete', help='Delete models or versions.') model_delete_parser.add_argument('--name', required=True, help='If no "." in the name, try deleting the specified ' + 'model. If "model.version" is provided, try deleting ' + 'the specified version.') model_delete_parser.add_argument('--project', help='The project to delete model or version. If absent, ' + 'use Datalab\'s default project.') model_delete_parser.set_defaults(func=_model_delete) model_deploy_parser = model_sub_commands.add_parser( 'deploy', help='Deploy a model version.') model_deploy_parser.add_argument('--name', required=True, help='Must be model.version to indicate the model ' + 'and version name to deploy.') model_deploy_parser.add_argument('--path', required=True, help='The GCS path of the model to be deployed.') model_deploy_parser.add_argument('--runtime_version', help='The TensorFlow version to use for this model. ' + 'For example, "1.2.1". If absent, the current ' + 'TensorFlow version installed in Datalab will be used.') model_deploy_parser.add_argument('--project', help='The project to deploy a model version. If absent, ' + 'use Datalab\'s default project.') model_deploy_parser.set_defaults(func=_model_deploy) return google.datalab.utils.commands.handle_magic_line(line, cell, parser) DataSet = collections.namedtuple('DataSet', ['train', 'eval']) def _abs_path(path): """Convert a non-GCS path to its absolute path. path can contain special filepath characters like '..', '*' and '.'. Example: If the current folder is /content/datalab/folder1 and path is '../folder2/files*', then this function returns the string '/content/datalab/folder2/files*'. This function is needed if using _shell_process.run_and_monitor() as that function runs a command in a different folder. Args: path: string. """ if path.startswith('gs://'): return path return os.path.abspath(path) def _create_json_file(tmpdir, data, filename): json_file = os.path.join(tmpdir, filename) with file_io.FileIO(json_file, 'w') as f: json.dump(data, f) return json_file def _show_job_link(job): log_url_query_strings = { 'project': Context.default().project_id, 'resource': 'ml.googleapis.com/job_id/' + job.info['jobId'] } log_url = 'https://console.developers.google.com/logs/viewer?' + \ urllib.urlencode(log_url_query_strings) html = 'Job "%s" submitted.' % job.info['jobId'] html += '

Click here to view cloud log.
' % log_url IPython.display.display_html(html, raw=True) def get_dataset_from_arg(dataset_arg): if isinstance(dataset_arg, DataSet): return dataset_arg if isinstance(dataset_arg, six.string_types): return google.datalab.utils.commands.notebook_environment()[dataset_arg] raise ValueError('Invalid dataset reference "%s". ' % dataset_arg + 'Expect a dataset defined with "%%ml dataset create".') def _analyze(args, cell): # For now, always run python2. If needed we can run python3 when the current kernel # is py3. Since now our transform cannot work on py3 anyway, I would rather run # everything with python2. cmd_args = ['python', 'analyze.py', '--output', _abs_path(args['output'])] if args['cloud']: cmd_args.append('--cloud') training_data = get_dataset_from_arg(args['data']) if args['cloud']: tmpdir = os.path.join(args['output'], 'tmp') else: tmpdir = tempfile.mkdtemp() try: if isinstance(training_data.train, datalab_ml.CsvDataSet): csv_data = training_data.train schema_file = _create_json_file(tmpdir, csv_data.schema, 'schema.json') for file_name in csv_data.input_files: cmd_args.append('--csv=' + _abs_path(file_name)) cmd_args.extend(['--schema', schema_file]) elif isinstance(training_data.train, datalab_ml.BigQueryDataSet): bq_data = training_data.train cmd_args.extend(['--bigquery', bq_data.table]) else: raise ValueError('Unexpected training data type. Only csv or bigquery are supported.') features = args['features'] features_file = _create_json_file(tmpdir, features, 'features.json') cmd_args.extend(['--features', features_file]) if args['package']: code_path = os.path.join(tmpdir, 'package') _archive.extract_archive(args['package'], code_path) else: code_path = DEFAULT_PACKAGE_PATH _shell_process.run_and_monitor(cmd_args, os.getpid(), cwd=code_path) finally: file_io.delete_recursively(tmpdir) def _transform(args, cell): if args['cloud_config'] and not args['cloud']: raise ValueError('"cloud_config" is provided but no "--cloud". ' 'Do you want local run or cloud run?') cmd_args = ['python', 'transform.py', '--output', _abs_path(args['output']), '--analysis', _abs_path(args['analysis'])] if args['cloud']: cmd_args.append('--cloud') cmd_args.append('--async') if args['shuffle']: cmd_args.append('--shuffle') if args['batch_size']: cmd_args.extend(['--batch-size', str(args['batch_size'])]) cloud_config = args['cloud_config'] if cloud_config: google.datalab.utils.commands.validate_config( cloud_config, required_keys=[], optional_keys=['num_workers', 'worker_machine_type', 'project_id', 'job_name']) if 'num_workers' in cloud_config: cmd_args.extend(['--num-workers', str(cloud_config['num_workers'])]) if 'worker_machine_type' in cloud_config: cmd_args.extend(['--worker-machine-type', cloud_config['worker_machine_type']]) if 'project_id' in cloud_config: cmd_args.extend(['--project-id', cloud_config['project_id']]) if 'job_name' in cloud_config: cmd_args.extend(['--job-name', cloud_config['job_name']]) if args['cloud'] and (not cloud_config or 'project_id' not in cloud_config): cmd_args.extend(['--project-id', google.datalab.Context.default().project_id]) training_data = get_dataset_from_arg(args['data']) data_names = ('train', 'eval') for name in data_names: cmd_args_copy = list(cmd_args) if isinstance(getattr(training_data, name), datalab_ml.CsvDataSet): for file_name in getattr(training_data, name).input_files: cmd_args_copy.append('--csv=' + _abs_path(file_name)) elif isinstance(getattr(training_data, name), datalab_ml.BigQueryDataSet): cmd_args_copy.extend(['--bigquery', getattr(training_data, name).table]) else: raise ValueError('Unexpected training data type. Only csv or bigquery are supported.') cmd_args_copy.extend(['--prefix', name]) try: tmpdir = None if args['package']: tmpdir = tempfile.mkdtemp() code_path = os.path.join(tmpdir, 'package') _archive.extract_archive(args['package'], code_path) else: code_path = DEFAULT_PACKAGE_PATH _shell_process.run_and_monitor(cmd_args_copy, os.getpid(), cwd=code_path) finally: if tmpdir: shutil.rmtree(tmpdir) def _train(args, cell): if args['cloud_config'] and not args['cloud']: raise ValueError('"cloud_config" is provided but no "--cloud". ' 'Do you want local run or cloud run?') job_args = ['--job-dir', _abs_path(args['output']), '--analysis', _abs_path(args['analysis'])] training_data = get_dataset_from_arg(args['data']) data_names = ('train', 'eval') for name in data_names: if (isinstance(getattr(training_data, name), datalab_ml.CsvDataSet) or isinstance(getattr(training_data, name), datalab_ml.TransformedDataSet)): for file_name in getattr(training_data, name).input_files: job_args.append('--%s=%s' % (name, _abs_path(file_name))) else: raise ValueError('Unexpected training data type. ' + 'Only csv and transformed type are supported.') if isinstance(training_data.train, datalab_ml.CsvDataSet): job_args.append('--transform') # TODO(brandondutra) document that any model_args that are file paths must # be given as an absolute path if args['model_args']: for k, v in six.iteritems(args['model_args']): job_args.extend(['--' + k, str(v)]) try: tmpdir = None if args['package']: tmpdir = tempfile.mkdtemp() code_path = os.path.join(tmpdir, 'package') _archive.extract_archive(args['package'], code_path) else: code_path = DEFAULT_PACKAGE_PATH if args['cloud']: cloud_config = args['cloud_config'] if not args['output'].startswith('gs://'): raise ValueError('Cloud training requires a GCS (starting with "gs://") output.') staging_tarball = os.path.join(args['output'], 'staging', 'trainer.tar.gz') datalab_ml.package_and_copy(code_path, os.path.join(code_path, 'setup.py'), staging_tarball) job_request = { 'package_uris': [staging_tarball], 'python_module': 'trainer.task', 'job_dir': args['output'], 'args': job_args, } job_request.update(cloud_config) job_id = cloud_config.get('job_id', None) job = datalab_ml.Job.submit_training(job_request, job_id) _show_job_link(job) if not args['notb']: datalab_ml.TensorBoard.start(args['output']) else: cmd_args = ['python', '-m', 'trainer.task'] + job_args if not args['notb']: datalab_ml.TensorBoard.start(args['output']) _shell_process.run_and_monitor(cmd_args, os.getpid(), cwd=code_path) finally: if tmpdir: shutil.rmtree(tmpdir) def _predict(args, cell): schema, features = _local_predict.get_model_schema_and_features(args['model']) headers = [x['name'] for x in schema] img_cols = [] for k, v in six.iteritems(features): if v['transform'] in ['image_to_vec']: img_cols.append(v['source_column']) data = args['data'] df = _local_predict.get_prediction_results( args['model'], data, headers, img_cols=img_cols, cloud=False, show_image=not args['no_show_image']) def _show_img(img_bytes): return '' def _truncate_text(text): return (text[:37] + '...') if isinstance(text, six.string_types) and len(text) > 40 else text # Truncate text explicitly here because we will set display.max_colwidth to -1. # This applies to images to but images will be overriden with "_show_img()" later. formatters = {x: _truncate_text for x in df.columns if df[x].dtype == np.object} if not args['no_show_image'] and img_cols: formatters.update({x + '_image': _show_img for x in img_cols}) # Set display.max_colwidth to -1 so we can display images. old_width = pd.get_option('display.max_colwidth') pd.set_option('display.max_colwidth', -1) try: IPython.display.display(IPython.display.HTML( df.to_html(formatters=formatters, escape=False, index=False))) finally: pd.set_option('display.max_colwidth', old_width) def _batch_predict(args, cell): if args['cloud_config'] and not args['cloud']: raise ValueError('"cloud_config" is provided but no "--cloud". ' 'Do you want local run or cloud run?') if args['cloud']: job_request = { 'data_format': 'TEXT', 'input_paths': file_io.get_matching_files(args['data']['csv']), 'output_path': args['output'], } if args['model'].startswith('gs://'): job_request['uri'] = args['model'] else: parts = args['model'].split('.') if len(parts) != 2: raise ValueError('Invalid model name for cloud prediction. Use "model.version".') version_name = ('projects/%s/models/%s/versions/%s' % (Context.default().project_id, parts[0], parts[1])) job_request['version_name'] = version_name cloud_config = args['cloud_config'] or {} job_id = cloud_config.pop('job_id', None) job_request.update(cloud_config) job = datalab_ml.Job.submit_batch_prediction(job_request, job_id) _show_job_link(job) else: print('local prediction...') _local_predict.local_batch_predict(args['model'], args['data']['csv'], args['output'], args['format'], args['batch_size']) print('done.') # Helper classes for explainer. Each for is for a combination # of algorithm (LIME, IG) and type (text, image, tabular) # =========================================================== class _TextLimeExplainerInstance(object): def __init__(self, explainer, labels, args): num_features = args['num_features'] if args['num_features'] else 10 num_samples = args['num_samples'] if args['num_samples'] else 5000 self._exp = explainer.explain_text( labels, args['data'], column_name=args['column_name'], num_features=num_features, num_samples=num_samples) self._col_name = args['column_name'] if args['column_name'] else explainer._text_columns[0] self._show_overview = args['overview'] def visualize(self, label_index): if self._show_overview: IPython.display.display( IPython.display.HTML('
Text Column "%s"
' % self._col_name)) self._exp.show_in_notebook(labels=[label_index]) else: fig = self._exp.as_pyplot_figure(label=label_index) # Clear original title set by lime. plt.title('') fig.suptitle('Text Column "%s"' % self._col_name, fontsize=16) plt.close(fig) IPython.display.display(fig) class _ImageLimeExplainerInstance(object): def __init__(self, explainer, labels, args): num_samples = args['num_samples'] if args['num_samples'] else 300 hide_color = None if args['hide_color'] == -1 else args['hide_color'] self._exp = explainer.explain_image( labels, args['data'], column_name=args['column_name'], num_samples=num_samples, batch_size=args['batch_size'], hide_color=hide_color) self._labels = labels self._positive_only = not args['include_negative'] self._num_features = args['num_features'] if args['num_features'] else 3 self._col_name = args['column_name'] if args['column_name'] else explainer._image_columns[0] def visualize(self, label_index): image, mask = self._exp.get_image_and_mask( label_index, positive_only=self._positive_only, num_features=self._num_features, hide_rest=False) fig = plt.figure() fig.suptitle('Image Column "%s"' % self._col_name, fontsize=16) plt.grid(False) plt.imshow(mark_boundaries(image, mask)) plt.close(fig) IPython.display.display(fig) class _ImageIgExplainerInstance(object): def __init__(self, explainer, labels, args): self._raw_image, self._analysis_images = explainer.probe_image( labels, args['data'], column_name=args['column_name'], num_scaled_images=args['num_gradients'], top_percent=args['percent_show']) self._labels = labels self._col_name = args['column_name'] if args['column_name'] else explainer._image_columns[0] def visualize(self, label_index): # Show both resized raw image and analyzed image. fig = plt.figure() fig.suptitle('Image Column "%s"' % self._col_name, fontsize=16) plt.grid(False) plt.imshow(self._analysis_images[label_index]) plt.close(fig) IPython.display.display(fig) class _TabularLimeExplainerInstance(object): def __init__(self, explainer, labels, args): if not args['training_data']: raise ValueError('tabular explanation requires training_data to determine ' + 'values distribution.') training_data = get_dataset_from_arg(args['training_data']) if (not isinstance(training_data.train, datalab_ml.CsvDataSet) and not isinstance(training_data.train, datalab_ml.BigQueryDataSet)): raise ValueError('Require csv or bigquery dataset.') sample_size = min(training_data.train.size, 10000) training_df = training_data.train.sample(sample_size) num_features = args['num_features'] if args['num_features'] else 5 self._exp = explainer.explain_tabular(training_df, labels, args['data'], num_features=num_features) self._show_overview = args['overview'] def visualize(self, label_index): if self._show_overview: IPython.display.display( IPython.display.HTML('
All Categorical and Numeric Columns
')) self._exp.show_in_notebook(labels=[label_index]) else: fig = self._exp.as_pyplot_figure(label=label_index) # Clear original title set by lime. plt.title('') fig.suptitle(' All Categorical and Numeric Columns', fontsize=16) plt.close(fig) IPython.display.display(fig) # End of Explainer Helper Classes # =================================================== def _explain(args, cell): explainer = _prediction_explainer.PredictionExplainer(args['model']) labels = args['labels'].split(',') instances = [] if args['type'] == 'all': if explainer._numeric_columns or explainer._categorical_columns: instances.append(_TabularLimeExplainerInstance(explainer, labels, args)) for col_name in explainer._text_columns: args['column_name'] = col_name instances.append(_TextLimeExplainerInstance(explainer, labels, args)) for col_name in explainer._image_columns: args['column_name'] = col_name if args['algorithm'] == 'lime': instances.append(_ImageLimeExplainerInstance(explainer, labels, args)) elif args['algorithm'] == 'ig': instances.append(_ImageIgExplainerInstance(explainer, labels, args)) elif args['type'] == 'text': instances.append(_TextLimeExplainerInstance(explainer, labels, args)) elif args['type'] == 'image' and args['algorithm'] == 'lime': instances.append(_ImageLimeExplainerInstance(explainer, labels, args)) elif args['type'] == 'image' and args['algorithm'] == 'ig': instances.append(_ImageIgExplainerInstance(explainer, labels, args)) elif args['type'] == 'tabular': instances.append(_TabularLimeExplainerInstance(explainer, labels, args)) for i, label in enumerate(labels): IPython.display.display( IPython.display.HTML('
Explaining features for label "%s"
' % label)) for instance in instances: instance.visualize(i) def _tensorboard_start(args, cell): datalab_ml.TensorBoard.start(args['logdir']) def _tensorboard_stop(args, cell): datalab_ml.TensorBoard.stop(args['pid']) def _tensorboard_list(args, cell): return datalab_ml.TensorBoard.list() def _get_evaluation_csv_schema(csv_file): # ML Workbench produces predict_results_schema.json in local batch prediction. schema_file = os.path.join(os.path.dirname(csv_file), 'predict_results_schema.json') if not file_io.file_exists(schema_file): raise ValueError('csv data requires headers.') return schema_file def _evaluate_cm(args, cell): if args['csv']: if args['headers']: headers = args['headers'].split(',') cm = datalab_ml.ConfusionMatrix.from_csv(args['csv'], headers=headers) else: schema_file = _get_evaluation_csv_schema(args['csv']) cm = datalab_ml.ConfusionMatrix.from_csv(args['csv'], schema_file=schema_file) elif args['bigquery']: cm = datalab_ml.ConfusionMatrix.from_bigquery(args['bigquery']) else: raise ValueError('Either csv or bigquery is needed.') if args['plot']: return cm.plot(figsize=(args['size'], args['size']), rotation=90) else: return cm.to_dataframe() def _create_metrics(args): if args['csv']: if args['headers']: headers = args['headers'].split(',') metrics = datalab_ml.Metrics.from_csv(args['csv'], headers=headers) else: schema_file = _get_evaluation_csv_schema(args['csv']) metrics = datalab_ml.Metrics.from_csv(args['csv'], schema_file=schema_file) elif args['bigquery']: metrics = datalab_ml.Metrics.from_bigquery(args['bigquery']) else: raise ValueError('Either csv or bigquery is needed.') return metrics def _evaluate_accuracy(args, cell): metrics = _create_metrics(args) return metrics.accuracy() def _evaluate_regression(args, cell): metrics = _create_metrics(args) metrics_dict = [] metrics_dict.append({ 'metric': 'Root Mean Square Error', 'value': metrics.rmse() }) metrics_dict.append({ 'metric': 'Mean Absolute Error', 'value': metrics.mae() }) metrics_dict.append({ 'metric': '50 Percentile Absolute Error', 'value': metrics.percentile_nearest(50) }) metrics_dict.append({ 'metric': '90 Percentile Absolute Error', 'value': metrics.percentile_nearest(90) }) metrics_dict.append({ 'metric': '99 Percentile Absolute Error', 'value': metrics.percentile_nearest(99) }) return pd.DataFrame(metrics_dict) def _evaluate_pr(args, cell): metrics = _create_metrics(args) df = metrics.precision_recall(args['num_thresholds'], args['target_class'], probability_column=args['probability_column']) if args['plot']: plt.plot(df['recall'], df['precision'], label='Precision-Recall curve for class ' + args['target_class']) plt.xlabel('Recall') plt.ylabel('Precision') plt.ylim([0.0, 1.05]) plt.xlim([0.0, 1.0]) plt.title('Precision-Recall') plt.legend(loc="lower left") plt.show() else: return df def _evaluate_roc(args, cell): metrics = _create_metrics(args) df = metrics.roc(args['num_thresholds'], args['target_class'], probability_column=args['probability_column']) if args['plot']: plt.plot(df['fpr'], df['tpr'], label='ROC curve for class ' + args['target_class']) plt.xlabel('fpr') plt.ylabel('tpr') plt.ylim([0.0, 1.05]) plt.xlim([0.0, 1.0]) plt.title('ROC') plt.legend(loc="lower left") plt.show() else: return df def _model_list(args, cell): if args['name']: # model name provided. List versions of that model. versions = datalab_ml.ModelVersions(args['name'], project_id=args['project']) versions = list(versions.get_iterator()) df = pd.DataFrame(versions) df['name'] = df['name'].apply(lambda x: x.split('/')[-1]) df = df.replace(np.nan, '', regex=True) return df else: # List all models. models = list(datalab_ml.Models(project_id=args['project']).get_iterator()) if len(models) > 0: df = pd.DataFrame(models) df['name'] = df['name'].apply(lambda x: x.split('/')[-1]) df['defaultVersion'] = df['defaultVersion'].apply(lambda x: x['name'].split('/')[-1]) df = df.replace(np.nan, '', regex=True) return df else: print('No models found.') def _model_delete(args, cell): parts = args['name'].split('.') if len(parts) == 1: models = datalab_ml.Models(project_id=args['project']) models.delete(parts[0]) elif len(parts) == 2: versions = datalab_ml.ModelVersions(parts[0], project_id=args['project']) versions.delete(parts[1]) else: raise ValueError('Too many "." in name. Use "model" or "model.version".') def _model_deploy(args, cell): parts = args['name'].split('.') if len(parts) == 2: model_name, version_name = parts[0], parts[1] model_exists = False try: # If describe() works, the model already exists. datalab_ml.Models(project_id=args['project']).get_model_details(model_name) model_exists = True except: pass if not model_exists: datalab_ml.Models(project_id=args['project']).create(model_name) versions = datalab_ml.ModelVersions(model_name, project_id=args['project']) runtime_version = args['runtime_version'] if not runtime_version: runtime_version = tf.__version__ versions.deploy(version_name, args['path'], runtime_version=runtime_version) else: raise ValueError('Name must be like "model.version".') def _dataset_create(args, cell): if args['format'] == 'csv': if not args['schema']: raise ValueError('schema is required if format is csv.') schema, schema_file = None, None if isinstance(args['schema'], six.string_types): schema_file = args['schema'] elif isinstance(args['schema'], list): schema = args['schema'] else: raise ValueError('schema should either be a file path, or a dictionary.') train_dataset = datalab_ml.CsvDataSet(args['train'], schema=schema, schema_file=schema_file) eval_dataset = datalab_ml.CsvDataSet(args['eval'], schema=schema, schema_file=schema_file) elif args['format'] == 'bigquery': train_dataset = datalab_ml.BigQueryDataSet(table=args['train']) eval_dataset = datalab_ml.BigQueryDataSet(table=args['eval']) elif args['format'] == 'transformed': train_dataset = datalab_ml.TransformedDataSet(args['train']) eval_dataset = datalab_ml.TransformedDataSet(args['eval']) else: raise ValueError('Invalid data format.') dataset = DataSet(train_dataset, eval_dataset) google.datalab.utils.commands.notebook_environment()[args['name']] = dataset def _dataset_explore(args, cell): dataset = get_dataset_from_arg(args['name']) print('train data instances: %d' % dataset.train.size) print('eval data instances: %d' % dataset.eval.size) if args['overview'] or args['facets']: if isinstance(dataset.train, datalab_ml.TransformedDataSet): raise ValueError('transformed data does not support overview or facets.') print('Sampled %s instances for each.' % args['sample_size']) sample_train_df = dataset.train.sample(args['sample_size']) sample_eval_df = dataset.eval.sample(args['sample_size']) if args['overview']: overview = datalab_ml.FacetsOverview().plot({'train': sample_train_df, 'eval': sample_eval_df}) IPython.display.display(overview) if args['facets']: sample_train_df['_source'] = pd.Series(['train'] * len(sample_train_df), index=sample_train_df.index) sample_eval_df['_source'] = pd.Series(['eval'] * len(sample_eval_df), index=sample_eval_df.index) df_merged = pd.concat([sample_train_df, sample_eval_df]) diveview = datalab_ml.FacetsDiveview().plot(df_merged) IPython.display.display(diveview) ================================================ FILE: google/datalab/contrib/pipeline/__init__.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: google/datalab/contrib/pipeline/_pipeline.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. import datetime import google import google.datalab.bigquery as bigquery import sys # Any operators need to be imported here. This is required for dynamically getting the list of # templated fields from the operators. Static code-analysis will report that this is not # necessary, hence the '# noqa' annotations from google.datalab.contrib.bigquery.operators._bq_load_operator import LoadOperator # noqa from google.datalab.contrib.bigquery.operators._bq_execute_operator import ExecuteOperator # noqa from google.datalab.contrib.bigquery.operators._bq_extract_operator import ExtractOperator # noqa class PipelineGenerator(object): """ Represents a Pipeline object that encapsulates an Airflow pipeline spec. This object can be used to generate the python airflow spec. """ _imports = """ import datetime from airflow import DAG from airflow.operators.bash_operator import BashOperator from airflow.contrib.operators.bigquery_operator import BigQueryOperator from airflow.contrib.operators.bigquery_table_delete_operator import BigQueryTableDeleteOperator from airflow.contrib.operators.bigquery_to_bigquery import BigQueryToBigQueryOperator from airflow.contrib.operators.bigquery_to_gcs import BigQueryToCloudStorageOperator from airflow.contrib.operators.gcs_to_bq import GoogleCloudStorageToBigQueryOperator from google.datalab.contrib.bigquery.operators._bq_load_operator import LoadOperator from google.datalab.contrib.bigquery.operators._bq_execute_operator import ExecuteOperator from google.datalab.contrib.bigquery.operators._bq_extract_operator import ExtractOperator from datetime import timedelta """ @staticmethod def generate_airflow_spec(name, pipeline_spec): """ Gets the airflow python spec for the Pipeline object. """ task_definitions = '' up_steam_statements = '' parameters = pipeline_spec.get('parameters') for (task_id, task_details) in sorted(pipeline_spec['tasks'].items()): task_def = PipelineGenerator._get_operator_definition(task_id, task_details, parameters) task_definitions = task_definitions + task_def dependency_def = PipelineGenerator._get_dependency_definition( task_id, task_details.get('up_stream', [])) up_steam_statements = up_steam_statements + dependency_def schedule_config = pipeline_spec.get('schedule', {}) default_args = PipelineGenerator._get_default_args(schedule_config, pipeline_spec.get('emails', {})) dag_definition = PipelineGenerator._get_dag_definition( name, schedule_config.get('interval', '@once'), schedule_config.get('catchup', False)) return PipelineGenerator._imports + default_args + dag_definition + task_definitions + \ up_steam_statements @staticmethod def _get_default_args(schedule_config, emails): start_datetime_obj = schedule_config.get('start', datetime.datetime.now()) end_datetime_obj = schedule_config.get('end') start_date_str = PipelineGenerator._get_datetime_expr_str(start_datetime_obj) end_date_str = PipelineGenerator._get_datetime_expr_str(end_datetime_obj) default_arg_literals = """ 'owner': 'Google Cloud Datalab', 'email': {0}, 'start_date': {1}, 'end_date': {2}, """.format(emails.split(',') if emails else [], start_date_str, end_date_str) configurable_keys = ['email_on_retry', 'email_on_failure', 'retries', 'retry_exponential_backoff'] for configurable_key in configurable_keys: if configurable_key in schedule_config: default_arg_literals = default_arg_literals + """ \'{0}\': {1}, """.format(configurable_key, schedule_config.get(configurable_key)) # We deal with these separately as they need to be timedelta literals. retry_delay_keys = ['retry_delay_seconds', 'max_retry_delay_seconds'] for retry_delay_key in retry_delay_keys: if retry_delay_key in schedule_config: default_arg_literals = default_arg_literals + """ \'{0}\': timedelta(seconds={1}), """.format(retry_delay_key[:-8], schedule_config.get(retry_delay_key)) return """ default_args = {{{0}}} """.format(default_arg_literals) @staticmethod def _get_datetime_expr_str(datetime_obj): if not datetime_obj: return None # Apache Airflow assumes that all times are timezone-unaware, and are in UTC: # https: // issues.apache.org / jira / browse / AIRFLOW - 1710 # Somewhat conveniently, yaml.load() recognizes and parses strings that look like datetimes # into timezone unaware datetime objects (if the user input specifies the timezone, it's # corrected and the result is assumed to be in UTC). # Here, we serialize this object into the format laid down by ISO 8601, and generate python code # that parses this format into a datetime object for Airflow. datetime_format = '%Y-%m-%dT%H:%M:%S' # ISO 8601, timezone unaware expr_format = 'datetime.datetime.strptime(\'{0}\', \'{1}\')' return expr_format.format(datetime_obj.strftime(datetime_format), datetime_format) @staticmethod def _get_operator_definition(task_id, task_details, parameters): """ Internal helper that gets the definition of the airflow operator for the task with the python parameters. All the parameters are also expanded with the airflow macros. :param parameters: """ operator_type = task_details['type'] full_param_string = 'task_id=\'{0}_id\''.format(task_id) operator_class_name, module = PipelineGenerator._get_operator_class_name(operator_type) operator_class_instance = getattr(sys.modules[module], operator_class_name, None) templated_fields = operator_class_instance.template_fields if operator_class_instance else () operator_param_values = PipelineGenerator._get_operator_param_name_and_values( operator_class_name, task_details) # This loop resolves all the macros and builds up the final string merged_parameters = google.datalab.bigquery.Query.merge_parameters( parameters, date_time=datetime.datetime.now(), macros=True, types_and_values=False) for (operator_param_name, operator_param_value) in sorted(operator_param_values.items()): # We replace modifiers in the parameter values with either the user-defined values, or with # with the airflow macros, as applicable. # An important assumption that this makes is that the operators parameters have the same names # as the templated_fields. TODO(rajivpb): There may be a better way to do this. if operator_param_name in templated_fields: operator_param_value = google.datalab.bigquery.Query._resolve_parameters( operator_param_value, merged_parameters) param_format_string = PipelineGenerator._get_param_format_string(operator_param_value) param_string = param_format_string.format(operator_param_name, operator_param_value) full_param_string = full_param_string + param_string return '{0} = {1}({2}, dag=dag)\n'.format(task_id, operator_class_name, full_param_string) @staticmethod def _get_param_format_string(param_value): # If the type is a python non-string (best guess), we don't quote it. if type(param_value) in [int, bool, float, type(None), list, dict]: return ', {0}={1}' return ', {0}="""{1}"""' @staticmethod def _get_dag_definition(name, schedule_interval, catchup=False): dag_definition = 'dag = DAG(dag_id=\'{0}\', schedule_interval=\'{1}\', ' \ 'catchup={2}, default_args=default_args)\n\n'.format(name, schedule_interval, catchup) return dag_definition @staticmethod def _get_dependency_definition(task_id, dependencies): """ Internal helper collects all the dependencies of the task, and returns the Airflow equivalent python sytax for specifying them. """ set_upstream_statements = '' for dependency in dependencies: set_upstream_statements = set_upstream_statements + \ '{0}.set_upstream({1})'.format(task_id, dependency) + '\n' return set_upstream_statements @staticmethod def _get_operator_class_name(task_detail_type): """ Internal helper gets the name of the Airflow operator class. We maintain this in a map, so this method really returns the enum name, concatenated with the string "Operator". """ # TODO(rajivpb): Rename this var correctly. task_type_to_operator_prefix_mapping = { 'pydatalab.bq.execute': ('Execute', 'google.datalab.contrib.bigquery.operators._bq_execute_operator'), 'pydatalab.bq.extract': ('Extract', 'google.datalab.contrib.bigquery.operators._bq_extract_operator'), 'pydatalab.bq.load': ('Load', 'google.datalab.contrib.bigquery.operators._bq_load_operator'), 'Bash': ('Bash', 'airflow.operators.bash_operator') } (operator_class_prefix, module) = task_type_to_operator_prefix_mapping.get( task_detail_type, (None, __name__)) format_string = '{0}Operator' operator_class_name = format_string.format(operator_class_prefix) if operator_class_prefix is None: return format_string.format(task_detail_type), module return operator_class_name, module @staticmethod def _get_operator_param_name_and_values(operator_class_name, task_details): """ Internal helper gets the name of the python parameter for the Airflow operator class. In some cases, we do not expose the airflow parameter name in its native form, but choose to expose a name that's more standard for Datalab, or one that's more friendly. For example, Airflow's BigQueryOperator uses 'bql' for the query string, but we want %%bq users in Datalab to use 'query'. Hence, a few substitutions that are specific to the Airflow operator need to be made. Similarly, we the parameter value could come from the notebook's context. All that happens here. Returns: Dict containing _only_ the keys and values that are required in Airflow operator definition. This requires a substituting existing keys in the dictionary with their Airflow equivalents ( i.e. by adding new keys, and removing the existing ones). """ # We make a clone and then remove 'type' and 'up_stream' since these aren't needed for the # the operator's parameters. operator_task_details = task_details.copy() if 'type' in operator_task_details.keys(): del operator_task_details['type'] if 'up_stream' in operator_task_details.keys(): del operator_task_details['up_stream'] # We special-case certain operators if we do some translation of the parameter names. This is # usually the case when we use syntactic sugar to expose the functionality. # TODO(rajivpb): It should be possible to make this a lookup from the modules mapping via # getattr() or equivalent. Avoid hard-coding these class-names here. if (operator_class_name == 'BigQueryOperator'): return PipelineGenerator._get_bq_execute_params(operator_task_details) if (operator_class_name == 'BigQueryToCloudStorageOperator'): return PipelineGenerator._get_bq_extract_params(operator_task_details) if (operator_class_name == 'GoogleCloudStorageToBigQueryOperator'): return PipelineGenerator._get_bq_load_params(operator_task_details) return operator_task_details @staticmethod def _get_bq_execute_params(operator_task_details): if 'query' in operator_task_details: operator_task_details['bql'] = operator_task_details['query'].sql del operator_task_details['query'] if 'parameters' in operator_task_details: operator_task_details['query_params'] = bigquery.Query.get_query_parameters( operator_task_details['parameters']) del operator_task_details['parameters'] # Add over-rides of Airflow defaults here. if 'use_legacy_sql' not in operator_task_details: operator_task_details['use_legacy_sql'] = False return operator_task_details @staticmethod def _get_bq_extract_params(operator_task_details): if 'table' in operator_task_details: table = bigquery.commands._bigquery._get_table(operator_task_details['table']) operator_task_details['source_project_dataset_table'] = table.full_name del operator_task_details['table'] if 'path' in operator_task_details: operator_task_details['destination_cloud_storage_uris'] = [operator_task_details['path']] del operator_task_details['path'] if 'format' in operator_task_details: operator_task_details['export_format'] = 'CSV' if operator_task_details['format'] == 'csv' \ else 'NEWLINE_DELIMITED_JSON' del operator_task_details['format'] if 'delimiter' in operator_task_details: operator_task_details['field_delimiter'] = operator_task_details['delimiter'] del operator_task_details['delimiter'] if 'compress' in operator_task_details: operator_task_details['compression'] = 'GZIP' if operator_task_details['compress'] else 'NONE' del operator_task_details['compress'] if 'header' in operator_task_details: operator_task_details['print_header'] = operator_task_details['header'] del operator_task_details['header'] return operator_task_details @staticmethod def _get_bq_load_params(operator_task_details): if 'table' in operator_task_details: table = bigquery.commands._bigquery._get_table(operator_task_details['table']) if not table: table = bigquery.Table(operator_task_details['table']) # TODO(rajivpb): Ensure that mode == create here. operator_task_details['destination_project_dataset_table'] = table.full_name del operator_task_details['table'] if 'format' in operator_task_details: operator_task_details['export_format'] = 'CSV' if operator_task_details['format'] == 'csv' \ else 'NEWLINE_DELIMITED_JSON' del operator_task_details['format'] if 'delimiter' in operator_task_details: operator_task_details['field_delimiter'] = operator_task_details['delimiter'] del operator_task_details['delimiter'] if 'skip' in operator_task_details: operator_task_details['skip_leading_rows'] = operator_task_details['skip'] del operator_task_details['skip'] if 'path' in operator_task_details: bucket, source_object = PipelineGenerator._get_bucket_and_source_object( operator_task_details['path']) operator_task_details['bucket'] = bucket operator_task_details['source_objects'] = source_object del operator_task_details['path'] return operator_task_details @staticmethod def _get_bucket_and_source_object(gcs_path): return gcs_path.split('/')[2], '/'.join(gcs_path.split('/')[3:]) ================================================ FILE: google/datalab/contrib/pipeline/airflow/__init__.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from ._airflow import Airflow # noqa ================================================ FILE: google/datalab/contrib/pipeline/airflow/_airflow.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. import google.datalab.storage as storage class Airflow(object): """ Represents a Airflow object that encapsulates a set of functionality relating to the Cloud Airflow service. This object can be used to generate the python airflow spec. """ def __init__(self, gcs_dag_bucket, gcs_dag_file_path=None): """ Initializes an instance of a Airflow object. Args: gcs_dag_bucket: Bucket where Airflow expects dag files to be uploaded. gcs_dag_file_path: File path of the Airflow dag files. """ self._gcs_dag_bucket = gcs_dag_bucket self._gcs_dag_file_path = gcs_dag_file_path or '' def deploy(self, name, dag_string): if self._gcs_dag_file_path is not '' and self._gcs_dag_file_path.endswith('/') is False: self._gcs_dag_file_path = self._gcs_dag_file_path + '/' file_name = '{0}{1}.py'.format(self._gcs_dag_file_path, name) bucket = storage.Bucket(self._gcs_dag_bucket) file_object = bucket.object(file_name) file_object.write_stream(dag_string, 'text/plain') ================================================ FILE: google/datalab/contrib/pipeline/commands/__init__.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: google/datalab/contrib/pipeline/commands/_pipeline.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - Pipeline IPython Functionality.""" from __future__ import absolute_import from __future__ import print_function from __future__ import unicode_literals import google def _create_cell(args, cell_body): """Implements the pipeline cell create magic used to create Pipeline objects. The supported syntax is: %%pipeline create [] Args: args: the arguments following '%%pipeline create'. cell_body: the contents of the cell """ name = args.get('name') if name is None: raise Exception("Pipeline name was not specified.") pipeline_spec = google.datalab.utils.commands.parse_config( cell_body, google.datalab.utils.commands.notebook_environment()) airflow_spec = google.datalab.contrib.pipeline._pipeline.PipelineGenerator.generate_airflow_spec( name, pipeline_spec) debug = args.get('debug') if debug is True: return airflow_spec def _create_create_subparser(parser): create_parser = parser.subcommand('create', 'Create and/or execute a ' 'Pipeline object. If a pipeline ' 'name is not specified, the ' 'pipeline is scheduled.') create_parser.add_argument('-n', '--name', type=str, help='The name of this Pipeline object.') create_parser.add_argument('-d', '--debug', action='store_true', default=False, help='Print the airflow python spec.') return create_parser def _add_command(parser, subparser_fn, handler, cell_required=False, cell_prohibited=False): """ Create and initialize a pipeline subcommand handler. """ sub_parser = subparser_fn(parser) sub_parser.set_defaults(func=lambda args, cell: _dispatch_handler( args, cell, sub_parser, handler, cell_required=cell_required, cell_prohibited=cell_prohibited)) def _create_pipeline_parser(): """ Create the parser for the %pipeline magics. Note that because we use the func default handler dispatch mechanism of argparse, our handlers can take only one argument which is the parsed args. So we must create closures for the handlers that bind the cell contents and thus must recreate this parser for each cell upon execution. """ parser = google.datalab.utils.commands.CommandParser( prog='%pipeline', description=""" Execute various pipeline-related operations. Use "%pipeline -h" for help on a specific command. """) # %%pipeline create _add_command(parser, _create_create_subparser, _create_cell) return parser _pipeline_parser = _create_pipeline_parser() # TODO(rajivpb): Decorate this with '@IPython.core.magic.register_line_cell_magic' def pipeline(line, cell=None): """Implements the pipeline cell magic for ipython notebooks. The supported syntax is: %%pipeline [] or: %pipeline [] Use %pipeline --help for a list of commands, or %pipeline --help for help on a specific command. """ return google.datalab.utils.commands.handle_magic_line(line, cell, _pipeline_parser) def _dispatch_handler(args, cell, parser, handler, cell_required=False, cell_prohibited=False): """ Makes sure cell magics include cell and line magics don't, before dispatching to handler. Args: args: the parsed arguments from the magic line. cell: the contents of the cell, if any. parser: the argument parser for ; used for error message. handler: the handler to call if the cell present/absent check passes. cell_required: True for cell magics, False for line magics that can't be cell magics. cell_prohibited: True for line magics, False for cell magics that can't be line magics. Returns: The result of calling the handler. Raises: Exception if the invocation is not valid. """ if cell_prohibited: if cell and len(cell.strip()): parser.print_help() raise Exception( 'Additional data is not supported with the %s command.' % parser.prog) return handler(args) if cell_required and not cell: parser.print_help() raise Exception('The %s command requires additional data' % parser.prog) return handler(args, cell) ================================================ FILE: google/datalab/contrib/pipeline/composer/__init__.py ================================================ # Copyright 2018 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from ._composer import Composer # noqa ================================================ FILE: google/datalab/contrib/pipeline/composer/_api.py ================================================ # Copyright 2018 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Composer HTTP API wrapper.""" import google.datalab.utils class Api(object): """A helper class to issue Composer HTTP requests.""" _ENDPOINT = 'https://composer.googleapis.com/v1alpha1' _ENVIRONMENTS_PATH_FORMAT = '/projects/%s/locations/%s/environments/%s' @staticmethod def get_environment_details(zone, environment): """ Issues a request to Composer to get the environment details. Args: zone: GCP zone of the composer environment environment: name of the Composer environment Returns: A parsed result object. Raises: Exception if there is an error performing the operation. """ default_context = google.datalab.Context.default() url = (Api._ENDPOINT + (Api._ENVIRONMENTS_PATH_FORMAT % (default_context.project_id, zone, environment))) return google.datalab.utils.Http.request(url, credentials=default_context.credentials) ================================================ FILE: google/datalab/contrib/pipeline/composer/_composer.py ================================================ # Copyright 2018 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. import google.datalab.storage as storage from google.datalab.contrib.pipeline.composer._api import Api import re class Composer(object): """ Represents a Composer object that encapsulates a set of functionality relating to the Cloud Composer service. This object can be used to generate the python airflow spec. """ gcs_file_regexp = re.compile('gs://.*') def __init__(self, zone, environment): """ Initializes an instance of a Composer object. Args: zone: Zone in which Composer environment has been created. environment: Name of the Composer environment. """ self._zone = zone self._environment = environment self._gcs_dag_location = None def deploy(self, name, dag_string): bucket_name, file_path = self.gcs_dag_location.split('/', 3)[2:] # setting maxsplit to 3 file_name = '{0}{1}.py'.format(file_path, name) bucket = storage.Bucket(bucket_name) file_object = bucket.object(file_name) file_object.write_stream(dag_string, 'text/plain') @property def gcs_dag_location(self): if not self._gcs_dag_location: environment_details = Api.get_environment_details(self._zone, self._environment) if ('config' not in environment_details or 'gcsDagLocation' not in environment_details.get('config')): raise ValueError('Dag location unavailable from Composer environment {0}'.format( self._environment)) gcs_dag_location = environment_details['config']['gcsDagLocation'] if gcs_dag_location is None or not self.gcs_file_regexp.match(gcs_dag_location): raise ValueError( 'Dag location {0} from Composer environment {1} is in incorrect format'.format( gcs_dag_location, self._environment)) self._gcs_dag_location = gcs_dag_location if gcs_dag_location.endswith('/') is False: self._gcs_dag_location = self._gcs_dag_location + '/' return self._gcs_dag_location ================================================ FILE: google/datalab/data/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - Generic SQL Helpers.""" from __future__ import absolute_import from __future__ import unicode_literals from ._csv_file import CsvFile __all__ = ['CsvFile'] ================================================ FILE: google/datalab/data/_csv_file.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements usefule CSV utilities.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import next from builtins import str as newstr from builtins import range from builtins import object import csv import os import pandas as pd import random import sys try: from StringIO import StringIO except ImportError: from io import StringIO import tempfile import google.datalab.storage import google.datalab.utils _MAX_CSV_BYTES = 10000000 class CsvFile(object): """Represents a CSV file in GCS or locally with same schema. """ def __init__(self, path, delimiter=b','): """Initializes an instance of a Csv instance. Args: path: path of the Csv file. delimiter: the separator used to parse a Csv line. """ self._path = path self._delimiter = delimiter @property def path(self): return self._path @staticmethod def _read_gcs_lines(path, max_lines=None): return google.datalab.storage.Object.from_url(path).read_lines(max_lines) @staticmethod def _read_local_lines(path, max_lines=None): lines = [] for line in open(path): if max_lines is not None and len(lines) >= max_lines: break lines.append(line) return lines def _is_probably_categorical(self, column): if newstr(column.dtype) != 'object': # only string types (represented in DataFrame as object) can potentially be categorical return False if len(max(column, key=lambda p: len(newstr(p)))) > 100: return False # value too long to be a category if len(set(column)) > 100: return False # too many unique values to be a category return True def browse(self, max_lines=None, headers=None): """Try reading specified number of lines from the CSV object. Args: max_lines: max number of lines to read. If None, the whole file is read headers: a list of strings as column names. If None, it will use "col0, col1..." Returns: A pandas DataFrame with the schema inferred from the data. Raises: Exception if the csv object cannot be read or not enough lines to read, or the headers size does not match columns size. """ if self.path.startswith('gs://'): lines = CsvFile._read_gcs_lines(self.path, max_lines) else: lines = CsvFile._read_local_lines(self.path, max_lines) if len(lines) == 0: return pd.DataFrame(columns=headers) columns_size = len(next(csv.reader([lines[0]], delimiter=self._delimiter))) if headers is None: headers = ['col' + newstr(e) for e in range(columns_size)] if len(headers) != columns_size: raise Exception('Number of columns in CSV do not match number of headers') buf = StringIO() for line in lines: buf.write(line) buf.write('\n') buf.seek(0) df = pd.read_csv(buf, names=headers, delimiter=self._delimiter) for key, col in df.iteritems(): if self._is_probably_categorical(col): df[key] = df[key].astype('category') return df def _create_external_data_source(self, skip_header_rows): import google.datalab.bigquery as bq df = self.browse(1, None) # read each column as STRING because we only want to sample rows. schema_train = bq.Schema([{'name': name, 'type': 'STRING'} for name in df.keys()]) options = bq.CSVOptions(skip_leading_rows=(1 if skip_header_rows is True else 0)) return bq.ExternalDataSource(self.path, csv_options=options, schema=schema_train, max_bad_records=0) def _get_gcs_csv_row_count(self, external_data_source): import google.datalab.bigquery as bq results = bq.Query('SELECT count(*) from data', data_sources={'data': external_data_source}).execute().result() return results[0].values()[0] def sample_to(self, count, skip_header_rows, strategy, target): """Sample rows from GCS or local file and save results to target file. Args: count: number of rows to sample. If strategy is "BIGQUERY", it is used as approximate number. skip_header_rows: whether to skip first row when reading from source. strategy: can be "LOCAL" or "BIGQUERY". If local, the sampling happens in local memory, and number of resulting rows matches count. If BigQuery, sampling is done with BigQuery in cloud, and the number of resulting rows will be approximated to count. target: The target file path, can be GCS or local path. Raises: Exception if strategy is "BIGQUERY" but source is not a GCS path. """ if sys.version_info.major > 2: xrange = range # for python 3 compatibility # TODO(qimingj) Add unit test # Read data from source into DataFrame. if strategy == 'BIGQUERY': import google.datalab.bigquery as bq if not self.path.startswith('gs://'): raise Exception('Cannot use BIGQUERY if data is not in GCS') external_data_source = self._create_external_data_source(skip_header_rows) row_count = self._get_gcs_csv_row_count(external_data_source) query = bq.Query('SELECT * from data', data_sources={'data': external_data_source}) sampling = bq.Sampling.random(count * 100 / float(row_count)) sample = query.sample(sampling=sampling) df = sample.to_dataframe() elif strategy == 'LOCAL': local_file = self.path if self.path.startswith('gs://'): local_file = tempfile.mktemp() google.datalab.utils.gcs_copy_file(self.path, local_file) with open(local_file) as f: row_count = sum(1 for line in f) start_row = 1 if skip_header_rows is True else 0 skip_count = row_count - count - 1 if skip_header_rows is True else row_count - count skip = sorted(random.sample(xrange(start_row, row_count), skip_count)) header_row = 0 if skip_header_rows is True else None df = pd.read_csv(local_file, skiprows=skip, header=header_row, delimiter=self._delimiter) if self.path.startswith('gs://'): os.remove(local_file) else: raise Exception('strategy must be BIGQUERY or LOCAL') # Write to target. if target.startswith('gs://'): with tempfile.NamedTemporaryFile() as f: df.to_csv(f, header=False, index=False) f.flush() google.datalab.utils.gcs_copy_file(f.name, target) else: with open(target, 'w') as f: df.to_csv(f, header=False, index=False, sep=str(self._delimiter)) ================================================ FILE: google/datalab/kernel/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Datalab - notebook functionality.""" import httplib2 as _httplib2 import requests as _requests try: import IPython as _IPython import IPython.core.magic as _magic # noqa import IPython.core.interactiveshell as _shell from IPython import get_ipython # noqa except ImportError: raise Exception('This package requires an IPython notebook installation') import google.datalab # Import the modules that do cell magics. import google.datalab.bigquery.commands import google.datalab.commands import google.datalab.stackdriver.commands import google.datalab.storage.commands import google.datalab.utils import google.datalab.utils.commands _orig_request = _httplib2.Http.request _orig_init = _requests.Session.__init__ _orig_run_cell_magic = _shell.InteractiveShell.run_cell_magic _orig_run_line_magic = _shell.InteractiveShell.run_line_magic def load_ipython_extension(shell): """ Called when the extension is loaded. Args: shell - (NotebookWebApplication): handle to the Notebook interactive shell instance. """ # Inject our user agent on all requests by monkey-patching a wrapper around httplib2.Http.request. def _request(self, uri, method="GET", body=None, headers=None, redirections=_httplib2.DEFAULT_MAX_REDIRECTS, connection_type=None): if headers is None: headers = {} headers['user-agent'] = 'GoogleCloudDataLab/1.0' return _orig_request(self, uri, method=method, body=body, headers=headers, redirections=redirections, connection_type=connection_type) _httplib2.Http.request = _request # Similarly for the requests library. def _init_session(self): _orig_init(self) self.headers['User-Agent'] = 'GoogleCloudDataLab/1.0' _requests.Session.__init__ = _init_session # Be more tolerant with magics. If the user specified a cell magic that doesn't # exist and an empty cell body but a line magic with that name exists, run that # instead. Conversely, if the user specified a line magic that doesn't exist but # a cell magic exists with that name, run the cell magic with an empty body. def _run_line_magic(self, magic_name, line): fn = self.find_line_magic(magic_name) if fn is None: cm = self.find_cell_magic(magic_name) if cm: return _run_cell_magic(self, magic_name, line, None) return _orig_run_line_magic(self, magic_name, line) def _run_cell_magic(self, magic_name, line, cell): if cell is None or len(cell) == 0 or cell.isspace(): fn = self.find_line_magic(magic_name) if fn: return _orig_run_line_magic(self, magic_name, line) # IPython will complain if cell is empty string but not if it is None cell = None return _orig_run_cell_magic(self, magic_name, line, cell) _shell.InteractiveShell.run_cell_magic = _run_cell_magic _shell.InteractiveShell.run_line_magic = _run_line_magic # Define global 'project_id' and 'set_project_id' functions to manage the default project ID. We # do this conditionally in a try/catch # to avoid the call to Context.default() when running tests # which mock IPython.get_ipython(). def _get_project_id(): try: return google.datalab.Context.default().project_id except Exception: return None def _set_project_id(project_id): context = google.datalab.Context.default() context.set_project_id(project_id) try: from datalab.context import Context as _old_context _old_context.default().set_project_id(project_id) except ImportError: # If the old library is not loaded, then we don't have to do anything pass try: if 'datalab_project_id' not in _IPython.get_ipython().user_ns: _IPython.get_ipython().user_ns['datalab_project_id'] = _get_project_id _IPython.get_ipython().user_ns['set_datalab_project_id'] = _set_project_id except TypeError: pass def unload_ipython_extension(shell): _shell.InteractiveShell.run_cell_magic = _orig_run_cell_magic _shell.InteractiveShell.run_line_magic = _orig_run_line_magic _requests.Session.__init__ = _orig_init _httplib2.Http.request = _orig_request try: del _IPython.get_ipython().user_ns['project_id'] del _IPython.get_ipython().user_ns['set_project_id'] except Exception: pass # We mock IPython for tests so we need this. # TODO(gram): unregister imports/magics/etc. ================================================ FILE: google/datalab/ml/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. # flake8: noqa """CloudML Helper Library.""" from __future__ import absolute_import from ._job import Jobs, Job from ._summary import Summary from ._tensorboard import TensorBoard from ._dataset import CsvDataSet, BigQueryDataSet, TransformedDataSet from ._cloud_models import Models, ModelVersions from ._confusion_matrix import ConfusionMatrix from ._feature_slice_view import FeatureSliceView from ._cloud_training_config import CloudTrainingConfig from ._fasets import FacetsOverview, FacetsDiveview from ._metrics import Metrics from ._util import * ================================================ FILE: google/datalab/ml/_cloud_models.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Cloud ML Model Operations""" from googleapiclient import discovery import os import yaml import google.datalab as datalab from . import _util class Models(object): """Represents a list of Cloud ML models for a project.""" def __init__(self, project_id=None): """ Args: project_id: project_id of the models. If not provided, default project_id will be used. """ if project_id is None: project_id = datalab.Context.default().project_id self._project_id = project_id self._credentials = datalab.Context.default().credentials self._api = discovery.build('ml', 'v1', credentials=self._credentials) self._page_size = 0 def _retrieve_models(self, page_token, _): list_info = self._api.projects().models().list( parent='projects/' + self._project_id, pageToken=page_token, pageSize=self._page_size).execute() models = list_info.get('models', []) self._page_size = self._page_size or len(models) page_token = list_info.get('nextPageToken', None) return models, page_token def get_iterator(self): """Get iterator of models so it can be used as "for model in Models().get_iterator()". """ return iter(datalab.utils.Iterator(self._retrieve_models)) def get_model_details(self, model_name): """Get details of the specified model from CloudML Service. Args: model_name: the name of the model. It can be a model full name ("projects/[project_id]/models/[model_name]") or just [model_name]. Returns: a dictionary of the model details. """ full_name = model_name if not model_name.startswith('projects/'): full_name = ('projects/%s/models/%s' % (self._project_id, model_name)) return self._api.projects().models().get(name=full_name).execute() def create(self, model_name): """Create a model. Args: model_name: the short name of the model, such as "iris". Returns: If successful, returns informaiton of the model, such as {u'regions': [u'us-central1'], u'name': u'projects/myproject/models/mymodel'} Raises: If the model creation failed. """ body = {'name': model_name} parent = 'projects/' + self._project_id # Model creation is instant. If anything goes wrong, Exception will be thrown. return self._api.projects().models().create(body=body, parent=parent).execute() def delete(self, model_name): """Delete a model. Args: model_name: the name of the model. It can be a model full name ("projects/[project_id]/models/[model_name]") or just [model_name]. """ full_name = model_name if not model_name.startswith('projects/'): full_name = ('projects/%s/models/%s' % (self._project_id, model_name)) response = self._api.projects().models().delete(name=full_name).execute() if 'name' not in response: raise Exception('Invalid response from service. "name" is not found.') _util.wait_for_long_running_operation(response['name']) def list(self, count=10): """List models under the current project in a table view. Args: count: upper limit of the number of models to list. Raises: Exception if it is called in a non-IPython environment. """ import IPython data = [] # Add range(count) to loop so it will stop either it reaches count, or iteration # on self is exhausted. "self" is iterable (see __iter__() method). for _, model in zip(range(count), self.get_iterator()): element = {'name': model['name']} if 'defaultVersion' in model: version_short_name = model['defaultVersion']['name'].split('/')[-1] element['defaultVersion'] = version_short_name data.append(element) IPython.display.display( datalab.utils.commands.render_dictionary(data, ['name', 'defaultVersion'])) def describe(self, model_name): """Print information of a specified model. Args: model_name: the name of the model to print details on. """ model_yaml = yaml.safe_dump(self.get_model_details(model_name), default_flow_style=False) print(model_yaml) class ModelVersions(object): """Represents a list of versions for a Cloud ML model.""" def __init__(self, model_name, project_id=None): """ Args: model_name: the name of the model. It can be a model full name ("projects/[project_id]/models/[model_name]") or just [model_name]. project_id: project_id of the models. If not provided and model_name is not a full name (not including project_id), default project_id will be used. """ if project_id is None: self._project_id = datalab.Context.default().project_id self._credentials = datalab.Context.default().credentials self._api = discovery.build('ml', 'v1', credentials=self._credentials) if not model_name.startswith('projects/'): model_name = ('projects/%s/models/%s' % (self._project_id, model_name)) self._full_model_name = model_name self._model_name = self._full_model_name.split('/')[-1] self._page_size = 0 def _retrieve_versions(self, page_token, _): parent = self._full_model_name list_info = self._api.projects().models().versions().list(parent=parent, pageToken=page_token, pageSize=self._page_size).execute() versions = list_info.get('versions', []) self._page_size = self._page_size or len(versions) page_token = list_info.get('nextPageToken', None) return versions, page_token def get_iterator(self): """Get iterator of versions so it can be used as "for v in ModelVersions(model_name).get_iterator()". """ return iter(datalab.utils.Iterator(self._retrieve_versions)) def get_version_details(self, version_name): """Get details of a version. Args: version: the name of the version in short form, such as "v1". Returns: a dictionary containing the version details. """ name = ('%s/versions/%s' % (self._full_model_name, version_name)) return self._api.projects().models().versions().get(name=name).execute() def deploy(self, version_name, path, runtime_version=None): """Deploy a model version to the cloud. Args: version_name: the name of the version in short form, such as "v1". path: the Google Cloud Storage path (gs://...) which contains the model files. runtime_version: the ML Engine runtime version as a string, example '1.2'. See https://cloud.google.com/ml-engine/docs/concepts/runtime-version-list for a list of runtimes. If None, the ML Engine service will pick one. Raises: Exception if the path is invalid or does not contain expected files. Exception if the service returns invalid response. """ if not path.startswith('gs://'): raise Exception('Invalid path. Only Google Cloud Storage path (gs://...) is accepted.') # If there is no "export.meta" or"saved_model.pb" under path but there is # path/model/export.meta or path/model/saved_model.pb, then append /model to the path. if not datalab.storage.Object.from_url(os.path.join(path, 'export.meta')).exists() and not \ datalab.storage.Object.from_url(os.path.join(path, 'saved_model.pb')).exists(): if datalab.storage.Object.from_url(os.path.join(path, 'model', 'export.meta')).exists() or \ datalab.storage.Object.from_url(os.path.join(path, 'model', 'saved_model.pb')).exists(): path = os.path.join(path, 'model') else: print('Cannot find export.meta or saved_model.pb, but continue with deployment anyway.') body = {'name': self._model_name} parent = 'projects/' + self._project_id try: self._api.projects().models().create(body=body, parent=parent).execute() except: # Trying to create an already existing model gets an error. Ignore it. pass body = { 'name': version_name, 'deployment_uri': path, } if runtime_version: body['runtime_version'] = runtime_version response = self._api.projects().models().versions().create( body=body, parent=self._full_model_name).execute() if 'name' not in response: raise Exception('Invalid response from service. "name" is not found.') _util.wait_for_long_running_operation(response['name']) def delete(self, version_name): """Delete a version of model. Args: version_name: the name of the version in short form, such as "v1". """ name = ('%s/versions/%s' % (self._full_model_name, version_name)) response = self._api.projects().models().versions().delete(name=name).execute() if 'name' not in response: raise Exception('Invalid response from service. "name" is not found.') _util.wait_for_long_running_operation(response['name']) def predict(self, version_name, data): """Get prediction results from features instances. Args: version_name: the name of the version used for prediction. data: typically a list of instance to be submitted for prediction. The format of the instance depends on the model. For example, structured data model may require a csv line for each instance. Note that online prediction only works on models that take one placeholder value, such as a string encoding a csv line. Returns: A list of prediction results for given instances. Each element is a dictionary representing output mapping from the graph. An example: [{"predictions": 1, "score": [0.00078, 0.71406, 0.28515]}, {"predictions": 1, "score": [0.00244, 0.99634, 0.00121]}] """ full_version_name = ('%s/versions/%s' % (self._full_model_name, version_name)) request = self._api.projects().predict(body={'instances': data}, name=full_version_name) request.headers['user-agent'] = 'GoogleCloudDataLab/1.0' result = request.execute() if 'predictions' not in result: raise Exception('Invalid response from service. Cannot find "predictions" in response.') return result['predictions'] def describe(self, version_name): """Print information of a specified model. Args: version: the name of the version in short form, such as "v1". """ version_yaml = yaml.safe_dump(self.get_version_details(version_name), default_flow_style=False) print(version_yaml) def list(self): """List versions under the current model in a table view. Raises: Exception if it is called in a non-IPython environment. """ import IPython # "self" is iterable (see __iter__() method). data = [{'name': version['name'].split()[-1], 'deploymentUri': version['deploymentUri'], 'createTime': version['createTime']} for version in self.get_iterator()] IPython.display.display( datalab.utils.commands.render_dictionary(data, ['name', 'deploymentUri', 'createTime'])) ================================================ FILE: google/datalab/ml/_cloud_training_config.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple _CloudTrainingConfig = namedtuple("CloudConfig", ['region', 'scale_tier', 'master_type', 'worker_type', 'parameter_server_type', 'worker_count', 'parameter_server_count']) _CloudTrainingConfig.__new__.__defaults__ = ('BASIC', None, None, None, None, None) class CloudTrainingConfig(_CloudTrainingConfig): """A config namedtuple containing cloud specific configurations for CloudML training. Fields: region: the region of the training job to be submitted. For example, "us-central1". Run "gcloud compute regions list" to get a list of regions. scale_tier: Specifies the machine types, the number of replicas for workers and parameter servers. For example, "STANDARD_1". See https://cloud.google.com/ml/reference/rest/v1beta1/projects.jobs#scaletier for list of accepted values. master_type: specifies the type of virtual machine to use for your training job's master worker. Must set this value when scale_tier is set to CUSTOM. See the link in "scale_tier". worker_type: specifies the type of virtual machine to use for your training job's worker nodes. Must set this value when scale_tier is set to CUSTOM. parameter_server_type: specifies the type of virtual machine to use for your training job's parameter server. Must set this value when scale_tier is set to CUSTOM. worker_count: the number of worker replicas to use for the training job. Each replica in the cluster will be of the type specified in "worker_type". Must set this value when scale_tier is set to CUSTOM. parameter_server_count: the number of parameter server replicas to use. Each replica in the cluster will be of the type specified in "parameter_server_type". Must set this value when scale_tier is set to CUSTOM. """ pass ================================================ FILE: google/datalab/ml/_confusion_matrix.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. import numpy as np import itertools import json import matplotlib.pyplot as plt import pandas as pd from sklearn.metrics import confusion_matrix import google.datalab.bigquery as bq from . import _util class ConfusionMatrix(object): """Represents a confusion matrix.""" def __init__(self, cm, labels): """ Args: cm: a 2-dimensional matrix with row index being target, column index being predicted, and values being count. labels: the labels whose order matches the row/column indexes. """ self._cm = cm self._labels = labels @staticmethod def from_csv(input_csv, headers=None, schema_file=None): """Create a ConfusionMatrix from a csv file. Args: input_csv: Path to a Csv file (with no header). Can be local or GCS path. headers: Csv headers. If present, it must include 'target' and 'predicted'. schema_file: Path to a JSON file containing BigQuery schema. Used if "headers" is None. If present, it must include 'target' and 'predicted' columns. Returns: A ConfusionMatrix that can be plotted. Raises: ValueError if both headers and schema_file are None, or it does not include 'target' or 'predicted' columns. """ if headers is not None: names = headers elif schema_file is not None: with _util.open_local_or_gcs(schema_file, mode='r') as f: schema = json.load(f) names = [x['name'] for x in schema] else: raise ValueError('Either headers or schema_file is needed') all_files = _util.glob_files(input_csv) all_df = [] for file_name in all_files: with _util.open_local_or_gcs(file_name, mode='r') as f: all_df.append(pd.read_csv(f, names=names)) df = pd.concat(all_df, ignore_index=True) if 'target' not in df or 'predicted' not in df: raise ValueError('Cannot find "target" or "predicted" column') labels = sorted(set(df['target']) | set(df['predicted'])) cm = confusion_matrix(df['target'], df['predicted'], labels=labels) return ConfusionMatrix(cm, labels) @staticmethod def from_bigquery(sql): """Create a ConfusionMatrix from a BigQuery table or query. Args: sql: Can be one of: A SQL query string. A Bigquery table string. A Query object defined with '%%bq query --name [query_name]'. The query results or table must include "target", "predicted" columns. Returns: A ConfusionMatrix that can be plotted. Raises: ValueError if query results or table does not include 'target' or 'predicted' columns. """ if isinstance(sql, bq.Query): sql = sql._expanded_sql() parts = sql.split('.') if len(parts) == 1 or len(parts) > 3 or any(' ' in x for x in parts): sql = '(' + sql + ')' # query, not a table name else: sql = '`' + sql + '`' # table name query = bq.Query( 'SELECT target, predicted, count(*) as count FROM %s group by target, predicted' % sql) df = query.execute().result().to_dataframe() labels = sorted(set(df['target']) | set(df['predicted'])) labels_count = len(labels) df['target'] = [labels.index(x) for x in df['target']] df['predicted'] = [labels.index(x) for x in df['predicted']] cm = [[0] * labels_count for i in range(labels_count)] for index, row in df.iterrows(): cm[row['target']][row['predicted']] = row['count'] return ConfusionMatrix(cm, labels) def to_dataframe(self): """Convert the confusion matrix to a dataframe. Returns: A DataFrame with "target", "predicted", "count" columns. """ data = [] for target_index, target_row in enumerate(self._cm): for predicted_index, count in enumerate(target_row): data.append((self._labels[target_index], self._labels[predicted_index], count)) return pd.DataFrame(data, columns=['target', 'predicted', 'count']) def plot(self, figsize=None, rotation=45): """Plot the confusion matrix. Args: figsize: tuple (x, y) of ints. Sets the size of the figure rotation: the rotation angle of the labels on the x-axis. """ fig, ax = plt.subplots(figsize=figsize) plt.imshow(self._cm, interpolation='nearest', cmap=plt.cm.Blues, aspect='auto') plt.title('Confusion matrix') plt.colorbar() tick_marks = np.arange(len(self._labels)) plt.xticks(tick_marks, self._labels, rotation=rotation) plt.yticks(tick_marks, self._labels) if isinstance(self._cm, list): # If cm is created from BigQuery then it is a list. thresh = max(max(self._cm)) / 2. for i, j in itertools.product(range(len(self._labels)), range(len(self._labels))): plt.text(j, i, self._cm[i][j], horizontalalignment="center", color="white" if self._cm[i][j] > thresh else "black") else: # If cm is created from csv then it is a sklearn's confusion_matrix. thresh = self._cm.max() / 2. for i, j in itertools.product(range(len(self._labels)), range(len(self._labels))): plt.text(j, i, self._cm[i, j], horizontalalignment="center", color="white" if self._cm[i, j] > thresh else "black") plt.tight_layout() plt.ylabel('True label') plt.xlabel('Predicted label') ================================================ FILE: google/datalab/ml/_dataset.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements DataSets that serve two purposes: 1. Recommended way to pass data source to ML packages. 2. All DataSets can be sampled into dataframe for analysis/visualization. """ import json import numpy as np import pandas as pd import random import six import google.datalab.bigquery as bq from . import _util class CsvDataSet(object): """DataSet based on CSV files and schema.""" def __init__(self, file_pattern, schema=None, schema_file=None): """ Args: file_pattern: A list of CSV files. or a string. Can contain wildcards in file names. Can be local or GCS path. schema: A google.datalab.bigquery.Schema object, or a json schema in the form of [{'name': 'col1', 'type': 'STRING'}, {'name': 'col2', 'type': 'INTEGER'}] or a single string in of the form 'col1:STRING,col2:INTEGER,col3:FLOAT'. schema_file: A JSON serialized schema file. If schema is None, it will try to load from schema_file if not None. Raise: ValueError if both schema and schema_file are None. """ if schema is None and schema_file is None: raise ValueError('schema and schema_file cannot both be None.') if schema is not None: # This check needs to come before list check, because Schema # is a subclass of list if isinstance(schema, bq.Schema): self._schema = schema._bq_schema elif isinstance(schema, list): self._schema = schema else: self._schema = [] for x in schema.split(','): parts = x.split(':') if len(parts) != 2: raise ValueError('invalid schema string "%s"' % x) self._schema.append({'name': parts[0].strip(), 'type': parts[1].strip()}) else: self._schema = json.loads(_util.read_file_to_string(schema_file)) if isinstance(file_pattern, six.string_types): file_pattern = [file_pattern] self._input_files = file_pattern self._glob_files = [] self._size = None @property def input_files(self): """Returns the file list that was given to this class without globing files.""" return self._input_files @property def files(self): if not self._glob_files: for file in self._input_files: # glob_files() returns unicode strings which doesn't make DataFlow happy. So str(). self._glob_files += [str(x) for x in _util.glob_files(file)] return self._glob_files @property def schema(self): return self._schema @property def size(self): """The size of the schema. If the underlying data source changes, it may be outdated. """ if self._size is None: self._size = 0 for csv_file in self.files: self._size += sum(1 if line else 0 for line in _util.open_local_or_gcs(csv_file, 'r')) return self._size def sample(self, n): """ Samples data into a Pandas DataFrame. Args: n: number of sampled counts. Returns: A dataframe containing sampled data. Raises: Exception if n is larger than number of rows. """ row_total_count = 0 row_counts = [] for file in self.files: with _util.open_local_or_gcs(file, 'r') as f: num_lines = sum(1 for line in f) row_total_count += num_lines row_counts.append(num_lines) names = None dtype = None if self._schema: _MAPPINGS = { 'FLOAT': np.float64, 'INTEGER': np.int64, 'TIMESTAMP': np.datetime64, 'BOOLEAN': np.bool, } names = [x['name'] for x in self._schema] dtype = {x['name']: _MAPPINGS.get(x['type'], object) for x in self._schema} skip_count = row_total_count - n # Get all skipped indexes. These will be distributed into each file. # Note that random.sample will raise Exception if skip_count is greater than rows count. skip_all = sorted(random.sample(range(0, row_total_count), skip_count)) dfs = [] for file, row_count in zip(self.files, row_counts): skip = [x for x in skip_all if x < row_count] skip_all = [x - row_count for x in skip_all if x >= row_count] with _util.open_local_or_gcs(file, 'r') as f: dfs.append(pd.read_csv(f, skiprows=skip, names=names, dtype=dtype, header=None)) return pd.concat(dfs, axis=0, ignore_index=True) class BigQueryDataSet(object): """DataSet based on BigQuery table or query.""" def __init__(self, sql=None, table=None): """ Args: sql: A SQL query string, or a SQL Query module defined with '%%bq query --name [query_name]' table: A table name in the form of 'dataset.table or project.dataset.table'. Raises: ValueError if both sql and table are set, or both are None. """ if (sql is None and table is None) or (sql is not None and table is not None): raise ValueError('One and only one of sql and table should be set.') self._query = sql._expanded_sql() if isinstance(sql, bq.Query) else sql self._table = table self._schema = None self._size = None @property def query(self): return self._query @property def table(self): return self._table def _get_source(self): if self._query is not None: return '(' + self._query + ')' return '`' + self._table + '`' @property def schema(self): if self._schema is None: self._schema = bq.Query('SELECT * FROM %s LIMIT 1' % self._get_source()).execute().result().schema return self._schema._bq_schema @property def size(self): """The size of the schema. If the underlying data source changes, it may be outdated. """ if self._size is None: self._size = bq.Query('SELECT COUNT(*) FROM %s' % self._get_source()).execute().result()[0].values()[0] return self._size def sample(self, n): """Samples data into a Pandas DataFrame. Note that it calls BigQuery so it will incur cost. Args: n: number of sampled counts. Note that the number of counts returned is approximated. Returns: A dataframe containing sampled data. Raises: Exception if n is larger than number of rows. """ total = bq.Query('select count(*) from %s' % self._get_source()).execute().result()[0].values()[0] if n > total: raise ValueError('sample larger than population') sampling = bq.Sampling.random(percent=n * 100.0 / float(total)) if self._query is not None: source = self._query else: source = 'SELECT * FROM `%s`' % self._table sample = bq.Query(source).execute(sampling=sampling).result() df = sample.to_dataframe() return df class TransformedDataSet(object): """DataSet based on tf.example.""" def __init__(self, file_pattern): """ Args: file_pattern: A list of gzip TF Example files. or a string. Can contain wildcards in file names. Can be local or GCS path. """ if isinstance(file_pattern, six.string_types): file_pattern = [file_pattern] self._input_files = file_pattern self._glob_files = [] self._size = None @property def input_files(self): """Returns the file list that was given to this class without globing files.""" return self._input_files @property def files(self): if not self._glob_files: for file in self._input_files: # glob_files() returns unicode strings which doesn't make DataFlow happy. So str(). self._glob_files += [str(x) for x in _util.glob_files(file)] return self._glob_files @property def size(self): """The number of instances in the data. If the underlying data source changes, it may be outdated. """ import tensorflow as tf if self._size is None: self._size = 0 options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP) for tfexample_file in self.files: self._size += sum(1 for x in tf.python_io.tf_record_iterator(tfexample_file, options=options)) return self._size ================================================ FILE: google/datalab/ml/_fasets.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. import base64 import google.datalab as datalab from google.datalab.utils.facets.generic_feature_statistics_generator \ import GenericFeatureStatisticsGenerator import numpy as np import pandas as pd import re import six class FacetsOverview(object): """Represents A facets overview. """ def _remove_nonascii(self, df): """Make copy and remove non-ascii characters from it.""" df_copy = df.copy(deep=True) for col in df_copy.columns: if (df_copy[col].dtype == np.dtype('O')): df_copy[col] = df[col].apply( lambda x: re.sub(r'[^\x00-\x7f]', r'', x) if isinstance(x, six.string_types) else x) return df_copy def plot(self, data): """ Plots an overview in a list of dataframes Args: data: a dictionary with key the name, and value the dataframe. """ import IPython if not isinstance(data, dict) or not all(isinstance(v, pd.DataFrame) for v in data.values()): raise ValueError('Expect a dictionary where the values are all dataframes.') gfsg = GenericFeatureStatisticsGenerator() data = [{'name': k, 'table': self._remove_nonascii(v)} for k, v in six.iteritems(data)] data_proto = gfsg.ProtoFromDataFrames(data) protostr = base64.b64encode(data_proto.SerializeToString()).decode("utf-8") html_id = 'f' + datalab.utils.commands.Html.next_id() HTML_TEMPLATE = """ """ html = HTML_TEMPLATE.format(html_id=html_id, protostr=protostr) return IPython.core.display.HTML(html) class FacetsDiveview(object): """Represents A facets overview. """ def plot(self, data, height=1000, render_large_data=False): """ Plots a detail view of data. Args: data: a Pandas dataframe. height: the height of the output. """ import IPython if not isinstance(data, pd.DataFrame): raise ValueError('Expect a DataFrame.') if (len(data) > 10000 and not render_large_data): raise ValueError('Facets dive may not work well with more than 10000 rows. ' + 'Reduce data or set "render_large_data" to True.') jsonstr = data.to_json(orient='records') html_id = 'f' + datalab.utils.commands.Html.next_id() HTML_TEMPLATE = """ """ html = HTML_TEMPLATE.format(html_id=html_id, jsonstr=jsonstr, height=height) return IPython.core.display.HTML(html) ================================================ FILE: google/datalab/ml/_feature_slice_view.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. import json import pandas as pd import sys import google.datalab as datalab import google.datalab.bigquery as bq class FeatureSliceView(object): """Represents A feature slice view.""" def _get_lantern_format(self, df): """ Feature slice view browser expects data in the format of: {"metricValues": {"count": 12, "accuracy": 1.0}, "feature": "species:Iris-setosa"} {"metricValues": {"count": 11, "accuracy": 0.72}, "feature": "species:Iris-versicolor"} ... This function converts a DataFrame to such format. """ if ('count' not in df) or ('feature' not in df): raise Exception('No "count" or "feature" found in data.') if len(df.columns) < 3: raise Exception('Need at least one metrics column.') if len(df) == 0: raise Exception('Data is empty') data = [] for _, row in df.iterrows(): metric_values = dict(row) feature = metric_values.pop('feature') data.append({'feature': feature, 'metricValues': metric_values}) return data def plot(self, data): """ Plots a featire slice view on given data. Args: data: Can be one of: A string of sql query. A sql query module defined by "%%sql --module module_name". A pandas DataFrame. Regardless of data type, it must include the following columns: "feature": identifies a slice of features. For example: "petal_length:4.0-4.2". "count": number of instances in that slice of features. All other columns are viewed as metrics for its feature slice. At least one is required. """ import IPython if ((sys.version_info.major > 2 and isinstance(data, str)) or (sys.version_info.major <= 2 and isinstance(data, basestring))): data = bq.Query(data) if isinstance(data, bq.Query): df = data.execute().result().to_dataframe() data = self._get_lantern_format(df) elif isinstance(data, pd.core.frame.DataFrame): data = self._get_lantern_format(data) else: raise Exception('data needs to be a sql query, or a pandas DataFrame.') HTML_TEMPLATE = """ """ # Serialize the data and list of metrics names to JSON string. metrics_str = str(map(str, data[0]['metricValues'].keys())) data_str = str([{str(k): json.dumps(v) for k, v in elem.iteritems()} for elem in data]) html_id = 'l' + datalab.utils.commands.Html.next_id() html = HTML_TEMPLATE.format(html_id=html_id, metrics=metrics_str, data=data_str) IPython.display.display(IPython.display.HTML(html)) ================================================ FILE: google/datalab/ml/_job.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Cloud ML Operation wrapper.""" import six import google.datalab as datalab from googleapiclient import discovery import yaml import datetime class Job(datalab.Job): """Represents a Cloud ML job.""" def __init__(self, name, context=None): """Initializes an instance of a CloudML Job. Args: name: the name of the job. It can be an operation full name ("projects/[project_id]/jobs/[operation_name]") or just [operation_name]. context: an optional Context object providing project_id and credentials. """ super(Job, self).__init__(name) if context is None: context = datalab.Context.default() self._context = context self._api = discovery.build('ml', 'v1', credentials=self._context.credentials) if not name.startswith('projects/'): name = 'projects/' + self._context.project_id + '/jobs/' + name self._name = name self._refresh_state() def _refresh_state(self): """ Refresh the job info. """ self._info = self._api.projects().jobs().get(name=self._name).execute() self._fatal_error = self._info.get('errorMessage', None) state = str(self._info.get('state')) self._is_complete = (state == 'SUCCEEDED' or state == 'FAILED') @property def info(self): self._refresh_state() return self._info def describe(self): self._refresh_state() job_yaml = yaml.safe_dump(self._info, default_flow_style=False) print(job_yaml) @staticmethod def submit_training(job_request, job_id=None): """Submit a training job. Args: job_request: the arguments of the training job in a dict. For example, { 'package_uris': 'gs://my-bucket/iris/trainer-0.1.tar.gz', 'python_module': 'trainer.task', 'scale_tier': 'BASIC', 'region': 'us-central1', 'args': { 'train_data_paths': ['gs://mubucket/data/features_train'], 'eval_data_paths': ['gs://mubucket/data/features_eval'], 'metadata_path': 'gs://mubucket/data/metadata.yaml', 'output_path': 'gs://mubucket/data/mymodel/', } } If 'args' is present in job_request and is a dict, it will be expanded to --key value or --key list_item_0 --key list_item_1, ... job_id: id for the training job. If None, an id based on timestamp will be generated. Returns: A Job object representing the cloud training job. """ new_job_request = dict(job_request) # convert job_args from dict to list as service required. if 'args' in job_request and isinstance(job_request['args'], dict): job_args = job_request['args'] args = [] for k, v in six.iteritems(job_args): if isinstance(v, list): for item in v: args.append('--' + str(k)) args.append(str(item)) else: args.append('--' + str(k)) args.append(str(v)) new_job_request['args'] = args if job_id is None: job_id = datetime.datetime.now().strftime('%y%m%d_%H%M%S') if 'python_module' in new_job_request: job_id = new_job_request['python_module'].replace('.', '_') + \ '_' + job_id job = { 'job_id': job_id, 'training_input': new_job_request, } context = datalab.Context.default() cloudml = discovery.build('ml', 'v1', credentials=context.credentials) request = cloudml.projects().jobs().create(body=job, parent='projects/' + context.project_id) request.headers['user-agent'] = 'GoogleCloudDataLab/1.0' request.execute() return Job(job_id) @staticmethod def submit_batch_prediction(job_request, job_id=None): """Submit a batch prediction job. Args: job_request: the arguments of the training job in a dict. For example, { 'version_name': 'projects/my-project/models/my-model/versions/my-version', 'data_format': 'TEXT', 'input_paths': ['gs://my_bucket/my_file.csv'], 'output_path': 'gs://my_bucket/predict_output', 'region': 'us-central1', 'max_worker_count': 1, } job_id: id for the training job. If None, an id based on timestamp will be generated. Returns: A Job object representing the batch prediction job. """ if job_id is None: job_id = 'prediction_' + datetime.datetime.now().strftime('%y%m%d_%H%M%S') job = { 'job_id': job_id, 'prediction_input': job_request, } context = datalab.Context.default() cloudml = discovery.build('ml', 'v1', credentials=context.credentials) request = cloudml.projects().jobs().create(body=job, parent='projects/' + context.project_id) request.headers['user-agent'] = 'GoogleCloudDataLab/1.0' request.execute() return Job(job_id) class Jobs(object): """Represents a list of Cloud ML jobs for a project.""" def __init__(self, filter=None): """Initializes an instance of a CloudML Job list that is iteratable ("for job in jobs()"). Args: filter: filter string for retrieving jobs, such as "state=FAILED" context: an optional Context object providing project_id and credentials. api: an optional CloudML API client. """ self._filter = filter self._context = datalab.Context.default() self._api = discovery.build('ml', 'v1', credentials=self._context.credentials) self._page_size = 0 def _retrieve_jobs(self, page_token, _): list_info = self._api.projects().jobs().list(parent='projects/' + self._context.project_id, pageToken=page_token, pageSize=self._page_size, filter=self._filter).execute() jobs = list_info.get('jobs', []) self._page_size = self._page_size or len(jobs) page_token = list_info.get('nextPageToken', None) return jobs, page_token def get_iterator(self): """Get iterator of jobs so it can be used as "for model in Jobs().get_iterator()". """ return iter(datalab.utils.Iterator(self._retrieve_jobs)) def list(self, count=10): import IPython data = [{'Id': job['jobId'], 'State': job.get('state', 'UNKNOWN'), 'createTime': job['createTime']} for _, job in zip(range(count), self)] IPython.display.display( datalab.utils.commands.render_dictionary(data, ['Id', 'State', 'createTime'])) ================================================ FILE: google/datalab/ml/_metrics.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. import json import math import numpy as np import pandas as pd from sklearn.metrics import mean_squared_error, mean_absolute_error import google.datalab.bigquery as bq from . import _util class Metrics(object): """Represents a Metrics object that computes metrics from raw evaluation results.""" def __init__(self, input_csv_pattern=None, headers=None, bigquery=None): """ Args: input_csv_pattern: Path to Csv file pattern (with no header). Can be local or GCS path. headers: Csv headers. Required if input_csv_pattern is not None. bigquery: Can be one of: A BigQuery query string. A Bigquery table string. A Query object defined with '%%bq query --name [query_name]'. Raises: ValueError if input_csv_pattern is provided but both headers and schema_file are None. ValueError if but both input_csv_pattern and bigquery are None. """ self._input_csv_files = None self._bigquery = None if input_csv_pattern: self._input_csv_files = _util.glob_files(input_csv_pattern) if not headers: raise ValueError('csv requires headers.') self._headers = headers elif bigquery: self._bigquery = bigquery else: raise ValueError('Either input_csv_pattern or bigquery needs to be provided.') @staticmethod def from_csv(input_csv_pattern, headers=None, schema_file=None): """Create a Metrics instance from csv file pattern. Args: input_csv_pattern: Path to Csv file pattern (with no header). Can be local or GCS path. headers: Csv headers. schema_file: Path to a JSON file containing BigQuery schema. Used if "headers" is None. Returns: a Metrics instance. Raises: ValueError if both headers and schema_file are None. """ if headers is not None: names = headers elif schema_file is not None: with _util.open_local_or_gcs(schema_file, mode='r') as f: schema = json.load(f) names = [x['name'] for x in schema] else: raise ValueError('Either headers or schema_file is needed') metrics = Metrics(input_csv_pattern=input_csv_pattern, headers=names) return metrics @staticmethod def from_bigquery(sql): """Create a Metrics instance from a bigquery query or table. Returns: a Metrics instance. Args: sql: A BigQuery table name or a query. """ if isinstance(sql, bq.Query): sql = sql._expanded_sql() parts = sql.split('.') if len(parts) == 1 or len(parts) > 3 or any(' ' in x for x in parts): sql = '(' + sql + ')' # query, not a table name else: sql = '`' + sql + '`' # table name metrics = Metrics(bigquery=sql) return metrics def _get_data_from_csv_files(self): """Get data from input csv files.""" all_df = [] for file_name in self._input_csv_files: with _util.open_local_or_gcs(file_name, mode='r') as f: all_df.append(pd.read_csv(f, names=self._headers)) df = pd.concat(all_df, ignore_index=True) return df def _get_data_from_bigquery(self, queries): """Get data from bigquery table or query.""" all_df = [] for query in queries: all_df.append(query.execute().result().to_dataframe()) df = pd.concat(all_df, ignore_index=True) return df def accuracy(self): """Get accuracy numbers for each target and overall. Returns: A DataFrame with two columns: 'class' and 'accuracy'. It also contains the overall accuracy with class being '_all'. Raises: Exception if the CSV headers do not include 'target' or 'predicted', or BigQuery does not return 'target' or 'predicted' column. """ if self._input_csv_files: df = self._get_data_from_csv_files() if 'target' not in df or 'predicted' not in df: raise ValueError('Cannot find "target" or "predicted" column') labels = sorted(set(df['target']) | set(df['predicted'])) accuracy_results = [] for label in labels: correct_count = len(df[(df['target'] == df['predicted']) & (df['target'] == label)]) total_count = len(df[(df['target'] == label)]) accuracy_results.append({ 'target': label, 'accuracy': float(correct_count) / total_count if total_count > 0 else 0, 'count': total_count }) total_correct_count = len(df[(df['target'] == df['predicted'])]) if len(df) > 0: total_accuracy = float(total_correct_count) / len(df) accuracy_results.append({'target': '_all', 'accuracy': total_accuracy, 'count': len(df)}) return pd.DataFrame(accuracy_results) elif self._bigquery: query = bq.Query(""" SELECT target, SUM(CASE WHEN target=predicted THEN 1 ELSE 0 END)/COUNT(*) as accuracy, COUNT(*) as count FROM %s GROUP BY target""" % self._bigquery) query_all = bq.Query(""" SELECT "_all" as target, SUM(CASE WHEN target=predicted THEN 1 ELSE 0 END)/COUNT(*) as accuracy, COUNT(*) as count FROM %s""" % self._bigquery) df = self._get_data_from_bigquery([query, query_all]) return df def roc(self, num_thresholds, target_class, probability_column=None): """Get true positive rate, false positive rate values from evaluation results. Args: num_thresholds: an integer. Number of thresholds. target_class: a string indciating the target class, i.e. "daisy" in flower classification. probability_column: the name of the probability column. If None, defaults to value of target_class. Returns: A DataFrame with columns: 'tpr', 'fpr', 'threshold' with number of rows equal to num_thresholds. Raises: Exception if the CSV headers do not include 'target' or probability_column, or BigQuery does not return 'target' or probability_column column. """ if not probability_column: probability_column = target_class thresholds = np.linspace(0, 1, num_thresholds + 1) if self._input_csv_files: df = self._get_data_from_csv_files() if 'target' not in df or probability_column not in df: raise ValueError('Cannot find "target" or "%s" column' % probability_column) total_positive = sum(1 for x in df['target'] if x == target_class) total_negative = len(df) - total_positive true_positives, false_positives = [], [] for threshold in thresholds: true_positive_count = len(df[(df[probability_column] > threshold) & (df['target'] == target_class)]) false_positive_count = len(df[(df[probability_column] > threshold) & (df['target'] != target_class)]) true_positives.append(true_positive_count) false_positives.append(false_positive_count) data = [] for tp, fp, t in zip(true_positives, false_positives, thresholds): tpr = (float)(tp) / total_positive if total_positive > 0. else 0. fpr = (float)(fp) / total_negative if total_negative > 0. else 0. data.append({'tpr': tpr, 'fpr': fpr, 'threshold': t}) return pd.DataFrame(data) elif self._bigquery: true_positive_query = bq.Query(""" SELECT COUNT(*) as true_positive FROM %s CROSS JOIN (SELECT * FROM UNNEST ([%s]) as t) WHERE %s > t AND target = '%s' GROUP BY t ORDER BY t """ % (self._bigquery, ','.join(map(str, thresholds)), probability_column, target_class)) false_positive_query = bq.Query(""" SELECT COUNT(*) as false_positive FROM %s CROSS JOIN (SELECT * FROM UNNEST ([%s]) as t) WHERE %s > t AND target != '%s' GROUP BY t ORDER BY t """ % (self._bigquery, ','.join(map(str, thresholds)), probability_column, target_class)) total_positive_query = bq.Query(""" SELECT COUNT(*) as total_positive FROM %s WHERE target = '%s' """ % (self._bigquery, target_class)) total_negative_query = bq.Query(""" SELECT COUNT(*) as total_negative FROM %s WHERE target != '%s' """ % (self._bigquery, target_class)) true_positives = true_positive_query.execute().result() false_positives = false_positive_query.execute().result() total_positive = total_positive_query.execute().result()[0]['total_positive'] total_negative = total_negative_query.execute().result()[0]['total_negative'] data = [] for tp, fp, t in zip(true_positives, false_positives, thresholds): tpr = (float)(tp['true_positive']) / total_positive if total_positive > 0. else 0. fpr = (float)(fp['false_positive']) / total_negative if total_negative > 0. else 0. data.append({'tpr': tpr, 'fpr': fpr, 'threshold': t}) data.append({'tpr': 0., 'fpr': 0., 'threshold': 1.0}) return pd.DataFrame(data) def precision_recall(self, num_thresholds, target_class, probability_column=None): """Get precision, recall values from evaluation results. Args: num_thresholds: an integer. Number of thresholds. target_class: a string indciating the target class, i.e. "daisy" in flower classification. probability_column: the name of the probability column. If None, defaults to value of target_class. Returns: A DataFrame with columns: 'threshold', 'precision', 'recall' with number of rows equal to num_thresholds. Raises: Exception if the CSV headers do not include 'target' or probability_column, or BigQuery does not return 'target' or probability_column column. """ if not probability_column: probability_column = target_class # threshold = 1.0 is excluded. thresholds = np.linspace(0, 1, num_thresholds + 1)[0:-1] if self._input_csv_files: df = self._get_data_from_csv_files() if 'target' not in df or probability_column not in df: raise ValueError('Cannot find "target" or "%s" column' % probability_column) total_target = sum(1 for x in df['target'] if x == target_class) total_predicted = [] correct_predicted = [] for threshold in thresholds: predicted_count = sum(1 for x in df[probability_column] if x > threshold) total_predicted.append(predicted_count) correct_count = len(df[(df[probability_column] > threshold) & (df['target'] == target_class)]) correct_predicted.append(correct_count) data = [] for p, c, t in zip(total_predicted, correct_predicted, thresholds): precision = (float)(c) / p if p > 0. else 0. recall = (float)(c) / total_target if total_target > 0. else 0. data.append({'precision': precision, 'recall': recall, 'threshold': t}) return pd.DataFrame(data) elif self._bigquery: total_predicted_query = bq.Query(""" SELECT COUNT(*) as total_predicted FROM %s CROSS JOIN (SELECT * FROM UNNEST ([%s]) as t) WHERE %s > t GROUP BY t ORDER BY t """ % (self._bigquery, ','.join(map(str, thresholds)), probability_column)) correct_predicted_query = bq.Query(""" SELECT COUNT(*) as correct_predicted FROM %s CROSS JOIN (SELECT * FROM UNNEST ([%s]) as t) WHERE %s > t AND target='%s' GROUP BY t ORDER BY t """ % (self._bigquery, ','.join(map(str, thresholds)), probability_column, target_class)) total_target_query = bq.Query(""" SELECT COUNT(*) as total_target FROM %s WHERE target='%s' """ % (self._bigquery, target_class)) total_predicted = total_predicted_query.execute().result() correct_predicted = correct_predicted_query.execute().result() total_target = total_target_query.execute().result()[0]['total_target'] data = [] for p, c, t in zip(total_predicted, correct_predicted, thresholds): precision = ((float)(c['correct_predicted']) / p['total_predicted'] if p['total_predicted'] > 0. else 0.) recall = (float)(c['correct_predicted']) / total_target if total_target > 0. else 0. data.append({'precision': precision, 'recall': recall, 'threshold': t}) return pd.DataFrame(data) def rmse(self): """Get RMSE for regression model evaluation results. Returns: the RMSE float number. Raises: Exception if the CSV headers do not include 'target' or 'predicted', or BigQuery does not return 'target' or 'predicted' column, or if target or predicted is not number. """ if self._input_csv_files: df = self._get_data_from_csv_files() if 'target' not in df or 'predicted' not in df: raise ValueError('Cannot find "target" or "predicted" column') df = df[['target', 'predicted']].apply(pd.to_numeric) # if df is empty or contains non-numeric, scikit learn will raise error. mse = mean_squared_error(df['target'], df['predicted']) return math.sqrt(mse) elif self._bigquery: query = bq.Query(""" SELECT SQRT(SUM(ABS(predicted-target) * ABS(predicted-target)) / COUNT(*)) as rmse FROM %s""" % self._bigquery) df = self._get_data_from_bigquery([query]) if df.empty: return None return df['rmse'][0] def mae(self): """Get MAE (Mean Absolute Error) for regression model evaluation results. Returns: the MAE float number. Raises: Exception if the CSV headers do not include 'target' or 'predicted', or BigQuery does not return 'target' or 'predicted' column, or if target or predicted is not number. """ if self._input_csv_files: df = self._get_data_from_csv_files() if 'target' not in df or 'predicted' not in df: raise ValueError('Cannot find "target" or "predicted" column') df = df[['target', 'predicted']].apply(pd.to_numeric) mae = mean_absolute_error(df['target'], df['predicted']) return mae elif self._bigquery: query = bq.Query(""" SELECT SUM(ABS(predicted-target)) / COUNT(*) as mae FROM %s""" % self._bigquery) df = self._get_data_from_bigquery([query]) if df.empty: return None return df['mae'][0] def percentile_nearest(self, percentile): """Get nearest percentile from regression model evaluation results. Args: percentile: a 0~100 float number. Returns: the percentile float number. Raises: Exception if the CSV headers do not include 'target' or 'predicted', or BigQuery does not return 'target' or 'predicted' column, or if target or predicted is not number. """ if self._input_csv_files: df = self._get_data_from_csv_files() if 'target' not in df or 'predicted' not in df: raise ValueError('Cannot find "target" or "predicted" column') df = df[['target', 'predicted']].apply(pd.to_numeric) abs_errors = np.array((df['target'] - df['predicted']).apply(abs)) return np.percentile(abs_errors, percentile, interpolation='nearest') elif self._bigquery: query = bq.Query(""" SELECT PERCENTILE_DISC(ABS(predicted-target), %f) OVER() AS percentile FROM %s LIMIT 1""" % (float(percentile) / 100, self._bigquery)) df = self._get_data_from_bigquery([query]) if df.empty: return None return df['percentile'][0] ================================================ FILE: google/datalab/ml/_summary.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import datetime import fnmatch import matplotlib.pyplot as plt import os import pandas as pd import six import tensorflow as tf from tensorflow.core.util import event_pb2 from tensorflow.python.lib.io import tf_record class Summary(object): """Represents TensorFlow summary events from files under specified directories.""" def __init__(self, paths): """Initializes an instance of a Summary. Args: path: a path or a list of paths to directories which hold TensorFlow events files. Can be local path or GCS paths. Wild cards allowed. """ if isinstance(paths, six.string_types): self._paths = [paths] else: self._paths = paths def _glob_events_files(self, paths, recursive): """Find all tf events files under a list of paths recursively. """ event_files = [] for path in paths: dirs = tf.gfile.Glob(path) dirs = filter(lambda x: tf.gfile.IsDirectory(x), dirs) for dir in dirs: if recursive: dir_files_pair = [(root, filenames) for root, _, filenames in tf.gfile.Walk(dir)] else: dir_files_pair = [(dir, tf.gfile.ListDirectory(dir))] for root, filenames in dir_files_pair: file_names = fnmatch.filter(filenames, '*.tfevents.*') file_paths = [os.path.join(root, x) for x in file_names] file_paths = filter(lambda x: not tf.gfile.IsDirectory(x), file_paths) event_files += file_paths return event_files def list_events(self): """List all scalar events in the directory. Returns: A dictionary. Key is the name of a event. Value is a set of dirs that contain that event. """ event_dir_dict = collections.defaultdict(set) for event_file in self._glob_events_files(self._paths, recursive=True): dir = os.path.dirname(event_file) try: for record in tf_record.tf_record_iterator(event_file): event = event_pb2.Event.FromString(record) if event.summary is None or event.summary.value is None: continue for value in event.summary.value: if value.simple_value is None or value.tag is None: continue event_dir_dict[value.tag].add(dir) except tf.errors.DataLossError: # DataLossError seems to happen sometimes for small logs. # We want to show good records regardless. continue return dict(event_dir_dict) def get_events(self, event_names): """Get all events as pandas DataFrames given a list of names. Args: event_names: A list of events to get. Returns: A list with the same length and order as event_names. Each element is a dictionary {dir1: DataFrame1, dir2: DataFrame2, ...}. Multiple directories may contain events with the same name, but they are different events (i.e. 'loss' under trains_set/, and 'loss' under eval_set/.) """ if isinstance(event_names, six.string_types): event_names = [event_names] all_events = self.list_events() dirs_to_look = set() for event, dirs in six.iteritems(all_events): if event in event_names: dirs_to_look.update(dirs) ret_events = [collections.defaultdict(lambda: pd.DataFrame(columns=['time', 'step', 'value'])) for i in range(len(event_names))] for event_file in self._glob_events_files(dirs_to_look, recursive=False): try: for record in tf_record.tf_record_iterator(event_file): event = event_pb2.Event.FromString(record) if event.summary is None or event.wall_time is None or event.summary.value is None: continue event_time = datetime.datetime.fromtimestamp(event.wall_time) for value in event.summary.value: if value.tag not in event_names or value.simple_value is None: continue index = event_names.index(value.tag) dir_event_dict = ret_events[index] dir = os.path.dirname(event_file) # Append a row. df = dir_event_dict[dir] df.loc[len(df)] = [event_time, event.step, value.simple_value] except tf.errors.DataLossError: # DataLossError seems to happen sometimes for small logs. # We want to show good records regardless. continue for idx, dir_event_dict in enumerate(ret_events): for df in dir_event_dict.values(): df.sort_values(by=['time'], inplace=True) ret_events[idx] = dict(dir_event_dict) return ret_events def plot(self, event_names, x_axis='step'): """Plots a list of events. Each event (a dir+event_name) is represetented as a line in the graph. Args: event_names: A list of events to plot. Each event_name may correspond to multiple events, each in a different directory. x_axis: whether to use step or time as x axis. """ if isinstance(event_names, six.string_types): event_names = [event_names] events_list = self.get_events(event_names) for event_name, dir_event_dict in zip(event_names, events_list): for dir, df in six.iteritems(dir_event_dict): label = event_name + ':' + dir x_column = df['step'] if x_axis == 'step' else df['time'] plt.plot(x_column, df['value'], label=label) plt.legend(loc='best') plt.show() ================================================ FILE: google/datalab/ml/_tensorboard.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. try: import IPython except ImportError: raise Exception('This module can only be loaded in ipython.') import argparse import os import pandas as pd import psutil import subprocess import time import google.datalab as datalab class TensorBoard(object): """Start, shutdown, and list TensorBoard instances.""" @staticmethod def list(): """List running TensorBoard instances.""" running_list = [] parser = argparse.ArgumentParser() parser.add_argument('--logdir') parser.add_argument('--port') for p in psutil.process_iter(): if p.name() != 'tensorboard' or p.status() == psutil.STATUS_ZOMBIE: continue cmd_args = p.cmdline() del cmd_args[0:2] # remove 'python' and 'tensorboard' args = parser.parse_args(cmd_args) running_list.append({'pid': p.pid, 'logdir': args.logdir, 'port': args.port}) return pd.DataFrame(running_list) @staticmethod def start(logdir): """Start a TensorBoard instance. Args: logdir: the logdir to run TensorBoard on. Raises: Exception if the instance cannot be started. """ if logdir.startswith('gs://'): # Check user does have access. TensorBoard will start successfully regardless # the user has read permissions or not so we check permissions here to # give user alerts if needed. datalab.storage._api.Api.verify_permitted_to_read(logdir) port = datalab.utils.pick_unused_port() args = ['tensorboard', '--logdir=' + logdir, '--port=' + str(port)] p = subprocess.Popen(args) retry = 10 while (retry > 0): if datalab.utils.is_http_running_on(port): basepath = os.environ.get('DATALAB_ENDPOINT_URL', '') url = '%s/_proxy/%d/' % (basepath.rstrip('/'), port) html = '

TensorBoard was started successfully with pid %d. ' % p.pid html += 'Click here to access it.

' % url IPython.display.display_html(html, raw=True) return p.pid time.sleep(1) retry -= 1 raise Exception('Cannot start TensorBoard.') @staticmethod def stop(pid): """Shut down a specific process. Args: pid: the pid of the process to shutdown. """ if psutil.pid_exists(pid): try: p = psutil.Process(pid) p.kill() except Exception: pass ================================================ FILE: google/datalab/ml/_util.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from googleapiclient import discovery import os import shutil import subprocess import tempfile import tensorflow as tf import time import datalab.context import google.datalab.utils as dlutils # TODO: Create an Operation class. def wait_for_long_running_operation(operation_full_name): print('Waiting for operation "%s"' % operation_full_name) api = discovery.build('ml', 'v1', credentials=datalab.context.Context.default().credentials) while True: response = api.projects().operations().get(name=operation_full_name).execute() if 'done' not in response or response['done'] is not True: time.sleep(3) else: if 'error' in response: print(response['error']) else: print('Done.') break def package_and_copy(package_root_dir, setup_py, output_tar_path): """Repackage an CloudML package and copy it to a staging dir. Args: package_root_dir: the root dir to install package from. Usually you can get the path from inside your module using a relative path to __file__. setup_py: the path to setup.py. output_tar_path: the GCS path of the output tarball package. Raises: ValueError if output_tar_path is not a GCS path, or setup_py does not exist. """ if not output_tar_path.startswith('gs://'): raise ValueError('output_tar_path needs to be a GCS path.') if not os.path.isfile(setup_py): raise ValueError('Supplied file "%s" does not exist.' % setup_py) dest_setup_py = os.path.join(package_root_dir, 'setup.py') if dest_setup_py != setup_py: # setuptools requires a "setup.py" in the current dir, so copy setup.py there. # Also check if there is an existing setup.py. If so, back it up. if os.path.isfile(dest_setup_py): os.rename(dest_setup_py, dest_setup_py + '._bak_') shutil.copyfile(setup_py, dest_setup_py) tempdir = tempfile.mkdtemp() previous_cwd = os.getcwd() os.chdir(package_root_dir) try: # Repackage. sdist = ['python', dest_setup_py, 'sdist', '--format=gztar', '-d', tempdir] subprocess.check_call(sdist) # Copy to GCS. source = os.path.join(tempdir, '*.tar.gz') gscopy = ['gsutil', 'cp', source, output_tar_path] subprocess.check_call(gscopy) return finally: os.chdir(previous_cwd) if dest_setup_py != setup_py: os.remove(dest_setup_py) if os.path.isfile(dest_setup_py + '._bak_'): os.rename(dest_setup_py + '._bak_', dest_setup_py) shutil.rmtree(tempdir) def read_file_to_string(path): """Read a file into a string.""" bytes_string = tf.gfile.Open(path, 'r').read() return dlutils.python_portable_string(bytes_string) def open_local_or_gcs(path, mode): """Opens the given path.""" return tf.gfile.Open(path, mode) def glob_files(path): """Glob the given path.""" return tf.gfile.Glob(path) ================================================ FILE: google/datalab/notebook/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Datalab - notebook extension functionality.""" try: import IPython as _ except ImportError: raise Exception('This package requires an IPython notebook installation') __all__ = ['_'] def _jupyter_nbextension_paths(): return [dict(section="notebook", src="static", dest="gcpdatalab")] ================================================ FILE: google/datalab/notebook/static/bigquery.css ================================================ /* * Copyright 2015 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ table.bqsv { font-family: inherit; font-size: smaller; } table.bqsv th, table.bqsv td { border: solid 1px #cfcfcf; } th.bqsv_expanded, th.bqsv_collapsed { background-color: #f7f7f7; } th.bqsv_colheader { font-weight: bold; background-color: #e7e7e7; } tbody.bqsv_hidden { display: none; } th.bqsv_expanded:before { content: '\25be ' } th.bqsv_collapsed:before { content: '\25b8 ' } ================================================ FILE: google/datalab/notebook/static/bigquery.ts ================================================ /* * Copyright 2015 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ /// module BigQuery { // Event handler to toggle visibility of a nested schema table. function _toggleNode(e: any): void { var node = e.target; var expand = node.className == 'bqsv_collapsed'; node.className = expand ? 'bqsv_expanded' : 'bqsv_collapsed'; var tgroup = node.parentNode.nextSibling; tgroup.className = expand ? 'bqsv_visible' : 'bqsv_hidden'; } // Helper function to recursively render a table schema. function _renderSchema(table: any, schema: any, title: string, includeColumnHeaders: boolean, columns: any): void { // Create a tbody element to hold the entities for this level. We group them so // we can easily collapse/expand the level. var tbody = document.createElement('tbody'); for (var i = 0; i < schema.length; i++) { if (i == 0) { if (title.length > 0) { // title.length > 0 implies we are in a nested table. Create a title header row // for this nested table with a click handler and hide the tbody. tbody.className = 'bqsv_hidden'; var th = document.createElement('th'); th.colSpan = columns.length; th.className = 'bqsv_collapsed'; th.textContent = title.substring(1); // skip the leading '.' th.addEventListener('click', _toggleNode); var tr = document.createElement('tr'); tr.appendChild(th); table.appendChild(tr); } else { // We are in the top-level table; add a header row with the column labels. tbody.className = 'bqsv_visible'; if (includeColumnHeaders) { // First line; show column headers. var tr = document.createElement('tr'); for (var j = 0; j < columns.length; j++) { var th = document.createElement('th'); th.textContent = columns[j]; th.className = 'bqsv_colheader'; tr.appendChild(th); } table.appendChild(tr); } } } // Add the details for the current row to the tbody. var field = schema[i]; var tr = document.createElement('tr'); for (var j = 0; j < columns.length; j++) { var td = document.createElement('td'); var v = field[columns[j]]; td.textContent = v == undefined ? '' : v; tr.appendChild(td); } tbody.appendChild(tr); } // Add the tbody with all the rows to the table. table.appendChild(tbody); // Recurse into any nested tables. for (var i = 0; i < schema.length; i++) { var field = schema[i]; if (field.type == 'RECORD') { _renderSchema(table, field.fields, title + '.' + field.name, false, columns); } } } // Top-level public function for schema rendering. export function renderSchema(dom: any, schema: any) { var columns = ['name', 'type', 'mode', 'description']; var table = document.createElement('table'); table.className = 'bqsv'; _renderSchema(table, schema, '', /*includeColumnHeaders*/ true, columns); dom.appendChild(table); } } export = BigQuery; ================================================ FILE: google/datalab/notebook/static/charting.css ================================================ /* * Copyright 2015 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ table.google-visualization-table-table, table.dataframe { font-family: inherit; font-size: smaller; } tr.gchart-table-row { } tr.gchart-table-headerrow, table.dataframe thead th { font-weight: bold; background-color: #e7e7e7; } tr.gchart-table-oddrow, table.dataframe tr:nth-child(odd) { background-color: #f7f7f7; } tr.gchart-table-selectedTableRow { background-color: #e3f2fd; } tr.gchart-table-hoverrow, table.dataframe tr:hover { background-color: #bbdefb; } td.gchart-table-cell, table.dataframe td { border: solid 1px #cfcfcf; } td.gchart-table-rownumcell, table.dataframe tr th { border: solid 1px #cfcfcf; color: #999; } th.gchart-table-headercell, table.dataframe th { border: solid 1px #cfcfcf; } div.bqgc { display: flex; justify-content: center; } div.bqgc img { max-width: none; // Fix the conflict with maps and Bootstrap that messes up zoom controls. } .gchart-slider { width: 80%; float: left; } .gchart-slider-value { text-align: center; float: left; width: 20%; } .gchart-control { padding-top: 10px; padding-bottom: 10px; } .gchart-controls { font-size: 14px; color: #333333; background: #f4f4f4; padding: 10px; width: 180px; float: left; } .bqgc { padding: 0; max-width: 100%; } .bqgc-controlled { display: flex; flex-direction: row; justify-content:space-between; } .bqgc-container { display: block; } .bqgc-ml-metrics { display: flex; flex-direction: row; justify-content:left; } ================================================ FILE: google/datalab/notebook/static/charting.ts ================================================ /* * Copyright 2015 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ /// module Charting { declare var IPython:any; declare var datalab:any; // Wrappers for Plotly.js and Google Charts abstract class ChartLibraryDriver { chartModule:any; constructor(protected dom:HTMLElement, protected chartStyle:string) { } abstract requires(url: string, chartStyle:string):Array; init(chartModule:any):void { this.chartModule = chartModule; } abstract draw(data:any, options:any):void; abstract getStaticImage(callback:Function):void; abstract addChartReadyHandler(handler:Function):void; addPageChangedHandler(handler:Function):void { } error(message:string):void { } } class PlotlyDriver extends ChartLibraryDriver { readyHandler:any; constructor(dom:HTMLElement, chartStyle:string) { super(dom, chartStyle) } requires(url: string, chartStyle:string):Array { return ['d3', 'plotly']; } public draw(data:any, options:any):void { /* * TODO(gram): if we start moving more chart types over to Plotly.js we should change the * shape of the data we pass to render so we don't need to reshape it here. Also, a fair * amount of the computation done here could be moved to Python code. We should just be * passing in the mostly complete layout object in JSON, for example. */ var xlabels: Array = []; var points: Array = []; var layout: any = { xaxis: {}, yaxis: {}, height: 300, margin: { b: 60, t: 60, l: 60, r: 60 } }; if (options.title) { layout.title = options.title; } var minX: number = undefined; var maxX: number = undefined; if ('hAxis' in options) { if ('minValue' in options.hAxis) { minX = options.hAxis.minValue; } if ('maxValue' in options.hAxis) { maxX = options.hAxis.maxValue; } if (minX != undefined || maxX != undefined) { layout.xaxis.range = [minX, maxX]; } } var minY: number = undefined; var maxY: number = undefined; if ('vAxis' in options) { if ('minValue' in options.vAxis) { minY = options.vAxis.minValue; } else if ('minValues' in options.vAxis) { minY = options.vAxis.minValues[0]; } if ('maxValue' in options.vAxis) { maxY = options.vAxis.maxValue; } else if ('maxValues' in options.vAxis) { maxY = options.vAxis.maxValues[0]; } if (minY != undefined || maxY != undefined) { layout.yaxis.range = [minY, maxY]; } if ('minValues' in options.vAxis) { minY = options.vAxis.minValues[1]; // for second axis below } if ('maxValues' in options.vAxis) { maxY = options.vAxis.maxValues[1]; // for second axis below } } if (options.xAxisTitle) { layout.xaxis.title = options.xAxisTitle; } if (options.xAxisSide) { layout.xaxis.side = options.xAxisSide; } if (options.yAxisTitle) { layout.yaxis.title = options.yAxisTitle; } if (options.yAxesTitles) { layout.yaxis.title = options.yAxesTitles[0]; layout.yaxis2 = { title: options.yAxesTitles[1], side: 'right', overlaying: 'y' }; if (minY != undefined || maxY != undefined) { layout.yaxis2.range = [minY, maxY]; } } if ('width' in options) { layout.width = options.width; } if ('height' in options) { layout.height = options.height; if ('width' in options) { layout.autosize = false; } } var pdata: Array = []; if (this.chartStyle == 'line' || this.chartStyle == 'scatter') { var hoverCol: number = 0; var x: Array = []; // First col is X, other cols are Y's and optional hover text only column var y: Array = []; var hover: Array = []; for (var c = 1; c < data.cols.length; c++) { x[c - 1] = []; y[c - 1] = []; var line:any = { x: x[c - 1], y: y[c - 1], name: data.cols[c].label, type: 'scatter', mode: this.chartStyle == 'scatter' ? 'markers' : 'lines' }; if (options.hoverOnly) { hover[c - 1] = []; line.text = hover[c - 1]; line.hoverinfo = 'text'; } if (options.yAxesTitles && (c % 2) == 0) { line.yaxis = 'y2'; } pdata.push(line); } for (var c = 1; c < data.cols.length; c++) { if (c == hoverCol) { continue; } for (var r = 0; r < data.rows.length; r++) { var entry:Array = data.rows[r].c; if ('v' in entry[c]) { var xVal = entry[0].v; var yVal = entry[c].v; if (options.hoverOnly) { // Each column is a dict with two values, one for y and one for // hover. Extract these. var hoverVal:any; var yDict:any = yVal; for (var prop in yDict) { var val = yDict[prop]; if (prop == options.hoverOnly) { hoverVal = val; } else { yVal = val; } } // TODO(gram): we may want to add explicit hover text this even without hoverOnly. var xlabel:any = options.xAxisTitle || data.cols[0].label; var ylabel:any = options.yAxisTitle || data.cols[c].label; var prefix = ''; if (options.yAxisTitle) { prefix += data.cols[c].label + ': '; } hover[c - 1].push(prefix + options.hoverOnly + '=' + hoverVal + ', ' + xlabel + '=' + xVal + ', ' + ylabel + '=' + yVal); } x[c - 1].push(xVal); y[c - 1].push(yVal); } } } } else if (this.chartStyle == 'heatmap') { var size:number = 200 + data.cols.length * 50; if (size > 800) size = 800; layout.height = size; layout.width = size; layout.autosize = false; for (var i = 0; i < data.cols.length; i++) { xlabels[i] = data.cols[i].label; } var ylabels = [].concat(xlabels); // Plotly draws the first row at the bottom, not the top, so we need // to reverse the y and z array ordering. // We will need to tweak this a bit if we later support non-square maps. ylabels.reverse(); var hovertext: Array> = []; var hoverx:string = options.xAxisTitle || 'x'; var hovery = options.yAxisTitle || 'y'; for (var i = 0; i < data.rows.length; i++) { var entry:Array = data.rows[i].c; var row:Array = []; var hoverrow:Array = []; for (var j = 0; j < data.cols.length; j++) { row[j] = entry[j].v; hoverrow[j] = hoverx + '= ' + xlabels[j] + ', ' + hovery + '= ' + ylabels[i] + ': ' + row[j]; } points[i] = row; hovertext[i] = hoverrow; } points.reverse(); layout.hovermode = 'closest'; pdata = [{ x: xlabels, y: ylabels, z: points, type: 'heatmap', text: hovertext, hoverinfo: 'text' }]; if (options.colorScale) { pdata[0].colorscale = [ [0, options.colorScale.min], [1, options.colorScale.max] ]; } else { pdata[0].colorscale = [ [0, 'red'], [0.5, 'gray'], [1, 'blue'] ]; } if (options.hideScale) { pdata[0].showscale = false; } if (options.annotate) { layout.annotations = []; for (var i = 0; i < pdata[0].y.length; i++) { for (var j = 0; j < pdata[0].x.length; j++) { var currentValue = pdata[0].z[i][j]; var textColor = (currentValue == 0.0) ? 'black' : 'white'; var result = { xref: 'x1', yref: 'y1', x: pdata[0].x[j], y: pdata[0].y[i], text: pdata[0].z[i][j].toPrecision(3), showarrow: false, font: { color: textColor } }; layout.annotations.push(result); } } } } this.chartModule.newPlot(this.dom.id, pdata, layout, {displayModeBar: false}); if (this.readyHandler) { this.readyHandler(); } } getStaticImage(callback:Function):void { this.chartModule.Snapshot.toImage(document.getElementById(this.dom.id), {format: 'png'}).once('success', function (url:string) { callback(this.model, url); }); } addChartReadyHandler(handler:Function):void { this.readyHandler = handler; } } interface IStringMap { [key: string]: string; } class GChartsDriver extends ChartLibraryDriver { chart:any; nameMap: IStringMap = { annotation: 'AnnotationChart', area: 'AreaChart', columns: 'ColumnChart', bars: 'BarChart', bubbles: 'BubbleChart', calendar: 'Calendar', candlestick: 'CandlestickChart', combo: 'ComboChart', gauge: 'Gauge', geo: 'GeoChart', histogram: 'Histogram', line: 'LineChart', map: 'Map', org: 'OrgChart', paged_table: 'Table', pie: 'PieChart', sankey: 'Sankey', scatter: 'ScatterChart', stepped_area: 'SteppedAreaChart', table: 'Table', timeline: 'Timeline', treemap: 'TreeMap', }; scriptMap: IStringMap = { annotation: 'annotationchart', calendar: 'calendar', gauge: 'gauge', geo: 'geochart', map: 'map', org: 'orgchart', paged_table: 'table', sankey: 'sankey', table: 'table', timeline: 'timeline', treemap: 'treemap' }; constructor(dom:HTMLElement, chartStyle:string) { super(dom, chartStyle); } requires(url: string, chartStyle:string):Array { var chartScript:string = 'corechart'; if (chartStyle in this.scriptMap) { chartScript = this.scriptMap[chartStyle]; } return [url + 'visualization!' + chartScript]; } init(chartModule:any):void { super.init(chartModule); var constructor:Function = this.chartModule[this.nameMap[this.chartStyle]]; this.chart = new (constructor)(this.dom); } error(message:string):void { this.chartModule.errors.addError(this.dom, 'Unable to render the chart', message, {showInTooltip: false}); } draw(data:any, options:any):void { console.log('Drawing with options ' + JSON.stringify(options)); this.chart.draw(new this.chartModule.DataTable(data), options); } getStaticImage(callback:Function):void { if (this.chart.getImageURI) { callback(this.chart.getImageURI()); } } addChartReadyHandler(handler:Function) { this.chartModule.events.addListener(this.chart, 'ready', handler); } addPageChangedHandler(handler:Function) { this.chartModule.events.addListener(this.chart, 'page', function (e:any) { handler(e.page); }); } } class Chart { dataCache:any; // TODO: add interface types for the caches. optionsCache:any; hasIPython:boolean; cellElement:HTMLElement; totalRows:number; constructor(protected driver:ChartLibraryDriver, protected dom:Element, protected controlIds:Array, protected base_options:any, protected refreshData:any, protected refreshInterval:number, totalRows:number) { this.totalRows = totalRows || -1; // Total rows in all (server-side) data. this.dataCache = {}; this.optionsCache = {}; this.hasIPython = false; try { if (IPython && IPython.notebook) { this.hasIPython = true; } } catch (e) { } (this.dom).innerHTML = ''; this.removeStaticChart(); this.addControls(); // Generate and add a new static chart once chart is ready. var _this = this; this.driver.addChartReadyHandler(function () { _this.addStaticChart(); }); } // Convert any string fields that are date type to JS Dates. public convertDates(data:any):void { // Format timestamps in the same way as in dataframes. const timestampFormatter = new this.driver.chartModule.DateFormat({ 'pattern' : 'yyyy-MM-dd HH:mm:ss', 'valueType' : 'datetime', 'timeZone' : -new Date().getTimezoneOffset() / 60, }); // Timestamp formatter with fractional seconds. // BQ and python store time down to the microsecond, but javascript Date // only stores it to the millisecond. const timestampWithFractionalSecondsFormatter = new this.driver.chartModule.DateFormat({ 'pattern' : 'yyyy-MM-dd HH:mm:ss.SSS', 'valueType' : 'datetime', 'timeZone' : -new Date().getTimezoneOffset() / 60, }); // Javascript has terrible support for timezones. When Date objects get converted to // strings, it always applies the local timezone. But we want dates and times to be // printed in UTC so that they match the output of dataframes and other conversions that // are happening in the kernel, which we assume is running in UTC in a docker container. // In order to make this work, we add an offset to our Date objects in an amount equal // to the local timezone offset from UTC so that when those Dates get output as a local // time they will appear as the right UTC time. This is made more confusing by the fact // that date, datetime, and timeofday data types are civil time for which timezone // should not even apply - but since we are passing them along as Date objects, we // pull the same trick with them. We add the 'f' field, for use by Google Charts when // displaying tables, to ensure we have the right string there, but when doing things // like line graphs, that field is not used, so we have to use the Date-offset trick // in order to get dates and times to display correctly as UTC in graphs. function dateAsUtc(localDate:Date):Date { const year = localDate.getUTCFullYear(); const month = localDate.getUTCMonth(); const day = localDate.getUTCDate(); const hours = localDate.getUTCHours(); const minutes = localDate.getUTCMinutes(); const seconds = localDate.getUTCSeconds(); const millis = localDate.getUTCMilliseconds(); return new Date(year, month, day, hours, minutes, seconds, millis); } const rows = data.rows; for (let col = 0; col < data.cols.length; col++) { // date, datetime, and timeofday are civil times that are independent of timezone if (data.cols[col].type == 'date' || data.cols[col].type == 'datetime') { for (let row = 0; row < rows.length; row++) { const v = rows[row].c[col].v; rows[row].c[col].v = dateAsUtc(new Date(v)); rows[row].c[col].f = v; // Display the string as-is to avoid timezone problems. } } else if (data.cols[col].type == 'timeofday') { for (let row = 0; row < rows.length; row++) { const v = rows[row].c[col].v; rows[row].c[col].f = v; // Display the string as-is to avoid timezone problems. const timeInSeconds = v.split('.')[0]; rows[row].c[col].v = timeInSeconds.split(':').map( function(n:string) { return parseInt(n, 10); }); } } else if (data.cols[col].type == 'timestamp') { data.cols[col].type = 'datetime'; // Run through all the dates to determine how to format them. let formatter = timestampFormatter; for (let row = 0; row < rows.length; row++) { const v = new Date(rows[row].c[col].v); if (v.getTime() % 1000 != 0) { formatter = timestampWithFractionalSecondsFormatter; break; } } for (let row = 0; row < rows.length; row++) { const v = new Date(rows[row].c[col].v); // Timestamp is sent back as UTC time string. rows[row].c[col].f = formatter.formatValue(v); rows[row].c[col].v = dateAsUtc(v); } } } } // Extend the properties in a 'base' object with the changes in an 'update' object. // We can add properties or override properties but not delete yet. private static extend(base:any, update:any):void { for (var p in update) { if (typeof base[p] !== 'object' || !base.hasOwnProperty(p)) { base[p] = update[p]; } else { this.extend(base[p], update[p]); } } } // Get the IPython cell associated with this chart. private getCell() { if (!this.hasIPython) { return undefined; } var cells = IPython.notebook.get_cells(); for (var cellIndex in cells) { var cell = cells[cellIndex]; if (cell.element && cell.element.length) { var element = cell.element[0]; var chartDivs = element.getElementsByClassName('bqgc'); if (chartDivs && chartDivs.length) { for (var i = 0; i < chartDivs.length; i++) { if (chartDivs[i].id == this.dom.id) { return cell; } } } } } return undefined; } protected getRefreshHandler(useCache:boolean):Function { var _this = this; return function () { _this.refresh(useCache); }; } // Bind event handlers to the chart controls, if any. private addControls():void { if (!this.controlIds) { return; } var controlHandler = this.getRefreshHandler(true); for (var i = 0; i < this.controlIds.length; i++) { var id = this.controlIds[i]; var split = id.indexOf(':'); var control:HTMLInputElement; if (split >= 0) { // Checkbox group. var count = parseInt(id.substring(split + 1)); var base = id.substring(0, split + 1); for (var j = 0; j < count; j++) { control = document.getElementById(base + j); control.disabled = !this.hasIPython; control.addEventListener('change', function() { controlHandler(); }); } continue; } // See if we have an associated control that needs dual binding. control = document.getElementById(id); if (!control) { // Kernel restart? return; } control.disabled = !this.hasIPython; var textControl = document.getElementById(id + '_value'); if (textControl) { textControl.disabled = !this.hasIPython; textControl.addEventListener('change', function () { if (control.value != textControl.value) { control.value = textControl.value; controlHandler(); } }); control.addEventListener('change', function () { textControl.value = control.value; controlHandler(); }); } else { control.addEventListener('change', function() { controlHandler(); }); } } } // Iterate through any widget controls and build up a JSON representation // of their values that can be passed to the Python kernel as part of the // magic to fetch data (also used as part of the cache key). protected getControlSettings():any { var env:any = {}; if (this.controlIds) { for (var i = 0; i < this.controlIds.length; i++) { var id = this.controlIds[i]; var parts = id.split('__'); var varName = parts[1]; var splitPoint = varName.indexOf(':'); if (splitPoint >= 0) { // this is a checkbox group var count = parseInt(varName.substring(splitPoint + 1)); varName = varName.substring(0, splitPoint); var cbBaseId = parts[0] + '__' + varName + ':'; var list:Array = []; env[varName] = list; for (var j = 0; j < count; j++) { var cb = document.getElementById(cbBaseId + j); if (!cb) { // Stale refresh; user re-executed cell. return undefined; } if (cb.checked) { list.push(cb.value); } } } else { var e = document.getElementById(id); if (!e) { // Stale refresh; user re-executed cell. return undefined; } if (e && e.type == 'checkbox') { // boolean env[varName] = e.checked; } else { // picker/slider/text env[varName] = e.value; } } } } return env; } // Get a string representation of the current environment - i.e. control settings and // refresh data. This is used as a cache key. private getEnvironment():string { var controls:any = this.getControlSettings(); if (controls == undefined) { // This means the user has re-executed the cell and our controls are gone. return undefined; } var env:any = {controls: controls}; Chart.extend(env, this.refreshData); return JSON.stringify(env); } protected refresh(useCache:boolean):void { // TODO(gram): remember last cache key and don't redraw chart if cache // key is the same unless this is an ML key and the number of data points has changed. this.removeStaticChart(); var env:string = this.getEnvironment(); if (env == undefined) { // This means the user has re-executed the cell and our controls are gone. console.log('No chart control environment; abandoning refresh'); return; } if (useCache && env in this.dataCache) { this.draw(this.dataCache[env], this.optionsCache[env]); return; } var code = '%_get_chart_data\n' + env; // TODO: hook into the notebook UI to enable/disable 'Running...' while we fetch more data. if (!this.cellElement) { var cell = this.getCell(); if (cell && cell.element && cell.element.length == 1) { this.cellElement = cell.element[0]; } } // Start the cell spinner in the notebook UI. if (this.cellElement) { this.cellElement.classList.remove('completed'); } var _this = this; datalab.session.execute(code, function (error:string, response:any) { _this.handleNewData(env, error, response); }); } private handleNewData(env: any, error:any, response: any) { var data = response.data; // Stop the cell spinner in the notebook UI. if (this.cellElement) { this.cellElement.classList.add('completed'); } if (data == undefined || data.cols == undefined) { error = 'No data'; } if (error) { this.driver.error(error); return; } this.refreshInterval = response.refresh_interval; if (this.refreshInterval == 0) { console.log('No more refreshes for ' + this.refreshData.name); } this.convertDates(data); var options = this.base_options; if (response.options) { // update any options. We need to make a copy so we don't break the base options. options = JSON.parse(JSON.stringify(options)); Chart.extend(options, response.options); } // Don't update or keep refreshing this if control settings have changed. var newEnv = this.getEnvironment(); if (env == newEnv) { console.log('Got refresh for ' + this.refreshData.name + ', ' + env); this.draw(data, options); } else { console.log('Stopping refresh for ' + env + ' as controls are now ' + newEnv) } } // Remove a static chart (PNG) from the notebook and the DOM. protected removeStaticChart():void { var cell = this.getCell(); if (cell) { var pngDivs = > cell.element[0].getElementsByClassName('output_png'); if (pngDivs) { for (var i = 0; i < pngDivs.length; i++) { pngDivs[i].innerHTML = ''; } } var cell_outputs = cell.output_area.outputs; var changed = true; while (changed) { changed = false; for (var outputIndex in cell_outputs) { var output = cell_outputs[outputIndex]; if (output.output_type == 'display_data' && output.metadata.source_id == this.dom.id) { cell_outputs.splice(outputIndex, 1); changed = true; break; } } } } else { // Not running under IPython; use a different approach and just clear the DOM. // Iterate through the IPython outputs... var outputDivs = document.getElementsByClassName('output_wrapper'); if (outputDivs) { for (var i = 0; i < outputDivs.length; i++) { // ...and any chart outputs in each... var outputDiv = outputDivs[i]; var chartDivs = outputDiv.getElementsByClassName('bqgc'); if (chartDivs) { for (var j = 0; j < chartDivs.length; j++) { // ...until we find the chart div ID we want... if (chartDivs[j].id == this.dom.id) { // ...then get any PNG outputs in that same output group... var pngDivs = >outputDiv. getElementsByClassName('output_png'); if (pngDivs) { for (var k = 0; k < pngDivs.length; k++) { // ... and clear their contents. pngDivs[k].innerHTML = ''; } } return; } } } } } } } // Add a static chart (PNG) to the notebook. The notebook will in turn add it to the DOM when // the notebook is opened. private addStaticChart():void { var _this = this; this.driver.getStaticImage(function (img:string) { _this.handleStaticChart(img); }); } private handleStaticChart(img: string) { if (img) { var cell = this.getCell(); if (cell) { var encoding = img.substr(img.indexOf(',') + 1); // strip leading base64 etc. var static_output = { metadata: { source_id: this.dom.id }, data: { 'image/png': encoding }, output_type: 'display_data' }; cell.output_area.outputs.push(static_output); } } } // Set up a refresh callback if we have a non-zero interval and the DOM element still exists // (i.e. output hasn't been cleared). private configureRefresh(refreshInterval:number):void { if (refreshInterval > 0 && document.getElementById(this.dom.id)) { window.setTimeout(this.getRefreshHandler(false), 1000 * refreshInterval); } } // Cache the current data and options and draw the chart. public draw(data:any, options:any):void { var env:string = this.getEnvironment(); this.dataCache[env] = data; this.optionsCache[env] = options; if ('cols' in data) { this.driver.draw(data, options); } this.configureRefresh(this.refreshInterval); } } //----------------------------------------------------------- // A special version of Chart for supporting paginated data. class PagedTable extends Chart { firstRow:number; pageSize:number; constructor(driver:ChartLibraryDriver, dom:HTMLElement, controlIds:Array, base_options:any, refreshData:any, refreshInterval:number, totalRows:number) { super(driver, dom, controlIds, base_options, refreshData, refreshInterval, totalRows); this.firstRow = 0; // Index of first row being displayed in page. this.pageSize = base_options.pageSize || 25; if (this.base_options.showRowNumber == undefined) { this.base_options.showRowNumber = true; } this.base_options.sort = 'disable'; var __this = this; this.driver.addPageChangedHandler(function (page:number) { __this.handlePageEvent(page); }); } // Get control settings for cache key. For paged table we add the first row offset of the table. protected getControlSettings():any { var env = super.getControlSettings(); if (env) { env.first = this.firstRow; } return env; } public draw(data:any, options:any):void { var count = this.pageSize; options.firstRowNumber = this.firstRow + 1; options.page = 'event'; if (this.totalRows < 0) { // We don't know where the end is, so we should have 'next' button. options.pagingButtonsConfiguration = this.firstRow > 0 ? 'both' : 'next'; } else { count = this.totalRows - this.firstRow; if (count > this.pageSize) { count = this.pageSize; } if (this.firstRow + count < this.totalRows) { // We are not on last page, so we should have 'next' button. options.pagingButtonsConfiguration = this.firstRow > 0 ? 'both' : 'next'; } else { // We are on last page if (this.firstRow == 0) { options.pagingButtonsConfiguration = 'none'; options.page = 'disable'; } else { options.pagingButtonsConfiguration = 'prev'; } } } super.draw(data, options); } // Handle page forward/back events. Page will only be 0 or 1. handlePageEvent(page:number):void { var offset = (page == 0) ? -1 : 1; this.firstRow += offset * this.pageSize; this.refreshData.first = this.firstRow; this.refreshData.count = this.pageSize; this.refresh(true); } } function convertListToDataTable(data:any):any { if (!data || !data.length) { return {cols: [], rows: []}; } var firstItem = data[0]; var names = Object.keys(firstItem); var columns = names.map(function (name) { return {id: name, label: name, type: typeof firstItem[name]} }); var rows = data.map(function (item:any) { var cells = names.map(function (name) { return {v: item[name]}; }); return {c: cells}; }); return {cols: columns, rows: rows}; } // The main render method, called from render() wrapper below. dom is the DOM element // for the chart, model is a set of parameters from Python, and options is a JSON // set of options provided by the user in the cell magic body, which takes precedence over // model. An initial set of data can be passed in as a final optional parameter. function _render(driver:ChartLibraryDriver, dom:HTMLElement, chartStyle:string, controlIds:Array, data:any, options:any, refreshData:any, refreshInterval:number, totalRows:number):void { require(["base/js/namespace"], function(Jupyter: any) { var url = "datalab/"; require(driver.requires(url, chartStyle), function (/* ... */) { // chart module should be last dependency in require() call... var chartModule = arguments[arguments.length - 1]; // See if it needs to be a member. driver.init(chartModule); options = options || {}; var chart:Chart; if (chartStyle == 'paged_table') { chart = new PagedTable(driver, dom, controlIds, options, refreshData, refreshInterval, totalRows); } else { chart = new Chart(driver, dom, controlIds, options, refreshData, refreshInterval, totalRows); } chart.convertDates(data); chart.draw(data, options); // Do we need to do anything to prevent it getting GCed? }); }); } export function render(driverName:string, dom:HTMLElement, events:any, chartStyle:string, controlIds:Array, data:any, options:any, refreshData:any, refreshInterval:number, totalRows:number):void { // If this is HTML from nbconvert we can't support paging so add some text making this clear. if (chartStyle == 'paged_table' && document.hasOwnProperty('_in_nbconverted')) { chartStyle = 'table'; var p = document.createElement("div"); p.innerHTML = '
(Truncated to first page of results)'; dom.parentNode.insertBefore(p, dom.nextSibling); } // Allocate an appropriate driver. var driver:ChartLibraryDriver; if (driverName == 'plotly') { driver = new PlotlyDriver(dom, chartStyle); } else if (driverName == 'gcharts') { driver = new GChartsDriver(dom, chartStyle); } else { throw new Error('Unsupported chart driver ' + driverName); } // Get data in form needed for GCharts. // We shouldn't need this; should be handled by caller. if (!data.cols && !data.rows) { data = this.convertListToDataTable(data); } // If there is no IPython instance, assume that this is being executed in a sandboxed output // environment and render immediately. // If we have a datalab session, we can go ahead and draw the chart; if not, add code to do the // drawing to an event handler for when the kernel is ready. if (!this.hasIPython || IPython.notebook.kernel.is_connected()) { _render(driver, dom, chartStyle, controlIds, data, options, refreshData, refreshInterval, totalRows) } else { // If the kernel is not connected, wait for the event. events.on('kernel_ready.Kernel', function (e:any) { _render(driver, dom, chartStyle, controlIds, data, options, refreshData, refreshInterval, totalRows) }); } } } export = Charting; ================================================ FILE: google/datalab/notebook/static/element.ts ================================================ /* * Copyright 2015 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ /// module Element { // RequireJS plugin to resolve DOM elements. 'use strict'; var pendingCallbacks: any = null; function resolve(cbInfo: any): void { cbInfo.cb(document.getElementById(cbInfo.name)); } function domReadyCallback(): void { if (pendingCallbacks) { // Clear out pendingCallbacks, so any future requests are immediately resolved. var callbacks = pendingCallbacks; pendingCallbacks = null; callbacks.forEach(resolve); } } export function load(name: any, req: any, loadCallback: any, config: any): void { if (config.isBuild) { loadCallback(null); } else { var cbInfo = { name: name, cb: loadCallback }; if (document.readyState == 'loading') { if (!pendingCallbacks) { pendingCallbacks = []; document.addEventListener('DOMContentLoaded', domReadyCallback, false); } pendingCallbacks.push(cbInfo); } else { resolve(cbInfo); } } } } export = Element; ================================================ FILE: google/datalab/notebook/static/extern/d3.parcoords.css ================================================ .parcoords > svg, .parcoords > canvas { /*font: 14px sans-serif;*/ position: absolute; } .parcoords > canvas { pointer-events: none; } .parcoords rect.background { fill: transparent; } .parcoords rect.background:hover { fill: rgba(120,120,120,0.2); } .parcoords .resize rect { fill: rgba(0,0,0,0.1); } .parcoords rect.extent { fill: rgba(255,255,255,0.25); stroke: rgba(0,0,0,0.6); } .parcoords .axis line, .parcoords .axis path { fill: none; stroke: #222; shape-rendering: crispEdges; } .parcoords canvas { opacity: 1; -moz-transition: opacity 0.3s; -webkit-transition: opacity 0.3s; -o-transition: opacity 0.3s; } .parcoords canvas.faded { opacity: 0.25; } .parcoords_grid { text-align: center; } .parcoords_grid .row, .header { clear: left; font-size: 16px; line-height: 18px; height: 18px; } .parcoords_grid .row:nth-child(odd) { background: rgba(0,0,0,0.05); } .parcoords_grid .row:hover { background: green; } .parcoords_grid .header { font-weight: bold; } .parcoords_grid .cell { float: left; overflow: hidden; white-space: nowrap; width: 120px; height: 18px; } .parcoords_grid .col-0 { width: 110px; } ================================================ FILE: google/datalab/notebook/static/extern/d3.parcoords.js ================================================ d3.parcoords = function(config) { var __ = { data: [], highlighted: [], dimensions: [], dimensionTitles: {}, dimensionTitleRotation: 0, types: {}, brushed: false, mode: "default", rate: 20, width: 600, height: 300, margin: { top: 30, right: 0, bottom: 12, left: 0 }, color: "#069", composite: "source-over", alpha: 0.7, bundlingStrength: 0.5, bundleDimension: null, smoothness: 0.25, showControlPoints: false, hideAxis : [] }; extend(__, config); var pc = function(selection) { selection = pc.selection = d3.select(selection); __.width = selection[0][0].clientWidth; __.height = selection[0][0].clientHeight; // canvas data layers ["shadows", "marks", "foreground", "highlight"].forEach(function(layer) { canvas[layer] = selection .append("canvas") .attr("class", layer)[0][0]; ctx[layer] = canvas[layer].getContext("2d"); }); // svg tick and brush layers pc.svg = selection .append("svg") .attr("width", __.width) .attr("height", __.height) .append("svg:g") .attr("transform", "translate(" + __.margin.left + "," + __.margin.top + ")"); return pc; }; var events = d3.dispatch.apply(this,["render", "resize", "highlight", "brush", "brushend", "axesreorder"].concat(d3.keys(__))), w = function() { return __.width - __.margin.right - __.margin.left; }, h = function() { return __.height - __.margin.top - __.margin.bottom; }, flags = { brushable: false, reorderable: false, axes: false, interactive: false, shadows: false, debug: false }, xscale = d3.scale.ordinal(), yscale = {}, dragging = {}, line = d3.svg.line(), axis = d3.svg.axis().orient("left").ticks(5), g, // groups for axes, brushes ctx = {}, canvas = {}, clusterCentroids = []; // side effects for setters var side_effects = d3.dispatch.apply(this,d3.keys(__)) .on("composite", function(d) { ctx.foreground.globalCompositeOperation = d.value; }) .on("alpha", function(d) { ctx.foreground.globalAlpha = d.value; }) .on("width", function(d) { pc.resize(); }) .on("height", function(d) { pc.resize(); }) .on("margin", function(d) { pc.resize(); }) .on("rate", function(d) { rqueue.rate(d.value); }) .on("data", function(d) { if (flags.shadows){paths(__.data, ctx.shadows);} }) .on("dimensions", function(d) { xscale.domain(__.dimensions); if (flags.interactive){pc.render().updateAxes();} }) .on("bundleDimension", function(d) { if (!__.dimensions.length) pc.detectDimensions(); if (!(__.dimensions[0] in yscale)) pc.autoscale(); if (typeof d.value === "number") { if (d.value < __.dimensions.length) { __.bundleDimension = __.dimensions[d.value]; } else if (d.value < __.hideAxis.length) { __.bundleDimension = __.hideAxis[d.value]; } } else { __.bundleDimension = d.value; } __.clusterCentroids = compute_cluster_centroids(__.bundleDimension); }) .on("hideAxis", function(d) { if (!__.dimensions.length) pc.detectDimensions(); pc.dimensions(without(__.dimensions, d.value)); }); // expose the state of the chart pc.state = __; pc.flags = flags; // create getter/setters getset(pc, __, events); // expose events d3.rebind(pc, events, "on"); // tick formatting d3.rebind(pc, axis, "ticks", "orient", "tickValues", "tickSubdivide", "tickSize", "tickPadding", "tickFormat"); // getter/setter with event firing function getset(obj,state,events) { d3.keys(state).forEach(function(key) { obj[key] = function(x) { if (!arguments.length) { return state[key]; } var old = state[key]; state[key] = x; side_effects[key].call(pc,{"value": x, "previous": old}); events[key].call(pc,{"value": x, "previous": old}); return obj; }; }); }; function extend(target, source) { for (key in source) { target[key] = source[key]; } return target; }; function without(arr, item) { return arr.filter(function(elem) { return item.indexOf(elem) === -1; }) }; pc.autoscale = function() { // yscale var defaultScales = { "date": function(k) { return d3.time.scale() .domain(d3.extent(__.data, function(d) { return d[k] ? d[k].getTime() : null; })) .range([h()+1, 1]); }, "number": function(k) { return d3.scale.linear() .domain(d3.extent(__.data, function(d) { return +d[k]; })) .range([h()+1, 1]); }, "string": function(k) { var counts = {}, domain = []; // Let's get the count for each value so that we can sort the domain based // on the number of items for each value. __.data.map(function(p) { if (counts[p[k]] === undefined) { counts[p[k]] = 1; } else { counts[p[k]] = counts[p[k]] + 1; } }); domain = Object.getOwnPropertyNames(counts).sort(function(a, b) { return counts[a] - counts[b]; }); return d3.scale.ordinal() .domain(domain) .rangePoints([h()+1, 1]); } }; __.dimensions.forEach(function(k) { yscale[k] = defaultScales[__.types[k]](k); }); __.hideAxis.forEach(function(k) { yscale[k] = defaultScales[__.types[k]](k); }); // hack to remove ordinal dimensions with many values pc.dimensions(pc.dimensions().filter(function(p,i) { var uniques = yscale[p].domain().length; if (__.types[p] == "string" && (uniques > 60 || uniques < 2)) { return false; } return true; })); // xscale xscale.rangePoints([0, w()], 1); // canvas sizes pc.selection.selectAll("canvas") .style("margin-top", __.margin.top + "px") .style("margin-left", __.margin.left + "px") .attr("width", w()+2) .attr("height", h()+2); // default styles, needs to be set when canvas width changes ctx.foreground.strokeStyle = __.color; ctx.foreground.lineWidth = 1.4; ctx.foreground.globalCompositeOperation = __.composite; ctx.foreground.globalAlpha = __.alpha; ctx.highlight.lineWidth = 3; ctx.shadows.strokeStyle = "#dadada"; return this; }; pc.scale = function(d, domain) { yscale[d].domain(domain); return this; }; pc.flip = function(d) { //yscale[d].domain().reverse(); // does not work yscale[d].domain(yscale[d].domain().reverse()); // works return this; }; pc.commonScale = function(global, type) { var t = type || "number"; if (typeof global === 'undefined') { global = true; } // scales of the same type var scales = __.dimensions.concat(__.hideAxis).filter(function(p) { return __.types[p] == t; }); if (global) { var extent = d3.extent(scales.map(function(p,i) { return yscale[p].domain(); }).reduce(function(a,b) { return a.concat(b); })); scales.forEach(function(d) { yscale[d].domain(extent); }); } else { scales.forEach(function(k) { yscale[k].domain(d3.extent(__.data, function(d) { return +d[k]; })); }); } // update centroids if (__.bundleDimension !== null) { pc.bundleDimension(__.bundleDimension); } return this; };pc.detectDimensions = function() { pc.types(pc.detectDimensionTypes(__.data)); pc.dimensions(d3.keys(pc.types())); return this; }; // a better "typeof" from this post: http://stackoverflow.com/questions/7390426/better-way-to-get-type-of-a-javascript-variable pc.toType = function(v) { return ({}).toString.call(v).match(/\s([a-zA-Z]+)/)[1].toLowerCase(); }; // try to coerce to number before returning type pc.toTypeCoerceNumbers = function(v) { if ((parseFloat(v) == v) && (v != null)) { return "number"; } return pc.toType(v); }; // attempt to determine types of each dimension based on first row of data pc.detectDimensionTypes = function(data) { var types = {}; d3.keys(data[0]) .forEach(function(col) { types[col] = pc.toTypeCoerceNumbers(data[0][col]); }); return types; }; pc.render = function() { // try to autodetect dimensions and create scales if (!__.dimensions.length) pc.detectDimensions(); if (!(__.dimensions[0] in yscale)) pc.autoscale(); pc.render[__.mode](); events.render.call(this); return this; }; pc.render['default'] = function() { pc.clear('foreground'); if (__.brushed) { __.brushed.forEach(path_foreground); __.highlighted.forEach(path_highlight); } else { __.data.forEach(path_foreground); __.highlighted.forEach(path_highlight); } }; var rqueue = d3.renderQueue(path_foreground) .rate(50) .clear(function() { pc.clear('foreground'); pc.clear('highlight'); }); pc.render.queue = function() { if (__.brushed) { rqueue(__.brushed); __.highlighted.forEach(path_highlight); } else { rqueue(__.data); __.highlighted.forEach(path_highlight); } }; function compute_cluster_centroids(d) { var clusterCentroids = d3.map(); var clusterCounts = d3.map(); // determine clusterCounts __.data.forEach(function(row) { var scaled = yscale[d](row[d]); if (!clusterCounts.has(scaled)) { clusterCounts.set(scaled, 0); } var count = clusterCounts.get(scaled); clusterCounts.set(scaled, count + 1); }); __.data.forEach(function(row) { __.dimensions.map(function(p, i) { var scaled = yscale[d](row[d]); if (!clusterCentroids.has(scaled)) { var map = d3.map(); clusterCentroids.set(scaled, map); } if (!clusterCentroids.get(scaled).has(p)) { clusterCentroids.get(scaled).set(p, 0); } var value = clusterCentroids.get(scaled).get(p); value += yscale[p](row[p]) / clusterCounts.get(scaled); clusterCentroids.get(scaled).set(p, value); }); }); return clusterCentroids; } function compute_centroids(row) { var centroids = []; var p = __.dimensions; var cols = p.length; var a = 0.5; // center between axes for (var i = 0; i < cols; ++i) { // centroids on 'real' axes var x = position(p[i]); var y = yscale[p[i]](row[p[i]]); centroids.push([x, y]); //centroids.push($V([x, y])); // centroids on 'virtual' axes if (i < cols - 1) { var cx = x + a * (position(p[i+1]) - x); var cy = y + a * (yscale[p[i+1]](row[p[i+1]]) - y); if (__.bundleDimension !== null) { var leftCentroid = __.clusterCentroids.get(yscale[__.bundleDimension](row[__.bundleDimension])).get(p[i]); var rightCentroid = __.clusterCentroids.get(yscale[__.bundleDimension](row[__.bundleDimension])).get(p[i+1]); var centroid = 0.5 * (leftCentroid + rightCentroid); cy = centroid + (1 - __.bundlingStrength) * (cy - centroid); } centroids.push([cx, cy]); //centroids.push($V([cx, cy])); } } return centroids; } pc.compute_centroids = compute_centroids; function compute_control_points(centroids) { var cols = centroids.length; var a = __.smoothness; var cps = []; cps.push(centroids[0]); cps.push($V([centroids[0].e(1) + a*2*(centroids[1].e(1)-centroids[0].e(1)), centroids[0].e(2)])); for (var col = 1; col < cols - 1; ++col) { var mid = centroids[col]; var left = centroids[col - 1]; var right = centroids[col + 1]; var diff = left.subtract(right); cps.push(mid.add(diff.x(a))); cps.push(mid); cps.push(mid.subtract(diff.x(a))); } cps.push($V([centroids[cols-1].e(1) + a*2*(centroids[cols-2].e(1)-centroids[cols-1].e(1)), centroids[cols-1].e(2)])); cps.push(centroids[cols - 1]); return cps; };pc.shadows = function() { flags.shadows = true; if (__.data.length > 0) { paths(__.data, ctx.shadows); } return this; }; // draw little dots on the axis line where data intersects pc.axisDots = function() { var ctx = pc.ctx.marks; ctx.globalAlpha = d3.min([ 1 / Math.pow(data.length, 1 / 2), 1 ]); __.data.forEach(function(d) { __.dimensions.map(function(p, i) { ctx.fillRect(position(p) - 0.75, yscale[p](d[p]) - 0.75, 1.5, 1.5); }); }); return this; }; // draw single cubic bezier curve function single_curve(d, ctx) { var centroids = compute_centroids(d); var cps = compute_control_points(centroids); ctx.moveTo(cps[0].e(1), cps[0].e(2)); for (var i = 1; i < cps.length; i += 3) { if (__.showControlPoints) { for (var j = 0; j < 3; j++) { ctx.fillRect(cps[i+j].e(1), cps[i+j].e(2), 2, 2); } } ctx.bezierCurveTo(cps[i].e(1), cps[i].e(2), cps[i+1].e(1), cps[i+1].e(2), cps[i+2].e(1), cps[i+2].e(2)); } }; // draw single polyline function color_path(d, i, ctx) { ctx.strokeStyle = d3.functor(__.color)(d, i); ctx.beginPath(); if (__.bundleDimension === null || (__.bundlingStrength === 0 && __.smoothness == 0)) { single_path(d, ctx); } else { single_curve(d, ctx); } ctx.stroke(); }; // draw many polylines of the same color function paths(data, ctx) { ctx.clearRect(-1, -1, w() + 2, h() + 2); ctx.beginPath(); data.forEach(function(d) { if (__.bundleDimension === null || (__.bundlingStrength === 0 && __.smoothness == 0)) { single_path(d, ctx); } else { single_curve(d, ctx); } }); ctx.stroke(); }; function single_path(d, ctx) { __.dimensions.map(function(p, i) { if (i == 0) { ctx.moveTo(position(p), yscale[p](d[p])); } else { ctx.lineTo(position(p), yscale[p](d[p])); } }); } function path_foreground(d, i) { return color_path(d, i, ctx.foreground); }; function path_highlight(d, i) { return color_path(d, i, ctx.highlight); }; pc.clear = function(layer) { ctx[layer].clearRect(0,0,w()+2,h()+2); return this; }; function flipAxisAndUpdatePCP(dimension, i) { var g = pc.svg.selectAll(".dimension"); pc.flip(dimension); d3.select(g[0][i]) .transition() .duration(1100) .call(axis.scale(yscale[dimension])); pc.render(); if (flags.shadows) paths(__.data, ctx.shadows); } function rotateLabels() { var delta = d3.event.deltaY; delta = delta < 0 ? -5 : delta; delta = delta > 0 ? 5 : delta; __.dimensionTitleRotation += delta; pc.svg.selectAll("text.label") .attr("transform", "translate(0,-5) rotate(" + __.dimensionTitleRotation + ")"); d3.event.preventDefault(); } pc.createAxes = function() { if (g) pc.removeAxes(); // Add a group element for each dimension. g = pc.svg.selectAll(".dimension") .data(__.dimensions, function(d) { return d; }) .enter().append("svg:g") .attr("class", "dimension") .attr("transform", function(d) { return "translate(" + xscale(d) + ")"; }); // Add an axis and title. g.append("svg:g") .attr("class", "axis") .attr("transform", "translate(0,0)") .each(function(d) { d3.select(this).call(axis.scale(yscale[d])); }) .append("svg:text") .attr({ "text-anchor": "middle", "y": 0, "transform": "translate(0,-5) rotate(" + __.dimensionTitleRotation + ")", "x": 0, "class": "label" }) .text(function(d) { return d in __.dimensionTitles ? __.dimensionTitles[d] : d; // dimension display names }) .on("dblclick", flipAxisAndUpdatePCP) .on("wheel", rotateLabels); flags.axes= true; return this; }; pc.removeAxes = function() { g.remove(); return this; }; pc.updateAxes = function() { var g_data = pc.svg.selectAll(".dimension").data(__.dimensions); // Enter g_data.enter().append("svg:g") .attr("class", "dimension") .attr("transform", function(p) { return "translate(" + position(p) + ")"; }) .style("opacity", 0) .append("svg:g") .attr("class", "axis") .attr("transform", "translate(0,0)") .each(function(d) { d3.select(this).call(axis.scale(yscale[d])); }) .append("svg:text") .attr({ "text-anchor": "middle", "y": 0, "transform": "translate(0,-5) rotate(" + __.dimensionTitleRotation + ")", "x": 0, "class": "label" }) .text(String) .on("dblclick", flipAxisAndUpdatePCP) .on("wheel", rotateLabels); // Update g_data.attr("opacity", 0); g_data.select(".axis") .transition() .duration(1100) .each(function(d) { d3.select(this).call(axis.scale(yscale[d])); }); g_data.select(".label") .transition() .duration(1100) .text(String) .attr("transform", "translate(0,-5) rotate(" + __.dimensionTitleRotation + ")"); // Exit g_data.exit().remove(); g = pc.svg.selectAll(".dimension"); g.transition().duration(1100) .attr("transform", function(p) { return "translate(" + position(p) + ")"; }) .style("opacity", 1); pc.svg.selectAll(".axis") .transition() .duration(1100) .each(function(d) { d3.select(this).call(axis.scale(yscale[d])); }); if (flags.shadows) paths(__.data, ctx.shadows); if (flags.brushable) pc.brushable(); if (flags.reorderable) pc.reorderable(); if (pc.brushMode() !== "None") { var mode = pc.brushMode(); pc.brushMode("None"); pc.brushMode(mode); } return this; }; // Jason Davies, http://bl.ocks.org/1341281 pc.reorderable = function() { if (!g) pc.createAxes(); // Keep track of the order of the axes to verify if the order has actually // changed after a drag ends. Changed order might have consequence (e.g. // strums that need to be reset). var dimsAtDragstart; g.style("cursor", "move") .call(d3.behavior.drag() .on("dragstart", function(d) { dragging[d] = this.__origin__ = xscale(d); dimsAtDragstart = __.dimensions.slice(); }) .on("drag", function(d) { dragging[d] = Math.min(w(), Math.max(0, this.__origin__ += d3.event.dx)); __.dimensions.sort(function(a, b) { return position(a) - position(b); }); xscale.domain(__.dimensions); pc.render(); g.attr("transform", function(d) { return "translate(" + position(d) + ")"; }); }) .on("dragend", function(d, i) { // Let's see if the order has changed and send out an event if so. var j = __.dimensions.indexOf(d), parent = this.parentElement; if (i !== j) { events.axesreorder.call(pc, __.dimensions); // We now also want to reorder the actual dom elements that represent // the axes. That is, the g.dimension elements. If we don't do this, // we get a weird and confusing transition when updateAxes is called. // This is due to the fact that, initially the nth g.dimension element // represents the nth axis. However, after a manual reordering, // without reordering the dom elements, the nth dom elements no longer // necessarily represents the nth axis. // // i is the original index of the dom element // j is the new index of the dom element parent.insertBefore(this, parent.children[j + 1]) } delete this.__origin__; delete dragging[d]; d3.select(this).transition().attr("transform", "translate(" + xscale(d) + ")"); pc.render(); if (flags.shadows) paths(__.data, ctx.shadows); })); flags.reorderable = true; return this; }; // pairs of adjacent dimensions pc.adjacent_pairs = function(arr) { var ret = []; for (var i = 0; i < arr.length-1; i++) { ret.push([arr[i],arr[i+1]]); }; return ret; }; var brush = { modes: { "None": { install: function(pc) {}, // Nothing to be done. uninstall: function(pc) {}, // Nothing to be done. selected: function() { return []; } // Nothing to return } }, mode: "None", predicate: "AND", currentMode: function() { return this.modes[this.mode]; } }; // This function can be used for 'live' updates of brushes. That is, during the // specification of a brush, this method can be called to update the view. // // @param newSelection - The new set of data items that is currently contained // by the brushes function brushUpdated(newSelection) { __.brushed = newSelection; events.brush.call(pc,__.brushed); pc.render(); } function brushPredicate(predicate) { if (!arguments.length) { return brush.predicate; } predicate = String(predicate).toUpperCase(); if (predicate !== "AND" && predicate !== "OR") { throw "Invalid predicate " + predicate; } brush.predicate = predicate; __.brushed = brush.currentMode().selected(); pc.render(); return pc; } pc.brushModes = function() { return Object.getOwnPropertyNames(brush.modes); }; pc.brushMode = function(mode) { if (arguments.length === 0) { return brush.mode; } if (pc.brushModes().indexOf(mode) === -1) { throw "pc.brushmode: Unsupported brush mode: " + mode; } // Make sure that we don't trigger unnecessary events by checking if the mode // actually changes. if (mode !== brush.mode) { // When changing brush modes, the first thing we need to do is clearing any // brushes from the current mode, if any. if (brush.mode !== "None") { pc.brushReset(); } // Next, we need to 'uninstall' the current brushMode. brush.modes[brush.mode].uninstall(pc); // Finally, we can install the requested one. brush.mode = mode; brush.modes[brush.mode].install(); if (mode === "None") { delete pc.brushPredicate; } else { pc.brushPredicate = brushPredicate; } } return pc; }; // brush mode: 1D-Axes (function() { var brushes = {}; function is_brushed(p) { return !brushes[p].empty(); } // data within extents function selected() { var actives = __.dimensions.filter(is_brushed), extents = actives.map(function(p) { return brushes[p].extent(); }); // We don't want to return the full data set when there are no axes brushed. // Actually, when there are no axes brushed, by definition, no items are // selected. So, let's avoid the filtering and just return false. //if (actives.length === 0) return false; // Resolves broken examples for now. They expect to get the full dataset back from empty brushes if (actives.length === 0) return __.data; // test if within range var within = { "date": function(d,p,dimension) { return extents[dimension][0] <= d[p] && d[p] <= extents[dimension][1] }, "number": function(d,p,dimension) { return extents[dimension][0] <= d[p] && d[p] <= extents[dimension][1] }, "string": function(d,p,dimension) { return extents[dimension][0] <= yscale[p](d[p]) && yscale[p](d[p]) <= extents[dimension][1] } }; return __.data .filter(function(d) { switch(brush.predicate) { case "AND": return actives.every(function(p, dimension) { return within[__.types[p]](d,p,dimension); }); case "OR": return actives.some(function(p, dimension) { return within[__.types[p]](d,p,dimension); }); default: throw "Unknown brush predicate " + __.brushPredicate; } }); }; function brushExtents() { var extents = {}; __.dimensions.forEach(function(d) { var brush = brushes[d]; if (!brush.empty()) { var extent = brush.extent(); extent.sort(d3.ascending); extents[d] = extent; } }); return extents; } function brushFor(axis) { var brush = d3.svg.brush(); brush .y(yscale[axis]) .on("brushstart", function() { d3.event.sourceEvent.stopPropagation() }) .on("brush", function() { brushUpdated(selected()); }) .on("brushend", function() { events.brushend.call(pc, __.brushed); }); brushes[axis] = brush; return brush; } function brushReset(dimension) { __.brushed = false; if (g) { g.selectAll('.brush') .each(function(d) { d3.select(this).call( brushes[d].clear() ); }); pc.render(); } return this; }; function install() { if (!g) pc.createAxes(); // Add and store a brush for each axis. g.append("svg:g") .attr("class", "brush") .each(function(d) { d3.select(this).call(brushFor(d)); }) .selectAll("rect") .style("visibility", null) .attr("x", -15) .attr("width", 30); pc.brushExtents = brushExtents; pc.brushReset = brushReset; return pc; } brush.modes["1D-axes"] = { install: install, uninstall: function() { g.selectAll(".brush").remove(); brushes = {}; delete pc.brushExtents; delete pc.brushReset; }, selected: selected } })(); // brush mode: 2D-strums // bl.ocks.org/syntagmatic/5441022 (function() { var strums = {}, strumRect; function drawStrum(strum, activePoint) { var svg = pc.selection.select("svg").select("g#strums"), id = strum.dims.i, points = [strum.p1, strum.p2], line = svg.selectAll("line#strum-" + id).data([strum]), circles = svg.selectAll("circle#strum-" + id).data(points), drag = d3.behavior.drag(); line.enter() .append("line") .attr("id", "strum-" + id) .attr("class", "strum"); line .attr("x1", function(d) { return d.p1[0]; }) .attr("y1", function(d) { return d.p1[1]; }) .attr("x2", function(d) { return d.p2[0]; }) .attr("y2", function(d) { return d.p2[1]; }) .attr("stroke", "black") .attr("stroke-width", 2); drag .on("drag", function(d, i) { var ev = d3.event; i = i + 1; strum["p" + i][0] = Math.min(Math.max(strum.minX + 1, ev.x), strum.maxX); strum["p" + i][1] = Math.min(Math.max(strum.minY, ev.y), strum.maxY); drawStrum(strum, i - 1); }) .on("dragend", onDragEnd()); circles.enter() .append("circle") .attr("id", "strum-" + id) .attr("class", "strum"); circles .attr("cx", function(d) { return d[0]; }) .attr("cy", function(d) { return d[1]; }) .attr("r", 5) .style("opacity", function(d, i) { return (activePoint !== undefined && i === activePoint) ? 0.8 : 0; }) .on("mouseover", function() { d3.select(this).style("opacity", 0.8); }) .on("mouseout", function() { d3.select(this).style("opacity", 0); }) .call(drag); } function dimensionsForPoint(p) { var dims = { i: -1, left: undefined, right: undefined }; __.dimensions.some(function(dim, i) { if (xscale(dim) < p[0]) { var next = __.dimensions[i + 1]; dims.i = i; dims.left = dim; dims.right = next; return false; } return true; }); if (dims.left === undefined) { // Event on the left side of the first axis. dims.i = 0; dims.left = __.dimensions[0]; dims.right = __.dimensions[1]; } else if (dims.right === undefined) { // Event on the right side of the last axis dims.i = __.dimensions.length - 1; dims.right = dims.left; dims.left = __.dimensions[__.dimensions.length - 2]; } return dims; } function onDragStart() { // First we need to determine between which two axes the sturm was started. // This will determine the freedom of movement, because a strum can // logically only happen between two axes, so no movement outside these axes // should be allowed. return function() { var p = d3.mouse(strumRect[0][0]), dims = dimensionsForPoint(p), strum = { p1: p, dims: dims, minX: xscale(dims.left), maxX: xscale(dims.right), minY: 0, maxY: h() }; strums[dims.i] = strum; strums.active = dims.i; // Make sure that the point is within the bounds strum.p1[0] = Math.min(Math.max(strum.minX, p[0]), strum.maxX); strum.p1[1] = p[1] - __.margin.top; strum.p2 = strum.p1.slice(); }; } function onDrag() { return function() { var ev = d3.event, strum = strums[strums.active]; // Make sure that the point is within the bounds strum.p2[0] = Math.min(Math.max(strum.minX + 1, ev.x), strum.maxX); strum.p2[1] = Math.min(Math.max(strum.minY, ev.y - __.margin.top), strum.maxY); drawStrum(strum, 1); }; } function containmentTest(strum, width) { var p1 = [strum.p1[0] - strum.minX, strum.p1[1] - strum.minX], p2 = [strum.p2[0] - strum.minX, strum.p2[1] - strum.minX], m1 = 1 - width / p1[0], b1 = p1[1] * (1 - m1), m2 = 1 - width / p2[0], b2 = p2[1] * (1 - m2); // test if point falls between lines return function(p) { var x = p[0], y = p[1], y1 = m1 * x + b1, y2 = m2 * x + b2; if (y > Math.min(y1, y2) && y < Math.max(y1, y2)) { return true; } return false; }; } function selected() { var ids = Object.getOwnPropertyNames(strums), brushed = __.data; // Get the ids of the currently active strums. ids = ids.filter(function(d) { return !isNaN(d); }); function crossesStrum(d, id) { var strum = strums[id], test = containmentTest(strum, strums.width(id)), d1 = strum.dims.left, d2 = strum.dims.right, y1 = yscale[d1], y2 = yscale[d2], point = [y1(d[d1]) - strum.minX, y2(d[d2]) - strum.minX]; return test(point); } if (ids.length === 0) { return brushed; } return brushed.filter(function(d) { switch(brush.predicate) { case "AND": return ids.every(function(id) { return crossesStrum(d, id); }); case "OR": return ids.some(function(id) { return crossesStrum(d, id); }); default: throw "Unknown brush predicate " + __.brushPredicate; } }); } function removeStrum() { var strum = strums[strums.active], svg = pc.selection.select("svg").select("g#strums"); delete strums[strums.active]; strums.active = undefined; svg.selectAll("line#strum-" + strum.dims.i).remove(); svg.selectAll("circle#strum-" + strum.dims.i).remove(); } function onDragEnd() { return function() { var brushed = __.data, strum = strums[strums.active]; // Okay, somewhat unexpected, but not totally unsurprising, a mousclick is // considered a drag without move. So we have to deal with that case if (strum && strum.p1[0] === strum.p2[0] && strum.p1[1] === strum.p2[1]) { removeStrum(strums); } brushed = selected(strums); strums.active = undefined; __.brushed = brushed; pc.render(); events.brushend.call(pc, __.brushed); }; } function brushReset(strums) { return function() { var ids = Object.getOwnPropertyNames(strums).filter(function(d) { return !isNaN(d); }); ids.forEach(function(d) { strums.active = d; removeStrum(strums); }); onDragEnd(strums)(); }; } function install() { var drag = d3.behavior.drag(); // Map of current strums. Strums are stored per segment of the PC. A segment, // being the area between two axes. The left most area is indexed at 0. strums.active = undefined; // Returns the width of the PC segment where currently a strum is being // placed. NOTE: even though they are evenly spaced in our current // implementation, we keep for when non-even spaced segments are supported as // well. strums.width = function(id) { var strum = strums[id]; if (strum === undefined) { return undefined; } return strum.maxX - strum.minX; }; pc.on("axesreorder.strums", function() { var ids = Object.getOwnPropertyNames(strums).filter(function(d) { return !isNaN(d); }); // Checks if the first dimension is directly left of the second dimension. function consecutive(first, second) { var length = __.dimensions.length; return __.dimensions.some(function(d, i) { return (d === first) ? i + i < length && __.dimensions[i + 1] === second : false; }); } if (ids.length > 0) { // We have some strums, which might need to be removed. ids.forEach(function(d) { var dims = strums[d].dims; strums.active = d; // If the two dimensions of the current strum are not next to each other // any more, than we'll need to remove the strum. Otherwise we keep it. if (!consecutive(dims.left, dims.right)) { removeStrum(strums); } }); onDragEnd(strums)(); } }); // Add a new svg group in which we draw the strums. pc.selection.select("svg").append("g") .attr("id", "strums") .attr("transform", "translate(" + __.margin.left + "," + __.margin.top + ")"); // Install the required brushReset function pc.brushReset = brushReset(strums); drag .on("dragstart", onDragStart(strums)) .on("drag", onDrag(strums)) .on("dragend", onDragEnd(strums)); // NOTE: The styling needs to be done here and not in the css. This is because // for 1D brushing, the canvas layers should not listen to // pointer-events. strumRect = pc.selection.select("svg").insert("rect", "g#strums") .attr("id", "strum-events") .attr("x", __.margin.left) .attr("y", __.margin.top) .attr("width", w()) .attr("height", h() + 2) .style("opacity", 0) .call(drag); } brush.modes["2D-strums"] = { install: install, uninstall: function() { pc.selection.select("svg").select("g#strums").remove(); pc.selection.select("svg").select("rect#strum-events").remove(); pc.on("axesreorder.strums", undefined); delete pc.brushReset; strumRect = undefined; }, selected: selected }; }()); pc.interactive = function() { flags.interactive = true; return this; }; // expose a few objects pc.xscale = xscale; pc.yscale = yscale; pc.ctx = ctx; pc.canvas = canvas; pc.g = function() { return g; }; // rescale for height, width and margins // TODO currently assumes chart is brushable, and destroys old brushes pc.resize = function() { // selection size pc.selection.select("svg") .attr("width", __.width) .attr("height", __.height) pc.svg.attr("transform", "translate(" + __.margin.left + "," + __.margin.top + ")"); // FIXME: the current brush state should pass through if (flags.brushable) pc.brushReset(); // scales pc.autoscale(); // axes, destroys old brushes. if (g) pc.createAxes(); if (flags.shadows) paths(__.data, ctx.shadows); if (flags.brushable) pc.brushable(); if (flags.reorderable) pc.reorderable(); events.resize.call(this, {width: __.width, height: __.height, margin: __.margin}); return this; }; // highlight an array of data pc.highlight = function(data) { if (arguments.length === 0) { return __.highlighted; } __.highlighted = data; pc.clear("highlight"); d3.select(canvas.foreground).classed("faded", true); data.forEach(path_highlight); events.highlight.call(this, data); return this; }; // clear highlighting pc.unhighlight = function() { __.highlighted = []; pc.clear("highlight"); d3.select(canvas.foreground).classed("faded", false); return this; }; // calculate 2d intersection of line a->b with line c->d // points are objects with x and y properties pc.intersection = function(a, b, c, d) { return { x: ((a.x * b.y - a.y * b.x) * (c.x - d.x) - (a.x - b.x) * (c.x * d.y - c.y * d.x)) / ((a.x - b.x) * (c.y - d.y) - (a.y - b.y) * (c.x - d.x)), y: ((a.x * b.y - a.y * b.x) * (c.y - d.y) - (a.y - b.y) * (c.x * d.y - c.y * d.x)) / ((a.x - b.x) * (c.y - d.y) - (a.y - b.y) * (c.x - d.x)) }; }; function position(d) { var v = dragging[d]; return v == null ? xscale(d) : v; } pc.version = "0.5.0"; // this descriptive text should live with other introspective methods pc.toString = function() { return "Parallel Coordinates: " + __.dimensions.length + " dimensions (" + d3.keys(__.data[0]).length + " total) , " + __.data.length + " rows"; }; return pc; }; d3.renderQueue = (function(func) { var _queue = [], // data to be rendered _rate = 10, // number of calls per frame _clear = function() {}, // clearing function _i = 0; // current iteration var rq = function(data) { if (data) rq.data(data); rq.invalidate(); _clear(); rq.render(); }; rq.render = function() { _i = 0; var valid = true; rq.invalidate = function() { valid = false; }; function doFrame() { if (!valid) return true; if (_i > _queue.length) return true; // Typical d3 behavior is to pass a data item *and* its index. As the // render queue splits the original data set, we'll have to be slightly // more carefull about passing the correct index with the data item. var end = Math.min(_i + _rate, _queue.length); for (var i = _i; i < end; i++) { func(_queue[i], i); } _i += _rate; } d3.timer(doFrame); }; rq.data = function(data) { rq.invalidate(); _queue = data.slice(0); return rq; }; rq.rate = function(value) { if (!arguments.length) return _rate; _rate = value; return rq; }; rq.remaining = function() { return _queue.length - _i; }; // clear the canvas rq.clear = function(func) { if (!arguments.length) { _clear(); return rq; } _clear = func; return rq; }; rq.invalidate = function() {}; return rq; }); d3.divgrid = function(config) { var columns = []; var dg = function(selection) { if (columns.length == 0) { columns = d3.keys(selection.data()[0][0]); columns = columns.filter( function(item) { return (item.substr(item.length - 5) != "(log)"); }); } // header selection.selectAll(".header") .data([true]) .enter().append("div") .attr("class", "header") var header = selection.select(".header") .selectAll(".cell") .data(columns); header.enter().append("div") .attr("class", function(d,i) { return "col-" + i; }) .classed("cell", true) selection.selectAll(".header .cell") .text(function(d) { return d; }); header.exit().remove(); // rows var rows = selection.selectAll(".row") .data(function(d) { return d; }) rows.enter().append("div") .attr("class", "row") rows.exit().remove(); var cells = selection.selectAll(".row").selectAll(".cell") .data(function(d) { return columns.map(function(col){return d[col];}) }) // cells cells.enter().append("div") .attr("class", function(d,i) { return "col-" + i; }) .classed("cell", true) cells.exit().remove(); selection.selectAll(".cell") .text(function(d) { return d; }); return dg; }; dg.columns = function(_) { if (!arguments.length) return columns; columns = _; return this; }; return dg; }; ================================================ FILE: google/datalab/notebook/static/extern/facets-jupyter.html ================================================ ================================================ FILE: google/datalab/notebook/static/extern/lantern-browser.html ================================================ ================================================ FILE: google/datalab/notebook/static/extern/parcoords-LICENSE.txt ================================================ Copyright (c) 2012, Kai Chang All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * The name Kai Chang may not be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: google/datalab/notebook/static/extern/sylvester-LICENSE.txt ================================================ (The MIT License) Copyright (c) 2007-2015 James Coglan Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the 'Software'), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: google/datalab/notebook/static/extern/sylvester.js ================================================ // === Sylvester === // Vector and Matrix mathematics modules for JavaScript // Copyright (c) 2007 James Coglan // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the "Software"), // to deal in the Software without restriction, including without limitation // the rights to use, copy, modify, merge, publish, distribute, sublicense, // and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included // in all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS // OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. var Sylvester = { version: '0.1.3', precision: 1e-6 }; function Vector() {} Vector.prototype = { // Returns element i of the vector e: function(i) { return (i < 1 || i > this.elements.length) ? null : this.elements[i-1]; }, // Returns the number of elements the vector has dimensions: function() { return this.elements.length; }, // Returns the modulus ('length') of the vector modulus: function() { return Math.sqrt(this.dot(this)); }, // Returns true iff the vector is equal to the argument eql: function(vector) { var n = this.elements.length; var V = vector.elements || vector; if (n != V.length) { return false; } do { if (Math.abs(this.elements[n-1] - V[n-1]) > Sylvester.precision) { return false; } } while (--n); return true; }, // Returns a copy of the vector dup: function() { return Vector.create(this.elements); }, // Maps the vector to another vector according to the given function map: function(fn) { var elements = []; this.each(function(x, i) { elements.push(fn(x, i)); }); return Vector.create(elements); }, // Calls the iterator for each element of the vector in turn each: function(fn) { var n = this.elements.length, k = n, i; do { i = k - n; fn(this.elements[i], i+1); } while (--n); }, // Returns a new vector created by normalizing the receiver toUnitVector: function() { var r = this.modulus(); if (r === 0) { return this.dup(); } return this.map(function(x) { return x/r; }); }, // Returns the angle between the vector and the argument (also a vector) angleFrom: function(vector) { var V = vector.elements || vector; var n = this.elements.length, k = n, i; if (n != V.length) { return null; } var dot = 0, mod1 = 0, mod2 = 0; // Work things out in parallel to save time this.each(function(x, i) { dot += x * V[i-1]; mod1 += x * x; mod2 += V[i-1] * V[i-1]; }); mod1 = Math.sqrt(mod1); mod2 = Math.sqrt(mod2); if (mod1*mod2 === 0) { return null; } var theta = dot / (mod1*mod2); if (theta < -1) { theta = -1; } if (theta > 1) { theta = 1; } return Math.acos(theta); }, // Returns true iff the vector is parallel to the argument isParallelTo: function(vector) { var angle = this.angleFrom(vector); return (angle === null) ? null : (angle <= Sylvester.precision); }, // Returns true iff the vector is antiparallel to the argument isAntiparallelTo: function(vector) { var angle = this.angleFrom(vector); return (angle === null) ? null : (Math.abs(angle - Math.PI) <= Sylvester.precision); }, // Returns true iff the vector is perpendicular to the argument isPerpendicularTo: function(vector) { var dot = this.dot(vector); return (dot === null) ? null : (Math.abs(dot) <= Sylvester.precision); }, // Returns the result of adding the argument to the vector add: function(vector) { var V = vector.elements || vector; if (this.elements.length != V.length) { return null; } return this.map(function(x, i) { return x + V[i-1]; }); }, // Returns the result of subtracting the argument from the vector subtract: function(vector) { var V = vector.elements || vector; if (this.elements.length != V.length) { return null; } return this.map(function(x, i) { return x - V[i-1]; }); }, // Returns the result of multiplying the elements of the vector by the argument multiply: function(k) { return this.map(function(x) { return x*k; }); }, x: function(k) { return this.multiply(k); }, // Returns the scalar product of the vector with the argument // Both vectors must have equal dimensionality dot: function(vector) { var V = vector.elements || vector; var i, product = 0, n = this.elements.length; if (n != V.length) { return null; } do { product += this.elements[n-1] * V[n-1]; } while (--n); return product; }, // Returns the vector product of the vector with the argument // Both vectors must have dimensionality 3 cross: function(vector) { var B = vector.elements || vector; if (this.elements.length != 3 || B.length != 3) { return null; } var A = this.elements; return Vector.create([ (A[1] * B[2]) - (A[2] * B[1]), (A[2] * B[0]) - (A[0] * B[2]), (A[0] * B[1]) - (A[1] * B[0]) ]); }, // Returns the (absolute) largest element of the vector max: function() { var m = 0, n = this.elements.length, k = n, i; do { i = k - n; if (Math.abs(this.elements[i]) > Math.abs(m)) { m = this.elements[i]; } } while (--n); return m; }, // Returns the index of the first match found indexOf: function(x) { var index = null, n = this.elements.length, k = n, i; do { i = k - n; if (index === null && this.elements[i] == x) { index = i + 1; } } while (--n); return index; }, // Returns a diagonal matrix with the vector's elements as its diagonal elements toDiagonalMatrix: function() { return Matrix.Diagonal(this.elements); }, // Returns the result of rounding the elements of the vector round: function() { return this.map(function(x) { return Math.round(x); }); }, // Returns a copy of the vector with elements set to the given value if they // differ from it by less than Sylvester.precision snapTo: function(x) { return this.map(function(y) { return (Math.abs(y - x) <= Sylvester.precision) ? x : y; }); }, // Returns the vector's distance from the argument, when considered as a point in space distanceFrom: function(obj) { if (obj.anchor) { return obj.distanceFrom(this); } var V = obj.elements || obj; if (V.length != this.elements.length) { return null; } var sum = 0, part; this.each(function(x, i) { part = x - V[i-1]; sum += part * part; }); return Math.sqrt(sum); }, // Returns true if the vector is point on the given line liesOn: function(line) { return line.contains(this); }, // Return true iff the vector is a point in the given plane liesIn: function(plane) { return plane.contains(this); }, // Rotates the vector about the given object. The object should be a // point if the vector is 2D, and a line if it is 3D. Be careful with line directions! rotate: function(t, obj) { var V, R, x, y, z; switch (this.elements.length) { case 2: V = obj.elements || obj; if (V.length != 2) { return null; } R = Matrix.Rotation(t).elements; x = this.elements[0] - V[0]; y = this.elements[1] - V[1]; return Vector.create([ V[0] + R[0][0] * x + R[0][1] * y, V[1] + R[1][0] * x + R[1][1] * y ]); break; case 3: if (!obj.direction) { return null; } var C = obj.pointClosestTo(this).elements; R = Matrix.Rotation(t, obj.direction).elements; x = this.elements[0] - C[0]; y = this.elements[1] - C[1]; z = this.elements[2] - C[2]; return Vector.create([ C[0] + R[0][0] * x + R[0][1] * y + R[0][2] * z, C[1] + R[1][0] * x + R[1][1] * y + R[1][2] * z, C[2] + R[2][0] * x + R[2][1] * y + R[2][2] * z ]); break; default: return null; } }, // Returns the result of reflecting the point in the given point, line or plane reflectionIn: function(obj) { if (obj.anchor) { // obj is a plane or line var P = this.elements.slice(); var C = obj.pointClosestTo(P).elements; return Vector.create([C[0] + (C[0] - P[0]), C[1] + (C[1] - P[1]), C[2] + (C[2] - (P[2] || 0))]); } else { // obj is a point var Q = obj.elements || obj; if (this.elements.length != Q.length) { return null; } return this.map(function(x, i) { return Q[i-1] + (Q[i-1] - x); }); } }, // Utility to make sure vectors are 3D. If they are 2D, a zero z-component is added to3D: function() { var V = this.dup(); switch (V.elements.length) { case 3: break; case 2: V.elements.push(0); break; default: return null; } return V; }, // Returns a string representation of the vector inspect: function() { return '[' + this.elements.join(', ') + ']'; }, // Set vector's elements from an array setElements: function(els) { this.elements = (els.elements || els).slice(); return this; } }; // Constructor function Vector.create = function(elements) { var V = new Vector(); return V.setElements(elements); }; // i, j, k unit vectors Vector.i = Vector.create([1,0,0]); Vector.j = Vector.create([0,1,0]); Vector.k = Vector.create([0,0,1]); // Random vector of size n Vector.Random = function(n) { var elements = []; do { elements.push(Math.random()); } while (--n); return Vector.create(elements); }; // Vector filled with zeros Vector.Zero = function(n) { var elements = []; do { elements.push(0); } while (--n); return Vector.create(elements); }; function Matrix() {} Matrix.prototype = { // Returns element (i,j) of the matrix e: function(i,j) { if (i < 1 || i > this.elements.length || j < 1 || j > this.elements[0].length) { return null; } return this.elements[i-1][j-1]; }, // Returns row k of the matrix as a vector row: function(i) { if (i > this.elements.length) { return null; } return Vector.create(this.elements[i-1]); }, // Returns column k of the matrix as a vector col: function(j) { if (j > this.elements[0].length) { return null; } var col = [], n = this.elements.length, k = n, i; do { i = k - n; col.push(this.elements[i][j-1]); } while (--n); return Vector.create(col); }, // Returns the number of rows/columns the matrix has dimensions: function() { return {rows: this.elements.length, cols: this.elements[0].length}; }, // Returns the number of rows in the matrix rows: function() { return this.elements.length; }, // Returns the number of columns in the matrix cols: function() { return this.elements[0].length; }, // Returns true iff the matrix is equal to the argument. You can supply // a vector as the argument, in which case the receiver must be a // one-column matrix equal to the vector. eql: function(matrix) { var M = matrix.elements || matrix; if (typeof(M[0][0]) == 'undefined') { M = Matrix.create(M).elements; } if (this.elements.length != M.length || this.elements[0].length != M[0].length) { return false; } var ni = this.elements.length, ki = ni, i, nj, kj = this.elements[0].length, j; do { i = ki - ni; nj = kj; do { j = kj - nj; if (Math.abs(this.elements[i][j] - M[i][j]) > Sylvester.precision) { return false; } } while (--nj); } while (--ni); return true; }, // Returns a copy of the matrix dup: function() { return Matrix.create(this.elements); }, // Maps the matrix to another matrix (of the same dimensions) according to the given function map: function(fn) { var els = [], ni = this.elements.length, ki = ni, i, nj, kj = this.elements[0].length, j; do { i = ki - ni; nj = kj; els[i] = []; do { j = kj - nj; els[i][j] = fn(this.elements[i][j], i + 1, j + 1); } while (--nj); } while (--ni); return Matrix.create(els); }, // Returns true iff the argument has the same dimensions as the matrix isSameSizeAs: function(matrix) { var M = matrix.elements || matrix; if (typeof(M[0][0]) == 'undefined') { M = Matrix.create(M).elements; } return (this.elements.length == M.length && this.elements[0].length == M[0].length); }, // Returns the result of adding the argument to the matrix add: function(matrix) { var M = matrix.elements || matrix; if (typeof(M[0][0]) == 'undefined') { M = Matrix.create(M).elements; } if (!this.isSameSizeAs(M)) { return null; } return this.map(function(x, i, j) { return x + M[i-1][j-1]; }); }, // Returns the result of subtracting the argument from the matrix subtract: function(matrix) { var M = matrix.elements || matrix; if (typeof(M[0][0]) == 'undefined') { M = Matrix.create(M).elements; } if (!this.isSameSizeAs(M)) { return null; } return this.map(function(x, i, j) { return x - M[i-1][j-1]; }); }, // Returns true iff the matrix can multiply the argument from the left canMultiplyFromLeft: function(matrix) { var M = matrix.elements || matrix; if (typeof(M[0][0]) == 'undefined') { M = Matrix.create(M).elements; } // this.columns should equal matrix.rows return (this.elements[0].length == M.length); }, // Returns the result of multiplying the matrix from the right by the argument. // If the argument is a scalar then just multiply all the elements. If the argument is // a vector, a vector is returned, which saves you having to remember calling // col(1) on the result. multiply: function(matrix) { if (!matrix.elements) { return this.map(function(x) { return x * matrix; }); } var returnVector = matrix.modulus ? true : false; var M = matrix.elements || matrix; if (typeof(M[0][0]) == 'undefined') { M = Matrix.create(M).elements; } if (!this.canMultiplyFromLeft(M)) { return null; } var ni = this.elements.length, ki = ni, i, nj, kj = M[0].length, j; var cols = this.elements[0].length, elements = [], sum, nc, c; do { i = ki - ni; elements[i] = []; nj = kj; do { j = kj - nj; sum = 0; nc = cols; do { c = cols - nc; sum += this.elements[i][c] * M[c][j]; } while (--nc); elements[i][j] = sum; } while (--nj); } while (--ni); var M = Matrix.create(elements); return returnVector ? M.col(1) : M; }, x: function(matrix) { return this.multiply(matrix); }, // Returns a submatrix taken from the matrix // Argument order is: start row, start col, nrows, ncols // Element selection wraps if the required index is outside the matrix's bounds, so you could // use this to perform row/column cycling or copy-augmenting. minor: function(a, b, c, d) { var elements = [], ni = c, i, nj, j; var rows = this.elements.length, cols = this.elements[0].length; do { i = c - ni; elements[i] = []; nj = d; do { j = d - nj; elements[i][j] = this.elements[(a+i-1)%rows][(b+j-1)%cols]; } while (--nj); } while (--ni); return Matrix.create(elements); }, // Returns the transpose of the matrix transpose: function() { var rows = this.elements.length, cols = this.elements[0].length; var elements = [], ni = cols, i, nj, j; do { i = cols - ni; elements[i] = []; nj = rows; do { j = rows - nj; elements[i][j] = this.elements[j][i]; } while (--nj); } while (--ni); return Matrix.create(elements); }, // Returns true iff the matrix is square isSquare: function() { return (this.elements.length == this.elements[0].length); }, // Returns the (absolute) largest element of the matrix max: function() { var m = 0, ni = this.elements.length, ki = ni, i, nj, kj = this.elements[0].length, j; do { i = ki - ni; nj = kj; do { j = kj - nj; if (Math.abs(this.elements[i][j]) > Math.abs(m)) { m = this.elements[i][j]; } } while (--nj); } while (--ni); return m; }, // Returns the indeces of the first match found by reading row-by-row from left to right indexOf: function(x) { var index = null, ni = this.elements.length, ki = ni, i, nj, kj = this.elements[0].length, j; do { i = ki - ni; nj = kj; do { j = kj - nj; if (this.elements[i][j] == x) { return {i: i+1, j: j+1}; } } while (--nj); } while (--ni); return null; }, // If the matrix is square, returns the diagonal elements as a vector. // Otherwise, returns null. diagonal: function() { if (!this.isSquare) { return null; } var els = [], n = this.elements.length, k = n, i; do { i = k - n; els.push(this.elements[i][i]); } while (--n); return Vector.create(els); }, // Make the matrix upper (right) triangular by Gaussian elimination. // This method only adds multiples of rows to other rows. No rows are // scaled up or switched, and the determinant is preserved. toRightTriangular: function() { var M = this.dup(), els; var n = this.elements.length, k = n, i, np, kp = this.elements[0].length, p; do { i = k - n; if (M.elements[i][i] == 0) { for (j = i + 1; j < k; j++) { if (M.elements[j][i] != 0) { els = []; np = kp; do { p = kp - np; els.push(M.elements[i][p] + M.elements[j][p]); } while (--np); M.elements[i] = els; break; } } } if (M.elements[i][i] != 0) { for (j = i + 1; j < k; j++) { var multiplier = M.elements[j][i] / M.elements[i][i]; els = []; np = kp; do { p = kp - np; // Elements with column numbers up to an including the number // of the row that we're subtracting can safely be set straight to // zero, since that's the point of this routine and it avoids having // to loop over and correct rounding errors later els.push(p <= i ? 0 : M.elements[j][p] - M.elements[i][p] * multiplier); } while (--np); M.elements[j] = els; } } } while (--n); return M; }, toUpperTriangular: function() { return this.toRightTriangular(); }, // Returns the determinant for square matrices determinant: function() { if (!this.isSquare()) { return null; } var M = this.toRightTriangular(); var det = M.elements[0][0], n = M.elements.length - 1, k = n, i; do { i = k - n + 1; det = det * M.elements[i][i]; } while (--n); return det; }, det: function() { return this.determinant(); }, // Returns true iff the matrix is singular isSingular: function() { return (this.isSquare() && this.determinant() === 0); }, // Returns the trace for square matrices trace: function() { if (!this.isSquare()) { return null; } var tr = this.elements[0][0], n = this.elements.length - 1, k = n, i; do { i = k - n + 1; tr += this.elements[i][i]; } while (--n); return tr; }, tr: function() { return this.trace(); }, // Returns the rank of the matrix rank: function() { var M = this.toRightTriangular(), rank = 0; var ni = this.elements.length, ki = ni, i, nj, kj = this.elements[0].length, j; do { i = ki - ni; nj = kj; do { j = kj - nj; if (Math.abs(M.elements[i][j]) > Sylvester.precision) { rank++; break; } } while (--nj); } while (--ni); return rank; }, rk: function() { return this.rank(); }, // Returns the result of attaching the given argument to the right-hand side of the matrix augment: function(matrix) { var M = matrix.elements || matrix; if (typeof(M[0][0]) == 'undefined') { M = Matrix.create(M).elements; } var T = this.dup(), cols = T.elements[0].length; var ni = T.elements.length, ki = ni, i, nj, kj = M[0].length, j; if (ni != M.length) { return null; } do { i = ki - ni; nj = kj; do { j = kj - nj; T.elements[i][cols + j] = M[i][j]; } while (--nj); } while (--ni); return T; }, // Returns the inverse (if one exists) using Gauss-Jordan inverse: function() { if (!this.isSquare() || this.isSingular()) { return null; } var ni = this.elements.length, ki = ni, i, j; var M = this.augment(Matrix.I(ni)).toRightTriangular(); var np, kp = M.elements[0].length, p, els, divisor; var inverse_elements = [], new_element; // Matrix is non-singular so there will be no zeros on the diagonal // Cycle through rows from last to first do { i = ni - 1; // First, normalise diagonal elements to 1 els = []; np = kp; inverse_elements[i] = []; divisor = M.elements[i][i]; do { p = kp - np; new_element = M.elements[i][p] / divisor; els.push(new_element); // Shuffle of the current row of the right hand side into the results // array as it will not be modified by later runs through this loop if (p >= ki) { inverse_elements[i].push(new_element); } } while (--np); M.elements[i] = els; // Then, subtract this row from those above it to // give the identity matrix on the left hand side for (j = 0; j < i; j++) { els = []; np = kp; do { p = kp - np; els.push(M.elements[j][p] - M.elements[i][p] * M.elements[j][i]); } while (--np); M.elements[j] = els; } } while (--ni); return Matrix.create(inverse_elements); }, inv: function() { return this.inverse(); }, // Returns the result of rounding all the elements round: function() { return this.map(function(x) { return Math.round(x); }); }, // Returns a copy of the matrix with elements set to the given value if they // differ from it by less than Sylvester.precision snapTo: function(x) { return this.map(function(p) { return (Math.abs(p - x) <= Sylvester.precision) ? x : p; }); }, // Returns a string representation of the matrix inspect: function() { var matrix_rows = []; var n = this.elements.length, k = n, i; do { i = k - n; matrix_rows.push(Vector.create(this.elements[i]).inspect()); } while (--n); return matrix_rows.join('\n'); }, // Set the matrix's elements from an array. If the argument passed // is a vector, the resulting matrix will be a single column. setElements: function(els) { var i, elements = els.elements || els; if (typeof(elements[0][0]) != 'undefined') { var ni = elements.length, ki = ni, nj, kj, j; this.elements = []; do { i = ki - ni; nj = elements[i].length; kj = nj; this.elements[i] = []; do { j = kj - nj; this.elements[i][j] = elements[i][j]; } while (--nj); } while(--ni); return this; } var n = elements.length, k = n; this.elements = []; do { i = k - n; this.elements.push([elements[i]]); } while (--n); return this; } }; // Constructor function Matrix.create = function(elements) { var M = new Matrix(); return M.setElements(elements); }; // Identity matrix of size n Matrix.I = function(n) { var els = [], k = n, i, nj, j; do { i = k - n; els[i] = []; nj = k; do { j = k - nj; els[i][j] = (i == j) ? 1 : 0; } while (--nj); } while (--n); return Matrix.create(els); }; // Diagonal matrix - all off-diagonal elements are zero Matrix.Diagonal = function(elements) { var n = elements.length, k = n, i; var M = Matrix.I(n); do { i = k - n; M.elements[i][i] = elements[i]; } while (--n); return M; }; // Rotation matrix about some axis. If no axis is // supplied, assume we're after a 2D transform Matrix.Rotation = function(theta, a) { if (!a) { return Matrix.create([ [Math.cos(theta), -Math.sin(theta)], [Math.sin(theta), Math.cos(theta)] ]); } var axis = a.dup(); if (axis.elements.length != 3) { return null; } var mod = axis.modulus(); var x = axis.elements[0]/mod, y = axis.elements[1]/mod, z = axis.elements[2]/mod; var s = Math.sin(theta), c = Math.cos(theta), t = 1 - c; // Formula derived here: http://www.gamedev.net/reference/articles/article1199.asp // That proof rotates the co-ordinate system so theta // becomes -theta and sin becomes -sin here. return Matrix.create([ [ t*x*x + c, t*x*y - s*z, t*x*z + s*y ], [ t*x*y + s*z, t*y*y + c, t*y*z - s*x ], [ t*x*z - s*y, t*y*z + s*x, t*z*z + c ] ]); }; // Special case rotations Matrix.RotationX = function(t) { var c = Math.cos(t), s = Math.sin(t); return Matrix.create([ [ 1, 0, 0 ], [ 0, c, -s ], [ 0, s, c ] ]); }; Matrix.RotationY = function(t) { var c = Math.cos(t), s = Math.sin(t); return Matrix.create([ [ c, 0, s ], [ 0, 1, 0 ], [ -s, 0, c ] ]); }; Matrix.RotationZ = function(t) { var c = Math.cos(t), s = Math.sin(t); return Matrix.create([ [ c, -s, 0 ], [ s, c, 0 ], [ 0, 0, 1 ] ]); }; // Random matrix of n rows, m columns Matrix.Random = function(n, m) { return Matrix.Zero(n, m).map( function() { return Math.random(); } ); }; // Matrix filled with zeros Matrix.Zero = function(n, m) { var els = [], ni = n, i, nj, j; do { i = n - ni; els[i] = []; nj = m; do { j = m - nj; els[i][j] = 0; } while (--nj); } while (--ni); return Matrix.create(els); }; function Line() {} Line.prototype = { // Returns true if the argument occupies the same space as the line eql: function(line) { return (this.isParallelTo(line) && this.contains(line.anchor)); }, // Returns a copy of the line dup: function() { return Line.create(this.anchor, this.direction); }, // Returns the result of translating the line by the given vector/array translate: function(vector) { var V = vector.elements || vector; return Line.create([ this.anchor.elements[0] + V[0], this.anchor.elements[1] + V[1], this.anchor.elements[2] + (V[2] || 0) ], this.direction); }, // Returns true if the line is parallel to the argument. Here, 'parallel to' // means that the argument's direction is either parallel or antiparallel to // the line's own direction. A line is parallel to a plane if the two do not // have a unique intersection. isParallelTo: function(obj) { if (obj.normal) { return obj.isParallelTo(this); } var theta = this.direction.angleFrom(obj.direction); return (Math.abs(theta) <= Sylvester.precision || Math.abs(theta - Math.PI) <= Sylvester.precision); }, // Returns the line's perpendicular distance from the argument, // which can be a point, a line or a plane distanceFrom: function(obj) { if (obj.normal) { return obj.distanceFrom(this); } if (obj.direction) { // obj is a line if (this.isParallelTo(obj)) { return this.distanceFrom(obj.anchor); } var N = this.direction.cross(obj.direction).toUnitVector().elements; var A = this.anchor.elements, B = obj.anchor.elements; return Math.abs((A[0] - B[0]) * N[0] + (A[1] - B[1]) * N[1] + (A[2] - B[2]) * N[2]); } else { // obj is a point var P = obj.elements || obj; var A = this.anchor.elements, D = this.direction.elements; var PA1 = P[0] - A[0], PA2 = P[1] - A[1], PA3 = (P[2] || 0) - A[2]; var modPA = Math.sqrt(PA1*PA1 + PA2*PA2 + PA3*PA3); if (modPA === 0) return 0; // Assumes direction vector is normalized var cosTheta = (PA1 * D[0] + PA2 * D[1] + PA3 * D[2]) / modPA; var sin2 = 1 - cosTheta*cosTheta; return Math.abs(modPA * Math.sqrt(sin2 < 0 ? 0 : sin2)); } }, // Returns true iff the argument is a point on the line contains: function(point) { var dist = this.distanceFrom(point); return (dist !== null && dist <= Sylvester.precision); }, // Returns true iff the line lies in the given plane liesIn: function(plane) { return plane.contains(this); }, // Returns true iff the line has a unique point of intersection with the argument intersects: function(obj) { if (obj.normal) { return obj.intersects(this); } return (!this.isParallelTo(obj) && this.distanceFrom(obj) <= Sylvester.precision); }, // Returns the unique intersection point with the argument, if one exists intersectionWith: function(obj) { if (obj.normal) { return obj.intersectionWith(this); } if (!this.intersects(obj)) { return null; } var P = this.anchor.elements, X = this.direction.elements, Q = obj.anchor.elements, Y = obj.direction.elements; var X1 = X[0], X2 = X[1], X3 = X[2], Y1 = Y[0], Y2 = Y[1], Y3 = Y[2]; var PsubQ1 = P[0] - Q[0], PsubQ2 = P[1] - Q[1], PsubQ3 = P[2] - Q[2]; var XdotQsubP = - X1*PsubQ1 - X2*PsubQ2 - X3*PsubQ3; var YdotPsubQ = Y1*PsubQ1 + Y2*PsubQ2 + Y3*PsubQ3; var XdotX = X1*X1 + X2*X2 + X3*X3; var YdotY = Y1*Y1 + Y2*Y2 + Y3*Y3; var XdotY = X1*Y1 + X2*Y2 + X3*Y3; var k = (XdotQsubP * YdotY / XdotX + XdotY * YdotPsubQ) / (YdotY - XdotY * XdotY); return Vector.create([P[0] + k*X1, P[1] + k*X2, P[2] + k*X3]); }, // Returns the point on the line that is closest to the given point or line pointClosestTo: function(obj) { if (obj.direction) { // obj is a line if (this.intersects(obj)) { return this.intersectionWith(obj); } if (this.isParallelTo(obj)) { return null; } var D = this.direction.elements, E = obj.direction.elements; var D1 = D[0], D2 = D[1], D3 = D[2], E1 = E[0], E2 = E[1], E3 = E[2]; // Create plane containing obj and the shared normal and intersect this with it // Thank you: http://www.cgafaq.info/wiki/Line-line_distance var x = (D3 * E1 - D1 * E3), y = (D1 * E2 - D2 * E1), z = (D2 * E3 - D3 * E2); var N = Vector.create([x * E3 - y * E2, y * E1 - z * E3, z * E2 - x * E1]); var P = Plane.create(obj.anchor, N); return P.intersectionWith(this); } else { // obj is a point var P = obj.elements || obj; if (this.contains(P)) { return Vector.create(P); } var A = this.anchor.elements, D = this.direction.elements; var D1 = D[0], D2 = D[1], D3 = D[2], A1 = A[0], A2 = A[1], A3 = A[2]; var x = D1 * (P[1]-A2) - D2 * (P[0]-A1), y = D2 * ((P[2] || 0) - A3) - D3 * (P[1]-A2), z = D3 * (P[0]-A1) - D1 * ((P[2] || 0) - A3); var V = Vector.create([D2 * x - D3 * z, D3 * y - D1 * x, D1 * z - D2 * y]); var k = this.distanceFrom(P) / V.modulus(); return Vector.create([ P[0] + V.elements[0] * k, P[1] + V.elements[1] * k, (P[2] || 0) + V.elements[2] * k ]); } }, // Returns a copy of the line rotated by t radians about the given line. Works by // finding the argument's closest point to this line's anchor point (call this C) and // rotating the anchor about C. Also rotates the line's direction about the argument's. // Be careful with this - the rotation axis' direction affects the outcome! rotate: function(t, line) { // If we're working in 2D if (typeof(line.direction) == 'undefined') { line = Line.create(line.to3D(), Vector.k); } var R = Matrix.Rotation(t, line.direction).elements; var C = line.pointClosestTo(this.anchor).elements; var A = this.anchor.elements, D = this.direction.elements; var C1 = C[0], C2 = C[1], C3 = C[2], A1 = A[0], A2 = A[1], A3 = A[2]; var x = A1 - C1, y = A2 - C2, z = A3 - C3; return Line.create([ C1 + R[0][0] * x + R[0][1] * y + R[0][2] * z, C2 + R[1][0] * x + R[1][1] * y + R[1][2] * z, C3 + R[2][0] * x + R[2][1] * y + R[2][2] * z ], [ R[0][0] * D[0] + R[0][1] * D[1] + R[0][2] * D[2], R[1][0] * D[0] + R[1][1] * D[1] + R[1][2] * D[2], R[2][0] * D[0] + R[2][1] * D[1] + R[2][2] * D[2] ]); }, // Returns the line's reflection in the given point or line reflectionIn: function(obj) { if (obj.normal) { // obj is a plane var A = this.anchor.elements, D = this.direction.elements; var A1 = A[0], A2 = A[1], A3 = A[2], D1 = D[0], D2 = D[1], D3 = D[2]; var newA = this.anchor.reflectionIn(obj).elements; // Add the line's direction vector to its anchor, then mirror that in the plane var AD1 = A1 + D1, AD2 = A2 + D2, AD3 = A3 + D3; var Q = obj.pointClosestTo([AD1, AD2, AD3]).elements; var newD = [Q[0] + (Q[0] - AD1) - newA[0], Q[1] + (Q[1] - AD2) - newA[1], Q[2] + (Q[2] - AD3) - newA[2]]; return Line.create(newA, newD); } else if (obj.direction) { // obj is a line - reflection obtained by rotating PI radians about obj return this.rotate(Math.PI, obj); } else { // obj is a point - just reflect the line's anchor in it var P = obj.elements || obj; return Line.create(this.anchor.reflectionIn([P[0], P[1], (P[2] || 0)]), this.direction); } }, // Set the line's anchor point and direction. setVectors: function(anchor, direction) { // Need to do this so that line's properties are not // references to the arguments passed in anchor = Vector.create(anchor); direction = Vector.create(direction); if (anchor.elements.length == 2) {anchor.elements.push(0); } if (direction.elements.length == 2) { direction.elements.push(0); } if (anchor.elements.length > 3 || direction.elements.length > 3) { return null; } var mod = direction.modulus(); if (mod === 0) { return null; } this.anchor = anchor; this.direction = Vector.create([ direction.elements[0] / mod, direction.elements[1] / mod, direction.elements[2] / mod ]); return this; } }; // Constructor function Line.create = function(anchor, direction) { var L = new Line(); return L.setVectors(anchor, direction); }; // Axes Line.X = Line.create(Vector.Zero(3), Vector.i); Line.Y = Line.create(Vector.Zero(3), Vector.j); Line.Z = Line.create(Vector.Zero(3), Vector.k); function Plane() {} Plane.prototype = { // Returns true iff the plane occupies the same space as the argument eql: function(plane) { return (this.contains(plane.anchor) && this.isParallelTo(plane)); }, // Returns a copy of the plane dup: function() { return Plane.create(this.anchor, this.normal); }, // Returns the result of translating the plane by the given vector translate: function(vector) { var V = vector.elements || vector; return Plane.create([ this.anchor.elements[0] + V[0], this.anchor.elements[1] + V[1], this.anchor.elements[2] + (V[2] || 0) ], this.normal); }, // Returns true iff the plane is parallel to the argument. Will return true // if the planes are equal, or if you give a line and it lies in the plane. isParallelTo: function(obj) { var theta; if (obj.normal) { // obj is a plane theta = this.normal.angleFrom(obj.normal); return (Math.abs(theta) <= Sylvester.precision || Math.abs(Math.PI - theta) <= Sylvester.precision); } else if (obj.direction) { // obj is a line return this.normal.isPerpendicularTo(obj.direction); } return null; }, // Returns true iff the receiver is perpendicular to the argument isPerpendicularTo: function(plane) { var theta = this.normal.angleFrom(plane.normal); return (Math.abs(Math.PI/2 - theta) <= Sylvester.precision); }, // Returns the plane's distance from the given object (point, line or plane) distanceFrom: function(obj) { if (this.intersects(obj) || this.contains(obj)) { return 0; } if (obj.anchor) { // obj is a plane or line var A = this.anchor.elements, B = obj.anchor.elements, N = this.normal.elements; return Math.abs((A[0] - B[0]) * N[0] + (A[1] - B[1]) * N[1] + (A[2] - B[2]) * N[2]); } else { // obj is a point var P = obj.elements || obj; var A = this.anchor.elements, N = this.normal.elements; return Math.abs((A[0] - P[0]) * N[0] + (A[1] - P[1]) * N[1] + (A[2] - (P[2] || 0)) * N[2]); } }, // Returns true iff the plane contains the given point or line contains: function(obj) { if (obj.normal) { return null; } if (obj.direction) { return (this.contains(obj.anchor) && this.contains(obj.anchor.add(obj.direction))); } else { var P = obj.elements || obj; var A = this.anchor.elements, N = this.normal.elements; var diff = Math.abs(N[0]*(A[0] - P[0]) + N[1]*(A[1] - P[1]) + N[2]*(A[2] - (P[2] || 0))); return (diff <= Sylvester.precision); } }, // Returns true iff the plane has a unique point/line of intersection with the argument intersects: function(obj) { if (typeof(obj.direction) == 'undefined' && typeof(obj.normal) == 'undefined') { return null; } return !this.isParallelTo(obj); }, // Returns the unique intersection with the argument, if one exists. The result // will be a vector if a line is supplied, and a line if a plane is supplied. intersectionWith: function(obj) { if (!this.intersects(obj)) { return null; } if (obj.direction) { // obj is a line var A = obj.anchor.elements, D = obj.direction.elements, P = this.anchor.elements, N = this.normal.elements; var multiplier = (N[0]*(P[0]-A[0]) + N[1]*(P[1]-A[1]) + N[2]*(P[2]-A[2])) / (N[0]*D[0] + N[1]*D[1] + N[2]*D[2]); return Vector.create([A[0] + D[0]*multiplier, A[1] + D[1]*multiplier, A[2] + D[2]*multiplier]); } else if (obj.normal) { // obj is a plane var direction = this.normal.cross(obj.normal).toUnitVector(); // To find an anchor point, we find one co-ordinate that has a value // of zero somewhere on the intersection, and remember which one we picked var N = this.normal.elements, A = this.anchor.elements, O = obj.normal.elements, B = obj.anchor.elements; var solver = Matrix.Zero(2,2), i = 0; while (solver.isSingular()) { i++; solver = Matrix.create([ [ N[i%3], N[(i+1)%3] ], [ O[i%3], O[(i+1)%3] ] ]); } // Then we solve the simultaneous equations in the remaining dimensions var inverse = solver.inverse().elements; var x = N[0]*A[0] + N[1]*A[1] + N[2]*A[2]; var y = O[0]*B[0] + O[1]*B[1] + O[2]*B[2]; var intersection = [ inverse[0][0] * x + inverse[0][1] * y, inverse[1][0] * x + inverse[1][1] * y ]; var anchor = []; for (var j = 1; j <= 3; j++) { // This formula picks the right element from intersection by // cycling depending on which element we set to zero above anchor.push((i == j) ? 0 : intersection[(j + (5 - i)%3)%3]); } return Line.create(anchor, direction); } }, // Returns the point in the plane closest to the given point pointClosestTo: function(point) { var P = point.elements || point; var A = this.anchor.elements, N = this.normal.elements; var dot = (A[0] - P[0]) * N[0] + (A[1] - P[1]) * N[1] + (A[2] - (P[2] || 0)) * N[2]; return Vector.create([P[0] + N[0] * dot, P[1] + N[1] * dot, (P[2] || 0) + N[2] * dot]); }, // Returns a copy of the plane, rotated by t radians about the given line // See notes on Line#rotate. rotate: function(t, line) { var R = Matrix.Rotation(t, line.direction).elements; var C = line.pointClosestTo(this.anchor).elements; var A = this.anchor.elements, N = this.normal.elements; var C1 = C[0], C2 = C[1], C3 = C[2], A1 = A[0], A2 = A[1], A3 = A[2]; var x = A1 - C1, y = A2 - C2, z = A3 - C3; return Plane.create([ C1 + R[0][0] * x + R[0][1] * y + R[0][2] * z, C2 + R[1][0] * x + R[1][1] * y + R[1][2] * z, C3 + R[2][0] * x + R[2][1] * y + R[2][2] * z ], [ R[0][0] * N[0] + R[0][1] * N[1] + R[0][2] * N[2], R[1][0] * N[0] + R[1][1] * N[1] + R[1][2] * N[2], R[2][0] * N[0] + R[2][1] * N[1] + R[2][2] * N[2] ]); }, // Returns the reflection of the plane in the given point, line or plane. reflectionIn: function(obj) { if (obj.normal) { // obj is a plane var A = this.anchor.elements, N = this.normal.elements; var A1 = A[0], A2 = A[1], A3 = A[2], N1 = N[0], N2 = N[1], N3 = N[2]; var newA = this.anchor.reflectionIn(obj).elements; // Add the plane's normal to its anchor, then mirror that in the other plane var AN1 = A1 + N1, AN2 = A2 + N2, AN3 = A3 + N3; var Q = obj.pointClosestTo([AN1, AN2, AN3]).elements; var newN = [Q[0] + (Q[0] - AN1) - newA[0], Q[1] + (Q[1] - AN2) - newA[1], Q[2] + (Q[2] - AN3) - newA[2]]; return Plane.create(newA, newN); } else if (obj.direction) { // obj is a line return this.rotate(Math.PI, obj); } else { // obj is a point var P = obj.elements || obj; return Plane.create(this.anchor.reflectionIn([P[0], P[1], (P[2] || 0)]), this.normal); } }, // Sets the anchor point and normal to the plane. If three arguments are specified, // the normal is calculated by assuming the three points should lie in the same plane. // If only two are sepcified, the second is taken to be the normal. Normal vector is // normalised before storage. setVectors: function(anchor, v1, v2) { anchor = Vector.create(anchor); anchor = anchor.to3D(); if (anchor === null) { return null; } v1 = Vector.create(v1); v1 = v1.to3D(); if (v1 === null) { return null; } if (typeof(v2) == 'undefined') { v2 = null; } else { v2 = Vector.create(v2); v2 = v2.to3D(); if (v2 === null) { return null; } } var A1 = anchor.elements[0], A2 = anchor.elements[1], A3 = anchor.elements[2]; var v11 = v1.elements[0], v12 = v1.elements[1], v13 = v1.elements[2]; var normal, mod; if (v2 !== null) { var v21 = v2.elements[0], v22 = v2.elements[1], v23 = v2.elements[2]; normal = Vector.create([ (v12 - A2) * (v23 - A3) - (v13 - A3) * (v22 - A2), (v13 - A3) * (v21 - A1) - (v11 - A1) * (v23 - A3), (v11 - A1) * (v22 - A2) - (v12 - A2) * (v21 - A1) ]); mod = normal.modulus(); if (mod === 0) { return null; } normal = Vector.create([normal.elements[0] / mod, normal.elements[1] / mod, normal.elements[2] / mod]); } else { mod = Math.sqrt(v11*v11 + v12*v12 + v13*v13); if (mod === 0) { return null; } normal = Vector.create([v1.elements[0] / mod, v1.elements[1] / mod, v1.elements[2] / mod]); } this.anchor = anchor; this.normal = normal; return this; } }; // Constructor function Plane.create = function(anchor, v1, v2) { var P = new Plane(); return P.setVectors(anchor, v1, v2); }; // X-Y-Z planes Plane.XY = Plane.create(Vector.Zero(3), Vector.k); Plane.YZ = Plane.create(Vector.Zero(3), Vector.i); Plane.ZX = Plane.create(Vector.Zero(3), Vector.j); Plane.YX = Plane.XY; Plane.ZY = Plane.YZ; Plane.XZ = Plane.ZX; // Utility functions var $V = Vector.create; var $M = Matrix.create; var $L = Line.create; var $P = Plane.create; ================================================ FILE: google/datalab/notebook/static/job.css ================================================ /* * Copyright 2015 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ p.jobfail { color: red; } p.jobsucceed { color: green; } p.jobfooter { font-size: smaller; } ================================================ FILE: google/datalab/notebook/static/job.ts ================================================ /* * Copyright 2015 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ /// declare var datalab: any; declare var IPython: any; module Job { function refresh(dom: any, job_name: any, job_type: any, interval: any, html_on_running: string, html_on_success: string): any { var code = '%_get_job_status ' + job_name + ' ' + job_type; datalab.session.execute(code, function (error: any, newData: any) { error = error || newData.error; if (error) { dom.innerHTML = '

Job failed with error: ' + error + '

'; return; } if (!newData.exists) { dom.innerHTML = '

The job does not exist.

'; } else if (newData.done) { dom.innerHTML = '

Job completed successfully.


' + html_on_success; } else { dom.innerHTML = 'Running...

Updated at ' + new Date().toLocaleTimeString() + '

' + html_on_running; setTimeout(function() { refresh(dom, job_name, job_type, interval, html_on_running, html_on_success); }, interval * 1000); } }); } // Render the job view. This is called from Python generated code. export function render(dom: any, events: any, job_name: string, job_type: string, interval: any, html_on_running: string, html_on_success: string) { if (IPython.notebook.kernel.is_connected()) { refresh(dom, job_name, job_type, interval, html_on_running, html_on_success); return; } // If the kernel is not connected, wait for the event. events.on('kernel_ready.Kernel', function(e: any) { refresh(dom, job_name, job_type, interval, html_on_running, html_on_success); }); } } export = Job; ================================================ FILE: google/datalab/notebook/static/parcoords.ts ================================================ /* * Copyright 2016 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ /// module ParCoords { function getCentroids(data: any, graph: any): any { var margins = graph.margin(); var graphCentPts: any[] = []; data.forEach(function(d: any){ var initCenPts = graph.compute_centroids(d).filter(function(d: any, i: number) {return i%2==0;}); var cenPts = initCenPts.map(function(d: any){ return [d[0] + margins["left"], d[1]+ margins["top"]]; }); graphCentPts.push(cenPts); }); return graphCentPts; } function getActiveData(graph: any): any{ if (graph.brushed()!=false) return graph.brushed(); return graph.data(); } function findAxes(testPt: any, cenPts: any): number { var x: number = testPt[0]; var y: number = testPt[1]; if (cenPts[0][0] > x) return 0; if (cenPts[cenPts.length-1][0] < x) return 0; for (var i=0; i x) return i; } return 0; } function isOnLine(startPt: any, endPt: any, testPt: any, tol: number){ var x0 = testPt[0]; var y0 = testPt[1]; var x1 = startPt[0]; var y1 = startPt[1]; var x2 = endPt[0]; var y2 = endPt[1]; var Dx = x2 - x1; var Dy = y2 - y1; var delta = Math.abs(Dy*x0 - Dx*y0 - x1*y2+x2*y1)/Math.sqrt(Math.pow(Dx, 2) + Math.pow(Dy, 2)); if (delta <= tol) return true; return false; } function getClickedLines(mouseClick: any, graph: any): any { var clicked: any[] = []; var clickedCenPts: any[] = []; // find which data is activated right now var activeData: any = getActiveData(graph); // find centriod points var graphCentPts: any = getCentroids(activeData, graph); if (graphCentPts.length==0) return false; // find between which axes the point is var axeNum: number = findAxes(mouseClick, graphCentPts[0]); if (!axeNum) return false; graphCentPts.forEach(function(d: any, i: number){ if (isOnLine(d[axeNum-1], d[axeNum], mouseClick, 2)) { clicked.push(activeData[i]); clickedCenPts.push(graphCentPts[i]); // for tooltip } }); return [clicked, clickedCenPts] } function highlightLineOnClick(mouseClick: any, graph: any) { var clicked: any[] = []; var clickedCenPts: any[] = []; var clickedData: any = getClickedLines(mouseClick, graph); if (clickedData && clickedData[0].length!=0){ clicked = clickedData[0]; clickedCenPts = clickedData[1]; // highlight clicked line graph.highlight(clicked); } }; export function plot(d3: any, color_domain: number[], maximize: boolean, data: any, graph_html_id: string, grid_html_id: string) { var range = ["green", "gray"]; if (maximize) { range = ["gray", "green"]; } var blue_to_brown = d3.scale.linear().domain(color_domain) .range(range) .interpolate(d3.interpolateLab); var color = function(d: any) { return blue_to_brown(d['Objective']); }; var columns_hide: string[] = ["Trial", "Training Step"]; for (var attr in data) { if (attr.lastIndexOf("(log)") > 0) { columns_hide.push(attr.slice(0, -5)); } } var data_display: any[] = []; for (var i: number =0; i module Style { 'use strict'; // An object containing the set of loaded stylesheets, so as to avoid reloading. var loadedStyleSheets: any = {}; // An object containing stylesheets to load, once the DOM is ready. var pendingStyleSheets: Array = null; function addStyleSheet(url: string): void { loadedStyleSheets[url] = true; var stylesheet = document.createElement('link'); stylesheet.type = 'text/css'; stylesheet.rel = 'stylesheet'; stylesheet.href = url; document.getElementsByTagName('head')[0].appendChild(stylesheet); } function domReadyCallback(): void { if (pendingStyleSheets) { // Clear out pendingStyleSheets, so any future adds are immediately processed. var styleSheets: Array = pendingStyleSheets; pendingStyleSheets = null; styleSheets.forEach(addStyleSheet); } } export function load(url: string, req: any, loadCallback: any, config: any): void { if (config.isBuild) { loadCallback(null); } else { // Go ahead and immediately/optimistically resolve this, since the resolved value of a // stylesheet is never interesting. setTimeout(loadCallback, 0); // Only load a specified stylesheet once for the lifetime of this page. if (loadedStyleSheets[url]) { return; } loadedStyleSheets[url] = true; if (document.readyState == 'loading') { if (!pendingStyleSheets) { pendingStyleSheets = []; document.addEventListener('DOMContentLoaded', domReadyCallback, false); } pendingStyleSheets.push(url); } else { addStyleSheet(url); } } } } export = Style; ================================================ FILE: google/datalab/notebook/static/visualization.ts ================================================ /* * Copyright 2015 Google Inc. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ // require.js plugin to allow Google Chart API to be loaded. /// declare var google: any; declare var window: any; module Visualization { 'use strict'; // Queued packages to load until the google api loader itself has not been loaded. var queue: any = { packages: [], callbacks: [] }; function loadGoogleApiLoader(callback: any): void { // Visualization packages are loaded using the Google loader. // The loader URL itself must contain a callback (by name) that it invokes when its loaded. var callbackName: string = '__googleApiLoaderCallback'; window[callbackName] = callback; var script = document.createElement('script'); script.type = 'text/javascript'; script.async = true; script.src = 'https://www.google.com/jsapi?callback=' + callbackName; document.getElementsByTagName('head')[0].appendChild(script); } function invokeVisualizationCallback(cb: any) { cb(google.visualization); } function loadVisualizationPackages(names: any, callbacks: any): void { if (names.length) { var visualizationOptions = { packages: names, callback: function() { callbacks.forEach(invokeVisualizationCallback); } }; google.load('visualization', '1', visualizationOptions); } } loadGoogleApiLoader(function() { if (queue) { loadVisualizationPackages(queue.packages, queue.callbacks); queue = null; } }); export function load(name: any, req: any, callback: any, config: any) { if (config.isBuild) { callback(null); } else { if (queue) { // Queue the package and associated callback to load, once the loader has been loaded. queue.packages.push(name); queue.callbacks.push(callback); } else { // Loader has already been loaded, so go ahead and load the specified package. loadVisualizationPackages([ name ], [ callback ]); } } } } export = Visualization; ================================================ FILE: google/datalab/stackdriver/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - Stackdriver Functionality.""" ================================================ FILE: google/datalab/stackdriver/commands/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from . import _monitoring __all__ = ['_monitoring'] ================================================ FILE: google/datalab/stackdriver/commands/_monitoring.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """IPython Functionality for the Google Monitoring API.""" from __future__ import absolute_import try: import IPython.core.display import IPython.core.magic except ImportError: raise Exception('This module can only be loaded in ipython.') import google.datalab import google.datalab.stackdriver.monitoring as gcm import google.datalab.utils.commands @IPython.core.magic.register_line_cell_magic def sd(line, cell=None): """Implements the stackdriver cell magic for ipython notebooks. Args: line: the contents of the storage line. Returns: The results of executing the cell. """ parser = google.datalab.utils.commands.CommandParser(prog='%sd', description=( 'Execute various Stackdriver related operations. Use "%sd ' ' -h" for help on a specific Stackdriver product.')) # %%sd monitoring _create_monitoring_subparser(parser) return google.datalab.utils.commands.handle_magic_line(line, cell, parser) def _create_monitoring_subparser(parser): monitoring_parser = parser.subcommand( 'monitoring', 'Execute Stackdriver monitoring related operations. Use ' '"sd monitoring -h" for help on a specific command') metric_parser = monitoring_parser.subcommand( 'metrics', 'Operations on Stackdriver Monitoring metrics') metric_list_parser = metric_parser.subcommand('list', 'List metrics') metric_list_parser.add_argument( '-p', '--project', help='The project whose metrics should be listed.') metric_list_parser.add_argument( '-t', '--type', help='The type of metric(s) to list; can include wildchars.') metric_list_parser.set_defaults(func=_monitoring_metrics_list) resource_parser = monitoring_parser.subcommand( 'resource_types', 'Operations on Stackdriver Monitoring resource types') resource_list_parser = resource_parser.subcommand('list', 'List resource types') resource_list_parser.add_argument( '-p', '--project', help='The project whose resource types should be listed.') resource_list_parser.add_argument( '-t', '--type', help='The resource type(s) to list; can include wildchars.') resource_list_parser.set_defaults(func=_monitoring_resource_types_list) group_parser = monitoring_parser.subcommand( 'groups', 'Operations on Stackdriver groups') group_list_parser = group_parser.subcommand('list', 'List groups') group_list_parser.add_argument( '-p', '--project', help='The project whose groups should be listed.') group_list_parser.add_argument( '-n', '--name', help='The name of the group(s) to list; can include wildchars.') group_list_parser.set_defaults(func=_monitoring_groups_list) def _monitoring_metrics_list(args, _): """Lists the metric descriptors in the project.""" project_id = args['project'] pattern = args['type'] or '*' descriptors = gcm.MetricDescriptors(context=_make_context(project_id)) dataframe = descriptors.as_dataframe(pattern=pattern) return _render_dataframe(dataframe) def _monitoring_resource_types_list(args, _): """Lists the resource descriptors in the project.""" project_id = args['project'] pattern = args['type'] or '*' descriptors = gcm.ResourceDescriptors(context=_make_context(project_id)) dataframe = descriptors.as_dataframe(pattern=pattern) return _render_dataframe(dataframe) def _monitoring_groups_list(args, _): """Lists the groups in the project.""" project_id = args['project'] pattern = args['name'] or '*' groups = gcm.Groups(context=_make_context(project_id)) dataframe = groups.as_dataframe(pattern=pattern) return _render_dataframe(dataframe) def _render_dataframe(dataframe): """Helper to render a dataframe as an HTML table.""" data = dataframe.to_dict(orient='records') fields = dataframe.columns.tolist() return IPython.core.display.HTML( google.datalab.utils.commands.HtmlBuilder.render_table(data, fields)) def _make_context(project_id): default_context = google.datalab.Context.default() if project_id: return google.datalab.Context(project_id, default_context.credentials) else: return default_context ================================================ FILE: google/datalab/stackdriver/monitoring/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - Monitoring Functionality.""" from __future__ import absolute_import from google.cloud.monitoring import enums from ._group import Groups from ._metric import MetricDescriptors from ._query import Query from ._query_metadata import QueryMetadata from ._resource import ResourceDescriptors Aligner = enums.Aggregation.Aligner Reducer = enums.Aggregation.Reducer __all__ = ['Aligner', 'Reducer', 'Groups', 'MetricDescriptors', 'Query', 'QueryMetadata', 'ResourceDescriptors'] ================================================ FILE: google/datalab/stackdriver/monitoring/_group.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Groups for the Google Monitoring API.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object import collections import fnmatch import pandas import google.datalab from . import _utils class Groups(object): """Represents a list of Stackdriver groups for a project.""" _DISPLAY_HEADERS = ('Group ID', 'Group name', 'Parent ID', 'Parent name', 'Is cluster', 'Filter') def __init__(self, context=None): """Initializes the Groups for a Stackdriver project. Args: context: An optional Context object to use instead of the global default. """ self._context = context or google.datalab.Context.default() self._client = _utils.make_client(self._context) self._group_dict = None def list(self, pattern='*'): """Returns a list of groups that match the filters. Args: pattern: An optional pattern to filter the groups based on their display name. This can include Unix shell-style wildcards. E.g. ``"Production*"``. Returns: A list of Group objects that match the filters. """ if self._group_dict is None: self._group_dict = collections.OrderedDict( (group.name, group) for group in self._client.list_groups()) return [group for group in self._group_dict.values() if fnmatch.fnmatch(group.display_name, pattern)] def as_dataframe(self, pattern='*', max_rows=None): """Creates a pandas dataframe from the groups that match the filters. Args: pattern: An optional pattern to further filter the groups. This can include Unix shell-style wildcards. E.g. ``"Production *"``, ``"*-backend"``. max_rows: The maximum number of groups to return. If None, return all. Returns: A pandas dataframe containing matching groups. """ data = [] for i, group in enumerate(self.list(pattern)): if max_rows is not None and i >= max_rows: break parent = self._group_dict.get(group.parent_name) parent_display_name = '' if parent is None else parent.display_name data.append([ group.name, group.display_name, group.parent_name, parent_display_name, group.is_cluster, group.filter]) return pandas.DataFrame(data, columns=self._DISPLAY_HEADERS) ================================================ FILE: google/datalab/stackdriver/monitoring/_metric.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Provides the MetricDescriptors in the monitoring API.""" from __future__ import absolute_import from builtins import object from google.cloud.monitoring_v3 import enums import fnmatch import pandas from . import _utils class MetricDescriptors(object): """MetricDescriptors object for retrieving the metric descriptors.""" _DISPLAY_HEADERS = ('Metric type', 'Display name', 'Kind', 'Value', 'Unit', 'Labels') def __init__(self, filter_string=None, type_prefix=None, context=None): """Initializes the MetricDescriptors based on the specified filters. Args: filter_string: An optional filter expression describing the resource descriptors to be returned. type_prefix: An optional prefix constraining the selected metric types. This adds ``metric.type = starts_with("")`` to the filter. context: An optional Context object to use instead of the global default. """ self._client = _utils.make_client(context) self._filter_string = filter_string self._type_prefix = type_prefix self._descriptors = None def list(self, pattern='*'): """Returns a list of metric descriptors that match the filters. Args: pattern: An optional pattern to further filter the descriptors. This can include Unix shell-style wildcards. E.g. ``"compute*"``, ``"*cpu/load_??m"``. Returns: A list of MetricDescriptor objects that match the filters. """ if self._descriptors is None: self._descriptors = self._client.list_metric_descriptors( filter_string=self._filter_string, type_prefix=self._type_prefix) return [metric for metric in self._descriptors if fnmatch.fnmatch(metric.type, pattern)] def as_dataframe(self, pattern='*', max_rows=None): """Creates a pandas dataframe from the descriptors that match the filters. Args: pattern: An optional pattern to further filter the descriptors. This can include Unix shell-style wildcards. E.g. ``"compute*"``, ``"*/cpu/load_??m"``. max_rows: The maximum number of descriptors to return. If None, return all. Returns: A pandas dataframe containing matching metric descriptors. """ data = [] for i, metric in enumerate(self.list(pattern)): if max_rows is not None and i >= max_rows: break labels = ', '. join([l.key for l in metric.labels]) data.append([ metric.type, metric.display_name, enums.MetricDescriptor.MetricKind(metric.metric_kind).name, enums.MetricDescriptor.ValueType(metric.value_type).name, metric.unit, labels]) return pandas.DataFrame(data, columns=self._DISPLAY_HEADERS) ================================================ FILE: google/datalab/stackdriver/monitoring/_query.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Provides access to metric data as pandas dataframes.""" from __future__ import absolute_import import google.cloud.monitoring_v3.query from . import _query_metadata from . import _utils class Query(google.cloud.monitoring_v3.query.Query): """Query object for retrieving metric data.""" def __init__(self, metric_type=google.cloud.monitoring_v3.query.Query.DEFAULT_METRIC_TYPE, end_time=None, days=0, hours=0, minutes=0, context=None): """Initializes the core query parameters. The start time (exclusive) is determined by combining the values of ``days``, ``hours``, and ``minutes``, and subtracting the resulting duration from the end time. It is also allowed to omit the end time and duration here, in which case :meth:`~google.cloud.monitoring_v3.query.Query.select_interval` must be called before the query is executed. Args: metric_type: The metric type name. The default value is :data:`Query.DEFAULT_METRIC_TYPE `, but please note that this default value is provided only for demonstration purposes and is subject to change. end_time: The end time (inclusive) of the time interval for which results should be returned, as a datetime object. The default is the start of the current minute. days: The number of days in the time interval. hours: The number of hours in the time interval. minutes: The number of minutes in the time interval. context: An optional Context object to use instead of the global default. Raises: ValueError: ``end_time`` was specified but ``days``, ``hours``, and ``minutes`` are all zero. If you really want to specify a point in time, use :meth:`~google.cloud.monitoring_v3.query.Query.select_interval`. """ client = _utils.make_client(context) super(Query, self).__init__(client.metrics_client, project=client.project, metric_type=metric_type, end_time=end_time, days=days, hours=hours, minutes=minutes) def metadata(self): """Retrieves the metadata for the query.""" return _query_metadata.QueryMetadata(self) ================================================ FILE: google/datalab/stackdriver/monitoring/_query_metadata.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """QueryMetadata object that shows the metadata in a query's results.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object from google.cloud.monitoring_v3 import _dataframe from google.protobuf.json_format import MessageToDict import pandas class QueryMetadata(object): """QueryMetadata object contains the metadata of a timeseries query.""" def __init__(self, query): """Initializes the QueryMetadata given the query object. Args: query: A Query object. """ self._timeseries_list = list(query.iter(headers_only=True)) # Note: If self._timeseries_list has even one entry, the metric type # can be extracted from there as well. self._metric_type = query.metric_type def __iter__(self): for timeseries in self._timeseries_list: yield timeseries @property def metric_type(self): """Returns the metric type in the underlying query.""" return self._metric_type @property def resource_types(self): """Returns a set containing resource types in the query result.""" return set([ts.resource.type for ts in self._timeseries_list]) def as_dataframe(self, max_rows=None): """Creates a pandas dataframe from the query metadata. Args: max_rows: The maximum number of timeseries metadata to return. If None, return all. Returns: A pandas dataframe containing the resource type, resource labels and metric labels. Each row in this dataframe corresponds to the metadata from one time series. """ max_rows = len(self._timeseries_list) if max_rows is None else max_rows headers = [{ 'resource': MessageToDict(ts.resource), 'metric': MessageToDict(ts.metric) } for ts in self._timeseries_list[:max_rows]] if not headers: return pandas.DataFrame() dataframe = pandas.io.json.json_normalize(headers) # Add a 2 level column header. dataframe.columns = pandas.MultiIndex.from_tuples( [(col, '') if col == 'resource.type' else col.rsplit('.', 1) for col in dataframe.columns]) # Re-order the columns. resource_keys = _dataframe._sorted_resource_labels( dataframe['resource.labels'].columns) sorted_columns = [('resource.type', '')] sorted_columns += [('resource.labels', key) for key in resource_keys] sorted_columns += sorted(col for col in dataframe.columns if col[0] == 'metric.labels') dataframe = dataframe[sorted_columns] # Sort the data, and clean up index values, and NaNs. dataframe = dataframe.sort_values(sorted_columns) dataframe = dataframe.reset_index(drop=True).fillna('') return dataframe ================================================ FILE: google/datalab/stackdriver/monitoring/_resource.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Provides the ResourceDescriptors in the monitoring API.""" from __future__ import absolute_import from builtins import object import fnmatch import pandas from . import _utils class ResourceDescriptors(object): """ResourceDescriptors object for retrieving the resource descriptors.""" _DISPLAY_HEADERS = ('Resource type', 'Display name', 'Labels') def __init__(self, filter_string=None, context=None): """Initializes the ResourceDescriptors based on the specified filters. Args: filter_string: An optional filter expression describing the resource descriptors to be returned. context: An optional Context object to use instead of the global default. """ self._client = _utils.make_client(context) self._filter_string = filter_string self._descriptors = None def list(self, pattern='*'): """Returns a list of resource descriptors that match the filters. Args: pattern: An optional pattern to further filter the descriptors. This can include Unix shell-style wildcards. E.g. ``"aws*"``, ``"*cluster*"``. Returns: A list of ResourceDescriptor objects that match the filters. """ if self._descriptors is None: self._descriptors = self._client.list_resource_descriptors( filter_string=self._filter_string) return [resource for resource in self._descriptors if fnmatch.fnmatch(resource.type, pattern)] def as_dataframe(self, pattern='*', max_rows=None): """Creates a pandas dataframe from the descriptors that match the filters. Args: pattern: An optional pattern to further filter the descriptors. This can include Unix shell-style wildcards. E.g. ``"aws*"``, ``"*cluster*"``. max_rows: The maximum number of descriptors to return. If None, return all. Returns: A pandas dataframe containing matching resource descriptors. """ data = [] for i, resource in enumerate(self.list(pattern)): if max_rows is not None and i >= max_rows: break labels = ', '. join([l.key for l in resource.labels]) data.append([resource.type, resource.display_name, labels]) return pandas.DataFrame(data, columns=self._DISPLAY_HEADERS) ================================================ FILE: google/datalab/stackdriver/monitoring/_utils.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Provides utility methods for the Monitoring API.""" from __future__ import absolute_import from google.api_core.gapic_v1.client_info import ClientInfo from google.cloud.monitoring_v3 import MetricServiceClient from google.cloud.monitoring_v3 import GroupServiceClient import google.datalab # _MonitoringClient holds instances of individual google.cloud.monitoring # clients and translates each call from the old signature, since the prior # client has been updated and has split into multiple client classes. class _MonitoringClient(object): def __init__(self, context): self.project = context.project_id client_info = ClientInfo(user_agent='pydatalab/v0') self.metrics_client = MetricServiceClient( credentials=context.credentials, client_info=client_info ) self.group_client = GroupServiceClient( credentials=context.credentials, client_info=client_info ) def list_metric_descriptors(self, filter_string=None, type_prefix=None): filters = [] if filter_string is not None: filters.append(filter_string) if type_prefix is not None: filters.append('metric.type = starts_with("{prefix}")'.format( prefix=type_prefix)) metric_filter = ' AND '.join(filters) metrics = self.metrics_client.list_metric_descriptors( self.project, filter_=metric_filter) return metrics def list_resource_descriptors(self, filter_string=None): resources = self.metrics_client.list_monitored_resource_descriptors( self.project, filter_=filter_string) return resources def list_groups(self): groups = self.group_client.list_groups(self.project) return groups def make_client(context=None): context = context or google.datalab.Context.default() client = _MonitoringClient(context) return client ================================================ FILE: google/datalab/storage/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - Cloud Storage Functionality.""" from __future__ import absolute_import from ._bucket import Bucket, Buckets from ._object import Object, Objects __all__ = ['Bucket', 'Buckets', 'Object', 'Objects'] ================================================ FILE: google/datalab/storage/_api.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Storage HTTP API wrapper.""" from __future__ import absolute_import from __future__ import unicode_literals from future import standard_library standard_library.install_aliases() # noqa from builtins import object import google.datalab import urllib.request import urllib.parse import urllib.error import google.datalab.utils class Api(object): """A helper class to issue Storage HTTP requests.""" # TODO(nikhilko): Use named placeholders in these string templates. _ENDPOINT = 'https://www.googleapis.com/storage/v1' _DOWNLOAD_ENDPOINT = 'https://www.googleapis.com/download/storage/v1' _UPLOAD_ENDPOINT = 'https://www.googleapis.com/upload/storage/v1' _BUCKET_PATH = '/b/%s' _OBJECT_PATH = '/b/%s/o/%s' _OBJECT_COPY_PATH = '/b/%s/o/%s/copyTo/b/%s/o/%s' _MAX_RESULTS = 100 def __init__(self, context): """Initializes the Storage helper with context information. Args: context: a Context object providing project_id and credentials. """ self._credentials = context.credentials self._project_id = context.project_id @property def project_id(self): """The project_id associated with this API client.""" return self._project_id def buckets_insert(self, bucket, project_id=None): """Issues a request to create a new bucket. Args: bucket: the name of the bucket. project_id: the project to use when inserting the bucket. Returns: A parsed bucket information dictionary. Raises: Exception if there is an error performing the operation. """ args = {'project': project_id if project_id else self._project_id} data = {'name': bucket} url = Api._ENDPOINT + (Api._BUCKET_PATH % '') return google.datalab.utils.Http.request(url, args=args, data=data, credentials=self._credentials) def buckets_delete(self, bucket): """Issues a request to delete a bucket. Args: bucket: the name of the bucket. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._BUCKET_PATH % bucket) google.datalab.utils.Http.request(url, method='DELETE', credentials=self._credentials, raw_response=True) def buckets_get(self, bucket, projection='noAcl'): """Issues a request to retrieve information about a bucket. Args: bucket: the name of the bucket. projection: the projection of the bucket information to retrieve. Returns: A parsed bucket information dictionary. Raises: Exception if there is an error performing the operation. """ args = {'projection': projection} url = Api._ENDPOINT + (Api._BUCKET_PATH % bucket) return google.datalab.utils.Http.request(url, credentials=self._credentials, args=args) def buckets_list(self, projection='noAcl', max_results=0, page_token=None, project_id=None): """Issues a request to retrieve the list of buckets. Args: projection: the projection of the bucket information to retrieve. max_results: an optional maximum number of objects to retrieve. page_token: an optional token to continue the retrieval. project_id: the project whose buckets should be listed. Returns: A parsed list of bucket information dictionaries. Raises: Exception if there is an error performing the operation. """ if max_results == 0: max_results = Api._MAX_RESULTS args = {'project': project_id if project_id else self._project_id, 'maxResults': max_results} if projection is not None: args['projection'] = projection if page_token is not None: args['pageToken'] = page_token url = Api._ENDPOINT + (Api._BUCKET_PATH % '') return google.datalab.utils.Http.request(url, args=args, credentials=self._credentials) def object_download(self, bucket, key, start_offset=0, byte_count=None): """Reads the contents of an object as text. Args: bucket: the name of the bucket containing the object. key: the key of the object to be read. start_offset: the start offset of bytes to read. byte_count: the number of bytes to read. If None, it reads to the end. Returns: The text content within the object. Raises: Exception if the object could not be read from. """ args = {'alt': 'media'} headers = {} if start_offset > 0 or byte_count is not None: header = 'bytes=%d-' % start_offset if byte_count is not None: header += '%d' % byte_count headers['Range'] = header url = Api._DOWNLOAD_ENDPOINT + (Api._OBJECT_PATH % (bucket, Api._escape_key(key))) return google.datalab.utils.Http.request(url, args=args, headers=headers, credentials=self._credentials, raw_response=True) def object_upload(self, bucket, key, content, content_type): """Writes text content to the object. Args: bucket: the name of the bucket containing the object. key: the key of the object to be written. content: the text content to be written. content_type: the type of text content. Raises: Exception if the object could not be written to. """ args = {'uploadType': 'media', 'name': key} headers = {'Content-Type': content_type} url = Api._UPLOAD_ENDPOINT + (Api._OBJECT_PATH % (bucket, '')) return google.datalab.utils.Http.request(url, args=args, data=content, headers=headers, credentials=self._credentials, raw_response=True) def objects_copy(self, source_bucket, source_key, target_bucket, target_key): """Updates the metadata associated with an object. Args: source_bucket: the name of the bucket containing the source object. source_key: the key of the source object being copied. target_bucket: the name of the bucket that will contain the copied object. target_key: the key of the copied object. Returns: A parsed object information dictionary. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._OBJECT_COPY_PATH % (source_bucket, Api._escape_key(source_key), target_bucket, Api._escape_key(target_key))) return google.datalab.utils.Http.request(url, method='POST', credentials=self._credentials) def objects_delete(self, bucket, key): """Deletes the specified object. Args: bucket: the name of the bucket. key: the key of the object within the bucket. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._OBJECT_PATH % (bucket, Api._escape_key(key))) google.datalab.utils.Http.request(url, method='DELETE', credentials=self._credentials, raw_response=True) def objects_get(self, bucket, key, projection='noAcl'): """Issues a request to retrieve information about an object. Args: bucket: the name of the bucket. key: the key of the object within the bucket. projection: the projection of the object to retrieve. Returns: A parsed object information dictionary. Raises: Exception if there is an error performing the operation. """ args = {} if projection is not None: args['projection'] = projection url = Api._ENDPOINT + (Api._OBJECT_PATH % (bucket, Api._escape_key(key))) return google.datalab.utils.Http.request(url, args=args, credentials=self._credentials) def objects_list(self, bucket, prefix=None, delimiter=None, projection='noAcl', versions=False, max_results=0, page_token=None): """Issues a request to retrieve information about an object. Args: bucket: the name of the bucket. prefix: an optional key prefix. delimiter: an optional key delimiter. projection: the projection of the objects to retrieve. versions: whether to list each version of a file as a distinct object. max_results: an optional maximum number of objects to retrieve. page_token: an optional token to continue the retrieval. Returns: A parsed list of object information dictionaries. Raises: Exception if there is an error performing the operation. """ if max_results == 0: max_results = Api._MAX_RESULTS args = {'maxResults': max_results} if prefix is not None: args['prefix'] = prefix if delimiter is not None: args['delimiter'] = delimiter if projection is not None: args['projection'] = projection if versions: args['versions'] = 'true' if page_token is not None: args['pageToken'] = page_token url = Api._ENDPOINT + (Api._OBJECT_PATH % (bucket, '')) return google.datalab.utils.Http.request(url, args=args, credentials=self._credentials) def objects_patch(self, bucket, key, info): """Updates the metadata associated with an object. Args: bucket: the name of the bucket containing the object. key: the key of the object being updated. info: the metadata to update. Returns: A parsed object information dictionary. Raises: Exception if there is an error performing the operation. """ url = Api._ENDPOINT + (Api._OBJECT_PATH % (bucket, Api._escape_key(key))) return google.datalab.utils.Http.request(url, method='PATCH', data=info, credentials=self._credentials) @staticmethod def _escape_key(key): # Disable the behavior to leave '/' alone by explicitly specifying the safe parameter. return urllib.parse.quote(key, safe='') @staticmethod def verify_permitted_to_read(gs_path): """Check if the user has permissions to read from the given path. Args: gs_path: the GCS path to check if user is permitted to read. Raises: Exception if user has no permissions to read. """ # TODO(qimingj): Storage APIs need to be modified to allow absence of project # or credential on Objects. When that happens we can move the function # to Objects class. from . import _bucket bucket, prefix = _bucket.parse_name(gs_path) credentials = None if google.datalab.Context._is_signed_in(): credentials = google.datalab.Context.default().credentials args = { 'maxResults': Api._MAX_RESULTS, 'projection': 'noAcl' } if prefix is not None: args['prefix'] = prefix url = Api._ENDPOINT + (Api._OBJECT_PATH % (bucket, '')) try: google.datalab.utils.Http.request(url, args=args, credentials=credentials) except google.datalab.utils.RequestException as e: if e.status == 401: raise Exception('Not permitted to read from specified path. ' 'Please sign in and make sure you have read access.') raise e ================================================ FILE: google/datalab/storage/_bucket.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Bucket-related Cloud Storage APIs.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object import dateutil.parser import re import google.datalab import google.datalab.utils from . import _api from . import _object # REs to match bucket names and optionally object names _BUCKET_NAME = '[a-z\d][a-z\d_\.\-]+[a-z\d]' _OBJECT_NAME = '[^\n\r]+' _STORAGE_NAME = 'gs://(' + _BUCKET_NAME + ')(/' + _OBJECT_NAME + ')?' def parse_name(name): """ Parse a gs:// URL into the bucket and object names. Args: name: a GCS URL of the form gs://bucket or gs://bucket/object Returns: The bucket name (with no gs:// prefix), and the object name if present. If the name could not be parsed returns None for both. """ bucket = None obj = None m = re.match(_STORAGE_NAME, name) if m: # We want to return the last two groups as first group is the optional 'gs://' bucket = m.group(1) obj = m.group(2) if obj is not None: obj = obj[1:] # Strip '/' else: m = re.match('(' + _OBJECT_NAME + ')', name) if m: obj = m.group(1) return bucket, obj class BucketMetadata(object): """Represents metadata about a Cloud Storage bucket.""" def __init__(self, info): """Initializes an instance of a BucketMetadata object. Args: info: a dictionary containing information about an Bucket. """ self._info = info @property def created_on(self): """The created timestamp of the bucket as a datetime.datetime.""" s = self._info.get('timeCreated', None) return dateutil.parser.parse(s) if s else None @property def etag(self): """The ETag of the bucket, if any.""" return self._info.get('etag', None) @property def name(self): """The name of the bucket.""" return self._info['name'] class Bucket(object): """Represents a Cloud Storage bucket.""" def __init__(self, name, info=None, context=None): """Initializes an instance of a Bucket object. Args: name: the name of the bucket. info: the information about the bucket if available. context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. """ if context is None: context = google.datalab.Context.default() self._context = context self._api = _api.Api(context) self._name = name self._info = info @property def name(self): """The name of the bucket.""" return self._name def __repr__(self): """Returns a representation for the table for showing in the notebook. """ return 'Google Cloud Storage Bucket gs://%s' % self._name @property def metadata(self): """Retrieves metadata about the bucket. Returns: A BucketMetadata instance with information about this bucket. Raises: Exception if there was an error requesting the bucket's metadata. """ if self._info is None: try: self._info = self._api.buckets_get(self._name) except Exception as e: raise e return BucketMetadata(self._info) if self._info else None def object(self, key): """Retrieves a Storage Object for the specified key in this bucket. The object need not exist. Args: key: the key of the object within the bucket. Returns: An Object instance representing the specified key. """ return _object.Object(self._name, key, context=self._context) def objects(self, prefix=None, delimiter=None): """Get an iterator for the objects within this bucket. Args: prefix: an optional prefix to match objects. delimiter: an optional string to simulate directory-like semantics. The returned objects will be those whose names do not contain the delimiter after the prefix. For the remaining objects, the names will be returned truncated after the delimiter with duplicates removed (i.e. as pseudo-directories). Returns: An iterable list of objects within this bucket. """ return _object.Objects(self._name, prefix, delimiter, context=self._context) def exists(self): """ Checks if the bucket exists. """ try: return self.metadata is not None except Exception: return False def create(self, context=None): """Creates the bucket. Args: context: the context object to use when creating the bucket. Returns: The bucket. Raises: Exception if there was an error creating the bucket. """ if not self.exists(): project_id = context.project_id if context else self._api.project_id try: self._info = self._api.buckets_insert(self._name, project_id=project_id) except Exception as e: raise e return self def delete(self): """Deletes the bucket. Raises: Exception if there was an error deleting the bucket. """ if self.exists(): try: self._api.buckets_delete(self._name) except Exception as e: raise e class Buckets(object): """Represents a list of Cloud Storage buckets for a project.""" def __init__(self, context=None): """Initializes an instance of a BucketList. Args: context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. """ if context is None: context = google.datalab.Context.default() self._context = context self._api = _api.Api(context) self._project_id = context.project_id if context else self._api.project_id def contains(self, name): """Checks if the specified bucket exists. Args: name: the name of the bucket to lookup. Returns: True if the bucket exists; False otherwise. Raises: Exception if there was an error requesting information about the bucket. """ try: self._api.buckets_get(name) except google.datalab.utils.RequestException as e: if e.status == 404: return False raise e except Exception as e: raise e return True def _retrieve_buckets(self, page_token, _): try: list_info = self._api.buckets_list(page_token=page_token, project_id=self._project_id) except Exception as e: raise e buckets = list_info.get('items', []) if len(buckets): try: buckets = [Bucket(info['name'], info, context=self._context) for info in buckets] except KeyError: raise Exception('Unexpected response from server') page_token = list_info.get('nextPageToken', None) return buckets, page_token def __iter__(self): return iter(google.datalab.utils.Iterator(self._retrieve_buckets)) ================================================ FILE: google/datalab/storage/_object.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements Object-related Cloud Storage APIs.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object import dateutil.parser import logging import time import google.datalab import google.datalab.utils from . import _api # TODO(nikhilko): Read/write operations don't account for larger files, or non-textual content. # Use streaming reads into a buffer or StringIO or into a file handle. # In some polling operations, we sleep between API calls to avoid hammering the # server. This argument controls how long we sleep between API calls. _POLLING_SLEEP = 1 # This argument controls how many times we'll poll before giving up. _MAX_POLL_ATTEMPTS = 30 class ObjectMetadata(object): """Represents metadata about a Cloud Storage object.""" def __init__(self, info): """Initializes an instance of a ObjectMetadata object. Args: info: a dictionary containing information about an Object. """ self._info = info @property def content_type(self): """The Content-Type associated with the object, if any.""" return self._info.get('contentType', None) @property def etag(self): """The ETag of the object, if any.""" return self._info.get('etag', None) @property def name(self): """The name of the object.""" return self._info['name'] @property def size(self): """The size (in bytes) of the object. 0 for objects that don't exist.""" return int(self._info.get('size', 0)) @property def updated_on(self): """The updated timestamp of the object as a datetime.datetime.""" s = self._info.get('updated', None) return dateutil.parser.parse(s) if s else None class Object(object): """Represents a Cloud Storage object within a bucket.""" def __init__(self, bucket, key, info=None, context=None): """Initializes an instance of an Object. Args: bucket: the name of the bucket containing the object. key: the key of the object. info: the information about the object if available. context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. """ if context is None: context = google.datalab.Context.default() self._context = context self._api = _api.Api(context) self._bucket = bucket self._key = key self._info = info @staticmethod def from_url(url): from . import _bucket bucket, object = _bucket.parse_name(url) return Object(bucket, object) @property def key(self): """Returns the key of the object.""" return self._key @property def uri(self): """Returns the gs:// URI for the object. """ return 'gs://%s/%s' % (self._bucket, self._key) def __repr__(self): """Returns a representation for the table for showing in the notebook. """ return 'Google Cloud Storage Object %s' % self.uri def copy_to(self, new_key, bucket=None): """Copies this object to the specified new key. Args: new_key: the new key to copy this object to. bucket: the bucket of the new object; if None (the default) use the same bucket. Returns: An Object corresponding to new key. Raises: Exception if there was an error copying the object. """ if bucket is None: bucket = self._bucket try: new_info = self._api.objects_copy(self._bucket, self._key, bucket, new_key) except Exception as e: raise e return Object(bucket, new_key, new_info, context=self._context) def exists(self): """ Checks if the object exists. """ try: return self.metadata is not None except google.datalab.utils.RequestException: return False except Exception as e: raise e def delete(self, wait_for_deletion=True): """Deletes this object from its bucket. Args: wait_for_deletion: If True, we poll until this object no longer appears in objects.list operations for this bucket before returning. Raises: Exception if there was an error deleting the object. """ if self.exists(): try: self._api.objects_delete(self._bucket, self._key) except Exception as e: raise e if wait_for_deletion: for _ in range(_MAX_POLL_ATTEMPTS): objects = Objects(self._bucket, prefix=self.key, delimiter='/', context=self._context) if any(o.key == self.key for o in objects): time.sleep(_POLLING_SLEEP) continue break else: logging.error('Failed to see object deletion after %d attempts.', _MAX_POLL_ATTEMPTS) @property def metadata(self): """Retrieves metadata about the object. Returns: An ObjectMetadata instance with information about this object. Raises: Exception if there was an error requesting the object's metadata. """ if self._info is None: try: self._info = self._api.objects_get(self._bucket, self._key) except Exception as e: raise e return ObjectMetadata(self._info) if self._info else None def read_stream(self, start_offset=0, byte_count=None): """Reads the content of this object as text. Args: start_offset: the start offset of bytes to read. byte_count: the number of bytes to read. If None, it reads to the end. Returns: The text content within the object. Raises: Exception if there was an error requesting the object's content. """ try: return self._api.object_download(self._bucket, self._key, start_offset=start_offset, byte_count=byte_count) except Exception as e: raise e def download(self): """Reads the content of this object. Returns: The content within the object. Raises: Exception if there was an error requesting the object's content. """ return self.read_stream() def read_lines(self, max_lines=None): """Reads the content of this object as text, and return a list of lines up to some max. Args: max_lines: max number of lines to return. If None, return all lines. Returns: The text content of the object as a list of lines. Raises: Exception if there was an error requesting the object's content. """ if max_lines is None: return self.read_stream().split('\n') max_to_read = self.metadata.size bytes_to_read = min(100 * max_lines, self.metadata.size) while True: content = self.read_stream(byte_count=bytes_to_read) lines = content.split('\n') if len(lines) > max_lines or bytes_to_read >= max_to_read: break # try 10 times more bytes or max bytes_to_read = min(bytes_to_read * 10, max_to_read) # remove the partial line at last del lines[-1] return lines[0:max_lines] def write_stream(self, content, content_type): """Writes text content to this object. Args: content: the text content to be written. content_type: the type of text content. Raises: Exception if there was an error requesting the object's content. """ try: self._api.object_upload(self._bucket, self._key, content, content_type) except Exception as e: raise e def upload(self, content): """Uploads content to this object. Args: content: the text content to be written. Raises: Exception if there was an error requesting the object's content. """ self.write_stream(content, content_type=None) class Objects(object): """Represents a list of Cloud Storage objects within a bucket.""" def __init__(self, bucket, prefix, delimiter, context=None): """Initializes an instance of an ObjectList. Args: bucket: the name of the bucket containing the objects. prefix: an optional prefix to match objects. delimiter: an optional string to simulate directory-like semantics. The returned objects will be those whose names do not contain the delimiter after the prefix. For the remaining objects, the names will be returned truncated after the delimiter with duplicates removed (i.e. as pseudo-directories). context: an optional Context object providing project_id and credentials. If a specific project id or credentials are unspecified, the default ones configured at the global level are used. """ if context is None: context = google.datalab.Context.default() self._context = context self._api = _api.Api(context) self._bucket = bucket self._prefix = prefix self._delimiter = delimiter def contains(self, key): """Checks if the specified object exists. Args: key: the key of the object to lookup. Returns: True if the object exists; False otherwise. Raises: Exception if there was an error requesting information about the object. """ try: self._api.objects_get(self._bucket, key) except google.datalab.utils.RequestException as e: if e.status == 404: return False raise e except Exception as e: raise e return True def _retrieve_objects(self, page_token, _): try: list_info = self._api.objects_list(self._bucket, prefix=self._prefix, delimiter=self._delimiter, page_token=page_token) except Exception as e: raise e objects = list_info.get('items', []) if len(objects): try: objects = [Object(self._bucket, info['name'], info, context=self._context) for info in objects] except KeyError: raise Exception('Unexpected response from server') page_token = list_info.get('nextPageToken', None) return objects, page_token def __iter__(self): return iter(google.datalab.utils.Iterator(self._retrieve_objects)) ================================================ FILE: google/datalab/storage/commands/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from . import _storage __all__ = ['_storage'] ================================================ FILE: google/datalab/storage/commands/_storage.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - BigQuery IPython Functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from past.builtins import basestring try: import IPython import IPython.core.display import IPython.core.magic except ImportError: raise Exception('This module can only be loaded in ipython.') import fnmatch import json import re import google.datalab.storage import google.datalab.utils.commands def _extract_gcs_api_response_error(message): """ A helper function to extract user-friendly error messages from service exceptions. Args: message: An error message from an exception. If this is from our HTTP client code, it will actually be a tuple. Returns: A modified version of the message that is less cryptic. """ try: if len(message) == 3: # Try treat the last part as JSON data = json.loads(message[2]) return data['error']['errors'][0]['message'] except Exception: pass return message @IPython.core.magic.register_line_cell_magic def gcs(line, cell=None): """Implements the gcs cell magic for ipython notebooks. Args: line: the contents of the gcs line. Returns: The results of executing the cell. """ parser = google.datalab.utils.commands.CommandParser(prog='%gcs', description=""" Execute various Google Cloud Storage related operations. Use "%gcs -h" for help on a specific command. """) # TODO(gram): consider adding a move command too. I did try this already using the # objects.patch API to change the object name but that fails with an error: # # Value 'newname' in content does not agree with value 'oldname'. This can happen when a value # set through a parameter is inconsistent with a value set in the request. # # This is despite 'name' being identified as writable in the storage API docs. # The alternative would be to use a copy/delete. copy_parser = parser.subcommand('copy', 'Copy one or more Google Cloud Storage objects to a ' 'different location.') copy_parser.add_argument('-s', '--source', help='The name of the object(s) to copy', nargs='+') copy_parser.add_argument('-d', '--destination', required=True, help='The copy destination. For multiple source objects this must be a ' 'bucket.') copy_parser.set_defaults(func=_gcs_copy) create_parser = parser.subcommand('create', 'Create one or more Google Cloud Storage buckets.') create_parser.add_argument('-p', '--project', help='The project associated with the objects') create_parser.add_argument('-b', '--bucket', help='The name of the bucket(s) to create', nargs='+') create_parser.set_defaults(func=_gcs_create) delete_parser = parser.subcommand('delete', 'Delete one or more Google Cloud Storage buckets or ' 'objects.') delete_parser.add_argument('-b', '--bucket', nargs='*', help='The name of the bucket(s) to remove') delete_parser.add_argument('-o', '--object', nargs='*', help='The name of the object(s) to remove') delete_parser.set_defaults(func=_gcs_delete) list_parser = parser.subcommand('list', 'List buckets in a project, or contents of a bucket.') list_parser.add_argument('-p', '--project', help='The project associated with the objects') list_parser.add_argument('-o', '--objects', help='List objects under the given Google Cloud Storage path', nargs='?') list_parser.set_defaults(func=_gcs_list) read_parser = parser.subcommand('read', 'Read the contents of a Google Cloud Storage object into ' 'a Python variable.') read_parser.add_argument('-o', '--object', help='The name of the object to read', required=True) read_parser.add_argument('-v', '--variable', required=True, help='The name of the Python variable to set') read_parser.set_defaults(func=_gcs_read) view_parser = parser.subcommand('view', 'View the contents of a Google Cloud Storage object.') view_parser.add_argument('-n', '--head', type=int, default=20, help='The number of initial lines to view') view_parser.add_argument('-t', '--tail', type=int, default=20, help='The number of lines from end to view') view_parser.add_argument('-o', '--object', help='The name of the object to view', required=True) view_parser.set_defaults(func=_gcs_view) write_parser = parser.subcommand('write', 'Write the value of a Python variable to a Google ' 'Cloud Storage object.') write_parser.add_argument('-v', '--variable', help='The name of the source Python variable', required=True) write_parser.add_argument('-o', '--object', required=True, help='The name of the destination Google Cloud Storage object to write') write_parser.add_argument('-c', '--content_type', help='MIME type', default='text/plain') write_parser.set_defaults(func=_gcs_write) return google.datalab.utils.commands.handle_magic_line(line, cell, parser) def _parser_exit(status=0, message=None): """ Replacement exit method for argument parser. We want to stop processing args but not call sys.exit(), so we raise an exception here and catch it in the call to parse_args. """ raise Exception() def _expand_list(names): """ Do a wildchar name expansion of object names in a list and return expanded list. The objects are expected to exist as this is used for copy sources or delete targets. Currently we support wildchars in the key name only. """ if names is None: names = [] elif isinstance(names, basestring): names = [names] results = [] # The expanded list. objects = {} # Cached contents of buckets; used for matching. for name in names: bucket, key = google.datalab.storage._bucket.parse_name(name) results_len = len(results) # If we fail to add any we add name and let caller deal with it. if bucket: if not key: # Just a bucket; add it. results.append('gs://%s' % bucket) elif google.datalab.storage.Object(bucket, key).exists(): results.append('gs://%s/%s' % (bucket, key)) else: # Expand possible key values. if bucket not in objects and key[:1] == '*': # We need the full list; cache a copy for efficiency. objects[bucket] = [obj.metadata.name for obj in list(google.datalab.storage.Bucket(bucket).objects())] # If we have a cached copy use it if bucket in objects: candidates = objects[bucket] # else we have no cached copy but can use prefix matching which is more efficient than # getting the full contents. else: # Get the non-wildchar prefix. match = re.search('\?|\*|\[', key) prefix = key if match: prefix = key[0:match.start()] candidates = [obj.metadata.name for obj in google.datalab.storage.Bucket(bucket).objects(prefix=prefix)] for obj in candidates: if fnmatch.fnmatch(obj, key): results.append('gs://%s/%s' % (bucket, obj)) # If we added no matches, add the original name and let caller deal with it. if len(results) == results_len: results.append(name) return results def _gcs_copy(args, _): target = args['destination'] target_bucket, target_key = google.datalab.storage._bucket.parse_name(target) if target_bucket is None and target_key is None: raise Exception('Invalid copy target name %s' % target) sources = _expand_list(args['source']) if len(sources) > 1: # Multiple sources; target must be a bucket if target_bucket is None or target_key is not None: raise Exception('More than one source but target %s is not a bucket' % target) errs = [] for source in sources: source_bucket, source_key = google.datalab.storage._bucket.parse_name(source) if source_bucket is None or source_key is None: raise Exception('Invalid source object name %s' % source) destination_bucket = target_bucket if target_bucket else source_bucket destination_key = target_key if target_key else source_key try: google.datalab.storage.Object(source_bucket, source_key).copy_to(destination_key, bucket=destination_bucket) except Exception as e: errs.append("Couldn't copy %s to %s: %s" % (source, target, _extract_gcs_api_response_error(str(e)))) if errs: raise Exception('\n'.join(errs)) def _gcs_create(args, _): """ Create one or more buckets. """ errs = [] for name in args['bucket']: try: bucket, key = google.datalab.storage._bucket.parse_name(name) if bucket and not key: google.datalab.storage.Bucket(bucket).create(_make_context(args['project'])) else: raise Exception("Invalid bucket name %s" % name) except Exception as e: errs.append("Couldn't create %s: %s" % (name, _extract_gcs_api_response_error(str(e)))) if errs: raise Exception('\n'.join(errs)) def _gcs_delete(args, _): """ Delete one or more buckets or objects. """ objects = _expand_list(args['bucket']) objects.extend(_expand_list(args['object'])) errs = [] for obj in objects: try: bucket, key = google.datalab.storage._bucket.parse_name(obj) if bucket and key: gcs_object = google.datalab.storage.Object(bucket, key) if gcs_object.exists(): google.datalab.storage.Object(bucket, key).delete() else: errs.append("%s does not exist" % obj) elif bucket: gcs_bucket = google.datalab.storage.Bucket(bucket) if gcs_bucket.exists(): gcs_bucket.delete() else: errs.append("%s does not exist" % obj) else: raise Exception("Can't delete object with invalid name %s" % obj) except Exception as e: errs.append("Couldn't delete %s: %s" % (obj, _extract_gcs_api_response_error(str(e)))) if errs: raise Exception('\n'.join(errs)) def _make_context(project_id=None): default_context = google.datalab.Context.default() project_id = project_id or default_context.project_id return google.datalab.Context(project_id, default_context.credentials) def _gcs_list_buckets(project, pattern): """ List all Google Cloud Storage buckets that match a pattern. """ data = [{'Bucket': 'gs://' + bucket.name, 'Created': bucket.metadata.created_on} for bucket in google.datalab.storage.Buckets(_make_context(project)) if fnmatch.fnmatch(bucket.name, pattern)] return google.datalab.utils.commands.render_dictionary(data, ['Bucket', 'Created']) def _gcs_get_keys(bucket, pattern): """ Get names of all Google Cloud Storage keys in a specified bucket that match a pattern. """ return [obj for obj in list(bucket.objects()) if fnmatch.fnmatch(obj.metadata.name, pattern)] def _gcs_get_key_names(bucket, pattern): """ Get names of all Google Cloud Storage keys in a specified bucket that match a pattern. """ return [obj.metadata.name for obj in _gcs_get_keys(bucket, pattern)] def _gcs_list_keys(bucket, pattern): """ List all Google Cloud Storage keys in a specified bucket that match a pattern. """ data = [{'Name': obj.metadata.name, 'Type': obj.metadata.content_type, 'Size': obj.metadata.size, 'Updated': obj.metadata.updated_on} for obj in _gcs_get_keys(bucket, pattern)] return google.datalab.utils.commands.render_dictionary(data, ['Name', 'Type', 'Size', 'Updated']) def _gcs_list(args, _): """ List the buckets or the contents of a bucket. This command is a bit different in that we allow wildchars in the bucket name and will list the buckets that match. """ target = args['objects'] project = args['project'] if target is None: return _gcs_list_buckets(project, '*') # List all buckets. bucket_name, key = google.datalab.storage._bucket.parse_name(target) if bucket_name is None: raise Exception('Cannot list %s; not a valid bucket name' % target) # If a target was specified, list keys inside it if target: if not re.search('\?|\*|\[', target): # If no wild characters are present in the key string, append a '/*' suffix to show all keys key = key.strip('/') + '/*' if key else '*' if project: # Only list if the bucket is in the project for bucket in google.datalab.storage.Buckets(_make_context(project)): if bucket.name == bucket_name: break else: raise Exception('%s does not exist in project %s' % (target, project)) else: bucket = google.datalab.storage.Bucket(bucket_name) if bucket.exists(): return _gcs_list_keys(bucket, key) else: raise Exception('Bucket %s does not exist' % target) else: # Treat the bucket name as a pattern and show matches. We don't use bucket_name as that # can strip off wildchars and so we need to strip off gs:// here. return _gcs_list_buckets(project, target.strip('/')[5:]) def _get_object_contents(source_name): source_bucket, source_key = google.datalab.storage._bucket.parse_name(source_name) if source_bucket is None: raise Exception('Invalid source object name %s; no bucket specified.' % source_name) if source_key is None: raise Exception('Invalid source object name %si; source cannot be a bucket.' % source_name) source = google.datalab.storage.Object(source_bucket, source_key) if not source.exists(): raise Exception('Source object %s does not exist' % source_name) return source.download() def _gcs_read(args, _): contents = _get_object_contents(args['object']) ipy = IPython.get_ipython() ipy.push({args['variable']: contents}) def _gcs_view(args, _): contents = _get_object_contents(args['object']) if not isinstance(contents, basestring): contents = str(contents) elif isinstance(contents, bytes): contents = str(contents, encoding='UTF-8') lines = contents.splitlines() head_count = args['head'] tail_count = args['tail'] if len(lines) > head_count + tail_count: head = '\n'.join(lines[:head_count]) tail = '\n'.join(lines[-tail_count:]) return head + '\n...\n' + tail else: return contents def _gcs_write(args, _): target_name = args['object'] target_bucket, target_key = google.datalab.storage._bucket.parse_name(target_name) if target_bucket is None or target_key is None: raise Exception('Invalid target object name %s' % target_name) target = google.datalab.storage.Object(target_bucket, target_key) ipy = IPython.get_ipython() contents = ipy.user_ns[args['variable']] # TODO(gram): would we want to to do any special handling here; e.g. for DataFrames? target.write_stream(str(contents), args['content_type']) ================================================ FILE: google/datalab/utils/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - Internal Helpers.""" from ._async import async_, async_function, async_method from ._http import Http, RequestException from ._iterator import Iterator from ._json_encoder import JSONEncoder from ._lru_cache import LRUCache from ._lambda_job import LambdaJob from ._dataflow_job import DataflowJob from ._utils import print_exception_with_last_stack, get_item, compare_datetimes, \ pick_unused_port, is_http_running_on, gcs_copy_file, python_portable_string __all__ = ['async_', 'async_function', 'async_method', 'Http', 'RequestException', 'Iterator', 'JSONEncoder', 'LRUCache', 'LambdaJob', 'DataflowJob', 'print_exception_with_last_stack', 'get_item', 'compare_datetimes', 'pick_unused_port', 'is_http_running_on', 'gcs_copy_file', 'python_portable_string'] ================================================ FILE: google/datalab/utils/_async.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Decorators for async methods and functions to dispatch on threads and support chained calls.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object import abc import concurrent.futures import functools from google.datalab._job import Job from future.utils import with_metaclass class async_(with_metaclass(abc.ABCMeta, object)): """ Base class for async_function/async_method. Creates a wrapped function/method that will run the original function/method on a thread pool worker thread and return a Job instance for monitoring the status of the thread. """ executor = concurrent.futures.ThreadPoolExecutor(max_workers=50) # Pool for doing the work. def __init__(self, function): self._function = function # Make the wrapper get attributes like docstring from wrapped method. functools.update_wrapper(self, function) @staticmethod def _preprocess_args(*args): # Pre-process arguments - if any are themselves Futures block until they can be resolved. return [arg.result() if isinstance(arg, concurrent.futures.Future) else arg for arg in args] @staticmethod def _preprocess_kwargs(**kwargs): # Pre-process keyword arguments - if any are Futures block until they can be resolved. return {kw: (arg.result() if isinstance(arg, concurrent.futures.Future) else arg) for kw, arg in list(kwargs.items())} @abc.abstractmethod def _call(self, *args, **kwargs): return def __call__(self, *args, **kwargs): # Queue the call up in the thread pool. return Job(future=self.executor.submit(self._call, *args, **kwargs)) class async_function(async_): """ This decorator can be applied to any static function that makes blocking calls to create a modified version that creates a Job and returns immediately; the original method will be called on a thread pool worker thread. """ def _call(self, *args, **kwargs): # Call the wrapped method. return self._function(*async_._preprocess_args(*args), **async_._preprocess_kwargs(**kwargs)) class async_method(async_): """ This decorator can be applied to any class instance method that makes blocking calls to create a modified version that creates a Job and returns immediately; the original method will be called on a thread pool worker thread. """ def _call(self, *args, **kwargs): # Call the wrapped method. return self._function(self.obj, *async_._preprocess_args(*args), **async_._preprocess_kwargs(**kwargs)) def __get__(self, instance, owner): # This is important for attribute inheritance and setting self.obj so it can be # passed as first argument to wrapped method. self.cls = owner self.obj = instance return self ================================================ FILE: google/datalab/utils/_dataflow_job.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements DataFlow Job functionality.""" from google.datalab import _job class DataflowJob(_job.Job): """Represents a DataFlow Job. """ def __init__(self, runner_results): """Initializes an instance of a DataFlow Job. Args: runner_results: a DataflowPipelineResult returned from Pipeline.run(). """ super(DataflowJob, self).__init__(runner_results._job.name) self._runner_results = runner_results def _refresh_state(self): """ Refresh the job info. """ # DataFlow's DataflowPipelineResult does not refresh state, so we have to do it ourselves # as a workaround. # TODO(Change this to use runner_results.state once it refreshes itself) dataflow_internal_job = ( self._runner_results._runner.dataflow_client.get_job(self._runner_results.job_id())) self._is_complete = str(dataflow_internal_job.currentState) in ['JOB_STATE_STOPPED', 'JOB_STATE_DONE', 'JOB_STATE_FAILED', 'JOB_STATE_CANCELLED'] self._fatal_error = getattr(self._runner_results._runner, 'last_error_msg', None) # Sometimes Dataflow does not populate runner.last_error_msg even if the job fails. if self._fatal_error is None and self._runner_results.state == 'FAILED': self._fatal_error = 'FAILED' ================================================ FILE: google/datalab/utils/_gcp_job.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements GCP Job functionality.""" from __future__ import absolute_import from __future__ import unicode_literals import google.datalab from google.datalab import _job class GCPJob(_job.Job): """Represents a BigQuery Job. """ def __init__(self, job_id, context): """Initializes an instance of a Job. Args: job_id: the BigQuery job ID corresponding to this job. context: a Context object providing project_id and credentials. """ super(GCPJob, self).__init__(job_id) if context is None: context = google.datalab.Context.default() self._context = context self._api = self._create_api(context) def _create_api(self, context): raise Exception('_create_api must be defined in a derived class') def __repr__(self): """Returns a representation for the job for showing in the notebook. """ return 'Job %s/%s %s' % (self._context.project_id, self._job_id, self.state) ================================================ FILE: google/datalab/utils/_http.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements HTTP client helper functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from future import standard_library standard_library.install_aliases() # noqa from builtins import str from past.builtins import basestring from builtins import object import copy import datetime import json import urllib.request import urllib.parse import urllib.error import httplib2 import google_auth_httplib2 import logging log = logging.getLogger(__name__) # TODO(nikhilko): Start using the requests library instead. class RequestException(Exception): def __init__(self, status, content): self.status = status self.content = content self.message = 'HTTP request failed' # Try extract a message from the body; swallow possible resulting ValueErrors and KeyErrors. try: if isinstance(content, str): error = json.loads(content)['error'] else: error = json.loads(str(content, encoding='UTF-8'))['error'] if 'errors' in error: error = error['errors'][0] self.message += ': {}'.format(error['message']) except Exception: lines = content.splitlines() if isinstance(content, basestring) else [] if lines: self.message += ': {}'.format(lines[0]) def __str__(self): return self.message class Http(object): """A helper class for making HTTP requests. """ # Reuse one Http object across requests to take advantage of Keep-Alive, e.g. # for BigQuery queries that requires at least ~5 sequential http requests. # # TODO(nikhilko): # SSL cert validation seemingly fails, and workarounds are not amenable # to implementing in library code. So configure the Http object to skip # doing so, in the interim. http = httplib2.Http() http.disable_ssl_certificate_validation = True def __init__(self): pass @staticmethod def request(url, args=None, data=None, headers=None, method=None, credentials=None, raw_response=False, stats=None): """Issues HTTP requests. Args: url: the URL to request. args: optional query string arguments. data: optional data to be sent within the request. headers: optional headers to include in the request. method: optional HTTP method to use. If unspecified this is inferred (GET or POST) based on the existence of request data. credentials: optional set of credentials to authorize the request. raw_response: whether the raw response content should be returned as-is. stats: an optional dictionary that, if provided, will be populated with some useful info about the request, like 'duration' in seconds and 'data_size' in bytes. These may be useful optimizing the access to rate-limited APIs. Returns: The parsed response object. Raises: Exception when the HTTP request fails or the response cannot be processed. """ if headers is None: headers = {} headers['user-agent'] = 'GoogleCloudDataLab/1.0' # Add querystring to the URL if there are any arguments. if args is not None: qs = urllib.parse.urlencode(args) url = url + '?' + qs # Setup method to POST if unspecified, and appropriate request headers # if there is data to be sent within the request. if data is not None: if method is None: method = 'POST' if data != '': # If there is a content type specified, use it (and the data) as-is. # Otherwise, assume JSON, and serialize the data object. if 'Content-Type' not in headers: data = json.dumps(data) headers['Content-Type'] = 'application/json' headers['Content-Length'] = str(len(data)) else: if method == 'POST': headers['Content-Length'] = '0' # If the method is still unset, i.e. it was unspecified, and there # was no data to be POSTed, then default to GET request. if method is None: method = 'GET' http = Http.http # Authorize with credentials if given if credentials is not None: # Make a copy of the shared http instance before we modify it. http = copy.copy(http) http = google_auth_httplib2.AuthorizedHttp(credentials) if stats is not None: stats['duration'] = datetime.datetime.utcnow() response = None try: log.debug('request: method[%(method)s], url[%(url)s], body[%(data)s]' % locals()) response, content = http.request(url, method=method, body=data, headers=headers) if 200 <= response.status < 300: if raw_response: return content if type(content) == str: return json.loads(content) else: return json.loads(str(content, encoding='UTF-8')) else: raise RequestException(response.status, content) except ValueError: raise Exception('Failed to process HTTP response.') except httplib2.HttpLib2Error: raise Exception('Failed to send HTTP request.') finally: if stats is not None: stats['data_size'] = len(data) stats['status'] = response.status stats['duration'] = (datetime.datetime.utcnow() - stats['duration']).total_seconds() ================================================ FILE: google/datalab/utils/_iterator.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Iterator class for iterable cloud lists.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import object class Iterator(object): """An iterator implementation that handles paging over a cloud list.""" def __init__(self, retriever): """Initializes an instance of an Iterator. Args: retriever: a function that can retrieve the next page of items. """ self._page_token = None self._first_page = True self._retriever = retriever self._count = 0 def __iter__(self): """Provides iterator functionality.""" while self._first_page or (self._page_token is not None): items, next_page_token = self._retriever(self._page_token, self._count) self._page_token = next_page_token self._first_page = False self._count += len(items) for item in items: yield item def reset(self): """Resets the current iteration.""" self._page_token = None self._first_page = True self._count = 0 ================================================ FILE: google/datalab/utils/_json_encoder.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """ JSON encoder that can handle Python datetime objects. """ from __future__ import absolute_import from __future__ import unicode_literals import datetime import json class JSONEncoder(json.JSONEncoder): """ A JSON encoder that can handle Python datetime objects. """ def default(self, obj): if isinstance(obj, datetime.date) or isinstance(obj, datetime.datetime): return obj.isoformat() elif isinstance(obj, datetime.timedelta): return (datetime.datetime.min + obj).time().isoformat() else: return super(JSONEncoder, self).default(obj) ================================================ FILE: google/datalab/utils/_lambda_job.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements OS shell Job functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from . import _async from google.datalab import _job class LambdaJob(_job.Job): """Represents an lambda function as a Job. """ def __init__(self, fn, job_id, *args, **kwargs): """Initializes an instance of a Job. Args: fn: the lambda function to execute asyncronously job_id: an optional ID for the job. If None, a UUID will be generated. """ super(LambdaJob, self).__init__(job_id) self._future = _async.async_.executor.submit(fn, *args, **kwargs) def __repr__(self): """Returns a representation for the job for showing in the notebook. """ return 'Job %s %s' % (self._job_id, self.state) # TODO: ShellJob, once we need it, should inherit on LambdaJob: # import subprocess # LambdaJob(subprocess.check_output, id, command_line, shell=True) ================================================ FILE: google/datalab/utils/_lru_cache.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """A simple LRU cache.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from past.builtins import basestring from builtins import object import datetime class LRUCache(object): """A simple LRU cache.""" def __init__(self, cache_size): """ Initialize the cache with the given size. Args: cache_size: the maximum number of items the cache can hold. Attempts to add more items than this will result in the least recently used items being displaced to make room. """ self._cache = {} self._cache_size = cache_size def __getitem__(self, key): """ Get an item from the cache. Args: key: a string used as the lookup key. Returns: The cached item, if any. Raises: Exception if the key is not a string. KeyError if the key is not found. """ if not isinstance(key, basestring): raise Exception("LRU cache can only be indexed by strings (%s has type %s)" % (str(key), str(type(key)))) if key in self._cache: entry = self._cache[key] entry['last_used'] = datetime.datetime.now() return entry['value'] else: raise KeyError(key) def __delitem__(self, key): """ Remove an item from the cache. Args: key: a string key for retrieving the item. """ if not isinstance(key, basestring): raise Exception("LRU cache can only be indexed by strings") del self._cache[key] def __setitem__(self, key, value): """ Put an item in the cache. Args: key: a string key for retrieving the item. value: the item to cache. Raises: Exception if the key is not a string. """ if not isinstance(key, basestring): raise Exception("LRU cache can only be indexed by strings") if key in self._cache: entry = self._cache[key] elif len(self._cache) < self._cache_size: # Cache is not full; append an new entry self._cache[key] = entry = {} else: # Cache is full; displace an entry entry = min(list(self._cache.values()), key=lambda x: x['last_used']) self._cache.pop(entry['key']) self._cache[key] = entry entry['value'] = value entry['key'] = key entry['last_used'] = datetime.datetime.now() def __contains__(self, key): return key in self._cache def get(self, key, value): if key in self._cache: return self._cache[key]['value'] return value ================================================ FILE: google/datalab/utils/_utils.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Miscellaneous simple utility functions.""" from __future__ import absolute_import from __future__ import print_function from __future__ import unicode_literals from builtins import str try: import http.client as httplib except ImportError: import httplib import json import logging import os import pytz import six import subprocess import socket import traceback import types import oauth2client.client import google.auth import google.auth.exceptions import google.auth.credentials import google.auth._oauth2client def print_exception_with_last_stack(e): """ Print the call stack of the last exception plu sprint the passed exception. Args: e: the exception to print. """ traceback.print_exc() print(str(e)) def get_item(env, name, default=None): """ Get an item from a dictionary, handling nested lookups with dotted notation. Args: env: the environment (dictionary) to use to look up the name. name: the name to look up, in dotted notation. default: the value to return if the name if not found. Returns: The result of looking up the name, if found; else the default. """ # TODO: handle attributes if not name: return default for key in name.split('.'): if isinstance(env, dict) and key in env: env = env[key] elif isinstance(env, types.ModuleType) and key in env.__dict__: env = env.__dict__[key] else: return default return env def compare_datetimes(d1, d2): """ Compares two datetimes safely, whether they are timezone-naive or timezone-aware. If either datetime is naive it is converted to an aware datetime assuming UTC. Args: d1: first datetime. d2: second datetime. Returns: -1 if d1 < d2, 0 if they are the same, or +1 is d1 > d2. """ if d1.tzinfo is None or d1.tzinfo.utcoffset(d1) is None: d1 = d1.replace(tzinfo=pytz.UTC) if d2.tzinfo is None or d2.tzinfo.utcoffset(d2) is None: d2 = d2.replace(tzinfo=pytz.UTC) if d1 < d2: return -1 elif d1 > d2: return 1 return 0 def pick_unused_port(): """ get an unused port on the VM. Returns: An unused port. """ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(('localhost', 0)) addr, port = s.getsockname() s.close() return port def is_http_running_on(port): """ Check if an http server runs on a given port. Args: The port to check. Returns: True if it is used by an http server. False otherwise. """ try: conn = httplib.HTTPConnection('127.0.0.1:' + str(port)) conn.connect() conn.close() return True except Exception: return False def gcs_copy_file(source, dest): """ Copy file from source to destination. The paths can be GCS or local. Args: source: the source file path. dest: the destination file path. """ subprocess.check_call(['gsutil', '-q', 'cp', source, dest]) """ Support for getting gcloud credentials. """ # TODO(ojarjur): This limits the APIs against which Datalab can be called # (when using a service account with a credentials file) to only being those # that are part of the Google Cloud Platform. We should either extend this # to all of the API scopes that Google supports, or make it extensible so # that the user can define for themselves which scopes they want to use. CREDENTIAL_SCOPES = [ 'https://www.googleapis.com/auth/cloud-platform', ] def _in_datalab_docker(): return os.path.exists('/datalab') and os.getenv('DATALAB_ENV') def get_config_dir(): config_dir = os.getenv('CLOUDSDK_CONFIG') if config_dir is None: if os.name == 'nt': try: config_dir = os.path.join(os.environ['APPDATA'], 'gcloud') except KeyError: # This should never happen unless someone is really messing with things. drive = os.environ.get('SystemDrive', 'C:') config_dir = os.path.join(drive, '\\gcloud') else: config_dir = os.path.join(os.path.expanduser('~'), '.config/gcloud') return config_dir def _convert_oauth2client_creds(credentials): new_credentials = google.oauth2.credentials.Credentials( token=credentials.access_token, refresh_token=credentials.refresh_token, token_uri=credentials.token_uri, client_id=credentials.client_id, client_secret=credentials.client_secret, scopes=credentials.scopes) new_credentials._expires = credentials.token_expiry return new_credentials def get_credentials(): """ Get the credentials to use. We try application credentials first, followed by user credentials. The path to the application credentials can be overridden by pointing the GOOGLE_APPLICATION_CREDENTIALS environment variable to some file; the path to the user credentials can be overridden by pointing the CLOUDSDK_CONFIG environment variable to some directory (after which we will look for the file $CLOUDSDK_CONFIG/gcloud/credentials). Unless you have specific reasons for overriding these the defaults should suffice. """ try: # We temporarily disable warning logs from the "_default" module to avoid # a spurious warning about the project not being set. authDefaultLogger = logging.getLogger("google.auth._default") previousLevel = authDefaultLogger.getEffectiveLevel() authDefaultLogger.setLevel(logging.ERROR) credentials, _ = google.auth.default() credentials = google.auth.credentials.with_scopes_if_required(credentials, CREDENTIAL_SCOPES) authDefaultLogger.setLevel(previousLevel) return credentials except Exception as e: # Try load user creds from file cred_file = get_config_dir() + '/credentials' if os.path.exists(cred_file): with open(cred_file) as f: creds = json.loads(f.read()) # Use the first gcloud one we find for entry in creds['data']: if entry['key']['type'] == 'google-cloud-sdk': creds = oauth2client.client.OAuth2Credentials.from_json(json.dumps(entry['credential'])) return _convert_oauth2client_creds(creds) if type(e) == google.auth.exceptions.DefaultCredentialsError: # If we are in Datalab container, change the message to be about signing in. if _in_datalab_docker(): raise Exception('No application credentials found. Perhaps you should sign in.') raise e def save_project_id(project_id): """ Save project id to config file. Args: project_id: the project_id to save. """ # Try gcloud first. If gcloud fails (probably because it does not exist), then # write to a config file. try: subprocess.call(['gcloud', 'config', 'set', 'project', project_id]) except: config_file = os.path.join(get_config_dir(), 'config.json') config = {} if os.path.exists(config_file): with open(config_file) as f: config = json.loads(f.read()) config['project_id'] = project_id with open(config_file, 'w') as f: f.write(json.dumps(config)) def get_default_project_id(): """ Get default project id from config or environment var. Returns: the project id if available, or None. """ # Try getting default project id from gcloud. If it fails try config.json. try: proc = subprocess.Popen(['gcloud', 'config', 'list', '--format', 'value(core.project)'], stdout=subprocess.PIPE) stdout, _ = proc.communicate() value = stdout.strip() if proc.poll() == 0 and value: if isinstance(value, six.string_types): return value else: # Hope it's a utf-8 string encoded in bytes. Otherwise an exception will # be thrown and config.json will be checked. return value.decode() except: pass config_file = os.path.join(get_config_dir(), 'config.json') if os.path.exists(config_file): with open(config_file) as f: config = json.loads(f.read()) if 'project_id' in config and config['project_id']: return str(config['project_id']) if os.getenv('PROJECT_ID') is not None: return os.getenv('PROJECT_ID') return None def _construct_context_for_args(args): """Construct a new Context for the parsed arguments. Args: args: the dictionary of magic arguments. Returns: A new Context based on the current default context, but with any explicitly specified arguments overriding the default's config. """ global_default_context = google.datalab.Context.default() config = {} for key in global_default_context.config: config[key] = global_default_context.config[key] billing_tier_arg = args.get('billing', None) if billing_tier_arg: config['bigquery_billing_tier'] = billing_tier_arg return google.datalab.Context( project_id=global_default_context.project_id, credentials=global_default_context.credentials, config=config) def python_portable_string(string, encoding='utf-8'): """Converts bytes into a string type. Valid string types are retuned without modification. So in Python 2, type str and unicode are not converted. In Python 3, type bytes is converted to type str (unicode) """ if isinstance(string, six.string_types): return string if six.PY3: return string.decode(encoding) raise ValueError('Unsupported type %s' % str(type(string))) ================================================ FILE: google/datalab/utils/commands/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. # flake8: noqa from __future__ import absolute_import from __future__ import unicode_literals # Support functions for magics and display help. from ._commands import CommandParser from ._html import Html, HtmlBuilder from ._utils import * # Magics from . import _chart from . import _chart_data from . import _csv from . import _job ================================================ FILE: google/datalab/utils/commands/_chart.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - Chart cell magic.""" from __future__ import absolute_import from __future__ import unicode_literals try: import IPython import IPython.core.display import IPython.core.magic except ImportError: raise Exception('This module can only be loaded in ipython.') from . import _commands from . import _utils @IPython.core.magic.register_line_cell_magic def chart(line, cell=None): """ Generate charts with Google Charts. Use %chart --help for more details. """ parser = _commands.CommandParser(prog='%chart', description=""" Generate an inline chart using Google Charts using the data in a Table, Query, dataframe, or list. Numerous types of charts are supported. Options for the charts can be specified in the cell body using YAML or JSON. """) for chart_type in ['annotation', 'area', 'bars', 'bubbles', 'calendar', 'candlestick', 'columns', 'combo', 'gauge', 'geo', 'heatmap', 'histogram', 'line', 'map', 'org', 'paged_table', 'pie', 'sankey', 'scatter', 'stepped_area', 'table', 'timeline', 'treemap']: subparser = parser.subcommand(chart_type, 'Generate a %s chart.' % chart_type) subparser.add_argument('-f', '--fields', help='The field(s) to include in the chart') subparser.add_argument('-d', '--data', help='The name of the variable referencing the Table or Query to chart', required=True) subparser.set_defaults(chart=chart_type) parser.set_defaults(func=_chart_cell) return _utils.handle_magic_line(line, cell, parser) def _chart_cell(args, cell): source = args['data'] ipy = IPython.get_ipython() chart_options = _utils.parse_config(cell, ipy.user_ns) if chart_options is None: chart_options = {} elif not isinstance(chart_options, dict): raise Exception("Could not parse chart options") chart_type = args['chart'] fields = args['fields'] if args['fields'] else '*' return IPython.core.display.HTML(_utils.chart_html('gcharts', chart_type, source=source, chart_options=chart_options, fields=fields)) ================================================ FILE: google/datalab/utils/commands/_chart_data.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - chart_data cell magic.""" from __future__ import absolute_import from __future__ import print_function from __future__ import unicode_literals try: import IPython import IPython.core.display import IPython.core.magic except ImportError: raise Exception('This module can only be loaded in ipython.') import json import datalab.utils.commands import google.datalab.data import google.datalab.utils from . import _utils @IPython.core.magic.register_cell_magic def _get_chart_data(line, cell_body=''): refresh = 0 options = {} try: metadata = json.loads(cell_body) if cell_body else {} source_index = metadata.get('source_index', None) fields = metadata.get('fields', '*') first_row = int(metadata.get('first', 0)) count = int(metadata.get('count', -1)) legacy = metadata.get('legacy', None) # Both legacy and non-legacy table viewer calls this magic for new pages of data. # Need to find their own data source --- one under datalab.utils.commands._utils # and the other under google.datalab.utils.commands._utils. if legacy is not None: data_source = datalab.utils.commands._utils._data_sources else: data_source = _utils._data_sources source_index = int(source_index) if source_index >= len(data_source): # Can happen after e.g. kernel restart # TODO(gram): get kernel restart events in charting.js and disable any refresh timers. print('No source %d' % source_index) return IPython.core.display.JSON({'data': {}}) source = data_source[source_index] schema = None controls = metadata['controls'] if 'controls' in metadata else {} if legacy is not None: data, _ = datalab.utils.commands.get_data( source, fields, controls, first_row, count, schema) else: data, _ = _utils.get_data(source, fields, controls, first_row, count, schema) except Exception as e: google.datalab.utils.print_exception_with_last_stack(e) print('Failed with exception %s' % e) data = {} # TODO(gram): The old way - commented out below - has the advantage that it worked # for datetimes, but it is strictly wrong. The correct way below may have issues if the # chart has datetimes though so test this. return IPython.core.display.JSON({'data': data, 'refresh_interval': refresh, 'options': options}) # return IPython.core.display.JSON(json.dumps({'data': data}, # cls=google.datalab.utils.JSONEncoder)) ================================================ FILE: google/datalab/utils/commands/_commands.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implementation of command parsing and handling within magics.""" from __future__ import absolute_import from __future__ import print_function from __future__ import unicode_literals try: import IPython except ImportError: raise Exception('This module can only be loaded in ipython.') import argparse import shlex import six import google.datalab.utils class CommandParser(argparse.ArgumentParser): """An argument parser to parse commands in line/cell magic declarations. """ def __init__(self, *args, **kwargs): """Initializes an instance of a CommandParser. """ super(CommandParser, self).__init__(*args, **kwargs) # Set _parser_class, so that subparsers added will also be of this type. self._parser_class = CommandParser self._subcommands = None # A dict such as {'argname': {'required': True, 'help': 'arg help string'}} self._cell_args = {} @staticmethod def create(name): """Creates a CommandParser for a specific magic. """ return CommandParser(prog=name) def exit(self, status=0, message=None): """Overridden exit method to stop parsing without calling sys.exit(). """ if status == 0 and message is None: # This happens when parsing '--help' raise Exception('exit_0') else: raise Exception(message) def format_help(self): """Override help doc to add cell args. """ if not self._cell_args: return super(CommandParser, self).format_help() else: # Print the standard argparse info, the cell arg block, and then the epilog # If we don't remove epilog before calling the super, then epilog will # be printed before the 'Cell args' block. epilog = self.epilog self.epilog = None orig_help = super(CommandParser, self).format_help() cell_args_help = '\nCell args:\n\n' for cell_arg, v in six.iteritems(self._cell_args): required = 'Required' if v['required'] else 'Optional' cell_args_help += '%s: %s. %s.\n\n' % (cell_arg, required, v['help']) orig_help += cell_args_help if epilog: orig_help += epilog + '\n\n' return orig_help def format_usage(self): """Overridden usage generator to use the full help message. """ return self.format_help() @staticmethod def create_args(line, namespace): """ Expand any meta-variable references in the argument list. """ args = [] # Using shlex.split handles quotes args and escape characters. for arg in shlex.split(line): if not arg: continue if arg[0] == '$': var_name = arg[1:] if var_name in namespace: args.append((namespace[var_name])) else: raise Exception('Undefined variable referenced in command line: %s' % arg) else: args.append(arg) return args def _get_subparsers(self): """Recursively get subparsers.""" subparsers = [] for action in self._actions: if isinstance(action, argparse._SubParsersAction): for _, subparser in action.choices.items(): subparsers.append(subparser) ret = subparsers for sp in subparsers: ret += sp._get_subparsers() return ret def _get_subparser_line_args(self, subparser_prog): """ Get line args of a specified subparser by its prog.""" subparsers = self._get_subparsers() for subparser in subparsers: if subparser_prog == subparser.prog: # Found the subparser. args_to_parse = [] for action in subparser._actions: if action.option_strings: for argname in action.option_strings: if argname.startswith('--'): args_to_parse.append(argname[2:]) return args_to_parse return None def _get_subparser_cell_args(self, subparser_prog): """ Get cell args of a specified subparser by its prog.""" subparsers = self._get_subparsers() for subparser in subparsers: if subparser_prog == subparser.prog: return subparser._cell_args return None def add_cell_argument(self, name, help, required=False): """ Add a cell only argument. Args: name: name of the argument. No need to start with "-" or "--". help: the help string of the argument. required: Whether it is required in cell content. """ for action in self._actions: if action.dest == name: raise ValueError('Arg "%s" was added by add_argument already.' % name) self._cell_args[name] = {'required': required, 'help': help} def parse(self, line, cell, namespace=None): """Parses a line and cell into a dictionary of arguments, expanding variables from a namespace. For each line parameters beginning with --, it also checks the cell content and see if it exists there. For example, if "--config1" is a line parameter, it checks to see if cell dict contains "config1" item, and if so, use the cell value. The "config1" item will also be removed from cell content. Args: line: line content. cell: cell content. namespace: user namespace. If None, IPython's user namespace is used. Returns: A tuple of: 1. parsed config dict. 2. remaining cell after line parameters are extracted. """ if namespace is None: ipy = IPython.get_ipython() namespace = ipy.user_ns # Find which subcommand in the line by comparing line with subcommand progs. # For example, assuming there are 3 subcommands with their progs # %bq tables # %bq tables list # %bq datasets # and the line is "tables list --dataset proj.myds" # it will find the second one --- "tables list" because it matches the prog and # it is the longest. args = CommandParser.create_args(line, namespace) # "prog" is a ArgumentParser's path splitted by namspace, such as '%bq tables list'. sub_parsers_progs = [x.prog for x in self._get_subparsers()] matched_progs = [] for prog in sub_parsers_progs: # Remove the leading magic such as "%bq". match = prog.split()[1:] for i in range(len(args)): if args[i:i + len(match)] == match: matched_progs.append(prog) break matched_prog = None if matched_progs: # Get the longest match. matched_prog = max(matched_progs, key=lambda x: len(x.split())) # Line args can be provided in cell too. If they are in cell, move them to line # so we can parse them all together. line_args = self._get_subparser_line_args(matched_prog) if line_args: cell_config = None try: cell_config, cell = google.datalab.utils.commands.parse_config_for_selected_keys( cell, line_args) except: # It is okay --- probably because cell is not in yaml or json format. pass if cell_config: google.datalab.utils.commands.replace_vars(cell_config, namespace) for arg_name in cell_config: arg_value = cell_config[arg_name] if arg_value is None: continue if '--' + arg_name in args: raise ValueError('config item "%s" is specified in both cell and line.' % arg_name) if isinstance(arg_value, bool): if arg_value: line += ' --%s' % arg_name else: line += ' --%s %s' % (arg_name, str(cell_config[arg_name])) # Parse args again with the new line. args = CommandParser.create_args(line, namespace) args = vars(self.parse_args(args)) # Parse cell args. cell_config = None cell_args = self._get_subparser_cell_args(matched_prog) if cell_args: try: cell_config, _ = google.datalab.utils.commands.parse_config_for_selected_keys( cell, cell_args) except: # It is okay --- probably because cell is not in yaml or json format. pass if cell_config: google.datalab.utils.commands.replace_vars(cell_config, namespace) for arg in cell_args: if (cell_args[arg]['required'] and (cell_config is None or cell_config.get(arg, None) is None)): raise ValueError('Cell config "%s" is required.' % arg) if cell_config: args.update(cell_config) return args, cell def subcommand(self, name, help, **kwargs): """Creates a parser for a sub-command. """ if self._subcommands is None: self._subcommands = self.add_subparsers(help='commands') return self._subcommands.add_parser(name, description=help, help=help, **kwargs) ================================================ FILE: google/datalab/utils/commands/_csv.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements CSV file exploration""" from __future__ import absolute_import from __future__ import unicode_literals try: import IPython import IPython.core.magic import IPython.core.display except ImportError: raise Exception('This module can only be loaded in ipython.') import pandas as pd import google.datalab.data from . import _commands from . import _utils @IPython.core.magic.register_line_cell_magic def csv(line, cell=None): parser = _commands.CommandParser.create('csv') view_parser = parser.subcommand('view', 'Browse CSV files without providing a schema. ' + 'Each value is considered string type.') view_parser.add_argument('-i', '--input', help='Path of the input CSV data', required=True) view_parser.add_argument('-n', '--count', help='The number of lines to browse from head, default to 5.') view_parser.add_argument('-P', '--profile', action='store_true', default=False, help='Generate an interactive profile of the data') view_parser.set_defaults(func=_view) return _utils.handle_magic_line(line, cell, parser) def _view(args, cell): csv = google.datalab.data.CsvFile(args['input']) num_lines = int(args['count'] or 5) headers = None if cell: ipy = IPython.get_ipython() config = _utils.parse_config(cell, ipy.user_ns) if 'columns' in config: headers = [e.strip() for e in config['columns'].split(',')] df = pd.DataFrame(csv.browse(num_lines, headers)) if args['profile']: # TODO(gram): We need to generate a schema and type-convert the columns before this # will be useful for CSV return _utils.profile_df(df) else: return IPython.core.display.HTML(df.to_html(index=False)) ================================================ FILE: google/datalab/utils/commands/_html.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Google Cloud Platform library - IPython HTML display Functionality.""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from builtins import range from past.builtins import basestring from builtins import object import time class Html(object): """A helper to enable generating an HTML representation as display data in a notebook. This object supports the combination of HTML markup and/or associated JavaScript. """ _div_id_counter = 0 @staticmethod def next_id(): """ Return an ID containing a reproducible part (counter) and unique part (timestamp). """ Html._div_id_counter += 1 return '%d_%d' % (Html._div_id_counter, int(round(time.time() * 100))) def __init__(self, markup=None): """Initializes an instance of Html. """ self._id = Html.next_id() Html._div_id_counter += 1 self._class_name = '' self._markup = markup self._dependencies = [('element!hh_%d' % self._id, 'dom')] self._script = '' self._class = None def add_class(self, class_name): """Adds a CSS class to be generated on the output HTML. """ self._class = class_name def add_dependency(self, path, name): """Adds a script dependency to be loaded before any script is executed. """ self._dependencies.append((path, name)) def add_script(self, script): """Adds JavaScript that should be included along-side the HTML. """ self._script = script def _repr_html_(self): """Generates the HTML representation. """ parts = [] if self._class: parts.append('
%s
' % (self._id, self._class, self._markup)) else: parts.append('
%s
' % (self._id, self._markup)) if len(self._script) != 0: parts.append('') return ''.join(parts) class HtmlBuilder(object): """A set of helpers to build HTML representations of objects. """ def __init__(self): """Initializes an instance of an HtmlBuilder. """ self._segments = [] def _render_objects(self, items, attributes=None, datatype='object'): """Renders an HTML table with the specified list of objects. Args: items: the iterable collection of objects to render. attributes: the optional list of properties or keys to render. datatype: the type of data; one of 'object' for Python objects, 'dict' for a list of dictionaries, or 'chartdata' for Google chart data. """ if not items: return if datatype == 'chartdata': if not attributes: attributes = [items['cols'][i]['label'] for i in range(0, len(items['cols']))] items = items['rows'] indices = {attributes[i]: i for i in range(0, len(attributes))} num_segments = len(self._segments) self._segments.append('
') first = True for o in items: if first: first = False if datatype == 'dict' and not attributes: attributes = list(o.keys()) if attributes is not None: self._segments.append('') for attr in attributes: self._segments.append('' % attr) self._segments.append('') self._segments.append('') if attributes is None: self._segments.append('' % HtmlBuilder._format(o)) else: for attr in attributes: if datatype == 'dict': self._segments.append('' % HtmlBuilder._format(o.get(attr, None), nbsp=True)) elif datatype == 'chartdata': self._segments.append('' % HtmlBuilder._format(o['c'][indices[attr]]['v'], nbsp=True)) else: self._segments.append('' % HtmlBuilder._format(o.__getattribute__(attr), nbsp=True)) self._segments.append('') self._segments.append('
%s
%s%s%s%s
') if first: # The table was empty; drop it from the segments. self._segments = self._segments[:num_segments] def _render_text(self, text, preformatted=False): """Renders an HTML formatted text block with the specified text. Args: text: the text to render preformatted: whether the text should be rendered as preformatted """ tag = 'pre' if preformatted else 'div' self._segments.append('<%s>%s' % (tag, HtmlBuilder._format(text), tag)) def _render_list(self, items, empty='
<empty>
'): """Renders an HTML list with the specified list of strings. Args: items: the iterable collection of objects to render. empty: what to render if the list is None or empty. """ if not items or len(items) == 0: self._segments.append(empty) return self._segments.append('
    ') for o in items: self._segments.append('
  • ') self._segments.append(str(o)) self._segments.append('
  • ') self._segments.append('
') def _to_html(self): """Returns the HTML that has been rendered. Returns: The HTML string that has been built. """ return ''.join(self._segments) @staticmethod def _format(value, nbsp=False): if value is None: return ' ' if nbsp else '' elif isinstance(value, basestring): return value.replace('&', '&').replace('<', '<').replace('>', '>') else: return str(value) @staticmethod def render_text(text, preformatted=False): """Renders an HTML formatted text block with the specified text. Args: text: the text to render preformatted: whether the text should be rendered as preformatted Returns: The formatted HTML. """ builder = HtmlBuilder() builder._render_text(text, preformatted=preformatted) return builder._to_html() @staticmethod def render_table(data, headers=None): """ Return a dictionary list formatted as a HTML table. Args: data: a list of dictionaries, one per row. headers: the keys in the dictionary to use as table columns, in order. """ builder = HtmlBuilder() builder._render_objects(data, headers, datatype='dict') return builder._to_html() @staticmethod def render_chart_data(data): """ Return a dictionary list formatted as a HTML table. Args: data: data in the form consumed by Google Charts. """ builder = HtmlBuilder() builder._render_objects(data, datatype='chartdata') return builder._to_html() @staticmethod def render_list(data): """ Return a list formatted as a HTML list. Args: data: a list of strings. """ builder = HtmlBuilder() builder._render_list(data) return builder._to_html() ================================================ FILE: google/datalab/utils/commands/_job.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Implements job view""" from __future__ import absolute_import from __future__ import unicode_literals from builtins import str try: import IPython import IPython.core.magic import IPython.core.display except ImportError: raise Exception('This module can only be loaded in ipython.') import google.datalab.utils from . import _html _local_jobs = {} def html_job_status(job_name, job_type, refresh_interval, html_on_running, html_on_success): """create html representation of status of a job (long running operation). Args: job_name: the full name of the job. job_type: type of job. Can be 'local' or 'cloud'. refresh_interval: how often should the client refresh status. html_on_running: additional html that the job view needs to include on job running. html_on_success: additional html that the job view needs to include on job success. """ _HTML_TEMPLATE = """
""" div_id = _html.Html.next_id() return IPython.core.display.HTML(_HTML_TEMPLATE % (div_id, div_id, job_name, job_type, refresh_interval, html_on_running, html_on_success)) @IPython.core.magic.register_line_magic def _get_job_status(line): """magic used as an endpoint for client to get job status. %_get_job_status Returns: A JSON object of the job status. """ try: args = line.strip().split() job_name = args[0] job = None if job_name in _local_jobs: job = _local_jobs[job_name] else: raise Exception('invalid job %s' % job_name) if job is not None: error = '' if job.fatal_error is None else str(job.fatal_error) data = {'exists': True, 'done': job.is_complete, 'error': error} else: data = {'exists': False} except Exception as e: google.datalab.utils.print_exception_with_last_stack(e) data = {'done': True, 'error': str(e)} return IPython.core.display.JSON(data) ================================================ FILE: google/datalab/utils/commands/_utils.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Utility functions.""" from __future__ import absolute_import from __future__ import division from __future__ import unicode_literals from builtins import str from past.builtins import basestring try: import IPython import IPython.core.display except ImportError: raise Exception('This module can only be loaded in ipython.') import json import pandas try: # Pandas profiling is not needed for build/test but will be in the container. import pandas_profiling except ImportError: pass import sys import yaml import google.datalab.data import google.datalab.bigquery import google.datalab.storage import google.datalab.utils from . import _html def notebook_environment(): """ Get the IPython user namespace. """ ipy = IPython.get_ipython() return ipy.user_ns def get_notebook_item(name): """ Get an item from the IPython environment. """ env = notebook_environment() return google.datalab.utils.get_item(env, name) def render_list(data): return IPython.core.display.HTML(_html.HtmlBuilder.render_list(data)) def render_dictionary(data, headers=None): """ Return a dictionary list formatted as a HTML table. Args: data: the dictionary list headers: the keys in the dictionary to use as table columns, in order. """ return IPython.core.display.HTML(_html.HtmlBuilder.render_table(data, headers)) def render_text(text, preformatted=False): """ Return text formatted as a HTML Args: text: the text to render preformatted: whether the text should be rendered as preformatted """ return IPython.core.display.HTML(_html.HtmlBuilder.render_text(text, preformatted)) def get_field_list(fields, schema): """ Convert a field list spec into a real list of field names. For tables, we return only the top-level non-RECORD fields as Google charts can't handle nested data. """ # If the fields weren't supplied get them from the schema. if schema: all_fields = [f['name'] for f in schema._bq_schema if f['type'] != 'RECORD'] if isinstance(fields, list): if schema: # validate fields exist for f in fields: if f not in all_fields: raise Exception('Cannot find field %s in given schema' % f) return fields if isinstance(fields, basestring) and fields != '*': if schema: # validate fields exist for f in fields.split(','): if f not in all_fields: raise Exception('Cannot find field %s in given schema' % f) return fields.split(',') if not schema: return [] return all_fields def _get_cols(fields, schema): """ Get column metadata for Google Charts based on field list and schema. """ typemap = { 'STRING': 'string', 'INT64': 'number', 'INTEGER': 'number', 'FLOAT': 'number', 'FLOAT64': 'number', 'BOOL': 'boolean', 'BOOLEAN': 'boolean', 'DATE': 'date', 'TIME': 'timeofday', 'DATETIME': 'datetime', 'TIMESTAMP': 'timestamp' } cols = [] for col in fields: if schema: f = schema[col] t = 'string' if f.mode == 'REPEATED' else typemap.get(f.type, 'string') cols.append({'id': f.name, 'label': f.name, 'type': t}) else: # This will only happen if we had no rows to infer a schema from, so the type # is not really important, except that GCharts will choke if we pass such a schema # to a chart if it is string x string so we default to number. cols.append({'id': col, 'label': col, 'type': 'number'}) return cols def _get_data_from_empty_list(source, fields='*', first_row=0, count=-1, schema=None): """ Helper function for _get_data that handles empty lists. """ fields = get_field_list(fields, schema) return {'cols': _get_cols(fields, schema), 'rows': []}, 0 def _get_data_from_list_of_dicts(source, fields='*', first_row=0, count=-1, schema=None): """ Helper function for _get_data that handles lists of dicts. """ if schema is None: schema = google.datalab.bigquery.Schema.from_data(source) fields = get_field_list(fields, schema) gen = source[first_row:first_row + count] if count >= 0 else source rows = [{'c': [{'v': row[c]} if c in row else {} for c in fields]} for row in gen] return {'cols': _get_cols(fields, schema), 'rows': rows}, len(source) def _get_data_from_list_of_lists(source, fields='*', first_row=0, count=-1, schema=None): """ Helper function for _get_data that handles lists of lists. """ if schema is None: schema = google.datalab.bigquery.Schema.from_data(source) fields = get_field_list(fields, schema) gen = source[first_row:first_row + count] if count >= 0 else source cols = [schema.find(name) for name in fields] rows = [{'c': [{'v': row[i]} for i in cols]} for row in gen] return {'cols': _get_cols(fields, schema), 'rows': rows}, len(source) def _get_data_from_dataframe(source, fields='*', first_row=0, count=-1, schema=None): """ Helper function for _get_data that handles Pandas DataFrames. """ if schema is None: schema = google.datalab.bigquery.Schema.from_data(source) fields = get_field_list(fields, schema) rows = [] if count < 0: count = len(source.index) df_slice = source.reset_index(drop=True)[first_row:first_row + count] for index, data_frame_row in df_slice.iterrows(): row = data_frame_row.to_dict() for key in list(row.keys()): val = row[key] if isinstance(val, pandas.Timestamp): row[key] = val.to_pydatetime() rows.append({'c': [{'v': row[c]} if c in row else {} for c in fields]}) cols = _get_cols(fields, schema) return {'cols': cols, 'rows': rows}, len(source) def _get_data_from_table(source, fields='*', first_row=0, count=-1, schema=None): """ Helper function for _get_data that handles BQ Tables. """ if not source.exists(): return _get_data_from_empty_list(source, fields, first_row, count) if schema is None: schema = source.schema fields = get_field_list(fields, schema) gen = source.range(first_row, count) if count >= 0 else source rows = [{'c': [{'v': row[c]} if c in row else {} for c in fields]} for row in gen] return {'cols': _get_cols(fields, schema), 'rows': rows}, source.length def get_data(source, fields='*', env=None, first_row=0, count=-1, schema=None): """ A utility function to get a subset of data from a Table, Query, Pandas dataframe or List. Args: source: the source of the data. Can be a Table, Pandas DataFrame, List of dictionaries or lists, or a string, in which case it is expected to be the name of a table in BQ. fields: a list of fields that we want to return as a list of strings, comma-separated string, or '*' for all. env: if the data source is a Query module, this is the set of variable overrides for parameterizing the Query. first_row: the index of the first row to return; default 0. Onl;y used if count is non-negative. count: the number or rows to return. If negative (the default), return all rows. schema: the schema of the data. Optional; if supplied this can be used to help do type-coercion. Returns: A tuple consisting of a dictionary and a count; the dictionary has two entries: 'cols' which is a list of column metadata entries for Google Charts, and 'rows' which is a list of lists of values. The count is the total number of rows in the source (independent of the first_row/count parameters). Raises: Exception if the request could not be fulfilled. """ ipy = IPython.get_ipython() if env is None: env = {} env.update(ipy.user_ns) if isinstance(source, basestring): source = google.datalab.utils.get_item(ipy.user_ns, source, source) if isinstance(source, basestring): source = google.datalab.bigquery.Table(source) if isinstance(source, list): if len(source) == 0: return _get_data_from_empty_list(source, fields, first_row, count, schema) elif isinstance(source[0], dict): return _get_data_from_list_of_dicts(source, fields, first_row, count, schema) elif isinstance(source[0], list): return _get_data_from_list_of_lists(source, fields, first_row, count, schema) else: raise Exception("To get tabular data from a list it must contain dictionaries or lists.") elif isinstance(source, pandas.DataFrame): return _get_data_from_dataframe(source, fields, first_row, count, schema) elif isinstance(source, google.datalab.bigquery.Query): return _get_data_from_table(source.execute().result(), fields, first_row, count, schema) elif isinstance(source, google.datalab.bigquery.Table): return _get_data_from_table(source, fields, first_row, count, schema) else: raise Exception("Cannot chart %s; unsupported object type" % source) def handle_magic_line(line, cell, parser, namespace=None): """ Helper function for handling magic command lines given a parser with handlers set. """ try: args, cell = parser.parse(line, cell, namespace) if args: return args['func'](args, cell) except Exception as e: # e.args[0] is 'exit_0' if --help is provided in line. # In this case don't write anything to stderr. if e.args and e.args[0] == 'exit_0': return sys.stderr.write('\n' + str(e)) sys.stderr.flush() def expand_var(v, env): """ If v is a variable reference (for example: '$myvar'), replace it using the supplied env dictionary. Args: v: the variable to replace if needed. env: user supplied dictionary. Raises: Exception if v is a variable reference but it is not found in env. """ if len(v) == 0: return v # Using len() and v[0] instead of startswith makes this Unicode-safe. if v[0] == '$': v = v[1:] if len(v) and v[0] != '$': if v in env: v = env[v] else: raise Exception('Cannot expand variable $%s' % v) return v def replace_vars(config, env): """ Replace variable references in config using the supplied env dictionary. Args: config: the config to parse. Can be a tuple, list or dict. env: user supplied dictionary. Raises: Exception if any variable references are not found in env. """ if isinstance(config, dict): for k, v in list(config.items()): if isinstance(v, dict) or isinstance(v, list) or isinstance(v, tuple): replace_vars(v, env) elif isinstance(v, basestring): config[k] = expand_var(v, env) elif isinstance(config, list): for i, v in enumerate(config): if isinstance(v, dict) or isinstance(v, list) or isinstance(v, tuple): replace_vars(v, env) elif isinstance(v, basestring): config[i] = expand_var(v, env) elif isinstance(config, tuple): # TODO(gram): figure out how to handle these if the tuple elements are scalar for v in config: if isinstance(v, dict) or isinstance(v, list) or isinstance(v, tuple): replace_vars(v, env) def parse_config(config, env, as_dict=True): """ Parse a config from a magic cell body. This could be JSON or YAML. We turn it into a Python dictionary then recursively replace any variable references using the supplied env dictionary. """ if config is None: return None stripped = config.strip() if len(stripped) == 0: config = {} elif stripped[0] == '{': config = json.loads(config) else: config = yaml.load(config) if as_dict: config = dict(config) # Now we need to walk the config dictionary recursively replacing any '$name' vars. replace_vars(config, env) return config def parse_config_for_selected_keys(content, keys): """ Parse a config from a magic cell body for selected config keys. For example, if 'content' is: config_item1: value1 config_item2: value2 config_item3: value3 and 'keys' are: [config_item1, config_item3] The results will be a tuple of 1. The parsed config items (dict): {config_item1: value1, config_item3: value3} 2. The remaining content (string): config_item2: value2 Args: content: the input content. A string. It has to be a yaml or JSON string. keys: a list of keys to retrieve from content. Note that it only checks top level keys in the dict. Returns: A tuple. First is the parsed config including only selected keys. Second is the remaining content. Raises: Exception if the content is not a valid yaml or JSON string. """ config_items = {key: None for key in keys} if not content: return config_items, content stripped = content.strip() if len(stripped) == 0: return {}, None elif stripped[0] == '{': config = json.loads(content) else: config = yaml.load(content) if not isinstance(config, dict): raise ValueError('Invalid config.') for key in keys: config_items[key] = config.pop(key, None) if not config: return config_items, None if stripped[0] == '{': content_out = json.dumps(config, indent=4) else: content_out = yaml.dump(config, default_flow_style=False) return config_items, content_out def validate_config(config, required_keys, optional_keys=None): """ Validate a config dictionary to make sure it includes all required keys and does not include any unexpected keys. Args: config: the config to validate. required_keys: the names of the keys that the config must have. optional_keys: the names of the keys that the config can have. Raises: Exception if the config is not a dict or invalid. """ if optional_keys is None: optional_keys = [] if not isinstance(config, dict): raise Exception('config is not dict type') invalid_keys = set(config) - set(required_keys + optional_keys) if len(invalid_keys) > 0: raise Exception('Invalid config with unexpected keys ' '"%s"' % ', '.join(e for e in invalid_keys)) missing_keys = set(required_keys) - set(config) if len(missing_keys) > 0: raise Exception('Invalid config with missing keys "%s"' % ', '.join(missing_keys)) def validate_config_must_have(config, required_keys): """ Validate a config dictionary to make sure it has all of the specified keys Args: config: the config to validate. required_keys: the list of possible keys that config must include. Raises: Exception if the config does not have any of them. """ missing_keys = set(required_keys) - set(config) if len(missing_keys) > 0: raise Exception('Invalid config with missing keys "%s"' % ', '.join(missing_keys)) def validate_config_has_one_of(config, one_of_keys): """ Validate a config dictionary to make sure it has one and only one key in one_of_keys. Args: config: the config to validate. one_of_keys: the list of possible keys that config can have one and only one. Raises: Exception if the config does not have any of them, or multiple of them. """ intersection = set(config).intersection(one_of_keys) if len(intersection) > 1: raise Exception('Only one of the values in "%s" is needed' % ', '.join(intersection)) if len(intersection) == 0: raise Exception('One of the values in "%s" is needed' % ', '.join(one_of_keys)) def validate_config_value(value, possible_values): """ Validate a config value to make sure it is one of the possible values. Args: value: the config value to validate. possible_values: the possible values the value can be Raises: Exception if the value is not one of possible values. """ if value not in possible_values: raise Exception('Invalid config value "%s". Possible values are ' '%s' % (value, ', '.join(e for e in possible_values))) # For chart and table HTML viewers, we use a list of table names and reference # instead the indices in the HTML, so as not to include things like projectID, etc, # in the HTML. _data_sources = [] def get_data_source_index(name): if name not in _data_sources: _data_sources.append(name) return _data_sources.index(name) def validate_gcs_path(path, require_object): """ Check whether a given path is a valid GCS path. Args: path: the config to check. require_object: if True, the path has to be an object path but not bucket path. Raises: Exception if the path is invalid """ bucket, key = google.datalab.storage._bucket.parse_name(path) if bucket is None: raise Exception('Invalid GCS path "%s"' % path) if require_object and key is None: raise Exception('It appears the GCS path "%s" is a bucket path but not an object path' % path) def parse_control_options(controls, variable_defaults=None): """ Parse a set of control options. Args: controls: The dictionary of control options. variable_defaults: If the controls are for a Query with variables, then this is the default variable values defined in the Query module. The options in the controls parameter can override these but if a variable has no 'value' property then we fall back to these. Returns: - the HTML for the controls. - the default values for the controls as a dict. - the list of DIV IDs of the controls. """ controls_html = '' control_defaults = {} control_ids = [] div_id = _html.Html.next_id() if variable_defaults is None: variable_defaults = {} for varname, control in list(controls.items()): label = control.get('label', varname) control_id = div_id + '__' + varname control_ids.append(control_id) value = control.get('value', variable_defaults.get(varname, None)) # The user should usually specify the type but we will default to 'textbox' for strings # and 'set' for lists. if isinstance(value, basestring): type = 'textbox' elif isinstance(value, list): type = 'set' else: type = None type = control.get('type', type) if type == 'picker': choices = control.get('choices', value) if not isinstance(choices, list) or len(choices) == 0: raise Exception('picker control must specify a nonempty set of choices') if value is None: value = choices[0] choices_html = '' for i, choice in enumerate(choices): choices_html += "" % \ (choice, ("selected=\"selected\"" if choice == value else ''), choice) control_html = "{label}" \ .format(label=label, id=control_id, choices=choices_html) elif type == 'set': # Multi-picker; implemented as checkboxes. # TODO(gram): consider using "name" property of the control to group checkboxes. That # way we can save the code of constructing and parsing control Ids with sequential # numbers in it. Multiple checkboxes can share the same name. choices = control.get('choices', value) if not isinstance(choices, list) or len(choices) == 0: raise Exception('set control must specify a nonempty set of choices') if value is None: value = choices choices_html = '' control_ids[-1] = '%s:%d' % (control_id, len(choices)) # replace ID to include count. for i, choice in enumerate(choices): checked = choice in value choice_id = '%s:%d' % (control_id, i) # TODO(gram): we may want a 'Submit/Refresh button as we may not want to rerun # query on each checkbox change. choices_html += """
""".format(id=choice_id, choice=choice, checked="checked" if checked else '') control_html = "{label}
{choices}
".format(label=label, choices=choices_html) elif type == 'checkbox': control_html = """ """.format(label=label, id=control_id, checked="checked" if value else '') elif type == 'slider': min_ = control.get('min', None) max_ = control.get('max', None) if min_ is None or max_ is None: raise Exception('slider control must specify a min and max value') if max_ <= min_: raise Exception('slider control must specify a min value less than max value') step = control.get('step', 1 if isinstance(min_, int) and isinstance(max_, int) else (float(max_ - min_) / 10.0)) if value is None: value = min_ control_html = """ {label} """.format(label=label, id=control_id, value=value, min=min_, max=max_, step=step) elif type == 'textbox': if value is None: value = '' control_html = "{label}" \ .format(label=label, value=value, id=control_id) else: raise Exception( 'Unknown control type %s (expected picker, slider, checkbox, textbox or set)' % type) control_defaults[varname] = value controls_html += "
{control}
\n" \ .format(control=control_html) controls_html = "
{controls}
".format(controls=controls_html) return controls_html, control_defaults, control_ids def chart_html(driver_name, chart_type, source, chart_options=None, fields='*', refresh_interval=0, refresh_data=None, control_defaults=None, control_ids=None, schema=None): """ Return HTML for a chart. Args: driver_name: the name of the chart driver. Currently we support 'plotly' or 'gcharts'. chart_type: string specifying type of chart. source: the data source for the chart. Can be actual data (e.g. list) or the name of a data source (e.g. the name of a query module). chart_options: a dictionary of options for the chart. Can contain a 'controls' entry specifying controls. Other entries are passed as JSON to Google Charts. fields: the fields to chart. Can be '*' for all fields (only sensible if the columns are ordered; e.g. a Query or list of lists, but not a list of dictionaries); otherwise a string containing a comma-separated list of field names. refresh_interval: a time in seconds after which the chart data will be refreshed. 0 if the chart should not be refreshed (i.e. the data is static). refresh_data: if the source is a list or other raw data, this is a YAML string containing metadata needed to support calls to refresh (get_chart_data). control_defaults: the default variable values for controls that are shared across charts including this one. control_ids: the DIV IDs for controls that are shared across charts including this one. schema: an optional schema for the data; if not supplied one will be inferred. Returns: A string containing the HTML for the chart. """ div_id = _html.Html.next_id() controls_html = '' if control_defaults is None: control_defaults = {} if control_ids is None: control_ids = [] if chart_options is not None and 'variables' in chart_options: controls = chart_options['variables'] del chart_options['variables'] # Just to make sure GCharts doesn't see them. controls_html, defaults, ids = parse_control_options(controls) # We augment what we are passed so that in principle we can have controls that are # shared by charts as well as controls that are specific to a chart. control_defaults.update(defaults) control_ids.extend(ids), _HTML_TEMPLATE = """
{controls}
""" count = 25 if chart_type == 'paged_table' else -1 data, total_count = get_data(source, fields, control_defaults, 0, count, schema) if refresh_data is None: if isinstance(source, basestring): source_index = get_data_source_index(source) refresh_data = {'source_index': source_index, 'name': source_index} else: refresh_data = {'name': 'raw data'} refresh_data['fields'] = fields # TODO(gram): check if we need to augment env with user_ns return _HTML_TEMPLATE \ .format(driver=driver_name, controls=controls_html, id=div_id, chart_type=chart_type, extra_class=" bqgc-controlled" if len(controls_html) else '', data=json.dumps(data, cls=google.datalab.utils.JSONEncoder), options=json.dumps(chart_options, cls=google.datalab.utils.JSONEncoder), refresh_data=json.dumps(refresh_data, cls=google.datalab.utils.JSONEncoder), refresh_interval=refresh_interval, control_ids=str(control_ids), total_rows=total_count) def profile_df(df): """ Generate a profile of data in a dataframe. Args: df: the Pandas dataframe. """ # The bootstrap CSS messes up the Datalab display so we tweak it to not have an effect. # TODO(gram): strip it out rather than this kludge. return IPython.core.display.HTML( pandas_profiling.ProfileReport(df).html.replace('bootstrap', 'nonexistent')) ================================================ FILE: google/datalab/utils/facets/__init__.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # The files in this directory are copied from # https://github.com/PAIR-code/facets/tree/master/facets_overview/python ================================================ FILE: google/datalab/utils/facets/base_feature_statistics_generator.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Base class for generating the feature_statistics proto from TensorFlow data. The proto is used as input for the Overview visualization. """ # flake8: noqa from functools import partial from .base_generic_feature_statistics_generator import BaseGenericFeatureStatisticsGenerator import tensorflow as tf # The feature name used to track sequence length when analyzing # tf.SequenceExamples. SEQUENCE_LENGTH_FEATURE_NAME = 'sequence length (derived feature)' class BaseFeatureStatisticsGenerator(BaseGenericFeatureStatisticsGenerator): """Base class for generator of stats proto from TF data.""" def __init__(self, fs_proto, datasets_proto, histogram_proto): BaseGenericFeatureStatisticsGenerator.__init__( self, fs_proto, datasets_proto, histogram_proto) def ProtoFromTfRecordFiles(self, files, max_entries=10000, features=None, is_sequence=False, iterator_options=None): """Creates a feature statistics proto from a set of TFRecord files. Args: files: A list of dicts describing files for each dataset for the proto. Each entry contains a 'path' field with the path to the TFRecord file on disk and a 'name' field to identify the dataset in the proto. max_entries: The maximum number of examples to load from each dataset in order to create the proto. Defaults to 10000. features: A list of strings that is an allowlist of feature names to create feature statistics for. If set to None then all features in the dataset are analyzed. Defaults to None. is_sequence: True if the input data from 'tables' are tf.SequenceExamples, False if tf.Examples. Defaults to false. iterator_options: Options to pass to the iterator that reads the examples. Defaults to None. Returns: The feature statistics proto for the provided files. """ datasets = [] for entry in files: entries, size = self._GetTfRecordEntries(entry['path'], max_entries, is_sequence, iterator_options) datasets.append({'entries': entries, 'size': size, 'name': entry['name']}) return self.GetDatasetsProto(datasets, features) def _ParseExample(self, example_features, example_feature_lists, entries, index): """Parses data from an example, populating a dictionary of feature values. Args: example_features: A map of strings to tf.Features from the example. example_feature_lists: A map of strings to tf.FeatureLists from the example. entries: A dictionary of all features parsed thus far and arrays of their values. This is mutated by the function. index: The index of the example to parse from a list of examples. Raises: TypeError: Raises an exception when a feature has inconsistent types across examples. """ features_seen = set() for feature_list, is_feature in zip( [example_features, example_feature_lists], [True, False]): sequence_length = None for feature_name in feature_list: # If this feature has not been seen in previous examples, then # initialize its entry into the entries dictionary. if feature_name not in entries: entries[feature_name] = { 'vals': [], 'counts': [], 'feat_lens': [], 'missing': index } feature_entry = entries[feature_name] feature = feature_list[feature_name] value_type = None value_list = [] if is_feature: # If parsing a tf.Feature, extract the type and values simply. if feature.HasField('float_list'): value_list = feature.float_list.value value_type = self.fs_proto.FLOAT elif feature.HasField('bytes_list'): value_list = feature.bytes_list.value value_type = self.fs_proto.STRING elif feature.HasField('int64_list'): value_list = feature.int64_list.value value_type = self.fs_proto.INT else: # If parsing a tf.FeatureList, get the type and values by iterating # over all Features in the FeatureList. sequence_length = len(feature.feature) if sequence_length != 0 and feature.feature[0].HasField('float_list'): for feat in feature.feature: for value in feat.float_list.value: value_list.append(value) value_type = self.fs_proto.FLOAT elif sequence_length != 0 and feature.feature[0].HasField( 'bytes_list'): for feat in feature.feature: for value in feat.bytes_list.value: value_list.append(value) value_type = self.fs_proto.STRING elif sequence_length != 0 and feature.feature[0].HasField( 'int64_list'): for feat in feature.feature: for value in feat.int64_list.value: value_list.append(value) value_type = self.fs_proto.INT if value_type is not None: if 'type' not in feature_entry: feature_entry['type'] = value_type elif feature_entry['type'] != value_type: raise TypeError('type mismatch for feature ' + feature_name) feature_entry['counts'].append(len(value_list)) feature_entry['vals'].extend(value_list) if sequence_length is not None: feature_entry['feat_lens'].append(sequence_length) if value_list: features_seen.add(feature_name) # For all previously-seen features not found in this example, update the # feature's missing value. for f in entries: fv = entries[f] if f not in features_seen: fv['missing'] += 1 def _GetEntries(self, paths, max_entries, iterator_from_file, is_sequence=False): """Extracts examples into a dictionary of feature values. Args: paths: A list of the paths to the files to parse. max_entries: The maximum number of examples to load. iterator_from_file: A method that takes a file path string and returns an iterator to the examples in that file. is_sequence: True if the input data from 'iterator_from_file' are tf.SequenceExamples, False if tf.Examples. Defaults to false. Returns: A tuple with two elements: - A dictionary of all features parsed thus far and arrays of their values. - The number of examples parsed. """ entries = {} index = 0 for filepath in paths: reader = iterator_from_file(filepath) for record in reader: if is_sequence: sequence_example = tf.train.SequenceExample.FromString(record) self._ParseExample(sequence_example.context.feature, sequence_example.feature_lists.feature_list, entries, index) else: self._ParseExample( tf.train.Example.FromString(record).features.feature, [], entries, index) index += 1 if index == max_entries: return entries, index return entries, index def _GetTfRecordEntries(self, path, max_entries, is_sequence, iterator_options): """Extracts TFRecord examples into a dictionary of feature values. Args: path: The path to the TFRecord file(s). max_entries: The maximum number of examples to load. is_sequence: True if the input data from 'path' are tf.SequenceExamples, False if tf.Examples. Defaults to false. iterator_options: Options to pass to the iterator that reads the examples. Defaults to None. Returns: A tuple with two elements: - A dictionary of all features parsed thus far and arrays of their values. - The number of examples parsed. """ return self._GetEntries([path], max_entries, partial( tf.python_io.tf_record_iterator, options=iterator_options), is_sequence) ================================================ FILE: google/datalab/utils/facets/base_generic_feature_statistics_generator.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # flake8: noqa """Base class for generating the feature_statistics proto from generic data. The proto is used as input for the Overview visualization. """ import numpy as np import pandas as pd class BaseGenericFeatureStatisticsGenerator(object): """Base class for generator of stats proto from generic data.""" def __init__(self, fs_proto, datasets_proto, histogram_proto): self.fs_proto = fs_proto self.datasets_proto = datasets_proto self.histogram_proto = histogram_proto def ProtoFromDataFrames(self, dataframes): """Creates a feature statistics proto from a set of pandas dataframes. Args: dataframes: A list of dicts describing tables for each dataset for the proto. Each entry contains a 'table' field of the dataframe of the data and a 'name' field to identify the dataset in the proto. Returns: The feature statistics proto for the provided tables. """ datasets = [] for dataframe in dataframes: table = dataframe['table'] table_entries = {} for col in table: table_entries[col] = self.NdarrayToEntry(table[col]) datasets.append({ 'entries': table_entries, 'size': len(table), 'name': dataframe['name'] }) return self.GetDatasetsProto(datasets) def DtypeToType(self, dtype): """Converts a Numpy dtype to the FeatureNameStatistics.Type proto enum.""" if dtype.char in np.typecodes['AllFloat']: return self.fs_proto.FLOAT elif (dtype.char in np.typecodes['AllInteger'] or dtype == np.bool or np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)): return self.fs_proto.INT else: return self.fs_proto.STRING def DtypeToNumberConverter(self, dtype): """Converts a Numpy dtype to a converter method if applicable. The converter method takes in a numpy array of objects of the provided dtype and returns a numpy array of the numbers backing that object for statistical analysis. Returns None if no converter is necessary. Args: dtype: The numpy dtype to make a converter for. Returns: The converter method or None. """ if np.issubdtype(dtype, np.datetime64): def DatetimesToNumbers(dt_list): return np.array([pd.Timestamp(dt).value for dt in dt_list]) return DatetimesToNumbers elif np.issubdtype(dtype, np.timedelta64): def TimedetlasToNumbers(td_list): return np.array([pd.Timedelta(td).value for td in td_list]) return TimedetlasToNumbers else: return None def NdarrayToEntry(self, x): """Converts an ndarray to the Entry format.""" row_counts = [] for row in x: try: rc = np.count_nonzero(~np.isnan(row)) if rc != 0: row_counts.append(rc) except TypeError: try: row_counts.append(row.size) except AttributeError: row_counts.append(1) data_type = self.DtypeToType(x.dtype) converter = self.DtypeToNumberConverter(x.dtype) flattened = x.ravel() orig_size = len(flattened) # Remove all None and nan values and count how many were removed. flattened = flattened[flattened != np.array(None)] if converter: flattened = converter(flattened) if data_type == self.fs_proto.STRING: flattened_temp = [] for x in flattened: try: if str(x) != 'nan': flattened_temp.append(x) except UnicodeEncodeError: if x.encode('utf-8') != 'nan': flattened_temp.append(x) flattened = flattened_temp else: flattened = flattened[~np.isnan(flattened)].tolist() missing = orig_size - len(flattened) return { 'vals': flattened, 'counts': row_counts, 'missing': missing, 'type': data_type } def GetDatasetsProto(self, datasets, features=None): """Generates the feature stats proto from dictionaries of feature values. Args: datasets: An array of dictionaries, one per dataset, each one containing: - 'entries': The dictionary of features in the dataset from the parsed examples. - 'size': The number of examples parsed for the dataset. - 'name': The name of the dataset. features: A list of strings that is an allowlist of feature names to create feature statistics for. If set to None then all features in the dataset are analyzed. Defaults to None. Returns: The feature statistics proto for the provided datasets. """ features_seen = set() allowlist_features = set(features) if features else None all_datasets = self.datasets_proto() # TODO(jwexler): Add ability to generate weighted feature stats # if there is a specified weight feature in the dataset. # Initialize each dataset for dataset in datasets: all_datasets.datasets.add( name=dataset['name'], num_examples=dataset['size']) # This outer loop ensures that for each feature seen in any of the provided # datasets, we check the feature once against all datasets. for outer_dataset in datasets: for key, value in outer_dataset['entries'].items(): # If we have a feature allowlist and this feature is not in the # allowlist then do not process it. # If we have processed this feature already, no need to do it again. if ((allowlist_features and key not in allowlist_features) or key in features_seen): continue features_seen.add(key) # Default to type int if no type is found, so that the fact that all # values are missing from this feature can be displayed. feature_type = value['type'] if 'type' in value else self.fs_proto.INT # Process the found feature for each dataset. for j, dataset in enumerate(datasets): feat = all_datasets.datasets[j].features.add( type=feature_type, name=str(key)) value = dataset['entries'].get(key) has_data = value is not None and (value['vals'].size != 0 if isinstance( value['vals'], np.ndarray) else value['vals']) commonstats = None # For numeric features, calculate numeric statistics. if feat.type in (self.fs_proto.INT, self.fs_proto.FLOAT): featstats = feat.num_stats commonstats = featstats.common_stats if has_data: nums = value['vals'] featstats.std_dev = np.asscalar(np.std(nums)) featstats.mean = np.asscalar(np.mean(nums)) featstats.min = np.asscalar(np.min(nums)) featstats.max = np.asscalar(np.max(nums)) featstats.median = np.asscalar(np.median(nums)) featstats.num_zeros = len(nums) - np.count_nonzero(nums) nums = np.array(nums) num_nan = len(nums[np.isnan(nums)]) num_posinf = len(nums[np.isposinf(nums)]) num_neginf = len(nums[np.isneginf(nums)]) # Remove all non-finite (including NaN) values from the numeric # values in order to calculate histogram buckets/counts. The # inf values will be added back to the first and last buckets. nums = nums[np.isfinite(nums)] counts, buckets = np.histogram(nums) hist = featstats.histograms.add() hist.type = self.histogram_proto.STANDARD hist.num_nan = num_nan for bucket_count in range(len(counts)): bucket = hist.buckets.add( low_value=buckets[bucket_count], high_value=buckets[bucket_count + 1], sample_count=np.asscalar(counts[bucket_count])) # Add any negative or positive infinities to the first and last # buckets in the histogram. if bucket_count == 0 and num_neginf > 0: bucket.low_value = float('-inf') bucket.sample_count += num_neginf elif bucket_count == len(counts) - 1 and num_posinf > 0: bucket.high_value = float('inf') bucket.sample_count += num_posinf if not hist.buckets: if num_neginf: hist.buckets.add( low_value=float('-inf'), high_value=float('-inf'), sample_count=num_neginf) if num_posinf: hist.buckets.add( low_value=float('inf'), high_value=float('inf'), sample_count=num_posinf) self._PopulateQuantilesHistogram(featstats.histograms.add(), nums.tolist()) elif feat.type == self.fs_proto.STRING: featstats = feat.string_stats commonstats = featstats.common_stats if has_data: strs = [] for item in value['vals']: strs.append(item if hasattr(item, '__len__') else str(item)) featstats.avg_length = np.mean(np.vectorize(len)(strs)) vals, counts = np.unique(strs, return_counts=True) featstats.unique = len(vals) sorted_vals = sorted(zip(counts, vals), reverse=True) for val_index, val in enumerate(sorted_vals): if val[1].dtype.type is np.str_: printable_val = val[1] else: try: printable_val = val[1].decode('UTF-8', 'strict') except (UnicodeDecodeError, UnicodeEncodeError): printable_val = '__BYTES_VALUE__' bucket = featstats.rank_histogram.buckets.add( low_rank=val_index, high_rank=val_index, sample_count=np.asscalar(val[0]), label=printable_val) if val_index < 2: featstats.top_values.add( value=bucket.label, frequency=bucket.sample_count) # Add the common stats regardless of the feature type. if has_data: commonstats.num_missing = value['missing'] commonstats.num_non_missing = (all_datasets.datasets[j].num_examples - featstats.common_stats.num_missing) commonstats.min_num_values = int(np.min(value['counts']).astype(int)) commonstats.max_num_values = int(np.max(value['counts']).astype(int)) commonstats.avg_num_values = np.mean(value['counts']) if 'feat_lens' in value and value['feat_lens']: self._PopulateQuantilesHistogram( commonstats.feature_list_length_histogram, value['feat_lens']) self._PopulateQuantilesHistogram(commonstats.num_values_histogram, value['counts']) else: commonstats.num_non_missing = 0 commonstats.num_missing = all_datasets.datasets[j].num_examples return all_datasets def _PopulateQuantilesHistogram(self, hist, nums): """Fills in the histogram with quantile information from the provided array. Args: hist: A Histogram proto message to fill in. nums: A list of numbers to create a quantiles histogram from. """ if not nums: return num_quantile_buckets = 10 quantiles_to_get = [ x * 100 / num_quantile_buckets for x in range(num_quantile_buckets + 1) ] quantiles = np.percentile(nums, quantiles_to_get) hist.type = self.histogram_proto.QUANTILES quantiles_sample_count = float(len(nums)) / num_quantile_buckets for low, high in zip(quantiles, quantiles[1:]): hist.buckets.add( low_value=low, high_value=high, sample_count=quantiles_sample_count) ================================================ FILE: google/datalab/utils/facets/feature_statistics_generator.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Class for generating the feature_statistics proto. The proto is used as input for the Overview visualization. """ import warnings from .base_feature_statistics_generator import BaseFeatureStatisticsGenerator from . import feature_statistics_pb2 as fs class FeatureStatisticsGenerator(BaseFeatureStatisticsGenerator): """Generator of stats proto from TF data.""" def __init__(self): BaseFeatureStatisticsGenerator.__init__(self, fs.FeatureNameStatistics, fs.DatasetFeatureStatisticsList, fs.Histogram) def ProtoFromTfRecordFiles(files, max_entries=10000, features=None, is_sequence=False, iterator_options=None): """Creates a feature statistics proto from a set of TFRecord files. Args: files: A list of dicts describing files for each dataset for the proto. Each entry contains a 'path' field with the path to the TFRecord file on disk and a 'name' field to identify the dataset in the proto. max_entries: The maximum number of examples to load from each dataset in order to create the proto. Defaults to 10000. features: A list of strings that is a allowlist of feature names to create feature statistics for. If set to None then all features in the dataset are analyzed. Defaults to None. is_sequence: True if the input data from 'tables' are tf.SequenceExamples, False if tf.Examples. Defaults to false. iterator_options: Options to pass to the iterator that reads the examples. Defaults to None. Returns: The feature statistics proto for the provided files. """ warnings.warn( 'Use GenericFeatureStatisticsGenerator class method instead.', DeprecationWarning) return FeatureStatisticsGenerator().ProtoFromTfRecordFiles( files, max_entries, features, is_sequence, iterator_options) ================================================ FILE: google/datalab/utils/facets/feature_statistics_pb2.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # Generated by the protocol buffer compiler. DO NOT EDIT! # source: feature_statistics.proto # flake8: noqa import sys _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database from google.protobuf import descriptor_pb2 # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() DESCRIPTOR = _descriptor.FileDescriptor( name='feature_statistics.proto', package='featureStatistics', syntax='proto3', serialized_pb=_b('\n\x18\x66\x65\x61ture_statistics.proto\x12\x11\x66\x65\x61tureStatistics\"]\n\x1c\x44\x61tasetFeatureStatisticsList\x12=\n\x08\x64\x61tasets\x18\x01 \x03(\x0b\x32+.featureStatistics.DatasetFeatureStatistics\"\x99\x01\n\x18\x44\x61tasetFeatureStatistics\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x14\n\x0cnum_examples\x18\x02 \x01(\x04\x12\x1d\n\x15weighted_num_examples\x18\x04 \x01(\x01\x12:\n\x08\x66\x65\x61tures\x18\x03 \x03(\x0b\x32(.featureStatistics.FeatureNameStatistics\"\x8b\x03\n\x15\x46\x65\x61tureNameStatistics\x12\x0c\n\x04name\x18\x01 \x01(\t\x12;\n\x04type\x18\x02 \x01(\x0e\x32-.featureStatistics.FeatureNameStatistics.Type\x12\x39\n\tnum_stats\x18\x03 \x01(\x0b\x32$.featureStatistics.NumericStatisticsH\x00\x12;\n\x0cstring_stats\x18\x04 \x01(\x0b\x32#.featureStatistics.StringStatisticsH\x00\x12\x39\n\x0b\x62ytes_stats\x18\x05 \x01(\x0b\x32\".featureStatistics.BytesStatisticsH\x00\x12\x38\n\x0c\x63ustom_stats\x18\x06 \x03(\x0b\x32\".featureStatistics.CustomStatistic\"1\n\x04Type\x12\x07\n\x03INT\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\n\n\x06STRING\x10\x02\x12\t\n\x05\x42YTES\x10\x03\x42\x07\n\x05stats\"x\n\x18WeightedCommonStatistics\x12\x17\n\x0fnum_non_missing\x18\x01 \x01(\x01\x12\x13\n\x0bnum_missing\x18\x02 \x01(\x01\x12\x16\n\x0e\x61vg_num_values\x18\x03 \x01(\x01\x12\x16\n\x0etot_num_values\x18\x04 \x01(\x01\"w\n\x0f\x43ustomStatistic\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x03num\x18\x02 \x01(\x01H\x00\x12\r\n\x03str\x18\x03 \x01(\tH\x00\x12\x31\n\thistogram\x18\x04 \x01(\x0b\x32\x1c.featureStatistics.HistogramH\x00\x42\x05\n\x03val\"\xaa\x02\n\x11NumericStatistics\x12\x39\n\x0c\x63ommon_stats\x18\x01 \x01(\x0b\x32#.featureStatistics.CommonStatistics\x12\x0c\n\x04mean\x18\x02 \x01(\x01\x12\x0f\n\x07std_dev\x18\x03 \x01(\x01\x12\x11\n\tnum_zeros\x18\x04 \x01(\x04\x12\x0b\n\x03min\x18\x05 \x01(\x01\x12\x0e\n\x06median\x18\x06 \x01(\x01\x12\x0b\n\x03max\x18\x07 \x01(\x01\x12\x30\n\nhistograms\x18\x08 \x03(\x0b\x32\x1c.featureStatistics.Histogram\x12L\n\x16weighted_numeric_stats\x18\t \x01(\x0b\x32,.featureStatistics.WeightedNumericStatistics\"\x8c\x03\n\x10StringStatistics\x12\x39\n\x0c\x63ommon_stats\x18\x01 \x01(\x0b\x32#.featureStatistics.CommonStatistics\x12\x0e\n\x06unique\x18\x02 \x01(\x04\x12\x44\n\ntop_values\x18\x03 \x03(\x0b\x32\x30.featureStatistics.StringStatistics.FreqAndValue\x12\x12\n\navg_length\x18\x04 \x01(\x02\x12\x38\n\x0erank_histogram\x18\x05 \x01(\x0b\x32 .featureStatistics.RankHistogram\x12J\n\x15weighted_string_stats\x18\x06 \x01(\x0b\x32+.featureStatistics.WeightedStringStatistics\x1aM\n\x0c\x46reqAndValue\x12\x1b\n\x0f\x64\x65precated_freq\x18\x01 \x01(\x04\x42\x02\x18\x01\x12\r\n\x05value\x18\x02 \x01(\t\x12\x11\n\tfrequency\x18\x03 \x01(\x01\"|\n\x19WeightedNumericStatistics\x12\x0c\n\x04mean\x18\x01 \x01(\x01\x12\x0f\n\x07std_dev\x18\x02 \x01(\x01\x12\x0e\n\x06median\x18\x03 \x01(\x01\x12\x30\n\nhistograms\x18\x04 \x03(\x0b\x32\x1c.featureStatistics.Histogram\"\x9a\x01\n\x18WeightedStringStatistics\x12\x44\n\ntop_values\x18\x01 \x03(\x0b\x32\x30.featureStatistics.StringStatistics.FreqAndValue\x12\x38\n\x0erank_histogram\x18\x02 \x01(\x0b\x32 .featureStatistics.RankHistogram\"\xa1\x01\n\x0f\x42ytesStatistics\x12\x39\n\x0c\x63ommon_stats\x18\x01 \x01(\x0b\x32#.featureStatistics.CommonStatistics\x12\x0e\n\x06unique\x18\x02 \x01(\x04\x12\x15\n\ravg_num_bytes\x18\x03 \x01(\x02\x12\x15\n\rmin_num_bytes\x18\x04 \x01(\x02\x12\x15\n\rmax_num_bytes\x18\x05 \x01(\x02\"\xed\x02\n\x10\x43ommonStatistics\x12\x17\n\x0fnum_non_missing\x18\x01 \x01(\x04\x12\x13\n\x0bnum_missing\x18\x02 \x01(\x04\x12\x16\n\x0emin_num_values\x18\x03 \x01(\x04\x12\x16\n\x0emax_num_values\x18\x04 \x01(\x04\x12\x16\n\x0e\x61vg_num_values\x18\x05 \x01(\x02\x12\x16\n\x0etot_num_values\x18\x08 \x01(\x04\x12:\n\x14num_values_histogram\x18\x06 \x01(\x0b\x32\x1c.featureStatistics.Histogram\x12J\n\x15weighted_common_stats\x18\x07 \x01(\x0b\x32+.featureStatistics.WeightedCommonStatistics\x12\x43\n\x1d\x66\x65\x61ture_list_length_histogram\x18\t \x01(\x0b\x32\x1c.featureStatistics.Histogram\"\xc4\x02\n\tHistogram\x12\x0f\n\x07num_nan\x18\x01 \x01(\x04\x12\x15\n\rnum_undefined\x18\x02 \x01(\x04\x12\x34\n\x07\x62uckets\x18\x03 \x03(\x0b\x32#.featureStatistics.Histogram.Bucket\x12\x38\n\x04type\x18\x04 \x01(\x0e\x32*.featureStatistics.Histogram.HistogramType\x12\x0c\n\x04name\x18\x05 \x01(\t\x1a\x63\n\x06\x42ucket\x12\x11\n\tlow_value\x18\x01 \x01(\x01\x12\x12\n\nhigh_value\x18\x02 \x01(\x01\x12\x1c\n\x10\x64\x65precated_count\x18\x03 \x01(\x04\x42\x02\x18\x01\x12\x14\n\x0csample_count\x18\x04 \x01(\x01\",\n\rHistogramType\x12\x0c\n\x08STANDARD\x10\x00\x12\r\n\tQUANTILES\x10\x01\"\xc9\x01\n\rRankHistogram\x12\x38\n\x07\x62uckets\x18\x01 \x03(\x0b\x32\'.featureStatistics.RankHistogram.Bucket\x12\x0c\n\x04name\x18\x02 \x01(\t\x1ap\n\x06\x42ucket\x12\x10\n\x08low_rank\x18\x01 \x01(\x04\x12\x11\n\thigh_rank\x18\x02 \x01(\x04\x12\x1c\n\x10\x64\x65precated_count\x18\x03 \x01(\x04\x42\x02\x18\x01\x12\r\n\x05label\x18\x04 \x01(\t\x12\x14\n\x0csample_count\x18\x05 \x01(\x01\x62\x06proto3') ) _FEATURENAMESTATISTICS_TYPE = _descriptor.EnumDescriptor( name='Type', full_name='featureStatistics.FeatureNameStatistics.Type', filename=None, file=DESCRIPTOR, values=[ _descriptor.EnumValueDescriptor( name='INT', index=0, number=0, options=None, type=None), _descriptor.EnumValueDescriptor( name='FLOAT', index=1, number=1, options=None, type=None), _descriptor.EnumValueDescriptor( name='STRING', index=2, number=2, options=None, type=None), _descriptor.EnumValueDescriptor( name='BYTES', index=3, number=3, options=None, type=None), ], containing_type=None, options=None, serialized_start=636, serialized_end=685, ) _sym_db.RegisterEnumDescriptor(_FEATURENAMESTATISTICS_TYPE) _HISTOGRAM_HISTOGRAMTYPE = _descriptor.EnumDescriptor( name='HistogramType', full_name='featureStatistics.Histogram.HistogramType', filename=None, file=DESCRIPTOR, values=[ _descriptor.EnumValueDescriptor( name='STANDARD', index=0, number=0, options=None, type=None), _descriptor.EnumValueDescriptor( name='QUANTILES', index=1, number=1, options=None, type=None), ], containing_type=None, options=None, serialized_start=2735, serialized_end=2779, ) _sym_db.RegisterEnumDescriptor(_HISTOGRAM_HISTOGRAMTYPE) _DATASETFEATURESTATISTICSLIST = _descriptor.Descriptor( name='DatasetFeatureStatisticsList', full_name='featureStatistics.DatasetFeatureStatisticsList', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( name='datasets', full_name='featureStatistics.DatasetFeatureStatisticsList.datasets', index=0, number=1, type=11, cpp_type=10, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ ], options=None, is_extendable=False, syntax='proto3', extension_ranges=[], oneofs=[ ], serialized_start=47, serialized_end=140, ) _DATASETFEATURESTATISTICS = _descriptor.Descriptor( name='DatasetFeatureStatistics', full_name='featureStatistics.DatasetFeatureStatistics', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( name='name', full_name='featureStatistics.DatasetFeatureStatistics.name', index=0, number=1, type=9, cpp_type=9, label=1, has_default_value=False, default_value=_b("").decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='num_examples', full_name='featureStatistics.DatasetFeatureStatistics.num_examples', index=1, number=2, type=4, cpp_type=4, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='weighted_num_examples', full_name='featureStatistics.DatasetFeatureStatistics.weighted_num_examples', index=2, number=4, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='features', full_name='featureStatistics.DatasetFeatureStatistics.features', index=3, number=3, type=11, cpp_type=10, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ ], options=None, is_extendable=False, syntax='proto3', extension_ranges=[], oneofs=[ ], serialized_start=143, serialized_end=296, ) _FEATURENAMESTATISTICS = _descriptor.Descriptor( name='FeatureNameStatistics', full_name='featureStatistics.FeatureNameStatistics', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( name='name', full_name='featureStatistics.FeatureNameStatistics.name', index=0, number=1, type=9, cpp_type=9, label=1, has_default_value=False, default_value=_b("").decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='type', full_name='featureStatistics.FeatureNameStatistics.type', index=1, number=2, type=14, cpp_type=8, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='num_stats', full_name='featureStatistics.FeatureNameStatistics.num_stats', index=2, number=3, type=11, cpp_type=10, label=1, has_default_value=False, default_value=None, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='string_stats', full_name='featureStatistics.FeatureNameStatistics.string_stats', index=3, number=4, type=11, cpp_type=10, label=1, has_default_value=False, default_value=None, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='bytes_stats', full_name='featureStatistics.FeatureNameStatistics.bytes_stats', index=4, number=5, type=11, cpp_type=10, label=1, has_default_value=False, default_value=None, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='custom_stats', full_name='featureStatistics.FeatureNameStatistics.custom_stats', index=5, number=6, type=11, cpp_type=10, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ _FEATURENAMESTATISTICS_TYPE, ], options=None, is_extendable=False, syntax='proto3', extension_ranges=[], oneofs=[ _descriptor.OneofDescriptor( name='stats', full_name='featureStatistics.FeatureNameStatistics.stats', index=0, containing_type=None, fields=[]), ], serialized_start=299, serialized_end=694, ) _WEIGHTEDCOMMONSTATISTICS = _descriptor.Descriptor( name='WeightedCommonStatistics', full_name='featureStatistics.WeightedCommonStatistics', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( name='num_non_missing', full_name='featureStatistics.WeightedCommonStatistics.num_non_missing', index=0, number=1, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='num_missing', full_name='featureStatistics.WeightedCommonStatistics.num_missing', index=1, number=2, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='avg_num_values', full_name='featureStatistics.WeightedCommonStatistics.avg_num_values', index=2, number=3, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='tot_num_values', full_name='featureStatistics.WeightedCommonStatistics.tot_num_values', index=3, number=4, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ ], options=None, is_extendable=False, syntax='proto3', extension_ranges=[], oneofs=[ ], serialized_start=696, serialized_end=816, ) _CUSTOMSTATISTIC = _descriptor.Descriptor( name='CustomStatistic', full_name='featureStatistics.CustomStatistic', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( name='name', full_name='featureStatistics.CustomStatistic.name', index=0, number=1, type=9, cpp_type=9, label=1, has_default_value=False, default_value=_b("").decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='num', full_name='featureStatistics.CustomStatistic.num', index=1, number=2, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='str', full_name='featureStatistics.CustomStatistic.str', index=2, number=3, type=9, cpp_type=9, label=1, has_default_value=False, default_value=_b("").decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='histogram', full_name='featureStatistics.CustomStatistic.histogram', index=3, number=4, type=11, cpp_type=10, label=1, has_default_value=False, default_value=None, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ ], options=None, is_extendable=False, syntax='proto3', extension_ranges=[], oneofs=[ _descriptor.OneofDescriptor( name='val', full_name='featureStatistics.CustomStatistic.val', index=0, containing_type=None, fields=[]), ], serialized_start=818, serialized_end=937, ) _NUMERICSTATISTICS = _descriptor.Descriptor( name='NumericStatistics', full_name='featureStatistics.NumericStatistics', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( name='common_stats', full_name='featureStatistics.NumericStatistics.common_stats', index=0, number=1, type=11, cpp_type=10, label=1, has_default_value=False, default_value=None, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='mean', full_name='featureStatistics.NumericStatistics.mean', index=1, number=2, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='std_dev', full_name='featureStatistics.NumericStatistics.std_dev', index=2, number=3, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='num_zeros', full_name='featureStatistics.NumericStatistics.num_zeros', index=3, number=4, type=4, cpp_type=4, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='min', full_name='featureStatistics.NumericStatistics.min', index=4, number=5, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='median', full_name='featureStatistics.NumericStatistics.median', index=5, number=6, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='max', full_name='featureStatistics.NumericStatistics.max', index=6, number=7, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='histograms', full_name='featureStatistics.NumericStatistics.histograms', index=7, number=8, type=11, cpp_type=10, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='weighted_numeric_stats', full_name='featureStatistics.NumericStatistics.weighted_numeric_stats', index=8, number=9, type=11, cpp_type=10, label=1, has_default_value=False, default_value=None, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ ], options=None, is_extendable=False, syntax='proto3', extension_ranges=[], oneofs=[ ], serialized_start=940, serialized_end=1238, ) _STRINGSTATISTICS_FREQANDVALUE = _descriptor.Descriptor( name='FreqAndValue', full_name='featureStatistics.StringStatistics.FreqAndValue', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( name='deprecated_freq', full_name='featureStatistics.StringStatistics.FreqAndValue.deprecated_freq', index=0, number=1, type=4, cpp_type=4, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\030\001'))), _descriptor.FieldDescriptor( name='value', full_name='featureStatistics.StringStatistics.FreqAndValue.value', index=1, number=2, type=9, cpp_type=9, label=1, has_default_value=False, default_value=_b("").decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='frequency', full_name='featureStatistics.StringStatistics.FreqAndValue.frequency', index=2, number=3, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ ], options=None, is_extendable=False, syntax='proto3', extension_ranges=[], oneofs=[ ], serialized_start=1560, serialized_end=1637, ) _STRINGSTATISTICS = _descriptor.Descriptor( name='StringStatistics', full_name='featureStatistics.StringStatistics', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( name='common_stats', full_name='featureStatistics.StringStatistics.common_stats', index=0, number=1, type=11, cpp_type=10, label=1, has_default_value=False, default_value=None, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='unique', full_name='featureStatistics.StringStatistics.unique', index=1, number=2, type=4, cpp_type=4, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='top_values', full_name='featureStatistics.StringStatistics.top_values', index=2, number=3, type=11, cpp_type=10, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='avg_length', full_name='featureStatistics.StringStatistics.avg_length', index=3, number=4, type=2, cpp_type=6, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='rank_histogram', full_name='featureStatistics.StringStatistics.rank_histogram', index=4, number=5, type=11, cpp_type=10, label=1, has_default_value=False, default_value=None, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='weighted_string_stats', full_name='featureStatistics.StringStatistics.weighted_string_stats', index=5, number=6, type=11, cpp_type=10, label=1, has_default_value=False, default_value=None, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[_STRINGSTATISTICS_FREQANDVALUE, ], enum_types=[ ], options=None, is_extendable=False, syntax='proto3', extension_ranges=[], oneofs=[ ], serialized_start=1241, serialized_end=1637, ) _WEIGHTEDNUMERICSTATISTICS = _descriptor.Descriptor( name='WeightedNumericStatistics', full_name='featureStatistics.WeightedNumericStatistics', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( name='mean', full_name='featureStatistics.WeightedNumericStatistics.mean', index=0, number=1, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='std_dev', full_name='featureStatistics.WeightedNumericStatistics.std_dev', index=1, number=2, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='median', full_name='featureStatistics.WeightedNumericStatistics.median', index=2, number=3, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='histograms', full_name='featureStatistics.WeightedNumericStatistics.histograms', index=3, number=4, type=11, cpp_type=10, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ ], options=None, is_extendable=False, syntax='proto3', extension_ranges=[], oneofs=[ ], serialized_start=1639, serialized_end=1763, ) _WEIGHTEDSTRINGSTATISTICS = _descriptor.Descriptor( name='WeightedStringStatistics', full_name='featureStatistics.WeightedStringStatistics', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( name='top_values', full_name='featureStatistics.WeightedStringStatistics.top_values', index=0, number=1, type=11, cpp_type=10, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='rank_histogram', full_name='featureStatistics.WeightedStringStatistics.rank_histogram', index=1, number=2, type=11, cpp_type=10, label=1, has_default_value=False, default_value=None, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ ], options=None, is_extendable=False, syntax='proto3', extension_ranges=[], oneofs=[ ], serialized_start=1766, serialized_end=1920, ) _BYTESSTATISTICS = _descriptor.Descriptor( name='BytesStatistics', full_name='featureStatistics.BytesStatistics', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( name='common_stats', full_name='featureStatistics.BytesStatistics.common_stats', index=0, number=1, type=11, cpp_type=10, label=1, has_default_value=False, default_value=None, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='unique', full_name='featureStatistics.BytesStatistics.unique', index=1, number=2, type=4, cpp_type=4, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='avg_num_bytes', full_name='featureStatistics.BytesStatistics.avg_num_bytes', index=2, number=3, type=2, cpp_type=6, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='min_num_bytes', full_name='featureStatistics.BytesStatistics.min_num_bytes', index=3, number=4, type=2, cpp_type=6, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='max_num_bytes', full_name='featureStatistics.BytesStatistics.max_num_bytes', index=4, number=5, type=2, cpp_type=6, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ ], options=None, is_extendable=False, syntax='proto3', extension_ranges=[], oneofs=[ ], serialized_start=1923, serialized_end=2084, ) _COMMONSTATISTICS = _descriptor.Descriptor( name='CommonStatistics', full_name='featureStatistics.CommonStatistics', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( name='num_non_missing', full_name='featureStatistics.CommonStatistics.num_non_missing', index=0, number=1, type=4, cpp_type=4, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='num_missing', full_name='featureStatistics.CommonStatistics.num_missing', index=1, number=2, type=4, cpp_type=4, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='min_num_values', full_name='featureStatistics.CommonStatistics.min_num_values', index=2, number=3, type=4, cpp_type=4, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='max_num_values', full_name='featureStatistics.CommonStatistics.max_num_values', index=3, number=4, type=4, cpp_type=4, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='avg_num_values', full_name='featureStatistics.CommonStatistics.avg_num_values', index=4, number=5, type=2, cpp_type=6, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='tot_num_values', full_name='featureStatistics.CommonStatistics.tot_num_values', index=5, number=8, type=4, cpp_type=4, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='num_values_histogram', full_name='featureStatistics.CommonStatistics.num_values_histogram', index=6, number=6, type=11, cpp_type=10, label=1, has_default_value=False, default_value=None, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='weighted_common_stats', full_name='featureStatistics.CommonStatistics.weighted_common_stats', index=7, number=7, type=11, cpp_type=10, label=1, has_default_value=False, default_value=None, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='feature_list_length_histogram', full_name='featureStatistics.CommonStatistics.feature_list_length_histogram', index=8, number=9, type=11, cpp_type=10, label=1, has_default_value=False, default_value=None, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ ], options=None, is_extendable=False, syntax='proto3', extension_ranges=[], oneofs=[ ], serialized_start=2087, serialized_end=2452, ) _HISTOGRAM_BUCKET = _descriptor.Descriptor( name='Bucket', full_name='featureStatistics.Histogram.Bucket', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( name='low_value', full_name='featureStatistics.Histogram.Bucket.low_value', index=0, number=1, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='high_value', full_name='featureStatistics.Histogram.Bucket.high_value', index=1, number=2, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='deprecated_count', full_name='featureStatistics.Histogram.Bucket.deprecated_count', index=2, number=3, type=4, cpp_type=4, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\030\001'))), _descriptor.FieldDescriptor( name='sample_count', full_name='featureStatistics.Histogram.Bucket.sample_count', index=3, number=4, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ ], options=None, is_extendable=False, syntax='proto3', extension_ranges=[], oneofs=[ ], serialized_start=2634, serialized_end=2733, ) _HISTOGRAM = _descriptor.Descriptor( name='Histogram', full_name='featureStatistics.Histogram', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( name='num_nan', full_name='featureStatistics.Histogram.num_nan', index=0, number=1, type=4, cpp_type=4, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='num_undefined', full_name='featureStatistics.Histogram.num_undefined', index=1, number=2, type=4, cpp_type=4, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='buckets', full_name='featureStatistics.Histogram.buckets', index=2, number=3, type=11, cpp_type=10, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='type', full_name='featureStatistics.Histogram.type', index=3, number=4, type=14, cpp_type=8, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='name', full_name='featureStatistics.Histogram.name', index=4, number=5, type=9, cpp_type=9, label=1, has_default_value=False, default_value=_b("").decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[_HISTOGRAM_BUCKET, ], enum_types=[ _HISTOGRAM_HISTOGRAMTYPE, ], options=None, is_extendable=False, syntax='proto3', extension_ranges=[], oneofs=[ ], serialized_start=2455, serialized_end=2779, ) _RANKHISTOGRAM_BUCKET = _descriptor.Descriptor( name='Bucket', full_name='featureStatistics.RankHistogram.Bucket', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( name='low_rank', full_name='featureStatistics.RankHistogram.Bucket.low_rank', index=0, number=1, type=4, cpp_type=4, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='high_rank', full_name='featureStatistics.RankHistogram.Bucket.high_rank', index=1, number=2, type=4, cpp_type=4, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='deprecated_count', full_name='featureStatistics.RankHistogram.Bucket.deprecated_count', index=2, number=3, type=4, cpp_type=4, label=1, has_default_value=False, default_value=0, message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\030\001'))), _descriptor.FieldDescriptor( name='label', full_name='featureStatistics.RankHistogram.Bucket.label', index=3, number=4, type=9, cpp_type=9, label=1, has_default_value=False, default_value=_b("").decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='sample_count', full_name='featureStatistics.RankHistogram.Bucket.sample_count', index=4, number=5, type=1, cpp_type=5, label=1, has_default_value=False, default_value=float(0), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[], enum_types=[ ], options=None, is_extendable=False, syntax='proto3', extension_ranges=[], oneofs=[ ], serialized_start=2871, serialized_end=2983, ) _RANKHISTOGRAM = _descriptor.Descriptor( name='RankHistogram', full_name='featureStatistics.RankHistogram', filename=None, file=DESCRIPTOR, containing_type=None, fields=[ _descriptor.FieldDescriptor( name='buckets', full_name='featureStatistics.RankHistogram.buckets', index=0, number=1, type=11, cpp_type=10, label=3, has_default_value=False, default_value=[], message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), _descriptor.FieldDescriptor( name='name', full_name='featureStatistics.RankHistogram.name', index=1, number=2, type=9, cpp_type=9, label=1, has_default_value=False, default_value=_b("").decode('utf-8'), message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None), ], extensions=[ ], nested_types=[_RANKHISTOGRAM_BUCKET, ], enum_types=[ ], options=None, is_extendable=False, syntax='proto3', extension_ranges=[], oneofs=[ ], serialized_start=2782, serialized_end=2983, ) _DATASETFEATURESTATISTICSLIST.fields_by_name['datasets'].message_type = _DATASETFEATURESTATISTICS _DATASETFEATURESTATISTICS.fields_by_name['features'].message_type = _FEATURENAMESTATISTICS _FEATURENAMESTATISTICS.fields_by_name['type'].enum_type = _FEATURENAMESTATISTICS_TYPE _FEATURENAMESTATISTICS.fields_by_name['num_stats'].message_type = _NUMERICSTATISTICS _FEATURENAMESTATISTICS.fields_by_name['string_stats'].message_type = _STRINGSTATISTICS _FEATURENAMESTATISTICS.fields_by_name['bytes_stats'].message_type = _BYTESSTATISTICS _FEATURENAMESTATISTICS.fields_by_name['custom_stats'].message_type = _CUSTOMSTATISTIC _FEATURENAMESTATISTICS_TYPE.containing_type = _FEATURENAMESTATISTICS _FEATURENAMESTATISTICS.oneofs_by_name['stats'].fields.append( _FEATURENAMESTATISTICS.fields_by_name['num_stats']) _FEATURENAMESTATISTICS.fields_by_name['num_stats'].containing_oneof = _FEATURENAMESTATISTICS.oneofs_by_name['stats'] _FEATURENAMESTATISTICS.oneofs_by_name['stats'].fields.append( _FEATURENAMESTATISTICS.fields_by_name['string_stats']) _FEATURENAMESTATISTICS.fields_by_name['string_stats'].containing_oneof = _FEATURENAMESTATISTICS.oneofs_by_name['stats'] _FEATURENAMESTATISTICS.oneofs_by_name['stats'].fields.append( _FEATURENAMESTATISTICS.fields_by_name['bytes_stats']) _FEATURENAMESTATISTICS.fields_by_name['bytes_stats'].containing_oneof = _FEATURENAMESTATISTICS.oneofs_by_name['stats'] _CUSTOMSTATISTIC.fields_by_name['histogram'].message_type = _HISTOGRAM _CUSTOMSTATISTIC.oneofs_by_name['val'].fields.append( _CUSTOMSTATISTIC.fields_by_name['num']) _CUSTOMSTATISTIC.fields_by_name['num'].containing_oneof = _CUSTOMSTATISTIC.oneofs_by_name['val'] _CUSTOMSTATISTIC.oneofs_by_name['val'].fields.append( _CUSTOMSTATISTIC.fields_by_name['str']) _CUSTOMSTATISTIC.fields_by_name['str'].containing_oneof = _CUSTOMSTATISTIC.oneofs_by_name['val'] _CUSTOMSTATISTIC.oneofs_by_name['val'].fields.append( _CUSTOMSTATISTIC.fields_by_name['histogram']) _CUSTOMSTATISTIC.fields_by_name['histogram'].containing_oneof = _CUSTOMSTATISTIC.oneofs_by_name['val'] _NUMERICSTATISTICS.fields_by_name['common_stats'].message_type = _COMMONSTATISTICS _NUMERICSTATISTICS.fields_by_name['histograms'].message_type = _HISTOGRAM _NUMERICSTATISTICS.fields_by_name['weighted_numeric_stats'].message_type = _WEIGHTEDNUMERICSTATISTICS _STRINGSTATISTICS_FREQANDVALUE.containing_type = _STRINGSTATISTICS _STRINGSTATISTICS.fields_by_name['common_stats'].message_type = _COMMONSTATISTICS _STRINGSTATISTICS.fields_by_name['top_values'].message_type = _STRINGSTATISTICS_FREQANDVALUE _STRINGSTATISTICS.fields_by_name['rank_histogram'].message_type = _RANKHISTOGRAM _STRINGSTATISTICS.fields_by_name['weighted_string_stats'].message_type = _WEIGHTEDSTRINGSTATISTICS _WEIGHTEDNUMERICSTATISTICS.fields_by_name['histograms'].message_type = _HISTOGRAM _WEIGHTEDSTRINGSTATISTICS.fields_by_name['top_values'].message_type = _STRINGSTATISTICS_FREQANDVALUE _WEIGHTEDSTRINGSTATISTICS.fields_by_name['rank_histogram'].message_type = _RANKHISTOGRAM _BYTESSTATISTICS.fields_by_name['common_stats'].message_type = _COMMONSTATISTICS _COMMONSTATISTICS.fields_by_name['num_values_histogram'].message_type = _HISTOGRAM _COMMONSTATISTICS.fields_by_name['weighted_common_stats'].message_type = _WEIGHTEDCOMMONSTATISTICS _COMMONSTATISTICS.fields_by_name['feature_list_length_histogram'].message_type = _HISTOGRAM _HISTOGRAM_BUCKET.containing_type = _HISTOGRAM _HISTOGRAM.fields_by_name['buckets'].message_type = _HISTOGRAM_BUCKET _HISTOGRAM.fields_by_name['type'].enum_type = _HISTOGRAM_HISTOGRAMTYPE _HISTOGRAM_HISTOGRAMTYPE.containing_type = _HISTOGRAM _RANKHISTOGRAM_BUCKET.containing_type = _RANKHISTOGRAM _RANKHISTOGRAM.fields_by_name['buckets'].message_type = _RANKHISTOGRAM_BUCKET DESCRIPTOR.message_types_by_name['DatasetFeatureStatisticsList'] = _DATASETFEATURESTATISTICSLIST DESCRIPTOR.message_types_by_name['DatasetFeatureStatistics'] = _DATASETFEATURESTATISTICS DESCRIPTOR.message_types_by_name['FeatureNameStatistics'] = _FEATURENAMESTATISTICS DESCRIPTOR.message_types_by_name['WeightedCommonStatistics'] = _WEIGHTEDCOMMONSTATISTICS DESCRIPTOR.message_types_by_name['CustomStatistic'] = _CUSTOMSTATISTIC DESCRIPTOR.message_types_by_name['NumericStatistics'] = _NUMERICSTATISTICS DESCRIPTOR.message_types_by_name['StringStatistics'] = _STRINGSTATISTICS DESCRIPTOR.message_types_by_name['WeightedNumericStatistics'] = _WEIGHTEDNUMERICSTATISTICS DESCRIPTOR.message_types_by_name['WeightedStringStatistics'] = _WEIGHTEDSTRINGSTATISTICS DESCRIPTOR.message_types_by_name['BytesStatistics'] = _BYTESSTATISTICS DESCRIPTOR.message_types_by_name['CommonStatistics'] = _COMMONSTATISTICS DESCRIPTOR.message_types_by_name['Histogram'] = _HISTOGRAM DESCRIPTOR.message_types_by_name['RankHistogram'] = _RANKHISTOGRAM _sym_db.RegisterFileDescriptor(DESCRIPTOR) DatasetFeatureStatisticsList = _reflection.GeneratedProtocolMessageType('DatasetFeatureStatisticsList', (_message.Message,), dict( DESCRIPTOR = _DATASETFEATURESTATISTICSLIST, __module__ = 'feature_statistics_pb2' # @@protoc_insertion_point(class_scope:featureStatistics.DatasetFeatureStatisticsList) )) _sym_db.RegisterMessage(DatasetFeatureStatisticsList) DatasetFeatureStatistics = _reflection.GeneratedProtocolMessageType('DatasetFeatureStatistics', (_message.Message,), dict( DESCRIPTOR = _DATASETFEATURESTATISTICS, __module__ = 'feature_statistics_pb2' # @@protoc_insertion_point(class_scope:featureStatistics.DatasetFeatureStatistics) )) _sym_db.RegisterMessage(DatasetFeatureStatistics) FeatureNameStatistics = _reflection.GeneratedProtocolMessageType('FeatureNameStatistics', (_message.Message,), dict( DESCRIPTOR = _FEATURENAMESTATISTICS, __module__ = 'feature_statistics_pb2' # @@protoc_insertion_point(class_scope:featureStatistics.FeatureNameStatistics) )) _sym_db.RegisterMessage(FeatureNameStatistics) WeightedCommonStatistics = _reflection.GeneratedProtocolMessageType('WeightedCommonStatistics', (_message.Message,), dict( DESCRIPTOR = _WEIGHTEDCOMMONSTATISTICS, __module__ = 'feature_statistics_pb2' # @@protoc_insertion_point(class_scope:featureStatistics.WeightedCommonStatistics) )) _sym_db.RegisterMessage(WeightedCommonStatistics) CustomStatistic = _reflection.GeneratedProtocolMessageType('CustomStatistic', (_message.Message,), dict( DESCRIPTOR = _CUSTOMSTATISTIC, __module__ = 'feature_statistics_pb2' # @@protoc_insertion_point(class_scope:featureStatistics.CustomStatistic) )) _sym_db.RegisterMessage(CustomStatistic) NumericStatistics = _reflection.GeneratedProtocolMessageType('NumericStatistics', (_message.Message,), dict( DESCRIPTOR = _NUMERICSTATISTICS, __module__ = 'feature_statistics_pb2' # @@protoc_insertion_point(class_scope:featureStatistics.NumericStatistics) )) _sym_db.RegisterMessage(NumericStatistics) StringStatistics = _reflection.GeneratedProtocolMessageType('StringStatistics', (_message.Message,), dict( FreqAndValue = _reflection.GeneratedProtocolMessageType('FreqAndValue', (_message.Message,), dict( DESCRIPTOR = _STRINGSTATISTICS_FREQANDVALUE, __module__ = 'feature_statistics_pb2' # @@protoc_insertion_point(class_scope:featureStatistics.StringStatistics.FreqAndValue) )) , DESCRIPTOR = _STRINGSTATISTICS, __module__ = 'feature_statistics_pb2' # @@protoc_insertion_point(class_scope:featureStatistics.StringStatistics) )) _sym_db.RegisterMessage(StringStatistics) _sym_db.RegisterMessage(StringStatistics.FreqAndValue) WeightedNumericStatistics = _reflection.GeneratedProtocolMessageType('WeightedNumericStatistics', (_message.Message,), dict( DESCRIPTOR = _WEIGHTEDNUMERICSTATISTICS, __module__ = 'feature_statistics_pb2' # @@protoc_insertion_point(class_scope:featureStatistics.WeightedNumericStatistics) )) _sym_db.RegisterMessage(WeightedNumericStatistics) WeightedStringStatistics = _reflection.GeneratedProtocolMessageType('WeightedStringStatistics', (_message.Message,), dict( DESCRIPTOR = _WEIGHTEDSTRINGSTATISTICS, __module__ = 'feature_statistics_pb2' # @@protoc_insertion_point(class_scope:featureStatistics.WeightedStringStatistics) )) _sym_db.RegisterMessage(WeightedStringStatistics) BytesStatistics = _reflection.GeneratedProtocolMessageType('BytesStatistics', (_message.Message,), dict( DESCRIPTOR = _BYTESSTATISTICS, __module__ = 'feature_statistics_pb2' # @@protoc_insertion_point(class_scope:featureStatistics.BytesStatistics) )) _sym_db.RegisterMessage(BytesStatistics) CommonStatistics = _reflection.GeneratedProtocolMessageType('CommonStatistics', (_message.Message,), dict( DESCRIPTOR = _COMMONSTATISTICS, __module__ = 'feature_statistics_pb2' # @@protoc_insertion_point(class_scope:featureStatistics.CommonStatistics) )) _sym_db.RegisterMessage(CommonStatistics) Histogram = _reflection.GeneratedProtocolMessageType('Histogram', (_message.Message,), dict( Bucket = _reflection.GeneratedProtocolMessageType('Bucket', (_message.Message,), dict( DESCRIPTOR = _HISTOGRAM_BUCKET, __module__ = 'feature_statistics_pb2' # @@protoc_insertion_point(class_scope:featureStatistics.Histogram.Bucket) )) , DESCRIPTOR = _HISTOGRAM, __module__ = 'feature_statistics_pb2' # @@protoc_insertion_point(class_scope:featureStatistics.Histogram) )) _sym_db.RegisterMessage(Histogram) _sym_db.RegisterMessage(Histogram.Bucket) RankHistogram = _reflection.GeneratedProtocolMessageType('RankHistogram', (_message.Message,), dict( Bucket = _reflection.GeneratedProtocolMessageType('Bucket', (_message.Message,), dict( DESCRIPTOR = _RANKHISTOGRAM_BUCKET, __module__ = 'feature_statistics_pb2' # @@protoc_insertion_point(class_scope:featureStatistics.RankHistogram.Bucket) )) , DESCRIPTOR = _RANKHISTOGRAM, __module__ = 'feature_statistics_pb2' # @@protoc_insertion_point(class_scope:featureStatistics.RankHistogram) )) _sym_db.RegisterMessage(RankHistogram) _sym_db.RegisterMessage(RankHistogram.Bucket) _STRINGSTATISTICS_FREQANDVALUE.fields_by_name['deprecated_freq'].has_options = True _STRINGSTATISTICS_FREQANDVALUE.fields_by_name['deprecated_freq']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\030\001')) _HISTOGRAM_BUCKET.fields_by_name['deprecated_count'].has_options = True _HISTOGRAM_BUCKET.fields_by_name['deprecated_count']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\030\001')) _RANKHISTOGRAM_BUCKET.fields_by_name['deprecated_count'].has_options = True _RANKHISTOGRAM_BUCKET.fields_by_name['deprecated_count']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\030\001')) # @@protoc_insertion_point(module_scope) ================================================ FILE: google/datalab/utils/facets/generic_feature_statistics_generator.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Code for generating the feature_statistics proto from generic data. The proto is used as input for the Overview visualization. """ import warnings from .base_generic_feature_statistics_generator import BaseGenericFeatureStatisticsGenerator from . import feature_statistics_pb2 as fs class GenericFeatureStatisticsGenerator(BaseGenericFeatureStatisticsGenerator): """Generator of stats proto from generic data.""" def __init__(self): BaseGenericFeatureStatisticsGenerator.__init__( self, fs.FeatureNameStatistics, fs.DatasetFeatureStatisticsList, fs.Histogram) def ProtoFromDataFrames(dataframes): """Creates a feature statistics proto from a set of pandas dataframes. Args: dataframes: A list of dicts describing tables for each dataset for the proto. Each entry contains a 'table' field of the dataframe of the data and a 'name' field to identify the dataset in the proto. Returns: The feature statistics proto for the provided tables. """ warnings.warn( 'Use GenericFeatureStatisticsGenerator class method instead.', DeprecationWarning) return GenericFeatureStatisticsGenerator().ProtoFromDataFrames(dataframes) ================================================ FILE: install-no-virtualenv.sh ================================================ #!/bin/sh -e # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. # Build a distribution package tsc --module amd --noImplicitAny --outdir datalab/notebook/static datalab/notebook/static/*.ts pip install . jupyter nbextension install --py datalab.notebook rm datalab/notebook/static/*.js ================================================ FILE: install-virtualenv.sh ================================================ #!/bin/sh -e # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. # Build a distribution package tsc --module amd --noImplicitAny --outdir datalab/notebook/static datalab/notebook/static/*.ts pip install . jupyter nbextension install --py datalab.notebook --sys-prefix rm datalab/notebook/static/*.js ================================================ FILE: legacy_tests/_util/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: legacy_tests/_util/http_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import mock import unittest # The httplib2 import is implicitly used when mocking its functionality. # pylint: disable=unused-import from datalab.utils._http import Http class TestCases(unittest.TestCase): @mock.patch('httplib2.Response') @mock.patch('datalab.utils._http.Http.http.request') def test_get_request_is_invoked(self, mock_request, mock_response): TestCases._setup_mocks(mock_request, mock_response, '{}') Http.request('http://www.example.org') self.assertEqual(mock_request.call_count, 1) self.assertEqual(mock_request.call_args[1]['method'], 'GET') @mock.patch('httplib2.Response') @mock.patch('datalab.utils._http.Http.http.request') def test_post_request_is_invoked(self, mock_request, mock_response): TestCases._setup_mocks(mock_request, mock_response, '{}') Http.request('http://www.example.org', data={}) self.assertEqual(mock_request.call_args[1]['method'], 'POST') @mock.patch('httplib2.Response') @mock.patch('datalab.utils._http.Http.http.request') def test_explicit_post_request_is_invoked(self, mock_request, mock_response): TestCases._setup_mocks(mock_request, mock_response, '{}') Http.request('http://www.example.org', method='POST') self.assertEqual(mock_request.call_args[1]['method'], 'POST') @mock.patch('httplib2.Response') @mock.patch('datalab.utils._http.Http.http.request') def test_query_string_format(self, mock_request, mock_response): TestCases._setup_mocks(mock_request, mock_response, '{}') Http.request('http://www.example.org', args={'a': 1, 'b': 'a b c'}) parts = mock_request.call_args[0][0].replace('?', '&').split('&') self.assertEqual(parts[0], 'http://www.example.org') self.assertTrue('a=1' in parts[1:]) self.assertTrue('b=a+b+c' in parts[1:]) @mock.patch('httplib2.Response') @mock.patch('datalab.utils._http.Http.http.request') def test_formats_json_request(self, mock_request, mock_response): TestCases._setup_mocks(mock_request, mock_response, '{}') data = {'abc': 123} Http.request('http://www.example.org', data=data) self.assertEqual(mock_request.call_args[1]['body'], '{"abc": 123}') self.assertEqual(mock_request.call_args[1]['headers']['Content-Type'], 'application/json') @mock.patch('httplib2.Response') @mock.patch('datalab.utils._http.Http.http.request') def test_supports_custom_content(self, mock_request, mock_response): TestCases._setup_mocks(mock_request, mock_response, '{}') headers = {'Content-Type': 'text/plain'} data = 'custom text' Http.request('http://www.example.org', data=data, headers=headers) self.assertEqual(mock_request.call_args[1]['body'], 'custom text') self.assertEqual(mock_request.call_args[1]['headers']['Content-Type'], 'text/plain') @mock.patch('httplib2.Response') @mock.patch('datalab.utils._http.Http.http.request') def test_parses_json_response(self, mock_request, mock_response): TestCases._setup_mocks(mock_request, mock_response, '{"abc":123}') data = Http.request('http://www.example.org') self.assertEqual(data['abc'], 123) @mock.patch('httplib2.Response') @mock.patch('datalab.utils._http.Http.http.request') def test_raises_http_error(self, mock_request, mock_response): TestCases._setup_mocks(mock_request, mock_response, 'Not Found', 404) with self.assertRaises(Exception) as error: Http.request('http://www.example.org') e = error.exception self.assertEqual(e.status, 404) self.assertEqual(e.content, 'Not Found') @staticmethod def _setup_mocks(mock_request, mock_response, content, status=200): response = mock_response() response.status = status mock_request.return_value = (response, content) ================================================ FILE: legacy_tests/_util/lru_cache_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import unittest from datalab.utils._lru_cache import LRUCache class TestCases(unittest.TestCase): def test_cache_no_entry(self): cache = LRUCache(3) with self.assertRaises(KeyError): cache['a'] def test_cache_lookup(self): cache = LRUCache(4) for x in ['a', 'b', 'c', 'd']: cache[x] = x for x in ['a', 'b', 'c', 'd']: self.assertEqual(x, cache[x]) def test_cache_overflow(self): cache = LRUCache(3) for x in ['a', 'b', 'c', 'd']: cache[x] = x for x in ['b', 'c', 'd']: self.assertEqual(x, cache[x]) with self.assertRaises(KeyError): cache['a'] cache['b'] cache['d'] # 'c' should be LRU now cache['e'] = 'e' with self.assertRaises(KeyError): cache['c'] for x in ['b', 'd', 'e']: self.assertEqual(x, cache[x]) ================================================ FILE: legacy_tests/_util/util_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import imp import unittest from datalab.utils._utils import get_item class TestCases(unittest.TestCase): @staticmethod def _get_data(): m = imp.new_module('baz') exec('x = 99', m.__dict__) data = { 'foo': { 'bar': { 'xyz': 0 }, 'm': m } } return data def test_no_entry(self): data = TestCases._get_data() self.assertIsNone(get_item(data, 'x')) self.assertIsNone(get_item(data, 'bar.x')) self.assertIsNone(get_item(data, 'foo.bar.x')) self.assertIsNone(get_item(globals(), 'datetime.bar.x')) def test_entry(self): data = TestCases._get_data() self.assertEquals(data['foo']['bar']['xyz'], get_item(data, 'foo.bar.xyz')) self.assertEquals(data['foo']['bar'], get_item(data, 'foo.bar')) self.assertEquals(data['foo'], get_item(data, 'foo')) self.assertEquals(data['foo']['m'], get_item(data, 'foo.m')) self.assertEquals(99, get_item(data, 'foo.m.x')) ================================================ FILE: legacy_tests/bigquery/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: legacy_tests/bigquery/api_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import unittest import mock import google.auth import datalab.bigquery import datalab.context import datalab.utils from datalab.bigquery._api import Api class TestCases(unittest.TestCase): def validate(self, mock_http_request, expected_url, expected_args=None, expected_data=None, expected_headers=None, expected_method=None): url = mock_http_request.call_args[0][0] kwargs = mock_http_request.call_args[1] self.assertEquals(expected_url, url) if expected_args is not None: self.assertEquals(expected_args, kwargs['args']) else: self.assertNotIn('args', kwargs) if expected_data is not None: self.assertEquals(expected_data, kwargs['data']) else: self.assertNotIn('data', kwargs) if expected_headers is not None: self.assertEquals(expected_headers, kwargs['headers']) else: self.assertNotIn('headers', kwargs) if expected_method is not None: self.assertEquals(expected_method, kwargs['method']) else: self.assertNotIn('method', kwargs) @mock.patch('datalab.utils.Http.request') def test_jobs_insert_load(self, mock_http_request): api = TestCases._create_api() api.jobs_insert_load('SOURCE', datalab.bigquery._utils.TableName('p', 'd', 't', '')) self.maxDiff = None expected_data = { 'kind': 'bigquery#job', 'configuration': { 'load': { 'sourceUris': ['SOURCE'], 'destinationTable': { 'projectId': 'p', 'datasetId': 'd', 'tableId': 't' }, 'createDisposition': 'CREATE_NEVER', 'writeDisposition': 'WRITE_EMPTY', 'sourceFormat': 'CSV', 'fieldDelimiter': ',', 'allowJaggedRows': False, 'allowQuotedNewlines': False, 'encoding': 'UTF-8', 'ignoreUnknownValues': False, 'maxBadRecords': 0, 'quote': '"', 'skipLeadingRows': 0 } } } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/jobs/', expected_data=expected_data) api.jobs_insert_load('SOURCE2', datalab.bigquery._utils.TableName('p2', 'd2', 't2', ''), append=True, create=True, allow_jagged_rows=True, allow_quoted_newlines=True, ignore_unknown_values=True, source_format='JSON', max_bad_records=1) expected_data = { 'kind': 'bigquery#job', 'configuration': { 'load': { 'sourceUris': ['SOURCE2'], 'destinationTable': { 'projectId': 'p2', 'datasetId': 'd2', 'tableId': 't2' }, 'createDisposition': 'CREATE_IF_NEEDED', 'writeDisposition': 'WRITE_APPEND', 'sourceFormat': 'JSON', 'ignoreUnknownValues': True, 'maxBadRecords': 1 } } } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p2/jobs/', expected_data=expected_data) @mock.patch('datalab.utils.Http.request') def test_jobs_insert_query(self, mock_http_request): api = TestCases._create_api() api.jobs_insert_query('SQL') expected_data = { 'kind': 'bigquery#job', 'configuration': { 'query': { 'query': 'SQL', 'useQueryCache': True, 'userDefinedFunctionResources': [], 'allowLargeResults': False, 'useLegacySql': True, }, 'dryRun': False, 'priority': 'BATCH', }, } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/test/jobs/', expected_data=expected_data) api.jobs_insert_query('SQL2', ['CODE'], table_name=datalab.bigquery._utils.TableName('p', 'd', 't', ''), append=True, dry_run=True, use_cache=False, batch=False, allow_large_results=True, dialect='standard', billing_tier=1) expected_data = { 'kind': 'bigquery#job', 'configuration': { 'query': { 'query': 'SQL2', 'useQueryCache': False, 'allowLargeResults': True, 'useLegacySql': False, 'maximumBillingTier': 1, 'destinationTable': { 'projectId': 'p', 'datasetId': 'd', 'tableId': 't' }, 'writeDisposition': 'WRITE_APPEND', 'userDefinedFunctionResources': [ { 'inlineCode': 'CODE' } ] }, 'dryRun': True, 'priority': 'INTERACTIVE', }, } self.maxDiff = None self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/test/jobs/', expected_data=expected_data) @mock.patch('datalab.utils.Http.request') def test_jobs_query_results(self, mock_http_request): api = TestCases._create_api() api.jobs_query_results('JOB', 'PROJECT', 10, 20, 30) self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/PROJECT/queries/JOB', expected_args={'maxResults': 10, 'timeoutMs': 20, 'startIndex': 30}) @mock.patch('datalab.utils.Http.request') def test_jobs_get(self, mock_http_request): api = TestCases._create_api() api.jobs_get('JOB', 'PROJECT') self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/PROJECT/jobs/JOB') @mock.patch('datalab.utils.Http.request') def test_datasets_insert(self, mock_http_request): api = TestCases._create_api() api.datasets_insert(datalab.bigquery._utils.DatasetName('p', 'd')) expected_data = { 'kind': 'bigquery#dataset', 'datasetReference': { 'projectId': 'p', 'datasetId': 'd', } } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/', expected_data=expected_data) api.datasets_insert(datalab.bigquery._utils.DatasetName('p', 'd'), 'FRIENDLY', 'DESCRIPTION') expected_data = { 'kind': 'bigquery#dataset', 'datasetReference': { 'projectId': 'p', 'datasetId': 'd' }, 'friendlyName': 'FRIENDLY', 'description': 'DESCRIPTION' } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/', expected_data=expected_data) @mock.patch('datalab.utils.Http.request') def test_datasets_delete(self, mock_http_request): api = TestCases._create_api() api.datasets_delete(datalab.bigquery._utils.DatasetName('p', 'd'), False) self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d', expected_args={}, expected_method='DELETE') api.datasets_delete(datalab.bigquery._utils.DatasetName('p', 'd'), True) self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d', expected_args={'deleteContents': True}, expected_method='DELETE') @mock.patch('datalab.utils.Http.request') def test_datasets_update(self, mock_http_request): api = TestCases._create_api() api.datasets_update(datalab.bigquery._utils.DatasetName('p', 'd'), 'INFO') self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d', expected_method='PUT', expected_data='INFO') @mock.patch('datalab.utils.Http.request') def test_datasets_get(self, mock_http_request): api = TestCases._create_api() api.datasets_get(datalab.bigquery._utils.DatasetName('p', 'd')) self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d') @mock.patch('datalab.utils.Http.request') def test_datasets_list(self, mock_http_request): api = TestCases._create_api() api.datasets_list() self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/test/datasets/', expected_args={}) api.datasets_list('PROJECT', 10, 'TOKEN') self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/PROJECT/datasets/', expected_args={'maxResults': 10, 'pageToken': 'TOKEN'}) @mock.patch('datalab.utils.Http.request') def test_tables_get(self, mock_http_request): api = TestCases._create_api() api.tables_get(datalab.bigquery._utils.TableName('p', 'd', 't', '')) self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/t') @mock.patch('datalab.utils.Http.request') def test_tables_list(self, mock_http_request): api = TestCases._create_api() api.tables_list(datalab.bigquery._utils.DatasetName('p', 'd')) self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/', expected_args={}) api.tables_list(datalab.bigquery._utils.DatasetName('p', 'd'), 10, 'TOKEN') self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/', expected_args={'maxResults': 10, 'pageToken': 'TOKEN'}) @mock.patch('datalab.utils.Http.request') def test_tables_insert(self, mock_http_request): api = TestCases._create_api() api.tables_insert(datalab.bigquery._utils.TableName('p', 'd', 't', '')) expected_data = { 'kind': 'bigquery#table', 'tableReference': { 'projectId': 'p', 'datasetId': 'd', 'tableId': 't' } } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/', expected_data=expected_data) api.tables_insert(datalab.bigquery._utils.TableName('p', 'd', 't', ''), 'SCHEMA', 'QUERY', 'FRIENDLY', 'DESCRIPTION') expected_data = { 'kind': 'bigquery#table', 'tableReference': { 'projectId': 'p', 'datasetId': 'd', 'tableId': 't' }, 'schema': { 'fields': 'SCHEMA' }, 'view': {'query': 'QUERY'}, 'friendlyName': 'FRIENDLY', 'description': 'DESCRIPTION' } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/', expected_data=expected_data) @mock.patch('datalab.utils.Http.request') def test_tabledata_insertAll(self, mock_http_request): api = TestCases._create_api() api.tabledata_insert_all(datalab.bigquery._utils.TableName('p', 'd', 't', ''), 'ROWS') expected_data = { 'kind': 'bigquery#tableDataInsertAllRequest', 'rows': 'ROWS' } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/t/insertAll', expected_data=expected_data) @mock.patch('datalab.utils.Http.request') def test_tabledata_list(self, mock_http_request): api = TestCases._create_api() api.tabledata_list(datalab.bigquery._utils.TableName('p', 'd', 't', '')) self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/t/data', expected_args={}) api.tabledata_list(datalab.bigquery._utils.TableName('p', 'd', 't', ''), 10, 20, 'TOKEN') self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/t/data', expected_args={ 'startIndex': 10, 'maxResults': 20, 'pageToken': 'TOKEN' }) @mock.patch('datalab.utils.Http.request') def test_table_delete(self, mock_http_request): api = TestCases._create_api() api.table_delete(datalab.bigquery._utils.TableName('p', 'd', 't', '')) self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/t', expected_method='DELETE') @mock.patch('datalab.utils.Http.request') def test_table_extract(self, mock_http_request): api = TestCases._create_api() api.table_extract(datalab.bigquery._utils.TableName('p', 'd', 't', ''), 'DEST') expected_data = { 'kind': 'bigquery#job', 'configuration': { 'extract': { 'sourceTable': { 'projectId': 'p', 'datasetId': 'd', 'tableId': 't' }, 'compression': 'GZIP', 'fieldDelimiter': ',', 'printHeader': True, 'destinationUris': ['DEST'], 'destinationFormat': 'CSV', } } } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/jobs/', expected_data=expected_data) api.table_extract(datalab.bigquery._utils.TableName('p', 'd', 't', ''), ['DEST'], format='JSON', compress=False, field_delimiter=':', print_header=False) expected_data = { 'kind': 'bigquery#job', 'configuration': { 'extract': { 'sourceTable': { 'projectId': 'p', 'datasetId': 'd', 'tableId': 't' }, 'compression': 'NONE', 'fieldDelimiter': ':', 'printHeader': False, 'destinationUris': ['DEST'], 'destinationFormat': 'JSON', } } } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/jobs/', expected_data=expected_data) @mock.patch('datalab.utils.Http.request') def test_table_update(self, mock_http_request): api = TestCases._create_api() api.table_update(datalab.bigquery._utils.TableName('p', 'd', 't', ''), 'INFO') self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/t', expected_method='PUT', expected_data='INFO') @staticmethod def _create_api(): context = TestCases._create_context() return Api(context) @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return datalab.context.Context(project_id, creds) ================================================ FILE: legacy_tests/bigquery/dataset_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals from builtins import str import mock import unittest import google.auth import datalab.bigquery import datalab.context import datalab.utils class TestCases(unittest.TestCase): def _check_name_parts(self, dataset): parsed_name = dataset._name_parts self.assertEqual('test', parsed_name[0]) self.assertEqual('requestlogs', parsed_name[1]) self.assertEqual('test:requestlogs', dataset._full_name) self.assertEqual('test:requestlogs', str(dataset)) def test_parse_full_name(self): dataset = TestCases._create_dataset('test:requestlogs') self._check_name_parts(dataset) def test_parse_local_name(self): dataset = TestCases._create_dataset('requestlogs') self._check_name_parts(dataset) def test_parse_dict_full_name(self): dataset = TestCases._create_dataset({'project_id': 'test', 'dataset_id': 'requestlogs'}) self._check_name_parts(dataset) def test_parse_dict_local_name(self): dataset = TestCases._create_dataset({'dataset_id': 'requestlogs'}) self._check_name_parts(dataset) def test_parse_named_tuple_name(self): dataset = TestCases._create_dataset(datalab.bigquery._utils.DatasetName('test', 'requestlogs')) self._check_name_parts(dataset) def test_parse_tuple_full_name(self): dataset = TestCases._create_dataset(('test', 'requestlogs')) self._check_name_parts(dataset) def test_parse_tuple_local(self): dataset = TestCases._create_dataset(('requestlogs')) self._check_name_parts(dataset) def test_parse_array_full_name(self): dataset = TestCases._create_dataset(['test', 'requestlogs']) self._check_name_parts(dataset) def test_parse_array_local(self): dataset = TestCases._create_dataset(['requestlogs']) self._check_name_parts(dataset) def test_parse_invalid_name(self): with self.assertRaises(Exception): TestCases._create_dataset('today@') @mock.patch('datalab.bigquery._api.Api.datasets_get') def test_dataset_exists(self, mock_api_datasets_get): mock_api_datasets_get.return_value = '' dataset = TestCases._create_dataset('test:requestlogs') self.assertTrue(dataset.exists()) mock_api_datasets_get.side_effect = datalab.utils.RequestException(404, None) dataset._info = None self.assertFalse(dataset.exists()) @mock.patch('datalab.bigquery._api.Api.datasets_insert') @mock.patch('datalab.bigquery._api.Api.datasets_get') def test_datasets_create_fails(self, mock_api_datasets_get, mock_api_datasets_insert): mock_api_datasets_get.side_effect = datalab.utils.RequestException(None, 404) mock_api_datasets_insert.return_value = {} ds = TestCases._create_dataset('requestlogs') with self.assertRaises(Exception): ds.create() @mock.patch('datalab.bigquery._api.Api.datasets_insert') @mock.patch('datalab.bigquery._api.Api.datasets_get') def test_datasets_create_succeeds(self, mock_api_datasets_get, mock_api_datasets_insert): mock_api_datasets_get.side_effect = datalab.utils.RequestException(404, None) mock_api_datasets_insert.return_value = {'selfLink': None} ds = TestCases._create_dataset('requestlogs') self.assertEqual(ds, ds.create()) @mock.patch('datalab.bigquery._api.Api.datasets_insert') @mock.patch('datalab.bigquery._api.Api.datasets_get') def test_datasets_create_redundant(self, mock_api_datasets_get, mock_api_datasets_insert): ds = TestCases._create_dataset('requestlogs', {}) mock_api_datasets_get.return_value = None mock_api_datasets_insert.return_value = {} self.assertEqual(ds, ds.create()) @mock.patch('datalab.bigquery._api.Api.datasets_get') @mock.patch('datalab.bigquery._api.Api.datasets_delete') def test_datasets_delete_succeeds(self, mock_api_datasets_delete, mock_api_datasets_get): mock_api_datasets_get.return_value = '' mock_api_datasets_delete.return_value = None ds = TestCases._create_dataset('requestlogs') self.assertIsNone(ds.delete()) @mock.patch('datalab.bigquery._api.Api.datasets_get') @mock.patch('datalab.bigquery._api.Api.datasets_delete') def test_datasets_delete_fails(self, mock_api_datasets_delete, mock_api_datasets_get): mock_api_datasets_delete.return_value = None mock_api_datasets_get.side_effect = datalab.utils.RequestException(404, None) ds = TestCases._create_dataset('requestlogs') with self.assertRaises(Exception): ds.delete() @mock.patch('datalab.bigquery._api.Api.tables_list') def test_tables_list(self, mock_api_tables_list): mock_api_tables_list.return_value = { 'tables': [ { 'type': 'TABLE', 'tableReference': {'projectId': 'p', 'datasetId': 'd', 'tableId': 't1'} }, { 'type': 'TABLE', 'tableReference': {'projectId': 'p', 'datasetId': 'd', 'tableId': 't2'} }, ] } ds = TestCases._create_dataset('requestlogs') tables = [table for table in ds] self.assertEqual(2, len(tables)) self.assertEqual('p:d.t1', str(tables[0])) self.assertEqual('p:d.t2', str(tables[1])) @mock.patch('datalab.bigquery.Dataset._get_info') @mock.patch('datalab.bigquery._api.Api.datasets_list') def test_datasets_list(self, mock_api_datasets_list, mock_dataset_get_info): mock_api_datasets_list.return_value = { 'datasets': [ {'datasetReference': {'projectId': 'p', 'datasetId': 'd1'}}, {'datasetReference': {'projectId': 'p', 'datasetId': 'd2'}}, ] } mock_dataset_get_info.return_value = {} datasets = [dataset for dataset in datalab.bigquery.Datasets('test', TestCases._create_context())] self.assertEqual(2, len(datasets)) self.assertEqual('p:d1', str(datasets[0])) self.assertEqual('p:d2', str(datasets[1])) @mock.patch('datalab.bigquery._api.Api.tables_list') @mock.patch('datalab.bigquery._api.Api.datasets_get') @mock.patch('datalab.bigquery._api.Api.datasets_update') def test_datasets_update(self, mock_api_datasets_update, mock_api_datasets_get, mock_api_tables_list): mock_api_tables_list.return_value = { 'tables': [ {'type': 'TABLE', 'tableReference': {'projectId': 'p', 'datasetId': 'd', 'tableId': 't1'}}, {'type': 'TABLE', 'tableReference': {'projectId': 'p', 'datasetId': 'd', 'tableId': 't2'}}, ] } info = {'friendlyName': 'casper', 'description': 'ghostly logs'} mock_api_datasets_get.return_value = info ds = TestCases._create_dataset('requestlogs') new_friendly_name = 'aziraphale' new_description = 'demon duties' ds.update(new_friendly_name, new_description) name, info = mock_api_datasets_update.call_args[0] self.assertEqual(ds.name, name) self.assertEqual(new_friendly_name, ds.friendly_name) self.assertEqual(new_description, ds.description) @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return datalab.context.Context(project_id, creds) @staticmethod def _create_dataset(name, metadata=None): # Patch get_info so we don't have to mock it everywhere else. orig = datalab.bigquery.Dataset._get_info datalab.bigquery.Dataset._get_info = mock.Mock(return_value=metadata) ds = datalab.bigquery.Dataset(name, context=TestCases._create_context()) datalab.bigquery.Dataset._get_info = orig return ds ================================================ FILE: legacy_tests/bigquery/federated_table_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import collections import mock import unittest import google.auth import datalab.bigquery import datalab.context import datalab.utils class TestCases(unittest.TestCase): # The main thing we need to test is a query that references an external table and how # that translates into a REST call. @staticmethod def _request_result(): return { 'jobReference': { 'jobId': 'job1234' }, 'configuration': { 'query': { 'destinationTable': { 'projectId': 'test', 'datasetId': 'dataset', 'tableId': 'table' } } }, 'jobComplete': True } @staticmethod def _get_data(): data = [] day = 1 for weight in [220, 221, 220, 219, 218]: d = collections.OrderedDict() data.append(d) d['day'] = day day += 1 d['weight'] = weight return data @staticmethod def _get_table_definition(uris, skip_rows=0): if not isinstance(uris, list): uris = [uris] return { 'compression': 'NONE', 'csvOptions': { 'allowJaggedRows': False, 'quote': '"', 'encoding': 'UTF-8', 'skipLeadingRows': skip_rows, 'fieldDelimiter': ',', 'allowQuotedNewlines': False }, 'sourceFormat': 'CSV', 'maxBadRecords': 0, 'ignoreUnknownValues': False, 'sourceUris': uris, 'schema': { 'fields': [ {'type': 'INTEGER', 'name': 'day'}, {'type': 'INTEGER', 'name': 'weight'} ] } } @staticmethod def _get_expected_request_data(sql, table_definitions): return { 'kind': 'bigquery#job', 'configuration': { 'priority': 'INTERACTIVE', 'query': { 'query': sql, 'useLegacySql': True, 'allowLargeResults': False, 'tableDefinitions': table_definitions, 'useQueryCache': True, 'userDefinedFunctionResources': [] }, 'dryRun': False } } @mock.patch('datalab.utils.Http.request') def test_external_table_query(self, mock_http_request): mock_http_request.return_value = self._request_result() data = self._get_data() schema = datalab.bigquery.Schema.from_data(data) table_uri = 'gs://datalab/weight.csv' options = datalab.bigquery.CSVOptions(skip_leading_rows=1) sql = 'SELECT * FROM weight' weight = datalab.bigquery.FederatedTable.from_storage(table_uri, schema=schema, csv_options=options) q = datalab.bigquery.Query(sql, data_sources={'weight': weight}, context=self._create_context()) q.execute_async() table_definition = self._get_table_definition(table_uri, skip_rows=1) expected_data = self._get_expected_request_data(sql, {'weight': table_definition}) request_url = 'https://www.googleapis.com/bigquery/v2/projects/test/jobs/' mock_http_request.assert_called_with(request_url, credentials=mock.ANY, data=expected_data) # Test with multiple URLs and no non-default options @mock.patch('datalab.utils.Http.request') def test_external_table_query2(self, mock_http_request): mock_http_request.return_value = self._request_result() data = self._get_data() schema = datalab.bigquery.Schema.from_data(data) table_uris = ['gs://datalab/weight1.csv', 'gs://datalab/weight2.csv'] sql = 'SELECT * FROM weight' weight = datalab.bigquery.FederatedTable.from_storage(table_uris, schema=schema) q = datalab.bigquery.Query(sql, data_sources={'weight': weight}, context=self._create_context()) q.execute_async() table_definition = self._get_table_definition(table_uris) expected_data = self._get_expected_request_data(sql, {'weight': table_definition}) request_url = 'https://www.googleapis.com/bigquery/v2/projects/test/jobs/' mock_http_request.assert_called_with(request_url, credentials=mock.ANY, data=expected_data) # Test with multiple tables and using keyword args @mock.patch('datalab.utils.Http.request') def test_external_tables_query(self, mock_http_request): mock_http_request.return_value = self._request_result() data = self._get_data() schema = datalab.bigquery.Schema.from_data(data) table_uri1 = 'gs://datalab/weight1.csv' table_uri2 = 'gs://datalab/weight2.csv' sql = 'SELECT * FROM weight1 JOIN weight2 ON day' options = datalab.bigquery.CSVOptions(skip_leading_rows=1) weight1 = datalab.bigquery.FederatedTable.from_storage(table_uri1, schema=schema, csv_options=options) weight2 = datalab.bigquery.FederatedTable.from_storage(table_uri2, schema=schema) q = datalab.bigquery.Query(sql, weight1=weight1, weight2=weight2, context=self._create_context()) q.execute_async() table_definition1 = self._get_table_definition(table_uri1, skip_rows=1) table_definition2 = self._get_table_definition(table_uri2) table_definitions = {'weight1': table_definition1, 'weight2': table_definition2} expected_data = self._get_expected_request_data(sql, table_definitions) request_url = 'https://www.googleapis.com/bigquery/v2/projects/test/jobs/' mock_http_request.assert_called_with(request_url, credentials=mock.ANY, data=expected_data) @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return datalab.context.Context(project_id, creds) ================================================ FILE: legacy_tests/bigquery/jobs_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import mock import unittest import google.auth import datalab.bigquery import datalab.context class TestCases(unittest.TestCase): @staticmethod def _make_job(id): return datalab.bigquery.Job(id, TestCases._create_context()) @mock.patch('datalab.bigquery._api.Api.jobs_get') def test_job_complete(self, mock_api_jobs_get): mock_api_jobs_get.return_value = {} j = TestCases._make_job('foo') self.assertFalse(j.is_complete) self.assertFalse(j.failed) mock_api_jobs_get.return_value = {'status': {'state': 'DONE'}} self.assertTrue(j.is_complete) self.assertFalse(j.failed) @mock.patch('datalab.bigquery._api.Api.jobs_get') def test_job_fatal_error(self, mock_api_jobs_get): mock_api_jobs_get.return_value = { 'status': { 'state': 'DONE', 'errorResult': { 'location': 'A', 'message': 'B', 'reason': 'C' } } } j = TestCases._make_job('foo') self.assertTrue(j.is_complete) self.assertTrue(j.failed) e = j.fatal_error self.assertIsNotNone(e) self.assertEqual('A', e.location) self.assertEqual('B', e.message) self.assertEqual('C', e.reason) @mock.patch('datalab.bigquery._api.Api.jobs_get') def test_job_errors(self, mock_api_jobs_get): mock_api_jobs_get.return_value = { 'status': { 'state': 'DONE', 'errors': [ { 'location': 'A', 'message': 'B', 'reason': 'C' }, { 'location': 'D', 'message': 'E', 'reason': 'F' } ] } } j = TestCases._make_job('foo') self.assertTrue(j.is_complete) self.assertFalse(j.failed) self.assertEqual(2, len(j.errors)) self.assertEqual('A', j.errors[0].location) self.assertEqual('B', j.errors[0].message) self.assertEqual('C', j.errors[0].reason) self.assertEqual('D', j.errors[1].location) self.assertEqual('E', j.errors[1].message) self.assertEqual('F', j.errors[1].reason) @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return datalab.context.Context(project_id, creds) @staticmethod def _create_api(): return datalab.bigquery._api.Api(TestCases._create_context()) ================================================ FILE: legacy_tests/bigquery/parser_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import unittest import datalab.bigquery as bq class TestCases(unittest.TestCase): def test_repeating_data(self): schema = [{'name': 'counts', 'type': 'INTEGER', 'mode': 'REPEATED'}] data = {'f': [{'v': [{'v': 0}, {'v': 1}, {'v': 2}]}]} parsed = {'counts': [0, 1, 2]} result = bq._parser.Parser.parse_row(schema, data) self.assertEqual(parsed, result) def test_non_nested_data(self): data = {u'f': [{u'v': u'1969'}, {u'v': u'1969'}, {u'v': u'1'}, {u'v': u'20'}, {u'v': None}, {u'v': u'AL'}, {u'v': u'true'}, {u'v': u'1'}, {u'v': u'7.81318256528'}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': u'AL'}, {u'v': u'1'}, {u'v': u'20'}, {u'v': None}, {u'v': u'88881998'}, {u'v': u'true'}, {u'v': u''}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': u'1'}, {u'v': u'0'}, {u'v': u'0'}, {u'v': u'2'}, {u'v': u'1'}, {u'v': u'19'}, {u'v': u'2'}]} natality_schema = [{u'description': u'Four-digit year of the birth. Example: 1975.', u'mode': u'REQUIRED', u'name': u'source_year', u'type': u'INTEGER'}, {u'description': u'Four-digit year of the birth. Example: 1975.', u'mode': u'NULLABLE', u'name': u'year', u'type': u'INTEGER'}, {u'description': u'Month index of the date of birth, where 1=January.', u'mode': u'NULLABLE', u'name': u'month', u'type': u'INTEGER'}, {u'description': u'Day of birth, starting from 1.', u'mode': u'NULLABLE', u'name': u'day', u'type': u'INTEGER'}, {u'description': u'Day of the week, where 1 is Sunday and 7 is Saturday.', u'mode': u'NULLABLE', u'name': u'wday', u'type': u'INTEGER'}, {u'description': u'The two character postal code for the state. ' u'Entries after 2004 do not include this value.', u'mode': u'NULLABLE', u'name': u'state', u'type': u'STRING'}, {u'description': u'TRUE if the child is male, FALSE if female.', u'mode': u'REQUIRED', u'name': u'is_male', u'type': u'BOOLEAN'}, {u'description': u'The race of the child. One of the following numbers:\n\n' u'1 - White\n2 - Black\n3 - American Indian\n4 - Chinese\n' u'5 - Japanese\n6 - Hawaiian\n7 - Filipino\n' u'9 - Unknown/Other\n18 - Asian Indian\n28 - Korean\n' u'39 - Samoan\n48 - Vietnamese', u'mode': u'NULLABLE', u'name': u'child_race', u'type': u'INTEGER'}, {u'description': u'Weight of the child, in pounds.', u'mode': u'NULLABLE', u'name': u'weight_pounds', u'type': u'FLOAT'}, {u'description': u'How many children were born as a result of this ' u'pregnancy. twins=2, triplets=3, and so on.', u'mode': u'NULLABLE', u'name': u'plurality', u'type': u'INTEGER'}, {u'description': u'Apgar scores measure the health of a newborn child on a ' u'scale from 0-10. Value after 1 minute. Available from ' u'1978-2002.', u'mode': u'NULLABLE', u'name': u'apgar_1min', u'type': u'INTEGER'}, {u'description': u'Apgar scores measure the health of a newborn child on a ' u'scale from 0-10. Value after 5 minutes. Available from ' u'1978-2002.', u'mode': u'NULLABLE', u'name': u'apgar_5min', u'type': u'INTEGER'}, {u'description': u"The two-letter postal code of the mother's state of " u"residence when the child was born.", u'mode': u'NULLABLE', u'name': u'mother_residence_state', u'type': u'STRING'}, {u'description': u'Race of the mother. Same values as child_race.', u'mode': u'NULLABLE', u'name': u'mother_race', u'type': u'INTEGER'}, {u'description': u'Reported age of the mother when giving birth.', u'mode': u'NULLABLE', u'name': u'mother_age', u'type': u'INTEGER'}, {u'description': u'The number of weeks of the pregnancy.', u'mode': u'NULLABLE', u'name': u'gestation_weeks', u'type': u'INTEGER'}, {u'description': u'Date of the last menstrual period in the format ' u'MMDDYYYY. Unknown values are recorded as "99" or "9999".', u'mode': u'NULLABLE', u'name': u'lmp', u'type': u'STRING'}, {u'description': u'True if the mother was married when she gave birth.', u'mode': u'NULLABLE', u'name': u'mother_married', u'type': u'BOOLEAN'}, {u'description': u"The two-letter postal code of the mother's birth state.", u'mode': u'NULLABLE', u'name': u'mother_birth_state', u'type': u'STRING'}, {u'description': u'True if the mother smoked cigarettes. Available starting ' u'2003.', u'mode': u'NULLABLE', u'name': u'cigarette_use', u'type': u'BOOLEAN'}, {u'description': u'Number of cigarettes smoked by the mother per day. ' u'Available starting 2003.', u'mode': u'NULLABLE', u'name': u'cigarettes_per_day', u'type': u'INTEGER'}, {u'description': u'True if the mother used alcohol. Available starting ' u'1989.', u'mode': u'NULLABLE', u'name': u'alcohol_use', u'type': u'BOOLEAN'}, {u'description': u'Number of drinks per week consumed by the mother. ' u'Available starting 1989.', u'mode': u'NULLABLE', u'name': u'drinks_per_week', u'type': u'INTEGER'}, {u'description': u'Number of pounds gained by the mother during pregnancy.', u'mode': u'NULLABLE', u'name': u'weight_gain_pounds', u'type': u'INTEGER'}, {u'description': u'Number of children previously born to the mother who are ' u'now living.', u'mode': u'NULLABLE', u'name': u'born_alive_alive', u'type': u'INTEGER'}, {u'description': u'Number of children previously born to the mother who are ' u'now dead.', u'mode': u'NULLABLE', u'name': u'born_alive_dead', u'type': u'INTEGER'}, {u'description': u'Number of children who were born dead ' u'(i.e. miscarriages)', u'mode': u'NULLABLE', u'name': u'born_dead', u'type': u'INTEGER'}, {u'description': u'Total number of children to whom the woman has ever ' u'given birth (includes the current birth).', u'mode': u'NULLABLE', u'name': u'ever_born', u'type': u'INTEGER'}, {u'description': u'Race of the father. Same values as child_race.', u'mode': u'NULLABLE', u'name': u'father_race', u'type': u'INTEGER'}, {u'description': u'Age of the father when the child was born.', u'mode': u'NULLABLE', u'name': u'father_age', u'type': u'INTEGER'}, {u'description': u'1 or 2, where 1 is a row from a full-reporting area, and ' u'2 is a row from a 50% sample area.', u'mode': u'NULLABLE', u'name': u'record_weight', u'type': u'INTEGER'}] parsed = {u'alcohol_use': None, u'apgar_1min': None, u'apgar_5min': None, u'born_alive_alive': 1, u'born_alive_dead': 0, u'born_dead': 0, u'child_race': 1, u'cigarette_use': None, u'cigarettes_per_day': None, u'day': 20, u'drinks_per_week': None, u'ever_born': 2, u'father_age': 19, u'father_race': 1, u'gestation_weeks': None, u'is_male': True, u'lmp': u'88881998', u'month': 1, u'mother_age': 20, u'mother_birth_state': u'', u'mother_married': True, u'mother_race': 1, u'mother_residence_state': u'AL', u'plurality': None, u'record_weight': 2, u'source_year': 1969, u'state': u'AL', u'wday': None, u'weight_gain_pounds': None, u'weight_pounds': 7.81318256528, u'year': 1969} self.assertEqual(parsed, bq._parser.Parser.parse_row(natality_schema, data)) def test_parse_nested_data(self): self.maxDiff = None # Show full diff on failure data = {u'f': [{u'v': {u'f': [{u'v': u'https://github.com/foo'}, {u'v': u'true'}, {u'v': u'2011/04/12 20:04:19 -0700'}, {u'v': u'true'}, {u'v': u'A website.'}, {u'v': u'17'}, {u'v': u'false'}, {u'v': u'true'}, {u'v': u'http://foo.com/'}, {u'v': None}, {u'v': None}, {u'v': u'424'}, {u'v': u'false'}, {u'v': u'foo'}, {u'v': None}, {u'v': u'foo'}, {u'v': u'0'}, {u'v': u'95'}, {u'v': u'2012/03/15 00:00:00 -0700'}, {u'v': u'Ruby'}]}}, {u'v': {u'f': [{u'v': u'http://foo.com/'}, {u'v': u'Flickr'}, {u'v': u'd+github@foo.com'}, {u'v': u'94c21234567890abcdef25e704b88407'}, {u'v': u'San Francisco, California'}, {u'v': u'foo'}, {u'v': u'Foo Bar'}, {u'v': u'User'}]}}, {u'v': u'2012/03/15 00:00:01 -0700'}, {u'v': u'true'}, {u'v': u'foo'}, {u'v': {u'f': [{u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': u'2de950123456789abcdef01234451feaf8ce6ae'}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': []}, {u'v': None}, {u'v': u'refs/heads/master'}, {u'v': None}, {u'v': u'1'}, {u'v': [ {u'v': {u'f': [ {u'v': u'2de958ab480eabe2501b343425b451feaf8ce6ae'}, {u'v': u'd+github@foo.com'}, {u'v': u'Foo tastes good.'}, {u'v': u'Foo Bar'}]}}]}, {u'v': None}, {u'v': None}]}}, {u'v': u'https://github.com/compare/d3e91cb736...2de958ab48'}, {u'v': u'PushEvent'}]} github_nested_schema = [{u'fields': [{u'name': u'url', u'type': u'STRING'}, {u'name': u'has_downloads', u'type': u'BOOLEAN'}, {u'name': u'created_at', u'type': u'STRING'}, {u'name': u'has_issues', u'type': u'BOOLEAN'}, {u'name': u'description', u'type': u'STRING'}, {u'name': u'forks', u'type': u'INTEGER'}, {u'name': u'fork', u'type': u'BOOLEAN'}, {u'name': u'has_wiki', u'type': u'BOOLEAN'}, {u'name': u'homepage', u'type': u'STRING'}, {u'name': u'integrate_branch', u'type': u'STRING'}, {u'name': u'master_branch', u'type': u'STRING'}, {u'name': u'size', u'type': u'INTEGER'}, {u'name': u'private', u'type': u'BOOLEAN'}, {u'name': u'name', u'type': u'STRING'}, {u'name': u'organization', u'type': u'STRING'}, {u'name': u'owner', u'type': u'STRING'}, {u'name': u'open_issues', u'type': u'INTEGER'}, {u'name': u'watchers', u'type': u'INTEGER'}, {u'name': u'pushed_at', u'type': u'STRING'}, {u'name': u'language', u'type': u'STRING'}], u'name': u'repository', u'type': u'RECORD'}, {u'fields': [{u'name': u'blog', u'type': u'STRING'}, {u'name': u'company', u'type': u'STRING'}, {u'name': u'email', u'type': u'STRING'}, {u'name': u'gravatar_id', u'type': u'STRING'}, {u'name': u'location', u'type': u'STRING'}, {u'name': u'login', u'type': u'STRING'}, {u'name': u'name', u'type': u'STRING'}, {u'name': u'type', u'type': u'STRING'}], u'name': u'actor_attributes', u'type': u'RECORD'}, {u'name': u'created_at', u'type': u'STRING'}, {u'name': u'public', u'type': u'BOOLEAN'}, {u'name': u'actor', u'type': u'STRING'}, {u'fields': [ {u'name': u'action', u'type': u'STRING'}, {u'name': u'after', u'type': u'STRING'}, {u'name': u'before', u'type': u'STRING'}, {u'name': u'commit', u'type': u'STRING'}, {u'fields': [ {u'name': u'commit_id', u'type': u'STRING'}, {u'name': u'body', u'type': u'STRING'}, {u'name': u'created_at', u'type': u'STRING'}, {u'name': u'id', u'type': u'INTEGER'}, {u'name': u'original_commit_id', u'type': u'STRING'}, {u'name': u'original_position', u'type': u'INTEGER'}, {u'name': u'path', u'type': u'STRING'}, {u'name': u'position', u'type': u'INTEGER'}, {u'name': u'updated_at', u'type': u'STRING'}, {u'fields': [{u'name': u'id', u'type': u'INTEGER'}, {u'name': u'gravatar_id', u'type': u'STRING'}, {u'name': u'avatar_url', u'type': u'STRING'}, {u'name': u'login', u'type': u'STRING'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'user', u'type': u'RECORD'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'comment', u'type': u'RECORD'}, {u'name': u'comment_id', u'type': u'INTEGER'}, {u'name': u'desc', u'type': u'STRING'}, {u'name': u'description', u'type': u'STRING'}, {u'name': u'head', u'type': u'STRING'}, {u'name': u'id', u'type': u'INTEGER'}, {u'name': u'issue', u'type': u'INTEGER'}, {u'name': u'issue_id', u'type': u'INTEGER'}, {u'name': u'master_branch', u'type': u'STRING'}, {u'name': u'master', u'type': u'STRING'}, {u'fields': [{u'name': u'id', u'type': u'INTEGER'}, {u'name': u'gravatar_id', u'type': u'STRING'}, {u'name': u'avatar_url', u'type': u'STRING'}, {u'name': u'login', u'type': u'STRING'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'member', u'type': u'RECORD'}, {u'name': u'name', u'type': u'STRING'}, {u'name': u'number', u'type': u'INTEGER'}, {u'fields': [{u'name': u'action', u'type': u'STRING'}, {u'name': u'html_url', u'type': u'STRING'}, {u'name': u'page_name', u'type': u'STRING'}, {u'name': u'sha', u'type': u'STRING'}, {u'name': u'summary', u'type': u'STRING'}, {u'name': u'title', u'type': u'STRING'}], u'mode': u'REPEATED', u'name': u'pages', u'type': u'RECORD'}, {u'fields': [ {u'name': u'additions', u'type': u'INTEGER'}, {u'fields': [ {u'fields': [ {u'name': u'clone_url', u'type': u'STRING'}, {u'name': u'created_at', u'type': u'STRING'}, {u'name': u'description', u'type': u'STRING'}, {u'name': u'fork', u'type': u'BOOLEAN'}, {u'name': u'forks', u'type': u'INTEGER'}, {u'name': u'git_url', u'type': u'STRING'}, {u'name': u'has_downloads', u'type': u'BOOLEAN'}, {u'name': u'has_issues', u'type': u'BOOLEAN'}, {u'name': u'has_wiki', u'type': u'BOOLEAN'}, {u'name': u'homepage', u'type': u'STRING'}, {u'name': u'html_url', u'type': u'STRING'}, {u'name': u'id', u'type': u'INTEGER'}, {u'name': u'language', u'type': u'STRING'}, {u'name': u'master_branch', u'type': u'STRING'}, {u'name': u'name', u'type': u'STRING'}, {u'name': u'open_issues', u'type': u'INTEGER'}, {u'fields': [ {u'name': u'id', u'type': u'INTEGER'}, {u'name': u'gravatar_id', u'type': u'STRING'}, {u'name': u'avatar_url', u'type': u'STRING'}, {u'name': u'login', u'type': u'STRING'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'owner', u'type': u'RECORD'}, {u'name': u'private', u'type': u'BOOLEAN'}, {u'name': u'pushed_at', u'type': u'STRING'}, {u'name': u'size', u'type': u'INTEGER'}, {u'name': u'ssh_url', u'type': u'STRING'}, {u'name': u'svn_url', u'type': u'STRING'}, {u'name': u'updated_at', u'type': u'STRING'}, {u'name': u'watchers', u'type': u'INTEGER'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'repo', u'type': u'RECORD'}, {u'name': u'sha', u'type': u'STRING'}, {u'name': u'ref', u'type': u'STRING'}, {u'fields': [{u'name': u'id', u'type': u'INTEGER'}, {u'name': u'gravatar_id', u'type': u'STRING'}, {u'name': u'avatar_url', u'type': u'STRING'}, {u'name': u'login', u'type': u'STRING'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'user', u'type': u'RECORD'}, {u'name': u'label', u'type': u'STRING'}], u'name': u'base', u'type': u'RECORD'}, {u'name': u'body', u'type': u'STRING'}, {u'name': u'changed_files', u'type': u'INTEGER'}, {u'name': u'closed_at', u'type': u'STRING'}, {u'name': u'comments', u'type': u'INTEGER'}, {u'name': u'commits', u'type': u'INTEGER'}, {u'name': u'created_at', u'type': u'STRING'}, {u'name': u'deletions', u'type': u'INTEGER'}, {u'name': u'diff_url', u'type': u'STRING'}, {u'fields': [ {u'fields': [ {u'name': u'clone_url', u'type': u'STRING'}, {u'name': u'created_at', u'type': u'STRING'}, {u'name': u'description', u'type': u'STRING'}, {u'name': u'fork', u'type': u'BOOLEAN'}, {u'name': u'forks', u'type': u'INTEGER'}, {u'name': u'git_url', u'type': u'STRING'}, {u'name': u'has_downloads', u'type': u'BOOLEAN'}, {u'name': u'has_issues', u'type': u'BOOLEAN'}, {u'name': u'has_wiki', u'type': u'BOOLEAN'}, {u'name': u'homepage', u'type': u'STRING'}, {u'name': u'html_url', u'type': u'STRING'}, {u'name': u'id', u'type': u'INTEGER'}, {u'name': u'language', u'type': u'STRING'}, {u'name': u'master_branch', u'type': u'STRING'}, {u'name': u'name', u'type': u'STRING'}, {u'name': u'open_issues', u'type': u'INTEGER'}, {u'fields': [ {u'name': u'id', u'type': u'INTEGER'}, {u'name': u'gravatar_id', u'type': u'STRING'}, {u'name': u'avatar_url', u'type': u'STRING'}, {u'name': u'login', u'type': u'STRING'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'owner', u'type': u'RECORD'}, {u'name': u'private', u'type': u'BOOLEAN'}, {u'name': u'pushed_at', u'type': u'STRING'}, {u'name': u'size', u'type': u'INTEGER'}, {u'name': u'ssh_url', u'type': u'STRING'}, {u'name': u'svn_url', u'type': u'STRING'}, {u'name': u'updated_at', u'type': u'STRING'}, {u'name': u'watchers', u'type': u'INTEGER'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'repo', u'type': u'RECORD'}, {u'name': u'sha', u'type': u'STRING'}, {u'name': u'ref', u'type': u'STRING'}, {u'fields': [{u'name': u'id', u'type': u'INTEGER'}, {u'name': u'gravatar_id', u'type': u'STRING'}, {u'name': u'avatar_url', u'type': u'STRING'}, {u'name': u'login', u'type': u'STRING'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'user', u'type': u'RECORD'}, {u'name': u'label', u'type': u'STRING'}], u'name': u'head', u'type': u'RECORD'}, {u'name': u'html_url', u'type': u'STRING'}, {u'name': u'issue_url', u'type': u'STRING'}, {u'name': u'id', u'type': u'INTEGER'}, {u'fields': [ {u'fields': [ {u'name': u'href', u'type': u'STRING'}], u'name': u'self', u'type': u'RECORD'}, {u'fields': [{u'name': u'href', u'type': u'STRING'}], u'name': u'html', u'type': u'RECORD'}, {u'fields': [{u'name': u'href', u'type': u'STRING'}], u'name': u'review_comments', u'type': u'RECORD'}, {u'fields': [{u'name': u'href', u'type': u'STRING'}], u'name': u'comments', u'type': u'RECORD'}], u'name': u'_links', u'type': u'RECORD'}, {u'name': u'mergeable', u'type': u'BOOLEAN'}, {u'name': u'merged', u'type': u'BOOLEAN'}, {u'name': u'merged_at', u'type': u'STRING'}, {u'fields': [{u'name': u'id', u'type': u'INTEGER'}, {u'name': u'gravatar_id', u'type': u'STRING'}, {u'name': u'avatar_url', u'type': u'STRING'}, {u'name': u'login', u'type': u'STRING'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'merged_by', u'type': u'RECORD'}, {u'name': u'number', u'type': u'INTEGER'}, {u'name': u'patch_url', u'type': u'STRING'}, {u'name': u'review_comments', u'type': u'INTEGER'}, {u'name': u'state', u'type': u'STRING'}, {u'name': u'title', u'type': u'STRING'}, {u'name': u'updated_at', u'type': u'STRING'}, {u'name': u'url', u'type': u'STRING'}, {u'fields': [{u'name': u'id', u'type': u'INTEGER'}, {u'name': u'gravatar_id', u'type': u'STRING'}, {u'name': u'avatar_url', u'type': u'STRING'}, {u'name': u'login', u'type': u'STRING'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'user', u'type': u'RECORD'}], u'name': u'pull_request', u'type': u'RECORD'}, {u'name': u'ref', u'type': u'STRING'}, {u'name': u'ref_type', u'type': u'STRING'}, {u'name': u'size', u'type': u'INTEGER'}, {u'fields': [ {u'name': u'encoded', u'type': u'STRING'}, {u'name': u'actor_email', u'type': u'STRING'}, {u'name': u'message', u'type': u'STRING'}, {u'name': u'actor_login', u'type': u'STRING'}], u'mode': u'REPEATED', u'name': u'shas', u'type': u'RECORD'}, {u'fields': [{u'name': u'login', u'type': u'STRING'}, {u'name': u'repos', u'type': u'INTEGER'}, {u'name': u'followers', u'type': u'INTEGER'}, {u'name': u'id', u'type': u'INTEGER'}, {u'name': u'gravatar_id', u'type': u'STRING'}], u'name': u'target', u'type': u'RECORD'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'payload', u'type': u'RECORD'}, {u'name': u'url', u'type': u'STRING'}, {u'name': u'type', u'type': u'STRING'}] parsed = {u'actor': u'foo', u'actor_attributes': {u'blog': u'http://foo.com/', u'company': u'Flickr', u'email': u'd+github@foo.com', u'gravatar_id': u'94c21234567890abcdef25e704b88407', u'location': u'San Francisco, California', u'login': u'foo', u'name': u'Foo Bar', u'type': u'User'}, u'created_at': u'2012/03/15 00:00:01 -0700', u'payload': {u'action': None, u'after': None, u'before': None, u'comment': {}, u'comment_id': None, u'commit': None, u'desc': None, u'description': None, u'head': u'2de950123456789abcdef01234451feaf8ce6ae', u'id': None, u'issue': None, u'issue_id': None, u'master': None, u'master_branch': None, u'member': {}, u'name': None, u'number': None, u'pages': [], u'pull_request': {}, u'ref': u'refs/heads/master', u'ref_type': None, u'shas': [{u'actor_email': u'd+github@foo.com', u'actor_login': u'Foo Bar', u'encoded': u'2de958ab480eabe2501b343425b451feaf8ce6ae', u'message': u'Foo tastes good.'}], u'size': 1, u'target': {}, u'url': None}, u'public': True, u'repository': {u'created_at': u'2011/04/12 20:04:19 -0700', u'description': u'A website.', u'fork': False, u'forks': 17, u'has_downloads': True, u'has_issues': True, u'has_wiki': True, u'homepage': u'http://foo.com/', u'integrate_branch': None, u'language': u'Ruby', u'master_branch': None, u'name': u'foo', u'open_issues': 0, u'organization': None, u'owner': u'foo', u'private': False, u'pushed_at': u'2012/03/15 00:00:00 -0700', u'size': 424, u'url': u'https://github.com/foo', u'watchers': 95}, u'type': u'PushEvent', u'url': u'https://github.com/compare/d3e91cb736...2de958ab48'} self.assertEqual(parsed, bq._parser.Parser.parse_row(github_nested_schema, data)) ================================================ FILE: legacy_tests/bigquery/query_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals from builtins import str import mock import unittest import google.auth import datalab.bigquery import datalab.context class TestCases(unittest.TestCase): @mock.patch('datalab.bigquery._api.Api.tabledata_list') @mock.patch('datalab.bigquery._api.Api.jobs_insert_query') @mock.patch('datalab.bigquery._api.Api.jobs_query_results') @mock.patch('datalab.bigquery._api.Api.jobs_get') @mock.patch('datalab.bigquery._api.Api.tables_get') def test_single_result_query(self, mock_api_tables_get, mock_api_jobs_get, mock_api_jobs_query_results, mock_api_insert_query, mock_api_tabledata_list): mock_api_tables_get.return_value = TestCases._create_tables_get_result() mock_api_jobs_get.return_value = {'status': {'state': 'DONE'}} mock_api_jobs_query_results.return_value = {'jobComplete': True} mock_api_insert_query.return_value = TestCases._create_insert_done_result() mock_api_tabledata_list.return_value = TestCases._create_single_row_result() sql = 'SELECT field1 FROM [table] LIMIT 1' q = TestCases._create_query(sql) results = q.results() self.assertEqual(sql, results.sql) self.assertEqual('(%s)' % sql, q._repr_sql_()) self.assertEqual(sql, str(q)) self.assertEqual(1, results.length) first_result = results[0] self.assertEqual('value1', first_result['field1']) @mock.patch('datalab.bigquery._api.Api.jobs_insert_query') @mock.patch('datalab.bigquery._api.Api.jobs_query_results') @mock.patch('datalab.bigquery._api.Api.jobs_get') @mock.patch('datalab.bigquery._api.Api.tables_get') def test_empty_result_query(self, mock_api_tables_get, mock_api_jobs_get, mock_api_jobs_query_results, mock_api_insert_query): mock_api_tables_get.return_value = TestCases._create_tables_get_result(0) mock_api_jobs_get.return_value = {'status': {'state': 'DONE'}} mock_api_jobs_query_results.return_value = {'jobComplete': True} mock_api_insert_query.return_value = TestCases._create_insert_done_result() q = TestCases._create_query() results = q.results() self.assertEqual(0, results.length) @mock.patch('datalab.bigquery._api.Api.jobs_insert_query') @mock.patch('datalab.bigquery._api.Api.jobs_query_results') @mock.patch('datalab.bigquery._api.Api.jobs_get') @mock.patch('datalab.bigquery._api.Api.tables_get') def test_incomplete_result_query(self, mock_api_tables_get, mock_api_jobs_get, mock_api_jobs_query_results, mock_api_insert_query): mock_api_tables_get.return_value = TestCases._create_tables_get_result() mock_api_jobs_get.return_value = {'status': {'state': 'DONE'}} mock_api_jobs_query_results.return_value = {'jobComplete': True} mock_api_insert_query.return_value = TestCases._create_incomplete_result() q = TestCases._create_query() results = q.results() self.assertEqual(1, results.length) self.assertEqual('test_job', results.job_id) @mock.patch('datalab.bigquery._api.Api.jobs_insert_query') def test_malformed_response_raises_exception(self, mock_api_insert_query): mock_api_insert_query.return_value = {} q = TestCases._create_query() with self.assertRaises(Exception) as error: q.results() self.assertEqual('Unexpected response from server', str(error.exception)) def test_udf_expansion(self): sql = 'SELECT * FROM udf(source)' udf = datalab.bigquery.UDF('inputs', [('foo', 'string'), ('bar', 'integer')], 'udf', 'code') context = TestCases._create_context() query = datalab.bigquery.Query(sql, udf=udf, context=context) self.assertEquals('SELECT * FROM (SELECT foo, bar FROM udf(source))', query.sql) # Alternate form query = datalab.bigquery.Query(sql, udfs=[udf], context=context) self.assertEquals('SELECT * FROM (SELECT foo, bar FROM udf(source))', query.sql) @staticmethod def _create_query(sql=None): if sql is None: sql = 'SELECT * ...' return datalab.bigquery.Query(sql, context=TestCases._create_context()) @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return datalab.context.Context(project_id, creds) @staticmethod def _create_insert_done_result(): # pylint: disable=g-continuation-in-parens-misaligned return { 'jobReference': { 'jobId': 'test_job' }, 'configuration': { 'query': { 'destinationTable': { 'projectId': 'project', 'datasetId': 'dataset', 'tableId': 'table' } } }, 'jobComplete': True, } @staticmethod def _create_single_row_result(): # pylint: disable=g-continuation-in-parens-misaligned return { 'totalRows': 1, 'rows': [ {'f': [{'v': 'value1'}]} ] } @staticmethod def _create_empty_result(): # pylint: disable=g-continuation-in-parens-misaligned return { 'totalRows': 0 } @staticmethod def _create_incomplete_result(): # pylint: disable=g-continuation-in-parens-misaligned return { 'jobReference': { 'jobId': 'test_job' }, 'configuration': { 'query': { 'destinationTable': { 'projectId': 'project', 'datasetId': 'dataset', 'tableId': 'table' } } }, 'jobComplete': False } @staticmethod def _create_page_result(page_token=None): # pylint: disable=g-continuation-in-parens-misaligned return { 'totalRows': 2, 'rows': [ {'f': [{'v': 'value1'}]} ], 'pageToken': page_token } @staticmethod def _create_tables_get_result(num_rows=1, schema=None): if schema is None: schema = [{'name': 'field1', 'type': 'string'}] return { 'numRows': num_rows, 'schema': { 'fields': schema }, } ================================================ FILE: legacy_tests/bigquery/sampling_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import unittest from datalab.bigquery import Sampling class TestCases(unittest.TestCase): BASE_SQL = '[]' def test_default(self): expected_sql = 'SELECT * FROM (%s) LIMIT 5' % TestCases.BASE_SQL self._apply_sampling(Sampling.default(), expected_sql) def test_default_custom_count(self): expected_sql = 'SELECT * FROM (%s) LIMIT 20' % TestCases.BASE_SQL self._apply_sampling(Sampling.default(count=20), expected_sql) def test_default_custom_fields(self): expected_sql = 'SELECT f1,f2 FROM (%s) LIMIT 5' % TestCases.BASE_SQL self._apply_sampling(Sampling.default(fields=['f1', 'f2']), expected_sql) def test_default_all_fields(self): expected_sql = 'SELECT * FROM (%s) LIMIT 5' % TestCases.BASE_SQL self._apply_sampling(Sampling.default(fields=[]), expected_sql) def test_hashed(self): expected_sql = 'SELECT * FROM (%s) WHERE ABS(HASH(f1)) %% 100 < 5' % TestCases.BASE_SQL self._apply_sampling(Sampling.hashed('f1', 5), expected_sql) def test_hashed_and_limited(self): expected_sql = 'SELECT * FROM (%s) WHERE ABS(HASH(f1)) %% 100 < 5 LIMIT 100' \ % TestCases.BASE_SQL self._apply_sampling(Sampling.hashed('f1', 5, count=100), expected_sql) def test_hashed_with_fields(self): expected_sql = 'SELECT f1 FROM (%s) WHERE ABS(HASH(f1)) %% 100 < 5' % TestCases.BASE_SQL self._apply_sampling(Sampling.hashed('f1', 5, fields=['f1']), expected_sql) def test_sorted_ascending(self): expected_sql = 'SELECT * FROM (%s) ORDER BY f1 LIMIT 5' % TestCases.BASE_SQL self._apply_sampling(Sampling.sorted('f1'), expected_sql) def test_sorted_descending(self): expected_sql = 'SELECT * FROM (%s) ORDER BY f1 DESC LIMIT 5' % TestCases.BASE_SQL self._apply_sampling(Sampling.sorted('f1', ascending=False), expected_sql) def test_sorted_with_fields(self): expected_sql = 'SELECT f1,f2 FROM (%s) ORDER BY f1 LIMIT 5' % TestCases.BASE_SQL self._apply_sampling(Sampling.sorted('f1', fields=['f1', 'f2']), expected_sql) def _apply_sampling(self, sampling, expected_query): sampled_query = sampling(TestCases.BASE_SQL) self.assertEqual(sampled_query, expected_query) ================================================ FILE: legacy_tests/bigquery/schema_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import collections import pandas import sys import unittest import datalab.bigquery import datalab.utils class TestCases(unittest.TestCase): def test_schema_from_dataframe(self): df = TestCases._create_data_frame() result = datalab.bigquery.Schema.from_data(df) self.assertEqual(datalab.bigquery.Schema.from_data(TestCases._create_inferred_schema()), result) def test_schema_from_data(self): variant1 = [ 3, 2.0, True, ['cow', 'horse', [0, []]] ] variant2 = collections.OrderedDict() variant2['Column1'] = 3 variant2['Column2'] = 2.0 variant2['Column3'] = True variant2['Column4'] = collections.OrderedDict() variant2['Column4']['Column1'] = 'cow' variant2['Column4']['Column2'] = 'horse' variant2['Column4']['Column3'] = collections.OrderedDict() variant2['Column4']['Column3']['Column1'] = 0 variant2['Column4']['Column3']['Column2'] = collections.OrderedDict() master = [ {'name': 'Column1', 'type': 'INTEGER'}, {'name': 'Column2', 'type': 'FLOAT'}, {'name': 'Column3', 'type': 'BOOLEAN'}, {'name': 'Column4', 'type': 'RECORD', 'fields': [ {'name': 'Column1', 'type': 'STRING'}, {'name': 'Column2', 'type': 'STRING'}, {'name': 'Column3', 'type': 'RECORD', 'fields': [ {'name': 'Column1', 'type': 'INTEGER'}, {'name': 'Column2', 'type': 'RECORD', 'fields': []} ]} ]} ] schema_master = datalab.bigquery.Schema(master) with self.assertRaises(Exception) as error1: datalab.bigquery.Schema.from_data(variant1) if sys.version_info[0] == 3: self.assertEquals('Cannot create a schema from heterogeneous list [3, 2.0, True, ' + '[\'cow\', \'horse\', [0, []]]]; perhaps you meant to use ' + 'Schema.from_record?', str(error1.exception)) else: self.assertEquals('Cannot create a schema from heterogeneous list [3, 2.0, True, ' + '[u\'cow\', u\'horse\', [0, []]]]; perhaps you meant to use ' + 'Schema.from_record?', str(error1.exception)) with self.assertRaises(Exception) as error2: datalab.bigquery.Schema.from_data(variant2) if sys.version_info[0] == 3: self.assertEquals('Cannot create a schema from dict OrderedDict([(\'Column1\', 3), ' + '(\'Column2\', 2.0), (\'Column3\', True), (\'Column4\', ' + 'OrderedDict([(\'Column1\', \'cow\'), (\'Column2\', \'horse\'), ' + '(\'Column3\', OrderedDict([(\'Column1\', 0), (\'Column2\', ' + 'OrderedDict())]))]))]); perhaps you meant to use Schema.from_record?', str(error2.exception)) else: self.assertEquals('Cannot create a schema from dict OrderedDict([(u\'Column1\', 3), ' + '(u\'Column2\', 2.0), (u\'Column3\', True), (u\'Column4\', ' + 'OrderedDict([(u\'Column1\', u\'cow\'), (u\'Column2\', u\'horse\'), ' + '(u\'Column3\', OrderedDict([(u\'Column1\', 0), (u\'Column2\', ' + 'OrderedDict())]))]))]); perhaps you meant to use Schema.from_record?', str(error2.exception)) schema3 = datalab.bigquery.Schema.from_data([variant1]) schema4 = datalab.bigquery.Schema.from_data([variant2]) schema5 = datalab.bigquery.Schema.from_data(master) schema6 = datalab.bigquery.Schema.from_record(variant1) schema7 = datalab.bigquery.Schema.from_record(variant2) self.assertEquals(schema_master, schema3, 'schema inferred from list of lists with from_data') self.assertEquals(schema_master, schema4, 'schema inferred from list of dicts with from_data') self.assertEquals(schema_master, schema5, 'schema inferred from BQ schema list with from_data') self.assertEquals(schema_master, schema6, 'schema inferred from list with from_record') self.assertEquals(schema_master, schema7, 'schema inferred from dict with from_record') @staticmethod def _create_data_frame(): data = { 'some': [ 0, 1, 2, 3 ], 'column': [ 'r0', 'r1', 'r2', 'r3' ], 'headers': [ 10.0, 10.0, 10.0, 10.0 ] } return pandas.DataFrame(data) @staticmethod def _create_inferred_schema(extra_field=None): schema = [ {'name': 'some', 'type': 'INTEGER'}, {'name': 'column', 'type': 'STRING'}, {'name': 'headers', 'type': 'FLOAT'}, ] if extra_field: schema.append({'name': extra_field, 'type': 'INTEGER'}) return schema ================================================ FILE: legacy_tests/bigquery/table_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from builtins import object import calendar import datetime as dt import mock import pandas import unittest import google.auth import datalab.bigquery import datalab.context import datalab.utils class TestCases(unittest.TestCase): def _check_name_parts(self, table): parsed_name = table._name_parts self.assertEqual('test', parsed_name[0]) self.assertEqual('requestlogs', parsed_name[1]) self.assertEqual('today', parsed_name[2]) self.assertEqual('', parsed_name[3]) self.assertEqual('[test:requestlogs.today]', table._repr_sql_()) self.assertEqual('test:requestlogs.today', str(table)) def test_api_paths(self): name = datalab.bigquery._utils.TableName('a', 'b', 'c', 'd') self.assertEqual('/projects/a/datasets/b/tables/cd', datalab.bigquery._api.Api._TABLES_PATH % name) self.assertEqual('/projects/a/datasets/b/tables/cd/data', datalab.bigquery._api.Api._TABLEDATA_PATH % name) name = datalab.bigquery._utils.DatasetName('a', 'b') self.assertEqual('/projects/a/datasets/b', datalab.bigquery._api.Api._DATASETS_PATH % name) def test_parse_full_name(self): table = TestCases._create_table('test:requestlogs.today') self._check_name_parts(table) def test_parse_local_name(self): table = TestCases._create_table('requestlogs.today') self._check_name_parts(table) def test_parse_dict_full_name(self): table = TestCases._create_table({'project_id': 'test', 'dataset_id': 'requestlogs', 'table_id': 'today'}) self._check_name_parts(table) def test_parse_dict_local_name(self): table = TestCases._create_table({'dataset_id': 'requestlogs', 'table_id': 'today'}) self._check_name_parts(table) def test_parse_named_tuple_name(self): table = TestCases._create_table(datalab.bigquery._utils.TableName('test', 'requestlogs', 'today', '')) self._check_name_parts(table) def test_parse_tuple_full_name(self): table = TestCases._create_table(('test', 'requestlogs', 'today')) self._check_name_parts(table) def test_parse_tuple_local(self): table = TestCases._create_table(('requestlogs', 'today')) self._check_name_parts(table) def test_parse_array_full_name(self): table = TestCases._create_table(['test', 'requestlogs', 'today']) self._check_name_parts(table) def test_parse_array_local(self): table = TestCases._create_table(['requestlogs', 'today']) self._check_name_parts(table) def test_parse_invalid_name(self): with self.assertRaises(Exception): TestCases._create_table('today@') @mock.patch('datalab.bigquery._api.Api.tables_get') def test_table_metadata(self, mock_api_tables_get): name = 'test:requestlogs.today' ts = dt.datetime.utcnow() mock_api_tables_get.return_value = TestCases._create_table_info_result(ts=ts) t = TestCases._create_table(name) metadata = t.metadata self.assertEqual('Logs', metadata.friendly_name) self.assertEqual(2, metadata.rows) self.assertEqual(2, metadata.rows) self.assertTrue(abs((metadata.created_on - ts).total_seconds()) <= 1) self.assertEqual(None, metadata.expires_on) @mock.patch('datalab.bigquery._api.Api.tables_get') def test_table_schema(self, mock_api_tables): mock_api_tables.return_value = TestCases._create_table_info_result() t = TestCases._create_table('test:requestlogs.today') schema = t.schema self.assertEqual(2, len(schema)) self.assertEqual('name', schema[0].name) @mock.patch('datalab.bigquery._api.Api.tables_get') def test_table_schema_nested(self, mock_api_tables): mock_api_tables.return_value = TestCases._create_table_info_nested_schema_result() t = TestCases._create_table('test:requestlogs.today') schema = t.schema self.assertEqual(4, len(schema)) self.assertEqual('name', schema[0].name) self.assertEqual('val', schema[1].name) self.assertEqual('more', schema[2].name) self.assertEqual('more.xyz', schema[3].name) self.assertIsNone(schema['value']) self.assertIsNotNone(schema['val']) @mock.patch('datalab.bigquery._api.Api.tables_get') def test_malformed_response_raises_exception(self, mock_api_tables_get): mock_api_tables_get.return_value = {} t = TestCases._create_table('test:requestlogs.today') with self.assertRaises(Exception) as error: t.schema self.assertEqual('Unexpected table response: missing schema', str(error.exception)) @mock.patch('datalab.bigquery._api.Api.tables_list') @mock.patch('datalab.bigquery._api.Api.datasets_get') def test_dataset_list(self, mock_api_datasets_get, mock_api_tables_list): mock_api_datasets_get.return_value = None mock_api_tables_list.return_value = TestCases._create_table_list_result() ds = datalab.bigquery.Dataset('testds', context=TestCases._create_context()) tables = [] for table in ds: tables.append(table) self.assertEqual(2, len(tables)) self.assertEqual('test:testds.testTable1', str(tables[0])) self.assertEqual('test:testds.testTable2', str(tables[1])) @mock.patch('datalab.bigquery._api.Api.tables_list') @mock.patch('datalab.bigquery._api.Api.datasets_get') def test_table_list(self, mock_api_datasets_get, mock_api_tables_list): mock_api_datasets_get.return_value = None mock_api_tables_list.return_value = TestCases._create_table_list_result() ds = datalab.bigquery.Dataset('testds', context=TestCases._create_context()) tables = [] for table in ds.tables(): tables.append(table) self.assertEqual(2, len(tables)) self.assertEqual('test:testds.testTable1', str(tables[0])) self.assertEqual('test:testds.testTable2', str(tables[1])) @mock.patch('datalab.bigquery._api.Api.tables_list') @mock.patch('datalab.bigquery._api.Api.datasets_get') def test_view_list(self, mock_api_datasets_get, mock_api_tables_list): mock_api_datasets_get.return_value = None mock_api_tables_list.return_value = TestCases._create_table_list_result() ds = datalab.bigquery.Dataset('testds', context=TestCases._create_context()) views = [] for view in ds.views(): views.append(view) self.assertEqual(1, len(views)) self.assertEqual('test:testds.testView1', str(views[0])) @mock.patch('datalab.bigquery._api.Api.tables_list') @mock.patch('datalab.bigquery._api.Api.datasets_get') def test_table_list_empty(self, mock_api_datasets_get, mock_api_tables_list): mock_api_datasets_get.return_value = None mock_api_tables_list.return_value = TestCases._create_table_list_empty_result() ds = datalab.bigquery.Dataset('testds', context=TestCases._create_context()) tables = [] for table in ds: tables.append(table) self.assertEqual(0, len(tables)) @mock.patch('datalab.bigquery._api.Api.tables_get') def test_table_exists(self, mock_api_tables_get): mock_api_tables_get.return_value = None tbl = datalab.bigquery.Table('testds.testTable0', context=TestCases._create_context()) self.assertTrue(tbl.exists()) mock_api_tables_get.side_effect = datalab.utils.RequestException(404, 'failed') self.assertFalse(tbl.exists()) @mock.patch('datalab.bigquery._api.Api.tables_insert') @mock.patch('datalab.bigquery._api.Api.tables_list') @mock.patch('datalab.bigquery._api.Api.datasets_get') def test_tables_create(self, mock_api_datasets_get, mock_api_tables_list, mock_api_tables_insert): mock_api_datasets_get.return_value = None mock_api_tables_list.return_value = [] schema = TestCases._create_inferred_schema() mock_api_tables_insert.return_value = {} with self.assertRaises(Exception) as error: TestCases._create_table_with_schema(schema) self.assertEqual('Table test:testds.testTable0 could not be created as it already exists', str(error.exception)) mock_api_tables_insert.return_value = {'selfLink': 'http://foo'} self.assertIsNotNone(TestCases._create_table_with_schema(schema), 'Expected a table') @mock.patch('uuid.uuid4') @mock.patch('time.sleep') @mock.patch('datalab.bigquery._api.Api.tables_list') @mock.patch('datalab.bigquery._api.Api.tables_insert') @mock.patch('datalab.bigquery._api.Api.tables_get') @mock.patch('datalab.bigquery._api.Api.tabledata_insert_all') @mock.patch('datalab.bigquery._api.Api.datasets_get') def test_insert_data_no_table(self, mock_api_datasets_get, mock_api_tabledata_insert_all, mock_api_tables_get, mock_api_tables_insert, mock_api_tables_list, mock_time_sleep, mock_uuid): mock_uuid.return_value = TestCases._create_uuid() mock_time_sleep.return_value = None mock_api_tables_list.return_value = [] mock_api_tables_insert.return_value = {'selfLink': 'http://foo'} mock_api_tables_get.side_effect = datalab.utils.RequestException(404, 'failed') mock_api_tabledata_insert_all.return_value = {} mock_api_datasets_get.return_value = None table = TestCases._create_table_with_schema(TestCases._create_inferred_schema()) df = TestCases._create_data_frame() with self.assertRaises(Exception) as error: table.insert_data(df) self.assertEqual('Table %s does not exist.' % str(table), str(error.exception)) @mock.patch('uuid.uuid4') @mock.patch('time.sleep') @mock.patch('datalab.bigquery._api.Api.datasets_get') @mock.patch('datalab.bigquery._api.Api.tables_list') @mock.patch('datalab.bigquery._api.Api.tables_insert') @mock.patch('datalab.bigquery._api.Api.tables_get') @mock.patch('datalab.bigquery._api.Api.tabledata_insert_all') def test_insert_data_missing_field(self, mock_api_tabledata_insert_all, mock_api_tables_get, mock_api_tables_insert, mock_api_tables_list, mock_api_datasets_get, mock_time_sleep, mock_uuid,): # Truncate the schema used when creating the table so we have an unmatched column in insert. schema = TestCases._create_inferred_schema()[:2] mock_uuid.return_value = TestCases._create_uuid() mock_time_sleep.return_value = None mock_api_datasets_get.return_value = None mock_api_tables_insert.return_value = {'selfLink': 'http://foo'} mock_api_tables_list.return_value = [] mock_api_tables_get.return_value = {'schema': {'fields': schema}} mock_api_tabledata_insert_all.return_value = {} table = TestCases._create_table_with_schema(schema) df = TestCases._create_data_frame() with self.assertRaises(Exception) as error: table.insert_data(df) self.assertEqual('Table does not contain field headers', str(error.exception)) @mock.patch('uuid.uuid4') @mock.patch('time.sleep') @mock.patch('datalab.bigquery._api.Api.tables_list') @mock.patch('datalab.bigquery._api.Api.tables_insert') @mock.patch('datalab.bigquery._api.Api.tables_get') @mock.patch('datalab.bigquery._api.Api.tabledata_insert_all') @mock.patch('datalab.bigquery._api.Api.datasets_get') def test_insert_data_mismatched_schema(self, mock_api_datasets_get, mock_api_tabledata_insert_all, mock_api_tables_get, mock_api_tables_insert, mock_api_tables_list, mock_time_sleep, mock_uuid): # Change the schema used when creating the table so we get a mismatch when inserting. schema = TestCases._create_inferred_schema() schema[2]['type'] = 'STRING' mock_uuid.return_value = TestCases._create_uuid() mock_time_sleep.return_value = None mock_api_tables_list.return_value = [] mock_api_tables_insert.return_value = {'selfLink': 'http://foo'} mock_api_tables_get.return_value = {'schema': {'fields': schema}} mock_api_tabledata_insert_all.return_value = {} mock_api_datasets_get.return_value = None table = TestCases._create_table_with_schema(schema) df = TestCases._create_data_frame() with self.assertRaises(Exception) as error: table.insert_data(df) self.assertEqual('Field headers in data has type FLOAT but in table has type STRING', str(error.exception)) @mock.patch('uuid.uuid4') @mock.patch('time.sleep') @mock.patch('datalab.bigquery._api.Api.datasets_get') @mock.patch('datalab.bigquery._api.Api.tables_list') @mock.patch('datalab.bigquery._api.Api.tables_insert') @mock.patch('datalab.bigquery._api.Api.tables_get') @mock.patch('datalab.bigquery._api.Api.tabledata_insert_all') def test_insert_data_dataframe(self, mock_api_tabledata_insert_all, mock_api_tables_get, mock_api_tables_insert, mock_api_tables_list, mock_api_datasets_get, mock_time_sleep, mock_uuid): schema = TestCases._create_inferred_schema() mock_uuid.return_value = TestCases._create_uuid() mock_time_sleep.return_value = None mock_api_datasets_get.return_value = True mock_api_tables_list.return_value = [] mock_api_tables_insert.return_value = {'selfLink': 'http://foo'} mock_api_tables_get.return_value = {'schema': {'fields': schema}} mock_api_tabledata_insert_all.return_value = {} table = TestCases._create_table_with_schema(schema) df = TestCases._create_data_frame() result = table.insert_data(df) self.assertIsNotNone(result, "insert_all should return the table object") mock_api_tabledata_insert_all.assert_called_with(('test', 'testds', 'testTable0', ''), [ {'insertId': '#0', 'json': {u'column': 'r0', u'headers': 10.0, u'some': 0}}, {'insertId': '#1', 'json': {u'column': 'r1', u'headers': 10.0, u'some': 1}}, {'insertId': '#2', 'json': {u'column': 'r2', u'headers': 10.0, u'some': 2}}, {'insertId': '#3', 'json': {u'column': 'r3', u'headers': 10.0, u'some': 3}} ]) @mock.patch('uuid.uuid4') @mock.patch('time.sleep') @mock.patch('datalab.bigquery._api.Api.datasets_get') @mock.patch('datalab.bigquery._api.Api.tables_list') @mock.patch('datalab.bigquery._api.Api.tables_insert') @mock.patch('datalab.bigquery._api.Api.tables_get') @mock.patch('datalab.bigquery._api.Api.tabledata_insert_all') def test_insert_data_dictlist(self, mock_api_tabledata_insert_all, mock_api_tables_get, mock_api_tables_insert, mock_api_tables_list, mock_api_datasets_get, mock_time_sleep, mock_uuid): schema = TestCases._create_inferred_schema() mock_uuid.return_value = TestCases._create_uuid() mock_time_sleep.return_value = None mock_api_datasets_get.return_value = True mock_api_tables_list.return_value = [] mock_api_tables_insert.return_value = {'selfLink': 'http://foo'} mock_api_tables_get.return_value = {'schema': {'fields': schema}} mock_api_tabledata_insert_all.return_value = {} table = TestCases._create_table_with_schema(schema) result = table.insert_data([ {u'column': 'r0', u'headers': 10.0, u'some': 0}, {u'column': 'r1', u'headers': 10.0, u'some': 1}, {u'column': 'r2', u'headers': 10.0, u'some': 2}, {u'column': 'r3', u'headers': 10.0, u'some': 3} ]) self.assertIsNotNone(result, "insert_all should return the table object") mock_api_tabledata_insert_all.assert_called_with(('test', 'testds', 'testTable0', ''), [ {'insertId': '#0', 'json': {u'column': 'r0', u'headers': 10.0, u'some': 0}}, {'insertId': '#1', 'json': {u'column': 'r1', u'headers': 10.0, u'some': 1}}, {'insertId': '#2', 'json': {u'column': 'r2', u'headers': 10.0, u'some': 2}}, {'insertId': '#3', 'json': {u'column': 'r3', u'headers': 10.0, u'some': 3}} ]) @mock.patch('uuid.uuid4') @mock.patch('time.sleep') @mock.patch('datalab.bigquery._api.Api.datasets_get') @mock.patch('datalab.bigquery._api.Api.tables_list') @mock.patch('datalab.bigquery._api.Api.tables_insert') @mock.patch('datalab.bigquery._api.Api.tables_get') @mock.patch('datalab.bigquery._api.Api.tabledata_insert_all') def test_insert_data_dictlist_index(self, mock_api_tabledata_insert_all, mock_api_tables_get, mock_api_tables_insert, mock_api_tables_list, mock_api_datasets_get, mock_time_sleep, mock_uuid): schema = TestCases._create_inferred_schema('Index') mock_uuid.return_value = TestCases._create_uuid() mock_time_sleep.return_value = None mock_api_datasets_get.return_value = True mock_api_tables_list.return_value = [] mock_api_tables_insert.return_value = {'selfLink': 'http://foo'} mock_api_tables_get.return_value = {'schema': {'fields': schema}} mock_api_tabledata_insert_all.return_value = {} table = TestCases._create_table_with_schema(schema) result = table.insert_data([ {u'column': 'r0', u'headers': 10.0, u'some': 0}, {u'column': 'r1', u'headers': 10.0, u'some': 1}, {u'column': 'r2', u'headers': 10.0, u'some': 2}, {u'column': 'r3', u'headers': 10.0, u'some': 3} ], include_index=True) self.assertIsNotNone(result, "insert_all should return the table object") mock_api_tabledata_insert_all.assert_called_with(('test', 'testds', 'testTable0', ''), [ {'insertId': '#0', 'json': {u'column': 'r0', u'headers': 10.0, u'some': 0, 'Index': 0}}, {'insertId': '#1', 'json': {u'column': 'r1', u'headers': 10.0, u'some': 1, 'Index': 1}}, {'insertId': '#2', 'json': {u'column': 'r2', u'headers': 10.0, u'some': 2, 'Index': 2}}, {'insertId': '#3', 'json': {u'column': 'r3', u'headers': 10.0, u'some': 3, 'Index': 3}} ]) @mock.patch('uuid.uuid4') @mock.patch('time.sleep') @mock.patch('datalab.bigquery._api.Api.datasets_get') @mock.patch('datalab.bigquery._api.Api.tables_list') @mock.patch('datalab.bigquery._api.Api.tables_insert') @mock.patch('datalab.bigquery._api.Api.tables_get') @mock.patch('datalab.bigquery._api.Api.tabledata_insert_all') def test_insert_data_dictlist_named_index(self, mock_api_tabledata_insert_all, mock_api_tables_get, mock_api_tables_insert, mock_api_tables_list, mock_api_datasets_get, mock_time_sleep, mock_uuid): schema = TestCases._create_inferred_schema('Row') mock_uuid.return_value = TestCases._create_uuid() mock_time_sleep.return_value = None mock_api_datasets_get.return_value = True mock_api_tables_list.return_value = [] mock_api_tables_insert.return_value = {'selfLink': 'http://foo'} mock_api_tables_get.return_value = {'schema': {'fields': schema}} mock_api_tabledata_insert_all.return_value = {} table = TestCases._create_table_with_schema(schema) result = table.insert_data([ {u'column': 'r0', u'headers': 10.0, u'some': 0}, {u'column': 'r1', u'headers': 10.0, u'some': 1}, {u'column': 'r2', u'headers': 10.0, u'some': 2}, {u'column': 'r3', u'headers': 10.0, u'some': 3} ], include_index=True, index_name='Row') self.assertIsNotNone(result, "insert_all should return the table object") mock_api_tabledata_insert_all.assert_called_with(('test', 'testds', 'testTable0', ''), [ {'insertId': '#0', 'json': {u'column': 'r0', u'headers': 10.0, u'some': 0, 'Row': 0}}, {'insertId': '#1', 'json': {u'column': 'r1', u'headers': 10.0, u'some': 1, 'Row': 1}}, {'insertId': '#2', 'json': {u'column': 'r2', u'headers': 10.0, u'some': 2, 'Row': 2}}, {'insertId': '#3', 'json': {u'column': 'r3', u'headers': 10.0, u'some': 3, 'Row': 3}} ]) @mock.patch('datalab.bigquery._api.Api.tables_get') @mock.patch('datalab.bigquery._api.Api.jobs_insert_load') @mock.patch('datalab.bigquery._api.Api.jobs_get') def test_table_load(self, mock_api_jobs_get, mock_api_jobs_insert_load, mock_api_tables_get): schema = TestCases._create_inferred_schema('Row') mock_api_jobs_get.return_value = {'status': {'state': 'DONE'}} mock_api_jobs_insert_load.return_value = None mock_api_tables_get.return_value = {'schema': {'fields': schema}} tbl = datalab.bigquery.Table('testds.testTable0', context=TestCases._create_context()) job = tbl.load('gs://foo') self.assertIsNone(job) mock_api_jobs_insert_load.return_value = {'jobReference': {'jobId': 'bar'}} job = tbl.load('gs://foo') self.assertEquals('bar', job.id) @mock.patch('datalab.bigquery._api.Api.table_extract') @mock.patch('datalab.bigquery._api.Api.jobs_get') @mock.patch('datalab.bigquery._api.Api.tables_get') def test_table_extract(self, mock_api_tables_get, mock_api_jobs_get, mock_api_table_extract): mock_api_tables_get.return_value = {} mock_api_jobs_get.return_value = {'status': {'state': 'DONE'}} mock_api_table_extract.return_value = None tbl = datalab.bigquery.Table('testds.testTable0', context=self._create_context()) job = tbl.extract('gs://foo') self.assertIsNone(job) mock_api_table_extract.return_value = {'jobReference': {'jobId': 'bar'}} job = tbl.extract('gs://foo') self.assertEquals('bar', job.id) @mock.patch('datalab.bigquery._api.Api.tabledata_list') @mock.patch('datalab.bigquery._api.Api.tables_get') def test_table_to_dataframe(self, mock_api_tables_get, mock_api_tabledata_list): schema = self._create_inferred_schema() mock_api_tables_get.return_value = {'schema': {'fields': schema}} mock_api_tabledata_list.return_value = { 'rows': [ {'f': [{'v': 1}, {'v': 'foo'}, {'v': 3.1415}]}, {'f': [{'v': 2}, {'v': 'bar'}, {'v': 0.5}]}, ] } tbl = datalab.bigquery.Table('testds.testTable0', context=TestCases._create_context()) df = tbl.to_dataframe() self.assertEquals(2, len(df)) self.assertEquals(1, df['some'][0]) self.assertEquals(2, df['some'][1]) self.assertEquals('foo', df['column'][0]) self.assertEquals('bar', df['column'][1]) self.assertEquals(3.1415, df['headers'][0]) self.assertEquals(0.5, df['headers'][1]) def test_encode_dict_as_row(self): when = dt.datetime(2001, 2, 3, 4, 5, 6, 7) row = datalab.bigquery.Table._encode_dict_as_row({'fo@o': 'b@r', 'b+ar': when}, {}) self.assertEqual({'foo': 'b@r', 'bar': '2001-02-03T04:05:06.000007'}, row) def test_decorators(self): tbl = datalab.bigquery.Table('testds.testTable0', context=TestCases._create_context()) tbl2 = tbl.snapshot(dt.timedelta(hours=-1)) self.assertEquals('test:testds.testTable0@-3600000', str(tbl2)) with self.assertRaises(Exception) as error: tbl2 = tbl2.snapshot(dt.timedelta(hours=-2)) self.assertEqual('Cannot use snapshot() on an already decorated table', str(error.exception)) with self.assertRaises(Exception) as error: tbl2.window(dt.timedelta(hours=-2), 0) self.assertEqual('Cannot use window() on an already decorated table', str(error.exception)) with self.assertRaises(Exception) as error: tbl.snapshot(dt.timedelta(days=-8)) self.assertEqual( 'Invalid snapshot relative when argument: must be within 7 days: -8 days, 0:00:00', str(error.exception)) with self.assertRaises(Exception) as error: tbl.snapshot(dt.timedelta(days=-8)) self.assertEqual( 'Invalid snapshot relative when argument: must be within 7 days: -8 days, 0:00:00', str(error.exception)) tbl2 = tbl.snapshot(dt.timedelta(days=-1)) self.assertEquals('test:testds.testTable0@-86400000', str(tbl2)) with self.assertRaises(Exception) as error: tbl.snapshot(dt.timedelta(days=1)) self.assertEqual('Invalid snapshot relative when argument: 1 day, 0:00:00', str(error.exception)) with self.assertRaises(Exception) as error: tbl2 = tbl.snapshot(1000) self.assertEqual('Invalid snapshot when argument type: 1000', str(error.exception)) self.assertEquals('test:testds.testTable0@-86400000', str(tbl2)) when = dt.datetime.utcnow() + dt.timedelta(1) with self.assertRaises(Exception) as error: tbl.snapshot(when) self.assertEqual('Invalid snapshot absolute when argument: %s' % when, str(error.exception)) when = dt.datetime.utcnow() - dt.timedelta(8) with self.assertRaises(Exception) as error: tbl.snapshot(when) self.assertEqual('Invalid snapshot absolute when argument: %s' % when, str(error.exception)) def test_window_decorators(self): # The at test above already tests many of the conversion cases. The extra things we # have to test are that we can use two values, we get a meaningful default for the second # if we pass None, and that the first time comes before the second. tbl = datalab.bigquery.Table('testds.testTable0', context=TestCases._create_context()) tbl2 = tbl.window(dt.timedelta(hours=-1)) self.assertEquals('test:testds.testTable0@-3600000-0', str(tbl2)) with self.assertRaises(Exception) as error: tbl2 = tbl2.window(-400000, 0) self.assertEqual('Cannot use window() on an already decorated table', str(error.exception)) with self.assertRaises(Exception) as error: tbl2.snapshot(-400000) self.assertEqual('Cannot use snapshot() on an already decorated table', str(error.exception)) with self.assertRaises(Exception) as error: tbl.window(dt.timedelta(0), dt.timedelta(hours=-1)) self.assertEqual( 'window: Between arguments: begin must be before end: 0:00:00, -1 day, 23:00:00', str(error.exception)) @mock.patch('datalab.bigquery._api.Api.tables_get') @mock.patch('datalab.bigquery._api.Api.table_update') def test_table_update(self, mock_api_table_update, mock_api_tables_get): schema = self._create_inferred_schema() info = {'schema': {'fields': schema}, 'friendlyName': 'casper', 'description': 'ghostly logs', 'expirationTime': calendar.timegm(dt.datetime(2020, 1, 1).utctimetuple()) * 1000} mock_api_tables_get.return_value = info tbl = datalab.bigquery.Table('testds.testTable0', context=TestCases._create_context()) new_name = 'aziraphale' new_description = 'demon duties' new_schema = [{'name': 'injected', 'type': 'FLOAT'}] new_schema.extend(schema) new_expiry = dt.datetime(2030, 1, 1) tbl.update(new_name, new_description, new_expiry, new_schema) name, info = mock_api_table_update.call_args[0] self.assertEqual(tbl.name, name) self.assertEqual(new_name, tbl.metadata.friendly_name) self.assertEqual(new_description, tbl.metadata.description) self.assertEqual(new_expiry, tbl.metadata.expires_on) self.assertEqual(len(new_schema), len(tbl.schema)) def test_table_to_query(self): tbl = datalab.bigquery.Table('testds.testTable0', context=TestCases._create_context()) q = tbl.to_query() self.assertEqual('SELECT * FROM [test:testds.testTable0]', q.sql) q = tbl.to_query('foo, bar') self.assertEqual('SELECT foo, bar FROM [test:testds.testTable0]', q.sql) q = tbl.to_query(['bar', 'foo']) self.assertEqual('SELECT bar,foo FROM [test:testds.testTable0]', q.sql) @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return datalab.context.Context(project_id, creds) @staticmethod def _create_table(name): return datalab.bigquery.Table(name, TestCases._create_context()) @staticmethod def _create_table_info_result(ts=None): if ts is None: ts = dt.datetime.utcnow() epoch = dt.datetime.utcfromtimestamp(0) timestamp = (ts - epoch).total_seconds() * 1000 return { 'description': 'Daily Logs Table', 'friendlyName': 'Logs', 'numBytes': 1000, 'numRows': 2, 'creationTime': timestamp, 'lastModifiedTime': timestamp, 'schema': { 'fields': [ {'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE'}, {'name': 'val', 'type': 'INTEGER', 'mode': 'NULLABLE'} ] } } @staticmethod def _create_table_info_nested_schema_result(ts=None): if ts is None: ts = dt.datetime.utcnow() epoch = dt.datetime.utcfromtimestamp(0) timestamp = (ts - epoch).total_seconds() * 1000 return { 'description': 'Daily Logs Table', 'friendlyName': 'Logs', 'numBytes': 1000, 'numRows': 2, 'creationTime': timestamp, 'lastModifiedTime': timestamp, 'schema': { 'fields': [ {'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE'}, {'name': 'val', 'type': 'INTEGER', 'mode': 'NULLABLE'}, {'name': 'more', 'type': 'RECORD', 'mode': 'REPEATED', 'fields': [ {'name': 'xyz', 'type': 'INTEGER', 'mode': 'NULLABLE'} ] } ] } } @staticmethod def _create_dataset(dataset_id): return datalab.bigquery.Dataset(dataset_id, context=TestCases._create_context()) @staticmethod def _create_table_list_result(): return { 'tables': [ { 'type': 'TABLE', 'tableReference': {'projectId': 'test', 'datasetId': 'testds', 'tableId': 'testTable1'} }, { 'type': 'VIEW', 'tableReference': {'projectId': 'test', 'datasetId': 'testds', 'tableId': 'testView1'} }, { 'type': 'TABLE', 'tableReference': {'projectId': 'test', 'datasetId': 'testds', 'tableId': 'testTable2'} } ] } @staticmethod def _create_table_list_empty_result(): return { 'tables': [] } @staticmethod def _create_data_frame(): data = { 'some': [ 0, 1, 2, 3 ], 'column': [ 'r0', 'r1', 'r2', 'r3' ], 'headers': [ 10.0, 10.0, 10.0, 10.0 ] } return pandas.DataFrame(data) @staticmethod def _create_inferred_schema(extra_field=None): schema = [ {'name': 'some', 'type': 'INTEGER'}, {'name': 'column', 'type': 'STRING'}, {'name': 'headers', 'type': 'FLOAT'}, ] if extra_field: schema.append({'name': extra_field, 'type': 'INTEGER'}) return schema @staticmethod def _create_table_with_schema(schema, name='test:testds.testTable0'): return datalab.bigquery.Table(name, TestCases._create_context()).create(schema) class _uuid(object): @property def hex(self): return '#' @staticmethod def _create_uuid(): return TestCases._uuid() ================================================ FILE: legacy_tests/bigquery/udf_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import mock import unittest import google.auth import datalab.bigquery import datalab.context class TestCases(unittest.TestCase): def test_sql_building(self): context = self._create_context() table = datalab.bigquery.Table('test:requestlogs.today', context=context) udf = self._create_udf() query = datalab.bigquery.Query('SELECT * FROM foo($t)', t=table, udfs=[udf], context=context) expected_js = '\nfoo=function(r,emit) { emit({output1: r.field2, output2: r.field1 }); };\n' +\ 'bigquery.defineFunction(\'foo\', ["field1", "field2"], ' +\ '[{"name": "output1", "type": "integer"}, ' +\ '{"name": "output2", "type": "string"}], foo);' self.assertEqual(query.sql, 'SELECT * FROM ' '(SELECT output1, output2 FROM foo([test:requestlogs.today]))') self.assertEqual(udf._code, expected_js) @staticmethod def _create_udf(): inputs = [('field1', 'string'), ('field2', 'integer')] outputs = [('output1', 'integer'), ('output2', 'string')] impl = 'function(r,emit) { emit({output1: r.field2, output2: r.field1 }); }' udf = datalab.bigquery.UDF(inputs, outputs, 'foo', impl) return udf @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return datalab.context.Context(project_id, creds) ================================================ FILE: legacy_tests/bigquery/view_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals from builtins import str import mock import unittest import google.auth import datalab.bigquery import datalab.context class TestCases(unittest.TestCase): def test_view_repr_sql(self): name = 'test:testds.testView0' view = datalab.bigquery.View(name, TestCases._create_context()) self.assertEqual('[%s]' % name, view._repr_sql_()) @mock.patch('datalab.bigquery._api.Api.tables_insert') @mock.patch('datalab.bigquery._api.Api.tables_get') @mock.patch('datalab.bigquery._api.Api.tables_list') @mock.patch('datalab.bigquery._api.Api.datasets_get') def test_view_create(self, mock_api_datasets_get, mock_api_tables_list, mock_api_tables_get, mock_api_tables_insert): mock_api_datasets_get.return_value = None mock_api_tables_list.return_value = [] mock_api_tables_get.return_value = None mock_api_tables_insert.return_value = TestCases._create_tables_insert_success_result() name = 'test:testds.testView0' sql = 'select * from test:testds.testTable0' view = datalab.bigquery.View(name, TestCases._create_context()) result = view.create(sql) self.assertTrue(view.exists()) self.assertEqual(name, str(view)) self.assertEqual('[%s]' % name, view._repr_sql_()) self.assertIsNotNone(result, 'Expected a view') @mock.patch('datalab.bigquery._api.Api.tables_insert') @mock.patch('datalab.bigquery._api.Api.tabledata_list') @mock.patch('datalab.bigquery._api.Api.jobs_insert_query') @mock.patch('datalab.bigquery._api.Api.jobs_query_results') @mock.patch('datalab.bigquery._api.Api.jobs_get') @mock.patch('datalab.bigquery._api.Api.tables_get') def test_view_result(self, mock_api_tables_get, mock_api_jobs_get, mock_api_jobs_query_results, mock_api_insert_query, mock_api_tabledata_list, mock_api_tables_insert): mock_api_insert_query.return_value = TestCases._create_insert_done_result() mock_api_tables_insert.return_value = TestCases._create_tables_insert_success_result() mock_api_jobs_query_results.return_value = {'jobComplete': True} mock_api_tables_get.return_value = TestCases._create_tables_get_result() mock_api_jobs_get.return_value = {'status': {'state': 'DONE'}} mock_api_tabledata_list.return_value = TestCases._create_single_row_result() name = 'test:testds.testView0' sql = 'select * from test:testds.testTable0' view = datalab.bigquery.View(name, TestCases._create_context()) view.create(sql) results = view.results() self.assertEqual(1, results.length) first_result = results[0] self.assertEqual('value1', first_result['field1']) @mock.patch('datalab.bigquery._api.Api.tables_insert') @mock.patch('datalab.bigquery._api.Api.tables_get') @mock.patch('datalab.bigquery._api.Api.table_update') @mock.patch('datalab.context.Context.default') def test_view_update(self, mock_context_default, mock_api_table_update, mock_api_tables_get, mock_api_tables_insert): mock_api_tables_insert.return_value = TestCases._create_tables_insert_success_result() mock_context_default.return_value = TestCases._create_context() mock_api_table_update.return_value = None friendly_name = 'casper' description = 'ghostly logs' sql = 'select * from [test:testds.testTable0]' info = {'friendlyName': friendly_name, 'description': description, 'view': {'query': sql}} mock_api_tables_get.return_value = info name = 'test:testds.testView0' view = datalab.bigquery.View(name, TestCases._create_context()) view.create(sql) self.assertEqual(friendly_name, view.friendly_name) self.assertEqual(description, view.description) self.assertEqual(sql, view.query.sql) new_friendly_name = 'aziraphale' new_description = 'demon duties' new_query = 'SELECT 3 AS x' view.update(new_friendly_name, new_description, new_query) self.assertEqual(new_friendly_name, view.friendly_name) self.assertEqual(new_description, view.description) self.assertEqual(new_query, view.query.sql) @staticmethod def _create_tables_insert_success_result(): return {'selfLink': 'http://foo'} @staticmethod def _create_insert_done_result(): # pylint: disable=g-continuation-in-parens-misaligned return { 'jobReference': { 'jobId': 'test_job' }, 'configuration': { 'query': { 'destinationTable': { 'projectId': 'project', 'datasetId': 'dataset', 'tableId': 'table' } } }, 'jobComplete': True, } @staticmethod def _create_tables_get_result(num_rows=1, schema=None): if not schema: schema = [{'name': 'field1', 'type': 'string'}] return { 'numRows': num_rows, 'schema': { 'fields': schema }, } @staticmethod def _create_single_row_result(): # pylint: disable=g-continuation-in-parens-misaligned return { 'totalRows': 1, 'rows': [ {'f': [{'v': 'value1'}]} ] } @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return datalab.context.Context(project_id, creds) ================================================ FILE: legacy_tests/data/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: legacy_tests/data/sql_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import imp import unittest import datalab.data class TestCases(unittest.TestCase): def test_sql_tokenizer(self): query = "SELECT * FROM function(SELECT * FROM [table]) -- a comment\n" \ "WHERE x>0 AND y == 'cat'/*\nmulti-line comment */LIMIT 10" tokens = datalab.data.tokenize(query) # Make sure we get all the content self.assertEquals(''.join(tokens), query) # Then check the tokens expected = [ 'SELECT', ' ', '*', ' ', 'FROM', ' ', 'function', '(', 'SELECT', ' ', '*', ' ', 'FROM', ' ', '[', 'table', ']', ')', ' ', '-- a comment\n', 'WHERE', ' ', 'x', '>', '0', ' ', 'AND', ' ', 'y', ' ', '=', '=', ' ', '\'cat\'', '/*\nmulti-line comment */', 'LIMIT', ' ', '10' ] self.assertEquals(expected, tokens) def test_zero_placeholders(self): queries = ['SELECT * FROM [logs.today]', ' SELECT time FROM [logs.today] '] for query in queries: formatted_query = datalab.data.SqlStatement.format(query, None) self.assertEqual(query, formatted_query) def test_single_placeholder(self): query = 'SELECT time FROM [logs.today] WHERE status == $param' args = {'param': 200} formatted_query = datalab.data.SqlStatement.format(query, args) self.assertEqual(formatted_query, 'SELECT time FROM [logs.today] WHERE status == 200') def test_multiple_placeholders(self): query = ('SELECT time FROM [logs.today] ' 'WHERE status == $status AND path == $path') args = {'status': 200, 'path': '/home'} formatted_query = datalab.data.SqlStatement.format(query, args) self.assertEqual(formatted_query, ('SELECT time FROM [logs.today] ' 'WHERE status == 200 AND path == "/home"')) def test_escaped_placeholder(self): query = 'SELECT time FROM [logs.today] WHERE path == "/foo$$bar"' args = {'status': 200} formatted_query = datalab.data.SqlStatement.format(query, args) self.assertEqual(formatted_query, 'SELECT time FROM [logs.today] WHERE path == "/foo$bar"') def test_string_escaping(self): query = 'SELECT time FROM [logs.today] WHERE path == $path' args = {'path': 'xyz"xyz'} formatted_query = datalab.data.SqlStatement.format(query, args) self.assertEqual(formatted_query, 'SELECT time FROM [logs.today] WHERE path == "xyz\\"xyz"') def test_all_combinations(self): query = ('SELECT time FROM ' ' (SELECT * FROM [logs.today] ' ' WHERE path contains "$$" AND path contains $segment ' ' AND status == $status) ' 'WHERE success == $success AND server == "$$master" ' 'LIMIT $pageSize') args = {'status': 200, 'pageSize': 10, 'success': False, 'segment': 'home'} expected_query = ('SELECT time FROM ' ' (SELECT * FROM [logs.today] ' ' WHERE path contains "$" AND path contains "home" ' ' AND status == 200) ' 'WHERE success == False AND server == "$master" ' 'LIMIT 10') formatted_query = datalab.data.SqlStatement.format(query, args) self.assertEqual(formatted_query, expected_query) def test_missing_args(self): query = 'SELECT time FROM [logs.today] WHERE status == $status' args = {'s': 200} with self.assertRaises(Exception) as error: datalab.data.SqlStatement.format(query, args) e = error.exception self.assertEqual('Unsatisfied dependency $status', str(e)) def test_invalid_args(self): query = 'SELECT time FROM [logs.today] WHERE status == $0' with self.assertRaises(Exception) as error: datalab.data.SqlStatement.format(query, {}) e = error.exception self.assertEqual( 'Invalid sql; $ with no following $ or identifier: ' + query + '.', str(e)) def test_nested_queries(self): query1 = datalab.data.SqlStatement('SELECT 3 as x') query2 = datalab.data.SqlStatement('SELECT x FROM $query1') query3 = 'SELECT * FROM $query2 WHERE x == $count' self.assertEquals('SELECT 3 as x', query1.sql) with self.assertRaises(Exception) as e: datalab.data.SqlStatement.format(query3)[0] self.assertEquals('Unsatisfied dependency $query2', str(e.exception)) with self.assertRaises(Exception) as e: datalab.data.SqlStatement.format(query3, {'query1': query1}) self.assertEquals('Unsatisfied dependency $query2', str(e.exception)) with self.assertRaises(Exception) as e: datalab.data.SqlStatement.format(query3, {'query2': query2}) self.assertEquals('Unsatisfied dependency $query1', str(e.exception)) with self.assertRaises(Exception) as e: datalab.data.SqlStatement.format(query3, {'query1': query1, 'query2': query2}) self.assertEquals('Unsatisfied dependency $count', str(e.exception)) formatted_query =\ datalab.data.SqlStatement.format(query3, {'query1': query1, 'query2': query2, 'count': 5}) self.assertEqual('SELECT * FROM (SELECT x FROM (SELECT 3 as x)) WHERE x == 5', formatted_query) def test_shared_nested_queries(self): query1 = datalab.data.SqlStatement('SELECT 3 as x') query2 = datalab.data.SqlStatement('SELECT x FROM $query1') query3 = 'SELECT x AS y FROM $query1, x FROM $query2' formatted_query = datalab.data.SqlStatement.format(query3, {'query1': query1, 'query2': query2}) self.assertEqual('SELECT x AS y FROM (SELECT 3 as x), x FROM (SELECT x FROM (SELECT 3 as x))', formatted_query) def test_circular_references(self): query1 = datalab.data.SqlStatement('SELECT * FROM $query3') query2 = datalab.data.SqlStatement('SELECT x FROM $query1') query3 = datalab.data.SqlStatement('SELECT * FROM $query2 WHERE x == $count') args = {'query1': query1, 'query2': query2, 'query3': query3} with self.assertRaises(Exception) as e: datalab.data.SqlStatement.format('SELECT * FROM $query1', args) self.assertEquals('Circular dependency in $query1', str(e.exception)) with self.assertRaises(Exception) as e: datalab.data.SqlStatement.format('SELECT * FROM $query2', args) self.assertEquals('Circular dependency in $query2', str(e.exception)) with self.assertRaises(Exception) as e: datalab.data.SqlStatement.format('SELECT * FROM $query3', args) self.assertEquals('Circular dependency in $query3', str(e.exception)) def test_module_reference(self): m = imp.new_module('m') m.__dict__['q1'] = datalab.data.SqlStatement('SELECT 3 AS x') m.__dict__[datalab.data._utils._SQL_MODULE_LAST] =\ datalab.data.SqlStatement('SELECT * FROM $q1 LIMIT 10') with self.assertRaises(Exception) as e: datalab.data.SqlStatement.format('SELECT * FROM $s', {'s': m}) self.assertEquals('Unsatisfied dependency $q1', str(e.exception)) formatted_query = datalab.data.SqlStatement.format('SELECT * FROM $s', {'s': m, 'q1': m.q1}) self.assertEqual('SELECT * FROM (SELECT * FROM (SELECT 3 AS x) LIMIT 10)', formatted_query) formatted_query = datalab.data.SqlStatement.format('SELECT * FROM $s', {'s': m.q1}) self.assertEqual('SELECT * FROM (SELECT 3 AS x)', formatted_query) def test_get_sql_statement_with_environment(self): # TODO(gram). pass def test_get_query_from_module(self): # TODO(gram). pass def test_get_sql_args(self): # TODO(gram). pass ================================================ FILE: legacy_tests/kernel/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: legacy_tests/kernel/bigquery_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import mock import unittest import google.auth # import Python so we can mock the parts we need to here. import IPython import IPython.core.magic def noop_decorator(func): return func IPython.core.magic.register_line_cell_magic = noop_decorator IPython.core.magic.register_line_magic = noop_decorator IPython.core.magic.register_cell_magic = noop_decorator IPython.get_ipython = mock.Mock() import datalab.bigquery # noqa: E402 import datalab.bigquery.commands # noqa: E402 import datalab.context # noqa: E402 import datalab.utils.commands # noqa: E402 class TestCases(unittest.TestCase): @mock.patch('datalab.utils.commands.notebook_environment') @mock.patch('datalab.context._context.Context.default') def test_udf_cell(self, mock_default_context, mock_notebook_environment): env = {} cell_body = """ /** * @param {{word: string, corpus: string, word_count: integer}} r * @param function({{word: string, corpus: string, count: integer}}) emitFn */ function(r, emitFn) { if (r.word.match(/[shakespeare]/) !== null) { var result = { word: r.word, corpus: r.corpus, count: r.word_count }; emitFn(result); } } """ mock_default_context.return_value = TestCases._create_context() mock_notebook_environment.return_value = env datalab.bigquery.commands._bigquery._udf_cell({'module': 'word_filter'}, cell_body) udf = env['word_filter'] self.assertIsNotNone(udf) self.assertEquals('word_filter', udf._name) self.assertEquals([('word', 'string'), ('corpus', 'string'), ('count', 'integer')], udf._outputs) self.assertEquals(cell_body, udf._implementation) @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return datalab.context.Context(project_id, creds) def test_sample_cell(self): # TODO(gram): complete this test pass def test_get_schema(self): # TODO(gram): complete this test pass def test_get_table(self): # TODO(gram): complete this test pass def test_table_viewer(self): # TODO(gram): complete this test pass ================================================ FILE: legacy_tests/kernel/chart_data_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import json import mock import unittest # import Python so we can mock the parts we need to here. import IPython.core.display import IPython.core.magic import datalab.utils.commands def noop_decorator(func): return func IPython.core.magic.register_line_cell_magic = noop_decorator IPython.core.magic.register_line_magic = noop_decorator IPython.core.magic.register_cell_magic = noop_decorator IPython.core.display.HTML = lambda x: x IPython.core.display.JSON = lambda x: x class TestCases(unittest.TestCase): @mock.patch('datalab.utils.get_item') def test_get_chart_data(self, mock_get_item): IPython.get_ipython().user_ns = {} t = [ {'country': 'US', 'quantity': 100}, {'country': 'ZA', 'quantity': 50}, {'country': 'UK', 'quantity': 75}, {'country': 'AU', 'quantity': 25} ] mock_get_item.return_value = t ds = datalab.utils.commands.get_data_source_index('t') data = datalab.utils.commands._chart_data._get_chart_data('', json.dumps({ 'source_index': ds, 'fields': 'country', 'first': 1, 'count': 1 })) self.assertEquals({"data": {"rows": [{"c": [{"v": "ZA"}]}], "cols": [{"type": "string", "id": "country", "label": "country"}]}, "refresh_interval": 0, "options": {}}, data) data = datalab.utils.commands._chart_data._get_chart_data('', json.dumps({ 'source_index': ds, 'fields': 'country', 'first': 6, 'count': 1 })) self.assertEquals({"data": {"rows": [], "cols": [{"type": "string", "id": "country", "label": "country"}]}, "refresh_interval": 0, "options": {}}, data) data = datalab.utils.commands._chart_data._get_chart_data('', json.dumps({ 'source_index': ds, 'fields': 'country', 'first': 2, 'count': 0 })) self.assertEquals({"data": {"rows": [], "cols": [{"type": "string", "id": "country", "label": "country"}]}, "refresh_interval": 0, "options": {}}, data) ================================================ FILE: legacy_tests/kernel/chart_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import unittest # import Python so we can mock the parts we need to here. import IPython.core.display import IPython.core.magic import datalab.utils.commands def noop_decorator(func): return func IPython.core.magic.register_line_cell_magic = noop_decorator IPython.core.magic.register_line_magic = noop_decorator IPython.core.magic.register_cell_magic = noop_decorator IPython.core.display.HTML = lambda x: x IPython.core.display.JSON = lambda x: x class TestCases(unittest.TestCase): def test_chart_cell(self): t = [{'country': 'US', 'quantity': 100}, {'country': 'ZA', 'quantity': 50}] IPython.get_ipython().user_ns = {} chart = datalab.utils.commands._chart._chart_cell({'chart': 'geo', 'data': t, 'fields': None}, '') self.assertTrue(chart.find('charts.render(') > 0) self.assertTrue(chart.find('\'geo\'') > 0) self.assertTrue(chart.find('"fields": "*"') > 0) self.assertTrue(chart.find('{"c": [{"v": "US"}, {"v": 100}]}') > 0 or chart.find('{"c": [{"v": 100}, {"v": "US"}]}') > 0) self.assertTrue(chart.find('{"c": [{"v": "ZA"}, {"v": 50}]}') > 0 or chart.find('{"c": [{"v": 50}, {"v": "ZA"}]}') > 0) def test_chart_magic(self): # TODO(gram): complete this test pass ================================================ FILE: legacy_tests/kernel/commands_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import mock import unittest # import Python so we can mock the parts we need to here. import IPython import IPython.core.magic def noop_decorator(func): return func IPython.core.magic.register_line_cell_magic = noop_decorator IPython.core.magic.register_line_magic = noop_decorator IPython.core.magic.register_cell_magic = noop_decorator IPython.get_ipython = mock.Mock() class TestCases(unittest.TestCase): def test_create_args(self): # TODO(gram): complete this test pass ================================================ FILE: legacy_tests/kernel/html_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import mock import unittest # import Python so we can mock the parts we need to here. import IPython import IPython.core.magic def noop_decorator(func): return func IPython.core.magic.register_line_cell_magic = noop_decorator IPython.core.magic.register_line_magic = noop_decorator IPython.core.magic.register_cell_magic = noop_decorator IPython.get_ipython = mock.Mock() class TestCases(unittest.TestCase): def test_render_table(self): # TODO(gram): complete this test pass def test_render_text(self): # TODO(gram): complete this test pass ================================================ FILE: legacy_tests/kernel/module_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import mock import sys import unittest # import Python so we can mock the parts we need to here. import IPython import IPython.core.magic import datalab.utils.commands def noop_decorator(func): return func IPython.core.magic.register_line_cell_magic = noop_decorator IPython.core.magic.register_line_magic = noop_decorator IPython.core.magic.register_cell_magic = noop_decorator IPython.get_ipython = mock.Mock() class TestCases(unittest.TestCase): def test_create_python_module(self): datalab.utils.commands._modules._create_python_module('bar', 'y=1') self.assertIsNotNone(sys.modules['bar']) self.assertEqual(1, sys.modules['bar'].y) def test_pymodule(self): datalab.utils.commands._modules.pymodule('--name foo', 'x=1') self.assertIsNotNone(sys.modules['foo']) self.assertEqual(1, sys.modules['foo'].x) @mock.patch('datalab.utils.commands._modules._pymodule_cell', autospec=True) def test_pymodule_magic(self, mock_pymodule_cell): datalab.utils.commands._modules.pymodule('-n foo') mock_pymodule_cell.assert_called_with({ 'name': 'foo', 'func': datalab.utils.commands._modules._pymodule_cell }, None) ================================================ FILE: legacy_tests/kernel/sql_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import imp import mock import unittest # import Python so we can mock the parts we need to here. import IPython import IPython.core.magic import google.auth import datalab.bigquery import datalab.context import datalab.data import datalab.data.commands import datalab.utils.commands def noop_decorator(func): return func IPython.core.magic.register_line_cell_magic = noop_decorator IPython.core.magic.register_line_magic = noop_decorator IPython.core.magic.register_cell_magic = noop_decorator IPython.get_ipython = mock.Mock() class TestCases(unittest.TestCase): _SQL_MODULE_MAIN = datalab.data._utils._SQL_MODULE_MAIN _SQL_MODULE_LAST = datalab.data._utils._SQL_MODULE_LAST def test_split_cell(self): # TODO(gram): add tests for argument parser. m = imp.new_module('m') query = datalab.data.commands._sql._split_cell('', m) self.assertIsNone(query) self.assertNotIn(TestCases._SQL_MODULE_LAST, m.__dict__) self.assertNotIn(TestCases._SQL_MODULE_MAIN, m.__dict__) m = imp.new_module('m') query = datalab.data.commands._sql._split_cell('\n\n', m) self.assertIsNone(query) self.assertNotIn(TestCases._SQL_MODULE_LAST, m.__dict__) self.assertNotIn(TestCases._SQL_MODULE_MAIN, m.__dict__) m = imp.new_module('m') query = datalab.data.commands._sql._split_cell('# This is a comment\n\nSELECT 3 AS x', m) self.assertEquals(query, m.__dict__[TestCases._SQL_MODULE_MAIN]) self.assertEquals(query, m.__dict__[TestCases._SQL_MODULE_LAST]) self.assertEquals('SELECT 3 AS x', m.__dict__[TestCases._SQL_MODULE_MAIN].sql) self.assertEquals('SELECT 3 AS x', m.__dict__[TestCases._SQL_MODULE_LAST].sql) m = imp.new_module('m') query = datalab.data.commands._sql._split_cell( '# This is a comment\n\nfoo="bar"\nSELECT 3 AS x', m) self.assertEquals(query, m.__dict__[TestCases._SQL_MODULE_MAIN]) self.assertEquals(query, m.__dict__[TestCases._SQL_MODULE_LAST]) self.assertEquals('SELECT 3 AS x', m.__dict__[TestCases._SQL_MODULE_MAIN].sql) self.assertEquals('SELECT 3 AS x', m.__dict__[TestCases._SQL_MODULE_LAST].sql) sql_string_list = ['SELECT 3 AS x', 'WITH q1 as (SELECT "1")\nSELECT * FROM q1', 'INSERT DataSet.Table (Id, Description)\nVALUES(100,"TestDesc")', 'INSERT DataSet.Table (Id, Description)\n' 'SELECT * FROM UNNEST([(200,"TestDesc2"),(300,"TestDesc3")])' 'INSERT DataSet.Table (Id, Description)\n' + 'WITH w as (SELECT ARRAY>\n' + '[(400, "TestDesc4"),(500, "TestDesc5")] col)\n' + 'SELECT Id, Description FROM w, UNNEST(w.col)' 'INSERT DataSet.Table (Id, Description)\n' + 'VALUES (600,\n' + '(SELECT Description FROM DataSet.Table WHERE Id = 400))', 'DELETE FROM DataSet.Table WHERE DESCRIPTION IS NULL' 'DELETE FROM DataSet.Table\n' + 'WHERE Id NOT IN (100, 200, 300)' ] for i in range(0, len(sql_string_list)): m = imp.new_module('m') query = datalab.data.commands._sql._split_cell(sql_string_list[i], m) self.assertEquals(query, m.__dict__[TestCases._SQL_MODULE_MAIN]) self.assertEquals(query, m.__dict__[TestCases._SQL_MODULE_LAST]) self.assertEquals(sql_string_list[i], m.__dict__[TestCases._SQL_MODULE_MAIN].sql) self.assertEquals(sql_string_list[i], m.__dict__[TestCases._SQL_MODULE_LAST].sql) m = imp.new_module('m') query = datalab.data.commands._sql._split_cell('DEFINE QUERY q1\nSELECT 3 AS x', m) self.assertEquals(query, m.__dict__[TestCases._SQL_MODULE_LAST]) self.assertEquals(query, m.__dict__[TestCases._SQL_MODULE_LAST]) self.assertEquals('SELECT 3 AS x', m.q1.sql) self.assertNotIn(TestCases._SQL_MODULE_MAIN, m.__dict__) self.assertEquals('SELECT 3 AS x', m.__dict__[TestCases._SQL_MODULE_LAST].sql) m = imp.new_module('m') query = datalab.data.commands._sql._split_cell( 'DEFINE QUERY q1\nSELECT 3 AS x\nSELECT * FROM $q1', m) self.assertEquals(query, m.__dict__[TestCases._SQL_MODULE_MAIN]) self.assertEquals(query, m.__dict__[TestCases._SQL_MODULE_LAST]) self.assertEquals('SELECT 3 AS x', m.q1.sql) self.assertEquals('SELECT * FROM $q1', m.__dict__[TestCases._SQL_MODULE_MAIN].sql) self.assertEquals('SELECT * FROM $q1', m.__dict__[TestCases._SQL_MODULE_LAST].sql) @mock.patch('datalab.context._context.Context.default') def test_arguments(self, mock_default_context): mock_default_context.return_value = TestCases._create_context() m = imp.new_module('m') query = datalab.data.commands._sql._split_cell(""" words = ('thus', 'forsooth') limit = 10 SELECT * FROM [publicdata:samples.shakespeare] WHERE word IN $words LIMIT $limit """, m) sql = datalab.bigquery.Query(query, values={}).sql self.assertEquals('SELECT * FROM [publicdata:samples.shakespeare]\n' + 'WHERE word IN ("thus", "forsooth")\nLIMIT 10', sql) # As above but with overrides, using list sql = datalab.bigquery.Query(query, words=['eyeball'], limit=5).sql self.assertEquals('SELECT * FROM [publicdata:samples.shakespeare]\n' + 'WHERE word IN ("eyeball")\nLIMIT 5', sql) # As above but with overrides, using tuple and values dict sql = datalab.bigquery.Query(query, values={'limit': 3, 'words': ('thus',)}).sql self.assertEquals('SELECT * FROM [publicdata:samples.shakespeare]\n' + 'WHERE word IN ("thus")\nLIMIT 3', sql) # As above but with list argument m = imp.new_module('m') query = datalab.data.commands._sql._split_cell(""" words = ['thus', 'forsooth'] limit = 10 SELECT * FROM [publicdata:samples.shakespeare] WHERE word IN $words LIMIT $limit """, m) sql = datalab.bigquery.Query(query, values={}).sql self.assertEquals('SELECT * FROM [publicdata:samples.shakespeare]\n' + 'WHERE word IN ("thus", "forsooth")\nLIMIT 10', sql) # As above but with overrides, using list sql = datalab.bigquery.Query(query, values={'limit': 2, 'words': ['forsooth']}).sql self.assertEquals('SELECT * FROM [publicdata:samples.shakespeare]\n' + 'WHERE word IN ("forsooth")\nLIMIT 2', sql) # As above but with overrides, using tuple sql = datalab.bigquery.Query(query, words=('eyeball',)).sql self.assertEquals('SELECT * FROM [publicdata:samples.shakespeare]\n' + 'WHERE word IN ("eyeball")\nLIMIT 10', sql) # TODO(gram): add some tests for source and datestring variables def test_date(self): # TODO(gram): complete this test pass def test_sql_cell(self): # TODO(gram): complete this test pass @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return datalab.context.Context(project_id, creds) ================================================ FILE: legacy_tests/kernel/storage_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import mock import unittest # import Python so we can mock the parts we need to here. import IPython import IPython.core.magic import google.auth import datalab.context import datalab.storage import datalab.storage.commands def noop_decorator(func): return func IPython.core.magic.register_line_cell_magic = noop_decorator IPython.core.magic.register_line_magic = noop_decorator IPython.core.magic.register_cell_magic = noop_decorator IPython.get_ipython = mock.Mock() class TestCases(unittest.TestCase): @mock.patch('datalab.storage._item.Item.exists', autospec=True) @mock.patch('datalab.storage._bucket.Bucket.items', autospec=True) @mock.patch('datalab.storage._api.Api.objects_get', autospec=True) @mock.patch('datalab.context._context.Context.default') def test_expand_list(self, mock_context_default, mock_api_objects_get, mock_bucket_items, mock_item_exists): context = TestCases._create_context() mock_context_default.return_value = context # Mock API for testing for item existence. Fail if called with name that includes wild char. def item_exists_side_effect(*args, **kwargs): return args[0].key.find('*') < 0 mock_item_exists.side_effect = item_exists_side_effect # Mock API for getting items in a bucket. mock_bucket_items.side_effect = TestCases._mock_bucket_items_return(context) # Mock API for getting item metadata. mock_api_objects_get.side_effect = TestCases._mock_api_objects_get() items = datalab.storage.commands._storage._expand_list(None) self.assertEqual([], items) items = datalab.storage.commands._storage._expand_list([]) self.assertEqual([], items) items = datalab.storage.commands._storage._expand_list('gs://bar/o*') self.assertEqual(['gs://bar/object1', 'gs://bar/object3'], items) items = datalab.storage.commands._storage._expand_list(['gs://foo', 'gs://bar']) self.assertEqual(['gs://foo', 'gs://bar'], items) items = datalab.storage.commands._storage._expand_list(['gs://foo/*', 'gs://bar']) self.assertEqual(['gs://foo/item1', 'gs://foo/item2', 'gs://foo/item3', 'gs://bar'], items) items = datalab.storage.commands._storage._expand_list(['gs://bar/o*']) self.assertEqual(['gs://bar/object1', 'gs://bar/object3'], items) items = datalab.storage.commands._storage._expand_list(['gs://bar/i*']) # Note - if no match we return the pattern. self.assertEqual(['gs://bar/i*'], items) items = datalab.storage.commands._storage._expand_list(['gs://baz']) self.assertEqual(['gs://baz'], items) items = datalab.storage.commands._storage._expand_list(['gs://baz/*']) self.assertEqual(['gs://baz/*'], items) items = datalab.storage.commands._storage._expand_list(['gs://foo/i*3']) self.assertEqual(['gs://foo/item3'], items) @mock.patch('datalab.storage._item.Item.copy_to', autospec=True) @mock.patch('datalab.storage._bucket.Bucket.items', autospec=True) @mock.patch('datalab.storage._api.Api.objects_get', autospec=True) @mock.patch('datalab.context._context.Context.default') def test_storage_copy(self, mock_context_default, mock_api_objects_get, mock_bucket_items, mock_storage_item_copy_to): context = TestCases._create_context() mock_context_default.return_value = context # Mock API for getting items in a bucket. mock_bucket_items.side_effect = TestCases._mock_bucket_items_return(context) # Mock API for getting item metadata. mock_api_objects_get.side_effect = TestCases._mock_api_objects_get() datalab.storage.commands._storage._storage_copy({ 'source': ['gs://foo/item1'], 'destination': 'gs://foo/bar1' }, None) mock_storage_item_copy_to.assert_called_with(mock.ANY, 'bar1', bucket='foo') self.assertEquals('item1', mock_storage_item_copy_to.call_args[0][0].key) self.assertEquals('foo', mock_storage_item_copy_to.call_args[0][0]._bucket) with self.assertRaises(Exception) as error: datalab.storage.commands._storage._storage_copy({ 'source': ['gs://foo/item*'], 'destination': 'gs://foo/bar1' }, None) self.assertEqual('More than one source but target gs://foo/bar1 is not a bucket', str(error.exception)) @mock.patch('datalab.storage.commands._storage._storage_copy', autospec=True) def test_storage_copy_magic(self, mock_storage_copy): datalab.storage.commands._storage.storage('copy --source gs://foo/item1 ' '--destination gs://foo/bar1') mock_storage_copy.assert_called_with({ 'source': ['gs://foo/item1'], 'destination': 'gs://foo/bar1', 'func': datalab.storage.commands._storage._storage_copy }, None) @mock.patch('datalab.storage._api.Api.buckets_insert', autospec=True) @mock.patch('datalab.context._context.Context.default') def test_storage_create(self, mock_context_default, mock_api_buckets_insert): context = TestCases._create_context() mock_context_default.return_value = context errs = datalab.storage.commands._storage._storage_create({ 'project': 'test', 'bucket': [ 'gs://baz' ] }, None) self.assertEqual(None, errs) mock_api_buckets_insert.assert_called_with(mock.ANY, 'baz', project_id='test') with self.assertRaises(Exception) as error: datalab.storage.commands._storage._storage_create({ 'project': 'test', 'bucket': [ 'gs://foo/bar' ] }, None) self.assertEqual("Couldn't create gs://foo/bar: Invalid bucket name gs://foo/bar", str(error.exception)) @mock.patch('datalab.storage._api.Api.buckets_get', autospec=True) @mock.patch('datalab.storage._api.Api.objects_get', autospec=True) @mock.patch('datalab.storage._bucket.Bucket.items', autospec=True) @mock.patch('datalab.storage._api.Api.objects_delete', autospec=True) @mock.patch('datalab.storage._api.Api.buckets_delete', autospec=True) @mock.patch('datalab.context._context.Context.default') def test_storage_delete(self, mock_context_default, mock_api_bucket_delete, mock_api_objects_delete, mock_bucket_items, mock_api_objects_get, mock_api_buckets_get): context = TestCases._create_context() mock_context_default.return_value = context # Mock API for getting items in a bucket. mock_bucket_items.side_effect = TestCases._mock_bucket_items_return(context) # Mock API for getting item metadata. mock_api_objects_get.side_effect = TestCases._mock_api_objects_get() mock_api_buckets_get.side_effect = TestCases._mock_api_buckets_get() with self.assertRaises(Exception) as error: datalab.storage.commands._storage._storage_delete({ 'bucket': [ 'gs://bar', 'gs://baz' ], 'object': [ 'gs://foo/item1', 'gs://baz/item1', ] }, None) self.assertEqual('gs://baz does not exist\ngs://baz/item1 does not exist', str(error.exception)) mock_api_bucket_delete.assert_called_with(mock.ANY, 'bar') mock_api_objects_delete.assert_called_with(mock.ANY, 'foo', 'item1') @mock.patch('datalab.context._context.Context.default') def test_storage_view(self, mock_context_default): context = TestCases._create_context() mock_context_default.return_value = context # TODO(gram): complete this test @mock.patch('datalab.context._context.Context.default') def test_storage_write(self, mock_context_default): context = TestCases._create_context() mock_context_default.return_value = context # TODO(gram): complete this test @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return datalab.context.Context(project_id, creds) @staticmethod def _mock_bucket_items_return(context): # Mock API for getting items in a bucket. def bucket_items_side_effect(*args, **kwargs): bucket = args[0].name # self if bucket == 'foo': return [ datalab.storage._item.Item(bucket, 'item1', context=context), datalab.storage._item.Item(bucket, 'item2', context=context), datalab.storage._item.Item(bucket, 'item3', context=context), ] elif bucket == 'bar': return [ datalab.storage._item.Item(bucket, 'object1', context=context), datalab.storage._item.Item(bucket, 'object3', context=context), ] else: return [] return bucket_items_side_effect @staticmethod def _mock_api_objects_get(): # Mock API for getting item metadata. def api_objects_get_side_effect(*args, **kwargs): if args[1].find('baz') >= 0: return None key = args[2] if key.find('*') >= 0: return None return {'name': key} return api_objects_get_side_effect @staticmethod def _mock_api_buckets_get(): # Mock API for getting bucket metadata. def api_buckets_get_side_effect(*args, **kwargs): key = args[1] if key.find('*') >= 0 or key.find('baz') >= 0: return None return {'name': key} return api_buckets_get_side_effect ================================================ FILE: legacy_tests/kernel/utils_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals from builtins import range import datetime as dt import collections import mock import pandas import unittest import google.auth # import Python so we can mock the parts we need to here. import IPython import IPython.core.magic IPython.core.magic.register_line_cell_magic = mock.Mock() IPython.core.magic.register_line_magic = mock.Mock() IPython.core.magic.register_cell_magic = mock.Mock() IPython.get_ipython = mock.Mock() import datalab.bigquery # noqa: E402 import datalab.context # noqa: E402 import datalab.utils.commands # noqa: E402 class TestCases(unittest.TestCase): @staticmethod def _get_expected_cols(): cols = [ {'type': 'number', 'id': 'Column1', 'label': 'Column1'}, {'type': 'number', 'id': 'Column2', 'label': 'Column2'}, {'type': 'string', 'id': 'Column3', 'label': 'Column3'}, {'type': 'boolean', 'id': 'Column4', 'label': 'Column4'}, {'type': 'number', 'id': 'Column5', 'label': 'Column5'}, {'type': 'datetime', 'id': 'Column6', 'label': 'Column6'} ] return cols @staticmethod def _timestamp(d): return (d - dt.datetime(1970, 1, 1)).total_seconds() @staticmethod def _get_raw_rows(): rows = [ {'f': [ {'v': 1}, {'v': 2}, {'v': '3'}, {'v': 'true'}, {'v': 0.0}, {'v': TestCases._timestamp(dt.datetime(2000, 1, 1))} ]}, {'f': [ {'v': 11}, {'v': 12}, {'v': '13'}, {'v': 'false'}, {'v': 0.2}, {'v': TestCases._timestamp(dt.datetime(2000, 1, 2))} ]}, {'f': [ {'v': 21}, {'v': 22}, {'v': '23'}, {'v': 'true'}, {'v': 0.3}, {'v': TestCases._timestamp(dt.datetime(2000, 1, 3))} ]}, {'f': [ {'v': 31}, {'v': 32}, {'v': '33'}, {'v': 'false'}, {'v': 0.4}, {'v': TestCases._timestamp(dt.datetime(2000, 1, 4))} ]}, {'f': [ {'v': 41}, {'v': 42}, {'v': '43'}, {'v': 'true'}, {'v': 0.5}, {'v': TestCases._timestamp(dt.datetime(2000, 1, 5))} ]}, {'f': [ {'v': 51}, {'v': 52}, {'v': '53'}, {'v': 'true'}, {'v': 0.6}, {'v': TestCases._timestamp(dt.datetime(2000, 1, 6))} ]} ] return rows @staticmethod def _get_expected_rows(): rows = [ {'c': [ {'v': 1}, {'v': 2}, {'v': '3'}, {'v': True}, {'v': 0.0}, {'v': dt.datetime(2000, 1, 1)} ]}, {'c': [ {'v': 11}, {'v': 12}, {'v': '13'}, {'v': False}, {'v': 0.2}, {'v': dt.datetime(2000, 1, 2)} ]}, {'c': [ {'v': 21}, {'v': 22}, {'v': '23'}, {'v': True}, {'v': 0.3}, {'v': dt.datetime(2000, 1, 3)} ]}, {'c': [ {'v': 31}, {'v': 32}, {'v': '33'}, {'v': False}, {'v': 0.4}, {'v': dt.datetime(2000, 1, 4)} ]}, {'c': [ {'v': 41}, {'v': 42}, {'v': '43'}, {'v': True}, {'v': 0.5}, {'v': dt.datetime(2000, 1, 5)} ]}, {'c': [ {'v': 51}, {'v': 52}, {'v': '53'}, {'v': True}, {'v': 0.6}, {'v': dt.datetime(2000, 1, 6)} ]} ] return rows @staticmethod def _get_test_data_as_list_of_dicts(): test_data = [ {'Column1': 1, 'Column2': 2, 'Column3': '3', 'Column4': True, 'Column5': 0.0, 'Column6': dt.datetime(2000, 1, 1)}, {'Column1': 11, 'Column2': 12, 'Column3': '13', 'Column4': False, 'Column5': 0.2, 'Column6': dt.datetime(2000, 1, 2)}, {'Column1': 21, 'Column2': 22, 'Column3': '23', 'Column4': True, 'Column5': 0.3, 'Column6': dt.datetime(2000, 1, 3)}, {'Column1': 31, 'Column2': 32, 'Column3': '33', 'Column4': False, 'Column5': 0.4, 'Column6': dt.datetime(2000, 1, 4)}, {'Column1': 41, 'Column2': 42, 'Column3': '43', 'Column4': True, 'Column5': 0.5, 'Column6': dt.datetime(2000, 1, 5)}, {'Column1': 51, 'Column2': 52, 'Column3': '53', 'Column4': True, 'Column5': 0.6, 'Column6': dt.datetime(2000, 1, 6)} ] # Use OrderedDicts to make testing the result easier. for i in range(0, len(test_data)): test_data[i] = collections.OrderedDict(sorted(list(test_data[i].items()), key=lambda t: t[0])) return test_data def test_get_data_from_list_of_dicts(self): self._test_get_data(TestCases._get_test_data_as_list_of_dicts(), TestCases._get_expected_cols(), TestCases._get_expected_rows(), 6, datalab.utils.commands._utils._get_data_from_list_of_dicts) self._test_get_data(TestCases._get_test_data_as_list_of_dicts(), TestCases._get_expected_cols(), TestCases._get_expected_rows(), 6, datalab.utils.commands._utils.get_data) def test_get_data_from_list_of_lists(self): test_data = [ [1, 2, '3', True, 0.0, dt.datetime(2000, 1, 1)], [11, 12, '13', False, 0.2, dt.datetime(2000, 1, 2)], [21, 22, '23', True, 0.3, dt.datetime(2000, 1, 3)], [31, 32, '33', False, 0.4, dt.datetime(2000, 1, 4)], [41, 42, '43', True, 0.5, dt.datetime(2000, 1, 5)], [51, 52, '53', True, 0.6, dt.datetime(2000, 1, 6)], ] self._test_get_data(test_data, TestCases._get_expected_cols(), TestCases._get_expected_rows(), 6, datalab.utils.commands._utils._get_data_from_list_of_lists) self._test_get_data(test_data, TestCases._get_expected_cols(), TestCases._get_expected_rows(), 6, datalab.utils.commands._utils.get_data) def test_get_data_from_dataframe(self): df = pandas.DataFrame(self._get_test_data_as_list_of_dicts()) self._test_get_data(df, TestCases._get_expected_cols(), TestCases._get_expected_rows(), 6, datalab.utils.commands._utils._get_data_from_dataframe) self._test_get_data(df, TestCases._get_expected_cols(), TestCases._get_expected_rows(), 6, datalab.utils.commands._utils.get_data) @mock.patch('datalab.bigquery._api.Api.tabledata_list') @mock.patch('datalab.bigquery._table.Table.exists') @mock.patch('datalab.bigquery._api.Api.tables_get') @mock.patch('datalab.context._context.Context.default') def test_get_data_from_table(self, mock_context_default, mock_api_tables_get, mock_table_exists, mock_api_tabledata_list): data = TestCases._get_expected_rows() mock_context_default.return_value = TestCases._create_context() mock_api_tables_get.return_value = { 'numRows': len(data), 'schema': { 'fields': [ {'name': 'Column1', 'type': 'INTEGER'}, {'name': 'Column2', 'type': 'INTEGER'}, {'name': 'Column3', 'type': 'STRING'}, {'name': 'Column4', 'type': 'BOOLEAN'}, {'name': 'Column5', 'type': 'FLOAT'}, {'name': 'Column6', 'type': 'TIMESTAMP'} ] } } mock_table_exists.return_value = True raw_data = self._get_raw_rows() def tabledata_list(*args, **kwargs): start_index = kwargs['start_index'] max_results = kwargs['max_results'] if max_results < 0: max_results = len(data) return {'rows': raw_data[start_index:start_index + max_results]} mock_api_tabledata_list.side_effect = tabledata_list t = datalab.bigquery.Table('foo.bar') self._test_get_data(t, TestCases._get_expected_cols(), TestCases._get_expected_rows(), 6, datalab.utils.commands._utils._get_data_from_table) self._test_get_data(t, TestCases._get_expected_cols(), TestCases._get_expected_rows(), 6, datalab.utils.commands._utils.get_data) def test_get_data_from_empty_list(self): self._test_get_data([], [], [], 0, datalab.utils.commands._utils.get_data) def test_get_data_from_malformed_list(self): with self.assertRaises(Exception) as error: self._test_get_data(['foo', 'bar'], [], [], 0, datalab.utils.commands._utils.get_data) self.assertEquals('To get tabular data from a list it must contain dictionaries or lists.', str(error.exception)) def _test_get_data(self, test_data, cols, rows, expected_count, fn): self.maxDiff = None data, count = fn(test_data) self.assertEquals(expected_count, count) self.assertEquals({'cols': cols, 'rows': rows}, data) # Test first_row. Note that count must be set in this case so we use a value greater than the # data set size. for first in range(0, 6): data, count = fn(test_data, first_row=first, count=10) self.assertEquals(expected_count, count) self.assertEquals({'cols': cols, 'rows': rows[first:]}, data) # Test first_row + count for first in range(0, 6): data, count = fn(test_data, first_row=first, count=2) self.assertEquals(expected_count, count) self.assertEquals({'cols': cols, 'rows': rows[first:first + 2]}, data) # Test subsets of columns # No columns data, count = fn(test_data, fields=[]) self.assertEquals({'cols': [], 'rows': [{'c': []}] * expected_count}, data) # Single column data, count = fn(test_data, fields=['Column3']) if expected_count == 0: return self.assertEquals({'cols': [cols[2]], 'rows': [{'c': [row['c'][2]]} for row in rows]}, data) # Multi-columns data, count = fn(test_data, fields=['Column1', 'Column3', 'Column6']) self.assertEquals({'cols': [cols[0], cols[2], cols[5]], 'rows': [{'c': [row['c'][0], row['c'][2], row['c'][5]]} for row in rows]}, data) # Switch order data, count = fn(test_data, fields=['Column3', 'Column1']) self.assertEquals({'cols': [cols[2], cols[0]], 'rows': [{'c': [row['c'][2], row['c'][0]]} for row in rows]}, data) # Select all data, count = fn(test_data, fields=['Column1', 'Column2', 'Column3', 'Column4', 'Column5', 'Column6']) self.assertEquals({'cols': cols, 'rows': rows}, data) @staticmethod def _create_api(): context = TestCases._create_context() return datalab.bigquery._api.Api(context.credentials, context.project_id) @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return datalab.context.Context(project_id, creds) ================================================ FILE: legacy_tests/main.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import os import sys import unittest # Set up the path so that we can import our datalab.* packages. sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../datalab'))) # noqa import bigquery.api_tests import bigquery.dataset_tests import bigquery.federated_table_tests import bigquery.jobs_tests import bigquery.parser_tests import bigquery.query_tests import bigquery.sampling_tests import bigquery.schema_tests import bigquery.table_tests import bigquery.udf_tests import bigquery.view_tests import data.sql_tests import kernel.bigquery_tests import kernel.chart_data_tests import kernel.chart_tests import kernel.commands_tests import kernel.html_tests import kernel.module_tests import kernel.sql_tests import kernel.storage_tests import kernel.utils_tests import stackdriver.commands.monitoring_tests import stackdriver.monitoring.group_tests import stackdriver.monitoring.metric_tests import stackdriver.monitoring.resource_tests import stackdriver.monitoring.query_metadata_tests import stackdriver.monitoring.query_tests import stackdriver.monitoring.utils_tests import storage.api_tests import storage.bucket_tests import storage.item_tests import _util.http_tests import _util.lru_cache_tests import _util.util_tests _TEST_MODULES = [ bigquery.api_tests, bigquery.dataset_tests, bigquery.federated_table_tests, bigquery.jobs_tests, bigquery.parser_tests, bigquery.query_tests, bigquery.sampling_tests, bigquery.schema_tests, bigquery.table_tests, bigquery.udf_tests, bigquery.view_tests, bigquery.sampling_tests, data.sql_tests, kernel.bigquery_tests, kernel.chart_data_tests, kernel.chart_tests, kernel.commands_tests, kernel.html_tests, kernel.module_tests, kernel.sql_tests, kernel.storage_tests, kernel.utils_tests, stackdriver.commands.monitoring_tests, stackdriver.monitoring.group_tests, stackdriver.monitoring.metric_tests, stackdriver.monitoring.resource_tests, stackdriver.monitoring.query_metadata_tests, stackdriver.monitoring.query_tests, stackdriver.monitoring.utils_tests, storage.api_tests, storage.bucket_tests, storage.item_tests, _util.http_tests, _util.lru_cache_tests, _util.util_tests ] if __name__ == '__main__': suite = unittest.TestSuite() for m in _TEST_MODULES: suite.addTests(unittest.defaultTestLoader.loadTestsFromModule(m)) runner = unittest.TextTestRunner() result = runner.run(suite) sys.exit(len(result.errors) + len(result.failures)) ================================================ FILE: legacy_tests/stackdriver/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: legacy_tests/stackdriver/commands/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: legacy_tests/stackdriver/commands/monitoring_tests.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import import mock import unittest import pandas import datalab.stackdriver.commands._monitoring as monitoring_commands PROJECT = 'my-project' class TestCases(unittest.TestCase): @mock.patch('datalab.stackdriver.commands._monitoring._render_dataframe') @mock.patch('datalab.stackdriver.monitoring.MetricDescriptors') def test_list_metric_descriptors(self, mock_metric_descriptors, mock_render_dataframe): METRIC_TYPES = ['compute.googleapis.com/instances/cpu/utilization', 'compute.googleapis.com/instances/cpu/usage_time'] DATAFRAME = pandas.DataFrame(METRIC_TYPES, columns=['Metric type']) PATTERN = 'compute*cpu*' mock_metric_class = mock_metric_descriptors.return_value mock_metric_class.as_dataframe.return_value = DATAFRAME monitoring_commands._list_metric_descriptors( {'project': PROJECT, 'type': PATTERN}, None) mock_metric_descriptors.assert_called_once_with(project_id=PROJECT) mock_metric_class.as_dataframe.assert_called_once_with(pattern=PATTERN) mock_render_dataframe.assert_called_once_with(DATAFRAME) @mock.patch('datalab.stackdriver.commands._monitoring._render_dataframe') @mock.patch('datalab.stackdriver.monitoring.ResourceDescriptors') def test_list_resource_descriptors(self, mock_resource_descriptors, mock_render_dataframe): RESOURCE_TYPES = ['gce_instance', 'aws_ec2_instance'] DATAFRAME = pandas.DataFrame(RESOURCE_TYPES, columns=['Resource type']) PATTERN = '*instance*' mock_resource_class = mock_resource_descriptors.return_value mock_resource_class.as_dataframe.return_value = DATAFRAME monitoring_commands._list_resource_descriptors( {'project': PROJECT, 'type': PATTERN}, None) mock_resource_descriptors.assert_called_once_with(project_id=PROJECT) mock_resource_class.as_dataframe.assert_called_once_with(pattern=PATTERN) mock_render_dataframe.assert_called_once_with(DATAFRAME) @mock.patch('datalab.stackdriver.commands._monitoring._render_dataframe') @mock.patch('datalab.stackdriver.monitoring.Groups') def test_list_groups(self, mock_groups, mock_render_dataframe): GROUP_IDS = ['GROUP-205', 'GROUP-101'] DATAFRAME = pandas.DataFrame(GROUP_IDS, columns=['Group ID']) PATTERN = 'GROUP-*' mock_group_class = mock_groups.return_value mock_group_class.as_dataframe.return_value = DATAFRAME monitoring_commands._list_groups( {'project': PROJECT, 'name': PATTERN}, None) mock_groups.assert_called_once_with(project_id=PROJECT) mock_group_class.as_dataframe.assert_called_once_with(pattern=PATTERN) mock_render_dataframe.assert_called_once_with(DATAFRAME) ================================================ FILE: legacy_tests/stackdriver/monitoring/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: legacy_tests/stackdriver/monitoring/group_tests.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import import mock import unittest import google.cloud.monitoring_v3 import google.auth import google.datalab import google.datalab.stackdriver.monitoring as gcm DEFAULT_PROJECT = 'test' PROJECT = 'my-project' GROUP_IDS = ['GROUP-205', 'GROUP-101'] PARENT_IDS = ['', GROUP_IDS[0]] DISPLAY_NAMES = ['All Instances', 'GCE Instances'] PARENT_DISPLAY_NAMES = ['', DISPLAY_NAMES[0]] FILTER_STRINGS = ['resource.type = ends_with("instance")', 'resource.type = "gce_instance"'] IS_CLUSTERS = [False, True] class TestCases(unittest.TestCase): def setUp(self): self.context = self._create_context(DEFAULT_PROJECT) self.groups = gcm.Groups(context=self.context) @mock.patch('google.datalab.Context.default') def test_constructor_minimal(self, mock_context_default): mock_context_default.return_value = self.context groups = gcm.Groups() self.assertIs(groups._context, self.context) self.assertIsNone(groups._group_dict) self.assertEqual(groups._client.project, DEFAULT_PROJECT) def test_constructor_maximal(self): context = self._create_context(PROJECT) groups = gcm.Groups(context) self.assertIs(groups._context, context) self.assertIsNone(groups._group_dict) self.assertEqual(groups._client.project, PROJECT) @mock.patch('google.cloud.monitoring_v3.GroupServiceClient.list_groups') def test_list(self, mock_api_list_groups): mock_api_list_groups.return_value = self._list_groups_get_result( context=self.context) group_list = self.groups.list() mock_api_list_groups.assert_called_once_with(DEFAULT_PROJECT) self.assertEqual(len(group_list), 2) self.assertEqual(group_list[0].name, GROUP_IDS[0]) self.assertEqual(group_list[1].name, GROUP_IDS[1]) @mock.patch('google.cloud.monitoring_v3.GroupServiceClient.list_groups') def test_list_w_pattern_match(self, mock_api_list_groups): mock_api_list_groups.return_value = self._list_groups_get_result( context=self.context) group_list = self.groups.list(pattern='GCE*') mock_api_list_groups.assert_called_once_with(DEFAULT_PROJECT) self.assertEqual(len(group_list), 1) self.assertEqual(group_list[0].name, GROUP_IDS[1]) @mock.patch('google.cloud.monitoring_v3.GroupServiceClient.list_groups') def test_list_caching(self, mock_gcloud_list_groups): mock_gcloud_list_groups.return_value = self._list_groups_get_result( context=self.context) actual_list1 = self.groups.list() actual_list2 = self.groups.list() mock_gcloud_list_groups.assert_called_once_with(DEFAULT_PROJECT) self.assertEqual(actual_list1, actual_list2) @mock.patch('google.cloud.monitoring_v3.GroupServiceClient.list_groups') def test_as_dataframe(self, mock_gcloud_list_groups): mock_gcloud_list_groups.return_value = self._list_groups_get_result( context=self.context) dataframe = self.groups.as_dataframe() mock_gcloud_list_groups.assert_called_once_with(DEFAULT_PROJECT) expected_headers = list(gcm.Groups._DISPLAY_HEADERS) self.assertEqual(dataframe.columns.tolist(), expected_headers) self.assertEqual(dataframe.columns.names, [None]) self.assertEqual(dataframe.index.tolist(), list(range(len(GROUP_IDS)))) self.assertEqual(dataframe.index.names, [None]) expected_values = [list(row) for row in zip(GROUP_IDS, DISPLAY_NAMES, PARENT_IDS, PARENT_DISPLAY_NAMES, IS_CLUSTERS, FILTER_STRINGS)] self.assertEqual(dataframe.values.tolist(), expected_values) @mock.patch('google.cloud.monitoring_v3.GroupServiceClient.list_groups') def test_as_dataframe_w_all_args(self, mock_gcloud_list_groups): mock_gcloud_list_groups.return_value = self._list_groups_get_result( context=self.context) dataframe = self.groups.as_dataframe(pattern='*Instance*', max_rows=1) mock_gcloud_list_groups.assert_called_once_with(DEFAULT_PROJECT) expected_headers = list(gcm.Groups._DISPLAY_HEADERS) self.assertEqual(dataframe.columns.tolist(), expected_headers) self.assertEqual(dataframe.index.tolist(), [0]) self.assertEqual(dataframe.iloc[0, 0], GROUP_IDS[0]) @staticmethod def _create_context(project_id): creds = mock.Mock(spec=google.auth.credentials.Credentials) return google.datalab.Context(project_id, creds) @staticmethod def _list_groups_get_result(context): groups = [] for group_id, parent_id, display_name, filter_string, is_cluster in \ zip(GROUP_IDS, PARENT_IDS, DISPLAY_NAMES, FILTER_STRINGS, IS_CLUSTERS): group = google.cloud.monitoring_v3.types.Group( name=group_id, display_name=display_name, parent_name=parent_id, filter=filter_string, is_cluster=is_cluster) groups.append(group) return groups ================================================ FILE: legacy_tests/stackdriver/monitoring/metric_tests.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import import mock import unittest import google.cloud.monitoring_v3 import google.auth import google.datalab import google.datalab.stackdriver.monitoring as gcm DEFAULT_PROJECT = 'test' PROJECT = 'my-project' METRIC_TYPES = ['compute.googleapis.com/instances/cpu/utilization', 'compute.googleapis.com/instances/cpu/usage_time'] DISPLAY_NAMES = ['CPU Utilization', 'CPU Usage'] METRIC_KIND = 'GAUGE' VALUE_TYPE = 'DOUBLE' UNIT = '1' LABELS = [dict(key='instance_name', value_type='STRING', description='VM instance'), dict(key='device_name', value_type='STRING', description='Device name')] FILTER_STRING = 'metric.type:"cpu"' TYPE_PREFIX = 'compute' class TestCases(unittest.TestCase): def setUp(self): self.context = self._create_context(DEFAULT_PROJECT) self.descriptors = gcm.MetricDescriptors(context=self.context) @mock.patch('google.datalab.Context.default') def test_constructor_minimal(self, mock_context_default): mock_context_default.return_value = self.context descriptors = gcm.MetricDescriptors() self.assertEqual(descriptors._client.project, DEFAULT_PROJECT) self.assertIsNone(descriptors._filter_string) self.assertIsNone(descriptors._type_prefix) self.assertIsNone(descriptors._descriptors) def test_constructor_maximal(self): context = self._create_context(PROJECT) descriptors = gcm.MetricDescriptors( filter_string=FILTER_STRING, type_prefix=TYPE_PREFIX, context=context) self.assertEqual(descriptors._client.project, PROJECT) self.assertEqual(descriptors._filter_string, FILTER_STRING) self.assertEqual(descriptors._type_prefix, TYPE_PREFIX) self.assertIsNone(descriptors._descriptors) @mock.patch('google.cloud.monitoring_v3.MetricServiceClient.list_metric_descriptors') def test_list(self, mock_gcloud_list_descriptors): mock_gcloud_list_descriptors.return_value = self._list_metrics_get_result( context=self.context) metric_descriptor_list = self.descriptors.list() mock_gcloud_list_descriptors.assert_called_once_with( DEFAULT_PROJECT, filter_='') self.assertEqual(len(metric_descriptor_list), 2) self.assertEqual(metric_descriptor_list[0].type, METRIC_TYPES[0]) self.assertEqual(metric_descriptor_list[1].type, METRIC_TYPES[1]) @mock.patch('google.cloud.monitoring_v3.MetricServiceClient.list_metric_descriptors') def test_list_w_api_filter(self, mock_gcloud_list_descriptors): mock_gcloud_list_descriptors.return_value = self._list_metrics_get_result( context=self.context) descriptors = gcm.MetricDescriptors( filter_string=FILTER_STRING, type_prefix=TYPE_PREFIX, context=self.context) metric_descriptor_list = descriptors.list() expected_filter = '{} AND metric.type = starts_with("{}")'.format( FILTER_STRING, TYPE_PREFIX) mock_gcloud_list_descriptors.assert_called_once_with( DEFAULT_PROJECT, filter_=expected_filter) self.assertEqual(len(metric_descriptor_list), 2) self.assertEqual(metric_descriptor_list[0].type, METRIC_TYPES[0]) self.assertEqual(metric_descriptor_list[1].type, METRIC_TYPES[1]) @mock.patch('google.cloud.monitoring_v3.MetricServiceClient.list_metric_descriptors') def test_list_w_pattern_match(self, mock_gcloud_list_descriptors): mock_gcloud_list_descriptors.return_value = self._list_metrics_get_result( context=self.context) metric_descriptor_list = self.descriptors.list(pattern='*usage_time') mock_gcloud_list_descriptors.assert_called_once_with( DEFAULT_PROJECT, filter_='') self.assertEqual(len(metric_descriptor_list), 1) self.assertEqual(metric_descriptor_list[0].type, METRIC_TYPES[1]) @mock.patch('google.cloud.monitoring_v3.MetricServiceClient.list_metric_descriptors') def test_list_caching(self, mock_gcloud_list_descriptors): mock_gcloud_list_descriptors.return_value = self._list_metrics_get_result( context=self.context) actual_list1 = self.descriptors.list() actual_list2 = self.descriptors.list() mock_gcloud_list_descriptors.assert_called_once_with( DEFAULT_PROJECT, filter_='') self.assertEqual(actual_list1, actual_list2) @mock.patch('google.datalab.stackdriver.monitoring.MetricDescriptors.list') def test_as_dataframe(self, mock_datalab_list_descriptors): mock_datalab_list_descriptors.return_value = self._list_metrics_get_result( context=self.context) dataframe = self.descriptors.as_dataframe() mock_datalab_list_descriptors.assert_called_once_with('*') expected_headers = list(gcm.MetricDescriptors._DISPLAY_HEADERS) self.assertEqual(dataframe.columns.tolist(), expected_headers) self.assertEqual(dataframe.columns.names, [None]) self.assertEqual(dataframe.index.tolist(), list(range(len(METRIC_TYPES)))) self.assertEqual(dataframe.index.names, [None]) expected_labels = 'instance_name, device_name' expected_values = [ [metric_type, display_name, METRIC_KIND, VALUE_TYPE, UNIT, expected_labels] for metric_type, display_name in zip(METRIC_TYPES, DISPLAY_NAMES)] self.assertEqual(dataframe.values.tolist(), expected_values) @mock.patch('google.datalab.stackdriver.monitoring.MetricDescriptors.list') def test_as_dataframe_w_all_args(self, mock_datalab_list_descriptors): mock_datalab_list_descriptors.return_value = self._list_metrics_get_result( context=self.context) dataframe = self.descriptors.as_dataframe(pattern='*cpu*', max_rows=1) mock_datalab_list_descriptors.assert_called_once_with('*cpu*') expected_headers = list(gcm.MetricDescriptors._DISPLAY_HEADERS) self.assertEqual(dataframe.columns.tolist(), expected_headers) self.assertEqual(dataframe.index.tolist(), [0]) self.assertEqual(dataframe.iloc[0, 0], METRIC_TYPES[0]) @staticmethod def _create_context(project_id): creds = mock.Mock(spec=google.auth.credentials.Credentials) return google.datalab.Context(project_id, creds) @staticmethod def _list_metrics_get_result(context): all_labels = [google.cloud.monitoring_v3.types.LabelDescriptor(**labels) for labels in LABELS] descriptors = [ google.cloud.monitoring_v3.types.MetricDescriptor( type=metric_type, metric_kind=METRIC_KIND, value_type=VALUE_TYPE, unit=UNIT, display_name=display_name, labels=all_labels, ) for metric_type, display_name in zip(METRIC_TYPES, DISPLAY_NAMES)] return descriptors ================================================ FILE: legacy_tests/stackdriver/monitoring/query_metadata_tests.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import import mock import unittest from google.cloud.monitoring_v3.types import Metric from google.cloud.monitoring_v3.types import MonitoredResource from google.cloud.monitoring_v3.types import TimeSeries import google.auth import google.datalab import google.datalab.stackdriver.monitoring as gcm PROJECT = 'my-project' METRIC_TYPE = 'compute.googleapis.com/instance/cpu/utilization' RESOURCE_TYPE = 'gce_instance' INSTANCE_NAMES = ['instance-1', 'instance-2'] INSTANCE_ZONES = ['us-east1-a', 'us-east1-b'] INSTANCE_IDS = ['1234567890123456789', '9876543210987654321'] class TestCases(unittest.TestCase): def setUp(self): creds = mock.Mock(spec=google.auth.credentials.Credentials) context = google.datalab.Context(PROJECT, creds) self.query = gcm.Query(METRIC_TYPE, context=context) @mock.patch('google.datalab.stackdriver.monitoring.Query.iter') def test_constructor(self, mock_query_iter): time_series_iterable = list(self._query_iter_get_result()) mock_query_iter.return_value = self._query_iter_get_result() query_metadata = gcm.QueryMetadata(self.query) mock_query_iter.assert_called_once_with(headers_only=True) self.assertEqual(query_metadata.metric_type, METRIC_TYPE) self.assertEqual(query_metadata.resource_types, set([RESOURCE_TYPE])) self.assertEqual(query_metadata._timeseries_list, time_series_iterable) @mock.patch('google.datalab.stackdriver.monitoring.Query.iter') def test_iteration(self, mock_query_iter): time_series_iterable = list(self._query_iter_get_result()) mock_query_iter.return_value = self._query_iter_get_result() query_metadata = gcm.QueryMetadata(self.query) response = list(query_metadata) self.assertEqual(len(response), len(time_series_iterable)) self.assertEqual(response, time_series_iterable) @mock.patch('google.datalab.stackdriver.monitoring.Query.iter') def test_as_dataframe(self, mock_query_iter): mock_query_iter.return_value = self._query_iter_get_result() query_metadata = gcm.QueryMetadata(self.query) dataframe = query_metadata.as_dataframe() NUM_INSTANCES = len(INSTANCE_IDS) self.assertEqual(dataframe.shape, (NUM_INSTANCES, 5)) expected_values = [ [RESOURCE_TYPE, PROJECT, zone, instance_id, instance_name] for zone, instance_id, instance_name in zip(INSTANCE_ZONES, INSTANCE_IDS, INSTANCE_NAMES)] self.assertEqual(dataframe.values.tolist(), expected_values) expected_headers = [ ('resource.type', ''), ('resource.labels', 'project_id'), ('resource.labels', 'zone'), ('resource.labels', 'instance_id'), ('metric.labels', 'instance_name') ] self.assertEqual(dataframe.columns.tolist(), expected_headers) self.assertEqual(dataframe.columns.names, [None, None]) self.assertEqual(dataframe.index.tolist(), list(range(NUM_INSTANCES))) self.assertEqual(dataframe.index.names, [None]) @mock.patch('google.datalab.stackdriver.monitoring.Query.iter') def test_as_dataframe_w_max_rows(self, mock_query_iter): mock_query_iter.return_value = self._query_iter_get_result() MAX_ROWS = 1 query_metadata = gcm.QueryMetadata(self.query) dataframe = query_metadata.as_dataframe(max_rows=MAX_ROWS) self.assertEqual(dataframe.shape, (MAX_ROWS, 5)) expected_values = [ [RESOURCE_TYPE, PROJECT, INSTANCE_ZONES[0], INSTANCE_IDS[0], INSTANCE_NAMES[0]], ] self.assertEqual(dataframe.values.tolist(), expected_values) expected_headers = [ ('resource.type', ''), ('resource.labels', 'project_id'), ('resource.labels', 'zone'), ('resource.labels', 'instance_id'), ('metric.labels', 'instance_name') ] self.assertEqual(dataframe.columns.tolist(), expected_headers) self.assertEqual(dataframe.columns.names, [None, None]) self.assertEqual(dataframe.index.tolist(), list(range(MAX_ROWS))) self.assertEqual(dataframe.index.names, [None]) @mock.patch('google.datalab.stackdriver.monitoring.Query.iter') def test_as_dataframe_w_no_data(self, mock_query_iter): query_metadata = gcm.QueryMetadata(self.query) dataframe = query_metadata.as_dataframe() self.assertEqual(dataframe.shape, (0, 0)) self.assertIsNone(dataframe.columns.name) self.assertIsNone(dataframe.index.name) @staticmethod def _query_iter_get_result(): METRIC_LABELS = list({'instance_name': name} for name in INSTANCE_NAMES) RESOURCE_LABELS = list({ 'project_id': PROJECT, 'zone': zone, 'instance_id': instance_id, } for zone, instance_id in zip(INSTANCE_ZONES, INSTANCE_IDS)) for metric_labels, resource_labels in zip(METRIC_LABELS, RESOURCE_LABELS): yield TimeSeries( metric=Metric(type=METRIC_TYPE, labels=metric_labels), resource=MonitoredResource(type=RESOURCE_TYPE, labels=resource_labels), metric_kind='GAUGE', value_type='DOUBLE', points=[], ) ================================================ FILE: legacy_tests/stackdriver/monitoring/query_tests.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import import datetime import mock import unittest from google.cloud.monitoring_v3.query import Query as BaseQuery import google.auth import google.datalab import google.datalab.stackdriver.monitoring as gcm PROJECT = 'my-project' METRIC_TYPE = 'compute.googleapis.com/instance/cpu/utilization' RESOURCE_TYPE = 'gce_instance' INSTANCE_NAMES = ['instance-1', 'instance-2'] INSTANCE_ZONES = ['us-east1-a', 'us-east1-b'] INSTANCE_IDS = ['1234567890123456789', '9876543210987654321'] class TestCases(unittest.TestCase): def setUp(self): creds = mock.Mock(spec=google.auth.credentials.Credentials) self.context = google.datalab.Context(PROJECT, creds) @mock.patch('google.datalab.Context.default') def test_constructor_minimal(self, mock_context_default): mock_context_default.return_value = self.context query = gcm.Query() self.assertEqual(query._filter.metric_type, BaseQuery.DEFAULT_METRIC_TYPE) self.assertIsNone(query._start_time) self.assertIsNone(query._end_time) self.assertEqual(query._per_series_aligner, 0) self.assertEqual(query._alignment_period_seconds, 0) self.assertEqual(query._cross_series_reducer, 0) self.assertEqual(query._group_by_fields, ()) def test_constructor_maximal(self): UPTIME_METRIC = 'compute.googleapis.com/instance/uptime' T1 = datetime.datetime(2016, 4, 7, 2, 30, 30) DAYS, HOURS, MINUTES = 1, 2, 3 T0 = T1 - datetime.timedelta(days=DAYS, hours=HOURS, minutes=MINUTES) query = gcm.Query(UPTIME_METRIC, end_time=T1, days=DAYS, hours=HOURS, minutes=MINUTES, context=self.context) self.assertEqual(query._filter.metric_type, UPTIME_METRIC) self.assertEqual(query._start_time, T0) self.assertEqual(query._end_time, T1) self.assertEqual(query._per_series_aligner, 0) self.assertEqual(query._alignment_period_seconds, 0) self.assertEqual(query._cross_series_reducer, 0) self.assertEqual(query._group_by_fields, ()) @mock.patch('google.datalab.stackdriver.monitoring.Query.iter') def test_metadata(self, mock_query_iter): query = gcm.Query(METRIC_TYPE, hours=1, context=self.context) query_metadata = query.metadata() mock_query_iter.assert_called_once_with(headers_only=True) self.assertIsInstance(query_metadata, gcm.QueryMetadata) self.assertEqual(query_metadata.metric_type, METRIC_TYPE) ================================================ FILE: legacy_tests/stackdriver/monitoring/resource_tests.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import import mock import unittest import google.cloud.monitoring_v3 import google.auth import google.datalab import google.datalab.stackdriver.monitoring as gcm DEFAULT_PROJECT = 'test' PROJECT = 'my-project' RESOURCE_TYPES = ['gce_instance', 'aws_ec2_instance'] DISPLAY_NAMES = ['GCE VM Instance', 'Amazon EC2 Instance'] LABELS = [dict(key='instance_id', value_type='STRING', description='VM instance ID'), dict(key='project_id', value_type='STRING', description='Project ID')] FILTER_STRING = 'resource.type = ends_with("instance")' class TestCases(unittest.TestCase): def setUp(self): self.context = self._create_context(DEFAULT_PROJECT) self.descriptors = gcm.ResourceDescriptors(context=self.context) @mock.patch('google.datalab.Context.default') def test_constructor_minimal(self, mock_context_default): mock_context_default.return_value = self.context descriptors = gcm.ResourceDescriptors() self.assertEqual(descriptors._client.project, DEFAULT_PROJECT) self.assertIsNone(descriptors._filter_string) self.assertIsNone(descriptors._descriptors) def test_constructor_maximal(self): context = self._create_context(PROJECT) descriptors = gcm.ResourceDescriptors( filter_string=FILTER_STRING, context=context) self.assertEqual(descriptors._client.project, PROJECT) self.assertEqual(descriptors._filter_string, FILTER_STRING) self.assertIsNone(descriptors._descriptors) @mock.patch('google.cloud.monitoring_v3.MetricServiceClient.list_monitored_resource_descriptors') def test_list(self, mock_api_list_descriptors): mock_api_list_descriptors.return_value = self._list_resources_get_result() resource_descriptor_list = self.descriptors.list() mock_api_list_descriptors.assert_called_once_with( DEFAULT_PROJECT, filter_=None) self.assertEqual(len(resource_descriptor_list), 2) self.assertEqual(resource_descriptor_list[0].type, RESOURCE_TYPES[0]) self.assertEqual(resource_descriptor_list[1].type, RESOURCE_TYPES[1]) @mock.patch('google.cloud.monitoring_v3.MetricServiceClient.list_monitored_resource_descriptors') def test_list_w_api_filter(self, mock_api_list_descriptors): mock_api_list_descriptors.return_value = self._list_resources_get_result() descriptors = gcm.ResourceDescriptors( filter_string=FILTER_STRING, context=self.context) resource_descriptor_list = descriptors.list() mock_api_list_descriptors.assert_called_once_with( DEFAULT_PROJECT, filter_=FILTER_STRING) self.assertEqual(len(resource_descriptor_list), 2) self.assertEqual(resource_descriptor_list[0].type, RESOURCE_TYPES[0]) self.assertEqual(resource_descriptor_list[1].type, RESOURCE_TYPES[1]) @mock.patch('google.cloud.monitoring_v3.MetricServiceClient.list_monitored_resource_descriptors') def test_list_w_pattern_match(self, mock_api_list_descriptors): mock_api_list_descriptors.return_value = self._list_resources_get_result() resource_descriptor_list = self.descriptors.list(pattern='*ec2*') mock_api_list_descriptors.assert_called_once_with( DEFAULT_PROJECT, filter_=None) self.assertEqual(len(resource_descriptor_list), 1) self.assertEqual(resource_descriptor_list[0].type, RESOURCE_TYPES[1]) @mock.patch('google.cloud.monitoring_v3.MetricServiceClient.list_monitored_resource_descriptors') def test_list_caching(self, mock_gcloud_list_descriptors): mock_gcloud_list_descriptors.return_value = ( self._list_resources_get_result()) actual_list1 = self.descriptors.list() actual_list2 = self.descriptors.list() mock_gcloud_list_descriptors.assert_called_once_with( DEFAULT_PROJECT, filter_=None) self.assertEqual(actual_list1, actual_list2) @mock.patch('google.cloud.monitoring_v3.MetricServiceClient.list_monitored_resource_descriptors') def test_as_dataframe(self, mock_datalab_list_descriptors): mock_datalab_list_descriptors.return_value = ( self._list_resources_get_result()) dataframe = self.descriptors.as_dataframe() mock_datalab_list_descriptors.assert_called_once_with( DEFAULT_PROJECT, filter_=None) expected_headers = list(gcm.ResourceDescriptors._DISPLAY_HEADERS) self.assertEqual(dataframe.columns.tolist(), expected_headers) self.assertEqual(dataframe.columns.names, [None]) self.assertEqual(dataframe.index.tolist(), list(range(len(RESOURCE_TYPES)))) self.assertEqual(dataframe.index.names, [None]) expected_labels = 'instance_id, project_id' expected_values = [ [resource_type, display_name, expected_labels] for resource_type, display_name in zip(RESOURCE_TYPES, DISPLAY_NAMES)] self.assertEqual(dataframe.values.tolist(), expected_values) @mock.patch('google.cloud.monitoring_v3.MetricServiceClient.list_monitored_resource_descriptors') def test_as_dataframe_w_all_args(self, mock_datalab_list_descriptors): mock_datalab_list_descriptors.return_value = ( self._list_resources_get_result()) dataframe = self.descriptors.as_dataframe(pattern='*instance*', max_rows=1) mock_datalab_list_descriptors.assert_called_once_with( DEFAULT_PROJECT, filter_=None) expected_headers = list(gcm.ResourceDescriptors._DISPLAY_HEADERS) self.assertEqual(dataframe.columns.tolist(), expected_headers) self.assertEqual(dataframe.index.tolist(), [0]) self.assertEqual(dataframe.iloc[0, 0], RESOURCE_TYPES[0]) @staticmethod def _create_context(project_id): creds = mock.Mock(spec=google.auth.credentials.Credentials) return google.datalab.Context(project_id, creds) @staticmethod def _list_resources_get_result(): all_labels = [google.cloud.monitoring_v3.types.LabelDescriptor(**labels) for labels in LABELS] descriptors = [ google.cloud.monitoring_v3.types.MonitoredResourceDescriptor( name=None, type=resource_type, display_name=display_name, description=None, labels=all_labels, ) for resource_type, display_name in zip(RESOURCE_TYPES, DISPLAY_NAMES)] return descriptors ================================================ FILE: legacy_tests/stackdriver/monitoring/utils_tests.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import import mock import unittest import google.auth import google.datalab import google.datalab.stackdriver.monitoring as gcm class TestCases(unittest.TestCase): def test_make_client(self): context = self._create_context() client = gcm._utils.make_client(context) self.assertEqual(client.project, context.project_id) @mock.patch('google.datalab.Context.default') def test_make_client_w_defaults(self, mock_context_default): default_context = self._create_context() mock_context_default.return_value = default_context client = gcm._utils.make_client() self.assertEqual(client.project, default_context.project_id) @staticmethod def _create_context(): creds = mock.Mock(spec=google.auth.credentials.Credentials) return google.datalab.Context('test_project', creds) ================================================ FILE: legacy_tests/storage/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: legacy_tests/storage/api_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import unittest import mock import google.auth import datalab.context from datalab.storage._api import Api class TestCases(unittest.TestCase): def validate(self, mock_http_request, expected_url, expected_args=None, expected_data=None, expected_headers=None, expected_method=None): url = mock_http_request.call_args[0][0] kwargs = mock_http_request.call_args[1] self.assertEquals(expected_url, url) if expected_args is not None: self.assertEquals(expected_args, kwargs['args']) if expected_data is not None: self.assertEquals(expected_data, kwargs['data']) if expected_headers is not None: self.assertEquals(expected_headers, kwargs['headers']) if expected_method is not None: self.assertEquals(expected_method, kwargs['method']) @mock.patch('datalab.utils.Http.request') def test_buckets_insert(self, mock_http_request): api = TestCases._create_api() api.buckets_insert('foo') self.validate(mock_http_request, 'https://www.googleapis.com/storage/v1/b/', expected_args={'project': 'test'}, expected_data={'name': 'foo'}) api.buckets_insert('foo', 'bar') self.validate(mock_http_request, 'https://www.googleapis.com/storage/v1/b/', expected_args={'project': 'bar'}, expected_data={'name': 'foo'}) @mock.patch('datalab.utils.Http.request') def test_buckets_delete(self, mock_http_request): api = TestCases._create_api() api.buckets_delete('foo') self.validate(mock_http_request, 'https://www.googleapis.com/storage/v1/b/foo', expected_method='DELETE') @mock.patch('datalab.utils.Http.request') def test_buckets_get(self, mock_http_request): api = TestCases._create_api() api.buckets_get('foo') self.validate(mock_http_request, 'https://www.googleapis.com/storage/v1/b/foo', expected_args={'projection': 'noAcl'}) api.buckets_get('foo', 'bar') self.validate(mock_http_request, 'https://www.googleapis.com/storage/v1/b/foo', expected_args={'projection': 'bar'}) @mock.patch('datalab.utils.Http.request') def test_buckets_list(self, mock_http_request): api = TestCases._create_api() api.buckets_list() self.validate(mock_http_request, 'https://www.googleapis.com/storage/v1/b/', expected_args={'project': 'test', 'projection': 'noAcl', 'maxResults': 100}) api.buckets_list(projection='foo', max_results=99, page_token='xyz', project_id='bar') self.validate(mock_http_request, 'https://www.googleapis.com/storage/v1/b/', expected_args={'project': 'bar', 'maxResults': 99, 'projection': 'foo', 'pageToken': 'xyz'}) @mock.patch('datalab.utils.Http.request') def test_object_download(self, mock_http_request): api = TestCases._create_api() api.object_download('foo', 'bar') self.validate(mock_http_request, 'https://www.googleapis.com/download/storage/v1/b/foo/o/bar', expected_args={'alt': 'media'}) @mock.patch('datalab.utils.Http.request') def test_object_upload(self, mock_http_request): api = TestCases._create_api() api.object_upload('b', 'k', 'c', 't') self.validate(mock_http_request, 'https://www.googleapis.com/upload/storage/v1/b/b/o/', expected_args={'uploadType': 'media', 'name': 'k'}, expected_data='c', expected_headers={'Content-Type': 't'}) @mock.patch('datalab.utils.Http.request') def test_objects_copy(self, mock_http_request): api = TestCases._create_api() api.objects_copy('sb', 'sk', 'tb', 'tk') self.validate(mock_http_request, 'https://www.googleapis.com/storage/v1/b/sb/o/sk/copyTo/b/tb/o/tk', expected_method='POST') @mock.patch('datalab.utils.Http.request') def test_objects_delete(self, mock_http_request): api = TestCases._create_api() api.objects_delete('b', 'k') self.validate(mock_http_request, 'https://www.googleapis.com/storage/v1/b/b/o/k', expected_method='DELETE') @mock.patch('datalab.utils.Http.request') def test_objects_get(self, mock_http_request): api = TestCases._create_api() api.objects_get('b', 'k') self.validate(mock_http_request, 'https://www.googleapis.com/storage/v1/b/b/o/k', expected_args={'projection': 'noAcl'}) api.objects_get('b', 'k', 'p') self.validate(mock_http_request, 'https://www.googleapis.com/storage/v1/b/b/o/k', expected_args={'projection': 'p'}) @mock.patch('datalab.utils.Http.request') def test_objects_list(self, mock_http_request): api = TestCases._create_api() api.objects_list('b') self.validate(mock_http_request, 'https://www.googleapis.com/storage/v1/b/b/o/', expected_args={'projection': 'noAcl', 'maxResults': 100}) api.objects_list('b', 'p', 'd', 'pr', True, 99, 'foo') self.validate(mock_http_request, 'https://www.googleapis.com/storage/v1/b/b/o/', expected_args={'projection': 'pr', 'maxResults': 99, 'prefix': 'p', 'delimiter': 'd', 'versions': 'true', 'pageToken': 'foo'}) @mock.patch('datalab.utils.Http.request') def test_objects_patch(self, mock_http_request): api = TestCases._create_api() api.objects_patch('b', 'k', 'i') self.validate(mock_http_request, 'https://www.googleapis.com/storage/v1/b/b/o/k', expected_method='PATCH', expected_data='i') @staticmethod def _create_api(): context = TestCases._create_context() return Api(context) @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return datalab.context.Context(project_id, creds) ================================================ FILE: legacy_tests/storage/bucket_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import mock import unittest import google.auth import datalab.context import datalab.storage import datalab.utils class TestCases(unittest.TestCase): @mock.patch('datalab.storage._api.Api.buckets_get') def test_bucket_existence(self, mock_api_buckets): mock_api_buckets.return_value = TestCases._create_buckets_get_result() buckets = datalab.storage.Buckets(context=TestCases._create_context()) self.assertTrue(buckets.contains('test_bucket')) mock_api_buckets.side_effect = datalab.utils.RequestException(404, 'failed') self.assertFalse(buckets.contains('test_bucket_2')) @mock.patch('datalab.storage._api.Api.buckets_get') def test_bucket_metadata(self, mock_api_buckets): mock_api_buckets.return_value = TestCases._create_buckets_get_result() b = TestCases._create_bucket() m = b.metadata self.assertEqual(m.name, 'test_bucket') @staticmethod def _create_bucket(name='test_bucket'): return datalab.storage.Bucket(name, context=TestCases._create_context()) @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return datalab.context.Context(project_id, creds) @staticmethod def _create_buckets_get_result(): return {'name': 'test_bucket'} ================================================ FILE: legacy_tests/storage/item_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import mock import unittest import google.auth import datalab.context import datalab.storage import datalab.utils class TestCases(unittest.TestCase): @mock.patch('datalab.storage._api.Api.objects_list') @mock.patch('datalab.storage._api.Api.objects_get') def test_item_existence(self, mock_api_objects_get, mock_api_objects_list): mock_api_objects_list.return_value = TestCases._create_enumeration_single_result() mock_api_objects_get.return_value = TestCases._create_objects_get_result() b = TestCases._create_bucket() self.assertTrue(b.items().contains('test_item1')) mock_api_objects_get.side_effect = datalab.utils.RequestException(404, 'failed') self.assertFalse('test_item2' in list(b.items())) @mock.patch('datalab.storage._api.Api.objects_get') def test_item_metadata(self, mock_api_objects): mock_api_objects.return_value = TestCases._create_objects_get_result() b = TestCases._create_bucket() i = b.item('test_item1') m = i.metadata self.assertEqual(m.name, 'test_item1') self.assertEqual(m.content_type, 'text/plain') @mock.patch('datalab.storage._api.Api.objects_list') def test_enumerate_items_empty(self, mock_api_objects): mock_api_objects.return_value = TestCases._create_enumeration_empty_result() b = self._create_bucket() items = list(b.items()) self.assertEqual(len(items), 0) @mock.patch('datalab.storage._api.Api.objects_list') def test_enumerate_items_single(self, mock_api_objects): mock_api_objects.return_value = TestCases._create_enumeration_single_result() b = TestCases._create_bucket() items = list(b.items()) self.assertEqual(len(items), 1) self.assertEqual(items[0].key, 'test_item1') @mock.patch('datalab.storage._api.Api.objects_list') def test_enumerate_items_multi_page(self, mock_api_objects): mock_api_objects.side_effect = [ TestCases._create_enumeration_multipage_result1(), TestCases._create_enumeration_multipage_result2() ] b = TestCases._create_bucket() items = list(b.items()) self.assertEqual(len(items), 2) self.assertEqual(items[0].key, 'test_item1') self.assertEqual(items[1].key, 'test_item2') @staticmethod def _create_bucket(name='test_bucket'): return datalab.storage.Bucket(name, context=TestCases._create_context()) @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return datalab.context.Context(project_id, creds) @staticmethod def _create_objects_get_result(): return {'name': 'test_item1', 'contentType': 'text/plain'} @staticmethod def _create_enumeration_empty_result(): return {} @staticmethod def _create_enumeration_single_result(): return { 'items': [ {'name': 'test_item1'} ] } @staticmethod def _create_enumeration_multipage_result1(): return { 'items': [ {'name': 'test_item1'} ], 'nextPageToken': 'test_token' } @staticmethod def _create_enumeration_multipage_result2(): return { 'items': [ {'name': 'test_item2'} ] } ================================================ FILE: release.sh ================================================ #!/bin/bash -e # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Compiles the typescript sources to javascript and submits the files # to the pypi server specified as first parameter, defaults to testpypi # In order to run this script locally, make sure you have the following: # - A Python 3 environment (due to urllib issues) # - Typescript installed # - A configured ~/.pypirc containing your pypi/testpypi credentials with # the server names matching the name you're passing in. Do not include # the repository URLs in the config file, this has been deprecated. # - Make sure the package version string in the setup.py file is updated. # It will get rejected by the server if it already exists # - If this is a new release, make sure the release notes are updated # and create a new release tag tsc --module amd --noImplicitAny datalab/notebook/static/*.ts tsc --module amd --noImplicitAny google/datalab/notebook/static/*.ts # This is the test url, you should change this to # https://upload.pypi.org/legacy/ for prod binaries server="${1:-https://test.pypi.python.org/pypi}" echo "Submitting package to ${server}" # Build and upload a distribution package rm -rf dist/* python setup.py sdist twine upload --repository-url "${server}" dist/* # Clean up rm -f datalab/notebook/static/*.js rm -f google/datalab/notebook/static/*.js ================================================ FILE: setup.cfg ================================================ [metadata] description-file = README.md [flake8] max-line-length = 100 exclude = docs ignore = # Indentation is not a multiple of four E111, # Indentation is not a multiple of four (comment) E114, # Continuation line under-indented for hanging indent E121 ================================================ FILE: setup.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. # To publish to PyPi use: python setup.py bdist_wheel upload -r pypi from setuptools import setup version = '1.2.1' setup( name='datalab', version=version, namespace_packages=['google', 'datalab'], packages=[ 'google.datalab', 'google.datalab.bigquery', 'google.datalab.bigquery.commands', 'google.datalab.commands', 'google.datalab.contrib', 'google.datalab.contrib.bigquery', 'google.datalab.contrib.bigquery.commands', 'google.datalab.contrib.bigquery.operators', 'google.datalab.contrib.mlworkbench', 'google.datalab.contrib.mlworkbench.commands', 'google.datalab.contrib.pipeline', 'google.datalab.contrib.pipeline.airflow', 'google.datalab.contrib.pipeline.composer', 'google.datalab.contrib.pipeline.commands', 'google.datalab.data', 'google.datalab.kernel', 'google.datalab.ml', 'google.datalab.notebook', 'google.datalab.stackdriver', 'google.datalab.stackdriver.commands', 'google.datalab.stackdriver.monitoring', 'google.datalab.storage', 'google.datalab.storage.commands', 'google.datalab.utils', 'google.datalab.utils.commands', 'google.datalab.utils.facets', 'datalab.bigquery', 'datalab.bigquery.commands', 'datalab.context', 'datalab.context.commands', 'datalab.data', 'datalab.data.commands', 'datalab.kernel', 'datalab.notebook', 'datalab.stackdriver', 'datalab.stackdriver.commands', 'datalab.stackdriver.monitoring', 'datalab.storage', 'datalab.storage.commands', 'datalab.utils', 'datalab.utils.commands' ], description='Google Cloud Datalab', author='Google', author_email='google-cloud-datalab-feedback@googlegroups.com', url='https://github.com/googledatalab/pydatalab', download_url='https://github.com/googledatalab/pydatalab/archive/v1.1.0.zip', keywords=[ 'Google', 'GCP', 'GCS', 'bigquery' ], license="Apache Software License", classifiers=[ "Programming Language :: Python", "Programming Language :: Python :: 2", "Programming Language :: Python :: 3", "Development Status :: 7 - Inactive", "Environment :: Other Environment", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", "Topic :: Software Development :: Libraries :: Python Modules" ], long_description_content_type='text/markdown', long_description="""\ Datalab is deprecated. [Vertex AI Workbench](https://cloud.google.com/vertex-ai/docs/workbench) provides a notebook-based environment that offers capabilities beyond Datalab. We recommend that you use Vertex AI Workbench for new projects and [migrate your Datalab notebooks to Vertex AI Workbench](https://cloud.google.com/datalab/docs/resources/troubleshooting#migrate). For more information, see [Deprecation information](https://cloud.google.com/datalab/docs/resources/deprecation). To get help migrating Datalab projects to Vertex AI Workbench see [Get help](https://cloud.google.com/datalab/docs/resources/support#get-help). """, install_requires=[ 'configparser>=3.5.0', 'mock>=2.0.0', 'future>=0.16.0', 'google-cloud-monitoring==0.31.1', 'google-api-core>=1.10.0', 'google-api-python-client>=1.6.2', 'seaborn>=0.7.0', 'plotly>=1.12.5', 'httplib2>=0.10.3', 'oauth2client>=2.2.0', 'pandas>=0.22.0', 'google_auth_httplib2>=0.0.2', 'pandas-profiling==1.4.0', 'python-dateutil>=2.5.0', 'pytz>=2015.4', 'pyyaml>=3.11', 'requests>=2.9.1', 'scikit-image>=0.13.0', 'scikit-learn>=0.18.2', 'ipykernel>=4.5.2', 'psutil>=4.3.0', 'jsonschema>=2.6.0', 'six>=1.10.0', 'urllib3>=1.22', ], extras_require={ ':python_version == "2.7"': [ 'futures>=3.0.5', ] }, package_data={ 'google.datalab.notebook': [ 'static/bigquery.css', 'static/bigquery.js', 'static/charting.css', 'static/charting.js', 'static/job.css', 'static/job.js', 'static/element.js', 'static/style.js', 'static/visualization.js', 'static/codemirror/mode/bigquery.js', 'static/parcoords.js', 'static/extern/d3.parcoords.js', 'static/extern/d3.parcoords.css', 'static/extern/sylvester.js', 'static/extern/lantern-browser.html', 'static/extern/facets-jupyter.html', ], 'datalab.notebook': [ 'static/bigquery.css', 'static/bigquery.js', 'static/charting.css', 'static/charting.js', 'static/job.css', 'static/job.js', 'static/element.js', 'static/style.js', 'static/visualization.js', 'static/codemirror/mode/sql.js', 'static/parcoords.js', 'static/extern/d3.parcoords.js', 'static/extern/d3.parcoords.css', 'static/extern/sylvester.js', 'static/extern/lantern-browser.html', 'static/extern/facets-jupyter.html', ] } ) ================================================ FILE: solutionbox/image_classification/mltoolbox/__init__.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== __import__('pkg_resources').declare_namespace(__name__) ================================================ FILE: solutionbox/image_classification/mltoolbox/image/__init__.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: solutionbox/image_classification/mltoolbox/image/classification/__init__.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """ Single label image classification solution. Typical usage is: Run preprocess() or preprocess_async() to preprocess data for training. Run train() or train_async() to train models. Run predict(), batch_predict(), batch_predict_async() to perform predictions. The trained model can also be deployed online with google.datalab.ml.ModelVersions.deploy() call. """ from ._api import preprocess, preprocess_async, train, train_async, predict, batch_predict, \ batch_predict_async __all__ = ['preprocess', 'preprocess_async', 'train', 'train_async', 'predict', 'batch_predict', 'batch_predict_async'] ================================================ FILE: solutionbox/image_classification/mltoolbox/image/classification/_api.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Face functions for image classification. """ import warnings from . import _local from . import _cloud def preprocess_async(train_dataset, output_dir, eval_dataset=None, checkpoint=None, cloud=None): """Preprocess data. Produce output that can be used by training efficiently. Args: train_dataset: training data source to preprocess. Can be CsvDataset or BigQueryDataSet. If eval_dataset is None, the pipeline will randomly split train_dataset into train/eval set with 7:3 ratio. output_dir: The output directory to use. Preprocessing will create a sub directory under it for each run, and also update "latest" file which points to the latest preprocessed directory. Users are responsible for cleanup. Can be local or GCS path. eval_dataset: evaluation data source to preprocess. Can be CsvDataset or BigQueryDataSet. If specified, it will be used for evaluation during training, and train_dataset will be completely used for training. checkpoint: the Inception checkpoint to use. If None, a default checkpoint is used. cloud: a DataFlow pipeline option dictionary such as {'num_workers': 3}. If anything but not None, it will run in cloud. Otherwise, it runs locally. Returns: A google.datalab.utils.Job object that can be used to query state from or wait. """ with warnings.catch_warnings(): warnings.simplefilter("ignore") if cloud is None: return _local.Local.preprocess(train_dataset, output_dir, eval_dataset, checkpoint) if not isinstance(cloud, dict): cloud = {} return _cloud.Cloud.preprocess(train_dataset, output_dir, eval_dataset, checkpoint, cloud) def preprocess(train_dataset, output_dir, eval_dataset=None, checkpoint=None, cloud=None): """Blocking version of preprocess_async(). The only difference is that it blocks the caller until the job finishes, and it does not have a return value. """ with warnings.catch_warnings(): warnings.simplefilter("ignore") job = preprocess_async(train_dataset, output_dir, eval_dataset, checkpoint, cloud) job.wait() print(job.state) def train_async(input_dir, batch_size, max_steps, output_dir, checkpoint=None, cloud=None): """Train model. The output can be used for batch prediction or online deployment. Args: input_dir: A directory path containing preprocessed results. Can be local or GCS path. batch_size: size of batch used for training. max_steps: number of steps to train. output_dir: The output directory to use. Can be local or GCS path. checkpoint: the Inception checkpoint to use. If None, a default checkpoint is used. cloud: a google.datalab.ml.CloudTrainingConfig object to let it run in cloud. If None, it runs locally. Returns: A google.datalab.utils.Job object that can be used to query state from or wait. """ with warnings.catch_warnings(): warnings.simplefilter("ignore") if cloud is None: return _local.Local.train(input_dir, batch_size, max_steps, output_dir, checkpoint) return _cloud.Cloud.train(input_dir, batch_size, max_steps, output_dir, checkpoint, cloud) def train(input_dir, batch_size, max_steps, output_dir, checkpoint=None, cloud=None): """Blocking version of train_async(). The only difference is that it blocks the caller until the job finishes, and it does not have a return value. """ with warnings.catch_warnings(): warnings.simplefilter("ignore") job = train_async(input_dir, batch_size, max_steps, output_dir, checkpoint, cloud) job.wait() print(job.state) def predict(model, image_files, resize=False, show_image=True, cloud=None): """Predict using an model in a local or GCS directory (offline), or a deployed model (online). Args: model: if cloud is None, a local or GCS directory of a trained model. Otherwise, it specifies a deployed model identified by model.version, such as "imagemodel.v1". image_files: The paths to the image files to predict labels. Can be local or GCS paths. resize: Whether to resize the image to a reasonable size (300x300) before prediction. show_image: Whether to show images in the results. cloud: if None, predicts with offline model locally. Otherwise, predict with a deployed online model. Returns: A pandas DataFrame including the prediction results. """ print('Predicting...') with warnings.catch_warnings(): warnings.simplefilter("ignore") if cloud is None: results = _local.Local.predict(model, image_files, resize, show_image) else: results = _cloud.Cloud.predict(model, image_files, resize, show_image) return results def batch_predict_async(dataset, model_dir, output_csv=None, output_bq_table=None, cloud=None): """Batch prediction with an offline model. Args: dataset: CsvDataSet or BigQueryDataSet for batch prediction input. Can contain either one column 'image_url', or two columns with another being 'label'. model_dir: The directory of a trained inception model. Can be local or GCS paths. output_csv: The output csv file for prediction results. If specified, it will also output a csv schema file with the name output_csv + '.schema.json'. output_bq_table: if specified, the output BigQuery table for prediction results. output_csv and output_bq_table can both be set. cloud: a DataFlow pipeline option dictionary such as {'num_workers': 3}. If anything but not None, it will run in cloud. Otherwise, it runs locally. If specified, it must include 'temp_location' with value being a GCS path, because cloud run requires a staging GCS directory. Raises: ValueError if both output_csv and output_bq_table are None, or if cloud is not None but it does not include 'temp_location'. Returns: A google.datalab.utils.Job object that can be used to query state from or wait. """ with warnings.catch_warnings(): warnings.simplefilter("ignore") if cloud is None: return _local.Local.batch_predict(dataset, model_dir, output_csv, output_bq_table) if not isinstance(cloud, dict): cloud = {} return _cloud.Cloud.batch_predict(dataset, model_dir, output_csv, output_bq_table, cloud) def batch_predict(dataset, model_dir, output_csv=None, output_bq_table=None, cloud=None): """Blocking version of batch_predict_async(). The only difference is that it blocks the caller until the job finishes, and it does not have a return value. """ with warnings.catch_warnings(): warnings.simplefilter("ignore") job = batch_predict_async(dataset, model_dir, output_csv, output_bq_table, cloud) job.wait() print(job.state) ================================================ FILE: solutionbox/image_classification/mltoolbox/image/classification/_cloud.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Cloud implementation for preprocessing, training and prediction for inception model. """ import base64 import datetime import logging import os import shutil import tempfile from tensorflow.python.lib.io import file_io import urllib from . import _util _TF_GS_URL = 'gs://cloud-datalab/deploy/tf/tensorflow-1.2.0-cp27-none-linux_x86_64.whl' _PROTOBUF_GS_URL = 'gs://cloud-datalab/deploy/tf/protobuf-3.1.0-py2.py3-none-any.whl' class Cloud(object): """Class for cloud training, preprocessing and prediction.""" @staticmethod def preprocess(train_dataset, output_dir, eval_dataset, checkpoint, pipeline_option): """Preprocess data in Cloud with DataFlow.""" import apache_beam as beam import google.datalab.utils from . import _preprocess if checkpoint is None: checkpoint = _util._DEFAULT_CHECKPOINT_GSURL job_name = ('preprocess-image-classification-' + datetime.datetime.now().strftime('%y%m%d-%H%M%S')) staging_package_url = _util.repackage_to_staging(output_dir) tmpdir = tempfile.mkdtemp() # suppress DataFlow warnings about wheel package as extra package. original_level = logging.getLogger().getEffectiveLevel() logging.getLogger().setLevel(logging.ERROR) try: # Workaround for DataFlow 2.0, which doesn't work well with extra packages in GCS. # Remove when the issue is fixed and new version of DataFlow is included in Datalab. extra_packages = [staging_package_url, _TF_GS_URL, _PROTOBUF_GS_URL] local_packages = [os.path.join(tmpdir, os.path.basename(p)) for p in extra_packages] for source, dest in zip(extra_packages, local_packages): file_io.copy(source, dest, overwrite=True) options = { 'staging_location': os.path.join(output_dir, 'tmp', 'staging'), 'temp_location': os.path.join(output_dir, 'tmp'), 'job_name': job_name, 'project': _util.default_project(), 'extra_packages': local_packages, 'teardown_policy': 'TEARDOWN_ALWAYS', 'no_save_main_session': True } if pipeline_option is not None: options.update(pipeline_option) opts = beam.pipeline.PipelineOptions(flags=[], **options) p = beam.Pipeline('DataflowRunner', options=opts) _preprocess.configure_pipeline(p, train_dataset, eval_dataset, checkpoint, output_dir, job_name) job_results = p.run() finally: shutil.rmtree(tmpdir) logging.getLogger().setLevel(original_level) if (_util.is_in_IPython()): import IPython dataflow_url = 'https://console.developers.google.com/dataflow?project=%s' % \ _util.default_project() html = 'Job "%s" submitted.' % job_name html += '

Click here to track preprocessing job.
' \ % dataflow_url IPython.display.display_html(html, raw=True) return google.datalab.utils.DataflowJob(job_results) @staticmethod def train(input_dir, batch_size, max_steps, output_dir, checkpoint, cloud_train_config): """Train model in the cloud with CloudML trainer service.""" import google.datalab.ml as ml if checkpoint is None: checkpoint = _util._DEFAULT_CHECKPOINT_GSURL staging_package_url = _util.repackage_to_staging(output_dir) job_args = { 'input_dir': input_dir, 'max_steps': max_steps, 'batch_size': batch_size, 'checkpoint': checkpoint } job_request = { 'package_uris': [staging_package_url, _TF_GS_URL, _PROTOBUF_GS_URL], 'python_module': 'mltoolbox.image.classification.task', 'job_dir': output_dir, 'args': job_args } job_request.update(dict(cloud_train_config._asdict())) job_id = 'image_classification_train_' + datetime.datetime.now().strftime('%y%m%d_%H%M%S') job = ml.Job.submit_training(job_request, job_id) if (_util.is_in_IPython()): import IPython log_url_query_strings = { 'project': _util.default_project(), 'resource': 'ml.googleapis.com/job_id/' + job.info['jobId'] } log_url = 'https://console.developers.google.com/logs/viewer?' + \ urllib.urlencode(log_url_query_strings) html = 'Job "%s" submitted.' % job.info['jobId'] html += '

Click here to view cloud log.
' % log_url IPython.display.display_html(html, raw=True) return job @staticmethod def predict(model_id, image_files, resize, show_image): """Predict using a deployed (online) model.""" import google.datalab.ml as ml images = _util.load_images(image_files, resize=resize) parts = model_id.split('.') if len(parts) != 2: raise ValueError('Invalid model name for cloud prediction. Use "model.version".') if len(images) == 0: raise ValueError('images is empty.') data = [] for ii, image in enumerate(images): image_encoded = base64.b64encode(image) data.append({ 'key': str(ii), 'image_bytes': {'b64': image_encoded} }) predictions = ml.ModelVersions(parts[0]).predict(parts[1], data) if len(predictions) == 0: raise Exception('Prediction results are empty.') # Although prediction results contains a labels list in each instance, they are all the same # so taking the first one. labels = predictions[0]['labels'] labels_and_scores = [(x['prediction'], x['scores'][labels.index(x['prediction'])]) for x in predictions] results = zip(image_files, images, labels_and_scores) ret = _util.process_prediction_results(results, show_image) return ret @staticmethod def batch_predict(dataset, model_dir, output_csv, output_bq_table, pipeline_option): """Batch predict running in cloud.""" import apache_beam as beam import google.datalab.utils from . import _predictor if output_csv is None and output_bq_table is None: raise ValueError('output_csv and output_bq_table cannot both be None.') if 'temp_location' not in pipeline_option: raise ValueError('"temp_location" is not set in cloud.') job_name = ('batch-predict-image-classification-' + datetime.datetime.now().strftime('%y%m%d-%H%M%S')) staging_package_url = _util.repackage_to_staging(pipeline_option['temp_location']) tmpdir = tempfile.mkdtemp() # suppress DataFlow warnings about wheel package as extra package. original_level = logging.getLogger().getEffectiveLevel() logging.getLogger().setLevel(logging.ERROR) try: # Workaround for DataFlow 2.0, which doesn't work well with extra packages in GCS. # Remove when the issue is fixed and new version of DataFlow is included in Datalab. extra_packages = [staging_package_url, _TF_GS_URL, _PROTOBUF_GS_URL] local_packages = [os.path.join(tmpdir, os.path.basename(p)) for p in extra_packages] for source, dest in zip(extra_packages, local_packages): file_io.copy(source, dest, overwrite=True) options = { 'staging_location': os.path.join(pipeline_option['temp_location'], 'staging'), 'job_name': job_name, 'project': _util.default_project(), 'extra_packages': local_packages, 'teardown_policy': 'TEARDOWN_ALWAYS', 'no_save_main_session': True } options.update(pipeline_option) opts = beam.pipeline.PipelineOptions(flags=[], **options) p = beam.Pipeline('DataflowRunner', options=opts) _predictor.configure_pipeline(p, dataset, model_dir, output_csv, output_bq_table) job_results = p.run() finally: shutil.rmtree(tmpdir) logging.getLogger().setLevel(original_level) if (_util.is_in_IPython()): import IPython dataflow_url = ('https://console.developers.google.com/dataflow?project=%s' % _util.default_project()) html = 'Job "%s" submitted.' % job_name html += ('

Click here to track batch prediction job.
' % dataflow_url) IPython.display.display_html(html, raw=True) return google.datalab.utils.DataflowJob(job_results) ================================================ FILE: solutionbox/image_classification/mltoolbox/image/classification/_inceptionlib.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inception model building libraries. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf slim = tf.contrib.slim def trunc_normal(stddev): return tf.truncated_normal_initializer(0.0, stddev) def inception_v3_base(inputs, final_endpoint='Mixed_7c', min_depth=16, depth_multiplier=1.0, scope=None): """Inception model from http://arxiv.org/abs/1512.00567. Constructs an Inception v3 network from inputs to the given final endpoint. This method can construct the network up to the final inception block Mixed_7c. Note that the names of the layers in the paper do not correspond to the names of the endpoints registered by this function although they build the same network. Here is a mapping from the old_names to the new names: Old name | New name ======================================= conv0 | Conv2d_1a_3x3 conv1 | Conv2d_2a_3x3 conv2 | Conv2d_2b_3x3 pool1 | MaxPool_3a_3x3 conv3 | Conv2d_3b_1x1 conv4 | Conv2d_4a_3x3 pool2 | MaxPool_5a_3x3 mixed_35x35x256a | Mixed_5b mixed_35x35x288a | Mixed_5c mixed_35x35x288b | Mixed_5d mixed_17x17x768a | Mixed_6a mixed_17x17x768b | Mixed_6b mixed_17x17x768c | Mixed_6c mixed_17x17x768d | Mixed_6d mixed_17x17x768e | Mixed_6e mixed_8x8x1280a | Mixed_7a mixed_8x8x2048a | Mixed_7b mixed_8x8x2048b | Mixed_7c Args: inputs: a tensor of size [batch_size, height, width, channels]. final_endpoint: specifies the endpoint to construct the network up to. It can be one of ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c']. min_depth: Minimum depth value (number of channels) for all convolution ops. Enforced when depth_multiplier < 1, and not an active constraint when depth_multiplier >= 1. depth_multiplier: Float multiplier for the depth (number of channels) for all convolution ops. The value must be greater than zero. Typical usage will be to set this value in (0, 1) to reduce the number of parameters or computation cost of the model. scope: Optional variable_scope. Returns: tensor_out: output tensor corresponding to the final_endpoint. end_points: a set of activations for external use, for example summaries or losses. Raises: ValueError: if final_endpoint is not set to one of the predefined values, or depth_multiplier <= 0 """ # end_points will collect relevant activations for external use, for example # summaries or losses. end_points = {} if depth_multiplier <= 0: raise ValueError('depth_multiplier is not greater than zero.') def depth(d): return max(int(d * depth_multiplier), min_depth) with tf.variable_scope(scope, 'InceptionV3', [inputs]): with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d], stride=1, padding='VALID'): # 299 x 299 x 3 end_point = 'Conv2d_1a_3x3' net = slim.conv2d(inputs, depth(32), [3, 3], stride=2, scope=end_point) end_points[end_point] = net if end_point == final_endpoint: return net, end_points # 149 x 149 x 32 end_point = 'Conv2d_2a_3x3' net = slim.conv2d(net, depth(32), [3, 3], scope=end_point) end_points[end_point] = net if end_point == final_endpoint: return net, end_points # 147 x 147 x 32 end_point = 'Conv2d_2b_3x3' net = slim.conv2d(net, depth(64), [3, 3], padding='SAME', scope=end_point) end_points[end_point] = net if end_point == final_endpoint: return net, end_points # 147 x 147 x 64 end_point = 'MaxPool_3a_3x3' net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point) end_points[end_point] = net if end_point == final_endpoint: return net, end_points # 73 x 73 x 64 end_point = 'Conv2d_3b_1x1' net = slim.conv2d(net, depth(80), [1, 1], scope=end_point) end_points[end_point] = net if end_point == final_endpoint: return net, end_points # 73 x 73 x 80. end_point = 'Conv2d_4a_3x3' net = slim.conv2d(net, depth(192), [3, 3], scope=end_point) end_points[end_point] = net if end_point == final_endpoint: return net, end_points # 71 x 71 x 192. end_point = 'MaxPool_5a_3x3' net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point) end_points[end_point] = net if end_point == final_endpoint: return net, end_points # 35 x 35 x 192. # Inception blocks with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d], stride=1, padding='SAME'): # mixed: 35 x 35 x 256. end_point = 'Mixed_5b' with tf.variable_scope(end_point): with tf.variable_scope('Branch_0'): branch_0 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1') with tf.variable_scope('Branch_1'): branch_1 = slim.conv2d(net, depth(48), [1, 1], scope='Conv2d_0a_1x1') branch_1 = slim.conv2d(branch_1, depth(64), [5, 5], scope='Conv2d_0b_5x5') with tf.variable_scope('Branch_2'): branch_2 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1') branch_2 = slim.conv2d(branch_2, depth(96), [3, 3], scope='Conv2d_0b_3x3') branch_2 = slim.conv2d(branch_2, depth(96), [3, 3], scope='Conv2d_0c_3x3') with tf.variable_scope('Branch_3'): branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3') branch_3 = slim.conv2d(branch_3, depth(32), [1, 1], scope='Conv2d_0b_1x1') net = tf.concat([branch_0, branch_1, branch_2, branch_3], 3) end_points[end_point] = net if end_point == final_endpoint: return net, end_points # mixed_1: 35 x 35 x 288. end_point = 'Mixed_5c' with tf.variable_scope(end_point): with tf.variable_scope('Branch_0'): branch_0 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1') with tf.variable_scope('Branch_1'): branch_1 = slim.conv2d(net, depth(48), [1, 1], scope='Conv2d_0b_1x1') branch_1 = slim.conv2d(branch_1, depth(64), [5, 5], scope='Conv_1_0c_5x5') with tf.variable_scope('Branch_2'): branch_2 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1') branch_2 = slim.conv2d(branch_2, depth(96), [3, 3], scope='Conv2d_0b_3x3') branch_2 = slim.conv2d(branch_2, depth(96), [3, 3], scope='Conv2d_0c_3x3') with tf.variable_scope('Branch_3'): branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3') branch_3 = slim.conv2d(branch_3, depth(64), [1, 1], scope='Conv2d_0b_1x1') net = tf.concat([branch_0, branch_1, branch_2, branch_3], 3) end_points[end_point] = net if end_point == final_endpoint: return net, end_points # mixed_2: 35 x 35 x 288. end_point = 'Mixed_5d' with tf.variable_scope(end_point): with tf.variable_scope('Branch_0'): branch_0 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1') with tf.variable_scope('Branch_1'): branch_1 = slim.conv2d(net, depth(48), [1, 1], scope='Conv2d_0a_1x1') branch_1 = slim.conv2d(branch_1, depth(64), [5, 5], scope='Conv2d_0b_5x5') with tf.variable_scope('Branch_2'): branch_2 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1') branch_2 = slim.conv2d(branch_2, depth(96), [3, 3], scope='Conv2d_0b_3x3') branch_2 = slim.conv2d(branch_2, depth(96), [3, 3], scope='Conv2d_0c_3x3') with tf.variable_scope('Branch_3'): branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3') branch_3 = slim.conv2d(branch_3, depth(64), [1, 1], scope='Conv2d_0b_1x1') net = tf.concat([branch_0, branch_1, branch_2, branch_3], 3) end_points[end_point] = net if end_point == final_endpoint: return net, end_points # mixed_3: 17 x 17 x 768. end_point = 'Mixed_6a' with tf.variable_scope(end_point): with tf.variable_scope('Branch_0'): branch_0 = slim.conv2d(net, depth(384), [3, 3], stride=2, padding='VALID', scope='Conv2d_1a_1x1') with tf.variable_scope('Branch_1'): branch_1 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1') branch_1 = slim.conv2d(branch_1, depth(96), [3, 3], scope='Conv2d_0b_3x3') branch_1 = slim.conv2d(branch_1, depth(96), [3, 3], stride=2, padding='VALID', scope='Conv2d_1a_1x1') with tf.variable_scope('Branch_2'): branch_2 = slim.max_pool2d(net, [3, 3], stride=2, padding='VALID', scope='MaxPool_1a_3x3') net = tf.concat([branch_0, branch_1, branch_2], 3) end_points[end_point] = net if end_point == final_endpoint: return net, end_points # mixed4: 17 x 17 x 768. end_point = 'Mixed_6b' with tf.variable_scope(end_point): with tf.variable_scope('Branch_0'): branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1') with tf.variable_scope('Branch_1'): branch_1 = slim.conv2d(net, depth(128), [1, 1], scope='Conv2d_0a_1x1') branch_1 = slim.conv2d(branch_1, depth(128), [1, 7], scope='Conv2d_0b_1x7') branch_1 = slim.conv2d(branch_1, depth(192), [7, 1], scope='Conv2d_0c_7x1') with tf.variable_scope('Branch_2'): branch_2 = slim.conv2d(net, depth(128), [1, 1], scope='Conv2d_0a_1x1') branch_2 = slim.conv2d(branch_2, depth(128), [7, 1], scope='Conv2d_0b_7x1') branch_2 = slim.conv2d(branch_2, depth(128), [1, 7], scope='Conv2d_0c_1x7') branch_2 = slim.conv2d(branch_2, depth(128), [7, 1], scope='Conv2d_0d_7x1') branch_2 = slim.conv2d(branch_2, depth(192), [1, 7], scope='Conv2d_0e_1x7') with tf.variable_scope('Branch_3'): branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3') branch_3 = slim.conv2d(branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1') net = tf.concat([branch_0, branch_1, branch_2, branch_3], 3) end_points[end_point] = net if end_point == final_endpoint: return net, end_points # mixed_5: 17 x 17 x 768. end_point = 'Mixed_6c' with tf.variable_scope(end_point): with tf.variable_scope('Branch_0'): branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1') with tf.variable_scope('Branch_1'): branch_1 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1') branch_1 = slim.conv2d(branch_1, depth(160), [1, 7], scope='Conv2d_0b_1x7') branch_1 = slim.conv2d(branch_1, depth(192), [7, 1], scope='Conv2d_0c_7x1') with tf.variable_scope('Branch_2'): branch_2 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1') branch_2 = slim.conv2d(branch_2, depth(160), [7, 1], scope='Conv2d_0b_7x1') branch_2 = slim.conv2d(branch_2, depth(160), [1, 7], scope='Conv2d_0c_1x7') branch_2 = slim.conv2d(branch_2, depth(160), [7, 1], scope='Conv2d_0d_7x1') branch_2 = slim.conv2d(branch_2, depth(192), [1, 7], scope='Conv2d_0e_1x7') with tf.variable_scope('Branch_3'): branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3') branch_3 = slim.conv2d(branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1') net = tf.concat([branch_0, branch_1, branch_2, branch_3], 3) end_points[end_point] = net if end_point == final_endpoint: return net, end_points # mixed_6: 17 x 17 x 768. end_point = 'Mixed_6d' with tf.variable_scope(end_point): with tf.variable_scope('Branch_0'): branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1') with tf.variable_scope('Branch_1'): branch_1 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1') branch_1 = slim.conv2d(branch_1, depth(160), [1, 7], scope='Conv2d_0b_1x7') branch_1 = slim.conv2d(branch_1, depth(192), [7, 1], scope='Conv2d_0c_7x1') with tf.variable_scope('Branch_2'): branch_2 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1') branch_2 = slim.conv2d(branch_2, depth(160), [7, 1], scope='Conv2d_0b_7x1') branch_2 = slim.conv2d(branch_2, depth(160), [1, 7], scope='Conv2d_0c_1x7') branch_2 = slim.conv2d(branch_2, depth(160), [7, 1], scope='Conv2d_0d_7x1') branch_2 = slim.conv2d(branch_2, depth(192), [1, 7], scope='Conv2d_0e_1x7') with tf.variable_scope('Branch_3'): branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3') branch_3 = slim.conv2d(branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1') net = tf.concat([branch_0, branch_1, branch_2, branch_3], 3) end_points[end_point] = net if end_point == final_endpoint: return net, end_points # mixed_7: 17 x 17 x 768. end_point = 'Mixed_6e' with tf.variable_scope(end_point): with tf.variable_scope('Branch_0'): branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1') with tf.variable_scope('Branch_1'): branch_1 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1') branch_1 = slim.conv2d(branch_1, depth(192), [1, 7], scope='Conv2d_0b_1x7') branch_1 = slim.conv2d(branch_1, depth(192), [7, 1], scope='Conv2d_0c_7x1') with tf.variable_scope('Branch_2'): branch_2 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1') branch_2 = slim.conv2d(branch_2, depth(192), [7, 1], scope='Conv2d_0b_7x1') branch_2 = slim.conv2d(branch_2, depth(192), [1, 7], scope='Conv2d_0c_1x7') branch_2 = slim.conv2d(branch_2, depth(192), [7, 1], scope='Conv2d_0d_7x1') branch_2 = slim.conv2d(branch_2, depth(192), [1, 7], scope='Conv2d_0e_1x7') with tf.variable_scope('Branch_3'): branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3') branch_3 = slim.conv2d(branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1') net = tf.concat([branch_0, branch_1, branch_2, branch_3], 3) end_points[end_point] = net if end_point == final_endpoint: return net, end_points # mixed_8: 8 x 8 x 1280. end_point = 'Mixed_7a' with tf.variable_scope(end_point): with tf.variable_scope('Branch_0'): branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1') branch_0 = slim.conv2d(branch_0, depth(320), [3, 3], stride=2, padding='VALID', scope='Conv2d_1a_3x3') with tf.variable_scope('Branch_1'): branch_1 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1') branch_1 = slim.conv2d(branch_1, depth(192), [1, 7], scope='Conv2d_0b_1x7') branch_1 = slim.conv2d(branch_1, depth(192), [7, 1], scope='Conv2d_0c_7x1') branch_1 = slim.conv2d(branch_1, depth(192), [3, 3], stride=2, padding='VALID', scope='Conv2d_1a_3x3') with tf.variable_scope('Branch_2'): branch_2 = slim.max_pool2d(net, [3, 3], stride=2, padding='VALID', scope='MaxPool_1a_3x3') net = tf.concat([branch_0, branch_1, branch_2], 3) end_points[end_point] = net if end_point == final_endpoint: return net, end_points # mixed_9: 8 x 8 x 2048. end_point = 'Mixed_7b' with tf.variable_scope(end_point): with tf.variable_scope('Branch_0'): branch_0 = slim.conv2d(net, depth(320), [1, 1], scope='Conv2d_0a_1x1') with tf.variable_scope('Branch_1'): branch_1 = slim.conv2d(net, depth(384), [1, 1], scope='Conv2d_0a_1x1') branch_1 = tf.concat([ slim.conv2d(branch_1, depth(384), [1, 3], scope='Conv2d_0b_1x3'), slim.conv2d(branch_1, depth(384), [3, 1], scope='Conv2d_0b_3x1')], 3) with tf.variable_scope('Branch_2'): branch_2 = slim.conv2d(net, depth(448), [1, 1], scope='Conv2d_0a_1x1') branch_2 = slim.conv2d( branch_2, depth(384), [3, 3], scope='Conv2d_0b_3x3') branch_2 = tf.concat([ slim.conv2d(branch_2, depth(384), [1, 3], scope='Conv2d_0c_1x3'), slim.conv2d(branch_2, depth(384), [3, 1], scope='Conv2d_0d_3x1')], 3) with tf.variable_scope('Branch_3'): branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3') branch_3 = slim.conv2d( branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1') net = tf.concat([branch_0, branch_1, branch_2, branch_3], 3) end_points[end_point] = net if end_point == final_endpoint: return net, end_points # mixed_10: 8 x 8 x 2048. end_point = 'Mixed_7c' with tf.variable_scope(end_point): with tf.variable_scope('Branch_0'): branch_0 = slim.conv2d(net, depth(320), [1, 1], scope='Conv2d_0a_1x1') with tf.variable_scope('Branch_1'): branch_1 = slim.conv2d(net, depth(384), [1, 1], scope='Conv2d_0a_1x1') branch_1 = tf.concat([ slim.conv2d(branch_1, depth(384), [1, 3], scope='Conv2d_0b_1x3'), slim.conv2d(branch_1, depth(384), [3, 1], scope='Conv2d_0c_3x1')], 3) with tf.variable_scope('Branch_2'): branch_2 = slim.conv2d(net, depth(448), [1, 1], scope='Conv2d_0a_1x1') branch_2 = slim.conv2d( branch_2, depth(384), [3, 3], scope='Conv2d_0b_3x3') branch_2 = tf.concat([ slim.conv2d(branch_2, depth(384), [1, 3], scope='Conv2d_0c_1x3'), slim.conv2d(branch_2, depth(384), [3, 1], scope='Conv2d_0d_3x1')], 3) with tf.variable_scope('Branch_3'): branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3') branch_3 = slim.conv2d( branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1') net = tf.concat([branch_0, branch_1, branch_2, branch_3], 3) end_points[end_point] = net if end_point == final_endpoint: return net, end_points raise ValueError('Unknown final endpoint %s' % final_endpoint) def inception_v3(inputs, num_classes=1000, is_training=True, dropout_keep_prob=0.8, min_depth=16, depth_multiplier=1.0, prediction_fn=slim.softmax, spatial_squeeze=True, reuse=None, scope='InceptionV3'): """Inception model from http://arxiv.org/abs/1512.00567. "Rethinking the Inception Architecture for Computer Vision" Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, Zbigniew Wojna. With the default arguments this method constructs the exact model defined in the paper. However, one can experiment with variations of the inception_v3 network by changing arguments dropout_keep_prob, min_depth and depth_multiplier. The default image size used to train this network is 299x299. Args: inputs: a tensor of size [batch_size, height, width, channels]. num_classes: number of predicted classes. is_training: whether is training or not. dropout_keep_prob: the percentage of activation values that are retained. min_depth: Minimum depth value (number of channels) for all convolution ops. Enforced when depth_multiplier < 1, and not an active constraint when depth_multiplier >= 1. depth_multiplier: Float multiplier for the depth (number of channels) for all convolution ops. The value must be greater than zero. Typical usage will be to set this value in (0, 1) to reduce the number of parameters or computation cost of the model. prediction_fn: a function to get predictions out of logits. spatial_squeeze: if True, logits is of shape is [B, C], if false logits is of shape [B, 1, 1, C], where B is batch_size and C is number of classes. reuse: whether or not the network and its variables should be reused. To be able to reuse 'scope' must be given. scope: Optional variable_scope. Returns: logits: the pre-softmax activations, a tensor of size [batch_size, num_classes] end_points: a dictionary from components of the network to the corresponding activation. Raises: ValueError: if 'depth_multiplier' is less than or equal to zero. """ if depth_multiplier <= 0: raise ValueError('depth_multiplier is not greater than zero.') def depth(d): return max(int(d * depth_multiplier), min_depth) with tf.variable_scope(scope, 'InceptionV3', [inputs, num_classes], reuse=reuse) as scope: with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training): net, end_points = inception_v3_base( inputs, scope=scope, min_depth=min_depth, depth_multiplier=depth_multiplier) # Auxiliary Head logits with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d], stride=1, padding='SAME'): aux_logits = end_points['Mixed_6e'] with tf.variable_scope('AuxLogits'): aux_logits = slim.avg_pool2d( aux_logits, [5, 5], stride=3, padding='VALID', scope='AvgPool_1a_5x5') aux_logits = slim.conv2d(aux_logits, depth(128), [1, 1], scope='Conv2d_1b_1x1') # Shape of feature map before the final layer. kernel_size = _reduced_kernel_size_for_small_input( aux_logits, [5, 5]) aux_logits = slim.conv2d( aux_logits, depth(768), kernel_size, weights_initializer=trunc_normal(0.01), padding='VALID', scope='Conv2d_2a_{}x{}'.format(*kernel_size)) aux_logits = slim.conv2d( aux_logits, num_classes, [1, 1], activation_fn=None, normalizer_fn=None, weights_initializer=trunc_normal(0.001), scope='Conv2d_2b_1x1') if spatial_squeeze: aux_logits = tf.squeeze(aux_logits, [1, 2], name='SpatialSqueeze') end_points['AuxLogits'] = aux_logits # Final pooling and prediction with tf.variable_scope('Logits'): kernel_size = _reduced_kernel_size_for_small_input(net, [8, 8]) net = slim.avg_pool2d(net, kernel_size, padding='VALID', scope='AvgPool_1a_{}x{}'.format(*kernel_size)) # 1 x 1 x 2048 net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b') end_points['PreLogits'] = net # 2048 logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, normalizer_fn=None, scope='Conv2d_1c_1x1') if spatial_squeeze: logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze') # 1000 end_points['Logits'] = logits end_points['Predictions'] = prediction_fn(logits, scope='Predictions') return logits, end_points inception_v3.default_image_size = 299 def _reduced_kernel_size_for_small_input(input_tensor, kernel_size): """Define kernel size which is automatically reduced for small input. If the shape of the input images is unknown at graph construction time this function assumes that the input images are is large enough. Args: input_tensor: input tensor of size [batch_size, height, width, channels]. kernel_size: desired kernel size of length 2: [kernel_height, kernel_width] Returns: a tensor with the kernel size. TODO(jrru): Make this function work with unknown shapes. Theoretically, this can be done with the code below. Problems are two-fold: (1) If the shape was known, it will be lost. (2) inception.slim.ops._two_element_tuple cannot handle tensors that define the kernel size. shape = tf.shape(input_tensor) return = tf.stack([tf.minimum(shape[1], kernel_size[0]), tf.minimum(shape[2], kernel_size[1])]) """ shape = input_tensor.get_shape().as_list() if shape[1] is None or shape[2] is None: kernel_size_out = kernel_size else: kernel_size_out = [min(shape[1], kernel_size[0]), min(shape[2], kernel_size[1])] return kernel_size_out def inception_v3_arg_scope(weight_decay=0.00004, stddev=0.1, batch_norm_var_collection='moving_vars'): """Defines the default InceptionV3 arg scope. Args: weight_decay: The weight decay to use for regularizing the model. stddev: The standard deviation of the trunctated normal weight initializer. batch_norm_var_collection: The name of the collection for the batch norm variables. Returns: An `arg_scope` to use for the inception v3 model. """ batch_norm_params = { # Decay for the moving averages. 'decay': 0.9997, # epsilon to prevent 0s in variance. 'epsilon': 0.001, # collection containing update_ops. 'updates_collections': tf.GraphKeys.UPDATE_OPS, # collection containing the moving mean and moving variance. 'variables_collections': { 'beta': None, 'gamma': None, 'moving_mean': [batch_norm_var_collection], 'moving_variance': [batch_norm_var_collection], } } # Set weight_decay for weights in Conv and FC layers. with slim.arg_scope([slim.conv2d, slim.fully_connected], weights_regularizer=slim.l2_regularizer(weight_decay)): with slim.arg_scope([slim.conv2d], weights_initializer=tf.truncated_normal_initializer(stddev=stddev), activation_fn=tf.nn.relu, normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params) as sc: return sc ================================================ FILE: solutionbox/image_classification/mltoolbox/image/classification/_local.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Local implementation for preprocessing, training and prediction for inception model. """ import datetime from . import _model from . import _trainer from . import _util class Local(object): """Class for local training, preprocessing and prediction.""" @staticmethod def preprocess(train_dataset, output_dir, eval_dataset, checkpoint): """Preprocess data locally.""" import apache_beam as beam from google.datalab.utils import LambdaJob from . import _preprocess if checkpoint is None: checkpoint = _util._DEFAULT_CHECKPOINT_GSURL job_id = ('preprocess-image-classification-' + datetime.datetime.now().strftime('%y%m%d-%H%M%S')) # Project is needed for bigquery data source, even in local run. options = { 'project': _util.default_project(), } opts = beam.pipeline.PipelineOptions(flags=[], **options) p = beam.Pipeline('DirectRunner', options=opts) _preprocess.configure_pipeline(p, train_dataset, eval_dataset, checkpoint, output_dir, job_id) job = LambdaJob(lambda: p.run().wait_until_finish(), job_id) return job @staticmethod def train(input_dir, batch_size, max_steps, output_dir, checkpoint): """Train model locally.""" from google.datalab.utils import LambdaJob if checkpoint is None: checkpoint = _util._DEFAULT_CHECKPOINT_GSURL labels = _util.get_labels(input_dir) model = _model.Model(labels, 0.5, checkpoint) task_data = {'type': 'master', 'index': 0} task = type('TaskSpec', (object,), task_data) job = LambdaJob(lambda: _trainer.Trainer(input_dir, batch_size, max_steps, output_dir, model, None, task).run_training(), 'training') return job @staticmethod def predict(model_dir, image_files, resize, show_image): """Predict using an model in a local or GCS directory.""" from . import _predictor images = _util.load_images(image_files, resize=resize) labels_and_scores = _predictor.predict(model_dir, images) results = zip(image_files, images, labels_and_scores) ret = _util.process_prediction_results(results, show_image) return ret @staticmethod def batch_predict(dataset, model_dir, output_csv, output_bq_table): """Batch predict running locally.""" import apache_beam as beam from google.datalab.utils import LambdaJob from . import _predictor if output_csv is None and output_bq_table is None: raise ValueError('output_csv and output_bq_table cannot both be None.') job_id = ('batch-predict-image-classification-' + datetime.datetime.now().strftime('%y%m%d-%H%M%S')) # Project is needed for bigquery data source, even in local run. options = { 'project': _util.default_project(), } opts = beam.pipeline.PipelineOptions(flags=[], **options) p = beam.Pipeline('DirectRunner', options=opts) _predictor.configure_pipeline(p, dataset, model_dir, output_csv, output_bq_table) job = LambdaJob(lambda: p.run().wait_until_finish(), job_id) return job ================================================ FILE: solutionbox/image_classification/mltoolbox/image/classification/_model.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inception model tensorflow implementation. """ from enum import Enum import logging import tensorflow as tf from tensorflow.contrib import layers from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variables from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants from . import _inceptionlib from . import _util slim = tf.contrib.slim LOGITS_TENSOR_NAME = 'logits_tensor' IMAGE_URI_COLUMN = 'image_uri' LABEL_COLUMN = 'label' EMBEDDING_COLUMN = 'embedding' BOTTLENECK_TENSOR_SIZE = 2048 class GraphMod(Enum): TRAIN = 1 EVALUATE = 2 PREDICT = 3 class GraphReferences(object): """Holder of base tensors used for training model using common task.""" def __init__(self): self.examples = None self.train = None self.global_step = None self.metric_updates = [] self.metric_values = [] self.keys = None self.predictions = [] class Model(object): """TensorFlow model for the flowers problem.""" def __init__(self, labels, dropout, inception_checkpoint_file): self.labels = labels self.labels.sort() self.dropout = dropout self.inception_checkpoint_file = inception_checkpoint_file def add_final_training_ops(self, embeddings, all_labels_count, bottleneck_tensor_size, hidden_layer_size=BOTTLENECK_TENSOR_SIZE / 4, dropout_keep_prob=None): """Adds a new softmax and fully-connected layer for training. The set up for the softmax and fully-connected layers is based on: https://tensorflow.org/versions/master/tutorials/mnist/beginners/index.html This function can be customized to add arbitrary layers for application-specific requirements. Args: embeddings: The embedding (bottleneck) tensor. all_labels_count: The number of all labels including the default label. bottleneck_tensor_size: The number of embeddings. hidden_layer_size: The size of the hidden_layer. Roughtly, 1/4 of the bottleneck tensor size. dropout_keep_prob: the percentage of activation values that are retained. Returns: softmax: The softmax or tensor. It stores the final scores. logits: The logits tensor. """ with tf.name_scope('input'): bottleneck_input = tf.placeholder_with_default( embeddings, shape=[None, bottleneck_tensor_size], name='ReshapeSqueezed') bottleneck_with_no_gradient = tf.stop_gradient(bottleneck_input) with tf.name_scope('Wx_plus_b'): hidden = layers.fully_connected(bottleneck_with_no_gradient, hidden_layer_size) # We need a dropout when the size of the dataset is rather small. if dropout_keep_prob: hidden = tf.nn.dropout(hidden, dropout_keep_prob) logits = layers.fully_connected( hidden, all_labels_count, activation_fn=None) softmax = tf.nn.softmax(logits, name='softmax') return softmax, logits def build_inception_graph(self): """Builds an inception graph and add the necessary input & output tensors. To use other Inception models modify this file. Also preprocessing must be modified accordingly. See tensorflow/contrib/slim/python/slim/nets/inception_v3.py for details about InceptionV3. Returns: input_jpeg: A placeholder for jpeg string batch that allows feeding the Inception layer with image bytes for prediction. inception_embeddings: The embeddings tensor. """ image_str_tensor = tf.placeholder(tf.string, shape=[None]) # The CloudML Prediction API always "feeds" the Tensorflow graph with # dynamic batch sizes e.g. (?,). decode_jpeg only processes scalar # strings because it cannot guarantee a batch of images would have # the same output size. We use tf.map_fn to give decode_jpeg a scalar # string from dynamic batches. image = tf.map_fn( _util.decode_and_resize, image_str_tensor, back_prop=False, dtype=tf.uint8) # convert_image_dtype, also scales [0, uint8_max] -> [0 ,1). image = tf.image.convert_image_dtype(image, dtype=tf.float32) # Then shift images to [-1, 1) for Inception. image = tf.subtract(image, 0.5) image = tf.multiply(image, 2.0) # Build Inception layers, which expect A tensor of type float from [-1, 1) # and shape [batch_size, height, width, channels]. with slim.arg_scope(_inceptionlib.inception_v3_arg_scope()): _, end_points = _inceptionlib.inception_v3(image, is_training=False) inception_embeddings = end_points['PreLogits'] inception_embeddings = tf.squeeze( inception_embeddings, [1, 2], name='SpatialSqueeze') return image_str_tensor, inception_embeddings def build_graph(self, data_paths, batch_size, graph_mod): """Builds generic graph for training or eval.""" tensors = GraphReferences() is_training = graph_mod == GraphMod.TRAIN if data_paths: _, tensors.examples = _util.read_examples( data_paths, batch_size, shuffle=is_training, num_epochs=None if is_training else 2) else: tensors.examples = tf.placeholder(tf.string, name='input', shape=(None,)) if graph_mod == GraphMod.PREDICT: inception_input, inception_embeddings = self.build_inception_graph() # Build the Inception graph. We later add final training layers # to this graph. This is currently used only for prediction. # For training, we use pre-processed data, so it is not needed. embeddings = inception_embeddings tensors.input_jpeg = inception_input else: # For training and evaluation we assume data is preprocessed, so the # inputs are tf-examples. # Generate placeholders for examples. with tf.name_scope('inputs'): feature_map = { 'image_uri': tf.FixedLenFeature( shape=[], dtype=tf.string, default_value=['']), # Some images may have no labels. For those, we assume a default # label. So the number of labels is label_count+1 for the default # label. 'label': tf.FixedLenFeature( shape=[1], dtype=tf.int64, default_value=[len(self.labels)]), 'embedding': tf.FixedLenFeature( shape=[BOTTLENECK_TENSOR_SIZE], dtype=tf.float32) } parsed = tf.parse_example(tensors.examples, features=feature_map) labels = tf.squeeze(parsed['label']) uris = tf.squeeze(parsed['image_uri']) embeddings = parsed['embedding'] # We assume a default label, so the total number of labels is equal to # label_count+1. all_labels_count = len(self.labels) + 1 with tf.name_scope('final_ops'): softmax, logits = self.add_final_training_ops( embeddings, all_labels_count, BOTTLENECK_TENSOR_SIZE, dropout_keep_prob=self.dropout if is_training else None) # Prediction is the index of the label with the highest score. We are # interested only in the top score. prediction = tf.argmax(softmax, 1) tensors.predictions = [prediction, softmax, embeddings] if graph_mod == GraphMod.PREDICT: return tensors with tf.name_scope('evaluate'): loss_value = loss(logits, labels) # Add to the Graph the Ops that calculate and apply gradients. if is_training: tensors.train, tensors.global_step = training(loss_value) else: tensors.global_step = tf.Variable(0, name='global_step', trainable=False) tensors.uris = uris # Add means across all batches. loss_updates, loss_op = _util.loss(loss_value) accuracy_updates, accuracy_op = _util.accuracy(logits, labels) if not is_training: tf.summary.scalar('accuracy', accuracy_op) tf.summary.scalar('loss', loss_op) tensors.metric_updates = loss_updates + accuracy_updates tensors.metric_values = [loss_op, accuracy_op] return tensors def build_train_graph(self, data_paths, batch_size): return self.build_graph(data_paths, batch_size, GraphMod.TRAIN) def build_eval_graph(self, data_paths, batch_size): return self.build_graph(data_paths, batch_size, GraphMod.EVALUATE) def restore_from_checkpoint(self, session, inception_checkpoint_file, trained_checkpoint_file): """To restore model variables from the checkpoint file. The graph is assumed to consist of an inception model and other layers including a softmax and a fully connected layer. The former is pre-trained and the latter is trained using the pre-processed data. So we restore this from two checkpoint files. Args: session: The session to be used for restoring from checkpoint. inception_checkpoint_file: Path to the checkpoint file for the Inception graph. trained_checkpoint_file: path to the trained checkpoint for the other layers. """ inception_exclude_scopes = [ 'InceptionV3/AuxLogits', 'InceptionV3/Logits', 'global_step', 'final_ops' ] reader = tf.train.NewCheckpointReader(inception_checkpoint_file) var_to_shape_map = reader.get_variable_to_shape_map() # Get all variables to restore. Exclude Logits and AuxLogits because they # depend on the input data and we do not need to intialize them. all_vars = tf.contrib.slim.get_variables_to_restore( exclude=inception_exclude_scopes) # Remove variables that do not exist in the inception checkpoint (for # example the final softmax and fully-connected layers). inception_vars = { var.op.name: var for var in all_vars if var.op.name in var_to_shape_map } inception_saver = tf.train.Saver(inception_vars) inception_saver.restore(session, inception_checkpoint_file) # Restore the rest of the variables from the trained checkpoint. trained_vars = tf.contrib.slim.get_variables_to_restore( exclude=inception_exclude_scopes + inception_vars.keys()) trained_saver = tf.train.Saver(trained_vars) trained_saver.restore(session, trained_checkpoint_file) def build_prediction_graph(self): """Builds prediction graph and registers appropriate endpoints.""" tensors = self.build_graph(None, 1, GraphMod.PREDICT) keys_placeholder = tf.placeholder(tf.string, shape=[None]) inputs = { 'key': keys_placeholder, 'image_bytes': tensors.input_jpeg } # To extract the id, we need to add the identity function. keys = tf.identity(keys_placeholder) labels = self.labels + ['UNKNOWN'] labels_tensor = tf.constant(labels) labels_table = tf.contrib.lookup.index_to_string_table_from_tensor(mapping=labels_tensor) predicted_label = labels_table.lookup(tensors.predictions[0]) # Need to duplicate the labels by num_of_instances so the output is one batch # (all output members share the same outer dimension). # The labels are needed for client to match class scores list. labels_tensor = tf.expand_dims(tf.constant(labels), 0) num_instance = tf.shape(keys) labels_tensors_n = tf.tile(labels_tensor, tf.concat(axis=0, values=[num_instance, [1]])) outputs = { 'key': keys, 'prediction': predicted_label, 'labels': labels_tensors_n, 'scores': tensors.predictions[1], } return inputs, outputs def export(self, last_checkpoint, output_dir): """Builds a prediction graph and xports the model. Args: last_checkpoint: Path to the latest checkpoint file from training. output_dir: Path to the folder to be used to output the model. """ logging.info('Exporting prediction graph to %s', output_dir) with tf.Session(graph=tf.Graph()) as sess: # Build and save prediction meta graph and trained variable values. inputs, outputs = self.build_prediction_graph() signature_def_map = { 'serving_default': signature_def_utils.predict_signature_def(inputs, outputs) } init_op = tf.global_variables_initializer() sess.run(init_op) self.restore_from_checkpoint(sess, self.inception_checkpoint_file, last_checkpoint) init_op_serving = control_flow_ops.group( variables.local_variables_initializer(), tf.tables_initializer()) builder = saved_model_builder.SavedModelBuilder(output_dir) builder.add_meta_graph_and_variables( sess, [tag_constants.SERVING], signature_def_map=signature_def_map, legacy_init_op=init_op_serving) builder.save(False) def format_metric_values(self, metric_values): """Formats metric values - used for logging purpose.""" # Early in training, metric_values may actually be None. loss_str = 'N/A' accuracy_str = 'N/A' try: loss_str = 'loss: %.3f' % metric_values[0] accuracy_str = 'accuracy: %.3f' % metric_values[1] except (TypeError, IndexError): pass return '%s, %s' % (loss_str, accuracy_str) def format_prediction_values(self, prediction): """Formats prediction values - used for writing batch predictions as csv.""" return '%.3f' % (prediction[0]) def loss(logits, labels): """Calculates the loss from the logits and the labels. Args: logits: Logits tensor, float - [batch_size, NUM_CLASSES]. labels: Labels tensor, int32 - [batch_size]. Returns: loss: Loss tensor of type float. """ labels = tf.to_int64(labels) cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=labels, name='xentropy') return tf.reduce_mean(cross_entropy, name='xentropy_mean') def training(loss_op): """Calculates the loss from the logits and the labels. Args: logits: Logits tensor, float - [batch_size, NUM_CLASSES]. labels: Labels tensor, int32 - [batch_size]. Returns: loss: Loss tensor of type float. """ global_step = tf.Variable(0, name='global_step', trainable=False) with tf.name_scope('train'): optimizer = tf.train.AdamOptimizer(epsilon=0.001) train_op = optimizer.minimize(loss_op, global_step) return train_op, global_step ================================================ FILE: solutionbox/image_classification/mltoolbox/image/classification/_predictor.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Local implementation for preprocessing, training and prediction for inception model. """ import apache_beam as beam from apache_beam.transforms import window from apache_beam.utils.windowed_value import WindowedValue import collections import json import os from . import _util def _load_tf_model(model_dir): from tensorflow.python.saved_model import tag_constants from tensorflow.contrib.session_bundle import bundle_shim model_dir = os.path.join(model_dir, 'model') session, meta_graph = bundle_shim.load_session_bundle_or_saved_model_bundle_from_path( model_dir, tags=[tag_constants.SERVING]) signature = meta_graph.signature_def['serving_default'] inputs = {friendly_name: tensor_info_proto.name for (friendly_name, tensor_info_proto) in signature.inputs.items()} outputs = {friendly_name: tensor_info_proto.name for (friendly_name, tensor_info_proto) in signature.outputs.items()} return session, inputs, outputs def _tf_predict(model_dir, images): session, inputs, outputs = _load_tf_model(model_dir) with session: feed_dict = collections.defaultdict(list) for ii, image in enumerate(images): feed_dict[inputs['image_bytes']].append(image) feed_dict[inputs['key']].append(str(ii)) predictions, labels, scores = session.run( [outputs['prediction'], outputs['labels'], outputs['scores']], feed_dict=feed_dict) return zip(predictions, labels, scores) def predict(model_dir, images): """Local instant prediction.""" results = _tf_predict(model_dir, images) predicted_and_scores = [(predicted, label_scores[list(labels).index(predicted)]) for predicted, labels, label_scores in results] return predicted_and_scores # Helpers for batch prediction dataflow pipeline class EmitAsBatchDoFn(beam.DoFn): """A DoFn that buffers the records and emits them batch by batch.""" def __init__(self, batch_size): self._batch_size = batch_size self._cached = [] def process(self, element): self._cached.append(element) if len(self._cached) >= self._batch_size: emit = self._cached self._cached = [] yield emit def finish_bundle(self, context=None): if len(self._cached) > 0: # pylint: disable=g-explicit-length-test yield WindowedValue(self._cached, -1, [window.GlobalWindow()]) class UnbatchDoFn(beam.DoFn): """A DoFn expand batch into elements.""" def process(self, element): for item in element: yield item class LoadImagesDoFn(beam.DoFn): """A DoFn that reads image from url.""" def process(self, element): from tensorflow.python.lib.io import file_io as tf_file_io with tf_file_io.FileIO(element['image_url'], 'r') as ff: image_bytes = ff.read() out_element = {'image_bytes': image_bytes} out_element.update(element) yield out_element class PredictBatchDoFn(beam.DoFn): """A DoFn that does batch prediction.""" def __init__(self, model_dir): self._model_dir = model_dir self._session = None self._tf_inputs = None self._tf_outputs = None def start_bundle(self, context=None): self._session, self._tf_inputs, self._tf_outputs = _load_tf_model(self._model_dir) def finish_bundle(self, context=None): if self._session is not None: self._session.close() def process(self, element): import collections image_urls = [x['image_url'] for x in element] targets = None if 'label' in element[0] and element[0]['label'] is not None: targets = [x['label'] for x in element] feed_dict = collections.defaultdict(list) feed_dict[self._tf_inputs['image_bytes']] = [x['image_bytes'] for x in element] feed_dict[self._tf_inputs['key']] = image_urls predictions, labels, scores = self._session.run( [self._tf_outputs['prediction'], self._tf_outputs['labels'], self._tf_outputs['scores']], feed_dict=feed_dict) if targets is not None: yield zip(image_urls, targets, predictions, labels, scores) else: yield zip(image_urls, predictions, labels, scores) class ProcessResultsDoFn(beam.DoFn): """A DoFn that process prediction results by casting values and calculating target_prob. """ def process(self, element): target = None if len(element) == 5: image_url, target, prediction, labels, scores = element else: image_url, prediction, labels, scores = element labels = list(labels) predicted_prob = scores[labels.index(prediction)] out_element = { 'image_url': image_url, 'predicted': prediction, # Convert to float from np.float32 because BigQuery Sink can only handle intrinsic types. 'predicted_prob': float(predicted_prob) } if target is not None: target_prob = scores[labels.index(target)] if target in labels else 0.0 out_element['target_prob'] = float(target_prob) out_element['target'] = target yield out_element class MakeCsvLineDoFn(beam.DoFn): """A DoFn that makes CSV lines out of prediction results.""" def process(self, element): import csv import StringIO line = StringIO.StringIO() if len(element) == 5: csv.DictWriter(line, ['image_url', 'target', 'predicted', 'target_prob', 'predicted_prob']).writerow(element) else: csv.DictWriter(line, ['image_url', 'predicted', 'predicted_prob']).writerow(element) yield line.getvalue() def configure_pipeline(p, dataset, model_dir, output_csv, output_bq_table): """Configures a dataflow pipeline for batch prediction.""" data = _util.get_sources_from_dataset(p, dataset, 'predict') if len(dataset.schema) == 2: output_schema = [ {'name': 'image_url', 'type': 'STRING'}, {'name': 'target', 'type': 'STRING'}, {'name': 'predicted', 'type': 'STRING'}, {'name': 'target_prob', 'type': 'FLOAT'}, {'name': 'predicted_prob', 'type': 'FLOAT'}, ] else: output_schema = [ {'name': 'image_url', 'type': 'STRING'}, {'name': 'predicted', 'type': 'STRING'}, {'name': 'predicted_prob', 'type': 'FLOAT'}, ] results = (data | 'Load Images' >> beam.ParDo(LoadImagesDoFn()) | 'Batch Inputs' >> beam.ParDo(EmitAsBatchDoFn(20)) | 'Batch Predict' >> beam.ParDo(PredictBatchDoFn(model_dir)) | 'Unbatch' >> beam.ParDo(UnbatchDoFn()) | 'Process Results' >> beam.ParDo(ProcessResultsDoFn())) if output_csv is not None: schema_file = output_csv + '.schema.json' results_save = (results | 'Prepare For Output' >> beam.ParDo(MakeCsvLineDoFn()) | 'Write Csv Results' >> beam.io.textio.WriteToText(output_csv, shard_name_template='')) (results_save | 'Sample One' >> beam.transforms.combiners.Sample.FixedSizeGlobally(1) | 'Serialize Schema' >> beam.Map(lambda path: json.dumps(output_schema)) | 'Write Schema' >> beam.io.textio.WriteToText(schema_file, shard_name_template='')) if output_bq_table is not None: # BigQuery sink takes schema in the form of 'field1:type1,field2:type2...' bq_schema_string = ','.join(x['name'] + ':' + x['type'] for x in output_schema) sink = beam.io.BigQuerySink(output_bq_table, schema=bq_schema_string, write_disposition=beam.io.BigQueryDisposition.WRITE_TRUNCATE) results | 'Write BQ Results' >> beam.io.Write(sink) ================================================ FILE: solutionbox/image_classification/mltoolbox/image/classification/_preprocess.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Preprocess pipeline implementation with Cloud DataFlow. """ import apache_beam as beam from apache_beam.io import tfrecordio from apache_beam.metrics import Metrics import cStringIO import logging import os from PIL import Image from . import _inceptionlib from . import _util error_count = Metrics.counter('main', 'errorCount') rows_count = Metrics.counter('main', 'rowsCount') skipped_empty_line = Metrics.counter('main', 'skippedEmptyLine') embedding_good = Metrics.counter('main', 'embedding_good') embedding_bad = Metrics.counter('main', 'embedding_bad') incompatible_image = Metrics.counter('main', 'incompatible_image') invalid_uri = Metrics.counter('main', 'invalid_file_name') unlabeled_image = Metrics.counter('main', 'unlabeled_image') class ExtractLabelIdsDoFn(beam.DoFn): """Extracts (uri, label_ids) tuples from CSV rows. """ def start_bundle(self, context=None): self.label_to_id_map = {} def process(self, element, all_labels): all_labels = list(all_labels) # DataFlow cannot garuantee the order of the labels when materializing it. # The labels materialized and consumed by training may not be with the same order # as the one used in preprocessing. So we need to sort it in both preprocessing # and training so the order matches. all_labels.sort() if not self.label_to_id_map: for i, label in enumerate(all_labels): label = label.strip() if label: self.label_to_id_map[label] = i # Row format is: # image_uri,label_id if not element: skipped_empty_line.inc() return rows_count.inc() uri = element['image_url'] try: label_id = self.label_to_id_map[element['label'].strip()] except KeyError: unlabeled_image.inc() yield uri, label_id class ReadImageAndConvertToJpegDoFn(beam.DoFn): """Read files from GCS and convert images to JPEG format. We do this even for JPEG images to remove variations such as different number of channels. """ def process(self, element): from tensorflow.python.lib.io import file_io as tf_file_io uri, label_id = element try: with tf_file_io.FileIO(uri, 'r') as f: img = Image.open(f).convert('RGB') # A variety of different calling libraries throw different exceptions here. # They all correspond to an unreadable file so we treat them equivalently. # pylint: disable broad-except except Exception as e: logging.exception('Error processing image %s: %s', uri, str(e)) error_count.inc() return # Convert to desired format and output. output = cStringIO.StringIO() img.save(output, 'jpeg') image_bytes = output.getvalue() yield uri, label_id, image_bytes class EmbeddingsGraph(object): """Builds a graph and uses it to extract embeddings from images. """ # These constants are set by Inception v3's expectations. WIDTH = 299 HEIGHT = 299 CHANNELS = 3 def __init__(self, tf_session, checkpoint_path): import tensorflow as tf self.tf_session = tf_session # input_jpeg is the tensor that contains raw image bytes. # It is used to feed image bytes and obtain embeddings. self.input_jpeg, self.embedding = self.build_graph() self.tf_session.run(tf.global_variables_initializer()) self.restore_from_checkpoint(checkpoint_path) def build_graph(self): """Forms the core by building a wrapper around the inception graph. Here we add the necessary input & output tensors, to decode jpegs, serialize embeddings, restore from checkpoint etc. To use other Inception models modify this file. Note that to use other models beside Inception, you should make sure input_shape matches their input. Resizing or other modifications may be necessary as well. See tensorflow/contrib/slim/python/slim/nets/inception_v3.py for details about InceptionV3. Returns: input_jpeg: A tensor containing raw image bytes as the input layer. embedding: The embeddings tensor, that will be materialized later. """ import tensorflow as tf input_jpeg = tf.placeholder(tf.string, shape=None) image = tf.image.decode_jpeg(input_jpeg, channels=self.CHANNELS) # Note resize expects a batch_size, but we are feeding a single image. # So we have to expand then squeeze. Resize returns float32 in the # range [0, uint8_max] image = tf.expand_dims(image, 0) # convert_image_dtype also scales [0, uint8_max] -> [0 ,1). image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = tf.image.resize_bilinear( image, [self.HEIGHT, self.WIDTH], align_corners=False) # Then rescale range to [-1, 1) for Inception. image = tf.subtract(image, 0.5) inception_input = tf.multiply(image, 2.0) # Build Inception layers, which expect a tensor of type float from [-1, 1) # and shape [batch_size, height, width, channels]. with tf.contrib.slim.arg_scope(_inceptionlib.inception_v3_arg_scope()): _, end_points = _inceptionlib.inception_v3(inception_input, is_training=False) embedding = end_points['PreLogits'] return input_jpeg, embedding def restore_from_checkpoint(self, checkpoint_path): """To restore inception model variables from the checkpoint file. Some variables might be missing in the checkpoint file, so it only loads the ones that are avialable, assuming the rest would be initialized later. Args: checkpoint_path: Path to the checkpoint file for the Inception graph. """ import tensorflow as tf # Get all variables to restore. Exclude Logits and AuxLogits because they # depend on the input data and we do not need to intialize them from # checkpoint. all_vars = tf.contrib.slim.get_variables_to_restore( exclude=['InceptionV3/AuxLogits', 'InceptionV3/Logits', 'global_step']) saver = tf.train.Saver(all_vars) saver.restore(self.tf_session, checkpoint_path) def calculate_embedding(self, batch_image_bytes): """Get the embeddings for a given JPEG image. Args: batch_image_bytes: As if returned from [ff.read() for ff in file_list]. Returns: The Inception embeddings (bottleneck layer output) """ return self.tf_session.run( self.embedding, feed_dict={self.input_jpeg: batch_image_bytes}) class TFExampleFromImageDoFn(beam.DoFn): """Embeds image bytes and labels, stores them in tensorflow.Example. (uri, label_ids, image_bytes) -> (tensorflow.Example). Output proto contains 'label', 'image_uri' and 'embedding'. The 'embedding' is calculated by feeding image into input layer of image neural network and reading output of the bottleneck layer of the network. Attributes: image_graph_uri: an uri to gcs bucket where serialized image graph is stored. """ def __init__(self, checkpoint_path): self.tf_session = None self.graph = None self.preprocess_graph = None self._checkpoint_path = checkpoint_path def start_bundle(self, context=None): # There is one tensorflow session per instance of TFExampleFromImageDoFn. # The same instance of session is re-used between bundles. # Session is closed by the destructor of Session object, which is called # when instance of TFExampleFromImageDoFn() is destructed. import tensorflow as tf if not self.graph: self.graph = tf.Graph() self.tf_session = tf.InteractiveSession(graph=self.graph) with self.graph.as_default(): self.preprocess_graph = EmbeddingsGraph(self.tf_session, self._checkpoint_path) def finish_bundle(self, context=None): if self.tf_session is not None: self.tf_session.close() def process(self, element): import tensorflow as tf def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) def _float_feature(value): return tf.train.Feature(float_list=tf.train.FloatList(value=value)) uri, label_id, image_bytes = element try: embedding = self.preprocess_graph.calculate_embedding(image_bytes) except tf.errors.InvalidArgumentError as e: incompatible_image.inc() logging.warning('Could not encode an image from %s: %s', uri, str(e)) return if embedding.any(): embedding_good.inc() else: embedding_bad.inc() example = tf.train.Example(features=tf.train.Features(feature={ 'image_uri': _bytes_feature([str(uri)]), 'embedding': _float_feature(embedding.ravel().tolist()), })) example.features.feature['label'].int64_list.value.append(label_id) yield example class TrainEvalSplitPartitionFn(beam.PartitionFn): """Split train and eval data.""" def partition_for(self, element, num_partitions): import random return 1 if random.random() > 0.7 else 0 class ExampleProtoCoder(beam.coders.Coder): """A coder to encode and decode TensorFlow Example objects.""" def encode(self, example_proto): return example_proto.SerializeToString() def decode(self, serialized_str): import tensorflow as tf example = tf.train.Example() example.ParseFromString(serialized_str) return example class SaveFeatures(beam.PTransform): """Save Features in a TFRecordIO format. """ def __init__(self, file_path_prefix): super(SaveFeatures, self).__init__('SaveFeatures') self._file_path_prefix = file_path_prefix def expand(self, features): return (features | 'Write to %s' % self._file_path_prefix.replace('/', '_') >> tfrecordio.WriteToTFRecord(file_path_prefix=self._file_path_prefix, file_name_suffix='.tfrecord.gz', coder=ExampleProtoCoder())) def _labels_pipeline(sources): labels = (sources | 'Flatten Sources for labels' >> beam.Flatten() | 'Parse input for labels' >> beam.Map(lambda x: str(x['label'])) | 'Combine labels' >> beam.transforms.combiners.Count.PerElement() | 'Get labels' >> beam.Map(lambda label_count: label_count[0])) return labels def _transformation_pipeline(source, checkpoint, labels, mode): transformed = (source | 'Extract label ids(%s)' % mode >> beam.ParDo(ExtractLabelIdsDoFn(), beam.pvalue.AsIter(labels)) | 'Read and convert to JPEG(%s)' % mode >> beam.ParDo(ReadImageAndConvertToJpegDoFn()) | 'Embed and make TFExample(%s)' % mode >> beam.ParDo(TFExampleFromImageDoFn(checkpoint))) return transformed def configure_pipeline(p, dataset_train, dataset_eval, checkpoint_path, output_dir, job_id): source_train = _util.get_sources_from_dataset(p, dataset_train, 'train') labels_source = [source_train] if dataset_eval is not None: source_eval = _util.get_sources_from_dataset(p, dataset_eval, 'eval') labels_source.append(source_eval) labels = _labels_pipeline(labels_source) train_preprocessed = _transformation_pipeline(source_train, checkpoint_path, labels, 'train') if dataset_eval is not None: # explicit eval data. eval_preprocessed = _transformation_pipeline(source_eval, checkpoint_path, labels, 'eval') else: # Split train/eval. train_preprocessed, eval_preprocessed = (train_preprocessed | 'Random Partition' >> beam.Partition(TrainEvalSplitPartitionFn(), 2)) output_train_path = os.path.join(output_dir, job_id, 'train') output_eval_path = os.path.join(output_dir, job_id, 'eval') labels_file = os.path.join(output_dir, job_id, 'labels') labels_save = (labels | 'Write labels' >> beam.io.textio.WriteToText(labels_file, shard_name_template='')) train_save = train_preprocessed | 'Save train to disk' >> SaveFeatures(output_train_path) eval_save = eval_preprocessed | 'Save eval to disk' >> SaveFeatures(output_eval_path) # Make sure we write "latest" file after train and eval data are successfully written. output_latest_file = os.path.join(output_dir, 'latest') ([eval_save, train_save, labels_save] | 'Wait for train eval saving' >> beam.Flatten() | 'Fixed One' >> beam.transforms.combiners.Sample.FixedSizeGlobally(1) | beam.Map(lambda path: job_id) | 'WriteLatest' >> beam.io.textio.WriteToText(output_latest_file, shard_name_template='')) ================================================ FILE: solutionbox/image_classification/mltoolbox/image/classification/_trainer.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Training implementation for inception model. """ import logging import os import tensorflow as tf import time from . import _util def start_server(cluster, task): if not task.type: raise ValueError('--task_type must be specified.') if task.index is None: raise ValueError('--task_index must be specified.') # Create and start a server. return tf.train.Server( tf.train.ClusterSpec(cluster), protocol='grpc', job_name=task.type, task_index=task.index) class Evaluator(object): """Loads variables from latest checkpoint and performs model evaluation.""" def __init__(self, model, data_paths, batch_size, output_path, dataset='eval'): data_size = self._data_size(data_paths) if data_size <= batch_size: raise Exception('Data size is smaller than batch size.') self.num_eval_batches = data_size // batch_size self.batch_of_examples = [] self.checkpoint_path = os.path.join(output_path, 'train') self.output_path = os.path.join(output_path, dataset) self.eval_data_paths = data_paths self.batch_size = batch_size self.model = model def _data_size(self, data_paths): n = 0 options = tf.python_io.TFRecordOptions( compression_type=tf.python_io.TFRecordCompressionType.GZIP) for file in data_paths: for line in tf.python_io.tf_record_iterator(file, options=options): n += 1 return n def evaluate(self, num_eval_batches=None): """Run one round of evaluation, return loss and accuracy.""" num_eval_batches = num_eval_batches or self.num_eval_batches with tf.Graph().as_default() as graph: self.tensors = self.model.build_eval_graph(self.eval_data_paths, self.batch_size) self.summary = tf.summary.merge_all() self.saver = tf.train.Saver() self.summary_writer = tf.summary.FileWriter(self.output_path) self.sv = tf.train.Supervisor( graph=graph, logdir=self.output_path, summary_op=None, global_step=None, saver=self.saver) last_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path) with self.sv.managed_session(master='', start_standard_services=False) as session: self.sv.saver.restore(session, last_checkpoint) if not self.batch_of_examples: self.sv.start_queue_runners(session) for i in range(num_eval_batches): self.batch_of_examples.append(session.run(self.tensors.examples)) for i in range(num_eval_batches): session.run(self.tensors.metric_updates, {self.tensors.examples: self.batch_of_examples[i]}) metric_values = session.run(self.tensors.metric_values) global_step = tf.train.global_step(session, self.tensors.global_step) summary = session.run(self.summary) self.summary_writer.add_summary(summary, global_step) self.summary_writer.flush() return metric_values class Trainer(object): """Performs model training and optionally evaluation.""" def __init__(self, input_dir, batch_size, max_steps, output_path, model, cluster, task): train_files, eval_files = _util.get_train_eval_files(input_dir) self.train_data_paths = train_files self.output_path = output_path self.batch_size = batch_size self.model = model self.max_steps = max_steps self.cluster = cluster self.task = task self.evaluator = Evaluator(self.model, eval_files, batch_size, output_path, 'eval_set') self.train_evaluator = Evaluator(self.model, train_files, batch_size, output_path, 'train_set') self.min_train_eval_rate = 8 def run_training(self): """Runs a Master.""" self.train_path = os.path.join(self.output_path, 'train') self.model_path = os.path.join(self.output_path, 'model') self.is_master = self.task.type != 'worker' log_interval = 15 self.eval_interval = 30 if self.is_master and self.task.index > 0: raise Exception('Only one replica of master expected') if self.cluster: logging.info('Starting %s/%d', self.task.type, self.task.index) server = start_server(self.cluster, self.task) target = server.target device_fn = tf.train.replica_device_setter( ps_device='/job:ps', worker_device='/job:%s/task:%d' % (self.task.type, self.task.index), cluster=self.cluster) # We use a device_filter to limit the communication between this job # and the parameter servers, i.e., there is no need to directly # communicate with the other workers; attempting to do so can result # in reliability problems. device_filters = [ '/job:ps', '/job:%s/task:%d' % (self.task.type, self.task.index) ] config = tf.ConfigProto(device_filters=device_filters) else: target = '' device_fn = '' config = None with tf.Graph().as_default() as graph: with tf.device(device_fn): # Build the training graph. self.tensors = self.model.build_train_graph(self.train_data_paths, self.batch_size) # Add the variable initializer Op. init_op = tf.global_variables_initializer() # Create a saver for writing training checkpoints. self.saver = tf.train.Saver() # Build the summary operation based on the TF collection of Summaries. self.summary_op = tf.summary.merge_all() # Create a "supervisor", which oversees the training process. self.sv = tf.train.Supervisor( graph, is_chief=self.is_master, logdir=self.train_path, init_op=init_op, saver=self.saver, # Write summary_ops by hand. summary_op=None, global_step=self.tensors.global_step, # No saving; we do it manually in order to easily evaluate immediately # afterwards. save_model_secs=0) should_retry = True to_run = [self.tensors.global_step, self.tensors.train] while should_retry: try: should_retry = False with self.sv.managed_session(target, config=config) as session: self.start_time = start_time = time.time() self.last_save = self.last_log = 0 self.global_step = self.last_global_step = 0 self.local_step = self.last_local_step = 0 self.last_global_time = self.last_local_time = start_time # Loop until the supervisor shuts down or max_steps have # completed. max_steps = self.max_steps while not self.sv.should_stop() and self.global_step < max_steps: try: # Run one step of the model. self.global_step = session.run(to_run)[0] self.local_step += 1 self.now = time.time() is_time_to_eval = (self.now - self.last_save) > self.eval_interval is_time_to_log = (self.now - self.last_log) > log_interval should_eval = self.is_master and is_time_to_eval should_log = is_time_to_log or should_eval if should_log: self.log(session) if should_eval: self.eval(session) except tf.errors.AbortedError: should_retry = True if self.is_master: # Take the final checkpoint and compute the final accuracy. # self.saver.save(session, self.sv.save_path, self.tensors.global_step) self.eval(session) except tf.errors.AbortedError: print('Hitting an AbortedError. Trying it again.') should_retry = True # Export the model for inference. if self.is_master: self.model.export(tf.train.latest_checkpoint(self.train_path), self.model_path) # Ask for all the services to stop. self.sv.stop() def log(self, session): """Logs training progress.""" logging.info('Train [%s/%d], step %d (%.3f sec) %.1f ' 'global steps/s, %.1f local steps/s', self.task.type, self.task.index, self.global_step, (self.now - self.start_time), (self.global_step - self.last_global_step) / (self.now - self.last_global_time), (self.local_step - self.last_local_step) / (self.now - self.last_local_time)) self.last_log = self.now self.last_global_step, self.last_global_time = self.global_step, self.now self.last_local_step, self.last_local_time = self.local_step, self.now def eval(self, session): """Runs evaluation loop.""" eval_start = time.time() self.saver.save(session, self.sv.save_path, self.tensors.global_step) logging.info( 'Eval, step %d:\n- on train set %s\n-- on eval set %s', self.global_step, self.model.format_metric_values(self.train_evaluator.evaluate()), self.model.format_metric_values(self.evaluator.evaluate())) now = time.time() # Make sure eval doesn't consume too much of total time. eval_time = now - eval_start train_eval_rate = self.eval_interval / eval_time if train_eval_rate < self.min_train_eval_rate and self.last_save > 0: logging.info('Adjusting eval interval from %.2fs to %.2fs', self.eval_interval, self.min_train_eval_rate * eval_time) self.eval_interval = self.min_train_eval_rate * eval_time self.last_save = now self.last_log = now def save_summaries(self, session): self.sv.summary_computed(session, session.run(self.summary_op), self.global_step) self.sv.summary_writer.flush() ================================================ FILE: solutionbox/image_classification/mltoolbox/image/classification/_util.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Reusable utility functions. """ import collections import multiprocessing import os import tensorflow as tf from tensorflow.python.lib.io import file_io _DEFAULT_CHECKPOINT_GSURL = 'gs://cloud-ml-data/img/flower_photos/inception_v3_2016_08_28.ckpt' def is_in_IPython(): try: import IPython # noqa return True except ImportError: return False def default_project(): from google.datalab import Context return Context.default().project_id def _get_latest_data_dir(input_dir): latest_file = os.path.join(input_dir, 'latest') if not file_io.file_exists(latest_file): raise Exception(('Cannot find "latest" file in "%s". ' + 'Please use a preprocessing output dir.') % input_dir) with file_io.FileIO(latest_file, 'r') as f: dir_name = f.read().rstrip() return os.path.join(input_dir, dir_name) def get_train_eval_files(input_dir): """Get preprocessed training and eval files.""" data_dir = _get_latest_data_dir(input_dir) train_pattern = os.path.join(data_dir, 'train*.tfrecord.gz') eval_pattern = os.path.join(data_dir, 'eval*.tfrecord.gz') train_files = file_io.get_matching_files(train_pattern) eval_files = file_io.get_matching_files(eval_pattern) return train_files, eval_files def get_labels(input_dir): """Get a list of labels from preprocessed output dir.""" data_dir = _get_latest_data_dir(input_dir) labels_file = os.path.join(data_dir, 'labels') with file_io.FileIO(labels_file, 'r') as f: labels = f.read().rstrip().split('\n') return labels def read_examples(input_files, batch_size, shuffle, num_epochs=None): """Creates readers and queues for reading example protos.""" files = [] for e in input_files: for path in e.split(','): files.extend(file_io.get_matching_files(path)) thread_count = multiprocessing.cpu_count() # The minimum number of instances in a queue from which examples are drawn # randomly. The larger this number, the more randomness at the expense of # higher memory requirements. min_after_dequeue = 1000 # When batching data, the queue's capacity will be larger than the batch_size # by some factor. The recommended formula is (num_threads + a small safety # margin). For now, we use a single thread for reading, so this can be small. queue_size_multiplier = thread_count + 3 # Convert num_epochs == 0 -> num_epochs is None, if necessary num_epochs = num_epochs or None # Build a queue of the filenames to be read. filename_queue = tf.train.string_input_producer(files, num_epochs, shuffle) options = tf.python_io.TFRecordOptions( compression_type=tf.python_io.TFRecordCompressionType.GZIP) example_id, encoded_example = tf.TFRecordReader(options=options).read_up_to( filename_queue, batch_size) if shuffle: capacity = min_after_dequeue + queue_size_multiplier * batch_size return tf.train.shuffle_batch( [example_id, encoded_example], batch_size, capacity, min_after_dequeue, enqueue_many=True, num_threads=thread_count) else: capacity = queue_size_multiplier * batch_size return tf.train.batch( [example_id, encoded_example], batch_size, capacity=capacity, enqueue_many=True, num_threads=thread_count) def override_if_not_in_args(flag, argument, args): """Checks if flags is in args, and if not it adds the flag to args.""" if flag not in args: args.extend([flag, argument]) def loss(loss_value): """Calculates aggregated mean loss.""" total_loss = tf.Variable(0.0, False) loss_count = tf.Variable(0, False) total_loss_update = tf.assign_add(total_loss, loss_value) loss_count_update = tf.assign_add(loss_count, 1) loss_op = total_loss / tf.cast(loss_count, tf.float32) return [total_loss_update, loss_count_update], loss_op def accuracy(logits, labels): """Calculates aggregated accuracy.""" is_correct = tf.nn.in_top_k(logits, labels, 1) correct = tf.reduce_sum(tf.cast(is_correct, tf.int32)) incorrect = tf.reduce_sum(tf.cast(tf.logical_not(is_correct), tf.int32)) correct_count = tf.Variable(0, False) incorrect_count = tf.Variable(0, False) correct_count_update = tf.assign_add(correct_count, correct) incorrect_count_update = tf.assign_add(incorrect_count, incorrect) accuracy_op = tf.cast(correct_count, tf.float32) / tf.cast( correct_count + incorrect_count, tf.float32) return [correct_count_update, incorrect_count_update], accuracy_op def check_dataset(dataset, mode): """Validate we have a good dataset.""" names = [x['name'] for x in dataset.schema] types = [x['type'] for x in dataset.schema] if mode == 'train': if (set(['image_url', 'label']) != set(names) or any(t != 'STRING' for t in types)): raise ValueError('Invalid dataset. Expect only "image_url,label" STRING columns.') else: if (set(['image_url']) != set(names) and set(['image_url', 'label']) != set(names)) or \ any(t != 'STRING' for t in types): raise ValueError('Invalid dataset. Expect only "image_url" or "image_url,label" ' + 'STRING columns.') def get_sources_from_dataset(p, dataset, mode): """get pcollection from dataset.""" import apache_beam as beam import csv from google.datalab.ml import CsvDataSet, BigQueryDataSet check_dataset(dataset, mode) if type(dataset) is CsvDataSet: source_list = [] for ii, input_path in enumerate(dataset.files): source_list.append(p | 'Read from Csv %d (%s)' % (ii, mode) >> beam.io.ReadFromText(input_path, strip_trailing_newlines=True)) return (source_list | 'Flatten Sources (%s)' % mode >> beam.Flatten() | 'Create Dict from Csv (%s)' % mode >> beam.Map(lambda line: csv.DictReader([line], fieldnames=['image_url', 'label']).next())) elif type(dataset) is BigQueryDataSet: bq_source = (beam.io.BigQuerySource(table=dataset.table) if dataset.table is not None else beam.io.BigQuerySource(query=dataset.query)) return p | 'Read source from BigQuery (%s)' % mode >> beam.io.Read(bq_source) else: raise ValueError('Invalid DataSet. Expect CsvDataSet or BigQueryDataSet') def decode_and_resize(image_str_tensor): """Decodes jpeg string, resizes it and returns a uint8 tensor.""" # These constants are set by Inception v3's expectations. height = 299 width = 299 channels = 3 image = tf.image.decode_jpeg(image_str_tensor, channels=channels) # Note resize expects a batch_size, but tf_map supresses that index, # thus we have to expand then squeeze. Resize returns float32 in the # range [0, uint8_max] image = tf.expand_dims(image, 0) image = tf.image.resize_bilinear(image, [height, width], align_corners=False) image = tf.squeeze(image, squeeze_dims=[0]) image = tf.cast(image, dtype=tf.uint8) return image def resize_image(image_str_tensor): """Decodes jpeg string, resizes it and re-encode it to jpeg.""" image = decode_and_resize(image_str_tensor) image = tf.image.encode_jpeg(image, quality=100) return image def load_images(image_files, resize=True): """Load images from files and optionally resize it.""" images = [] for image_file in image_files: with file_io.FileIO(image_file, 'r') as ff: images.append(ff.read()) if resize is False: return images # To resize, run a tf session so we can reuse 'decode_and_resize()' # which is used in prediction graph. This makes sure we don't lose # any quality in prediction, while decreasing the size of the images # submitted to the model over network. image_str_tensor = tf.placeholder(tf.string, shape=[None]) image = tf.map_fn(resize_image, image_str_tensor, back_prop=False) feed_dict = collections.defaultdict(list) feed_dict[image_str_tensor.name] = images with tf.Session() as sess: images_resized = sess.run(image, feed_dict=feed_dict) return images_resized def process_prediction_results(results, show_image): """Create DataFrames out of prediction results, and display images in IPython if requested.""" import pandas as pd if (is_in_IPython() and show_image is True): import IPython for image_url, image, label_and_score in results: IPython.display.display_html('

%s(%.5f)

' % label_and_score, raw=True) IPython.display.display(IPython.display.Image(data=image)) result_dict = [{'image_url': url, 'label': r[0], 'score': r[1]} for url, _, r in results] return pd.DataFrame(result_dict) def repackage_to_staging(output_path): """Repackage it from local installed location and copy it to GCS.""" import google.datalab.ml as ml # Find the package root. __file__ is under [package_root]/mltoolbox/image/classification. package_root = os.path.join(os.path.dirname(__file__), '../../../') # We deploy setup.py in the same dir for repackaging purpose. setup_py = os.path.join(os.path.dirname(__file__), 'setup.py') staging_package_url = os.path.join(output_path, 'staging', 'image_classification.tar.gz') ml.package_and_copy(package_root, setup_py, staging_package_url) return staging_package_url ================================================ FILE: solutionbox/image_classification/mltoolbox/image/classification/setup.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. # To publish to PyPi use: python setup.py bdist_wheel upload -r pypi import datetime from setuptools import setup minor = datetime.datetime.now().strftime("%y%m%d%H%M") version = '0.2' setup( name='mltoolbox_datalab_image_classification', namespace_packages=['mltoolbox'], version=version, packages=[ 'mltoolbox', 'mltoolbox.image', 'mltoolbox.image.classification', ], description='Google Cloud Datalab Inception Package', author='Google', author_email='google-cloud-datalab-feedback@googlegroups.com', keywords=[ ], license="Apache Software License", classifiers=[ "Programming Language :: Python", "Programming Language :: Python :: 2", "Development Status :: 4 - Beta", "Environment :: Other Environment", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", "Topic :: Software Development :: Libraries :: Python Modules" ], long_description=""" """, install_requires=[ 'pillow==3.4.1', ], package_data={ } ) ================================================ FILE: solutionbox/image_classification/mltoolbox/image/classification/task.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Entry point for CloudML training. CloudML training requires a tarball package and a python module to run. This file provides such a "main" method and a list of args passed with the program. """ import argparse import json import logging import os import tensorflow as tf from . import _model from . import _trainer from . import _util def main(_): parser = argparse.ArgumentParser() parser.add_argument( '--input_dir', type=str, help='The input dir path for training and evaluation data.') parser.add_argument( '--job-dir', dest='job_dir', type=str, help='The GCS path to which checkpoints and other outputs should be saved.') parser.add_argument( '--max_steps', type=int,) parser.add_argument( '--batch_size', type=int, help='Number of examples to be processed per mini-batch.') parser.add_argument( '--checkpoint', type=str, default=_util._DEFAULT_CHECKPOINT_GSURL, help='Pretrained inception checkpoint path.') args, _ = parser.parse_known_args() labels = _util.get_labels(args.input_dir) model = _model.Model(labels, 0.5, args.checkpoint) env = json.loads(os.environ.get('TF_CONFIG', '{}')) # Print the job data as provided by the service. logging.info('Original job data: %s', env.get('job', {})) task_data = env.get('task', None) or {'type': 'master', 'index': 0} task = type('TaskSpec', (object,), task_data) cluster_data = env.get('cluster', None) cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None if not cluster or not task or task.type == 'master' or task.type == 'worker': _trainer.Trainer(args.input_dir, args.batch_size, args.max_steps, args.job_dir, model, cluster, task).run_training() elif task.type == 'ps': server = _trainer.start_server(cluster, task) server.join() else: raise ValueError('invalid task_type %s' % (task.type,)) if __name__ == '__main__': logging.basicConfig(level=logging.INFO) tf.app.run() ================================================ FILE: solutionbox/image_classification/setup.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. # To publish to PyPi use: python setup.py bdist_wheel upload -r pypi import datetime from setuptools import setup minor = datetime.datetime.now().strftime("%y%m%d%H%M") version = '0.2' setup( name='mltoolbox_datalab_image_classification', namespace_packages=['mltoolbox'], version=version, packages=[ 'mltoolbox', 'mltoolbox.image', 'mltoolbox.image.classification', ], description='Google Cloud Datalab Inception Package', author='Google', author_email='google-cloud-datalab-feedback@googlegroups.com', keywords=[ ], license="Apache Software License", classifiers=[ "Programming Language :: Python", "Programming Language :: Python :: 2", "Development Status :: 4 - Beta", "Environment :: Other Environment", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", "Topic :: Software Development :: Libraries :: Python Modules" ], long_description=""" """, install_requires=[ 'pillow==6.2.0', ], package_data={ } ) ================================================ FILE: solutionbox/ml_workbench/setup.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from setuptools import setup setup( name='mltoolbox_code_free', namespace_packages=['mltoolbox'], version='1.0.0', packages=[ 'mltoolbox', 'mltoolbox.code_free_ml', 'mltoolbox.code_free_ml.trainer', ], description='Google Cloud Datalab Structured Data Package', author='Google', author_email='google-cloud-datalab-feedback@googlegroups.com', keywords=[ ], license="Apache Software License", classifiers=[ "Programming Language :: Python", "Programming Language :: Python :: 2", "Development Status :: 4 - Beta", "Environment :: Other Environment", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", "Topic :: Software Development :: Libraries :: Python Modules" ], long_description=""" """, install_requires=[ # TODO(brandondutra): fill this in. Add pydatalab? ], package_data={ }, data_files=[], ) ================================================ FILE: solutionbox/ml_workbench/tensorflow/__init__.py ================================================ ================================================ FILE: solutionbox/ml_workbench/tensorflow/analyze.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import copy import json import os import sys import six import textwrap from tensorflow.python.lib.io import file_io from trainer import feature_transforms as constant from trainer import feature_analysis as feature_analysis def parse_arguments(argv): """Parse command line arguments. Args: argv: list of command line arguments, including program name. Returns: An argparse Namespace object. Raises: ValueError: for bad parameters """ parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, description=textwrap.dedent("""\ Runs analysis on structured data and produces auxiliary files for training. The output files can also be used by the Transform step to materialize TF.Examples files, which for some problems can speed up training. Description of input files -------------------------- 1) If using csv files, the --schema parameter must be the file path to a schema file. The format of this file must be a valid BigQuery schema file, which is a JSON file containing a list of dicts. Consider the example schema file below: [ {"name": "column_name_1", "type": "integer"}, {"name": "column_name_2", "type": "float"}, {"name": "column_name_3", "type": "string"}, {"name": "column_name_4", "type": "string"}, ] Note that the column names in the csv file much match the order in the schema list. Also, we only support three BigQuery types ( integer, float, and string). If instead of csv files, --bigquery is used, the schema file is not needed as this program will extract it from the table directly. 2) --features is a file path to a file describing the transformations. Below is an example features file: { "column_name_1": {"transform": "scale"}, "column_name_3": {"transform": "target"}, "column_name_2": {"transform": "one_hot"}, "new_feature_name": {"transform": "key", "source_column": "column_name_4"}, } The format of the dict is `name`: `transform-dict` where the `name` is the name of the transformed feature. The `source_column` value lists what column in the input data is the source for this transformation. If `source_column` is missing, it is assumed the `name` is a source column and the transformed feature will have the same name as the input column. A list of supported `transform-dict`s is below: {"transform": "identity"}: does nothing (for numerical columns). {"transform": "scale", "value": x}: scale a numerical column to [-a, a]. If value is missing, x defaults to 1. {"transform": "one_hot"}: makes a one-hot encoding of a string column. {"transform": "embedding", "embedding_dim": d}: makes an embedding of a string column. {"transform": "multi_hot", "separator": ' '}: makes a multi-hot encoding of a string column. {"transform": "bag_of_words"}: bag of words transform for string columns. {"transform": "tfidf"}: TFIDF transform for string columns. {"transform": "image_to_vec", "checkpoint": "gs://b/o"}: From image gs url to embeddings. "checkpoint" is a inception v3 checkpoint. If absent, a default checkpoint is used. {"transform": "target"}: denotes what column is the target. If the schema type of this column is string, a one_hot encoding is automatically applied. If type is numerical, a identity transform is automatically applied. {"transform": "key"}: column contains metadata-like information and is not included in the model. Note that for tfidf and bag_of_words, the input string is assumed to contain text separated by a space. So for example, the string "a, b c." has three tokens 'a,', 'b', and 'c.'. """)) parser.add_argument('--cloud', action='store_true', help='Analysis will use cloud services.') parser.add_argument('--output', metavar='DIR', type=str, required=True, help='GCS or local folder') input_group = parser.add_argument_group( title='Data Source Parameters', description='schema is only needed if using --csv') # CSV input input_group.add_argument('--csv', metavar='FILE', type=str, required=False, action='append', help='Input CSV absolute file paths. May contain a ' 'file pattern.') input_group.add_argument('--schema', metavar='FILE', type=str, required=False, help='Schema file path. Only required if using csv files') # Bigquery input input_group.add_argument('--bigquery', metavar='PROJECT_ID.DATASET.TABLE_NAME', type=str, required=False, help=('Must be in the form project.dataset.table_name')) parser.add_argument('--features', metavar='FILE', type=str, required=True, help='Features file path') args = parser.parse_args(args=argv[1:]) if args.cloud: if not args.output.startswith('gs://'): raise ValueError('--output must point to a location on GCS') if (args.csv and not all(x.startswith('gs://') for x in args.csv)): raise ValueError('--csv must point to a location on GCS') if args.schema and not args.schema.startswith('gs://'): raise ValueError('--schema must point to a location on GCS') if not args.cloud and args.bigquery: raise ValueError('--bigquery must be used with --cloud') if not ((args.bigquery and args.csv is None and args.schema is None) or (args.bigquery is None and args.csv and args.schema)): raise ValueError('either --csv and --schema must both' ' be set or just --bigquery is set') return args def run_cloud_analysis(output_dir, csv_file_pattern, bigquery_table, schema, features): """Use BigQuery to analyze input date. Only one of csv_file_pattern or bigquery_table should be non-None. Args: output_dir: output folder csv_file_pattern: list of csv file paths, may contain wildcards bigquery_table: project_id.dataset_name.table_name schema: schema list features: features config """ def _execute_sql(sql, table): """Runs a BigQuery job and dowloads the results into local memeory. Args: sql: a SQL string table: bq.ExternalDataSource or bq.Table Returns: A Pandas dataframe. """ import google.datalab.bigquery as bq if isinstance(table, bq.ExternalDataSource): query = bq.Query(sql, data_sources={'csv_table': table}) else: query = bq.Query(sql) return query.execute().result().to_dataframe() feature_analysis.expand_defaults(schema, features) # features are updated. inverted_features = feature_analysis.invert_features(features) feature_analysis.check_schema_transforms_match(schema, inverted_features) import google.datalab.bigquery as bq if bigquery_table: table_name = '`%s`' % bigquery_table table = None else: table_name = 'csv_table' table = bq.ExternalDataSource( source=csv_file_pattern, schema=bq.Schema(schema)) # Make a copy of inverted_features and update the target transform to be # identity or one hot depending on the schema. inverted_features_target = copy.deepcopy(inverted_features) for name, transforms in six.iteritems(inverted_features_target): transform_set = {x['transform'] for x in transforms} if transform_set == set([constant.TARGET_TRANSFORM]): target_schema = next(col['type'].lower() for col in schema if col['name'] == name) if target_schema in constant.NUMERIC_SCHEMA: inverted_features_target[name] = [{'transform': constant.IDENTITY_TRANSFORM}] else: inverted_features_target[name] = [{'transform': constant.ONE_HOT_TRANSFORM}] numerical_vocab_stats = {} for col_name, transform_set in six.iteritems(inverted_features_target): sys.stdout.write('Analyzing column %s...\n' % col_name) sys.stdout.flush() # All transforms in transform_set require the same analysis. So look # at the first transform. transform = next(iter(transform_set)) if (transform['transform'] in constant.CATEGORICAL_TRANSFORMS or transform['transform'] in constant.TEXT_TRANSFORMS): if transform['transform'] in constant.TEXT_TRANSFORMS: # Split strings on space, then extract labels and how many rows each # token is in. This is done by making two temp tables: # SplitTable: each text row is made into an array of strings. The # array may contain repeat tokens # TokenTable: SplitTable with repeated tokens removed per row. # Then to flatten the arrays, TokenTable has to be joined with itself. # See the sections 'Flattening Arrays' and 'Filtering Arrays' at # https://cloud.google.com/bigquery/docs/reference/standard-sql/arrays separator = transform.get('separator', ' ') sql = ('WITH SplitTable AS ' ' (SELECT SPLIT({name}, \'{separator}\') as token_array FROM {table}), ' ' TokenTable AS ' ' (SELECT ARRAY(SELECT DISTINCT x ' ' FROM UNNEST(token_array) AS x) AS unique_tokens_per_row ' ' FROM SplitTable) ' 'SELECT token, COUNT(token) as token_count ' 'FROM TokenTable ' 'CROSS JOIN UNNEST(TokenTable.unique_tokens_per_row) as token ' 'WHERE LENGTH(token) > 0 ' 'GROUP BY token ' 'ORDER BY token_count DESC, token ASC').format(separator=separator, name=col_name, table=table_name) else: # Extract label and frequency sql = ('SELECT {name} as token, count(*) as count ' 'FROM {table} ' 'WHERE {name} IS NOT NULL ' 'GROUP BY {name} ' 'ORDER BY count DESC, token ASC').format(name=col_name, table=table_name) df = _execute_sql(sql, table) # Save the vocab csv_string = df.to_csv(index=False, header=False) file_io.write_string_to_file( os.path.join(output_dir, constant.VOCAB_ANALYSIS_FILE % col_name), csv_string) numerical_vocab_stats[col_name] = {'vocab_size': len(df)} # free memeory del csv_string del df elif transform['transform'] in constant.NUMERIC_TRANSFORMS: # get min/max/average sql = ('SELECT max({name}) as max_value, min({name}) as min_value, ' 'avg({name}) as avg_value from {table}').format(name=col_name, table=table_name) df = _execute_sql(sql, table) numerical_vocab_stats[col_name] = {'min': df.iloc[0]['min_value'], 'max': df.iloc[0]['max_value'], 'mean': df.iloc[0]['avg_value']} sys.stdout.write('column %s analyzed.\n' % col_name) sys.stdout.flush() # get num examples sql = 'SELECT count(*) as num_examples from {table}'.format(table=table_name) df = _execute_sql(sql, table) num_examples = df.iloc[0]['num_examples'] # Write the stats file. stats = {'column_stats': numerical_vocab_stats, 'num_examples': num_examples} file_io.write_string_to_file( os.path.join(output_dir, constant.STATS_FILE), json.dumps(stats, indent=2, separators=(',', ': '))) feature_analysis.save_schema_features(schema, features, output_dir) def main(argv=None): args = parse_arguments(sys.argv if argv is None else argv) if args.schema: schema = json.loads( file_io.read_file_to_string(args.schema).decode()) else: import google.datalab.bigquery as bq schema = bq.Table(args.bigquery).schema._bq_schema features = json.loads( file_io.read_file_to_string(args.features).decode()) file_io.recursive_create_dir(args.output) if args.cloud: run_cloud_analysis( output_dir=args.output, csv_file_pattern=args.csv, bigquery_table=args.bigquery, schema=schema, features=features) else: feature_analysis.run_local_analysis( output_dir=args.output, csv_file_pattern=args.csv, schema=schema, features=features) if __name__ == '__main__': main() ================================================ FILE: solutionbox/ml_workbench/tensorflow/setup.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. # This setup file is used when running cloud training or cloud dataflow jobs. from setuptools import setup, find_packages setup( name='trainer', version='1.0.0', packages=find_packages(), description='Google Cloud Datalab helper sub-package', author='Google', author_email='google-cloud-datalab-feedback@googlegroups.com', keywords=[ ], license="Apache Software License", long_description=""" """, install_requires=[ 'tensorflow==1.15.2', 'protobuf==3.1.0', 'pillow==6.2.0', # ML Engine does not have PIL installed ], package_data={ }, data_files=[], ) ================================================ FILE: solutionbox/ml_workbench/tensorflow/trainer/__init__.py ================================================ ================================================ FILE: solutionbox/ml_workbench/tensorflow/trainer/feature_analysis.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import copy import csv import json import os import pandas as pd import sys import six from tensorflow.python.lib.io import file_io from . import feature_transforms as constant def check_schema_transforms_match(schema, inverted_features): """Checks that the transform and schema do not conflict. Args: schema: schema list inverted_features: inverted_features dict Raises: ValueError if transform cannot be applied given schema type. """ num_target_transforms = 0 for col_schema in schema: col_name = col_schema['name'] col_type = col_schema['type'].lower() # Check each transform and schema are compatible if col_name in inverted_features: for transform in inverted_features[col_name]: transform_name = transform['transform'] if transform_name == constant.TARGET_TRANSFORM: num_target_transforms += 1 continue elif col_type in constant.NUMERIC_SCHEMA: if transform_name not in constant.NUMERIC_TRANSFORMS: raise ValueError( 'Transform %s not supported by schema %s' % (transform_name, col_type)) elif col_type == constant.STRING_SCHEMA: if (transform_name not in constant.CATEGORICAL_TRANSFORMS + constant.TEXT_TRANSFORMS and transform_name != constant.IMAGE_TRANSFORM): raise ValueError( 'Transform %s not supported by schema %s' % (transform_name, col_type)) else: raise ValueError('Unsupported schema type %s' % col_type) # Check each transform is compatible for the same source column. # inverted_features[col_name] should belong to exactly 1 of the 5 groups. if col_name in inverted_features: transform_set = {x['transform'] for x in inverted_features[col_name]} if 1 != sum([transform_set.issubset(set(constant.NUMERIC_TRANSFORMS)), transform_set.issubset(set(constant.CATEGORICAL_TRANSFORMS)), transform_set.issubset(set(constant.TEXT_TRANSFORMS)), transform_set.issubset(set([constant.IMAGE_TRANSFORM])), transform_set.issubset(set([constant.TARGET_TRANSFORM]))]): message = """ The source column of a feature can only be used in multiple features within the same family of transforms. The familes are 1) text transformations: %s 2) categorical transformations: %s 3) numerical transformations: %s 4) image transformations: %s 5) target transform: %s Any column can also be a key column. But column %s is used by transforms %s. """ % (str(constant.TEXT_TRANSFORMS), str(constant.CATEGORICAL_TRANSFORMS), str(constant.NUMERIC_TRANSFORMS), constant.IMAGE_TRANSFORM, constant.TARGET_TRANSFORM, col_name, str(transform_set)) raise ValueError(message) if num_target_transforms != 1: raise ValueError('Must have exactly one target transform') def save_schema_features(schema, features, output): # Save a copy of the schema and features in the output folder. file_io.write_string_to_file( os.path.join(output, constant.SCHEMA_FILE), json.dumps(schema, indent=2)) file_io.write_string_to_file( os.path.join(output, constant.FEATURES_FILE), json.dumps(features, indent=2)) def expand_defaults(schema, features): """Add to features any default transformations. Not every column in the schema has an explicit feature transformation listed in the featurs file. For these columns, add a default transformation based on the schema's type. The features dict is modified by this function call. After this function call, every column in schema is used in a feature, and every feature uses a column in the schema. Args: schema: schema list features: features dict Raises: ValueError: if transform cannot be applied given schema type. """ schema_names = [x['name'] for x in schema] # Add missing source columns for name, transform in six.iteritems(features): if 'source_column' not in transform: transform['source_column'] = name # Check source columns are in the schema and collect which are used. used_schema_columns = [] for name, transform in six.iteritems(features): if transform['source_column'] not in schema_names: raise ValueError('source column %s is not in the schema for transform %s' % (transform['source_column'], name)) used_schema_columns.append(transform['source_column']) # Update default transformation based on schema. for col_schema in schema: schema_name = col_schema['name'] schema_type = col_schema['type'].lower() if schema_type not in constant.NUMERIC_SCHEMA + [constant.STRING_SCHEMA]: raise ValueError(('Only the following schema types are supported: %s' % ' '.join(constant.NUMERIC_SCHEMA + [constant.STRING_SCHEMA]))) if schema_name not in used_schema_columns: # add the default transform to the features if schema_type in constant.NUMERIC_SCHEMA: features[schema_name] = { 'transform': constant.DEFAULT_NUMERIC_TRANSFORM, 'source_column': schema_name} elif schema_type == constant.STRING_SCHEMA: features[schema_name] = { 'transform': constant.DEFAULT_CATEGORICAL_TRANSFORM, 'source_column': schema_name} else: raise NotImplementedError('Unknown type %s' % schema_type) # TODO(brandondutra): introduce the notion an analysis plan/classes if we # support more complicated transforms like binning by quratiles. def invert_features(features): """Make a dict in the form source column : set of transforms. Note that the key transform is removed. """ inverted_features = collections.defaultdict(list) for transform in six.itervalues(features): source_column = transform['source_column'] if transform['transform'] == constant.KEY_TRANSFORM: continue inverted_features[source_column].append(transform) return dict(inverted_features) # convert from defaultdict to dict def run_local_analysis(output_dir, csv_file_pattern, schema, features): """Use pandas to analyze csv files. Produces a stats file and vocab files. Args: output_dir: output folder csv_file_pattern: list of csv file paths, may contain wildcards schema: CSV schema list features: features config Raises: ValueError: on unknown transfrorms/schemas """ sys.stdout.write('Expanding any file patterns...\n') sys.stdout.flush() header = [column['name'] for column in schema] input_files = [] for file_pattern in csv_file_pattern: input_files.extend(file_io.get_matching_files(file_pattern)) sys.stdout.write('file list computed.\n') sys.stdout.flush() expand_defaults(schema, features) # features are updated. inverted_features = invert_features(features) check_schema_transforms_match(schema, inverted_features) # Make a copy of inverted_features and update the target transform to be # identity or one hot depending on the schema. inverted_features_target = copy.deepcopy(inverted_features) for name, transforms in six.iteritems(inverted_features_target): transform_set = {x['transform'] for x in transforms} if transform_set == set([constant.TARGET_TRANSFORM]): target_schema = next(col['type'].lower() for col in schema if col['name'] == name) if target_schema in constant.NUMERIC_SCHEMA: inverted_features_target[name] = [{'transform': constant.IDENTITY_TRANSFORM}] else: inverted_features_target[name] = [{'transform': constant.ONE_HOT_TRANSFORM}] # initialize the results def _init_numerical_results(): return {'min': float('inf'), 'max': float('-inf'), 'count': 0, 'sum': 0.0} numerical_results = collections.defaultdict(_init_numerical_results) vocabs = collections.defaultdict(lambda: collections.defaultdict(int)) num_examples = 0 # for each file, update the numerical stats from that file, and update the set # of unique labels. for input_file in input_files: sys.stdout.write('Analyzing file %s...\n' % input_file) sys.stdout.flush() with file_io.FileIO(input_file, 'r') as f: for line in csv.reader(f): if len(header) != len(line): raise ValueError('Schema has %d columns but a csv line only has %d columns.' % (len(header), len(line))) parsed_line = dict(zip(header, line)) num_examples += 1 for col_name, transform_set in six.iteritems(inverted_features_target): # All transforms in transform_set require the same analysis. So look # at the first transform. transform = next(iter(transform_set)) if transform['transform'] in constant.TEXT_TRANSFORMS: separator = transform.get('separator', ' ') split_strings = parsed_line[col_name].split(separator) # If a label is in the row N times, increase it's vocab count by 1. # This is needed for TFIDF, but it's also an interesting stat. for one_label in set(split_strings): # Filter out empty strings if one_label: vocabs[col_name][one_label] += 1 elif transform['transform'] in constant.CATEGORICAL_TRANSFORMS: if parsed_line[col_name]: vocabs[col_name][parsed_line[col_name]] += 1 elif transform['transform'] in constant.NUMERIC_TRANSFORMS: if not parsed_line[col_name].strip(): continue numerical_results[col_name]['min'] = ( min(numerical_results[col_name]['min'], float(parsed_line[col_name]))) numerical_results[col_name]['max'] = ( max(numerical_results[col_name]['max'], float(parsed_line[col_name]))) numerical_results[col_name]['count'] += 1 numerical_results[col_name]['sum'] += float(parsed_line[col_name]) sys.stdout.write('file %s analyzed.\n' % input_file) sys.stdout.flush() # Write the vocab files. Each label is on its own line. vocab_sizes = {} for name, label_count in six.iteritems(vocabs): # df is now: # label1,count # label2,count # ... # where label1 is the most frequent label, and label2 is the 2nd most, etc. df = pd.DataFrame([{'label': label, 'count': count} for label, count in sorted(six.iteritems(label_count), key=lambda x: x[1], reverse=True)], columns=['label', 'count']) csv_string = df.to_csv(index=False, header=False) file_io.write_string_to_file( os.path.join(output_dir, constant.VOCAB_ANALYSIS_FILE % name), csv_string) vocab_sizes[name] = {'vocab_size': len(label_count)} # Update numerical_results to just have min/min/mean for col_name in numerical_results: if float(numerical_results[col_name]['count']) == 0: raise ValueError('Column %s has a zero count' % col_name) mean = (numerical_results[col_name]['sum'] / float(numerical_results[col_name]['count'])) del numerical_results[col_name]['sum'] del numerical_results[col_name]['count'] numerical_results[col_name]['mean'] = mean # Write the stats file. numerical_results.update(vocab_sizes) stats = {'column_stats': numerical_results, 'num_examples': num_examples} file_io.write_string_to_file( os.path.join(output_dir, constant.STATS_FILE), json.dumps(stats, indent=2, separators=(',', ': '))) save_schema_features(schema, features, output_dir) ================================================ FILE: solutionbox/ml_workbench/tensorflow/trainer/feature_transforms.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function import base64 import collections import cStringIO import json import os from PIL import Image import pandas as pd import six import shutil import tensorflow as tf import tempfile from tensorflow.contrib.learn.python.learn.utils import input_fn_utils from tensorflow.contrib import lookup from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3 from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_arg_scope from tensorflow.python.lib.io import file_io # ------------------------------------------------------------------------------ # public constants. Changing these could break user's code # ------------------------------------------------------------------------------ # Individual transforms IDENTITY_TRANSFORM = 'identity' SCALE_TRANSFORM = 'scale' ONE_HOT_TRANSFORM = 'one_hot' EMBEDDING_TRANSFROM = 'embedding' MULTI_HOT_TRANSFORM = 'multi_hot' BOW_TRANSFORM = 'bag_of_words' TFIDF_TRANSFORM = 'tfidf' KEY_TRANSFORM = 'key' TARGET_TRANSFORM = 'target' IMAGE_TRANSFORM = 'image_to_vec' # ------------------------------------------------------------------------------ # internal constants. # ------------------------------------------------------------------------------ # Files SCHEMA_FILE = 'schema.json' FEATURES_FILE = 'features.json' STATS_FILE = 'stats.json' VOCAB_ANALYSIS_FILE = 'vocab_%s.csv' # Transform collections NUMERIC_TRANSFORMS = [IDENTITY_TRANSFORM, SCALE_TRANSFORM] CATEGORICAL_TRANSFORMS = [ONE_HOT_TRANSFORM, EMBEDDING_TRANSFROM] TEXT_TRANSFORMS = [MULTI_HOT_TRANSFORM, BOW_TRANSFORM, TFIDF_TRANSFORM] # If the features file is missing transforms, apply these. DEFAULT_NUMERIC_TRANSFORM = IDENTITY_TRANSFORM DEFAULT_CATEGORICAL_TRANSFORM = ONE_HOT_TRANSFORM # BigQuery Schema values supported INTEGER_SCHEMA = 'integer' FLOAT_SCHEMA = 'float' STRING_SCHEMA = 'string' NUMERIC_SCHEMA = [INTEGER_SCHEMA, FLOAT_SCHEMA] # Inception Checkpoint INCEPTION_V3_CHECKPOINT = 'gs://cloud-ml-data/img/flower_photos/inception_v3_2016_08_28.ckpt' INCEPTION_EXCLUDED_VARIABLES = ['InceptionV3/AuxLogits', 'InceptionV3/Logits', 'global_step'] _img_buf = cStringIO.StringIO() Image.new('RGB', (16, 16)).save(_img_buf, 'jpeg') IMAGE_DEFAULT_STRING = base64.urlsafe_b64encode(_img_buf.getvalue()) IMAGE_BOTTLENECK_TENSOR_SIZE = 2048 IMAGE_HIDDEN_TENSOR_SIZE = int(IMAGE_BOTTLENECK_TENSOR_SIZE / 4) # ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------ # start of transform functions # ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------ def _scale(x, min_x_value, max_x_value, output_min, output_max): """Scale a column to [output_min, output_max]. Assumes the columns's range is [min_x_value, max_x_value]. If this is not true at training or prediction time, the output value of this scale could be outside the range [output_min, output_max]. Raises: ValueError: if min_x_value = max_x_value, as the column is constant. """ if round(min_x_value - max_x_value, 7) == 0: # There is something wrong with the data. # Why round to 7 places? It's the same as unittest's assertAlmostEqual. raise ValueError('In make_scale_tito, min_x_value == max_x_value') def _scale(x): min_x_valuef = tf.to_float(min_x_value) max_x_valuef = tf.to_float(max_x_value) output_minf = tf.to_float(output_min) output_maxf = tf.to_float(output_max) return ((((tf.to_float(x) - min_x_valuef) * (output_maxf - output_minf)) / (max_x_valuef - min_x_valuef)) + output_minf) return _scale(x) def _string_to_int(x, vocab): """Given a vocabulary and a string tensor `x`, maps `x` into an int tensor. Args: x: A `Column` representing a string value. vocab: list of strings. Returns: A `Column` where each string value is mapped to an integer representing its index in the vocab. Out of vocab values are mapped to len(vocab). """ def _map_to_int(x): """Maps string tensor into indexes using vocab. Args: x : a Tensor/SparseTensor of string. Returns: a Tensor/SparseTensor of indexes (int) of the same shape as x. """ table = lookup.index_table_from_tensor( vocab, default_value=len(vocab)) return table.lookup(x) return _map_to_int(x) # TODO(brandondura): update this to not depend on tf layer's feature column # 'sum' combiner in the future. def _tfidf(x, reduced_term_freq, vocab_size, corpus_size): """Maps the terms in x to their (1/doc_length) * inverse document frequency. Args: x: A `Column` representing int64 values (most likely that are the result of calling string_to_int on a tokenized string). reduced_term_freq: A dense tensor of shape (vocab_size,) that represents the count of the number of documents with each term. So vocab token i ( which is an int) occures in reduced_term_freq[i] examples in the corpus. This means reduced_term_freq should have a count for out-of-vocab tokens vocab_size: An int - the count of vocab used to turn the string into int64s including any out-of-vocab ids corpus_size: A scalar count of the number of documents in the corpus Returns: A `Column` where each int value is mapped to a double equal to (1 if that term appears in that row, 0 otherwise / the number of terms in that row) * the log of (the number of rows in `x` / (1 + the number of rows in `x` where the term appears at least once)) NOTE: This is intented to be used with the feature_column 'sum' combiner to arrive at the true term frequncies. """ def _map_to_vocab_range(x): """Enforces that the vocab_ids in x are positive.""" return tf.SparseTensor( indices=x.indices, values=tf.mod(x.values, vocab_size), dense_shape=x.dense_shape) def _map_to_tfidf(x): """Calculates the inverse document frequency of terms in the corpus. Args: x : a SparseTensor of int64 representing string indices in vocab. Returns: The tf*idf values """ # Add one to the reduced term freqnencies to avoid dividing by zero. idf = tf.log(tf.to_double(corpus_size) / ( 1.0 + tf.to_double(reduced_term_freq))) dense_doc_sizes = tf.to_double(tf.sparse_reduce_sum(tf.SparseTensor( indices=x.indices, values=tf.ones_like(x.values), dense_shape=x.dense_shape), 1)) # For every term in x, divide the idf by the doc size. # The two gathers both result in shape idf_over_doc_size = (tf.gather(idf, x.values) / tf.gather(dense_doc_sizes, x.indices[:, 0])) return tf.SparseTensor( indices=x.indices, values=idf_over_doc_size, dense_shape=x.dense_shape) cleaned_input = _map_to_vocab_range(x) weights = _map_to_tfidf(cleaned_input) return tf.to_float(weights) # TODO(brandondura): update this to not depend on tf layer's feature column # 'sum' combiner in the future. def _bag_of_words(x): """Computes bag of words weights Note the return type is a float sparse tensor, not a int sparse tensor. This is so that the output types batch tfidf, and any downstream transformation in tf layers during training can be applied to both. """ def _bow(x): """Comptue BOW weights. As tf layer's sum combiner is used, the weights can be just ones. Tokens are not summed together here. """ return tf.SparseTensor( indices=x.indices, values=tf.to_float(tf.ones_like(x.values)), dense_shape=x.dense_shape) return _bow(x) def _make_image_to_vec_tito(feature_name, tmp_dir=None, checkpoint=None): """Creates a tensor-in-tensor-out function that produces embeddings from image bytes. Image to embedding is implemented with Tensorflow's inception v3 model and a pretrained checkpoint. It returns 1x2048 'PreLogits' embeddings for each image. Args: feature_name: The name of the feature. Used only to identify the image tensors so we can get gradients for probe in image prediction explaining. tmp_dir: a local directory that is used for downloading the checkpoint. If non, a temp folder will be made and deleted. checkpoint: the inception v3 checkpoint gs or local path. If None, default checkpoint is used. Returns: a tensor-in-tensor-out function that takes image string tensor and returns embeddings. """ def _image_to_vec(image_str_tensor): def _decode_and_resize(image_tensor): """Decodes jpeg string, resizes it and returns a uint8 tensor.""" # These constants are set by Inception v3's expectations. height = 299 width = 299 channels = 3 image_tensor = tf.where(tf.equal(image_tensor, ''), IMAGE_DEFAULT_STRING, image_tensor) # Fork by whether image_tensor value is a file path, or a base64 encoded string. slash_positions = tf.equal(tf.string_split([image_tensor], delimiter="").values, '/') is_file_path = tf.cast(tf.count_nonzero(slash_positions), tf.bool) # The following two functions are required for tf.cond. Note that we can not replace them # with lambda. According to TF docs, if using inline lambda, both branches of condition # will be executed. The workaround is to use a function call. def _read_file(): return tf.read_file(image_tensor) def _decode_base64(): return tf.decode_base64(image_tensor) image = tf.cond(is_file_path, lambda: _read_file(), lambda: _decode_base64()) image = tf.image.decode_jpeg(image, channels=channels) image = tf.expand_dims(image, 0) image = tf.image.resize_bilinear(image, [height, width], align_corners=False) image = tf.squeeze(image, squeeze_dims=[0]) image = tf.cast(image, dtype=tf.uint8) return image # The CloudML Prediction API always "feeds" the Tensorflow graph with # dynamic batch sizes e.g. (?,). decode_jpeg only processes scalar # strings because it cannot guarantee a batch of images would have # the same output size. We use tf.map_fn to give decode_jpeg a scalar # string from dynamic batches. image = tf.map_fn(_decode_and_resize, image_str_tensor, back_prop=False, dtype=tf.uint8) image = tf.image.convert_image_dtype(image, dtype=tf.float32) # "gradients_[feature_name]" will be used for computing integrated gradients. image = tf.identity(image, name='gradients_' + feature_name) image = tf.subtract(image, 0.5) inception_input = tf.multiply(image, 2.0) # Build Inception layers, which expect a tensor of type float from [-1, 1) # and shape [batch_size, height, width, channels]. with tf.contrib.slim.arg_scope(inception_v3_arg_scope()): _, end_points = inception_v3(inception_input, is_training=False) embeddings = end_points['PreLogits'] inception_embeddings = tf.squeeze(embeddings, [1, 2], name='SpatialSqueeze') return inception_embeddings def _tito_from_checkpoint(tito_in, checkpoint, exclude): """ Create an all-constants tito function from an original tito function. Given a tensor-in-tensor-out function which contains variables and a checkpoint path, create a new tensor-in-tensor-out function which includes only constants, and can be used in tft.map. """ def _tito_out(tensor_in): checkpoint_dir = tmp_dir if tmp_dir is None: checkpoint_dir = tempfile.mkdtemp() g = tf.Graph() with g.as_default(): si = tf.placeholder(dtype=tensor_in.dtype, shape=tensor_in.shape, name=tensor_in.op.name) so = tito_in(si) all_vars = tf.contrib.slim.get_variables_to_restore(exclude=exclude) saver = tf.train.Saver(all_vars) # Downloading the checkpoint from GCS to local speeds up saver.restore() a lot. checkpoint_tmp = os.path.join(checkpoint_dir, 'checkpoint') with file_io.FileIO(checkpoint, 'r') as f_in, file_io.FileIO(checkpoint_tmp, 'w') as f_out: f_out.write(f_in.read()) with tf.Session() as sess: saver.restore(sess, checkpoint_tmp) output_graph_def = tf.graph_util.convert_variables_to_constants(sess, g.as_graph_def(), [so.op.name]) file_io.delete_file(checkpoint_tmp) if tmp_dir is None: shutil.rmtree(checkpoint_dir) tensors_out = tf.import_graph_def(output_graph_def, input_map={si.name: tensor_in}, return_elements=[so.name]) return tensors_out[0] return _tito_out if not checkpoint: checkpoint = INCEPTION_V3_CHECKPOINT return _tito_from_checkpoint(_image_to_vec, checkpoint, INCEPTION_EXCLUDED_VARIABLES) # ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------ # end of transform functions # ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------ def make_preprocessing_fn(output_dir, features, keep_target): """Makes a preprocessing function. Args: output_dir: folder path that contains the vocab and stats files. features: the features dict Returns: a function that takes a dict of input tensors """ def preprocessing_fn(inputs): """Preprocessing function. Args: inputs: dictionary of raw input tensors Returns: A dictionary of transformed tensors """ stats = json.loads( file_io.read_file_to_string( os.path.join(output_dir, STATS_FILE)).decode()) result = {} for name, transform in six.iteritems(features): transform_name = transform['transform'] source_column = transform['source_column'] if transform_name == KEY_TRANSFORM: transform_name = 'identity' elif transform_name == TARGET_TRANSFORM: if not keep_target: continue if file_io.file_exists(os.path.join(output_dir, VOCAB_ANALYSIS_FILE % source_column)): transform_name = 'one_hot' else: transform_name = 'identity' if transform_name == 'identity': result[name] = inputs[source_column] elif transform_name == 'scale': result[name] = _scale( inputs[name], min_x_value=stats['column_stats'][source_column]['min'], max_x_value=stats['column_stats'][source_column]['max'], output_min=transform.get('value', 1) * (-1), output_max=transform.get('value', 1)) elif transform_name in [ONE_HOT_TRANSFORM, EMBEDDING_TRANSFROM, MULTI_HOT_TRANSFORM, TFIDF_TRANSFORM, BOW_TRANSFORM]: vocab, ex_count = read_vocab_file( os.path.join(output_dir, VOCAB_ANALYSIS_FILE % source_column)) if transform_name == TFIDF_TRANSFORM: separator = transform.get('separator', ' ') tokens = tf.string_split(inputs[source_column], separator) ids = _string_to_int(tokens, vocab) weights = _tfidf( x=ids, reduced_term_freq=ex_count + [0], vocab_size=len(vocab) + 1, corpus_size=stats['num_examples']) result[name + '_ids'] = ids result[name + '_weights'] = weights elif transform_name == BOW_TRANSFORM: separator = transform.get('separator', ' ') tokens = tf.string_split(inputs[source_column], separator) ids = _string_to_int(tokens, vocab) weights = _bag_of_words(x=ids) result[name + '_ids'] = ids result[name + '_weights'] = weights elif transform_name == MULTI_HOT_TRANSFORM: separator = transform.get('separator', ' ') tokens = tf.string_split(inputs[source_column], separator) result[name] = _string_to_int(tokens, vocab) else: # ONE_HOT_TRANSFORM: making a dense vector is done at training # EMBEDDING_TRANSFROM: embedding vectors have to be done at training result[name] = _string_to_int(inputs[source_column], vocab) elif transform_name == IMAGE_TRANSFORM: make_image_to_vec_fn = _make_image_to_vec_tito( name, checkpoint=transform.get('checkpoint', None)) result[name] = make_image_to_vec_fn(inputs[source_column]) else: raise ValueError('unknown transform %s' % transform_name) return result return preprocessing_fn def get_transformed_feature_info(features, schema): """Returns information about the transformed features. Returns: Dict in the from {transformed_feature_name: {dtype: tf type, size: int or None}}. If the size is None, then the tensor is a sparse tensor. """ info = collections.defaultdict(dict) for name, transform in six.iteritems(features): transform_name = transform['transform'] source_column = transform['source_column'] if transform_name == IDENTITY_TRANSFORM: schema_type = next(col['type'].lower() for col in schema if col['name'] == source_column) if schema_type == FLOAT_SCHEMA: info[name]['dtype'] = tf.float32 elif schema_type == INTEGER_SCHEMA: info[name]['dtype'] = tf.int64 else: raise ValueError('itentity should only be applied to integer or float' 'columns, but was used on %s' % name) info[name]['size'] = 1 elif transform_name == SCALE_TRANSFORM: info[name]['dtype'] = tf.float32 info[name]['size'] = 1 elif transform_name == ONE_HOT_TRANSFORM: info[name]['dtype'] = tf.int64 info[name]['size'] = 1 elif transform_name == EMBEDDING_TRANSFROM: info[name]['dtype'] = tf.int64 info[name]['size'] = 1 elif transform_name == MULTI_HOT_TRANSFORM: info[name]['dtype'] = tf.int64 info[name]['size'] = None elif transform_name == BOW_TRANSFORM or transform_name == TFIDF_TRANSFORM: info[name + '_ids']['dtype'] = tf.int64 info[name + '_weights']['dtype'] = tf.float32 info[name + '_ids']['size'] = None info[name + '_weights']['size'] = None elif transform_name == KEY_TRANSFORM: schema_type = next(col['type'].lower() for col in schema if col['name'] == source_column) if schema_type == FLOAT_SCHEMA: info[name]['dtype'] = tf.float32 elif schema_type == INTEGER_SCHEMA: info[name]['dtype'] = tf.int64 else: info[name]['dtype'] = tf.string info[name]['size'] = 1 elif transform_name == TARGET_TRANSFORM: # If the input is a string, it gets converted to an int (id) schema_type = next(col['type'].lower() for col in schema if col['name'] == source_column) if schema_type in NUMERIC_SCHEMA: info[name]['dtype'] = tf.float32 else: info[name]['dtype'] = tf.int64 info[name]['size'] = 1 elif transform_name == IMAGE_TRANSFORM: info[name]['dtype'] = tf.float32 info[name]['size'] = IMAGE_BOTTLENECK_TENSOR_SIZE else: raise ValueError('Unknown transfrom %s' % transform_name) return info def csv_header_and_defaults(features, schema, stats, keep_target): """Gets csv header and default lists.""" target_name = get_target_name(features) if keep_target and not target_name: raise ValueError('Cannot find target transform') csv_header = [] record_defaults = [] for col in schema: if not keep_target and col['name'] == target_name: continue # Note that numerical key columns do not have a stats entry, hence the use # of get(col['name'], {}) csv_header.append(col['name']) if col['type'].lower() == INTEGER_SCHEMA: dtype = tf.int64 default = int(stats['column_stats'].get(col['name'], {}).get('mean', 0)) elif col['type'].lower() == FLOAT_SCHEMA: dtype = tf.float32 default = float(stats['column_stats'].get(col['name'], {}).get('mean', 0.0)) else: dtype = tf.string default = '' record_defaults.append(tf.constant([default], dtype=dtype)) return csv_header, record_defaults def build_csv_serving_tensors_for_transform_step(analysis_path, features, schema, stats, keep_target): """Builds a serving function starting from raw csv. This should only be used by transform.py (the transform step), and the For image columns, the image should be a base64 string encoding the image. The output of this function will transform that image to a 2048 long vector using the inception model. """ csv_header, record_defaults = csv_header_and_defaults(features, schema, stats, keep_target) placeholder = tf.placeholder(dtype=tf.string, shape=(None,), name='csv_input_placeholder') tensors = tf.decode_csv(placeholder, record_defaults) raw_features = dict(zip(csv_header, tensors)) transform_fn = make_preprocessing_fn(analysis_path, features, keep_target) transformed_tensors = transform_fn(raw_features) transformed_features = {} # Expand the dims of non-sparse tensors for k, v in six.iteritems(transformed_tensors): if isinstance(v, tf.Tensor) and v.get_shape().ndims == 1: transformed_features[k] = tf.expand_dims(v, -1) else: transformed_features[k] = v return input_fn_utils.InputFnOps( transformed_features, None, {"csv_example": placeholder}) def build_csv_serving_tensors_for_training_step(analysis_path, features, schema, stats, keep_target): """Builds a serving function starting from raw csv, used at model export time. For image columns, the image should be a base64 string encoding the image. The output of this function will transform that image to a 2048 long vector using the inception model and then a fully connected net is attached to the 2048 long image embedding. """ transformed_features, _, placeholder_dict = build_csv_serving_tensors_for_transform_step( analysis_path=analysis_path, features=features, schema=schema, stats=stats, keep_target=keep_target) transformed_features = image_feature_engineering( features=features, feature_tensors_dict=transformed_features) return input_fn_utils.InputFnOps( transformed_features, None, placeholder_dict) def build_csv_transforming_training_input_fn(schema, features, stats, analysis_output_dir, raw_data_file_pattern, training_batch_size, num_epochs=None, randomize_input=False, min_after_dequeue=1, reader_num_threads=1, allow_smaller_final_batch=True): """Creates training input_fn that reads raw csv data and applies transforms. Args: schema: schema list features: features dict stats: stats dict analysis_output_dir: output folder from analysis raw_data_file_pattern: file path, or list of files training_batch_size: An int specifying the batch size to use. num_epochs: numer of epochs to read from the files. Use None to read forever. randomize_input: If true, the input rows are read out of order. This randomness is limited by the min_after_dequeue value. min_after_dequeue: Minimum number elements in the reading queue after a dequeue, used to ensure a level of mixing of elements. Only used if randomize_input is True. reader_num_threads: The number of threads enqueuing data. allow_smaller_final_batch: If false, fractional batches at the end of training or evaluation are not used. Returns: An input_fn suitable for training that reads raw csv training data and applies transforms. """ def raw_training_input_fn(): """Training input function that reads raw data and applies transforms.""" if isinstance(raw_data_file_pattern, six.string_types): filepath_list = [raw_data_file_pattern] else: filepath_list = raw_data_file_pattern files = [] for path in filepath_list: files.extend(file_io.get_matching_files(path)) filename_queue = tf.train.string_input_producer( files, num_epochs=num_epochs, shuffle=randomize_input) csv_id, csv_lines = tf.TextLineReader().read_up_to(filename_queue, training_batch_size) queue_capacity = (reader_num_threads + 3) * training_batch_size + min_after_dequeue if randomize_input: _, batch_csv_lines = tf.train.shuffle_batch( tensors=[csv_id, csv_lines], batch_size=training_batch_size, capacity=queue_capacity, min_after_dequeue=min_after_dequeue, enqueue_many=True, num_threads=reader_num_threads, allow_smaller_final_batch=allow_smaller_final_batch) else: _, batch_csv_lines = tf.train.batch( tensors=[csv_id, csv_lines], batch_size=training_batch_size, capacity=queue_capacity, enqueue_many=True, num_threads=reader_num_threads, allow_smaller_final_batch=allow_smaller_final_batch) csv_header, record_defaults = csv_header_and_defaults(features, schema, stats, keep_target=True) parsed_tensors = tf.decode_csv(batch_csv_lines, record_defaults, name='csv_to_tensors') raw_features = dict(zip(csv_header, parsed_tensors)) transform_fn = make_preprocessing_fn(analysis_output_dir, features, keep_target=True) transformed_tensors = transform_fn(raw_features) # Expand the dims of non-sparse tensors. This is needed by tf.learn. transformed_features = {} for k, v in six.iteritems(transformed_tensors): if isinstance(v, tf.Tensor) and v.get_shape().ndims == 1: transformed_features[k] = tf.expand_dims(v, -1) else: transformed_features[k] = v # image_feature_engineering does not need to be called as images are not # supported in raw csv for training. # Remove the target tensor, and return it directly target_name = get_target_name(features) if not target_name or target_name not in transformed_features: raise ValueError('Cannot find target transform in features') transformed_target = transformed_features.pop(target_name) return transformed_features, transformed_target return raw_training_input_fn def build_tfexample_transfored_training_input_fn(schema, features, analysis_output_dir, raw_data_file_pattern, training_batch_size, num_epochs=None, randomize_input=False, min_after_dequeue=1, reader_num_threads=1, allow_smaller_final_batch=True): """Creates training input_fn that reads transformed tf.example files. Args: schema: schema list features: features dict analysis_output_dir: output folder from analysis raw_data_file_pattern: file path, or list of files training_batch_size: An int specifying the batch size to use. num_epochs: numer of epochs to read from the files. Use None to read forever. randomize_input: If true, the input rows are read out of order. This randomness is limited by the min_after_dequeue value. min_after_dequeue: Minimum number elements in the reading queue after a dequeue, used to ensure a level of mixing of elements. Only used if randomize_input is True. reader_num_threads: The number of threads enqueuing data. allow_smaller_final_batch: If false, fractional batches at the end of training or evaluation are not used. Returns: An input_fn suitable for training that reads transformed data in tf record files of tf.example. """ def transformed_training_input_fn(): """Training input function that reads transformed data.""" if isinstance(raw_data_file_pattern, six.string_types): filepath_list = [raw_data_file_pattern] else: filepath_list = raw_data_file_pattern files = [] for path in filepath_list: files.extend(file_io.get_matching_files(path)) filename_queue = tf.train.string_input_producer( files, num_epochs=num_epochs, shuffle=randomize_input) options = tf.python_io.TFRecordOptions( compression_type=tf.python_io.TFRecordCompressionType.GZIP) ex_id, ex_str = tf.TFRecordReader(options=options).read_up_to( filename_queue, training_batch_size) queue_capacity = (reader_num_threads + 3) * training_batch_size + min_after_dequeue if randomize_input: _, batch_ex_str = tf.train.shuffle_batch( tensors=[ex_id, ex_str], batch_size=training_batch_size, capacity=queue_capacity, min_after_dequeue=min_after_dequeue, enqueue_many=True, num_threads=reader_num_threads, allow_smaller_final_batch=allow_smaller_final_batch) else: _, batch_ex_str = tf.train.batch( tensors=[ex_id, ex_str], batch_size=training_batch_size, capacity=queue_capacity, enqueue_many=True, num_threads=reader_num_threads, allow_smaller_final_batch=allow_smaller_final_batch) feature_spec = {} feature_info = get_transformed_feature_info(features, schema) for name, info in six.iteritems(feature_info): if info['size'] is None: feature_spec[name] = tf.VarLenFeature(dtype=info['dtype']) else: feature_spec[name] = tf.FixedLenFeature(shape=[info['size']], dtype=info['dtype']) parsed_tensors = tf.parse_example(batch_ex_str, feature_spec) # Expand the dims of non-sparse tensors. This is needed by tf.learn. transformed_features = {} for k, v in six.iteritems(parsed_tensors): if isinstance(v, tf.Tensor) and v.get_shape().ndims == 1: transformed_features[k] = tf.expand_dims(v, -1) else: # Sparse tensor transformed_features[k] = v transformed_features = image_feature_engineering( features=features, feature_tensors_dict=transformed_features) # Remove the target tensor, and return it directly target_name = get_target_name(features) if not target_name or target_name not in transformed_features: raise ValueError('Cannot find target transform in features') transformed_target = transformed_features.pop(target_name) return transformed_features, transformed_target return transformed_training_input_fn def image_feature_engineering(features, feature_tensors_dict): """Add a hidden layer on image features. Args: features: features dict feature_tensors_dict: dict of feature-name: tensor """ engineered_features = {} for name, feature_tensor in six.iteritems(feature_tensors_dict): if name in features and features[name]['transform'] == IMAGE_TRANSFORM: with tf.name_scope(name, 'Wx_plus_b'): hidden = tf.contrib.layers.fully_connected( feature_tensor, IMAGE_HIDDEN_TENSOR_SIZE) engineered_features[name] = hidden else: engineered_features[name] = feature_tensor return engineered_features def get_target_name(features): for name, transform in six.iteritems(features): if transform['transform'] == TARGET_TRANSFORM: return name return None def read_vocab_file(file_path): """Reads a vocab file to memeory. Args: file_path: Each line of the vocab is in the form "token,example_count" Returns: Two lists, one for the vocab, and one for just the example counts. """ with file_io.FileIO(file_path, 'r') as f: vocab_pd = pd.read_csv( f, header=None, names=['vocab', 'count'], dtype=str, # Prevent pd from converting numerical categories. na_filter=False) # Prevent pd from converting 'NA' to a NaN. vocab = vocab_pd['vocab'].tolist() ex_count = vocab_pd['count'].astype(int).tolist() return vocab, ex_count ================================================ FILE: solutionbox/ml_workbench/tensorflow/trainer/task.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import json import math import multiprocessing import os import re import sys import six import tensorflow as tf from tensorflow.contrib.framework.python.ops import variables as contrib_variables from tensorflow.contrib.learn.python.learn import export_strategy from tensorflow.contrib.learn.python.learn import learn_runner from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils from tensorflow.python.client import session as tf_session from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io from tensorflow.python.ops import resources from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variables from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import saver from tensorflow.python.util import compat from . import feature_transforms from . import feature_analysis # Constants for the Prediction Graph fetch tensors. PG_TARGET = 'target' # from input PG_REGRESSION_PREDICTED_TARGET = 'predicted' PG_CLASSIFICATION_FIRST_LABEL = 'predicted' PG_CLASSIFICATION_FIRST_SCORE = 'probability' PG_CLASSIFICATION_LABEL_TEMPLATE = 'predicted_%s' PG_CLASSIFICATION_SCORE_TEMPLATE = 'probability_%s' class DatalabParser(): """An arg parser that also prints package specific args with --datalab-help. When using Datalab magic's to run this trainer, it prints it's own help menu that describes the required options that are common to all trainers. In order to print just the options that are unique to this trainer, datalab calls this file with --datalab-help. This class implements --datalab-help by building a list of help string that only includes the unique parameters. """ def __init__(self, epilog=None, datalab_epilog=None): self.full_parser = argparse.ArgumentParser(epilog=epilog) self.datalab_help = [] self.datalab_epilog = datalab_epilog # Datalab help string self.full_parser.add_argument( '--datalab-help', action=self.make_datalab_help_action(), help='Show a smaller help message for DataLab only and exit') # The arguments added here are required to exist by Datalab's "%%ml train" magics. self.full_parser.add_argument( '--train', type=str, required=True, action='append', metavar='FILE') self.full_parser.add_argument( '--eval', type=str, required=True, action='append', metavar='FILE') self.full_parser.add_argument('--job-dir', type=str, required=True) self.full_parser.add_argument( '--analysis', type=str, metavar='ANALYSIS_OUTPUT_DIR', help=('Output folder of analysis. Should contain the schema, stats, and ' 'vocab files. Path must be on GCS if running cloud training. ' + 'If absent, --schema and --features must be provided and ' + 'the master trainer will do analysis locally.')) self.full_parser.add_argument( '--transform', action='store_true', default=False, help='If used, input data is raw csv that needs transformation. If analysis ' + 'is required to run in trainerm this is automatically set to true.') self.full_parser.add_argument( '--schema', type=str, help='Schema of the training csv file. Only needed if analysis is required.') self.full_parser.add_argument( '--features', type=str, help='Feature transform config. Only needed if analysis is required.') def make_datalab_help_action(self): """Custom action for --datalab-help. The action output the package specific parameters and will be part of "%%ml train" help string. """ datalab_help = self.datalab_help epilog = self.datalab_epilog class _CustomAction(argparse.Action): def __init__(self, option_strings, dest, help=None): super(_CustomAction, self).__init__( option_strings=option_strings, dest=dest, nargs=0, help=help) def __call__(self, parser, args, values, option_string=None): print('\n\n'.join(datalab_help)) if epilog: print(epilog) # We have printed all help string datalab needs. If we don't quit, it will complain about # missing required arguments later. quit() return _CustomAction def add_argument(self, name, **kwargs): # Any argument added here is not required by Datalab, and so is unique # to this trainer. Add each argument to the main parser and the datalab helper string. self.full_parser.add_argument(name, **kwargs) name = name.replace('--', '') # leading spaces are needed for datalab's help formatting. msg = ' ' + name + ': ' if 'help' in kwargs: msg += kwargs['help'] + ' ' if kwargs.get('required', False): msg += 'Required. ' else: msg += 'Optional. ' if 'choices' in kwargs: msg += 'One of ' + str(kwargs['choices']) + '. ' if 'default' in kwargs: msg += 'default: ' + str(kwargs['default']) + '.' self.datalab_help.append(msg) def parse_known_args(self, args=None): return self.full_parser.parse_known_args(args=args) def parse_arguments(argv): """Parse the command line arguments.""" parser = DatalabParser( epilog=('Note that if using a DNN model, --hidden-layer-size1=NUM, ' '--hidden-layer-size2=NUM, ..., is also required. '), datalab_epilog=(""" Note that if using a DNN model, hidden-layer-size1: NUM hidden-layer-size2: NUM ... is also required. """)) # HP parameters parser.add_argument( '--epsilon', type=float, default=0.0005, metavar='R', help='tf.train.AdamOptimizer epsilon. Only used in dnn models.') parser.add_argument( '--l1-regularization', type=float, default=0.0, metavar='R', help='L1 term for linear models.') parser.add_argument( '--l2-regularization', type=float, default=0.0, metavar='R', help='L2 term for linear models.') # Model parameters parser.add_argument( '--model', required=True, choices=['linear_classification', 'linear_regression', 'dnn_classification', 'dnn_regression']) parser.add_argument( '--top-n', type=int, default=0, metavar='N', help=('For classification problems, the output graph will contain the ' 'labels and scores for the top n classes, and results will be in the form of ' '"predicted, predicted_2, ..., probability, probability_2, ...". ' 'If --top-n=0, then all labels and scores are returned in the form of ' '"predicted, class_name1, class_name2,...".')) # HP parameters parser.add_argument( '--learning-rate', type=float, default=0.01, metavar='R', help='optimizer learning rate.') # Training input parameters parser.add_argument( '--max-steps', type=int, metavar='N', help='Maximum number of training steps to perform. If unspecified, will ' 'honor "max-epochs".') parser.add_argument( '--max-epochs', type=int, default=1000, metavar='N', help='Maximum number of training data epochs on which to train. If ' 'both "max-steps" and "max-epochs" are specified, the training ' 'job will run for "max-steps" or "num-epochs", whichever occurs ' 'first. If early stopping is enabled, training may also stop ' 'earlier.') parser.add_argument( '--train-batch-size', type=int, default=64, metavar='N', help='How many training examples are used per step. If num-epochs is ' 'used, the last batch may not be full.') parser.add_argument( '--eval-batch-size', type=int, default=64, metavar='N', help='Batch size during evaluation. Larger values increase performance ' 'but also increase peak memory usgae on the master node. One pass ' 'over the full eval set is performed per evaluation run.') parser.add_argument( '--min-eval-frequency', type=int, default=1000, metavar='N', help='Minimum number of training steps between evaluations. Evaluation ' 'does not occur if no new checkpoint is available, hence, this is ' 'the minimum. If 0, the evaluation will only happen after training. ') parser.add_argument( '--early-stopping-num_evals', type=int, default=3, help='Automatic training stop after results of specified number of evals ' 'in a row show the model performance does not improve. Set to 0 to ' 'disable early stopping.') parser.add_argument( '--logging-level', choices=['error', 'warning', 'info'], help='The TF logging level. If absent, use info for cloud training ' 'and warning for local training.') args, remaining_args = parser.parse_known_args(args=argv[1:]) # All HP parambeters must be unique, so we need to support an unknown number # of --hidden-layer-size1=10 --lhidden-layer-size2=10 ... # Look at remaining_args for hidden-layer-size\d+ to get the layer info. # Get number of layers pattern = re.compile('hidden-layer-size(\d+)') num_layers = 0 for other_arg in remaining_args: match = re.search(pattern, other_arg) if match: if int(match.group(1)) <= 0: raise ValueError('layer size must be a positive integer. Was given %s' % other_arg) num_layers = max(num_layers, int(match.group(1))) # Build a new parser so we catch unknown args and missing layer_sizes. parser = argparse.ArgumentParser() for i in range(num_layers): parser.add_argument('--hidden-layer-size%s' % str(i + 1), type=int, required=True) layer_args = vars(parser.parse_args(args=remaining_args)) hidden_layer_sizes = [] for i in range(num_layers): key = 'hidden_layer_size%s' % str(i + 1) hidden_layer_sizes.append(layer_args[key]) assert len(hidden_layer_sizes) == num_layers args.hidden_layer_sizes = hidden_layer_sizes return args def is_linear_model(model_type): return model_type.startswith('linear_') def is_dnn_model(model_type): return model_type.startswith('dnn_') def is_regression_model(model_type): return model_type.endswith('_regression') def is_classification_model(model_type): return model_type.endswith('_classification') def build_feature_columns(features, stats, model_type): feature_columns = [] is_dnn = is_dnn_model(model_type) # Supported transforms: # for DNN # numerical number # one hot: sparse int column -> one_hot_column # ebmedding: sparse int column -> embedding_column # text: sparse int weighted column -> embedding_column # for linear # numerical number # one hot: sparse int column # ebmedding: sparse int column -> hash int # text: sparse int weighted column # It is unfortunate that tf.layers has different feature transforms if the # model is linear or DNN. This pacakge should not expose to the user that # we are using tf.layers. for name, transform in six.iteritems(features): transform_name = transform['transform'] source_column = transform['source_column'] if transform_name in feature_transforms.NUMERIC_TRANSFORMS: new_feature = tf.contrib.layers.real_valued_column(name, dimension=1) elif (transform_name == feature_transforms.ONE_HOT_TRANSFORM or transform_name == feature_transforms.MULTI_HOT_TRANSFORM): sparse = tf.contrib.layers.sparse_column_with_integerized_feature( name, bucket_size=stats['column_stats'][source_column]['vocab_size']) if is_dnn: new_feature = tf.contrib.layers.one_hot_column(sparse) else: new_feature = sparse elif transform_name == feature_transforms.EMBEDDING_TRANSFROM: if is_dnn: sparse = tf.contrib.layers.sparse_column_with_integerized_feature( name, bucket_size=stats['column_stats'][source_column]['vocab_size']) new_feature = tf.contrib.layers.embedding_column( sparse, dimension=transform['embedding_dim']) else: new_feature = tf.contrib.layers.sparse_column_with_hash_bucket( name, hash_bucket_size=transform['embedding_dim'], dtype=dtypes.int64) elif transform_name in feature_transforms.TEXT_TRANSFORMS: sparse_ids = tf.contrib.layers.sparse_column_with_integerized_feature( name + '_ids', bucket_size=stats['column_stats'][source_column]['vocab_size'], combiner='sum') sparse_weights = tf.contrib.layers.weighted_sparse_column( sparse_id_column=sparse_ids, weight_column_name=name + '_weights', dtype=dtypes.float32) if is_dnn: new_feature = tf.contrib.layers.one_hot_column(sparse_ids) dimension = int(math.log(stats['column_stats'][source_column]['vocab_size'])) + 1 new_feature = tf.contrib.layers.embedding_column( sparse_weights, dimension=dimension, combiner='sqrtn') else: new_feature = sparse_weights elif (transform_name == feature_transforms.TARGET_TRANSFORM or transform_name == feature_transforms.KEY_TRANSFORM): continue elif transform_name == feature_transforms.IMAGE_TRANSFORM: new_feature = tf.contrib.layers.real_valued_column( name, dimension=feature_transforms.IMAGE_HIDDEN_TENSOR_SIZE) else: raise ValueError('Unknown transfrom %s' % transform_name) feature_columns.append(new_feature) return feature_columns def recursive_copy(src_dir, dest_dir): """Copy the contents of src_dir into the folder dest_dir. Args: src_dir: gsc or local path. dest_dir: gcs or local path. """ file_io.recursive_create_dir(dest_dir) for file_name in file_io.list_directory(src_dir): old_path = os.path.join(src_dir, file_name) new_path = os.path.join(dest_dir, file_name) if file_io.is_directory(old_path): recursive_copy(old_path, new_path) else: file_io.copy(old_path, new_path, overwrite=True) def make_prediction_output_tensors(args, features, input_ops, model_fn_ops, keep_target): """Makes the final prediction output layer.""" target_name = feature_transforms.get_target_name(features) key_names = get_key_names(features) outputs = {} outputs.update({key_name: tf.squeeze(input_ops.features[key_name]) for key_name in key_names}) if is_classification_model(args.model): # build maps from ints to the origional categorical strings. class_names = read_vocab(args, target_name) table = tf.contrib.lookup.index_to_string_table_from_tensor( mapping=class_names, default_value='UNKNOWN') # Get the label of the input target. if keep_target: input_target_label = table.lookup(input_ops.features[target_name]) outputs[PG_TARGET] = tf.squeeze(input_target_label) # TODO(brandondutra): get the score of the target label too. probabilities = model_fn_ops.predictions['probabilities'] # if top_n == 0, this means use all the classes. We will use class names as # probabilities labels. if args.top_n == 0: predicted_index = tf.argmax(probabilities, axis=1) predicted = table.lookup(predicted_index) outputs.update({PG_CLASSIFICATION_FIRST_LABEL: predicted}) probabilities_list = tf.unstack(probabilities, axis=1) for class_name, p in zip(class_names, probabilities_list): outputs[class_name] = p else: top_n = args.top_n # get top k labels and their scores. (top_k_values, top_k_indices) = tf.nn.top_k(probabilities, k=top_n) top_k_labels = table.lookup(tf.to_int64(top_k_indices)) # Write the top_k values using 2*top_n columns. num_digits = int(math.ceil(math.log(top_n, 10))) if num_digits == 0: num_digits = 1 for i in range(0, top_n): # Pad i based on the size of k. So if k = 100, i = 23 -> i = '023'. This # makes sorting the columns easy. padded_i = str(i + 1).zfill(num_digits) if i == 0: label_alias = PG_CLASSIFICATION_FIRST_LABEL else: label_alias = PG_CLASSIFICATION_LABEL_TEMPLATE % padded_i label_tensor_name = (tf.squeeze( tf.slice(top_k_labels, [0, i], [tf.shape(top_k_labels)[0], 1]))) if i == 0: score_alias = PG_CLASSIFICATION_FIRST_SCORE else: score_alias = PG_CLASSIFICATION_SCORE_TEMPLATE % padded_i score_tensor_name = (tf.squeeze( tf.slice(top_k_values, [0, i], [tf.shape(top_k_values)[0], 1]))) outputs.update({label_alias: label_tensor_name, score_alias: score_tensor_name}) else: if keep_target: outputs[PG_TARGET] = tf.squeeze(input_ops.features[target_name]) scores = model_fn_ops.predictions['scores'] outputs[PG_REGRESSION_PREDICTED_TARGET] = tf.squeeze(scores) return outputs # This function is strongly based on # tensorflow/contrib/learn/python/learn/estimators/estimator.py:export_savedmodel() # The difference is we need to modify estimator's output layer. def make_export_strategy( args, keep_target, assets_extra, features, schema, stats): """Makes prediction graph that takes json input. Args: args: command line args keep_target: If ture, target column is returned in prediction graph. Target column must also exist in input data assets_extra: other fiels to copy to the output folder job_dir: root job folder features: features dict schema: schema list stats: stats dict """ target_name = feature_transforms.get_target_name(features) csv_header = [col['name'] for col in schema] if not keep_target: csv_header.remove(target_name) def export_fn(estimator, export_dir_base, checkpoint_path=None, eval_result=None): with ops.Graph().as_default() as g: contrib_variables.create_global_step(g) input_ops = feature_transforms.build_csv_serving_tensors_for_training_step( args.analysis, features, schema, stats, keep_target) model_fn_ops = estimator._call_model_fn(input_ops.features, None, model_fn_lib.ModeKeys.INFER) output_fetch_tensors = make_prediction_output_tensors( args=args, features=features, input_ops=input_ops, model_fn_ops=model_fn_ops, keep_target=keep_target) # Don't use signature_def_utils.predict_signature_def as that renames # tensor names if there is only 1 input/output tensor! signature_inputs = {key: tf.saved_model.utils.build_tensor_info(tensor) for key, tensor in six.iteritems(input_ops.default_inputs)} signature_outputs = {key: tf.saved_model.utils.build_tensor_info(tensor) for key, tensor in six.iteritems(output_fetch_tensors)} signature_def_map = { 'serving_default': signature_def_utils.build_signature_def( signature_inputs, signature_outputs, tf.saved_model.signature_constants.PREDICT_METHOD_NAME)} if not checkpoint_path: # Locate the latest checkpoint checkpoint_path = saver.latest_checkpoint(estimator._model_dir) if not checkpoint_path: raise ValueError("Couldn't find trained model at %s." % estimator._model_dir) export_dir = saved_model_export_utils.get_timestamped_export_dir( export_dir_base) if (model_fn_ops.scaffold is not None and model_fn_ops.scaffold.saver is not None): saver_for_restore = model_fn_ops.scaffold.saver else: saver_for_restore = saver.Saver(sharded=True) with tf_session.Session('') as session: saver_for_restore.restore(session, checkpoint_path) init_op = control_flow_ops.group( variables.local_variables_initializer(), resources.initialize_resources(resources.shared_resources()), tf.tables_initializer()) # Perform the export builder = saved_model_builder.SavedModelBuilder(export_dir) builder.add_meta_graph_and_variables( session, [tag_constants.SERVING], signature_def_map=signature_def_map, assets_collection=ops.get_collection( ops.GraphKeys.ASSET_FILEPATHS), legacy_init_op=init_op) builder.save(False) # Add the extra assets if assets_extra: assets_extra_path = os.path.join(compat.as_bytes(export_dir), compat.as_bytes('assets.extra')) for dest_relative, source in assets_extra.items(): dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative)) dest_path = os.path.dirname(dest_absolute) file_io.recursive_create_dir(dest_path) file_io.copy(source, dest_absolute) # only keep the last 3 models saved_model_export_utils.garbage_collect_exports( export_dir_base, exports_to_keep=3) # save the last model to the model folder. # export_dir_base = A/B/intermediate_models/ if keep_target: final_dir = os.path.join(args.job_dir, 'evaluation_model') else: final_dir = os.path.join(args.job_dir, 'model') if file_io.is_directory(final_dir): file_io.delete_recursively(final_dir) file_io.recursive_create_dir(final_dir) recursive_copy(export_dir, final_dir) return export_dir if keep_target: intermediate_dir = 'intermediate_evaluation_models' else: intermediate_dir = 'intermediate_prediction_models' return export_strategy.ExportStrategy(intermediate_dir, export_fn) def get_estimator(args, output_dir, features, stats, target_vocab_size): # Check layers used for dnn models. if is_dnn_model(args.model) and not args.hidden_layer_sizes: raise ValueError('--hidden-layer-size* must be used with DNN models') if is_linear_model(args.model) and args.hidden_layer_sizes: raise ValueError('--hidden-layer-size* cannot be used with linear models') # Build tf.learn features feature_columns = build_feature_columns(features, stats, args.model) # Set how often to run checkpointing in terms of steps. config = tf.contrib.learn.RunConfig( save_checkpoints_steps=args.min_eval_frequency) train_dir = os.path.join(output_dir, 'train') if args.model == 'dnn_regression': estimator = tf.contrib.learn.DNNRegressor( feature_columns=feature_columns, hidden_units=args.hidden_layer_sizes, config=config, model_dir=train_dir, optimizer=tf.train.AdamOptimizer( args.learning_rate, epsilon=args.epsilon)) elif args.model == 'linear_regression': estimator = tf.contrib.learn.LinearRegressor( feature_columns=feature_columns, config=config, model_dir=train_dir, optimizer=tf.train.FtrlOptimizer( args.learning_rate, l1_regularization_strength=args.l1_regularization, l2_regularization_strength=args.l2_regularization)) elif args.model == 'dnn_classification': estimator = tf.contrib.learn.DNNClassifier( feature_columns=feature_columns, hidden_units=args.hidden_layer_sizes, n_classes=target_vocab_size, config=config, model_dir=train_dir, optimizer=tf.train.AdamOptimizer( args.learning_rate, epsilon=args.epsilon)) elif args.model == 'linear_classification': estimator = tf.contrib.learn.LinearClassifier( feature_columns=feature_columns, n_classes=target_vocab_size, config=config, model_dir=train_dir, optimizer=tf.train.FtrlOptimizer( args.learning_rate, l1_regularization_strength=args.l1_regularization, l2_regularization_strength=args.l2_regularization)) else: raise ValueError('bad --model-type value') return estimator def read_vocab(args, column_name): """Reads a vocab file if it exists. Args: args: command line flags column_name: name of column to that has a vocab file. Returns: List of vocab words or [] if the vocab file is not found. """ vocab_path = os.path.join(args.analysis, feature_transforms.VOCAB_ANALYSIS_FILE % column_name) if not file_io.file_exists(vocab_path): return [] vocab, _ = feature_transforms.read_vocab_file(vocab_path) return vocab def get_key_names(features): names = [] for name, transform in six.iteritems(features): if transform['transform'] == feature_transforms.KEY_TRANSFORM: names.append(name) return names def read_json_file(file_path): if not file_io.file_exists(file_path): raise ValueError('File not found: %s' % file_path) return json.loads(file_io.read_file_to_string(file_path).decode()) def get_experiment_fn(args): """Builds the experiment function for learn_runner.run. Args: args: the command line args Returns: A function that returns a tf.learn experiment object. """ def get_experiment(output_dir): # Read schema, input features, and transforms. schema_path_with_target = os.path.join(args.analysis, feature_transforms.SCHEMA_FILE) features_path = os.path.join(args.analysis, feature_transforms.FEATURES_FILE) stats_path = os.path.join(args.analysis, feature_transforms.STATS_FILE) schema = read_json_file(schema_path_with_target) features = read_json_file(features_path) stats = read_json_file(stats_path) target_column_name = feature_transforms.get_target_name(features) if not target_column_name: raise ValueError('target missing from features file.') # Make a copy of the schema file without the target column. schema_without_target = [col for col in schema if col['name'] != target_column_name] schema_path_without_target = os.path.join(args.job_dir, 'schema_without_target.json') file_io.recursive_create_dir(args.job_dir) file_io.write_string_to_file(schema_path_without_target, json.dumps(schema_without_target, indent=2)) # Make list of files to save with the trained model. additional_assets_with_target = { feature_transforms.FEATURES_FILE: features_path, feature_transforms.SCHEMA_FILE: schema_path_with_target} additional_assets_without_target = { feature_transforms.FEATURES_FILE: features_path, feature_transforms.SCHEMA_FILE: schema_path_without_target} # Get the model to train. target_vocab = read_vocab(args, target_column_name) estimator = get_estimator(args, output_dir, features, stats, len(target_vocab)) export_strategy_csv_notarget = make_export_strategy( args=args, keep_target=False, assets_extra=additional_assets_without_target, features=features, schema=schema, stats=stats) export_strategy_csv_target = make_export_strategy( args=args, keep_target=True, assets_extra=additional_assets_with_target, features=features, schema=schema, stats=stats) # Build readers for training. if args.transform: if any(v['transform'] == feature_transforms.IMAGE_TRANSFORM for k, v in six.iteritems(features)): raise ValueError('"image_to_vec" transform requires transformation step. ' + 'Cannot train from raw data.') input_reader_for_train = feature_transforms.build_csv_transforming_training_input_fn( schema=schema, features=features, stats=stats, analysis_output_dir=args.analysis, raw_data_file_pattern=args.train, training_batch_size=args.train_batch_size, num_epochs=args.max_epochs, randomize_input=True, min_after_dequeue=10, reader_num_threads=multiprocessing.cpu_count()) input_reader_for_eval = feature_transforms.build_csv_transforming_training_input_fn( schema=schema, features=features, stats=stats, analysis_output_dir=args.analysis, raw_data_file_pattern=args.eval, training_batch_size=args.eval_batch_size, num_epochs=1, randomize_input=False, reader_num_threads=multiprocessing.cpu_count()) else: input_reader_for_train = feature_transforms.build_tfexample_transfored_training_input_fn( schema=schema, features=features, analysis_output_dir=args.analysis, raw_data_file_pattern=args.train, training_batch_size=args.train_batch_size, num_epochs=args.max_epochs, randomize_input=True, min_after_dequeue=10, reader_num_threads=multiprocessing.cpu_count()) input_reader_for_eval = feature_transforms.build_tfexample_transfored_training_input_fn( schema=schema, features=features, analysis_output_dir=args.analysis, raw_data_file_pattern=args.eval, training_batch_size=args.eval_batch_size, num_epochs=1, randomize_input=False, reader_num_threads=multiprocessing.cpu_count()) if args.early_stopping_num_evals == 0: train_monitors = None else: if is_classification_model(args.model): early_stop_monitor = tf.contrib.learn.monitors.ValidationMonitor( input_fn=input_reader_for_eval, every_n_steps=args.min_eval_frequency, early_stopping_rounds=(args.early_stopping_num_evals * args.min_eval_frequency), early_stopping_metric='accuracy', early_stopping_metric_minimize=False) else: early_stop_monitor = tf.contrib.learn.monitors.ValidationMonitor( input_fn=input_reader_for_eval, every_n_steps=args.min_eval_frequency, early_stopping_rounds=(args.early_stopping_num_evals * args.min_eval_frequency)) train_monitors = [early_stop_monitor] return tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=input_reader_for_train, eval_input_fn=input_reader_for_eval, train_steps=args.max_steps, train_monitors=train_monitors, export_strategies=[export_strategy_csv_notarget, export_strategy_csv_target], min_eval_frequency=args.min_eval_frequency, eval_steps=None) # Return a function to create an Experiment. return get_experiment def local_analysis(args): if args.analysis: # Already analyzed. return if not args.schema or not args.features: raise ValueError('Either --analysis, or both --schema and --features are provided.') tf_config = json.loads(os.environ.get('TF_CONFIG', '{}')) cluster_spec = tf_config.get('cluster', {}) if len(cluster_spec.get('worker', [])) > 0: raise ValueError('If "schema" and "features" are provided, local analysis will run and ' + 'only BASIC scale-tier (no workers node) is supported.') if cluster_spec and not (args.schema.startswith('gs://') and args.features.startswith('gs://')): raise ValueError('Cloud trainer requires GCS paths for --schema and --features.') print('Running analysis.') schema = json.loads(file_io.read_file_to_string(args.schema).decode()) features = json.loads(file_io.read_file_to_string(args.features).decode()) args.analysis = os.path.join(args.job_dir, 'analysis') args.transform = True file_io.recursive_create_dir(args.analysis) feature_analysis.run_local_analysis(args.analysis, args.train, schema, features) print('Analysis done.') def set_logging_level(args): if 'TF_CONFIG' in os.environ: tf.logging.set_verbosity(tf.logging.INFO) else: tf.logging.set_verbosity(tf.logging.ERROR) if args.logging_level == 'error': tf.logging.set_verbosity(tf.logging.ERROR) elif args.logging_level == 'warning': tf.logging.set_verbosity(tf.logging.WARN) elif args.logging_level == 'info': tf.logging.set_verbosity(tf.logging.INFO) def main(argv=None): args = parse_arguments(sys.argv if argv is None else argv) local_analysis(args) set_logging_level(args) # Supress TensorFlow Debugging info. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' learn_runner.run( experiment_fn=get_experiment_fn(args), output_dir=args.job_dir) if __name__ == '__main__': main() ================================================ FILE: solutionbox/ml_workbench/tensorflow/transform.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Flake8 cannot disable a warning for the file. Flake8 does not like beam code # and reports many 'W503 line break before binary operator' errors. So turn off # flake8 for this file. # flake8: noqa from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import datetime import json import logging import os import sys import apache_beam as beam import textwrap def parse_arguments(argv): """Parse command line arguments. Args: argv: list of command line arguments including program name. Returns: The parsed arguments as returned by argparse.ArgumentParser. """ parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, description=textwrap.dedent("""\ Runs preprocessing on raw data for TensorFlow training. This script applies some transformations to raw data to improve training performance. Some data transformations can be expensive such as the tf-idf text column transformation. During training, the same raw data row might be used multiply times to train a model. This means the same transformations are applied to the same data row multiple times. This can be very inefficient, so this script applies partial transformations to the raw data and writes an intermediate preprocessed datasource to disk for training. Running this transformation step is required for two usage paths: 1) If the img_url_to_vec transform is used. This is because preprocessing as image is expensive and TensorFlow cannot easily read raw image files during training. 2) If the raw data is in BigQuery. TensorFlow cannot read from a BigQuery source. Running this transformation step is recommended if a text transform is used (like tf-idf or bag-of-words), and the text value for each row is very long. Running this transformation step may not have an interesting training performance impact if the transforms are all simple like scaling numerical values.""")) source_group = parser.add_mutually_exclusive_group(required=True) source_group.add_argument( '--csv', metavar='FILE', required=False, action='append', help='CSV data to transform.') source_group.add_argument( '--bigquery', metavar='PROJECT_ID.DATASET.TABLE_NAME', type=str, required=False, help=('Must be in the form `project.dataset.table_name`. BigQuery ' 'data to transform')) parser.add_argument( '--analysis', metavar='ANALYSIS_OUTPUT_DIR', required=True, help='The output folder of analyze') parser.add_argument( '--prefix', metavar='OUTPUT_FILENAME_PREFIX', required=True, type=str) parser.add_argument( '--output', metavar='DIR', default=None, required=True, help=('Google Cloud Storage or Local directory in which ' 'to place outputs.')) parser.add_argument( '--shuffle', action='store_true', default=False, help='If used, data source is shuffled. This is recommended for training data.') parser.add_argument( '--batch-size', metavar='N', type=int, default=100, help='Larger values increase performance and peak memory usage.') cloud_group = parser.add_argument_group( title='Cloud Parameters', description='These parameters are only used if --cloud is used.') cloud_group.add_argument( '--cloud', action='store_true', help='Run preprocessing on the cloud.') cloud_group.add_argument( '--job-name', type=str, help='Unique dataflow job name.') cloud_group.add_argument( '--project-id', help='The project to which the job will be submitted.') cloud_group.add_argument( '--num-workers', metavar='N', type=int, default=0, help='Set to 0 to use the default size determined by the Dataflow service.') cloud_group.add_argument( '--worker-machine-type', metavar='NAME', type=str, help='A machine name from https://cloud.google.com/compute/docs/machine-types. ' ' If not given, the service uses the default machine type.') cloud_group.add_argument( '--async', action='store_true', help='If used, this script returns before the dataflow job is completed.') args = parser.parse_args(args=argv[1:]) if args.cloud and not args.project_id: raise ValueError('--project-id is needed for --cloud') if args.async and not args.cloud: raise ValueError('--async should only be used with --cloud') if not args.job_name: args.job_name = ('dataflow-job-{}'.format( datetime.datetime.now().strftime('%Y%m%d%H%M%S'))) return args @beam.ptransform_fn def shuffle(pcoll): # pylint: disable=invalid-name import random return (pcoll | 'PairWithRandom' >> beam.Map(lambda x: (random.random(), x)) | 'GroupByRandom' >> beam.GroupByKey() | 'DropRandom' >> beam.FlatMap(lambda (k, vs): vs)) def image_transform_columns(features): """Returns a list of columns that prepare_image_transforms() should run on. Because of beam + pickle, IMAGE_URL_TO_VEC_TRANSFORM cannot be used inside of a beam function, so we extract the columns prepare_image_transforms() should run on outside of beam. """ import six from trainer import feature_transforms img_cols = [] for name, transform in six.iteritems(features): if transform['transform'] == feature_transforms.IMAGE_TRANSFORM: img_cols.append(name) return img_cols def prepare_image_transforms(element, image_columns): """Replace an images url with its jpeg bytes. Args: element: one input row, as a dict image_columns: list of columns that are image paths Return: element, where each image file path has been replaced by a base64 image. """ import base64 import cStringIO from PIL import Image from tensorflow.python.lib.io import file_io as tf_file_io from apache_beam.metrics import Metrics img_error_count = Metrics.counter('main', 'ImgErrorCount') img_missing_count = Metrics.counter('main', 'ImgMissingCount') for name in image_columns: uri = element[name] if not uri: img_missing_count.inc() continue try: with tf_file_io.FileIO(uri, 'r') as f: img = Image.open(f).convert('RGB') # A variety of different calling libraries throw different exceptions here. # They all correspond to an unreadable file so we treat them equivalently. # pylint: disable broad-except except Exception as e: logging.exception('Error processing image %s: %s', uri, str(e)) img_error_count.inc() return # Convert to desired format and output. output = cStringIO.StringIO() img.save(output, 'jpeg') element[name] = base64.urlsafe_b64encode(output.getvalue()) return element class EmitAsBatchDoFn(beam.DoFn): """A DoFn that buffers the records and emits them batch by batch.""" def __init__(self, batch_size): """Constructor of EmitAsBatchDoFn beam.DoFn class. Args: batch_size: the max size we want to buffer the records before emitting. """ self._batch_size = batch_size self._cached = [] def process(self, element): self._cached.append(element) if len(self._cached) >= self._batch_size: emit = self._cached self._cached = [] yield emit def finish_bundle(self, element=None): from apache_beam.transforms import window from apache_beam.utils.windowed_value import WindowedValue if len(self._cached) > 0: # pylint: disable=g-explicit-length-test yield WindowedValue(self._cached, -1, [window.GlobalWindow()]) class TransformFeaturesDoFn(beam.DoFn): """Converts raw data into transformed data.""" def __init__(self, analysis_output_dir, features, schema, stats): self._analysis_output_dir = analysis_output_dir self._features = features self._schema = schema self._stats = stats self._session = None def start_bundle(self, element=None): """Build the transfromation graph once.""" import tensorflow as tf from trainer import feature_transforms g = tf.Graph() session = tf.Session(graph=g) # Build the transformation graph with g.as_default(): transformed_features, _, placeholders = ( feature_transforms.build_csv_serving_tensors_for_transform_step( analysis_path=self._analysis_output_dir, features=self._features, schema=self._schema, stats=self._stats, keep_target=True)) session.run(tf.tables_initializer()) self._session = session self._transformed_features = transformed_features self._input_placeholder_tensor = placeholders['csv_example'] def finish_bundle(self, element=None): self._session.close() def process(self, element): """Run the transformation graph on batched input data Args: element: list of csv strings, representing one batch input to the TF graph. Returns: dict containing the transformed data. Results are un-batched. Sparse tensors are converted to lists. """ import apache_beam as beam import six import tensorflow as tf # This function is invoked by a separate sub-process so setting the logging level # does not affect Datalab's kernel process. tf.logging.set_verbosity(tf.logging.ERROR) try: clean_element = [] for line in element: clean_element.append(line.rstrip()) # batch_result is list of numpy arrays with batch_size many rows. batch_result = self._session.run( fetches=self._transformed_features, feed_dict={self._input_placeholder_tensor: clean_element}) # ex batch_result. # Dense tensor: {'col1': array([[batch_1], [batch_2]])} # Sparse tensor: {'col1': tf.SparseTensorValue( # indices=array([[batch_1, 0], [batch_1, 1], ..., # [batch_2, 0], [batch_2, 1], ...]], # values=array[value, value, value, ...])} # Unbatch the results. for i in range(len(clean_element)): transformed_features = {} for name, value in six.iteritems(batch_result): if isinstance(value, tf.SparseTensorValue): batch_i_indices = value.indices[:, 0] == i batch_i_values = value.values[batch_i_indices] transformed_features[name] = batch_i_values.tolist() else: transformed_features[name] = value[i].tolist() yield transformed_features except Exception as e: # pylint: disable=broad-except yield beam.pvalue.TaggedOutput('errors', (str(e), element)) def decode_csv(csv_string, column_names): """Parse a csv line into a dict. Args: csv_string: a csv string. May contain missing values "a,,c" column_names: list of column names Returns: Dict of {column_name, value_from_csv}. If there are missing values, value_from_csv will be ''. """ import csv r = next(csv.reader([csv_string])) if len(r) != len(column_names): raise ValueError('csv line %s does not have %d columns' % (csv_string, len(column_names))) return {k: v for k, v in zip(column_names, r)} def encode_csv(data_dict, column_names): """Builds a csv string. Args: data_dict: dict of {column_name: 1 value} column_names: list of column names Returns: A csv string version of data_dict """ import csv import six values = [str(data_dict[x]) for x in column_names] str_buff = six.StringIO() writer = csv.writer(str_buff, lineterminator='') writer.writerow(values) return str_buff.getvalue() def serialize_example(transformed_json_data, info_dict): """Makes a serialized tf.example. Args: transformed_json_data: dict of transformed data. info_dict: output of feature_transforms.get_transfrormed_feature_info() Returns: The serialized tf.example version of transformed_json_data. """ import six import tensorflow as tf def _make_int64_list(x): return tf.train.Feature(int64_list=tf.train.Int64List(value=x)) def _make_bytes_list(x): return tf.train.Feature(bytes_list=tf.train.BytesList(value=x)) def _make_float_list(x): return tf.train.Feature(float_list=tf.train.FloatList(value=x)) if sorted(six.iterkeys(transformed_json_data)) != sorted(six.iterkeys(info_dict)): raise ValueError('Keys do not match %s, %s' % (list(six.iterkeys(transformed_json_data)), list(six.iterkeys(info_dict)))) ex_dict = {} for name, info in six.iteritems(info_dict): if info['dtype'] == tf.int64: ex_dict[name] = _make_int64_list(transformed_json_data[name]) elif info['dtype'] == tf.float32: ex_dict[name] = _make_float_list(transformed_json_data[name]) elif info['dtype'] == tf.string: ex_dict[name] = _make_bytes_list(transformed_json_data[name]) else: raise ValueError('Unsupported data type %s' % info['dtype']) ex = tf.train.Example(features=tf.train.Features(feature=ex_dict)) return ex.SerializeToString() def preprocess(pipeline, args): """Transfrom csv data into transfromed tf.example files. Outline: 1) read the input data (as csv or bigquery) into a dict format 2) replace image paths with base64 encoded image files 3) build a csv input string with images paths replaced with base64. This matches the serving csv that a trained model would expect. 4) batch the csv strings 5) run the transformations 6) write the results to tf.example files and save any errors. """ from tensorflow.python.lib.io import file_io from trainer import feature_transforms schema = json.loads(file_io.read_file_to_string( os.path.join(args.analysis, feature_transforms.SCHEMA_FILE)).decode()) features = json.loads(file_io.read_file_to_string( os.path.join(args.analysis, feature_transforms.FEATURES_FILE)).decode()) stats = json.loads(file_io.read_file_to_string( os.path.join(args.analysis, feature_transforms.STATS_FILE)).decode()) column_names = [col['name'] for col in schema] if args.csv: all_files = [] for i, file_pattern in enumerate(args.csv): all_files.append(pipeline | ('ReadCSVFile%d' % i) >> beam.io.ReadFromText(file_pattern)) raw_data = ( all_files | 'MergeCSVFiles' >> beam.Flatten() | 'ParseCSVData' >> beam.Map(decode_csv, column_names)) else: columns = ', '.join(column_names) query = 'SELECT {columns} FROM `{table}`'.format(columns=columns, table=args.bigquery) raw_data = ( pipeline | 'ReadBiqQueryData' >> beam.io.Read(beam.io.BigQuerySource(query=query, use_standard_sql=True))) # Note that prepare_image_transforms does not make embeddings, it justs reads # the image files and converts them to byte stings. TransformFeaturesDoFn() # will make the image embeddings. image_columns = image_transform_columns(features) clean_csv_data = ( raw_data | 'PreprocessTransferredLearningTransformations' >> beam.Map(prepare_image_transforms, image_columns) | 'BuildCSVString' >> beam.Map(encode_csv, column_names)) if args.shuffle: clean_csv_data = clean_csv_data | 'ShuffleData' >> shuffle() transform_dofn = TransformFeaturesDoFn(args.analysis, features, schema, stats) (transformed_data, errors) = ( clean_csv_data | 'Batch Input' >> beam.ParDo(EmitAsBatchDoFn(args.batch_size)) | 'Run TF Graph on Batches' >> beam.ParDo(transform_dofn).with_outputs('errors', main='main')) _ = (transformed_data | 'SerializeExamples' >> beam.Map(serialize_example, feature_transforms.get_transformed_feature_info(features, schema)) | 'WriteExamples' >> beam.io.WriteToTFRecord( os.path.join(args.output, args.prefix), file_name_suffix='.tfrecord.gz')) _ = (errors | 'WriteErrors' >> beam.io.WriteToText( os.path.join(args.output, 'errors_' + args.prefix), file_name_suffix='.txt')) def main(argv=None): """Run Preprocessing as a Dataflow.""" args = parse_arguments(sys.argv if argv is None else argv) temp_dir = os.path.join(args.output, 'tmp') if args.cloud: pipeline_name = 'DataflowRunner' else: pipeline_name = 'DirectRunner' # Suppress TF warnings. os.environ['TF_CPP_MIN_LOG_LEVEL']='3' options = { 'job_name': args.job_name, 'temp_location': temp_dir, 'project': args.project_id, 'setup_file': os.path.abspath(os.path.join( os.path.dirname(__file__), 'setup.py')), } if args.num_workers: options['num_workers'] = args.num_workers if args.worker_machine_type: options['worker_machine_type'] = args.worker_machine_type pipeline_options = beam.pipeline.PipelineOptions(flags=[], **options) p = beam.Pipeline(pipeline_name, options=pipeline_options) preprocess(pipeline=p, args=args) pipeline_result = p.run() if not args.async: pipeline_result.wait_until_finish() if args.async and args.cloud: print('View job at https://console.developers.google.com/dataflow/job/%s?project=%s' % (pipeline_result.job_id(), args.project_id)) if __name__ == '__main__': main() ================================================ FILE: solutionbox/ml_workbench/test_tensorflow/run_all.sh ================================================ #! /bin/bash set -e echo '*** Running tensorflow test_analyze.py ***' python test_analyze.py --verbose echo '*** Running tensorflow test_feature_transforms.py ***' python test_feature_transforms.py --verbose echo '*** Running tensorflow test_transform.py ***' python test_transform.py --verbose echo '*** Running tensorflow test_training.py ***' python test_training.py --verbose echo 'Finished tensorflow run_all.sh!' ================================================ FILE: solutionbox/ml_workbench/test_tensorflow/test_analyze.py ================================================ from __future__ import absolute_import from __future__ import print_function import json import os import shutil import subprocess import sys import tempfile import uuid import unittest import pandas as pd import six from tensorflow.python.lib.io import file_io import google.datalab as dl import google.datalab.bigquery as bq import google.datalab.storage as storage # To make 'import analyze' work without installing it. CODE_PATH = os.path.abspath( os.path.join(os.path.dirname(__file__), '..', '', 'tensorflow')) sys.path.append(CODE_PATH) from trainer import feature_analysis as feature_analysis # noqa: E303 import analyze # noqa: E303 # TODO: travis tests failed because sometimes a VM has gcloud signed-in # (maybe due to failed cleanup) with default project set and BQ is not enabled. # In that case the cloud tests will fail. Disable it for now. RUN_CLOUD_TESTS = False class TestConfigFiles(unittest.TestCase): """Tests for checking the format between the schema and features files.""" def test_expand_defaults_do_nothing(self): schema = [{'name': 'col1', 'type': 'FLOAT'}, {'name': 'col2', 'type': 'INTEGER'}] features = {'col1': {'transform': 'x'}, 'col2': {'transform': 'y'}} expected_features = { 'col1': {'transform': 'x', 'source_column': 'col1'}, 'col2': {'transform': 'y', 'source_column': 'col2'}} feature_analysis.expand_defaults(schema, features) # Nothing should change. self.assertEqual(expected_features, features) def test_expand_defaults_unknown_schema_type(self): schema = [{'name': 'col1', 'type': 'BYTES'}, {'name': 'col2', 'type': 'INTEGER'}] features = {'col1': {'transform': 'x'}, 'col2': {'transform': 'y'}} with self.assertRaises(ValueError): feature_analysis.expand_defaults(schema, features) def test_expand_defaults(self): schema = [{'name': 'col1', 'type': 'FLOAT'}, {'name': 'col2', 'type': 'INTEGER'}, {'name': 'col3', 'type': 'STRING'}, {'name': 'col4', 'type': 'FLOAT'}, {'name': 'col5', 'type': 'INTEGER'}, {'name': 'col6', 'type': 'STRING'}] features = {'col1': {'transform': 'x'}, 'col2': {'transform': 'y'}, 'col3': {'transform': 'z'}} feature_analysis.expand_defaults(schema, features) self.assertEqual( features, {'col1': {'transform': 'x', 'source_column': 'col1'}, 'col2': {'transform': 'y', 'source_column': 'col2'}, 'col3': {'transform': 'z', 'source_column': 'col3'}, 'col4': {'transform': 'identity', 'source_column': 'col4'}, 'col5': {'transform': 'identity', 'source_column': 'col5'}, 'col6': {'transform': 'one_hot', 'source_column': 'col6'}}) def test_check_schema_transforms_match(self): with self.assertRaises(ValueError): feature_analysis.check_schema_transforms_match( [{'name': 'col1', 'type': 'INTEGER'}], feature_analysis.invert_features( {'col1': {'transform': 'one_hot', 'source_column': 'col1'}})) with self.assertRaises(ValueError): feature_analysis.check_schema_transforms_match( [{'name': 'col1', 'type': 'FLOAT'}], feature_analysis.invert_features( {'col1': {'transform': 'embedding', 'source_column': 'col1'}})) with self.assertRaises(ValueError): feature_analysis.check_schema_transforms_match( [{'name': 'col1', 'type': 'STRING'}], feature_analysis.invert_features( {'col1': {'transform': 'scale', 'source_column': 'col1'}})) with self.assertRaises(ValueError): feature_analysis.check_schema_transforms_match( [{'name': 'col1', 'type': 'xxx'}], feature_analysis.invert_features( {'col1': {'transform': 'scale', 'source_column': 'col1'}})) with self.assertRaises(ValueError): feature_analysis.check_schema_transforms_match( [{'name': 'col1', 'type': 'INTEGER'}], feature_analysis.invert_features( {'col1': {'transform': 'xxx', 'source_column': 'col1'}})) with self.assertRaises(ValueError): # scale and one_hot different transform family feature_analysis.check_schema_transforms_match( [{'name': 'col1', 'type': 'INTEGER'}], feature_analysis.invert_features( {'col1': {'transform': 'scale', 'source_column': 'col1'}, 'col2': {'transform': 'one_hot', 'source_column': 'col1'}, 'col3': {'transform': 'key', 'source_column': 'col1'}})) with self.assertRaises(ValueError): # Unknown transform feature_analysis.check_schema_transforms_match( [{'name': 'col1', 'type': 'INTEGER'}], feature_analysis.invert_features({'col1': {'transform': 'x', 'source_column': 'col1'}})) class TestLocalAnalyze(unittest.TestCase): """Test local analyze functions.""" def test_numerics(self): output_folder = tempfile.mkdtemp() input_file_path = tempfile.mkstemp(dir=output_folder)[1] try: file_io.write_string_to_file( input_file_path, '\n'.join(['%s,%s,%s' % (i, 10 * i + 0.5, i + 0.5) for i in range(100)])) schema = [{'name': 'col1', 'type': 'INTEGER'}, {'name': 'col2', 'type': 'FLOAT'}, {'name': 'col3', 'type': 'FLOAT'}] features = {'col1': {'transform': 'scale', 'source_column': 'col1'}, 'col2': {'transform': 'identity', 'source_column': 'col2'}, 'col3': {'transform': 'target'}} feature_analysis.run_local_analysis( output_folder, [input_file_path], schema, features) stats = json.loads( file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.STATS_FILE)).decode()) self.assertEqual(stats['num_examples'], 100) col = stats['column_stats']['col1'] self.assertAlmostEqual(col['max'], 99.0) self.assertAlmostEqual(col['min'], 0.0) self.assertAlmostEqual(col['mean'], 49.5) col = stats['column_stats']['col2'] self.assertAlmostEqual(col['max'], 990.5) self.assertAlmostEqual(col['min'], 0.5) self.assertAlmostEqual(col['mean'], 495.5) finally: shutil.rmtree(output_folder) def test_categorical(self): output_folder = tempfile.mkdtemp() input_file_path = tempfile.mkstemp(dir=output_folder)[1] try: csv_file = ['red,car,apple', 'red,truck,pepper', 'red,van,apple', 'blue,bike,grape', 'blue,train,apple', 'green,airplane,pepper'] file_io.write_string_to_file( input_file_path, '\n'.join(csv_file)) schema = [{'name': 'color', 'type': 'STRING'}, {'name': 'transport', 'type': 'STRING'}, {'name': 'type', 'type': 'STRING'}] features = {'color': {'transform': 'one_hot', 'source_column': 'color'}, 'transport': {'transform': 'embedding', 'source_column': 'transport'}, 'type': {'transform': 'target'}} feature_analysis.run_local_analysis( output_folder, [input_file_path], schema, features) stats = json.loads( file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.STATS_FILE)).decode()) self.assertEqual(stats['column_stats']['color']['vocab_size'], 3) self.assertEqual(stats['column_stats']['transport']['vocab_size'], 6) # Color column. vocab_str = file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.VOCAB_ANALYSIS_FILE % 'color')) vocab = pd.read_csv(six.StringIO(vocab_str), header=None, names=['color', 'count']) expected_vocab = pd.DataFrame( {'color': ['red', 'blue', 'green'], 'count': [3, 2, 1]}, columns=['color', 'count']) pd.util.testing.assert_frame_equal(vocab, expected_vocab) # transport column. As each vocab has the same count, order in file is # not known. vocab_str = file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.VOCAB_ANALYSIS_FILE % 'transport')) vocab = pd.read_csv(six.StringIO(vocab_str), header=None, names=['transport', 'count']) self.assertEqual(vocab['count'].tolist(), [1 for i in range(6)]) self.assertItemsEqual(vocab['transport'].tolist(), ['car', 'truck', 'van', 'bike', 'train', 'airplane']) finally: shutil.rmtree(output_folder) def test_text(self): output_folder = tempfile.mkdtemp() input_file_path = tempfile.mkstemp(dir=output_folder)[1] try: csv_file = ['the quick brown fox,raining in kir,cat1|cat2,true', 'quick brown brown chicken,raining in pdx,cat2|cat3|cat4,false'] file_io.write_string_to_file( input_file_path, '\n'.join(csv_file)) schema = [{'name': 'col1', 'type': 'STRING'}, {'name': 'col2', 'type': 'STRING'}, {'name': 'col3', 'type': 'STRING'}, {'name': 'col4', 'type': 'STRING'}] features = {'col1': {'transform': 'bag_of_words', 'source_column': 'col1'}, 'col2': {'transform': 'tfidf', 'source_column': 'col2'}, 'col3': {'transform': 'multi_hot', 'source_column': 'col3', 'separator': '|'}, 'col4': {'transform': 'target'}} feature_analysis.run_local_analysis( output_folder, [input_file_path], schema, features) stats = json.loads( file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.STATS_FILE)).decode()) self.assertEqual(stats['column_stats']['col1']['vocab_size'], 5) self.assertEqual(stats['column_stats']['col2']['vocab_size'], 4) self.assertEqual(stats['column_stats']['col3']['vocab_size'], 4) vocab_str = file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.VOCAB_ANALYSIS_FILE % 'col1')) vocab = pd.read_csv(six.StringIO(vocab_str), header=None, names=['col1', 'count']) # vocabs are sorted by count only col1_vocab = vocab['col1'].tolist() self.assertItemsEqual(col1_vocab[:2], ['brown', 'quick']) self.assertItemsEqual(col1_vocab[2:], ['chicken', 'fox', 'the']) self.assertEqual(vocab['count'].tolist(), [2, 2, 1, 1, 1]) vocab_str = file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.VOCAB_ANALYSIS_FILE % 'col2')) vocab = pd.read_csv(six.StringIO(vocab_str), header=None, names=['col2', 'count']) # vocabs are sorted by count only col2_vocab = vocab['col2'].tolist() self.assertItemsEqual(col2_vocab[:2], ['in', 'raining']) self.assertItemsEqual(col2_vocab[2:], ['kir', 'pdx']) self.assertEqual(vocab['count'].tolist(), [2, 2, 1, 1]) finally: shutil.rmtree(output_folder) @unittest.skipIf(not RUN_CLOUD_TESTS, 'GCS access missing') class TestCloudAnalyzeFromBQTable(unittest.TestCase): """Test the analyze functions using data in a BigQuery table. As the SQL statements do not change if the BigQuery source is csv fiels or a real table, there is no need to test every SQL analyze statement. We only run one test to make sure this path works. """ def test_numerics(self): """Build a BQ table, and then call analyze on it.""" schema = [{'name': 'col1', 'type': 'INTEGER'}, {'name': 'col2', 'type': 'FLOAT'}, {'name': 'col3', 'type': 'FLOAT'}] project_id = dl.Context.default().project_id dataset_name = 'temp_pydatalab_test_%s' % uuid.uuid4().hex table_name = 'temp_table' full_table_name = '%s.%s.%s' % (project_id, dataset_name, table_name) output_folder = tempfile.mkdtemp() try: # Make a dataset, a table, and insert data. db = bq.Dataset((project_id, dataset_name)) db.create() table = bq.Table(full_table_name) table.create(schema=bq.Schema(schema), overwrite=True) data = [{'col1': i, 'col2': 10 * i + 0.5, 'col3': i + 0.5} for i in range(100)] table.insert(data) features = {'col1': {'transform': 'scale', 'source_column': 'col1'}, 'col2': {'transform': 'identity', 'source_column': 'col2'}, 'col3': {'transform': 'target'}} analyze.run_cloud_analysis( output_dir=output_folder, csv_file_pattern=None, bigquery_table=full_table_name, schema=schema, features=features) stats = json.loads( file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.STATS_FILE)).decode()) self.assertEqual(stats['num_examples'], 100) col = stats['column_stats']['col1'] self.assertAlmostEqual(col['max'], 99.0) self.assertAlmostEqual(col['min'], 0.0) self.assertAlmostEqual(col['mean'], 49.5) col = stats['column_stats']['col2'] self.assertAlmostEqual(col['max'], 990.5) self.assertAlmostEqual(col['min'], 0.5) self.assertAlmostEqual(col['mean'], 495.5) finally: shutil.rmtree(output_folder) db.delete(delete_contents=True) @unittest.skipIf(not RUN_CLOUD_TESTS, 'GCS access missing') class TestCloudAnalyzeFromCSVFiles(unittest.TestCase): """Test the analyze function using BigQuery from csv files that are on GCS.""" @classmethod def setUpClass(cls): cls._bucket_name = 'temp_pydatalab_test_%s' % uuid.uuid4().hex cls._bucket_root = 'gs://%s' % cls._bucket_name storage.Bucket(cls._bucket_name).create() @classmethod def tearDownClass(cls): bucket = storage.Bucket(cls._bucket_name) for obj in bucket.objects(): obj.delete() bucket.delete() def test_numerics(self): test_folder = os.path.join(self._bucket_root, 'test_numerics') input_file_path = os.path.join(test_folder, 'input.csv') output_folder = os.path.join(test_folder, 'test_output') file_io.recursive_create_dir(output_folder) file_io.write_string_to_file( input_file_path, '\n'.join(['%s,%s,%s' % (i, 10 * i + 0.5, i) for i in range(100)])) schema = [{'name': 'col1', 'type': 'INTEGER'}, {'name': 'col2', 'type': 'FLOAT'}, {'name': 'col3', 'type': 'FLOAT'}] features = {'col1': {'transform': 'scale', 'source_column': 'col1'}, 'col2': {'transform': 'identity', 'source_column': 'col2'}, 'col3': {'transform': 'target'}} analyze.run_cloud_analysis( output_dir=output_folder, csv_file_pattern=input_file_path, bigquery_table=None, schema=schema, features=features) stats = json.loads( file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.STATS_FILE)).decode()) self.assertEqual(stats['num_examples'], 100) col = stats['column_stats']['col1'] self.assertAlmostEqual(col['max'], 99.0) self.assertAlmostEqual(col['min'], 0.0) self.assertAlmostEqual(col['mean'], 49.5) col = stats['column_stats']['col2'] self.assertAlmostEqual(col['max'], 990.5) self.assertAlmostEqual(col['min'], 0.5) self.assertAlmostEqual(col['mean'], 495.5) def test_categorical(self): test_folder = os.path.join(self._bucket_root, 'test_categorical') input_file_path = os.path.join(test_folder, 'input.csv') output_folder = os.path.join(test_folder, 'test_output') file_io.recursive_create_dir(output_folder) csv_file = ['red,car,apple', 'red,truck,pepper', 'red,van,apple', 'blue,bike,grape', 'blue,train,apple', 'green,airplane,pepper'] file_io.write_string_to_file( input_file_path, '\n'.join(csv_file)) schema = [{'name': 'color', 'type': 'STRING'}, {'name': 'transport', 'type': 'STRING'}, {'name': 'type', 'type': 'STRING'}] features = {'color': {'transform': 'one_hot', 'source_column': 'color'}, 'transport': {'transform': 'embedding', 'source_column': 'transport'}, 'type': {'transform': 'target'}} analyze.run_cloud_analysis( output_dir=output_folder, csv_file_pattern=input_file_path, bigquery_table=None, schema=schema, features=features) stats = json.loads( file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.STATS_FILE)).decode()) self.assertEqual(stats['column_stats']['color']['vocab_size'], 3) self.assertEqual(stats['column_stats']['transport']['vocab_size'], 6) # Color column. vocab_str = file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.VOCAB_ANALYSIS_FILE % 'color')) vocab = pd.read_csv(six.StringIO(vocab_str), header=None, names=['color', 'count']) expected_vocab = pd.DataFrame( {'color': ['red', 'blue', 'green'], 'count': [3, 2, 1]}, columns=['color', 'count']) pd.util.testing.assert_frame_equal(vocab, expected_vocab) # transport column. vocab_str = file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.VOCAB_ANALYSIS_FILE % 'transport')) vocab = pd.read_csv(six.StringIO(vocab_str), header=None, names=['transport', 'count']) self.assertEqual(vocab['count'].tolist(), [1 for i in range(6)]) self.assertEqual(vocab['transport'].tolist(), ['airplane', 'bike', 'car', 'train', 'truck', 'van']) def test_text(self): test_folder = os.path.join(self._bucket_root, 'test_text') input_file_path = os.path.join(test_folder, 'input.csv') output_folder = os.path.join(test_folder, 'test_output') file_io.recursive_create_dir(output_folder) csv_file = ['the quick brown fox,raining in kir,cat1|cat2,true', 'quick brown brown chicken,raining in pdx,cat2|cat3|cat4,false'] file_io.write_string_to_file( input_file_path, '\n'.join(csv_file)) schema = [{'name': 'col1', 'type': 'STRING'}, {'name': 'col2', 'type': 'STRING'}, {'name': 'col3', 'type': 'STRING'}, {'name': 'col4', 'type': 'STRING'}] features = {'col1': {'transform': 'bag_of_words', 'source_column': 'col1'}, 'col2': {'transform': 'tfidf', 'source_column': 'col2'}, 'col3': {'transform': 'multi_hot', 'source_column': 'col3', 'separator': '|'}, 'col4': {'transform': 'target'}} analyze.run_cloud_analysis( output_dir=output_folder, csv_file_pattern=input_file_path, bigquery_table=None, schema=schema, features=features) stats = json.loads( file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.STATS_FILE)).decode()) self.assertEqual(stats['column_stats']['col1']['vocab_size'], 5) self.assertEqual(stats['column_stats']['col2']['vocab_size'], 4) self.assertEqual(stats['column_stats']['col3']['vocab_size'], 4) vocab_str = file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.VOCAB_ANALYSIS_FILE % 'col1')) vocab = pd.read_csv(six.StringIO(vocab_str), header=None, names=['col1', 'count']) self.assertEqual(vocab['col1'].tolist(), ['brown', 'quick', 'chicken', 'fox', 'the', ]) self.assertEqual(vocab['count'].tolist(), [2, 2, 1, 1, 1]) vocab_str = file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.VOCAB_ANALYSIS_FILE % 'col2')) vocab = pd.read_csv(six.StringIO(vocab_str), header=None, names=['col2', 'count']) self.assertEqual(vocab['col2'].tolist(), ['in', 'raining', 'kir', 'pdx']) self.assertEqual(vocab['count'].tolist(), [2, 2, 1, 1]) class TestOneSourceColumnManyFeatures(unittest.TestCase): """Test input column can be used more than once.""" def test_multiple_usage(self): def _make_csv_row(i): """Makes a csv file with the following header. target,number,category,text,image """ return "%d,%d,%s,%s,%s" % (i * 2, i, 'red' if i % 2 else 'blue', 'hello world' if i % 2 else 'bye moon', '/image%d.jpeg' % i) output_folder = tempfile.mkdtemp() try: input_data_path = tempfile.mkstemp(dir=output_folder, prefix='data')[1] file_io.write_string_to_file( input_data_path, '\n'.join([_make_csv_row(i) for i in range(100)])) input_schema_path = tempfile.mkstemp(dir=output_folder, prefix='sch')[1] file_io.write_string_to_file( input_schema_path, json.dumps([{'name': 'target', 'type': 'INTEGER'}, {'name': 'int', 'type': 'INTEGER'}, {'name': 'cat', 'type': 'STRING'}, {'name': 'text', 'type': 'STRING'}, {'name': 'img', 'type': 'STRING'}], indent=2)) input_feature_path = tempfile.mkstemp(dir=output_folder, prefix='feat')[1] file_io.write_string_to_file( input_feature_path, json.dumps({'target': {'transform': 'target'}, 'int': {'transform': 'scale'}, 'int2': {'transform': 'identity', 'source_column': 'int'}, 'int3': {'transform': 'key', 'source_column': 'int'}, 'cat1': {'transform': 'one_hot', 'source_column': 'cat'}, 'cat2': {'transform': 'embedding', 'source_column': 'cat'}, 'text': {'transform': 'tfidf', 'source_column': 'text'}, 'text2': {'transform': 'bag_of_words', 'source_column': 'text'}, 'text3': {'transform': 'key', 'source_column': 'text'}, 'img': {'transform': 'image_to_vec'}}, indent=2)) cmd = ['python %s/analyze.py' % CODE_PATH, '--output=' + output_folder, '--csv=' + input_data_path, '--schema=' + input_schema_path, '--features=' + input_feature_path] subprocess.check_call(' '.join(cmd), shell=True) self.assertTrue(os.path.isfile(os.path.join(output_folder, 'vocab_cat.csv'))) self.assertTrue(os.path.isfile(os.path.join(output_folder, 'vocab_text.csv'))) stats = json.loads( file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.STATS_FILE)).decode()) self.assertEqual(stats['num_examples'], 100) col = stats['column_stats']['int'] self.assertAlmostEqual(col['max'], 99.0) self.assertAlmostEqual(col['min'], 0.0) self.assertAlmostEqual(col['mean'], 49.5) col = stats['column_stats']['target'] self.assertAlmostEqual(col['max'], 198.0) self.assertAlmostEqual(col['min'], 0.0) self.assertAlmostEqual(col['mean'], 99.0) col = stats['column_stats']['cat'] self.assertEqual(col['vocab_size'], 2) finally: shutil.rmtree(output_folder) if __name__ == '__main__': unittest.main() ================================================ FILE: solutionbox/ml_workbench/test_tensorflow/test_cloud_workflow.py ================================================ from __future__ import absolute_import import base64 import json import logging import os from PIL import Image import random import six import shutil import subprocess import sys import tempfile import unittest import uuid from tensorflow.python.lib.io import file_io CODE_PATH = os.path.abspath(os.path.join( os.path.dirname(__file__), '..', 'tensorflow')) class TestCloudServices(unittest.TestCase): """Tests everything using the cloud services. Run cloud analyze, cloud transformation, cloud training, and cloud batch prediction. Easy step is done by making a subprocess call to python or gcloud. Each step has a local 'cloud' variable that can be mannually set to False to run the local version of the step. This is usefull when debugging as not every step needs to use cloud services. Because of the cloud overhead, this test easily takes 30-40 mins to finish. Test files will be uploaded into a new bucket named temp_pydatalab_test_* using the default project from gcloud. The bucket is removed at the end of the test. To run this test, the following evironment is needed: * gcloud version >= 156 * gcloud with a default project that has access to dataflow and ml engine * google-cloud-dataflow 0.6.0 """ def __init__(self, *args, **kwargs): super(TestCloudServices, self).__init__(*args, **kwargs) self._max_steps = 2000 # Log everything self._logger = logging.getLogger('TestStructuredDataLogger') self._logger.setLevel(logging.DEBUG) if not self._logger.handlers: self._logger.addHandler(logging.StreamHandler(stream=sys.stdout)) def setUp(self): random.seed(12321) self._local_dir = tempfile.mkdtemp() # Local folder for temp files. self._gs_dir = 'gs://temp_pydatalab_test_%s' % uuid.uuid4().hex subprocess.check_call('gsutil mb %s' % self._gs_dir, shell=True) self._input_files = os.path.join(self._gs_dir, 'input_files') self._analysis_output = os.path.join(self._gs_dir, 'analysis_output') self._transform_output = os.path.join(self._gs_dir, 'transform_output') self._train_output = os.path.join(self._gs_dir, 'train_output') self._prediction_output = os.path.join(self._gs_dir, 'prediction_output') file_io.recursive_create_dir(self._input_files) self._csv_train_filename = os.path.join(self._input_files, 'train_csv_data.csv') self._csv_eval_filename = os.path.join(self._input_files, 'eval_csv_data.csv') self._csv_predict_filename = os.path.join(self._input_files, 'predict_csv_data.csv') self._schema_filename = os.path.join(self._input_files, 'schema_file.json') self._features_filename = os.path.join(self._input_files, 'features_file.json') self._image_files = None def tearDown(self): self._logger.debug('TestCloudServices: removing folders %s, %s' % (self._local_dir, self._gs_dir)) shutil.rmtree(self._local_dir) subprocess.check_call('gsutil -m rm -r %s' % self._gs_dir, shell=True) def _make_image_files(self): """Makes random images and uploads them to GCS. The images are first made locally and then moved to GCS for speed. """ self._image_files = [] for i in range(10): r = random.randint(0, 255) g = random.randint(0, 255) b = random.randint(0, 255) img_name = 'img%02d.jpg' % i local_img = os.path.join(self._local_dir, img_name) img = Image.new('RGBA', size=(300, 300), color=(155, 0, 0)) img.save(local_img) self._image_files.append((r, g, b, os.path.join(self._input_files, img_name))) cmd = 'gsutil -m mv %s/img*.jpg %s/' % (self._local_dir, self._input_files) subprocess.check_call(cmd, shell=True) def _make_csv_data(self, filename, num_rows, keep_target=True, embedded_image=False): """Writes csv data. Builds a linear model that uses 1 numerical column and an image column. Args: filename: gcs filepath num_rows: how many rows of data will be generated. keep_target: if false, the target column is missing. embedded_image: if true, the image column will be the base64 data """ def _drop_out(x): # Make 5% of the data missing if random.uniform(0, 1) < 0.05: return '' return x local_file = os.path.join(self._local_dir, 'data.csv') with open(local_file, 'w') as f: for i in range(num_rows): num = random.randint(0, 20) r, g, b, img_path = random.choice(self._image_files) if embedded_image: with file_io.FileIO(img_path, 'r') as img_file: img_bytes = Image.open(img_file) buf = six.StringIO() img_bytes.save(buf, 'JPEG') img_data = base64.urlsafe_b64encode(buf.getvalue()) else: img_data = img_path # Build a simple linear model t = -10 + 0.5 * num + 0.1 * r num = _drop_out(num) if num is not '': # Don't drop every column img_data = _drop_out(img_data) if keep_target: csv_line = "{key},{target},{num},{img_data}\n".format( key=i, target=t, num=num, img_data=img_data) else: csv_line = "{key},{num},{img_data}\n".format( key=i, num=num, img_data=img_data) f.write(csv_line) subprocess.check_call('gsutil cp %s %s' % (local_file, filename), shell=True) def _get_default_project_id(self): with open(os.devnull, 'w') as dev_null: cmd = 'gcloud config list project --format=\'value(core.project)\'' return subprocess.check_output(cmd, shell=True, stderr=dev_null).strip() def _run_analyze(self): """Runs analysis using BigQuery from csv files.""" cloud = True self._logger.debug('Create input files') features = { 'num': {'transform': 'scale'}, 'img': {'transform': 'image_to_vec'}, 'target': {'transform': 'target'}, 'key': {'transform': 'key'}} file_io.write_string_to_file(self._features_filename, json.dumps(features, indent=2)) schema = [ {'name': 'key', 'type': 'integer'}, {'name': 'target', 'type': 'float'}, {'name': 'num', 'type': 'integer'}, {'name': 'img', 'type': 'string'}] file_io.write_string_to_file(self._schema_filename, json.dumps(schema, indent=2)) self._make_image_files() self._make_csv_data(self._csv_train_filename, 30, True, False) self._make_csv_data(self._csv_eval_filename, 10, True, False) self._make_csv_data(self._csv_predict_filename, 5, False, True) cmd = ['python %s' % os.path.join(CODE_PATH, 'analyze.py'), '--cloud' if cloud else '', '--output=' + self._analysis_output, '--csv=' + self._csv_train_filename, '--schema=' + self._schema_filename, '--features=' + self._features_filename] self._logger.debug('Running subprocess: %s \n\n' % ' '.join(cmd)) subprocess.check_call(' '.join(cmd), shell=True) self.assertTrue(file_io.file_exists(os.path.join(self._analysis_output, 'stats.json'))) self.assertTrue(file_io.file_exists(os.path.join(self._analysis_output, 'schema.json'))) self.assertTrue(file_io.file_exists(os.path.join(self._analysis_output, 'features.json'))) def _run_transform(self): """Runs DataFlow for makint tf.example files. Only the train file uses DataFlow, the eval file runs beam locally to save time. """ cloud = True extra_args = [] if cloud: extra_args = ['--cloud', '--job-name=test-mltoolbox-df-%s' % uuid.uuid4().hex, '--project-id=%s' % self._get_default_project_id(), '--num-workers=3'] cmd = ['python %s' % os.path.join(CODE_PATH, 'transform.py'), '--csv=' + self._csv_train_filename, '--analysis=' + self._analysis_output, '--prefix=features_train', '--output=' + self._transform_output, '--shuffle'] + extra_args self._logger.debug('Running subprocess: %s \n\n' % ' '.join(cmd)) subprocess.check_call(' '.join(cmd), shell=True) # Don't wate time running a 2nd DF job, run it locally. cmd = ['python %s' % os.path.join(CODE_PATH, 'transform.py'), '--csv=' + self._csv_eval_filename, '--analysis=' + self._analysis_output, '--prefix=features_eval', '--output=' + self._transform_output] self._logger.debug('Running subprocess: %s \n\n' % ' '.join(cmd)) subprocess.check_call(' '.join(cmd), shell=True) # Check the files were made train_files = file_io.get_matching_files( os.path.join(self._transform_output, 'features_train*')) eval_files = file_io.get_matching_files( os.path.join(self._transform_output, 'features_eval*')) self.assertNotEqual([], train_files) self.assertNotEqual([], eval_files) def _run_training_transform(self): """Runs training starting with transformed tf.example files.""" cloud = True if cloud: cmd = ['gcloud ml-engine jobs submit training test_mltoolbox_train_%s' % uuid.uuid4().hex, '--runtime-version=1.0', '--scale-tier=STANDARD_1', '--stream-logs'] else: cmd = ['gcloud ml-engine local train'] cmd = cmd + [ '--module-name trainer.task', '--job-dir=' + self._train_output, '--package-path=' + os.path.join(CODE_PATH, 'trainer'), '--', '--train=' + os.path.join(self._transform_output, 'features_train*'), '--eval=' + os.path.join(self._transform_output, 'features_eval*'), '--analysis=' + self._analysis_output, '--model=linear_regression', '--train-batch-size=10', '--eval-batch-size=10', '--max-steps=' + str(self._max_steps)] self._logger.debug('Running subprocess: %s \n\n' % ' '.join(cmd)) subprocess.check_call(' '.join(cmd), shell=True) # Check the saved model was made. self.assertTrue(file_io.file_exists( os.path.join(self._train_output, 'model', 'saved_model.pb'))) self.assertTrue(file_io.file_exists( os.path.join(self._train_output, 'evaluation_model', 'saved_model.pb'))) def _run_batch_prediction(self): """Run batch prediction using the cloudml engine prediction service. There is no local version of this step as it's the last step. """ job_name = 'test_mltoolbox_batchprediction_%s' % uuid.uuid4().hex cmd = ['gcloud ml-engine jobs submit prediction ' + job_name, '--data-format=TEXT', '--input-paths=' + self._csv_predict_filename, '--output-path=' + self._prediction_output, '--model-dir=' + os.path.join(self._train_output, 'model'), '--runtime-version=1.0', '--region=us-central1'] self._logger.debug('Running subprocess: %s \n\n' % ' '.join(cmd)) subprocess.check_call(' '.join(cmd), shell=True) # async call. subprocess.check_call('gcloud ml-engine jobs stream-logs ' + job_name, shell=True) # check that there was no errors. error_files = file_io.get_matching_files( os.path.join(self._prediction_output, 'prediction.errors_stats*')) self.assertEqual(1, len(error_files)) error_str = file_io.read_file_to_string(error_files[0]) self.assertEqual('', error_str) def test_cloud_workflow(self): self._run_analyze() self._run_transform() self._run_training_transform() self._run_batch_prediction() if __name__ == '__main__': unittest.main() ================================================ FILE: solutionbox/ml_workbench/test_tensorflow/test_feature_transforms.py ================================================ from __future__ import absolute_import from __future__ import print_function import base64 import cStringIO from PIL import Image import json import math import numpy as np import os import shutil import sys import tempfile import unittest import tensorflow as tf from tensorflow.python.lib.io import file_io # To make 'import analyze' work without installing it. sys.path.append(os.path.abspath( os.path.join(os.path.dirname(__file__), '..', 'tensorflow', 'trainer'))) import feature_transforms # noqa: E303 # Some tests put files in GCS or use BigQuery. If HAS_CREDENTIALS is false, # those tests will not run. HAS_CREDENTIALS = True try: import google.datalab as dl dl.Context.default().project_id except Exception: HAS_CREDENTIALS = False class TestGraphBuilding(unittest.TestCase): """Test the TITO functions work and can produce a working TF graph.""" def _run_graph(self, analysis_path, features, schema, stats, predict_data): """Runs the preprocessing graph. Args: analysis_path: path to folder containing analysis output. Should contain the stats file. features: features dict schema: schema list stats: stats dict predict_data: list of csv strings """ stats = {'column_stats': {}} with tf.Graph().as_default(): with tf.Session().as_default() as session: outputs, labels, inputs = feature_transforms.build_csv_serving_tensors_for_transform_step( analysis_path, features, schema, stats, keep_target=False) feed_inputs = {inputs['csv_example']: predict_data} session.run(tf.tables_initializer()) result = session.run(outputs, feed_dict=feed_inputs) return result def test_make_transform_graph_numerics(self): output_folder = tempfile.mkdtemp() stats_file_path = os.path.join(output_folder, feature_transforms.STATS_FILE) try: stats = {'column_stats': {'num1': {'max': 10.0, 'mean': 9.5, 'min': 0.0}, # noqa 'num2': {'max': 1.0, 'mean': 2.0, 'min': -1.0}, 'num3': {'max': 10.0, 'mean': 2.0, 'min': 5.0}}} schema = [{'name': 'num1', 'type': 'FLOAT'}, {'name': 'num2', 'type': 'FLOAT'}, {'name': 'num3', 'type': 'INTEGER'}] features = {'num1': {'transform': 'identity', 'source_column': 'num1'}, 'num2': {'transform': 'scale', 'value': 10, 'source_column': 'num2'}, 'num3': {'transform': 'scale', 'source_column': 'num3'}} input_data = ['5.0,-1.0,10', '10.0,1.0,5', '15.0,0.5,7'] file_io.write_string_to_file( stats_file_path, json.dumps(stats)) results = self._run_graph(output_folder, features, schema, stats, input_data) for result, expected_result in zip(results['num1'].flatten().tolist(), [5, 10, 15]): self.assertAlmostEqual(result, expected_result) for result, expected_result in zip(results['num2'].flatten().tolist(), [-10, 10, 5]): self.assertAlmostEqual(result, expected_result) for result, expected_result in zip(results['num3'].flatten().tolist(), [1, -1, (7.0 - 5) * 2.0 / 5.0 - 1]): self.assertAlmostEqual(result, expected_result) finally: shutil.rmtree(output_folder) def test_make_transform_graph_category(self): output_folder = tempfile.mkdtemp() try: file_io.write_string_to_file( os.path.join(output_folder, feature_transforms.VOCAB_ANALYSIS_FILE % 'cat1'), '\n'.join(['red,300', 'blue,200', 'green,100'])) file_io.write_string_to_file( os.path.join(output_folder, feature_transforms.VOCAB_ANALYSIS_FILE % 'cat2'), '\n'.join(['pizza,300', 'ice_cream,200', 'cookies,100'])) stats = {'column_stats': {}} # stats file needed but unused. file_io.write_string_to_file( os.path.join(output_folder, feature_transforms.STATS_FILE), json.dumps(stats)) schema = [{'name': 'cat1', 'type': 'STRING'}, {'name': 'cat2', 'type': 'STRING'}] features = {'cat1': {'transform': 'one_hot', 'source_column': 'cat1'}, 'cat2': {'transform': 'embedding', 'source_column': 'cat2'}} input_data = ['red,pizza', 'blue,', 'green,extra'] results = self._run_graph(output_folder, features, schema, stats, input_data) for result, expected_result in zip(results['cat1'].flatten().tolist(), [0, 1, 2]): self.assertEqual(result, expected_result) for result, expected_result in zip(results['cat2'].flatten().tolist(), [0, 3, 3]): self.assertEqual(result, expected_result) finally: shutil.rmtree(output_folder) def test_make_transform_graph_text_tfidf(self): output_folder = tempfile.mkdtemp() try: # vocab id # red 0 # blue 1 # green 2 # oov 3 (out of vocab) # corpus size aka num_examples = 4 # IDF: log(num_examples/(1+number of examples that have this token)) # red: log(4/3) # blue: log(4/3) # green: log(4/2) # oov: log(4/1) file_io.write_string_to_file( os.path.join(output_folder, feature_transforms.VOCAB_ANALYSIS_FILE % 'cat1'), '\n'.join(['red,2', 'blue,2', 'green,1'])) stats = {'column_stats': {}, 'num_examples': 4} file_io.write_string_to_file( os.path.join(output_folder, feature_transforms.STATS_FILE), json.dumps(stats)) # decode_csv does not like 1 column files with an empty row, so add # a key column schema = [{'name': 'key', 'type': 'STRING'}, {'name': 'cat1', 'type': 'STRING'}] features = {'key': {'transform': 'key', 'source_column': 'key'}, 'cat1': {'transform': 'tfidf', 'source_column': 'cat1'}} input_data = ['0,red red red', # doc 0 '1,red green red', # doc 1 '2,blue', # doc 2 '3,blue blue', # doc 3 '4,', # doc 4 '5,brown', # doc 5 '6,brown blue'] # doc 6 results = self._run_graph(output_folder, features, schema, stats, input_data) # indices are in the form [doc id, vocab id] expected_indices = [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [2, 0], [3, 0], [3, 1], [5, 0], [6, 0], [6, 1]] expected_ids = [0, 0, 0, 0, 2, 0, 1, 1, 1, 3, 3, 1] self.assertEqual(results['cat1_ids'].indices.tolist(), expected_indices) self.assertEqual(results['cat1_ids'].dense_shape.tolist(), [7, 3]) self.assertEqual(results['cat1_ids'].values.tolist(), expected_ids) # Note, these are natural logs. log_4_3 = math.log(4.0 / 3.0) expected_weights = [ 1.0 / 3.0 * log_4_3, 1.0 / 3.0 * log_4_3, 1.0 / 3.0 * log_4_3, # doc 0 1.0 / 3.0 * log_4_3, 1.0 / 3.0 * math.log(2.0), 1.0 / 3.0 * log_4_3, # doc 1 math.log(4.0 / 3.0), # doc 2 1.0 / 2.0 * log_4_3, 1.0 / 2.0 * log_4_3, # doc 3 math.log(4.0), # doc 5 1.0 / 2.0 * math.log(4.0), 1.0 / 2.0 * log_4_3] # doc 6 self.assertEqual(results['cat1_weights'].indices.tolist(), expected_indices) self.assertEqual(results['cat1_weights'].dense_shape.tolist(), [7, 3]) self.assertEqual(results['cat1_weights'].values.size, len(expected_weights)) for weight, expected_weight in zip(results['cat1_weights'].values.tolist(), expected_weights): self.assertAlmostEqual(weight, expected_weight) finally: shutil.rmtree(output_folder) def test_make_transform_graph_text_multi_hot(self): output_folder = tempfile.mkdtemp() try: # vocab id # red 0 # blue 1 # green 2 # oov 3 (out of vocab) file_io.write_string_to_file( os.path.join(output_folder, feature_transforms.VOCAB_ANALYSIS_FILE % 'cat1'), '\n'.join(['red,2', 'blue,2', 'green,1'])) stats = {'column_stats': {}} file_io.write_string_to_file( os.path.join(output_folder, feature_transforms.STATS_FILE), json.dumps(stats)) # Stats file needed but unused. # decode_csv does not like 1 column files with an empty row, so add # a key column schema = [{'name': 'key', 'type': 'STRING'}, {'name': 'cat1', 'type': 'STRING'}] features = {'key': {'transform': 'key', 'source_column': 'key'}, 'cat1': {'transform': 'multi_hot', 'source_column': 'cat1', 'separator': '|'}} input_data = ['0,red', # doc 0 '1,red|green', # doc 1 '2,blue', # doc 2 '3,red|blue|green', # doc 3 '4,'] # doc 4 results = self._run_graph(output_folder, features, schema, stats, input_data) # indices are in the form [doc id, vocab id] expected_indices = [[0, 0], [1, 0], [1, 1], [2, 0], [3, 0], [3, 1], [3, 2]] # doc id 0 1 1 2 3 3 3 expected_ids = [0, 0, 2, 1, 0, 1, 2] # noqa self.assertEqual(results['cat1'].indices.tolist(), expected_indices) self.assertEqual(results['cat1'].dense_shape.tolist(), [5, 3]) self.assertEqual(results['cat1'].values.tolist(), expected_ids) finally: shutil.rmtree(output_folder) def test_make_transform_graph_text_bag_of_words(self): output_folder = tempfile.mkdtemp() try: # vocab id # red 0 # blue 1 # green 2 # oov 3 (out of vocab) file_io.write_string_to_file( os.path.join(output_folder, feature_transforms.VOCAB_ANALYSIS_FILE % 'cat1'), '\n'.join(['red,2', 'blue,2', 'green,1'])) stats = {'column_stats': {}} file_io.write_string_to_file( os.path.join(output_folder, feature_transforms.STATS_FILE), json.dumps(stats)) # Stats file needed but unused. # decode_csv does not like 1 column files with an empty row, so add # a key column schema = [{'name': 'key', 'type': 'STRING'}, {'name': 'cat1', 'type': 'STRING'}] features = {'key': {'transform': 'key', 'source_column': 'key'}, 'cat1': {'transform': 'bag_of_words', 'source_column': 'cat1'}} input_data = ['0,red red red', # doc 0 '1,red green red', # doc 1 '2,blue', # doc 2 '3,blue blue', # doc 3 '4,', # doc 4 '5,brown', # doc 5 '6,brown blue'] # doc 6 results = self._run_graph(output_folder, features, schema, stats, input_data) # indices are in the form [doc id, vocab id] expected_indices = [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [2, 0], [3, 0], [3, 1], [5, 0], [6, 0], [6, 1]] # Note in doc 6, is is blue, then brown. # doc id 0 0 0 1 1 1 2 3 3 5 6 6 expected_ids = [0, 0, 0, 0, 2, 0, 1, 1, 1, 3, 3, 1] # noqa expected_weights = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] self.assertEqual(results['cat1_ids'].indices.tolist(), expected_indices) self.assertEqual(results['cat1_ids'].dense_shape.tolist(), [7, 3]) self.assertEqual(results['cat1_ids'].values.tolist(), expected_ids) self.assertEqual(results['cat1_weights'].indices.tolist(), expected_indices) self.assertEqual(results['cat1_weights'].dense_shape.tolist(), [7, 3]) self.assertEqual(results['cat1_weights'].values.size, len(expected_weights)) for weight, exp_weight in zip(results['cat1_weights'].values.tolist(), expected_weights): self.assertAlmostEqual(weight, exp_weight) finally: shutil.rmtree(output_folder) @unittest.skipIf(not HAS_CREDENTIALS, 'GCS access missing') def test_make_transform_graph_images(self): print('Testing make_transform_graph with image_to_vec. ' + 'It may take a few minutes because it needs to download a large inception checkpoint.') def _open_and_encode_image(img_url): with file_io.FileIO(img_url, 'r') as f: img = Image.open(f).convert('RGB') output = cStringIO.StringIO() img.save(output, 'jpeg') return base64.urlsafe_b64encode(output.getvalue()) try: output_folder = tempfile.mkdtemp() stats_file_path = os.path.join(output_folder, feature_transforms.STATS_FILE) stats = {'column_stats': {}} file_io.write_string_to_file(stats_file_path, json.dumps(stats)) schema = [{'name': 'img', 'type': 'STRING'}] features = {'img': {'transform': 'image_to_vec', 'source_column': 'img'}} # Test transformation with encoded image content. img_string1 = _open_and_encode_image( 'gs://cloud-ml-data/img/flower_photos/daisy/15207766_fc2f1d692c_n.jpg') img_string2 = _open_and_encode_image( 'gs://cloud-ml-data/img/flower_photos/dandelion/8980164828_04fbf64f79_n.jpg') # Test transformation with direct file path. img_string3 = 'gs://cloud-ml-data/img/flower_photos/daisy/15207766_fc2f1d692c_n.jpg' img_string4 = 'gs://cloud-ml-data/img/flower_photos/dandelion/8980164828_04fbf64f79_n.jpg' input_data = [img_string1, img_string2, img_string3, img_string4] results = self._run_graph(output_folder, features, schema, stats, input_data) embeddings = results['img'] self.assertEqual(len(embeddings), 4) self.assertEqual(len(embeddings[0]), 2048) self.assertEqual(embeddings[0].dtype, np.float32) self.assertTrue(any(x != 0.0 for x in embeddings[1])) self.assertTrue(any(x != 0.0 for x in embeddings[3])) finally: shutil.rmtree(output_folder) if __name__ == '__main__': unittest.main() ================================================ FILE: solutionbox/ml_workbench/test_tensorflow/test_training.py ================================================ from __future__ import absolute_import import base64 import glob import json import logging import os import pandas as pd from PIL import Image import random import shutil from six.moves.urllib.request import urlopen import subprocess import sys import tempfile import unittest import tensorflow as tf from tensorflow.python.lib.io import file_io CODE_PATH = os.path.abspath(os.path.join( os.path.dirname(__file__), '..', 'tensorflow')) def run_exported_model(model_path, csv_data): """Runs an exported model. Model should have one placeholder of csv data. Args: model_path: path to the saved_model.pb csv_data: list of csv strings Return: The result of session.run """ with tf.Graph().as_default(), tf.Session() as sess: meta_graph_pb = tf.saved_model.loader.load( sess=sess, tags=[tf.saved_model.tag_constants.SERVING], export_dir=model_path) signature = meta_graph_pb.signature_def['serving_default'] input_alias_map = { friendly_name: tensor_info_proto.name for (friendly_name, tensor_info_proto) in signature.inputs.items()} output_alias_map = { friendly_name: tensor_info_proto.name for (friendly_name, tensor_info_proto) in signature.outputs.items()} _, csv_tensor_name = input_alias_map.items()[0] result = sess.run(fetches=output_alias_map, feed_dict={csv_tensor_name: csv_data}) return result class TestSpecialCharacters(unittest.TestCase): """Test special characters are supported.""" def testCommaQuote(self): """Test when csv input data has quotes and commas.""" output_dir = tempfile.mkdtemp() try: features = { 'target': {'transform': 'target'}, 'cat': {'transform': 'one_hot'}, 'text': {'transform': 'bag_of_words'}} schema = [ {'name': 'target', 'type': 'string'}, {'name': 'cat', 'type': 'string'}, {'name': 'text', 'type': 'string'}] # Target column = cat column data = [{'cat': 'red,', 'text': 'one, two, three', 'target': 'red,'}, {'cat': 'blue"', 'text': 'one, "two"', 'target': 'blue"'}, {'cat': '"green"', 'text': '"two', 'target': '"green"'}, {'cat': "yellow, 'brown", 'text': "'one, two'", 'target': "yellow, 'brown"}] file_io.recursive_create_dir(output_dir) file_io.write_string_to_file(os.path.join(output_dir, 'schema.json'), json.dumps(schema, indent=2)) file_io.write_string_to_file(os.path.join(output_dir, 'features.json'), json.dumps(features, indent=2)) file_io.write_string_to_file( os.path.join(output_dir, 'data.csv'), pd.DataFrame(data, columns=['target', 'cat', 'text']).to_csv(index=False, header=False)) # Run analysis and check the output vocabs are correctly encoded in csv cmd = ['python %s' % os.path.join(CODE_PATH, 'analyze.py'), '--output=' + os.path.join(output_dir, 'analysis'), '--csv=' + os.path.join(output_dir, 'data.csv'), '--schema=' + os.path.join(output_dir, 'schema.json'), '--features=' + os.path.join(output_dir, 'features.json')] subprocess.check_call(' '.join(cmd), shell=True) df_vocab_cat = pd.read_csv( os.path.join(output_dir, 'analysis', 'vocab_cat.csv'), header=None, names=['label', 'count'], dtype=str, na_filter=False) self.assertEqual(df_vocab_cat['count'].tolist(), ['1', '1', '1', '1']) self.assertItemsEqual( df_vocab_cat['label'].tolist(), ['blue"', '"green"', "yellow, 'brown", 'red,']) df_vocab_target = pd.read_csv( os.path.join(output_dir, 'analysis', 'vocab_target.csv'), header=None, names=['label', 'count'], dtype=str, na_filter=False) self.assertEqual(df_vocab_target['count'].tolist(), ['1', '1', '1', '1']) self.assertItemsEqual( df_vocab_target['label'].tolist(), ['blue"', '"green"', "yellow, 'brown", 'red,']) df_vocab_text = pd.read_csv( os.path.join(output_dir, 'analysis', 'vocab_text.csv'), header=None, names=['label', 'count'], dtype=str, na_filter=False) vocab_text = df_vocab_text['label'].tolist() self.assertEqual(vocab_text[0], 'one,') self.assertItemsEqual(vocab_text[1:], ['two,', '"two"', "'one,", '"two', "two'", 'three']) vocab_count = df_vocab_text['count'].tolist() self.assertEqual(vocab_count[0], '2') self.assertEqual(vocab_count[1:], ['1', '1', '1', '1', '1', '1']) # Run transform, and check there are no reported errors. cmd = ['python %s' % os.path.join(CODE_PATH, 'transform.py'), '--csv=' + os.path.join(output_dir, 'data.csv'), '--analysis=' + os.path.join(output_dir, 'analysis'), '--prefix=features_train', '--output=' + os.path.join(output_dir, 'transform')] subprocess.check_call(' '.join(cmd), shell=True) error_files = glob.glob(os.path.join(output_dir, 'transform', 'error*')) self.assertEqual(1, len(error_files)) self.assertEqual(0, os.path.getsize(error_files[0])) # Run training cmd = ['cd %s && ' % CODE_PATH, 'python -m trainer.task', '--train=' + os.path.join(output_dir, 'data.csv'), '--eval=' + os.path.join(output_dir, 'data.csv'), '--job-dir=' + os.path.join(output_dir, 'training'), '--analysis=' + os.path.join(output_dir, 'analysis'), '--model=linear_classification', '--train-batch-size=4', '--eval-batch-size=4', '--max-steps=500', '--learning-rate=1.0', '--transform'] subprocess.check_call(' '.join(cmd), shell=True) result = run_exported_model( model_path=os.path.join(output_dir, 'training', 'model'), csv_data=['"red,","one, two, three"']) # The prediction data is a training row. As the data is samll, the model # should have near 100% accuracy. Check it made the correct prediction. self.assertEqual(result['predicted'], 'red,') finally: shutil.rmtree(output_dir) class TestClassificationTopN(unittest.TestCase): """Test top_n works.""" def testTopNZero(self): """Test top_n=0 gives all the classes.""" output_dir = tempfile.mkdtemp() try: features = { 'num': {'transform': 'identity'}, 'target': {'transform': 'target'}} schema = [ {'name': 'num', 'type': 'integer'}, {'name': 'target', 'type': 'string'}] data = ['1,1\n', '4,2\n', '5,3\n', '11,1\n'] file_io.recursive_create_dir(output_dir) file_io.write_string_to_file(os.path.join(output_dir, 'schema.json'), json.dumps(schema, indent=2)) file_io.write_string_to_file(os.path.join(output_dir, 'features.json'), json.dumps(features, indent=2)) file_io.write_string_to_file(os.path.join(output_dir, 'data.csv'), ''.join(data)) cmd = ['python %s' % os.path.join(CODE_PATH, 'analyze.py'), '--output=' + os.path.join(output_dir, 'analysis'), '--csv=' + os.path.join(output_dir, 'data.csv'), '--schema=' + os.path.join(output_dir, 'schema.json'), '--features=' + os.path.join(output_dir, 'features.json')] subprocess.check_call(' '.join(cmd), shell=True) cmd = ['cd %s && ' % CODE_PATH, 'python -m trainer.task', '--train=' + os.path.join(output_dir, 'data.csv'), '--eval=' + os.path.join(output_dir, 'data.csv'), '--job-dir=' + os.path.join(output_dir, 'training'), '--analysis=' + os.path.join(output_dir, 'analysis'), '--model=linear_classification', '--train-batch-size=4', '--eval-batch-size=4', '--max-steps=1', '--top-n=0', # This parameter is tested in this test! '--learning-rate=0.1', '--transform'] subprocess.check_call(' '.join(cmd), shell=True) result = run_exported_model( model_path=os.path.join(output_dir, 'training', 'model'), csv_data=['20']) keys = result.keys() self.assertIn('predicted', keys) self.assertIn('1', keys) self.assertIn('2', keys) self.assertIn('3', keys) finally: shutil.rmtree(output_dir) class TestMultipleFeatures(unittest.TestCase): """Test one source column can be used in many features.""" def testMultipleColumnsRaw(self): """Test training starting from raw csv.""" output_dir = tempfile.mkdtemp() try: features = { 'num': {'transform': 'identity'}, 'num2': {'transform': 'key', 'source_column': 'num'}, 'target': {'transform': 'target'}, 'text': {'transform': 'bag_of_words'}, 'text2': {'transform': 'multi_hot', 'source_column': 'text'}, 'text3': {'transform': 'tfidf', 'source_column': 'text'}, 'text4': {'transform': 'key', 'source_column': 'text'}} schema = [ {'name': 'num', 'type': 'integer'}, {'name': 'target', 'type': 'float'}, {'name': 'text', 'type': 'string'}] data = ['1,2,hello world\n', '4,8,bye moon\n', '5,10,hello moon\n', '11,22,moon moon\n'] file_io.recursive_create_dir(output_dir) file_io.write_string_to_file(os.path.join(output_dir, 'schema.json'), json.dumps(schema, indent=2)) file_io.write_string_to_file(os.path.join(output_dir, 'features.json'), json.dumps(features, indent=2)) file_io.write_string_to_file(os.path.join(output_dir, 'data.csv'), ''.join(data)) cmd = ['python %s' % os.path.join(CODE_PATH, 'analyze.py'), '--output=' + os.path.join(output_dir, 'analysis'), '--csv=' + os.path.join(output_dir, 'data.csv'), '--schema=' + os.path.join(output_dir, 'schema.json'), '--features=' + os.path.join(output_dir, 'features.json')] subprocess.check_call(' '.join(cmd), shell=True) cmd = ['cd %s && ' % CODE_PATH, 'python -m trainer.task', '--train=' + os.path.join(output_dir, 'data.csv'), '--eval=' + os.path.join(output_dir, 'data.csv'), '--job-dir=' + os.path.join(output_dir, 'training'), '--analysis=' + os.path.join(output_dir, 'analysis'), '--model=linear_regression', '--train-batch-size=4', '--eval-batch-size=4', '--max-steps=200', '--learning-rate=0.1', '--transform'] subprocess.check_call(' '.join(cmd), shell=True) result = run_exported_model( model_path=os.path.join(output_dir, 'training', 'model'), csv_data=['20,hello moon']) # check keys were made self.assertEqual(20, result['num2']) self.assertEqual('hello moon', result['text4']) finally: shutil.rmtree(output_dir) def testMultipleColumnsTransformed(self): """Test training starting from tf.example.""" output_dir = tempfile.mkdtemp() try: features = { 'num': {'transform': 'identity'}, 'num2': {'transform': 'key', 'source_column': 'num'}, 'target': {'transform': 'target'}, 'text': {'transform': 'bag_of_words'}, 'text2': {'transform': 'multi_hot', 'source_column': 'text'}, 'text3': {'transform': 'tfidf', 'source_column': 'text'}, 'text4': {'transform': 'key', 'source_column': 'text'}} schema = [ {'name': 'num', 'type': 'integer'}, {'name': 'target', 'type': 'float'}, {'name': 'text', 'type': 'string'}] data = ['1,2,hello world\n', '4,8,bye moon\n', '5,10,hello moon\n', '11,22,moon moon\n'] file_io.recursive_create_dir(output_dir) file_io.write_string_to_file(os.path.join(output_dir, 'schema.json'), json.dumps(schema, indent=2)) file_io.write_string_to_file(os.path.join(output_dir, 'features.json'), json.dumps(features, indent=2)) file_io.write_string_to_file(os.path.join(output_dir, 'data.csv'), ''.join(data)) cmd = ['python %s' % os.path.join(CODE_PATH, 'analyze.py'), '--output=' + os.path.join(output_dir, 'analysis'), '--csv=' + os.path.join(output_dir, 'data.csv'), '--schema=' + os.path.join(output_dir, 'schema.json'), '--features=' + os.path.join(output_dir, 'features.json')] subprocess.check_call(' '.join(cmd), shell=True) cmd = ['python %s' % os.path.join(CODE_PATH, 'transform.py'), '--output=' + os.path.join(output_dir, 'transform'), '--csv=' + os.path.join(output_dir, 'data.csv'), '--analysis=' + os.path.join(output_dir, 'analysis'), '--prefix=features'] subprocess.check_call(' '.join(cmd), shell=True) # Check tf.example file has the expected features file_list = file_io.get_matching_files(os.path.join(output_dir, 'transform', 'features*')) options = tf.python_io.TFRecordOptions( compression_type=tf.python_io.TFRecordCompressionType.GZIP) record_iter = tf.python_io.tf_record_iterator(path=file_list[0], options=options) tf_example = tf.train.Example() tf_example.ParseFromString(next(record_iter)) self.assertEqual(1, len(tf_example.features.feature['num'].int64_list.value)) self.assertEqual(1, len(tf_example.features.feature['num2'].int64_list.value)) self.assertEqual(1, len(tf_example.features.feature['target'].float_list.value)) self.assertEqual(2, len(tf_example.features.feature['text_ids'].int64_list.value)) self.assertEqual(2, len(tf_example.features.feature['text_weights'].float_list.value)) self.assertEqual(2, len(tf_example.features.feature['text2'].int64_list.value)) self.assertEqual(2, len(tf_example.features.feature['text3_ids'].int64_list.value)) self.assertEqual(2, len(tf_example.features.feature['text3_weights'].float_list.value)) self.assertEqual(1, len(tf_example.features.feature['text4'].bytes_list.value)) cmd = ['cd %s && ' % CODE_PATH, 'python -m trainer.task', '--train=' + os.path.join(output_dir, 'transform', 'features*'), '--eval=' + os.path.join(output_dir, 'transform', 'features*'), '--job-dir=' + os.path.join(output_dir, 'training'), '--analysis=' + os.path.join(output_dir, 'analysis'), '--model=linear_regression', '--train-batch-size=4', '--eval-batch-size=4', '--max-steps=200', '--learning-rate=0.1'] subprocess.check_call(' '.join(cmd), shell=True) result = run_exported_model( model_path=os.path.join(output_dir, 'training', 'model'), csv_data=['20,hello moon']) # check keys were made self.assertEqual(20, result['num2']) self.assertEqual('hello moon', result['text4']) finally: shutil.rmtree(output_dir) class TestOptionalKeys(unittest.TestCase): def testNoKeys(self): output_dir = tempfile.mkdtemp() try: features = { 'num': {'transform': 'identity'}, 'target': {'transform': 'target'}} schema = [ {'name': 'num', 'type': 'integer'}, {'name': 'target', 'type': 'float'}] data = ['1,2\n', '4,8\n', '5,10\n', '11,22\n'] file_io.recursive_create_dir(output_dir) file_io.write_string_to_file(os.path.join(output_dir, 'schema.json'), json.dumps(schema, indent=2)) file_io.write_string_to_file(os.path.join(output_dir, 'features.json'), json.dumps(features, indent=2)) file_io.write_string_to_file(os.path.join(output_dir, 'data.csv'), ''.join(data)) cmd = ['python %s' % os.path.join(CODE_PATH, 'analyze.py'), '--output=' + os.path.join(output_dir, 'analysis'), '--csv=' + os.path.join(output_dir, 'data.csv'), '--schema=' + os.path.join(output_dir, 'schema.json'), '--features=' + os.path.join(output_dir, 'features.json')] subprocess.check_call(' '.join(cmd), shell=True) cmd = ['cd %s && ' % CODE_PATH, 'python -m trainer.task', '--train=' + os.path.join(output_dir, 'data.csv'), '--eval=' + os.path.join(output_dir, 'data.csv'), '--job-dir=' + os.path.join(output_dir, 'training'), '--analysis=' + os.path.join(output_dir, 'analysis'), '--model=linear_regression', '--train-batch-size=4', '--eval-batch-size=4', '--max-steps=2000', '--learning-rate=0.1', '--transform'] subprocess.check_call(' '.join(cmd), shell=True) result = run_exported_model( model_path=os.path.join(output_dir, 'training', 'model'), csv_data=['20']) self.assertTrue(abs(40 - result['predicted']) < 5) finally: shutil.rmtree(output_dir) def testManyKeys(self): output_dir = tempfile.mkdtemp() try: features = { 'keyint': {'transform': 'key'}, 'keyfloat': {'transform': 'key'}, 'keystr': {'transform': 'key'}, 'num': {'transform': 'identity'}, 'target': {'transform': 'target'}} schema = [ {'name': 'keyint', 'type': 'integer'}, {'name': 'keyfloat', 'type': 'float'}, {'name': 'keystr', 'type': 'string'}, {'name': 'num', 'type': 'integer'}, {'name': 'target', 'type': 'float'}] data = ['1,1.5,one,1,2\n', '2,2.5,two,4,8\n', '3,3.5,three,5,10\n'] file_io.recursive_create_dir(output_dir) file_io.write_string_to_file(os.path.join(output_dir, 'schema.json'), json.dumps(schema, indent=2)) file_io.write_string_to_file(os.path.join(output_dir, 'features.json'), json.dumps(features, indent=2)) file_io.write_string_to_file(os.path.join(output_dir, 'data.csv'), ''.join(data)) cmd = ['python %s' % os.path.join(CODE_PATH, 'analyze.py'), '--output=' + os.path.join(output_dir, 'analysis'), '--csv=' + os.path.join(output_dir, 'data.csv'), '--schema=' + os.path.join(output_dir, 'schema.json'), '--features=' + os.path.join(output_dir, 'features.json')] subprocess.check_call(' '.join(cmd), shell=True) cmd = ['cd %s && ' % CODE_PATH, 'python -m trainer.task', '--train=' + os.path.join(output_dir, 'data.csv'), '--eval=' + os.path.join(output_dir, 'data.csv'), '--job-dir=' + os.path.join(output_dir, 'training'), '--analysis=' + os.path.join(output_dir, 'analysis'), '--model=linear_regression', '--train-batch-size=4', '--eval-batch-size=4', '--max-steps=2000', '--transform'] subprocess.check_call(' '.join(cmd), shell=True) result = run_exported_model( model_path=os.path.join(output_dir, 'training', 'model'), csv_data=['7,4.5,hello,1']) self.assertEqual(7, result['keyint']) self.assertAlmostEqual(4.5, result['keyfloat']) self.assertEqual('hello', result['keystr']) finally: shutil.rmtree(output_dir) class TestTrainer(unittest.TestCase): """Tests training. Runs analyze.py and transform.py on generated test data. Also loads the exported graphs and checks they run. No validation of the test results is done (i.e., the training loss is not checked). """ def __init__(self, *args, **kwargs): super(TestTrainer, self).__init__(*args, **kwargs) # Allow this class to be subclassed for quick tests that only care about # training working, not model loss/accuracy. self._max_steps = 2000 self._check_model_fit = True # Log everything self._logger = logging.getLogger('TestStructuredDataLogger') self._logger.setLevel(logging.DEBUG) if not self._logger.handlers: self._logger.addHandler(logging.StreamHandler(stream=sys.stdout)) def setUp(self): self._test_dir = tempfile.mkdtemp() self._analysis_output = os.path.join(self._test_dir, 'analysis_output') self._transform_output = os.path.join(self._test_dir, 'transform_output') self._train_output = os.path.join(self._test_dir, 'train_output') file_io.recursive_create_dir(self._analysis_output) file_io.recursive_create_dir(self._transform_output) file_io.recursive_create_dir(self._train_output) self._csv_train_filename = os.path.join(self._test_dir, 'train_csv_data.csv') self._csv_eval_filename = os.path.join(self._test_dir, 'eval_csv_data.csv') self._csv_predict_filename = os.path.join(self._test_dir, 'predict_csv_data.csv') self._schema_filename = os.path.join(self._test_dir, 'schema_file.json') self._features_filename = os.path.join(self._test_dir, 'features_file.json') def tearDown(self): self._logger.debug('TestTrainer: removing test dir ' + self._test_dir) shutil.rmtree(self._test_dir) def make_image_files(self): img1_file = os.path.join(self._test_dir, 'img1.jpg') image1 = Image.new('RGB', size=(300, 300), color=(155, 0, 0)) image1.save(img1_file) img2_file = os.path.join(self._test_dir, 'img2.jpg') image2 = Image.new('RGB', size=(50, 50), color=(125, 240, 0)) image2.save(img2_file) img3_file = os.path.join(self._test_dir, 'img3.jpg') image3 = Image.new('RGB', size=(800, 600), color=(33, 55, 77)) image3.save(img3_file) self._image_files = [img1_file, img2_file, img3_file] def make_csv_data(self, filename, num_rows, problem_type, keep_target=True, with_image=False): """Writes csv data for preprocessing and training. There is one csv column for each supported transform. Args: filename: writes data to local csv file. num_rows: how many rows of data will be generated. problem_type: 'classification' or 'regression'. Changes the target value. keep_target: if false, the csv file will have an empty column ',,' for the target. """ random.seed(12321) def _drop_out(x): # Make 5% of the data missing if random.uniform(0, 1) < 0.05: return '' return x with open(filename, 'w') as f: for i in range(num_rows): num_id = random.randint(0, 20) num_scale = random.uniform(0, 30) str_one_hot = random.choice(['red', 'blue', 'green', 'pink', 'yellow', 'brown', 'black']) str_embedding = random.choice(['abc', 'def', 'ghi', 'jkl', 'mno', 'pqr']) def _word_fn(): return random.choice(['car', 'truck', 'van', 'bike', 'train', 'drone']) str_bow = [_word_fn() for _ in range(random.randint(1, 4))] str_tfidf = [_word_fn() for _ in range(random.randint(1, 4))] color_map = {'red': 2, 'blue': 6, 'green': 4, 'pink': -5, 'yellow': -6, 'brown': -1, 'black': -7} abc_map = {'abc': -1, 'def': -1, 'ghi': 1, 'jkl': 1, 'mno': 2, 'pqr': 1} transport_map = {'car': 5, 'truck': 10, 'van': 15, 'bike': 20, 'train': -25, 'drone': -30} # Build some model: t id the dependent variable t = 0.5 + 0.5 * num_id - 2.5 * num_scale t += color_map[str_one_hot] t += abc_map[str_embedding] t += sum([transport_map[x] for x in str_bow]) t += sum([transport_map[x] * 0.5 for x in str_tfidf]) if problem_type == 'classification': # If you cange the weights above or add more columns, look at the new # distribution of t values and try to divide them into 3 buckets. if t < -40: t = 100 elif t < 0: t = 101 else: t = 102 str_bow = ' '.join(str_bow) str_tfidf = ' '.join(str_tfidf) if with_image: img_url = random.choice(self._image_files) _drop_out(img_url) num_id = _drop_out(num_id) num_scale = _drop_out(num_scale) str_one_hot = _drop_out(str_one_hot) str_embedding = _drop_out(str_embedding) str_bow = _drop_out(str_bow) str_tfidf = _drop_out(str_tfidf) if keep_target: if with_image: csv_line = "{key},{target},{num_id},{num_scale},{str_one_hot},{str_embedding},{str_bow},{str_tfidf},{img_url}\n".format( # noqa key=i, target=t, num_id=num_id, num_scale=num_scale, str_one_hot=str_one_hot, str_embedding=str_embedding, str_bow=str_bow, str_tfidf=str_tfidf, img_url=img_url) else: csv_line = "{key},{target},{num_id},{num_scale},{str_one_hot},{str_embedding},{str_bow},{str_tfidf}\n".format( # noqa key=i, target=t, num_id=num_id, num_scale=num_scale, str_one_hot=str_one_hot, str_embedding=str_embedding, str_bow=str_bow, str_tfidf=str_tfidf) else: if with_image: csv_line = "{key},{num_id},{num_scale},{str_one_hot},{str_embedding},{str_bow},{str_tfidf},{img_url}\n".format( # noqa key=i, num_id=num_id, num_scale=num_scale, str_one_hot=str_one_hot, str_embedding=str_embedding, str_bow=str_bow, str_tfidf=str_tfidf, img_url=img_url) else: csv_line = "{key},{num_id},{num_scale},{str_one_hot},{str_embedding},{str_bow},{str_tfidf}\n".format( # noqa key=i, num_id=num_id, num_scale=num_scale, str_one_hot=str_one_hot, str_embedding=str_embedding, str_bow=str_bow, str_tfidf=str_tfidf) f.write(csv_line) def _create_schema_features(self, problem_type, with_image=False): features = { 'num_id': {'transform': 'identity'}, 'num_scale': {'transform': 'scale', 'value': 4}, 'str_one_hot': {'transform': 'one_hot'}, 'str_embedding': {'transform': 'embedding', 'embedding_dim': 3}, 'str_bow': {'transform': 'bag_of_words'}, 'str_tfidf': {'transform': 'tfidf'}, 'target': {'transform': 'target'}, 'key': {'transform': 'key'}} if with_image: # Download inception checkpoint. Note that gs url doesn't work because # we may not have gcloud signed in when running the test. url = ('https://storage.googleapis.com/cloud-ml-data/img/' + 'flower_photos/inception_v3_2016_08_28.ckpt') checkpoint_path = os.path.join(self._test_dir, "checkpoint") response = urlopen(url) with open(checkpoint_path, 'wb') as f: f.write(response.read()) features['image'] = {'transform': 'image_to_vec', 'checkpoint': checkpoint_path} schema = [ {'name': 'key', 'type': 'integer'}, {'name': 'target', 'type': 'string' if problem_type == 'classification' else 'float'}, {'name': 'num_id', 'type': 'integer'}, {'name': 'num_scale', 'type': 'float'}, {'name': 'str_one_hot', 'type': 'string'}, {'name': 'str_embedding', 'type': 'string'}, {'name': 'str_bow', 'type': 'string'}, {'name': 'str_tfidf', 'type': 'string'}] if with_image: schema.append({'name': 'image', 'type': 'string'}) self._schema = schema file_io.write_string_to_file(self._schema_filename, json.dumps(schema, indent=2)) file_io.write_string_to_file(self._features_filename, json.dumps(features, indent=2)) if with_image: self.make_image_files() self.make_csv_data(self._csv_train_filename, 50, problem_type, True, with_image) self.make_csv_data(self._csv_eval_filename, 30, problem_type, True, with_image) self.make_csv_data(self._csv_predict_filename, 10, problem_type, False, with_image) def _run_analyze(self, problem_type, with_image=False): self._create_schema_features(problem_type, with_image=with_image) cmd = ['python %s' % os.path.join(CODE_PATH, 'analyze.py'), '--output=' + self._analysis_output, '--csv=' + self._csv_train_filename, '--schema=' + self._schema_filename, '--features=' + self._features_filename] subprocess.check_call(' '.join(cmd), shell=True) def _run_transform(self): cmd = ['python %s' % os.path.join(CODE_PATH, 'transform.py'), '--csv=' + self._csv_train_filename, '--analysis=' + self._analysis_output, '--prefix=features_train', '--output=' + self._transform_output, '--shuffle'] self._logger.debug('Running subprocess: %s \n\n' % ' '.join(cmd)) subprocess.check_call(' '.join(cmd), shell=True) cmd = ['python %s' % os.path.join(CODE_PATH, 'transform.py'), '--csv=' + self._csv_eval_filename, '--analysis=' + self._analysis_output, '--prefix=features_eval', '--output=' + self._transform_output] self._logger.debug('Running subprocess: %s \n\n' % ' '.join(cmd)) subprocess.check_call(' '.join(cmd), shell=True) def _run_training_transform(self, problem_type, model_type, extra_args=[]): """Runs training starting with transformed tf.example files. Args: problem_type: 'regression' or 'classification' model_type: 'linear' or 'dnn' extra_args: list of strings to pass to the trainer. """ cmd = ['cd %s && ' % CODE_PATH, 'python -m trainer.task', '--train=' + os.path.join(self._transform_output, 'features_train*'), '--eval=' + os.path.join(self._transform_output, 'features_eval*'), '--job-dir=' + self._train_output, '--analysis=' + self._analysis_output, '--model=%s_%s' % (model_type, problem_type), '--train-batch-size=100', '--eval-batch-size=50', '--max-steps=' + str(self._max_steps)] + extra_args self._logger.debug('Running subprocess: %s \n\n' % ' '.join(cmd)) subprocess.check_call(' '.join(cmd), shell=True) def _run_training_raw(self, problem_type, model_type, extra_args=[]): """Runs training starting from raw csv data. Args: problem_type: 'regression' or 'classification' model_type: 'linear' or 'dnn' extra_args: list of strings to pass to the trainer. """ cmd = ['cd %s && ' % CODE_PATH, 'python -m trainer.task', '--train=' + self._csv_train_filename, '--eval=' + self._csv_eval_filename, '--job-dir=' + self._train_output, '--analysis=' + self._analysis_output, '--model=%s_%s' % (model_type, problem_type), '--train-batch-size=100', '--eval-batch-size=50', '--max-steps=' + str(self._max_steps), '--transform'] + extra_args self._logger.debug('Running subprocess: %s \n\n' % ' '.join(cmd)) subprocess.check_call(' '.join(cmd), shell=True) def _run_training_with_analysis(self, problem_type, model_type, extra_args=[]): """Runs training starting from raw csv data. Args: problem_type: 'regression' or 'classification' model_type: 'linear' or 'dnn' extra_args: list of strings to pass to the trainer. """ cmd = ['cd %s && ' % CODE_PATH, 'python -m trainer.task', '--train=' + self._csv_train_filename, '--eval=' + self._csv_eval_filename, '--job-dir=' + self._train_output, '--model=%s_%s' % (model_type, problem_type), '--train-batch-size=100', '--eval-batch-size=50', '--max-steps=' + str(self._max_steps), '--features=' + self._features_filename, '--schema=' + self._schema_filename, '--transform'] + extra_args self._logger.debug('Running subprocess: %s \n\n' % ' '.join(cmd)) subprocess.check_call(' '.join(cmd), shell=True) def _check_model(self, problem_type, model_type, with_image=False): """Checks that both exported prediction graphs work.""" for has_target in [True, False]: if has_target: model_path = os.path.join(self._train_output, 'evaluation_model') else: model_path = os.path.join(self._train_output, 'model') self._logger.debug('Checking model %s %s at %s' % (problem_type, model_type, model_path)) # Check there is a saved model. self.assertTrue(os.path.isfile(os.path.join(model_path, 'saved_model.pb'))) # Must create new graphs as multiple graphs are loaded into memory. with tf.Graph().as_default(), tf.Session() as sess: meta_graph_pb = tf.saved_model.loader.load( sess=sess, tags=[tf.saved_model.tag_constants.SERVING], export_dir=model_path) signature = meta_graph_pb.signature_def['serving_default'] input_alias_map = { friendly_name: tensor_info_proto.name for (friendly_name, tensor_info_proto) in signature.inputs.items()} output_alias_map = { friendly_name: tensor_info_proto.name for (friendly_name, tensor_info_proto) in signature.outputs.items()} prediction_data = { 'key': [12, 11], 'target': [-49, -9] if problem_type == 'regression' else ['100', '101'], 'num_id': [11, 10], 'num_scale': [22.29, 5.20], 'str_one_hot': ['brown', 'brown'], 'str_embedding': ['def', 'def'], 'str_bow': ['drone', 'drone truck bike truck'], 'str_tfidf': ['bike train train car', 'train']} if with_image: image_bytes = [] for image_file in [self._image_files[0], self._image_files[2]]: with file_io.FileIO(image_file, 'r') as ff: image_bytes.append(base64.urlsafe_b64encode(ff.read())) prediction_data.update({'image': image_bytes}) # Convert the prediciton data to csv. csv_header = [col['name'] for col in self._schema if (has_target or col['name'] != 'target')] if not has_target: del prediction_data['target'] csv_data = [] for i in range(2): data = [str(prediction_data[name][i]) for name in csv_header] csv_data.append(','.join(data)) # Test the *_alias_maps have the expected keys expected_output_keys = ['predicted', 'key'] if has_target: expected_output_keys.append('target') if problem_type == 'classification': expected_output_keys.extend( ['probability', 'probability_2', 'probability_3', 'predicted_2', 'predicted_3']) self.assertEqual(1, len(input_alias_map.keys())) self.assertItemsEqual(expected_output_keys, output_alias_map.keys()) _, csv_tensor_name = input_alias_map.items()[0] result = sess.run(fetches=output_alias_map, feed_dict={csv_tensor_name: csv_data}) self.assertItemsEqual(expected_output_keys, result.keys()) self.assertEqual([12, 11], result['key'].flatten().tolist()) def testClassificationLinear(self): self._logger.debug('\n\nTesting Classification Linear') problem_type = 'classification' model_type = 'linear' self._run_analyze(problem_type) self._run_training_raw( problem_type=problem_type, model_type=model_type, extra_args=['--top-n=3']) self._check_model( problem_type=problem_type, model_type=model_type) def testRegressionLinear(self): self._logger.debug('\n\nTesting Regression Linear') problem_type = 'regression' model_type = 'linear' self._run_analyze(problem_type) self._run_transform() self._run_training_transform( problem_type=problem_type, model_type=model_type) self._check_model( problem_type=problem_type, model_type=model_type) def testRegressionDNN(self): self._logger.debug('\n\nTesting Regression DNN') problem_type = 'regression' model_type = 'dnn' self._run_analyze(problem_type) self._run_training_raw( problem_type=problem_type, model_type=model_type, extra_args=['--top-n=3', '--hidden-layer-size1=10', '--hidden-layer-size2=2']) self._check_model( problem_type=problem_type, model_type=model_type) def testClassificationDNNWithImage(self): self._logger.debug('\n\nTesting Classification DNN With Image') problem_type = 'classification' model_type = 'dnn' self._run_analyze(problem_type, with_image=True) self._run_transform() self._run_training_transform( problem_type=problem_type, model_type=model_type, extra_args=['--top-n=3', '--hidden-layer-size1=10']) self._check_model( problem_type=problem_type, model_type=model_type, with_image=True) def testTrainingWithAnalysis(self): self._logger.debug('\n\nTesting Training with Analysis') self._create_schema_features('classification') self._run_training_with_analysis( problem_type='classification', model_type='linear', extra_args=['--top-n=3']) self._check_model( problem_type='classification', model_type='linear') if __name__ == '__main__': unittest.main() ================================================ FILE: solutionbox/ml_workbench/test_tensorflow/test_transform.py ================================================ from __future__ import absolute_import from __future__ import print_function import json import os import pandas as pd from PIL import Image import shutil from six.moves.urllib.request import urlopen import subprocess import tempfile import unittest import uuid import tensorflow as tf from tensorflow.python.lib.io import file_io import google.datalab as dl import google.datalab.bigquery as bq import google.datalab.storage as storage CODE_PATH = os.path.abspath(os.path.join( os.path.dirname(__file__), '..', 'tensorflow')) # TODO: travis tests failed because sometimes a VM has gcloud signed-in # (maybe due to failed cleanup) with default project set and BQ is not enabled. # In that case the cloud tests will fail. Disable it for now. RUN_CLOUD_TESTS = False class TestTransformRawData(unittest.TestCase): """Tests for applying a saved model""" @classmethod def setUpClass(cls): # Set up dirs. cls.working_dir = tempfile.mkdtemp() cls.source_dir = os.path.join(cls.working_dir, 'source') cls.analysis_dir = os.path.join(cls.working_dir, 'analysis') cls.output_dir = os.path.join(cls.working_dir, 'output') file_io.create_dir(cls.source_dir) # Make test image files. img1_file = os.path.join(cls.source_dir, 'img1.jpg') image1 = Image.new('RGB', size=(300, 300), color=(155, 0, 0)) image1.save(img1_file) img2_file = os.path.join(cls.source_dir, 'img2.jpg') image2 = Image.new('RGB', size=(50, 50), color=(125, 240, 0)) image2.save(img2_file) img3_file = os.path.join(cls.source_dir, 'img3.jpg') image3 = Image.new('RGB', size=(800, 600), color=(33, 55, 77)) image3.save(img3_file) # Download inception checkpoint. Note that gs url doesn't work because # we may not have gcloud signed in when running the test. url = ('https://storage.googleapis.com/cloud-ml-data/img/' + 'flower_photos/inception_v3_2016_08_28.ckpt') checkpoint_path = os.path.join(cls.working_dir, "checkpoint") response = urlopen(url) with open(checkpoint_path, 'wb') as f: f.write(response.read()) # Make csv input file cls.csv_input_filepath = os.path.join(cls.source_dir, 'input.csv') file_io.write_string_to_file( cls.csv_input_filepath, '1,1,Monday,23.0,%s\n' % img1_file + '2,0,Friday,18.0,%s\n' % img2_file + '3,0,Sunday,12.0,%s\n' % img3_file) # Call analyze.py to create analysis results. schema = [{'name': 'key_col', 'type': 'INTEGER'}, {'name': 'target_col', 'type': 'FLOAT'}, {'name': 'cat_col', 'type': 'STRING'}, {'name': 'num_col', 'type': 'FLOAT'}, {'name': 'img_col', 'type': 'STRING'}] schema_file = os.path.join(cls.source_dir, 'schema.json') file_io.write_string_to_file(schema_file, json.dumps(schema)) features = {'key_col': {'transform': 'key'}, 'target_col': {'transform': 'target'}, 'cat_col': {'transform': 'one_hot'}, 'num_col': {'transform': 'identity'}, 'img_col': {'transform': 'image_to_vec', 'checkpoint': checkpoint_path}} features_file = os.path.join(cls.source_dir, 'features.json') file_io.write_string_to_file(features_file, json.dumps(features)) cmd = ['python ' + os.path.join(CODE_PATH, 'analyze.py'), '--output=' + cls.analysis_dir, '--csv=' + cls.csv_input_filepath, '--schema=' + schema_file, '--features=' + features_file] subprocess.check_call(' '.join(cmd), shell=True) @classmethod def tearDownClass(cls): shutil.rmtree(cls.working_dir) def test_local_csv_transform(self): """Test transfrom from local csv files.""" cmd = ['python ' + os.path.join(CODE_PATH, 'transform.py'), '--csv=' + self.csv_input_filepath, '--analysis=' + self.analysis_dir, '--prefix=features', '--output=' + self.output_dir] print('cmd ', ' '.join(cmd)) subprocess.check_call(' '.join(cmd), shell=True) # Read the tf record file. There should only be one file. record_filepath = os.path.join(self.output_dir, 'features-00000-of-00001.tfrecord.gz') options = tf.python_io.TFRecordOptions( compression_type=tf.python_io.TFRecordCompressionType.GZIP) serialized_examples = list(tf.python_io.tf_record_iterator(record_filepath, options=options)) self.assertEqual(len(serialized_examples), 3) # Find the example with key=1 in the file. first_example = None for ex in serialized_examples: example = tf.train.Example() example.ParseFromString(ex) if example.features.feature['key_col'].int64_list.value[0] == 1: first_example = example self.assertIsNotNone(first_example) transformed_number = first_example.features.feature['num_col'].float_list.value[0] self.assertAlmostEqual(transformed_number, 23.0) # transformed category = row number in the vocab file. transformed_category = first_example.features.feature['cat_col'].int64_list.value[0] vocab = pd.read_csv( os.path.join(self.analysis_dir, 'vocab_cat_col.csv'), header=None, names=['label', 'count'], dtype=str) origional_category = vocab.iloc[transformed_category]['label'] self.assertEqual(origional_category, 'Monday') image_bytes = first_example.features.feature['img_col'].float_list.value self.assertEqual(len(image_bytes), 2048) self.assertTrue(any(x != 0.0 for x in image_bytes)) @unittest.skipIf(not RUN_CLOUD_TESTS, 'GCS access missing') def test_local_bigquery_transform(self): """Test transfrom locally, but the data comes from bigquery.""" # Make a BQ table, and insert 1 row. try: bucket_name = 'temp_pydatalab_test_%s' % uuid.uuid4().hex bucket_root = 'gs://%s' % bucket_name bucket = storage.Bucket(bucket_name) bucket.create() project_id = dl.Context.default().project_id dataset_name = 'test_transform_raw_data_%s' % uuid.uuid4().hex table_name = 'tmp_table' dataset = bq.Dataset((project_id, dataset_name)).create() table = bq.Table((project_id, dataset_name, table_name)) table.create([{'name': 'key_col', 'type': 'INTEGER'}, {'name': 'target_col', 'type': 'FLOAT'}, {'name': 'cat_col', 'type': 'STRING'}, {'name': 'num_col', 'type': 'FLOAT'}, {'name': 'img_col', 'type': 'STRING'}]) img1_file = os.path.join(self.source_dir, 'img1.jpg') dest_file = os.path.join(bucket_root, 'img1.jpg') file_io.copy(img1_file, dest_file) data = [ { 'key_col': 1, 'target_col': 1.0, 'cat_col': 'Monday', 'num_col': 23.0, 'img_col': dest_file, }, ] table.insert(data=data) cmd = ['python ' + os.path.join(CODE_PATH, 'transform.py'), '--bigquery=%s.%s.%s' % (project_id, dataset_name, table_name), '--analysis=' + self.analysis_dir, '--prefix=features', '--project-id=' + project_id, '--output=' + self.output_dir] print('cmd ', ' '.join(cmd)) subprocess.check_call(' '.join(cmd), shell=True) # Read the tf record file. There should only be one file. record_filepath = os.path.join(self.output_dir, 'features-00000-of-00001.tfrecord.gz') options = tf.python_io.TFRecordOptions( compression_type=tf.python_io.TFRecordCompressionType.GZIP) serialized_examples = list(tf.python_io.tf_record_iterator(record_filepath, options=options)) self.assertEqual(len(serialized_examples), 1) example = tf.train.Example() example.ParseFromString(serialized_examples[0]) transformed_number = example.features.feature['num_col'].float_list.value[0] self.assertAlmostEqual(transformed_number, 23.0) transformed_category = example.features.feature['cat_col'].int64_list.value[0] self.assertEqual(transformed_category, 2) image_bytes = example.features.feature['img_col'].float_list.value self.assertEqual(len(image_bytes), 2048) self.assertTrue(any(x != 0.0 for x in image_bytes)) finally: dataset.delete(delete_contents=True) for obj in bucket.objects(): obj.delete() bucket.delete() if __name__ == '__main__': unittest.main() ================================================ FILE: solutionbox/ml_workbench/test_xgboost/run_all.sh ================================================ #! /bin/bash set -e echo '*** Running xgboost test_analyze.py ***' python test_analyze.py --verbose echo '*** Running xgboost test_transform.py ***' python test_transform.py --verbose echo 'Finished xgboost run_all.sh!' ================================================ FILE: solutionbox/ml_workbench/test_xgboost/test_analyze.py ================================================ from __future__ import absolute_import from __future__ import print_function import json import os import shutil import sys import tempfile import unittest import pandas as pd import six from tensorflow.python.lib.io import file_io # To make 'import analyze' work without installing it. CODE_PATH = os.path.abspath( os.path.join(os.path.dirname(__file__), '..', '', 'xgboost')) sys.path.append(CODE_PATH) from trainer import feature_analysis as feature_analysis # noqa: E303 import analyze # noqa: E303 class TestConfigFiles(unittest.TestCase): """Tests for checking the format between the schema and features files.""" def test_expand_defaults_do_nothing(self): schema = [{'name': 'col1', 'type': 'FLOAT'}, {'name': 'col2', 'type': 'INTEGER'}] features = {'col1': {'transform': 'x'}, 'col2': {'transform': 'y'}} expected_features = { 'col1': {'transform': 'x', 'source_column': 'col1'}, 'col2': {'transform': 'y', 'source_column': 'col2'}} feature_analysis.expand_defaults(schema, features) # Nothing should change. self.assertEqual(expected_features, features) def test_expand_defaults_unknown_schema_type(self): schema = [{'name': 'col1', 'type': 'BYTES'}, {'name': 'col2', 'type': 'INTEGER'}] features = {'col1': {'transform': 'x'}, 'col2': {'transform': 'y'}} with self.assertRaises(ValueError): feature_analysis.expand_defaults(schema, features) def test_expand_defaults(self): schema = [{'name': 'col1', 'type': 'FLOAT'}, {'name': 'col2', 'type': 'INTEGER'}, {'name': 'col3', 'type': 'STRING'}, {'name': 'col4', 'type': 'FLOAT'}, {'name': 'col5', 'type': 'INTEGER'}, {'name': 'col6', 'type': 'STRING'}] features = {'col1': {'transform': 'x'}, 'col2': {'transform': 'y'}, 'col3': {'transform': 'z'}} feature_analysis.expand_defaults(schema, features) self.assertEqual( features, {'col1': {'transform': 'x', 'source_column': 'col1'}, 'col2': {'transform': 'y', 'source_column': 'col2'}, 'col3': {'transform': 'z', 'source_column': 'col3'}, 'col4': {'transform': 'identity', 'source_column': 'col4'}, 'col5': {'transform': 'identity', 'source_column': 'col5'}, 'col6': {'transform': 'one_hot', 'source_column': 'col6'}}) class TestLocalAnalyze(unittest.TestCase): """Test local analyze functions.""" def test_numerics(self): output_folder = tempfile.mkdtemp() input_file_path = tempfile.mkstemp(dir=output_folder)[1] try: file_io.write_string_to_file( input_file_path, '\n'.join(['%s,%s,%s' % (i, 10 * i + 0.5, i + 0.5) for i in range(100)])) schema = [{'name': 'col1', 'type': 'INTEGER'}, {'name': 'col2', 'type': 'FLOAT'}, {'name': 'col3', 'type': 'FLOAT'}] features = {'col1': {'transform': 'scale', 'source_column': 'col1'}, 'col2': {'transform': 'identity', 'source_column': 'col2'}, 'col3': {'transform': 'target'}} feature_analysis.run_local_analysis( output_folder, [input_file_path], schema, features) stats = json.loads( file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.STATS_FILE)).decode()) self.assertEqual(stats['num_examples'], 100) col = stats['column_stats']['col1'] self.assertAlmostEqual(col['max'], 99.0) self.assertAlmostEqual(col['min'], 0.0) self.assertAlmostEqual(col['mean'], 49.5) col = stats['column_stats']['col2'] self.assertAlmostEqual(col['max'], 990.5) self.assertAlmostEqual(col['min'], 0.5) self.assertAlmostEqual(col['mean'], 495.5) finally: shutil.rmtree(output_folder) def test_categorical(self): output_folder = tempfile.mkdtemp() input_file_path = tempfile.mkstemp(dir=output_folder)[1] try: csv_file = ['red,apple', 'red,pepper', 'red,apple', 'blue,grape', 'blue,apple', 'green,pepper'] file_io.write_string_to_file( input_file_path, '\n'.join(csv_file)) schema = [{'name': 'color', 'type': 'STRING'}, {'name': 'type', 'type': 'STRING'}] features = {'color': {'transform': 'one_hot', 'source_column': 'color'}, 'type': {'transform': 'target'}} feature_analysis.run_local_analysis( output_folder, [input_file_path], schema, features) stats = json.loads( file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.STATS_FILE)).decode()) self.assertEqual(stats['column_stats']['color']['vocab_size'], 3) # Color column. vocab_str = file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.VOCAB_ANALYSIS_FILE % 'color')) vocab = pd.read_csv(six.StringIO(vocab_str), header=None, names=['color', 'count']) expected_vocab = pd.DataFrame( {'color': ['red', 'blue', 'green'], 'count': [3, 2, 1]}, columns=['color', 'count']) pd.util.testing.assert_frame_equal(vocab, expected_vocab) finally: shutil.rmtree(output_folder) def test_text(self): output_folder = tempfile.mkdtemp() input_file_path = tempfile.mkstemp(dir=output_folder)[1] try: csv_file = ['the quick brown fox,cat1|cat2,true', 'quick brown brown chicken,cat2|cat3|cat4,false'] file_io.write_string_to_file( input_file_path, '\n'.join(csv_file)) schema = [{'name': 'col1', 'type': 'STRING'}, {'name': 'col2', 'type': 'STRING'}, {'name': 'col3', 'type': 'STRING'}] features = {'col1': {'transform': 'multi_hot', 'source_column': 'col1'}, 'col2': {'transform': 'multi_hot', 'source_column': 'col2', 'separator': '|'}, 'col3': {'transform': 'target'}} feature_analysis.run_local_analysis( output_folder, [input_file_path], schema, features) stats = json.loads( file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.STATS_FILE)).decode()) self.assertEqual(stats['column_stats']['col1']['vocab_size'], 5) self.assertEqual(stats['column_stats']['col2']['vocab_size'], 4) vocab_str = file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.VOCAB_ANALYSIS_FILE % 'col1')) vocab = pd.read_csv(six.StringIO(vocab_str), header=None, names=['col1', 'count']) # vocabs are sorted by count only col1_vocab = vocab['col1'].tolist() self.assertItemsEqual(col1_vocab[:2], ['brown', 'quick']) self.assertItemsEqual(col1_vocab[2:], ['chicken', 'fox', 'the']) self.assertEqual(vocab['count'].tolist(), [2, 2, 1, 1, 1]) vocab_str = file_io.read_file_to_string( os.path.join(output_folder, analyze.constant.VOCAB_ANALYSIS_FILE % 'col2')) vocab = pd.read_csv(six.StringIO(vocab_str), header=None, names=['col2', 'count']) # vocabs are sorted by count only col2_vocab = vocab['col2'].tolist() self.assertItemsEqual(col2_vocab, ['cat2', 'cat1', 'cat3', 'cat4']) self.assertEqual(vocab['count'].tolist(), [2, 1, 1, 1]) finally: shutil.rmtree(output_folder) if __name__ == '__main__': unittest.main() ================================================ FILE: solutionbox/ml_workbench/test_xgboost/test_transform.py ================================================ from __future__ import absolute_import from __future__ import print_function import json import os import pandas as pd from PIL import Image from six.moves.urllib.request import urlopen import subprocess import tempfile import unittest import xgboost as xgb from tensorflow.python.lib.io import file_io CODE_PATH = os.path.abspath(os.path.join( os.path.dirname(__file__), '..', 'xgboost')) class TestTransformRawData(unittest.TestCase): """Tests for applying a saved model""" @classmethod def setUpClass(cls): # Set up dirs. cls.working_dir = tempfile.mkdtemp() cls.source_dir = os.path.join(cls.working_dir, 'source') cls.analysis_dir = os.path.join(cls.working_dir, 'analysis') cls.output_dir = os.path.join(cls.working_dir, 'output') file_io.create_dir(cls.source_dir) # Make test image files. img1_file = os.path.join(cls.source_dir, 'img1.jpg') image1 = Image.new('RGB', size=(300, 300), color=(155, 0, 0)) image1.save(img1_file) img2_file = os.path.join(cls.source_dir, 'img2.jpg') image2 = Image.new('RGB', size=(50, 50), color=(125, 240, 0)) image2.save(img2_file) img3_file = os.path.join(cls.source_dir, 'img3.jpg') image3 = Image.new('RGB', size=(800, 600), color=(33, 55, 77)) image3.save(img3_file) # Download inception checkpoint. Note that gs url doesn't work because # we may not have gcloud signed in when running the test. url = ('https://storage.googleapis.com/cloud-ml-data/img/' + 'flower_photos/inception_v3_2016_08_28.ckpt') checkpoint_path = os.path.join(cls.working_dir, "checkpoint") response = urlopen(url) with open(checkpoint_path, 'wb') as f: f.write(response.read()) # Make csv input file cls.csv_input_filepath = os.path.join(cls.source_dir, 'input.csv') file_io.write_string_to_file( cls.csv_input_filepath, '1,Monday,23.0,red blue,%s\n' % img1_file + '0,Friday,18.0,green,%s\n' % img2_file + '0,Sunday,12.0,green red blue green,%s\n' % img3_file) # Call analyze.py to create analysis results. schema = [{'name': 'target_col', 'type': 'FLOAT'}, {'name': 'cat_col', 'type': 'STRING'}, {'name': 'num_col', 'type': 'FLOAT'}, {'name': 'text_col', 'type': 'STRING'}, {'name': 'img_col', 'type': 'STRING'}] schema_file = os.path.join(cls.source_dir, 'schema.json') file_io.write_string_to_file(schema_file, json.dumps(schema)) features = {'target_col': {'transform': 'target'}, 'cat_col': {'transform': 'one_hot'}, 'num_col': {'transform': 'identity'}, 'text_col': {'transform': 'multi_hot'}, 'img_col': {'transform': 'image_to_vec', 'checkpoint': checkpoint_path}} features_file = os.path.join(cls.source_dir, 'features.json') file_io.write_string_to_file(features_file, json.dumps(features)) cmd = ['python ' + os.path.join(CODE_PATH, 'analyze.py'), '--output=' + cls.analysis_dir, '--csv=' + cls.csv_input_filepath, '--schema=' + schema_file, '--features=' + features_file] subprocess.check_call(' '.join(cmd), shell=True) @classmethod def tearDownClass(cls): pass # shutil.rmtree(cls.working_dir) def test_local_csv_transform(self): """Test transfrom from local csv files.""" cmd = ['python ' + os.path.join(CODE_PATH, 'transform.py'), '--csv=' + self.csv_input_filepath, '--analysis=' + self.analysis_dir, '--prefix=features', '--output=' + self.output_dir] print('cmd ', ' '.join(cmd)) subprocess.check_call(' '.join(cmd), shell=True) # Verify transformed file. libsvm_filepath = os.path.join(self.output_dir, 'features-00000-of-00001.libsvm') dtrain = xgb.DMatrix(libsvm_filepath) self.assertTrue(2056, dtrain.num_col()) self.assertTrue(3, dtrain.num_row()) # Verify featuremap file. featuremap_filepath = os.path.join(self.output_dir, 'featuremap-00000-of-00001.txt') df = pd.read_csv(featuremap_filepath, names=['index', 'description']) pd.util.testing.assert_series_equal(pd.Series(range(1, 2056), name='index'), df['index']) expected_descriptions = ['cat_col=Sunday', 'cat_col=Monday', 'img_col image feature 1000', 'num_col', 'text_col has "blue"'] self.assertTrue(all(x in df['description'].values for x in expected_descriptions)) if __name__ == '__main__': unittest.main() ================================================ FILE: solutionbox/ml_workbench/xgboost/__init__.py ================================================ ================================================ FILE: solutionbox/ml_workbench/xgboost/analyze.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import copy import json import os import sys import six import textwrap from tensorflow.python.lib.io import file_io from trainer import feature_transforms as constant from trainer import feature_analysis as feature_analysis def parse_arguments(argv): """Parse command line arguments. Args: argv: list of command line arguments, including program name. Returns: An argparse Namespace object. Raises: ValueError: for bad parameters """ parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, description=textwrap.dedent("""\ Runs analysis on structured data and produces auxiliary files for training. The output files can also be used by the Transform step to materialize TF.Examples files, which for some problems can speed up training. Description of input files -------------------------- 1) If using csv files, the --schema parameter must be the file path to a schema file. The format of this file must be a valid BigQuery schema file, which is a JSON file containing a list of dicts. Consider the example schema file below: [ {"name": "column_name_1", "type": "integer"}, {"name": "column_name_2", "type": "float"}, {"name": "column_name_3", "type": "string"}, {"name": "column_name_4", "type": "string"}, ] Note that the column names in the csv file much match the order in the schema list. Also, we only support three BigQuery types ( integer, float, and string). If instead of csv files, --bigquery is used, the schema file is not needed as this program will extract it from the table directly. 2) --features is a file path to a file describing the transformations. Below is an example features file: { "column_name_1": {"transform": "scale"}, "column_name_3": {"transform": "target"}, "column_name_2": {"transform": "one_hot"}, "new_feature_name": {"transform": "multi_hot", "source_column": "column_name_4"}, } The format of the dict is `name`: `transform-dict` where the `name` is the name of the transformed feature. The `source_column` value lists what column in the input data is the source for this transformation. If `source_column` is missing, it is assumed the `name` is a source column and the transformed feature will have the same name as the input column. A list of supported `transform-dict`s for xgboost is below: {"transform": "identity"}: does nothing (for numerical columns). {"transform": "scale", "value": x}: scale a numerical column to [-a, a]. If value is missing, x defaults to 1. {"transform": "one_hot"}: makes a one-hot encoding of a string column. {"transform": "multi_hot", "separator": ' '}: makes a multi-hot encoding of a string column. {"transform": "image_to_vec", "checkpoint": "gs://b/o"}: From image gs url to embeddings. "checkpoint" is a inception v3 checkpoint. If absent, a default checkpoint is used. {"transform": "target"}: denotes what column is the target. If the schema type of this column is string, a one_hot encoding is automatically applied. If type is numerical, a identity transform is automatically applied. """)) parser.add_argument('--cloud', action='store_true', help='Analysis will use cloud services.') parser.add_argument('--output', metavar='DIR', type=str, required=True, help='GCS or local folder') input_group = parser.add_argument_group( title='Data Source Parameters', description='schema is only needed if using --csv') # CSV input input_group.add_argument('--csv', metavar='FILE', type=str, required=False, action='append', help='Input CSV absolute file paths. May contain a ' 'file pattern.') input_group.add_argument('--schema', metavar='FILE', type=str, required=False, help='Schema file path. Only required if using csv files') # Bigquery input input_group.add_argument('--bigquery', metavar='PROJECT_ID.DATASET.TABLE_NAME', type=str, required=False, help=('Must be in the form project.dataset.table_name')) parser.add_argument('--features', metavar='FILE', type=str, required=True, help='Features file path') args = parser.parse_args(args=argv[1:]) if args.cloud: if not args.output.startswith('gs://'): raise ValueError('--output must point to a location on GCS') if (args.csv and not all(x.startswith('gs://') for x in args.csv)): raise ValueError('--csv must point to a location on GCS') if args.schema and not args.schema.startswith('gs://'): raise ValueError('--schema must point to a location on GCS') if not args.cloud and args.bigquery: raise ValueError('--bigquery must be used with --cloud') if not ((args.bigquery and args.csv is None and args.schema is None) or (args.bigquery is None and args.csv and args.schema)): raise ValueError('either --csv and --schema must both' ' be set or just --bigquery is set') return args def run_cloud_analysis(output_dir, csv_file_pattern, bigquery_table, schema, features): """Use BigQuery to analyze input date. Only one of csv_file_pattern or bigquery_table should be non-None. Args: output_dir: output folder csv_file_pattern: list of csv file paths, may contain wildcards bigquery_table: project_id.dataset_name.table_name schema: schema list features: features config """ def _execute_sql(sql, table): """Runs a BigQuery job and dowloads the results into local memeory. Args: sql: a SQL string table: bq.ExternalDataSource or bq.Table Returns: A Pandas dataframe. """ import google.datalab.bigquery as bq if isinstance(table, bq.ExternalDataSource): query = bq.Query(sql, data_sources={'csv_table': table}) else: query = bq.Query(sql) return query.execute().result().to_dataframe() feature_analysis.expand_defaults(schema, features) # features are updated. inverted_features = feature_analysis.invert_features(features) feature_analysis.check_schema_transforms_match(schema, inverted_features) import google.datalab.bigquery as bq if bigquery_table: table_name = '`%s`' % bigquery_table table = None else: table_name = 'csv_table' table = bq.ExternalDataSource( source=csv_file_pattern, schema=bq.Schema(schema)) # Make a copy of inverted_features and update the target transform to be # identity or one hot depending on the schema. inverted_features_target = copy.deepcopy(inverted_features) for name, transforms in six.iteritems(inverted_features_target): transform_set = {x['transform'] for x in transforms} if transform_set == set([constant.TARGET_TRANSFORM]): target_schema = next(col['type'].lower() for col in schema if col['name'] == name) if target_schema in constant.NUMERIC_SCHEMA: inverted_features_target[name] = [{'transform': constant.IDENTITY_TRANSFORM}] else: inverted_features_target[name] = [{'transform': constant.ONE_HOT_TRANSFORM}] numerical_vocab_stats = {} for col_name, transform_set in six.iteritems(inverted_features_target): sys.stdout.write('Analyzing column %s...\n' % col_name) sys.stdout.flush() # All transforms in transform_set require the same analysis. So look # at the first transform. transform = next(iter(transform_set)) if (transform['transform'] in constant.CATEGORICAL_TRANSFORMS or transform['transform'] in constant.TEXT_TRANSFORMS): if transform['transform'] in constant.TEXT_TRANSFORMS: # Split strings on space, then extract labels and how many rows each # token is in. This is done by making two temp tables: # SplitTable: each text row is made into an array of strings. The # array may contain repeat tokens # TokenTable: SplitTable with repeated tokens removed per row. # Then to flatten the arrays, TokenTable has to be joined with itself. # See the sections 'Flattening Arrays' and 'Filtering Arrays' at # https://cloud.google.com/bigquery/docs/reference/standard-sql/arrays separator = transform.get('separator', ' ') sql = ('WITH SplitTable AS ' ' (SELECT SPLIT({name}, \'{separator}\') as token_array FROM {table}), ' ' TokenTable AS ' ' (SELECT ARRAY(SELECT DISTINCT x ' ' FROM UNNEST(token_array) AS x) AS unique_tokens_per_row ' ' FROM SplitTable) ' 'SELECT token, COUNT(token) as token_count ' 'FROM TokenTable ' 'CROSS JOIN UNNEST(TokenTable.unique_tokens_per_row) as token ' 'WHERE LENGTH(token) > 0 ' 'GROUP BY token ' 'ORDER BY token_count DESC, token ASC').format(separator=separator, name=col_name, table=table_name) else: # Extract label and frequency sql = ('SELECT {name} as token, count(*) as count ' 'FROM {table} ' 'WHERE {name} IS NOT NULL ' 'GROUP BY {name} ' 'ORDER BY count DESC, token ASC').format(name=col_name, table=table_name) df = _execute_sql(sql, table) # Save the vocab csv_string = df.to_csv(index=False, header=False) file_io.write_string_to_file( os.path.join(output_dir, constant.VOCAB_ANALYSIS_FILE % col_name), csv_string) numerical_vocab_stats[col_name] = {'vocab_size': len(df)} # free memeory del csv_string del df elif transform['transform'] in constant.NUMERIC_TRANSFORMS: # get min/max/average sql = ('SELECT max({name}) as max_value, min({name}) as min_value, ' 'avg({name}) as avg_value from {table}').format(name=col_name, table=table_name) df = _execute_sql(sql, table) numerical_vocab_stats[col_name] = {'min': df.iloc[0]['min_value'], 'max': df.iloc[0]['max_value'], 'mean': df.iloc[0]['avg_value']} sys.stdout.write('column %s analyzed.\n' % col_name) sys.stdout.flush() # get num examples sql = 'SELECT count(*) as num_examples from {table}'.format(table=table_name) df = _execute_sql(sql, table) num_examples = df.iloc[0]['num_examples'] # Write the stats file. stats = {'column_stats': numerical_vocab_stats, 'num_examples': num_examples} file_io.write_string_to_file( os.path.join(output_dir, constant.STATS_FILE), json.dumps(stats, indent=2, separators=(',', ': '))) feature_analysis.save_schema_features(schema, features, output_dir) def main(argv=None): args = parse_arguments(sys.argv if argv is None else argv) if args.schema: schema = json.loads( file_io.read_file_to_string(args.schema).decode()) else: import google.datalab.bigquery as bq schema = bq.Table(args.bigquery).schema._bq_schema features = json.loads( file_io.read_file_to_string(args.features).decode()) file_io.recursive_create_dir(args.output) if args.cloud: run_cloud_analysis( output_dir=args.output, csv_file_pattern=args.csv, bigquery_table=args.bigquery, schema=schema, features=features) else: feature_analysis.run_local_analysis( output_dir=args.output, csv_file_pattern=args.csv, schema=schema, features=features) if __name__ == '__main__': main() ================================================ FILE: solutionbox/ml_workbench/xgboost/setup.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. # This setup file is used when running cloud training or cloud dataflow jobs. from setuptools import setup, find_packages setup( name='trainer', version='1.0.0', packages=find_packages(), description='Google Cloud Datalab helper sub-package', author='Google', author_email='google-cloud-datalab-feedback@googlegroups.com', keywords=[ ], license="Apache Software License", long_description=""" """, install_requires=[ 'tensorflow==1.15.2', 'protobuf==3.4.0', 'pillow==6.2.0', # ML Engine does not have PIL installed 'xgboost==0.6a2', ], package_data={ }, data_files=[], ) ================================================ FILE: solutionbox/ml_workbench/xgboost/trainer/__init__.py ================================================ ================================================ FILE: solutionbox/ml_workbench/xgboost/trainer/feature_analysis.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import copy import csv import json import os import pandas as pd import sys import six from tensorflow.python.lib.io import file_io from . import feature_transforms as constant def check_schema_transforms_match(schema, inverted_features): """Checks that the transform and schema do not conflict. Args: schema: schema list inverted_features: inverted_features dict Raises: ValueError if transform cannot be applied given schema type. """ num_target_transforms = 0 for col_schema in schema: col_name = col_schema['name'] col_type = col_schema['type'].lower() # Check each transform and schema are compatible if col_name in inverted_features: for transform in inverted_features[col_name]: transform_name = transform['transform'] if transform_name == constant.TARGET_TRANSFORM: num_target_transforms += 1 continue elif col_type in constant.NUMERIC_SCHEMA: if transform_name not in constant.NUMERIC_TRANSFORMS: raise ValueError( 'Transform %s not supported by schema %s' % (transform_name, col_type)) elif col_type == constant.STRING_SCHEMA: if (transform_name not in constant.CATEGORICAL_TRANSFORMS + constant.TEXT_TRANSFORMS and transform_name != constant.IMAGE_TRANSFORM): raise ValueError( 'Transform %s not supported by schema %s' % (transform_name, col_type)) else: raise ValueError('Unsupported schema type %s' % col_type) # Check each transform is compatible for the same source column. # inverted_features[col_name] should belong to exactly 1 of the 5 groups. if col_name in inverted_features: transform_set = {x['transform'] for x in inverted_features[col_name]} if 1 != sum([transform_set.issubset(set(constant.NUMERIC_TRANSFORMS)), transform_set.issubset(set(constant.CATEGORICAL_TRANSFORMS)), transform_set.issubset(set(constant.TEXT_TRANSFORMS)), transform_set.issubset(set([constant.IMAGE_TRANSFORM])), transform_set.issubset(set([constant.TARGET_TRANSFORM]))]): message = """ The source column of a feature can only be used in multiple features within the same family of transforms. The familes are 1) text transformations: %s 2) categorical transformations: %s 3) numerical transformations: %s 4) image transformations: %s 5) target transform: %s Any column can also be a key column. But column %s is used by transforms %s. """ % (str(constant.TEXT_TRANSFORMS), str(constant.CATEGORICAL_TRANSFORMS), str(constant.NUMERIC_TRANSFORMS), constant.IMAGE_TRANSFORM, constant.TARGET_TRANSFORM, col_name, str(transform_set)) raise ValueError(message) if num_target_transforms != 1: raise ValueError('Must have exactly one target transform') def save_schema_features(schema, features, output): # Save a copy of the schema and features in the output folder. file_io.write_string_to_file( os.path.join(output, constant.SCHEMA_FILE), json.dumps(schema, indent=2)) file_io.write_string_to_file( os.path.join(output, constant.FEATURES_FILE), json.dumps(features, indent=2)) def expand_defaults(schema, features): """Add to features any default transformations. Not every column in the schema has an explicit feature transformation listed in the featurs file. For these columns, add a default transformation based on the schema's type. The features dict is modified by this function call. After this function call, every column in schema is used in a feature, and every feature uses a column in the schema. Args: schema: schema list features: features dict Raises: ValueError: if transform cannot be applied given schema type. """ schema_names = [x['name'] for x in schema] # Add missing source columns for name, transform in six.iteritems(features): if 'source_column' not in transform: transform['source_column'] = name # Check source columns are in the schema and collect which are used. used_schema_columns = [] for name, transform in six.iteritems(features): if transform['source_column'] not in schema_names: raise ValueError('source column %s is not in the schema for transform %s' % (transform['source_column'], name)) used_schema_columns.append(transform['source_column']) # Update default transformation based on schema. for col_schema in schema: schema_name = col_schema['name'] schema_type = col_schema['type'].lower() if schema_type not in constant.NUMERIC_SCHEMA + [constant.STRING_SCHEMA]: raise ValueError(('Only the following schema types are supported: %s' % ' '.join(constant.NUMERIC_SCHEMA + [constant.STRING_SCHEMA]))) if schema_name not in used_schema_columns: # add the default transform to the features if schema_type in constant.NUMERIC_SCHEMA: features[schema_name] = { 'transform': constant.DEFAULT_NUMERIC_TRANSFORM, 'source_column': schema_name} elif schema_type == constant.STRING_SCHEMA: features[schema_name] = { 'transform': constant.DEFAULT_CATEGORICAL_TRANSFORM, 'source_column': schema_name} else: raise NotImplementedError('Unknown type %s' % schema_type) # TODO(qimingj): introduce the notion an analysis plan/classes if we # support more complicated transforms like binning by quratiles. def invert_features(features): """Make a dict in the form source column : set of transforms. Note that the key transform is removed. """ inverted_features = collections.defaultdict(list) for transform in six.itervalues(features): source_column = transform['source_column'] inverted_features[source_column].append(transform) return dict(inverted_features) # convert from defaultdict to dict def run_local_analysis(output_dir, csv_file_pattern, schema, features): """Use pandas to analyze csv files. Produces a stats file and vocab files. Args: output_dir: output folder csv_file_pattern: list of csv file paths, may contain wildcards schema: CSV schema list features: features config Raises: ValueError: on unknown transfrorms/schemas """ sys.stdout.write('Expanding any file patterns...\n') sys.stdout.flush() header = [column['name'] for column in schema] input_files = [] for file_pattern in csv_file_pattern: input_files.extend(file_io.get_matching_files(file_pattern)) sys.stdout.write('file list computed.\n') sys.stdout.flush() expand_defaults(schema, features) # features are updated. inverted_features = invert_features(features) check_schema_transforms_match(schema, inverted_features) # Make a copy of inverted_features and update the target transform to be # identity or one hot depending on the schema. inverted_features_target = copy.deepcopy(inverted_features) for name, transforms in six.iteritems(inverted_features_target): transform_set = {x['transform'] for x in transforms} if transform_set == set([constant.TARGET_TRANSFORM]): target_schema = next(col['type'].lower() for col in schema if col['name'] == name) if target_schema in constant.NUMERIC_SCHEMA: inverted_features_target[name] = [{'transform': constant.IDENTITY_TRANSFORM}] else: inverted_features_target[name] = [{'transform': constant.ONE_HOT_TRANSFORM}] # initialize the results def _init_numerical_results(): return {'min': float('inf'), 'max': float('-inf'), 'count': 0, 'sum': 0.0} numerical_results = collections.defaultdict(_init_numerical_results) vocabs = collections.defaultdict(lambda: collections.defaultdict(int)) num_examples = 0 # for each file, update the numerical stats from that file, and update the set # of unique labels. for input_file in input_files: sys.stdout.write('Analyzing file %s...\n' % input_file) sys.stdout.flush() with file_io.FileIO(input_file, 'r') as f: for line in csv.reader(f): if len(header) != len(line): raise ValueError('Schema has %d columns but a csv line only has %d columns.' % (len(header), len(line))) parsed_line = dict(zip(header, line)) num_examples += 1 for col_name, transform_set in six.iteritems(inverted_features_target): # All transforms in transform_set require the same analysis. So look # at the first transform. transform = next(iter(transform_set)) if transform['transform'] in constant.TEXT_TRANSFORMS: separator = transform.get('separator', ' ') split_strings = parsed_line[col_name].split(separator) # If a label is in the row N times, increase it's vocab count by 1. # This is needed for TFIDF, but it's also an interesting stat. for one_label in set(split_strings): # Filter out empty strings if one_label: vocabs[col_name][one_label] += 1 elif transform['transform'] in constant.CATEGORICAL_TRANSFORMS: if parsed_line[col_name]: vocabs[col_name][parsed_line[col_name]] += 1 elif transform['transform'] in constant.NUMERIC_TRANSFORMS: if not parsed_line[col_name].strip(): continue numerical_results[col_name]['min'] = ( min(numerical_results[col_name]['min'], float(parsed_line[col_name]))) numerical_results[col_name]['max'] = ( max(numerical_results[col_name]['max'], float(parsed_line[col_name]))) numerical_results[col_name]['count'] += 1 numerical_results[col_name]['sum'] += float(parsed_line[col_name]) sys.stdout.write('file %s analyzed.\n' % input_file) sys.stdout.flush() # Write the vocab files. Each label is on its own line. vocab_sizes = {} for name, label_count in six.iteritems(vocabs): # df is now: # label1,count # label2,count # ... # where label1 is the most frequent label, and label2 is the 2nd most, etc. df = pd.DataFrame([{'label': label, 'count': count} for label, count in sorted(six.iteritems(label_count), key=lambda x: x[1], reverse=True)], columns=['label', 'count']) csv_string = df.to_csv(index=False, header=False) file_io.write_string_to_file( os.path.join(output_dir, constant.VOCAB_ANALYSIS_FILE % name), csv_string) vocab_sizes[name] = {'vocab_size': len(label_count)} # Update numerical_results to just have min/min/mean for col_name in numerical_results: if float(numerical_results[col_name]['count']) == 0: raise ValueError('Column %s has a zero count' % col_name) mean = (numerical_results[col_name]['sum'] / float(numerical_results[col_name]['count'])) del numerical_results[col_name]['sum'] del numerical_results[col_name]['count'] numerical_results[col_name]['mean'] = mean # Write the stats file. numerical_results.update(vocab_sizes) stats = {'column_stats': numerical_results, 'num_examples': num_examples} file_io.write_string_to_file( os.path.join(output_dir, constant.STATS_FILE), json.dumps(stats, indent=2, separators=(',', ': '))) save_schema_features(schema, features, output_dir) ================================================ FILE: solutionbox/ml_workbench/xgboost/trainer/feature_transforms.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function import base64 import cStringIO import json import os from PIL import Image import pandas as pd import six import shutil import tensorflow as tf import tempfile from tensorflow.contrib.learn.python.learn.utils import input_fn_utils from tensorflow.contrib import lookup from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3 from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_arg_scope from tensorflow.python.lib.io import file_io # ------------------------------------------------------------------------------ # public constants. Changing these could break user's code # ------------------------------------------------------------------------------ # Individual transforms IDENTITY_TRANSFORM = 'identity' SCALE_TRANSFORM = 'scale' ONE_HOT_TRANSFORM = 'one_hot' MULTI_HOT_TRANSFORM = 'multi_hot' TARGET_TRANSFORM = 'target' IMAGE_TRANSFORM = 'image_to_vec' # ------------------------------------------------------------------------------ # internal constants. # ------------------------------------------------------------------------------ # Files SCHEMA_FILE = 'schema.json' FEATURES_FILE = 'features.json' STATS_FILE = 'stats.json' VOCAB_ANALYSIS_FILE = 'vocab_%s.csv' # Transform collections NUMERIC_TRANSFORMS = [IDENTITY_TRANSFORM, SCALE_TRANSFORM] CATEGORICAL_TRANSFORMS = [ONE_HOT_TRANSFORM] TEXT_TRANSFORMS = [MULTI_HOT_TRANSFORM] # If the features file is missing transforms, apply these. DEFAULT_NUMERIC_TRANSFORM = IDENTITY_TRANSFORM DEFAULT_CATEGORICAL_TRANSFORM = ONE_HOT_TRANSFORM # BigQuery Schema values supported INTEGER_SCHEMA = 'integer' FLOAT_SCHEMA = 'float' STRING_SCHEMA = 'string' NUMERIC_SCHEMA = [INTEGER_SCHEMA, FLOAT_SCHEMA] # Inception Checkpoint INCEPTION_V3_CHECKPOINT = 'gs://cloud-ml-data/img/flower_photos/inception_v3_2016_08_28.ckpt' INCEPTION_EXCLUDED_VARIABLES = ['InceptionV3/AuxLogits', 'InceptionV3/Logits', 'global_step'] _img_buf = cStringIO.StringIO() Image.new('RGB', (16, 16)).save(_img_buf, 'jpeg') IMAGE_DEFAULT_STRING = base64.urlsafe_b64encode(_img_buf.getvalue()) IMAGE_BOTTLENECK_TENSOR_SIZE = 2048 # ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------ # start of transform functions # ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------ def _scale(x, min_x_value, max_x_value, output_min, output_max): """Scale a column to [output_min, output_max]. Assumes the columns's range is [min_x_value, max_x_value]. If this is not true at training or prediction time, the output value of this scale could be outside the range [output_min, output_max]. Raises: ValueError: if min_x_value = max_x_value, as the column is constant. """ if round(min_x_value - max_x_value, 7) == 0: # There is something wrong with the data. # Why round to 7 places? It's the same as unittest's assertAlmostEqual. raise ValueError('In make_scale_tito, min_x_value == max_x_value') def _scale(x): min_x_valuef = tf.to_float(min_x_value) max_x_valuef = tf.to_float(max_x_value) output_minf = tf.to_float(output_min) output_maxf = tf.to_float(output_max) return ((((tf.to_float(x) - min_x_valuef) * (output_maxf - output_minf)) / (max_x_valuef - min_x_valuef)) + output_minf) return _scale(x) def _string_to_int(x, vocab): """Given a vocabulary and a string tensor `x`, maps `x` into an int tensor. Args: x: A `Column` representing a string value. vocab: list of strings. Returns: A `Column` where each string value is mapped to an integer representing its index in the vocab. Out of vocab values are mapped to len(vocab). """ def _map_to_int(x): """Maps string tensor into indexes using vocab. Args: x : a Tensor/SparseTensor of string. Returns: a Tensor/SparseTensor of indexes (int) of the same shape as x. """ table = lookup.index_table_from_tensor( vocab, default_value=len(vocab)) return table.lookup(x) return _map_to_int(x) def _make_image_to_vec_tito(feature_name, tmp_dir=None, checkpoint=None): """Creates a tensor-in-tensor-out function that produces embeddings from image bytes. Image to embedding is implemented with Tensorflow's inception v3 model and a pretrained checkpoint. It returns 1x2048 'PreLogits' embeddings for each image. Args: feature_name: The name of the feature. Used only to identify the image tensors so we can get gradients for probe in image prediction explaining. tmp_dir: a local directory that is used for downloading the checkpoint. If non, a temp folder will be made and deleted. checkpoint: the inception v3 checkpoint gs or local path. If None, default checkpoint is used. Returns: a tensor-in-tensor-out function that takes image string tensor and returns embeddings. """ def _image_to_vec(image_str_tensor): def _decode_and_resize(image_tensor): """Decodes jpeg string, resizes it and returns a uint8 tensor.""" # These constants are set by Inception v3's expectations. height = 299 width = 299 channels = 3 image_tensor = tf.where(tf.equal(image_tensor, ''), IMAGE_DEFAULT_STRING, image_tensor) # Fork by whether image_tensor value is a file path, or a base64 encoded string. slash_positions = tf.equal(tf.string_split([image_tensor], delimiter="").values, '/') is_file_path = tf.cast(tf.count_nonzero(slash_positions), tf.bool) # The following two functions are required for tf.cond. Note that we can not replace them # with lambda. According to TF docs, if using inline lambda, both branches of condition # will be executed. The workaround is to use a function call. def _read_file(): return tf.read_file(image_tensor) def _decode_base64(): return tf.decode_base64(image_tensor) image = tf.cond(is_file_path, lambda: _read_file(), lambda: _decode_base64()) image = tf.image.decode_jpeg(image, channels=channels) image = tf.expand_dims(image, 0) image = tf.image.resize_bilinear(image, [height, width], align_corners=False) image = tf.squeeze(image, squeeze_dims=[0]) image = tf.cast(image, dtype=tf.uint8) return image # The CloudML Prediction API always "feeds" the Tensorflow graph with # dynamic batch sizes e.g. (?,). decode_jpeg only processes scalar # strings because it cannot guarantee a batch of images would have # the same output size. We use tf.map_fn to give decode_jpeg a scalar # string from dynamic batches. image = tf.map_fn(_decode_and_resize, image_str_tensor, back_prop=False, dtype=tf.uint8) image = tf.image.convert_image_dtype(image, dtype=tf.float32) # "gradients_[feature_name]" will be used for computing integrated gradients. image = tf.identity(image, name='gradients_' + feature_name) image = tf.subtract(image, 0.5) inception_input = tf.multiply(image, 2.0) # Build Inception layers, which expect a tensor of type float from [-1, 1) # and shape [batch_size, height, width, channels]. with tf.contrib.slim.arg_scope(inception_v3_arg_scope()): _, end_points = inception_v3(inception_input, is_training=False) embeddings = end_points['PreLogits'] inception_embeddings = tf.squeeze(embeddings, [1, 2], name='SpatialSqueeze') return inception_embeddings def _tito_from_checkpoint(tito_in, checkpoint, exclude): """ Create an all-constants tito function from an original tito function. Given a tensor-in-tensor-out function which contains variables and a checkpoint path, create a new tensor-in-tensor-out function which includes only constants, and can be used in tft.map. """ def _tito_out(tensor_in): checkpoint_dir = tmp_dir if tmp_dir is None: checkpoint_dir = tempfile.mkdtemp() g = tf.Graph() with g.as_default(): si = tf.placeholder(dtype=tensor_in.dtype, shape=tensor_in.shape, name=tensor_in.op.name) so = tito_in(si) all_vars = tf.contrib.slim.get_variables_to_restore(exclude=exclude) saver = tf.train.Saver(all_vars) # Downloading the checkpoint from GCS to local speeds up saver.restore() a lot. checkpoint_tmp = os.path.join(checkpoint_dir, 'checkpoint') with file_io.FileIO(checkpoint, 'r') as f_in, file_io.FileIO(checkpoint_tmp, 'w') as f_out: f_out.write(f_in.read()) with tf.Session() as sess: saver.restore(sess, checkpoint_tmp) output_graph_def = tf.graph_util.convert_variables_to_constants(sess, g.as_graph_def(), [so.op.name]) file_io.delete_file(checkpoint_tmp) if tmp_dir is None: shutil.rmtree(checkpoint_dir) tensors_out = tf.import_graph_def(output_graph_def, input_map={si.name: tensor_in}, return_elements=[so.name]) return tensors_out[0] return _tito_out if not checkpoint: checkpoint = INCEPTION_V3_CHECKPOINT return _tito_from_checkpoint(_image_to_vec, checkpoint, INCEPTION_EXCLUDED_VARIABLES) # ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------ # end of transform functions # ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------ def make_preprocessing_fn(output_dir, features, keep_target): """Makes a preprocessing function. Args: output_dir: folder path that contains the vocab and stats files. features: the features dict Returns: a function that takes a dict of input tensors """ def preprocessing_fn(inputs): """Preprocessing function. Args: inputs: dictionary of raw input tensors Returns: A dictionary of transformed tensors """ stats = json.loads( file_io.read_file_to_string( os.path.join(output_dir, STATS_FILE)).decode()) result = {} for name, transform in six.iteritems(features): transform_name = transform['transform'] source_column = transform['source_column'] if transform_name == TARGET_TRANSFORM: if not keep_target: continue if file_io.file_exists(os.path.join(output_dir, VOCAB_ANALYSIS_FILE % source_column)): transform_name = 'one_hot' else: transform_name = 'identity' if transform_name == 'identity': result[name] = inputs[source_column] elif transform_name == 'scale': result[name] = _scale( inputs[name], min_x_value=stats['column_stats'][source_column]['min'], max_x_value=stats['column_stats'][source_column]['max'], output_min=transform.get('value', 1) * (-1), output_max=transform.get('value', 1)) elif transform_name in [ONE_HOT_TRANSFORM, MULTI_HOT_TRANSFORM]: vocab, ex_count = read_vocab_file( os.path.join(output_dir, VOCAB_ANALYSIS_FILE % source_column)) if transform_name == MULTI_HOT_TRANSFORM: separator = transform.get('separator', ' ') tokens = tf.string_split(inputs[source_column], separator) result[name] = _string_to_int(tokens, vocab) else: result[name] = _string_to_int(inputs[source_column], vocab) elif transform_name == IMAGE_TRANSFORM: make_image_to_vec_fn = _make_image_to_vec_tito( name, checkpoint=transform.get('checkpoint', None)) result[name] = make_image_to_vec_fn(inputs[source_column]) else: raise ValueError('unknown transform %s' % transform_name) return result return preprocessing_fn def csv_header_and_defaults(features, schema, stats, keep_target): """Gets csv header and default lists.""" target_name = get_target_name(features) if keep_target and not target_name: raise ValueError('Cannot find target transform') csv_header = [] record_defaults = [] for col in schema: if not keep_target and col['name'] == target_name: continue # Note that numerical key columns do not have a stats entry, hence the use # of get(col['name'], {}) csv_header.append(col['name']) if col['type'].lower() == INTEGER_SCHEMA: dtype = tf.int64 default = int(stats['column_stats'].get(col['name'], {}).get('mean', 0)) elif col['type'].lower() == FLOAT_SCHEMA: dtype = tf.float32 default = float(stats['column_stats'].get(col['name'], {}).get('mean', 0.0)) else: dtype = tf.string default = '' record_defaults.append(tf.constant([default], dtype=dtype)) return csv_header, record_defaults def build_csv_serving_tensors_for_transform_step(analysis_path, features, schema, stats, keep_target): """Builds a serving function starting from raw csv. This should only be used by transform.py (the transform step), and the For image columns, the image should be a base64 string encoding the image. The output of this function will transform that image to a 2048 long vector using the inception model. """ csv_header, record_defaults = csv_header_and_defaults(features, schema, stats, keep_target) placeholder = tf.placeholder(dtype=tf.string, shape=(None,), name='csv_input_placeholder') tensors = tf.decode_csv(placeholder, record_defaults) raw_features = dict(zip(csv_header, tensors)) transform_fn = make_preprocessing_fn(analysis_path, features, keep_target) transformed_tensors = transform_fn(raw_features) transformed_features = {} # Expand the dims of non-sparse tensors for k, v in six.iteritems(transformed_tensors): if isinstance(v, tf.Tensor) and v.get_shape().ndims == 1: transformed_features[k] = tf.expand_dims(v, -1) else: transformed_features[k] = v return input_fn_utils.InputFnOps( transformed_features, None, {"csv_example": placeholder}) def get_target_name(features): for name, transform in six.iteritems(features): if transform['transform'] == TARGET_TRANSFORM: return name return None def read_vocab_file(file_path): """Reads a vocab file to memeory. Args: file_path: Each line of the vocab is in the form "token,example_count" Returns: Two lists, one for the vocab, and one for just the example counts. """ with file_io.FileIO(file_path, 'r') as f: vocab_pd = pd.read_csv( f, header=None, names=['vocab', 'count'], dtype=str, # Prevent pd from converting numerical categories. na_filter=False) # Prevent pd from converting 'NA' to a NaN. vocab = vocab_pd['vocab'].tolist() ex_count = vocab_pd['count'].astype(int).tolist() return vocab, ex_count def get_transformed_feature_indices(features, stats): """Returns information about the transformed features. Returns: List in the from [(transformed_feature_name, {size: int, index_start: int})] """ feature_indices = [] index_start = 1 for name, transform in sorted(six.iteritems(features)): transform_name = transform['transform'] source_column = transform['source_column'] info = {} if transform_name in [IDENTITY_TRANSFORM, SCALE_TRANSFORM]: info['size'] = 1 elif transform_name in [ONE_HOT_TRANSFORM, MULTI_HOT_TRANSFORM]: info['size'] = stats['column_stats'][source_column]['vocab_size'] elif transform_name == IMAGE_TRANSFORM: info['size'] = IMAGE_BOTTLENECK_TENSOR_SIZE elif transform_name == TARGET_TRANSFORM: info['size'] = 0 else: raise ValueError('xgboost does not support transform "%s"' % transform) info['index_start'] = index_start index_start += info['size'] feature_indices.append((name, info)) return feature_indices def create_feature_map(features, feature_indices, output_dir): """Returns feature_map about the transformed features. feature_map includes information such as: 1, cat1=0 2, cat1=1 3, numeric1 ... Returns: List in the from [(index, feature_description)] """ feature_map = [] for name, info in feature_indices: transform_name = features[name]['transform'] source_column = features[name]['source_column'] if transform_name in [IDENTITY_TRANSFORM, SCALE_TRANSFORM]: feature_map.append((info['index_start'], name)) elif transform_name in [ONE_HOT_TRANSFORM, MULTI_HOT_TRANSFORM]: vocab, _ = read_vocab_file( os.path.join(output_dir, VOCAB_ANALYSIS_FILE % source_column)) for i, word in enumerate(vocab): if transform_name == ONE_HOT_TRANSFORM: feature_map.append((info['index_start'] + i, '%s=%s' % (source_column, word))) elif transform_name == MULTI_HOT_TRANSFORM: feature_map.append((info['index_start'] + i, '%s has "%s"' % (source_column, word))) elif transform_name == IMAGE_TRANSFORM: for i in range(info['size']): feature_map.append((info['index_start'] + i, '%s image feature %d' % (source_column, i))) return feature_map ================================================ FILE: solutionbox/ml_workbench/xgboost/trainer/task.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import json import math import multiprocessing import os import re import sys import six import tensorflow as tf from tensorflow.contrib.framework.python.ops import variables as contrib_variables from tensorflow.contrib.learn.python.learn import export_strategy from tensorflow.contrib.learn.python.learn import learn_runner from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils from tensorflow.python.client import session as tf_session from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io from tensorflow.python.ops import resources from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variables from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import saver from tensorflow.python.util import compat from . import feature_transforms from . import feature_analysis # Constants for the Prediction Graph fetch tensors. PG_TARGET = 'target' # from input PG_REGRESSION_PREDICTED_TARGET = 'predicted' PG_CLASSIFICATION_FIRST_LABEL = 'predicted' PG_CLASSIFICATION_FIRST_SCORE = 'probability' PG_CLASSIFICATION_LABEL_TEMPLATE = 'predicted_%s' PG_CLASSIFICATION_SCORE_TEMPLATE = 'probability_%s' class DatalabParser(): """An arg parser that also prints package specific args with --datalab-help. When using Datalab magic's to run this trainer, it prints it's own help menu that describes the required options that are common to all trainers. In order to print just the options that are unique to this trainer, datalab calls this file with --datalab-help. This class implements --datalab-help by building a list of help string that only includes the unique parameters. """ def __init__(self, epilog=None, datalab_epilog=None): self.full_parser = argparse.ArgumentParser(epilog=epilog) self.datalab_help = [] self.datalab_epilog = datalab_epilog # Datalab help string self.full_parser.add_argument( '--datalab-help', action=self.make_datalab_help_action(), help='Show a smaller help message for DataLab only and exit') # The arguments added here are required to exist by Datalab's "%%ml train" magics. self.full_parser.add_argument( '--train', type=str, required=True, action='append', metavar='FILE') self.full_parser.add_argument( '--eval', type=str, required=True, action='append', metavar='FILE') self.full_parser.add_argument('--job-dir', type=str, required=True) self.full_parser.add_argument( '--analysis', type=str, metavar='ANALYSIS_OUTPUT_DIR', help=('Output folder of analysis. Should contain the schema, stats, and ' 'vocab files. Path must be on GCS if running cloud training. ' + 'If absent, --schema and --features must be provided and ' + 'the master trainer will do analysis locally.')) self.full_parser.add_argument( '--transform', action='store_true', default=False, help='If used, input data is raw csv that needs transformation. If analysis ' + 'is required to run in trainerm this is automatically set to true.') self.full_parser.add_argument( '--schema', type=str, help='Schema of the training csv file. Only needed if analysis is required.') self.full_parser.add_argument( '--features', type=str, help='Feature transform config. Only needed if analysis is required.') def make_datalab_help_action(self): """Custom action for --datalab-help. The action output the package specific parameters and will be part of "%%ml train" help string. """ datalab_help = self.datalab_help epilog = self.datalab_epilog class _CustomAction(argparse.Action): def __init__(self, option_strings, dest, help=None): super(_CustomAction, self).__init__( option_strings=option_strings, dest=dest, nargs=0, help=help) def __call__(self, parser, args, values, option_string=None): print('\n\n'.join(datalab_help)) if epilog: print(epilog) # We have printed all help string datalab needs. If we don't quit, it will complain about # missing required arguments later. quit() return _CustomAction def add_argument(self, name, **kwargs): # Any argument added here is not required by Datalab, and so is unique # to this trainer. Add each argument to the main parser and the datalab helper string. self.full_parser.add_argument(name, **kwargs) name = name.replace('--', '') # leading spaces are needed for datalab's help formatting. msg = ' ' + name + ': ' if 'help' in kwargs: msg += kwargs['help'] + ' ' if kwargs.get('required', False): msg += 'Required. ' else: msg += 'Optional. ' if 'choices' in kwargs: msg += 'One of ' + str(kwargs['choices']) + '. ' if 'default' in kwargs: msg += 'default: ' + str(kwargs['default']) + '.' self.datalab_help.append(msg) def parse_known_args(self, args=None): return self.full_parser.parse_known_args(args=args) def parse_arguments(argv): """Parse the command line arguments.""" parser = DatalabParser( epilog=('Note that if using a DNN model, --hidden-layer-size1=NUM, ' '--hidden-layer-size2=NUM, ..., is also required. '), datalab_epilog=(""" Note that if using a DNN model, hidden-layer-size1: NUM hidden-layer-size2: NUM ... is also required. """)) # HP parameters parser.add_argument( '--epsilon', type=float, default=0.0005, metavar='R', help='tf.train.AdamOptimizer epsilon. Only used in dnn models.') parser.add_argument( '--l1-regularization', type=float, default=0.0, metavar='R', help='L1 term for linear models.') parser.add_argument( '--l2-regularization', type=float, default=0.0, metavar='R', help='L2 term for linear models.') # Model parameters parser.add_argument( '--model', required=True, choices=['linear_classification', 'linear_regression', 'dnn_classification', 'dnn_regression']) parser.add_argument( '--top-n', type=int, default=0, metavar='N', help=('For classification problems, the output graph will contain the ' 'labels and scores for the top n classes, and results will be in the form of ' '"predicted, predicted_2, ..., probability, probability_2, ...". ' 'If --top-n=0, then all labels and scores are returned in the form of ' '"predicted, class_name1, class_name2,...".')) # HP parameters parser.add_argument( '--learning-rate', type=float, default=0.01, metavar='R', help='optimizer learning rate.') # Training input parameters parser.add_argument( '--max-steps', type=int, metavar='N', help='Maximum number of training steps to perform. If unspecified, will ' 'honor "max-epochs".') parser.add_argument( '--max-epochs', type=int, default=1000, metavar='N', help='Maximum number of training data epochs on which to train. If ' 'both "max-steps" and "max-epochs" are specified, the training ' 'job will run for "max-steps" or "num-epochs", whichever occurs ' 'first. If early stopping is enabled, training may also stop ' 'earlier.') parser.add_argument( '--train-batch-size', type=int, default=64, metavar='N', help='How many training examples are used per step. If num-epochs is ' 'used, the last batch may not be full.') parser.add_argument( '--eval-batch-size', type=int, default=64, metavar='N', help='Batch size during evaluation. Larger values increase performance ' 'but also increase peak memory usgae on the master node. One pass ' 'over the full eval set is performed per evaluation run.') parser.add_argument( '--min-eval-frequency', type=int, default=1000, metavar='N', help='Minimum number of training steps between evaluations. Evaluation ' 'does not occur if no new checkpoint is available, hence, this is ' 'the minimum. If 0, the evaluation will only happen after training. ') parser.add_argument( '--early-stopping-num_evals', type=int, default=3, help='Automatic training stop after results of specified number of evals ' 'in a row show the model performance does not improve. Set to 0 to ' 'disable early stopping.') parser.add_argument( '--logging-level', choices=['error', 'warning', 'info'], help='The TF logging level. If absent, use info for cloud training ' 'and warning for local training.') args, remaining_args = parser.parse_known_args(args=argv[1:]) # All HP parambeters must be unique, so we need to support an unknown number # of --hidden-layer-size1=10 --lhidden-layer-size2=10 ... # Look at remaining_args for hidden-layer-size\d+ to get the layer info. # Get number of layers pattern = re.compile('hidden-layer-size(\d+)') num_layers = 0 for other_arg in remaining_args: match = re.search(pattern, other_arg) if match: if int(match.group(1)) <= 0: raise ValueError('layer size must be a positive integer. Was given %s' % other_arg) num_layers = max(num_layers, int(match.group(1))) # Build a new parser so we catch unknown args and missing layer_sizes. parser = argparse.ArgumentParser() for i in range(num_layers): parser.add_argument('--hidden-layer-size%s' % str(i + 1), type=int, required=True) layer_args = vars(parser.parse_args(args=remaining_args)) hidden_layer_sizes = [] for i in range(num_layers): key = 'hidden_layer_size%s' % str(i + 1) hidden_layer_sizes.append(layer_args[key]) assert len(hidden_layer_sizes) == num_layers args.hidden_layer_sizes = hidden_layer_sizes return args def is_linear_model(model_type): return model_type.startswith('linear_') def is_dnn_model(model_type): return model_type.startswith('dnn_') def is_regression_model(model_type): return model_type.endswith('_regression') def is_classification_model(model_type): return model_type.endswith('_classification') def build_feature_columns(features, stats, model_type): feature_columns = [] is_dnn = is_dnn_model(model_type) # Supported transforms: # for DNN # numerical number # one hot: sparse int column -> one_hot_column # ebmedding: sparse int column -> embedding_column # text: sparse int weighted column -> embedding_column # for linear # numerical number # one hot: sparse int column # ebmedding: sparse int column -> hash int # text: sparse int weighted column # It is unfortunate that tf.layers has different feature transforms if the # model is linear or DNN. This pacakge should not expose to the user that # we are using tf.layers. for name, transform in six.iteritems(features): transform_name = transform['transform'] source_column = transform['source_column'] if transform_name in feature_transforms.NUMERIC_TRANSFORMS: new_feature = tf.contrib.layers.real_valued_column(name, dimension=1) elif (transform_name == feature_transforms.ONE_HOT_TRANSFORM or transform_name == feature_transforms.MULTI_HOT_TRANSFORM): sparse = tf.contrib.layers.sparse_column_with_integerized_feature( name, bucket_size=stats['column_stats'][source_column]['vocab_size']) if is_dnn: new_feature = tf.contrib.layers.one_hot_column(sparse) else: new_feature = sparse elif transform_name == feature_transforms.EMBEDDING_TRANSFROM: if is_dnn: sparse = tf.contrib.layers.sparse_column_with_integerized_feature( name, bucket_size=stats['column_stats'][source_column]['vocab_size']) new_feature = tf.contrib.layers.embedding_column( sparse, dimension=transform['embedding_dim']) else: new_feature = tf.contrib.layers.sparse_column_with_hash_bucket( name, hash_bucket_size=transform['embedding_dim'], dtype=dtypes.int64) elif transform_name in feature_transforms.TEXT_TRANSFORMS: sparse_ids = tf.contrib.layers.sparse_column_with_integerized_feature( name + '_ids', bucket_size=stats['column_stats'][source_column]['vocab_size'], combiner='sum') sparse_weights = tf.contrib.layers.weighted_sparse_column( sparse_id_column=sparse_ids, weight_column_name=name + '_weights', dtype=dtypes.float32) if is_dnn: new_feature = tf.contrib.layers.one_hot_column(sparse_ids) dimension = int(math.log(stats['column_stats'][source_column]['vocab_size'])) + 1 new_feature = tf.contrib.layers.embedding_column( sparse_weights, dimension=dimension, combiner='sqrtn') else: new_feature = sparse_weights elif (transform_name == feature_transforms.TARGET_TRANSFORM or transform_name == feature_transforms.KEY_TRANSFORM): continue elif transform_name == feature_transforms.IMAGE_TRANSFORM: new_feature = tf.contrib.layers.real_valued_column( name, dimension=feature_transforms.IMAGE_HIDDEN_TENSOR_SIZE) else: raise ValueError('Unknown transfrom %s' % transform_name) feature_columns.append(new_feature) return feature_columns def recursive_copy(src_dir, dest_dir): """Copy the contents of src_dir into the folder dest_dir. Args: src_dir: gsc or local path. dest_dir: gcs or local path. """ file_io.recursive_create_dir(dest_dir) for file_name in file_io.list_directory(src_dir): old_path = os.path.join(src_dir, file_name) new_path = os.path.join(dest_dir, file_name) if file_io.is_directory(old_path): recursive_copy(old_path, new_path) else: file_io.copy(old_path, new_path, overwrite=True) def make_prediction_output_tensors(args, features, input_ops, model_fn_ops, keep_target): """Makes the final prediction output layer.""" target_name = feature_transforms.get_target_name(features) key_names = get_key_names(features) outputs = {} outputs.update({key_name: tf.squeeze(input_ops.features[key_name]) for key_name in key_names}) if is_classification_model(args.model): # build maps from ints to the origional categorical strings. class_names = read_vocab(args, target_name) table = tf.contrib.lookup.index_to_string_table_from_tensor( mapping=class_names, default_value='UNKNOWN') # Get the label of the input target. if keep_target: input_target_label = table.lookup(input_ops.features[target_name]) outputs[PG_TARGET] = tf.squeeze(input_target_label) # TODO(brandondutra): get the score of the target label too. probabilities = model_fn_ops.predictions['probabilities'] # if top_n == 0, this means use all the classes. We will use class names as # probabilities labels. if args.top_n == 0: predicted_index = tf.argmax(probabilities, axis=1) predicted = table.lookup(predicted_index) outputs.update({PG_CLASSIFICATION_FIRST_LABEL: predicted}) probabilities_list = tf.unstack(probabilities, axis=1) for class_name, p in zip(class_names, probabilities_list): outputs[class_name] = p else: top_n = args.top_n # get top k labels and their scores. (top_k_values, top_k_indices) = tf.nn.top_k(probabilities, k=top_n) top_k_labels = table.lookup(tf.to_int64(top_k_indices)) # Write the top_k values using 2*top_n columns. num_digits = int(math.ceil(math.log(top_n, 10))) if num_digits == 0: num_digits = 1 for i in range(0, top_n): # Pad i based on the size of k. So if k = 100, i = 23 -> i = '023'. This # makes sorting the columns easy. padded_i = str(i + 1).zfill(num_digits) if i == 0: label_alias = PG_CLASSIFICATION_FIRST_LABEL else: label_alias = PG_CLASSIFICATION_LABEL_TEMPLATE % padded_i label_tensor_name = (tf.squeeze( tf.slice(top_k_labels, [0, i], [tf.shape(top_k_labels)[0], 1]))) if i == 0: score_alias = PG_CLASSIFICATION_FIRST_SCORE else: score_alias = PG_CLASSIFICATION_SCORE_TEMPLATE % padded_i score_tensor_name = (tf.squeeze( tf.slice(top_k_values, [0, i], [tf.shape(top_k_values)[0], 1]))) outputs.update({label_alias: label_tensor_name, score_alias: score_tensor_name}) else: if keep_target: outputs[PG_TARGET] = tf.squeeze(input_ops.features[target_name]) scores = model_fn_ops.predictions['scores'] outputs[PG_REGRESSION_PREDICTED_TARGET] = tf.squeeze(scores) return outputs # This function is strongly based on # tensorflow/contrib/learn/python/learn/estimators/estimator.py:export_savedmodel() # The difference is we need to modify estimator's output layer. def make_export_strategy( args, keep_target, assets_extra, features, schema, stats): """Makes prediction graph that takes json input. Args: args: command line args keep_target: If ture, target column is returned in prediction graph. Target column must also exist in input data assets_extra: other fiels to copy to the output folder job_dir: root job folder features: features dict schema: schema list stats: stats dict """ target_name = feature_transforms.get_target_name(features) csv_header = [col['name'] for col in schema] if not keep_target: csv_header.remove(target_name) def export_fn(estimator, export_dir_base, checkpoint_path=None, eval_result=None): with ops.Graph().as_default() as g: contrib_variables.create_global_step(g) input_ops = feature_transforms.build_csv_serving_tensors_for_training_step( args.analysis, features, schema, stats, keep_target) model_fn_ops = estimator._call_model_fn(input_ops.features, None, model_fn_lib.ModeKeys.INFER) output_fetch_tensors = make_prediction_output_tensors( args=args, features=features, input_ops=input_ops, model_fn_ops=model_fn_ops, keep_target=keep_target) # Don't use signature_def_utils.predict_signature_def as that renames # tensor names if there is only 1 input/output tensor! signature_inputs = {key: tf.saved_model.utils.build_tensor_info(tensor) for key, tensor in six.iteritems(input_ops.default_inputs)} signature_outputs = {key: tf.saved_model.utils.build_tensor_info(tensor) for key, tensor in six.iteritems(output_fetch_tensors)} signature_def_map = { 'serving_default': signature_def_utils.build_signature_def( signature_inputs, signature_outputs, tf.saved_model.signature_constants.PREDICT_METHOD_NAME)} if not checkpoint_path: # Locate the latest checkpoint checkpoint_path = saver.latest_checkpoint(estimator._model_dir) if not checkpoint_path: raise ValueError("Couldn't find trained model at %s." % estimator._model_dir) export_dir = saved_model_export_utils.get_timestamped_export_dir( export_dir_base) if (model_fn_ops.scaffold is not None and model_fn_ops.scaffold.saver is not None): saver_for_restore = model_fn_ops.scaffold.saver else: saver_for_restore = saver.Saver(sharded=True) with tf_session.Session('') as session: saver_for_restore.restore(session, checkpoint_path) init_op = control_flow_ops.group( variables.local_variables_initializer(), resources.initialize_resources(resources.shared_resources()), tf.tables_initializer()) # Perform the export builder = saved_model_builder.SavedModelBuilder(export_dir) builder.add_meta_graph_and_variables( session, [tag_constants.SERVING], signature_def_map=signature_def_map, assets_collection=ops.get_collection( ops.GraphKeys.ASSET_FILEPATHS), legacy_init_op=init_op) builder.save(False) # Add the extra assets if assets_extra: assets_extra_path = os.path.join(compat.as_bytes(export_dir), compat.as_bytes('assets.extra')) for dest_relative, source in assets_extra.items(): dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative)) dest_path = os.path.dirname(dest_absolute) file_io.recursive_create_dir(dest_path) file_io.copy(source, dest_absolute) # only keep the last 3 models saved_model_export_utils.garbage_collect_exports( export_dir_base, exports_to_keep=3) # save the last model to the model folder. # export_dir_base = A/B/intermediate_models/ if keep_target: final_dir = os.path.join(args.job_dir, 'evaluation_model') else: final_dir = os.path.join(args.job_dir, 'model') if file_io.is_directory(final_dir): file_io.delete_recursively(final_dir) file_io.recursive_create_dir(final_dir) recursive_copy(export_dir, final_dir) return export_dir if keep_target: intermediate_dir = 'intermediate_evaluation_models' else: intermediate_dir = 'intermediate_prediction_models' return export_strategy.ExportStrategy(intermediate_dir, export_fn) def get_estimator(args, output_dir, features, stats, target_vocab_size): # Check layers used for dnn models. if is_dnn_model(args.model) and not args.hidden_layer_sizes: raise ValueError('--hidden-layer-size* must be used with DNN models') if is_linear_model(args.model) and args.hidden_layer_sizes: raise ValueError('--hidden-layer-size* cannot be used with linear models') # Build tf.learn features feature_columns = build_feature_columns(features, stats, args.model) # Set how often to run checkpointing in terms of steps. config = tf.contrib.learn.RunConfig( save_checkpoints_steps=args.min_eval_frequency) train_dir = os.path.join(output_dir, 'train') if args.model == 'dnn_regression': estimator = tf.contrib.learn.DNNRegressor( feature_columns=feature_columns, hidden_units=args.hidden_layer_sizes, config=config, model_dir=train_dir, optimizer=tf.train.AdamOptimizer( args.learning_rate, epsilon=args.epsilon)) elif args.model == 'linear_regression': estimator = tf.contrib.learn.LinearRegressor( feature_columns=feature_columns, config=config, model_dir=train_dir, optimizer=tf.train.FtrlOptimizer( args.learning_rate, l1_regularization_strength=args.l1_regularization, l2_regularization_strength=args.l2_regularization)) elif args.model == 'dnn_classification': estimator = tf.contrib.learn.DNNClassifier( feature_columns=feature_columns, hidden_units=args.hidden_layer_sizes, n_classes=target_vocab_size, config=config, model_dir=train_dir, optimizer=tf.train.AdamOptimizer( args.learning_rate, epsilon=args.epsilon)) elif args.model == 'linear_classification': estimator = tf.contrib.learn.LinearClassifier( feature_columns=feature_columns, n_classes=target_vocab_size, config=config, model_dir=train_dir, optimizer=tf.train.FtrlOptimizer( args.learning_rate, l1_regularization_strength=args.l1_regularization, l2_regularization_strength=args.l2_regularization)) else: raise ValueError('bad --model-type value') return estimator def read_vocab(args, column_name): """Reads a vocab file if it exists. Args: args: command line flags column_name: name of column to that has a vocab file. Returns: List of vocab words or [] if the vocab file is not found. """ vocab_path = os.path.join(args.analysis, feature_transforms.VOCAB_ANALYSIS_FILE % column_name) if not file_io.file_exists(vocab_path): return [] vocab, _ = feature_transforms.read_vocab_file(vocab_path) return vocab def get_key_names(features): names = [] for name, transform in six.iteritems(features): if transform['transform'] == feature_transforms.KEY_TRANSFORM: names.append(name) return names def read_json_file(file_path): if not file_io.file_exists(file_path): raise ValueError('File not found: %s' % file_path) return json.loads(file_io.read_file_to_string(file_path).decode()) def get_experiment_fn(args): """Builds the experiment function for learn_runner.run. Args: args: the command line args Returns: A function that returns a tf.learn experiment object. """ def get_experiment(output_dir): # Read schema, input features, and transforms. schema_path_with_target = os.path.join(args.analysis, feature_transforms.SCHEMA_FILE) features_path = os.path.join(args.analysis, feature_transforms.FEATURES_FILE) stats_path = os.path.join(args.analysis, feature_transforms.STATS_FILE) schema = read_json_file(schema_path_with_target) features = read_json_file(features_path) stats = read_json_file(stats_path) target_column_name = feature_transforms.get_target_name(features) if not target_column_name: raise ValueError('target missing from features file.') # Make a copy of the schema file without the target column. schema_without_target = [col for col in schema if col['name'] != target_column_name] schema_path_without_target = os.path.join(args.job_dir, 'schema_without_target.json') file_io.recursive_create_dir(args.job_dir) file_io.write_string_to_file(schema_path_without_target, json.dumps(schema_without_target, indent=2)) # Make list of files to save with the trained model. additional_assets_with_target = { feature_transforms.FEATURES_FILE: features_path, feature_transforms.SCHEMA_FILE: schema_path_with_target} additional_assets_without_target = { feature_transforms.FEATURES_FILE: features_path, feature_transforms.SCHEMA_FILE: schema_path_without_target} # Get the model to train. target_vocab = read_vocab(args, target_column_name) estimator = get_estimator(args, output_dir, features, stats, len(target_vocab)) export_strategy_csv_notarget = make_export_strategy( args=args, keep_target=False, assets_extra=additional_assets_without_target, features=features, schema=schema, stats=stats) export_strategy_csv_target = make_export_strategy( args=args, keep_target=True, assets_extra=additional_assets_with_target, features=features, schema=schema, stats=stats) # Build readers for training. if args.transform: if any(v['transform'] == feature_transforms.IMAGE_TRANSFORM for k, v in six.iteritems(features)): raise ValueError('"image_to_vec" transform requires transformation step. ' + 'Cannot train from raw data.') input_reader_for_train = feature_transforms.build_csv_transforming_training_input_fn( schema=schema, features=features, stats=stats, analysis_output_dir=args.analysis, raw_data_file_pattern=args.train, training_batch_size=args.train_batch_size, num_epochs=args.max_epochs, randomize_input=True, min_after_dequeue=10, reader_num_threads=multiprocessing.cpu_count()) input_reader_for_eval = feature_transforms.build_csv_transforming_training_input_fn( schema=schema, features=features, stats=stats, analysis_output_dir=args.analysis, raw_data_file_pattern=args.eval, training_batch_size=args.eval_batch_size, num_epochs=1, randomize_input=False, reader_num_threads=multiprocessing.cpu_count()) else: input_reader_for_train = feature_transforms.build_tfexample_transfored_training_input_fn( schema=schema, features=features, analysis_output_dir=args.analysis, raw_data_file_pattern=args.train, training_batch_size=args.train_batch_size, num_epochs=args.max_epochs, randomize_input=True, min_after_dequeue=10, reader_num_threads=multiprocessing.cpu_count()) input_reader_for_eval = feature_transforms.build_tfexample_transfored_training_input_fn( schema=schema, features=features, analysis_output_dir=args.analysis, raw_data_file_pattern=args.eval, training_batch_size=args.eval_batch_size, num_epochs=1, randomize_input=False, reader_num_threads=multiprocessing.cpu_count()) if args.early_stopping_num_evals == 0: train_monitors = None else: if is_classification_model(args.model): early_stop_monitor = tf.contrib.learn.monitors.ValidationMonitor( input_fn=input_reader_for_eval, every_n_steps=args.min_eval_frequency, early_stopping_rounds=(args.early_stopping_num_evals * args.min_eval_frequency), early_stopping_metric='accuracy', early_stopping_metric_minimize=False) else: early_stop_monitor = tf.contrib.learn.monitors.ValidationMonitor( input_fn=input_reader_for_eval, every_n_steps=args.min_eval_frequency, early_stopping_rounds=(args.early_stopping_num_evals * args.min_eval_frequency)) train_monitors = [early_stop_monitor] return tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=input_reader_for_train, eval_input_fn=input_reader_for_eval, train_steps=args.max_steps, train_monitors=train_monitors, export_strategies=[export_strategy_csv_notarget, export_strategy_csv_target], min_eval_frequency=args.min_eval_frequency, eval_steps=None) # Return a function to create an Experiment. return get_experiment def local_analysis(args): if args.analysis: # Already analyzed. return if not args.schema or not args.features: raise ValueError('Either --analysis, or both --schema and --features are provided.') tf_config = json.loads(os.environ.get('TF_CONFIG', '{}')) cluster_spec = tf_config.get('cluster', {}) if len(cluster_spec.get('worker', [])) > 0: raise ValueError('If "schema" and "features" are provided, local analysis will run and ' + 'only BASIC scale-tier (no workers node) is supported.') if cluster_spec and not (args.schema.startswith('gs://') and args.features.startswith('gs://')): raise ValueError('Cloud trainer requires GCS paths for --schema and --features.') print('Running analysis.') schema = json.loads(file_io.read_file_to_string(args.schema).decode()) features = json.loads(file_io.read_file_to_string(args.features).decode()) args.analysis = os.path.join(args.job_dir, 'analysis') args.transform = True file_io.recursive_create_dir(args.analysis) feature_analysis.run_local_analysis(args.analysis, args.train, schema, features) print('Analysis done.') def set_logging_level(args): if 'TF_CONFIG' in os.environ: tf.logging.set_verbosity(tf.logging.INFO) else: tf.logging.set_verbosity(tf.logging.ERROR) if args.logging_level == 'error': tf.logging.set_verbosity(tf.logging.ERROR) elif args.logging_level == 'warning': tf.logging.set_verbosity(tf.logging.WARN) elif args.logging_level == 'info': tf.logging.set_verbosity(tf.logging.INFO) def main(argv=None): args = parse_arguments(sys.argv if argv is None else argv) local_analysis(args) set_logging_level(args) # Supress TensorFlow Debugging info. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' learn_runner.run( experiment_fn=get_experiment_fn(args), output_dir=args.job_dir) if __name__ == '__main__': main() ================================================ FILE: solutionbox/ml_workbench/xgboost/transform.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Flake8 cannot disable a warning for the file. Flake8 does not like beam code # and reports many 'W503 line break before binary operator' errors. So turn off # flake8 for this file. # flake8: noqa from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import datetime import json import logging import os import sys import apache_beam as beam import textwrap def parse_arguments(argv): """Parse command line arguments. Args: argv: list of command line arguments including program name. Returns: The parsed arguments as returned by argparse.ArgumentParser. """ parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, description=textwrap.dedent("""\ Runs preprocessing on raw data for TensorFlow training. This script applies some transformations to raw data to improve training performance. Some data transformations can be expensive such as the tf-idf text column transformation. During training, the same raw data row might be used multiply times to train a model. This means the same transformations are applied to the same data row multiple times. This can be very inefficient, so this script applies partial transformations to the raw data and writes an intermediate preprocessed datasource to disk for training. Running this transformation step is required for two usage paths: 1) If the img_url_to_vec transform is used. This is because preprocessing as image is expensive and TensorFlow cannot easily read raw image files during training. 2) If the raw data is in BigQuery. TensorFlow cannot read from a BigQuery source. Running this transformation step is recommended if a text transform is used (like tf-idf or bag-of-words), and the text value for each row is very long. Running this transformation step may not have an interesting training performance impact if the transforms are all simple like scaling numerical values.""")) source_group = parser.add_mutually_exclusive_group(required=True) source_group.add_argument( '--csv', metavar='FILE', required=False, action='append', help='CSV data to transform.') source_group.add_argument( '--bigquery', metavar='PROJECT_ID.DATASET.TABLE_NAME', type=str, required=False, help=('Must be in the form `project.dataset.table_name`. BigQuery ' 'data to transform')) parser.add_argument( '--analysis', metavar='ANALYSIS_OUTPUT_DIR', required=True, help='The output folder of analyze') parser.add_argument( '--prefix', metavar='OUTPUT_FILENAME_PREFIX', required=True, type=str) parser.add_argument( '--output', metavar='DIR', default=None, required=True, help=('Google Cloud Storage or Local directory in which ' 'to place outputs.')) parser.add_argument( '--shuffle', action='store_true', default=False, help='If used, data source is shuffled. This is recommended for training data.') parser.add_argument( '--batch-size', metavar='N', type=int, default=100, help='Larger values increase performance and peak memory usage.') cloud_group = parser.add_argument_group( title='Cloud Parameters', description='These parameters are only used if --cloud is used.') cloud_group.add_argument( '--cloud', action='store_true', help='Run preprocessing on the cloud.') cloud_group.add_argument( '--job-name', type=str, help='Unique dataflow job name.') cloud_group.add_argument( '--project-id', help='The project to which the job will be submitted.') cloud_group.add_argument( '--num-workers', metavar='N', type=int, default=0, help='Set to 0 to use the default size determined by the Dataflow service.') cloud_group.add_argument( '--worker-machine-type', metavar='NAME', type=str, help='A machine name from https://cloud.google.com/compute/docs/machine-types. ' ' If not given, the service uses the default machine type.') cloud_group.add_argument( '--async', action='store_true', help='If used, this script returns before the dataflow job is completed.') args = parser.parse_args(args=argv[1:]) if args.cloud and not args.project_id: raise ValueError('--project-id is needed for --cloud') if args.async and not args.cloud: raise ValueError('--async should only be used with --cloud') if not args.job_name: args.job_name = ('dataflow-job-{}'.format( datetime.datetime.now().strftime('%Y%m%d%H%M%S'))) return args @beam.ptransform_fn def shuffle(pcoll): # pylint: disable=invalid-name import random return (pcoll | 'PairWithRandom' >> beam.Map(lambda x: (random.random(), x)) | 'GroupByRandom' >> beam.GroupByKey() | 'DropRandom' >> beam.FlatMap(lambda (k, vs): vs)) def image_transform_columns(features): """Returns a list of columns that prepare_image_transforms() should run on. Because of beam + pickle, IMAGE_URL_TO_VEC_TRANSFORM cannot be used inside of a beam function, so we extract the columns prepare_image_transforms() should run on outside of beam. """ import six from trainer import feature_transforms img_cols = [] for name, transform in six.iteritems(features): if transform['transform'] == feature_transforms.IMAGE_TRANSFORM: img_cols.append(name) return img_cols def prepare_image_transforms(element, image_columns): """Replace an images url with its jpeg bytes. Args: element: one input row, as a dict image_columns: list of columns that are image paths Return: element, where each image file path has been replaced by a base64 image. """ import base64 import cStringIO from PIL import Image from tensorflow.python.lib.io import file_io as tf_file_io from apache_beam.metrics import Metrics img_error_count = Metrics.counter('main', 'ImgErrorCount') img_missing_count = Metrics.counter('main', 'ImgMissingCount') for name in image_columns: uri = element[name] if not uri: img_missing_count.inc() continue try: with tf_file_io.FileIO(uri, 'r') as f: img = Image.open(f).convert('RGB') # A variety of different calling libraries throw different exceptions here. # They all correspond to an unreadable file so we treat them equivalently. # pylint: disable broad-except except Exception as e: logging.exception('Error processing image %s: %s', uri, str(e)) img_error_count.inc() return # Convert to desired format and output. output = cStringIO.StringIO() img.save(output, 'jpeg') element[name] = base64.urlsafe_b64encode(output.getvalue()) return element class EmitAsBatchDoFn(beam.DoFn): """A DoFn that buffers the records and emits them batch by batch.""" def __init__(self, batch_size): """Constructor of EmitAsBatchDoFn beam.DoFn class. Args: batch_size: the max size we want to buffer the records before emitting. """ self._batch_size = batch_size self._cached = [] def process(self, element): self._cached.append(element) if len(self._cached) >= self._batch_size: emit = self._cached self._cached = [] yield emit def finish_bundle(self, element=None): from apache_beam.transforms import window from apache_beam.utils.windowed_value import WindowedValue if len(self._cached) > 0: # pylint: disable=g-explicit-length-test yield WindowedValue(self._cached, -1, [window.GlobalWindow()]) class TransformFeaturesDoFn(beam.DoFn): """Converts raw data into transformed data.""" def __init__(self, analysis_output_dir, features, schema, stats): self._analysis_output_dir = analysis_output_dir self._features = features self._schema = schema self._stats = stats self._session = None def start_bundle(self, element=None): """Build the transfromation graph once.""" import tensorflow as tf from trainer import feature_transforms g = tf.Graph() session = tf.Session(graph=g) # Build the transformation graph with g.as_default(): transformed_features, _, placeholders = ( feature_transforms.build_csv_serving_tensors_for_transform_step( analysis_path=self._analysis_output_dir, features=self._features, schema=self._schema, stats=self._stats, keep_target=True)) session.run(tf.tables_initializer()) self._session = session self._transformed_features = transformed_features self._input_placeholder_tensor = placeholders['csv_example'] def finish_bundle(self, element=None): self._session.close() def process(self, element): """Run the transformation graph on batched input data Args: element: list of csv strings, representing one batch input to the TF graph. Returns: dict containing the transformed data. Results are un-batched. Sparse tensors are converted to lists. """ import apache_beam as beam import six import tensorflow as tf # This function is invoked by a separate sub-process so setting the logging level # does not affect Datalab's kernel process. tf.logging.set_verbosity(tf.logging.ERROR) try: clean_element = [] for line in element: clean_element.append(line.rstrip()) # batch_result is list of numpy arrays with batch_size many rows. batch_result = self._session.run( fetches=self._transformed_features, feed_dict={self._input_placeholder_tensor: clean_element}) # ex batch_result. # Dense tensor: {'col1': array([[batch_1], [batch_2]])} # Sparse tensor: {'col1': tf.SparseTensorValue( # indices=array([[batch_1, 0], [batch_1, 1], ..., # [batch_2, 0], [batch_2, 1], ...]], # values=array[value, value, value, ...])} # Unbatch the results. for i in range(len(clean_element)): transformed_features = {} for name, value in six.iteritems(batch_result): if isinstance(value, tf.SparseTensorValue): batch_i_indices = value.indices[:, 0] == i batch_i_values = value.values[batch_i_indices] transformed_features[name] = batch_i_values.tolist() else: transformed_features[name] = value[i].tolist() yield transformed_features except Exception as e: # pylint: disable=broad-except yield beam.pvalue.TaggedOutput('errors', (str(e), element)) def decode_csv(csv_string, column_names): """Parse a csv line into a dict. Args: csv_string: a csv string. May contain missing values "a,,c" column_names: list of column names Returns: Dict of {column_name, value_from_csv}. If there are missing values, value_from_csv will be ''. """ import csv r = next(csv.reader([csv_string])) if len(r) != len(column_names): raise ValueError('csv line %s does not have %d columns' % (csv_string, len(column_names))) return {k: v for k, v in zip(column_names, r)} def encode_csv(data_dict, column_names): """Builds a csv string. Args: data_dict: dict of {column_name: 1 value} column_names: list of column names Returns: A csv string version of data_dict """ import csv import six values = [str(data_dict[x]) for x in column_names] str_buff = six.StringIO() writer = csv.writer(str_buff, lineterminator='') writer.writerow(values) return str_buff.getvalue() def serialize_example(transformed_json_data, features, feature_indices, target_name): """Makes an instance of data in libsvm format. Args: transformed_json_data: dict of transformed data. features: features config. feature_indices: output of feature_transforms.get_transformed_feature_indices() Returns: The text line representation of an instance in libsvm format. """ import six import tensorflow as tf from trainer import feature_transforms line = str(transformed_json_data[target_name][0]) for name, info in feature_indices: if features[name]['transform'] in [feature_transforms.IDENTITY_TRANSFORM, feature_transforms.SCALE_TRANSFORM]: line += ' %d:%s' % (info['index_start'], str(transformed_json_data[name][0])) elif features[name]['transform'] in [feature_transforms.ONE_HOT_TRANSFORM, feature_transforms.MULTI_HOT_TRANSFORM]: for i in range(info['size']): if i in transformed_json_data[name]: line += ' %d:1' % (info['index_start'] + i) elif features[name]['transform'] in [feature_transforms.IMAGE_TRANSFORM]: for i in range(info['size']): line += ' %d:%s' % (info['index_start'] + i, str(transformed_json_data[name][i])) return line def preprocess(pipeline, args): """Transfrom csv data into transfromed tf.example files. Outline: 1) read the input data (as csv or bigquery) into a dict format 2) replace image paths with base64 encoded image files 3) build a csv input string with images paths replaced with base64. This matches the serving csv that a trained model would expect. 4) batch the csv strings 5) run the transformations 6) write the results to tf.example files and save any errors. """ import six from tensorflow.python.lib.io import file_io from trainer import feature_transforms schema = json.loads(file_io.read_file_to_string( os.path.join(args.analysis, feature_transforms.SCHEMA_FILE)).decode()) features = json.loads(file_io.read_file_to_string( os.path.join(args.analysis, feature_transforms.FEATURES_FILE)).decode()) stats = json.loads(file_io.read_file_to_string( os.path.join(args.analysis, feature_transforms.STATS_FILE)).decode()) column_names = [col['name'] for col in schema] if args.csv: all_files = [] for i, file_pattern in enumerate(args.csv): all_files.append(pipeline | ('ReadCSVFile%d' % i) >> beam.io.ReadFromText(file_pattern)) raw_data = ( all_files | 'MergeCSVFiles' >> beam.Flatten() | 'ParseCSVData' >> beam.Map(decode_csv, column_names)) else: columns = ', '.join(column_names) query = 'SELECT {columns} FROM `{table}`'.format(columns=columns, table=args.bigquery) raw_data = ( pipeline | 'ReadBiqQueryData' >> beam.io.Read(beam.io.BigQuerySource(query=query, use_standard_sql=True))) # Note that prepare_image_transforms does not make embeddings, it justs reads # the image files and converts them to byte stings. TransformFeaturesDoFn() # will make the image embeddings. image_columns = image_transform_columns(features) clean_csv_data = ( raw_data | 'PreprocessTransferredLearningTransformations' >> beam.Map(prepare_image_transforms, image_columns) | 'BuildCSVString' >> beam.Map(encode_csv, column_names)) if args.shuffle: clean_csv_data = clean_csv_data | 'ShuffleData' >> shuffle() transform_dofn = TransformFeaturesDoFn(args.analysis, features, schema, stats) (transformed_data, errors) = ( clean_csv_data | 'Batch Input' >> beam.ParDo(EmitAsBatchDoFn(args.batch_size)) | 'Run TF Graph on Batches' >> beam.ParDo(transform_dofn).with_outputs('errors', main='main')) target_name = next((name for name, transform in six.iteritems(features) if transform['transform'] == feature_transforms.TARGET_TRANSFORM), None) feature_indices = feature_transforms.get_transformed_feature_indices(features, stats) _ = (transformed_data | 'SerializeExamples' >> beam.Map(serialize_example, features, feature_indices, target_name) | 'WriteExamples' >> beam.io.WriteToText( os.path.join(args.output, args.prefix), file_name_suffix='.libsvm', num_shards=1)) feature_map = feature_transforms.create_feature_map(features, feature_indices, args.analysis) # Create the whole file content as one string to avoid dataflow reordering the entries. feature_map_content = ['\n'.join(['%d,%s' % x for x in feature_map])] _ = (pipeline | beam.Create(feature_map_content) | 'WriteFeatureMap' >> beam.io.WriteToText( os.path.join(args.output, 'featuremap'), file_name_suffix='.txt', num_shards=1)) def main(argv=None): """Run Preprocessing as a Dataflow.""" args = parse_arguments(sys.argv if argv is None else argv) temp_dir = os.path.join(args.output, 'tmp') if args.cloud: pipeline_name = 'DataflowRunner' else: pipeline_name = 'DirectRunner' # Suppress TF warnings. os.environ['TF_CPP_MIN_LOG_LEVEL']='3' options = { 'job_name': args.job_name, 'temp_location': temp_dir, 'project': args.project_id, 'setup_file': os.path.abspath(os.path.join( os.path.dirname(__file__), 'setup.py')), } if args.num_workers: options['num_workers'] = args.num_workers if args.worker_machine_type: options['worker_machine_type'] = args.worker_machine_type pipeline_options = beam.pipeline.PipelineOptions(flags=[], **options) p = beam.Pipeline(pipeline_name, options=pipeline_options) preprocess(pipeline=p, args=args) pipeline_result = p.run() if not args.async: pipeline_result.wait_until_finish() if args.async and args.cloud: print('View job at https://console.developers.google.com/dataflow/job/%s?project=%s' % (pipeline_result.job_id(), args.project_id)) if __name__ == '__main__': main() ================================================ FILE: solutionbox/structured_data/build.sh ================================================ #! /bin/bash rm -fr dist cp setup.py mltoolbox/_structured_data/master_setup.py python setup.py sdist ================================================ FILE: solutionbox/structured_data/mltoolbox/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. __import__('pkg_resources').declare_namespace(__name__) ================================================ FILE: solutionbox/structured_data/mltoolbox/_structured_data/__init__.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from ._package import analyze, analyze_async, train_async, predict, batch_predict, \ batch_predict_async __all__ = ['analyze', 'analyze_async', 'train_async', 'predict', 'batch_predict', 'batch_predict_async'] ================================================ FILE: solutionbox/structured_data/mltoolbox/_structured_data/__version__.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. # Source of truth for the version of this package. __version__ = '1.0.1' ================================================ FILE: solutionbox/structured_data/mltoolbox/_structured_data/_package.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Provides interface for Datalab. Datalab will look for functions with the below names: local_preprocess local_train local_predict cloud_preprocess cloud_train cloud_predict """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import datetime import io import os import shutil import sys import tempfile import json import glob import six import subprocess import pandas as pd from tensorflow.python.lib.io import file_io import warnings # Note that subpackages of _structured_data are locally imported. # This is because this part of the mltoolbox is packaged up during training # and batch prediction. This package depends on datafow, but the trainer # workers should not need it; therefore at training time, a package cannot # import apache_beam. # Likewise, datalab packages are locally imported. This is because TF and # dataflow workers should not need it. _TF_GS_URL = 'gs://cloud-datalab/deploy/tf/tensorflow-1.2.0-cp27-none-linux_x86_64.whl' _PROTOBUF_GS_URL = 'gs://cloud-datalab/deploy/tf/protobuf-3.1.0-py2.py3-none-any.whl' class FileNotFoundError(IOError): pass def _default_project(): from google.datalab import Context return Context.default().project_id def _assert_gcs_files(files): """Check files starts wtih gs://. Args: files: string to file path, or list of file paths. """ if sys.version_info.major > 2: string_type = (str, bytes) # for python 3 compatibility else: string_type = basestring # noqa if isinstance(files, string_type): files = [files] for f in files: if f is not None and not f.startswith('gs://'): raise ValueError('File %s is not a gcs path' % f) def _package_to_staging(staging_package_url): """Repackage this package from local installed location and copy it to GCS. Args: staging_package_url: GCS path. """ import google.datalab.ml as ml # Find the package root. __file__ is under [package_root]/mltoolbox/_structured_data/this_file package_root = os.path.abspath( os.path.join(os.path.dirname(__file__), '../../')) setup_path = os.path.abspath( os.path.join(os.path.dirname(__file__), 'master_setup.py')) tar_gz_path = os.path.join(staging_package_url, 'staging', 'trainer.tar.gz') print('Building package and uploading to %s' % tar_gz_path) ml.package_and_copy(package_root, setup_path, tar_gz_path) return tar_gz_path def _wait_and_kill(pid_to_wait, pids_to_kill): """ Helper function. Wait for a process to finish if it exists, and then try to kill a list of processes. Used by local_train Args: pid_to_wait: the process to wait for. pids_to_kill: a list of processes to kill after the process of pid_to_wait finishes. """ # cloud workers don't have psutil import psutil if psutil.pid_exists(pid_to_wait): psutil.Process(pid=pid_to_wait).wait() for pid_to_kill in pids_to_kill: if psutil.pid_exists(pid_to_kill): p = psutil.Process(pid=pid_to_kill) p.kill() p.wait() # ============================================================================== # Analyze # ============================================================================== def analyze(output_dir, dataset, cloud=False, project_id=None): """Blocking version of analyze_async. See documentation of analyze_async.""" job = analyze_async( output_dir=output_dir, dataset=dataset, cloud=cloud, project_id=project_id) job.wait() print('Analyze: ' + str(job.state)) def analyze_async(output_dir, dataset, cloud=False, project_id=None): """Analyze data locally or in the cloud with BigQuery. Produce analysis used by training. This can take a while, even for small datasets. For small datasets, it may be faster to use local_analysis. Args: output_dir: The output directory to use. dataset: only CsvDataSet is supported currently. cloud: If False, runs analysis locally with Pandas. If Ture, runs analysis in the cloud with BigQuery. project_id: Uses BigQuery with this project id. Default is datalab's default project id. Returns: A google.datalab.utils.Job object that can be used to query state from or wait. """ import google.datalab.utils as du with warnings.catch_warnings(): warnings.simplefilter("ignore") fn = lambda: _analyze(output_dir, dataset, cloud, project_id) # noqa return du.LambdaJob(fn, job_id=None) def _analyze(output_dir, dataset, cloud=False, project_id=None): import google.datalab.ml as ml from . import preprocess if not isinstance(dataset, ml.CsvDataSet): raise ValueError('Only CsvDataSet is supported') if len(dataset.input_files) != 1: raise ValueError('CsvDataSet should be built with a file pattern, not a ' 'list of files.') if project_id and not cloud: raise ValueError('project_id only needed if cloud is True') if cloud: _assert_gcs_files([output_dir, dataset.input_files[0]]) tmp_dir = tempfile.mkdtemp() try: # write the schema file. _, schema_file_path = tempfile.mkstemp(dir=tmp_dir, suffix='.json', prefix='schema') file_io.write_string_to_file(schema_file_path, json.dumps(dataset.schema)) # TODO(brandondutra) use project_id in the local preprocess function. args = ['preprocess', '--input-file-pattern=%s' % dataset.input_files[0], '--output-dir=%s' % output_dir, '--schema-file=%s' % schema_file_path] if cloud: if not project_id: project_id = _default_project() print('Track BigQuery status at') print('https://bigquery.cloud.google.com/queries/%s' % project_id) preprocess.cloud_preprocess.main(args) else: preprocess.local_preprocess.main(args) finally: shutil.rmtree(tmp_dir) # ============================================================================== # Train # ============================================================================== def train_async(train_dataset, eval_dataset, analysis_dir, output_dir, features, model_type, max_steps=5000, num_epochs=None, train_batch_size=100, eval_batch_size=16, min_eval_frequency=100, top_n=None, layer_sizes=None, learning_rate=0.01, epsilon=0.0005, job_name=None, # cloud param job_name_prefix='', # cloud param cloud=None, # cloud param ): # NOTE: if you make a chane go this doc string, you MUST COPY it 4 TIMES in # mltoolbox.{classification|regression}.{dnn|linear}, but you must remove # the model_type parameter, and maybe change the layer_sizes and top_n # parameters! # Datalab does some tricky things and messing with train.__doc__ will # not work! """Train model locally or in the cloud. Local Training: Args: train_dataset: CsvDataSet eval_dataset: CsvDataSet analysis_dir: The output directory from local_analysis output_dir: Output directory of training. features: file path or features object. Example: { "col_A": {"transform": "scale", "default": 0.0}, "col_B": {"transform": "scale","value": 4}, # Note col_C is missing, so default transform used. "col_D": {"transform": "hash_one_hot", "hash_bucket_size": 4}, "col_target": {"transform": "target"}, "col_key": {"transform": "key"} } The keys correspond to the columns in the input files as defined by the schema file during preprocessing. Some notes 1) The "key" and "target" transforms are required. 2) Default values are optional. These are used if the input data has missing values during training and prediction. If not supplied for a column, the default value for a numerical column is that column's mean vlaue, and for a categorical column the empty string is used. 3) For numerical colums, the following transforms are supported: i) {"transform": "identity"}: does nothing to the number. (default) ii) {"transform": "scale"}: scales the colum values to -1, 1. iii) {"transform": "scale", "value": a}: scales the colum values to -a, a. For categorical colums, the following transforms are supported: i) {"transform": "one_hot"}: A one-hot vector using the full vocabulary is used. (default) ii) {"transform": "embedding", "embedding_dim": d}: Each label is embedded into an d-dimensional space. model_type: One of 'linear_classification', 'linear_regression', 'dnn_classification', 'dnn_regression'. max_steps: Int. Number of training steps to perform. num_epochs: Maximum number of training data epochs on which to train. The training job will run for max_steps or num_epochs, whichever occurs first. train_batch_size: number of rows to train on in one step. eval_batch_size: number of rows to eval in one step. One pass of the eval dataset is done. If eval_batch_size does not perfectly divide the numer of eval instances, the last fractional batch is not used. min_eval_frequency: Minimum number of training steps between evaluations. top_n: Int. For classification problems, the output graph will contain the labels and scores for the top n classes with a default of n=1. Use None for regression problems. layer_sizes: List. Represents the layers in the connected DNN. If the model type is DNN, this must be set. Example [10, 3, 2], this will create three DNN layers where the first layer will have 10 nodes, the middle layer will have 3 nodes, and the laster layer will have 2 nodes. learning_rate: tf.train.AdamOptimizer's learning rate, epsilon: tf.train.AdamOptimizer's epsilon value. Cloud Training: Args: All local training arguments are valid for cloud training. Cloud training contains two additional args: cloud: A CloudTrainingConfig object. job_name: Training job name. A default will be picked if None. job_name_prefix: If job_name is None, the job will be named '_'. Returns: A google.datalab.utils.Job object that can be used to query state from or wait. """ import google.datalab.utils as du if model_type not in ['linear_classification', 'linear_regression', 'dnn_classification', 'dnn_regression']: raise ValueError('Unknown model_type %s' % model_type) with warnings.catch_warnings(): warnings.simplefilter("ignore") if cloud: return cloud_train( train_dataset=train_dataset, eval_dataset=eval_dataset, analysis_dir=analysis_dir, output_dir=output_dir, features=features, model_type=model_type, max_steps=max_steps, num_epochs=num_epochs, train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, min_eval_frequency=min_eval_frequency, top_n=top_n, layer_sizes=layer_sizes, learning_rate=learning_rate, epsilon=epsilon, job_name=job_name, job_name_prefix=job_name_prefix, config=cloud, ) else: def fn(): return local_train( train_dataset=train_dataset, eval_dataset=eval_dataset, analysis_dir=analysis_dir, output_dir=output_dir, features=features, model_type=model_type, max_steps=max_steps, num_epochs=num_epochs, train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, min_eval_frequency=min_eval_frequency, top_n=top_n, layer_sizes=layer_sizes, learning_rate=learning_rate, epsilon=epsilon) return du.LambdaJob(fn, job_id=None) def local_train(train_dataset, eval_dataset, analysis_dir, output_dir, features, model_type, max_steps, num_epochs, train_batch_size, eval_batch_size, min_eval_frequency, top_n, layer_sizes, learning_rate, epsilon): if len(train_dataset.input_files) != 1 or len(eval_dataset.input_files) != 1: raise ValueError('CsvDataSets must be built with a file pattern, not list ' 'of files.') if file_io.file_exists(output_dir): raise ValueError('output_dir already exist. Use a new output path.') if eval_dataset.size < eval_batch_size: raise ValueError('Eval batch size must be smaller than the eval data size.') if isinstance(features, dict): # Make a features file. if not file_io.file_exists(output_dir): file_io.recursive_create_dir(output_dir) features_file = os.path.join(output_dir, 'features_file.json') file_io.write_string_to_file( features_file, json.dumps(features)) else: features_file = features def _get_abs_path(input_path): cur_path = os.getcwd() full_path = os.path.abspath(os.path.join(cur_path, input_path)) # put path in quotes as it could contain spaces. return "'" + full_path + "'" args = ['cd %s &&' % os.path.abspath(os.path.dirname(__file__)), 'python -m trainer.task', '--train-data-paths=%s' % _get_abs_path(train_dataset.input_files[0]), '--eval-data-paths=%s' % _get_abs_path(eval_dataset.input_files[0]), '--job-dir=%s' % _get_abs_path(output_dir), '--preprocess-output-dir=%s' % _get_abs_path(analysis_dir), '--transforms-file=%s' % _get_abs_path(features_file), '--model-type=%s' % model_type, '--max-steps=%s' % str(max_steps), '--train-batch-size=%s' % str(train_batch_size), '--eval-batch-size=%s' % str(eval_batch_size), '--min-eval-frequency=%s' % str(min_eval_frequency), '--learning-rate=%s' % str(learning_rate), '--epsilon=%s' % str(epsilon)] if num_epochs: args.append('--num-epochs=%s' % str(num_epochs)) if top_n: args.append('--top-n=%s' % str(top_n)) if layer_sizes: for i in range(len(layer_sizes)): args.append('--layer-size%s=%s' % (i + 1, str(layer_sizes[i]))) monitor_process = None try: p = subprocess.Popen(' '.join(args), shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) pids_to_kill = [p.pid] # script -> name = datalab_structured_data._package script = 'import %s; %s._wait_and_kill(%s, %s)' % (__name__, __name__, str(os.getpid()), str(pids_to_kill)) monitor_process = subprocess.Popen(['python', '-c', script]) while p.poll() is None: line = p.stdout.readline() if not six.PY2: line = line.decode() if (line.startswith('INFO:tensorflow:global') or line.startswith('INFO:tensorflow:loss') or line.startswith('INFO:tensorflow:Saving dict')): sys.stdout.write(line) finally: if monitor_process: monitor_process.kill() monitor_process.wait() def cloud_train(train_dataset, eval_dataset, analysis_dir, output_dir, features, model_type, max_steps, num_epochs, train_batch_size, eval_batch_size, min_eval_frequency, top_n, layer_sizes, learning_rate, epsilon, job_name, job_name_prefix, config): """Train model using CloudML. See local_train() for a description of the args. Args: config: A CloudTrainingConfig object. job_name: Training job name. A default will be picked if None. """ import google.datalab.ml as ml if len(train_dataset.input_files) != 1 or len(eval_dataset.input_files) != 1: raise ValueError('CsvDataSets must be built with a file pattern, not list ' 'of files.') if file_io.file_exists(output_dir): raise ValueError('output_dir already exist. Use a new output path.') if isinstance(features, dict): # Make a features file. if not file_io.file_exists(output_dir): file_io.recursive_create_dir(output_dir) features_file = os.path.join(output_dir, 'features_file.json') file_io.write_string_to_file( features_file, json.dumps(features)) else: features_file = features if not isinstance(config, ml.CloudTrainingConfig): raise ValueError('cloud should be an instance of ' 'google.datalab.ml.CloudTrainingConfig for cloud training.') _assert_gcs_files([output_dir, train_dataset.input_files[0], eval_dataset.input_files[0], features_file, analysis_dir]) args = ['--train-data-paths=%s' % train_dataset.input_files[0], '--eval-data-paths=%s' % eval_dataset.input_files[0], '--preprocess-output-dir=%s' % analysis_dir, '--transforms-file=%s' % features_file, '--model-type=%s' % model_type, '--max-steps=%s' % str(max_steps), '--train-batch-size=%s' % str(train_batch_size), '--eval-batch-size=%s' % str(eval_batch_size), '--min-eval-frequency=%s' % str(min_eval_frequency), '--learning-rate=%s' % str(learning_rate), '--epsilon=%s' % str(epsilon)] if num_epochs: args.append('--num-epochs=%s' % str(num_epochs)) if top_n: args.append('--top-n=%s' % str(top_n)) if layer_sizes: for i in range(len(layer_sizes)): args.append('--layer-size%s=%s' % (i + 1, str(layer_sizes[i]))) job_request = { 'package_uris': [_package_to_staging(output_dir), _TF_GS_URL, _PROTOBUF_GS_URL], 'python_module': 'mltoolbox._structured_data.trainer.task', 'job_dir': output_dir, 'args': args } job_request.update(dict(config._asdict())) if not job_name: job_name = job_name_prefix or 'structured_data_train' job_name += '_' + datetime.datetime.now().strftime('%y%m%d_%H%M%S') job = ml.Job.submit_training(job_request, job_name) print('Job request send. View status of job at') print('https://console.developers.google.com/ml/jobs?project=%s' % _default_project()) return job # ============================================================================== # Predict # ============================================================================== def predict(data, training_dir=None, model_name=None, model_version=None, cloud=False): """Runs prediction locally or on the cloud. Args: data: List of csv strings or a Pandas DataFrame that match the model schema. training_dir: local path to the trained output folder. model_name: deployed model name model_version: depoyed model version cloud: bool. If False, does local prediction and data and training_dir must be set. If True, does cloud prediction and data, model_name, and model_version must be set. For cloud prediction, the model must be created. This can be done by running two gcloud commands:: 1) gcloud beta ml models create NAME 2) gcloud beta ml versions create VERSION --model NAME --origin gs://BUCKET/training_dir/model or these datalab commands: 1) import google.datalab as datalab model = datalab.ml.ModelVersions(MODEL_NAME) model.deploy(version_name=VERSION, path='gs://BUCKET/training_dir/model') Note that the model must be on GCS. Returns: Pandas DataFrame. """ if cloud: if not model_version or not model_name: raise ValueError('model_version or model_name is not set') if training_dir: raise ValueError('training_dir not needed when cloud is True') with warnings.catch_warnings(): warnings.simplefilter("ignore") return cloud_predict(model_name, model_version, data) else: if not training_dir: raise ValueError('training_dir is not set') if model_version or model_name: raise ValueError('model_name and model_version not needed when cloud is ' 'False.') with warnings.catch_warnings(): warnings.simplefilter("ignore") return local_predict(training_dir, data) def local_predict(training_dir, data): """Runs local prediction on the prediction graph. Runs local prediction and returns the result in a Pandas DataFrame. For running prediction on a large dataset or saving the results, run local_batch_prediction or batch_prediction. Input data should fully match the schema that was used at training, except the target column should not exist. Args: training_dir: local path to the trained output folder. data: List of csv strings or a Pandas DataFrame that match the model schema. Raises: ValueError: if training_dir does not contain the folder 'model'. FileNotFoundError: if the prediction data is not found. """ # from . import predict as predict_module from .prediction import predict as predict_module # Save the instances to a file, call local batch prediction, and return it tmp_dir = tempfile.mkdtemp() _, input_file_path = tempfile.mkstemp(dir=tmp_dir, suffix='.csv', prefix='input') try: if isinstance(data, pd.DataFrame): data.to_csv(input_file_path, header=False, index=False) else: with open(input_file_path, 'w') as f: for line in data: f.write(line + '\n') model_dir = os.path.join(training_dir, 'model') if not file_io.file_exists(model_dir): raise ValueError('training_dir should contain the folder model') cmd = ['predict.py', '--predict-data=%s' % input_file_path, '--trained-model-dir=%s' % model_dir, '--output-dir=%s' % tmp_dir, '--output-format=csv', '--batch-size=16', '--mode=prediction', '--no-shard-files'] # runner_results = predict_module.predict.main(cmd) runner_results = predict_module.main(cmd) runner_results.wait_until_finish() # Read the header file. schema_file = os.path.join(tmp_dir, 'csv_schema.json') with open(schema_file, 'r') as f: schema = json.loads(f.read()) # Print any errors to the screen. errors_file = glob.glob(os.path.join(tmp_dir, 'errors*')) if errors_file and os.path.getsize(errors_file[0]) > 0: print('Warning: there are errors. See below:') with open(errors_file[0], 'r') as f: text = f.read() print(text) # Read the predictions data. prediction_file = glob.glob(os.path.join(tmp_dir, 'predictions*')) if not prediction_file: raise FileNotFoundError('Prediction results not found') predictions = pd.read_csv(prediction_file[0], header=None, names=[col['name'] for col in schema]) return predictions finally: shutil.rmtree(tmp_dir) def cloud_predict(model_name, model_version, data): """Use Online prediction. Runs online prediction in the cloud and prints the results to the screen. For running prediction on a large dataset or saving the results, run local_batch_prediction or batch_prediction. Args: model_name: deployed model name model_version: depoyed model version data: List of csv strings or a Pandas DataFrame that match the model schema. Before using this, the model must be created. This can be done by running two gcloud commands: 1) gcloud beta ml models create NAME 2) gcloud beta ml versions create VERSION --model NAME \ --origin gs://BUCKET/training_dir/model or these datalab commands: 1) import google.datalab as datalab model = datalab.ml.ModelVersions(MODEL_NAME) model.deploy(version_name=VERSION, path='gs://BUCKET/training_dir/model') Note that the model must be on GCS. """ import google.datalab.ml as ml if isinstance(data, pd.DataFrame): # write the df to csv. string_buffer = io.StringIO() data.to_csv(string_buffer, header=None, index=False) input_data = string_buffer.getvalue().split('\n') # remove empty strings input_data = [line for line in input_data if line] else: input_data = data predictions = ml.ModelVersions(model_name).predict(model_version, input_data) # Convert predictions into a dataframe df = pd.DataFrame(columns=sorted(predictions[0].keys())) for i in range(len(predictions)): for k, v in predictions[i].iteritems(): df.loc[i, k] = v return df # ============================================================================== # Batch predict # ============================================================================== def batch_predict(training_dir, prediction_input_file, output_dir, mode, batch_size=16, shard_files=True, output_format='csv', cloud=False): """Blocking versoin of batch_predict. See documentation of batch_prediction_async. """ job = batch_predict_async( training_dir=training_dir, prediction_input_file=prediction_input_file, output_dir=output_dir, mode=mode, batch_size=batch_size, shard_files=shard_files, output_format=output_format, cloud=cloud) job.wait() print('Batch predict: ' + str(job.state)) def batch_predict_async(training_dir, prediction_input_file, output_dir, mode, batch_size=16, shard_files=True, output_format='csv', cloud=False): """Local and cloud batch prediction. Args: training_dir: The output folder of training. prediction_input_file: csv file pattern to a file. File must be on GCS if running cloud prediction output_dir: output location to save the results. Must be a GSC path if running cloud prediction. mode: 'evaluation' or 'prediction'. If 'evaluation', the input data must contain a target column. If 'prediction', the input data must not contain a target column. batch_size: Int. How many instances to run in memory at once. Larger values mean better performace but more memeory consumed. shard_files: If False, the output files are not shardded. output_format: csv or json. Json file are json-newlined. cloud: If ture, does cloud batch prediction. If False, runs batch prediction locally. Returns: A google.datalab.utils.Job object that can be used to query state from or wait. """ import google.datalab.utils as du with warnings.catch_warnings(): warnings.simplefilter("ignore") if cloud: runner_results = cloud_batch_predict(training_dir, prediction_input_file, output_dir, mode, batch_size, shard_files, output_format) job = du.DataflowJob(runner_results) else: runner_results = local_batch_predict(training_dir, prediction_input_file, output_dir, mode, batch_size, shard_files, output_format) job = du.LambdaJob(lambda: runner_results.wait_until_finish(), job_id=None) return job def local_batch_predict(training_dir, prediction_input_file, output_dir, mode, batch_size, shard_files, output_format): """See batch_predict""" # from . import predict as predict_module from .prediction import predict as predict_module if mode == 'evaluation': model_dir = os.path.join(training_dir, 'evaluation_model') elif mode == 'prediction': model_dir = os.path.join(training_dir, 'model') else: raise ValueError('mode must be evaluation or prediction') if not file_io.file_exists(model_dir): raise ValueError('Model folder %s does not exist' % model_dir) cmd = ['predict.py', '--predict-data=%s' % prediction_input_file, '--trained-model-dir=%s' % model_dir, '--output-dir=%s' % output_dir, '--output-format=%s' % output_format, '--batch-size=%s' % str(batch_size), '--shard-files' if shard_files else '--no-shard-files', '--has-target' if mode == 'evaluation' else '--no-has-target' ] # return predict_module.predict.main(cmd) return predict_module.main(cmd) def cloud_batch_predict(training_dir, prediction_input_file, output_dir, mode, batch_size, shard_files, output_format): """See batch_predict""" # from . import predict as predict_module from .prediction import predict as predict_module if mode == 'evaluation': model_dir = os.path.join(training_dir, 'evaluation_model') elif mode == 'prediction': model_dir = os.path.join(training_dir, 'model') else: raise ValueError('mode must be evaluation or prediction') if not file_io.file_exists(model_dir): raise ValueError('Model folder %s does not exist' % model_dir) _assert_gcs_files([training_dir, prediction_input_file, output_dir]) cmd = ['predict.py', '--cloud', '--project-id=%s' % _default_project(), '--predict-data=%s' % prediction_input_file, '--trained-model-dir=%s' % model_dir, '--output-dir=%s' % output_dir, '--output-format=%s' % output_format, '--batch-size=%s' % str(batch_size), '--shard-files' if shard_files else '--no-shard-files', '--extra-package=%s' % _TF_GS_URL, '--extra-package=%s' % _PROTOBUF_GS_URL, '--extra-package=%s' % _package_to_staging(output_dir) ] return predict_module.main(cmd) ================================================ FILE: solutionbox/structured_data/mltoolbox/_structured_data/master_setup.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. # A copy of this file must be made in datalab_structured_data/setup.py import os import re from setuptools import setup # The version is saved in an __init__ file. def get_version(): VERSIONFILE = 'mltoolbox/_structured_data/__version__.py' if not os.path.isfile(VERSIONFILE): raise ValueError('setup.py: File not found %s' % VERSIONFILE) initfile_lines = open(VERSIONFILE, 'rt').readlines() VSRE = r"^__version__ = ['\"]([^'\"]*)['\"]" for line in initfile_lines: mo = re.search(VSRE, line, re.M) if mo: return mo.group(1) raise RuntimeError('Unable to find version string in %s.' % (VERSIONFILE,)) # Calling setuptools.find_packages does not work with cloud training repackaging # because this script is not ran from this folder. setup( name='mltoolbox_datalab_classification_and_regression', namespace_packages=['mltoolbox'], version=get_version(), packages=[ 'mltoolbox', 'mltoolbox.classification', 'mltoolbox.classification.linear', 'mltoolbox.classification.dnn', 'mltoolbox.regression', 'mltoolbox.regression.linear', 'mltoolbox.regression.dnn', 'mltoolbox._structured_data', 'mltoolbox._structured_data.preprocess', 'mltoolbox._structured_data.prediction', # 'mltoolbox._structured_data.test', 'mltoolbox._structured_data.trainer', ], description='Google Cloud Datalab Structured Data Package', author='Google', author_email='google-cloud-datalab-feedback@googlegroups.com', keywords=[ ], license="Apache Software License", classifiers=[ "Programming Language :: Python", "Programming Language :: Python :: 2", "Development Status :: 4 - Beta", "Environment :: Other Environment", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", "Topic :: Software Development :: Libraries :: Python Modules" ], long_description=""" """, install_requires=[ ], package_data={ }, data_files=[], ) ================================================ FILE: solutionbox/structured_data/mltoolbox/_structured_data/prediction/__init__.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import from . import predict __all__ = ['predict'] ================================================ FILE: solutionbox/structured_data/mltoolbox/_structured_data/prediction/predict.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Runs prediction on a trained model.""" import argparse import datetime import os import shutil import sys import tempfile from tensorflow.python.lib.io import file_io import apache_beam as beam from apache_beam.transforms import window from apache_beam.utils.windowed_value import WindowedValue def parse_arguments(argv): """Parse command line arguments. Args: argv: includes the script's name. Returns: argparse object """ parser = argparse.ArgumentParser( description='Runs Prediction inside a beam or Dataflow job.') # cloud options parser.add_argument('--project-id', help='The project to which the job will be submitted.') parser.add_argument('--cloud', action='store_true', help='Run preprocessing on the cloud.') parser.add_argument('--job-name', default=('mltoolbox-batch-prediction-' + datetime.datetime.now().strftime('%Y%m%d%H%M%S')), help='Dataflow job name. Must be unique over all jobs.') parser.add_argument('--extra-package', default=[], action='append', help=('If using --cloud, also installs these packages on ' 'each dataflow worker')) # I/O args parser.add_argument('--predict-data', required=True, help='Data to run prediction on') parser.add_argument('--trained-model-dir', required=True, help='Usually train_output_path/model.') parser.add_argument('--output-dir', required=True, help=('Location to save output.')) # Other args parser.add_argument('--batch-size', required=False, default=1000, type=int, help=('Batch size. Larger values consumes more memrory ' 'but takes less time to finish.')) parser.add_argument('--shard-files', dest='shard_files', action='store_true', help='Shard files') parser.add_argument('--no-shard-files', dest='shard_files', action='store_false', help='Don\'t shard files') parser.set_defaults(shard_files=True) parser.add_argument('--output-format', choices=['csv', 'json'], default='csv', help=""" The output results. raw_json: produces a newline file where each line is json. No post processing is performed and the output matches what the trained model produces. csv: produces a csv file without a header row and a header csv file. For classification problems, the vector of probabalities for each target class is split into individual csv columns.""") args, _ = parser.parse_known_args(args=argv[1:]) if args.cloud: if not args.project_id: raise ValueError('--project-id needed with --cloud') if not args.trained_model_dir.startswith('gs://'): raise ValueError('--trained-model-dir needs to be a GCS path,') if not args.output_dir.startswith('gs://'): raise ValueError('--output-dir needs to be a GCS path.') if not args.predict_data.startswith('gs://'): raise ValueError('--predict-data needs to be a GCS path.') return args class EmitAsBatchDoFn(beam.DoFn): """A DoFn that buffers the records and emits them batch by batch.""" def __init__(self, batch_size): """Constructor of EmitAsBatchDoFn beam.DoFn class. Args: batch_size: the max size we want to buffer the records before emitting. """ self._batch_size = batch_size self._cached = [] def process(self, element): self._cached.append(element) if len(self._cached) >= self._batch_size: emit = self._cached self._cached = [] yield emit def finish_bundle(self, element=None): if len(self._cached) > 0: # pylint: disable=g-explicit-length-test yield WindowedValue(self._cached, -1, [window.GlobalWindow()]) class RunGraphDoFn(beam.DoFn): """A DoFn for running the TF graph.""" def __init__(self, trained_model_dir): self._trained_model_dir = trained_model_dir self._session = None def start_bundle(self, element=None): from tensorflow.python.saved_model import tag_constants from tensorflow.contrib.session_bundle import bundle_shim self._session, meta_graph = bundle_shim.load_session_bundle_or_saved_model_bundle_from_path( self._trained_model_dir, tags=[tag_constants.SERVING]) signature = meta_graph.signature_def['serving_default'] # get the mappings between aliases and tensor names # for both inputs and outputs self._input_alias_map = {friendly_name: tensor_info_proto.name for (friendly_name, tensor_info_proto) in signature.inputs.items()} self._output_alias_map = {friendly_name: tensor_info_proto.name for (friendly_name, tensor_info_proto) in signature.outputs.items()} self._aliases, self._tensor_names = zip(*self._output_alias_map.items()) def finish_bundle(self, element=None): import tensorflow as tf self._session.close() tf.reset_default_graph() def process(self, element): """Run batch prediciton on a TF graph. Args: element: list of strings, representing one batch input to the TF graph. """ import collections import apache_beam as beam num_in_batch = 0 try: assert self._session is not None feed_dict = collections.defaultdict(list) for line in element: # Remove trailing newline. if line.endswith('\n'): line = line[:-1] feed_dict[self._input_alias_map.values()[0]].append(line) num_in_batch += 1 # batch_result is list of numpy arrays with batch_size many rows. batch_result = self._session.run(fetches=self._tensor_names, feed_dict=feed_dict) # ex batch_result for batch_size > 1: # (array([value1, value2, ..., value_batch_size]), # array([[a1, b1, c1]], ..., [a_batch_size, b_batch_size, c_batch_size]]), # ...) # ex batch_result for batch_size == 1: # (value, # array([a1, b1, c1]), # ...) # Convert the results into a dict and unbatch the results. if num_in_batch > 1: for result in zip(*batch_result): predictions = {} for name, value in zip(self._aliases, result): predictions[name] = (value.tolist() if getattr(value, 'tolist', None) else value) yield predictions else: predictions = {} for i in range(len(self._aliases)): value = batch_result[i] value = (value.tolist() if getattr(value, 'tolist', None) else value) predictions[self._aliases[i]] = value yield predictions except Exception as e: # pylint: disable=broad-except yield beam.pvalue.TaggedOutput('errors', (str(e), element)) class RawJsonCoder(beam.coders.Coder): """Coder for json newline files.""" def encode(self, obj): """Encodes a python object into a JSON string. Args: obj: python object. Returns: JSON string. """ import json return json.dumps(obj, separators=(',', ': ')) class CSVCoder(beam.coders.Coder): """Coder for CSV files containing the output of prediction.""" def __init__(self, header): """Sets the headers in the csv file. Args: header: list of strings that correspond to keys in the predictions dict. """ self._header = header def make_header_string(self): return ','.join(self._header) def encode(self, tf_graph_predictions): """Encodes the graph json prediction into csv. Args: tf_graph_predictions: python dict. Returns: csv string. """ row = [] for col in self._header: row.append(str(tf_graph_predictions[col])) return ','.join(row) class FormatAndSave(beam.PTransform): def __init__(self, args): self._shard_name_template = None if args.shard_files else '' self._output_format = args.output_format self._output_dir = args.output_dir # Get the BQ schema if csv. if self._output_format == 'csv': from tensorflow.python.saved_model import tag_constants from tensorflow.contrib.session_bundle import bundle_shim from tensorflow.core.framework import types_pb2 session, meta_graph = bundle_shim.load_session_bundle_or_saved_model_bundle_from_path( args.trained_model_dir, tags=[tag_constants.SERVING]) signature = meta_graph.signature_def['serving_default'] self._schema = [] for friendly_name in sorted(signature.outputs): tensor_info_proto = signature.outputs[friendly_name] # TODO(brandondutra): Could dtype be DT_INVALID? # Consider getting the dtype from the graph via # session.graph.get_tensor_by_name(tensor_info_proto.name).dtype) dtype = tensor_info_proto.dtype if dtype == types_pb2.DT_FLOAT or dtype == types_pb2.DT_DOUBLE: bq_type = 'FLOAT' elif dtype == types_pb2.DT_INT32 or dtype == types_pb2.DT_INT64: bq_type = 'INTEGER' else: bq_type = 'STRING' self._schema.append({'mode': 'NULLABLE', 'name': friendly_name, 'type': bq_type}) session.close() def apply(self, datasets): return self.expand(datasets) def expand(self, datasets): import json tf_graph_predictions, errors = datasets if self._output_format == 'json': (tf_graph_predictions | 'Write Raw JSON' >> beam.io.textio.WriteToText(os.path.join(self._output_dir, 'predictions'), file_name_suffix='.json', coder=RawJsonCoder(), shard_name_template=self._shard_name_template)) elif self._output_format == 'csv': # make a csv header file header = [col['name'] for col in self._schema] csv_coder = CSVCoder(header) (tf_graph_predictions.pipeline | 'Make CSV Header' >> beam.Create([json.dumps(self._schema, indent=2)]) | 'Write CSV Schema File' >> beam.io.textio.WriteToText(os.path.join(self._output_dir, 'csv_schema'), file_name_suffix='.json', shard_name_template='')) # Write the csv predictions (tf_graph_predictions | 'Write CSV' >> beam.io.textio.WriteToText(os.path.join(self._output_dir, 'predictions'), file_name_suffix='.csv', coder=csv_coder, shard_name_template=self._shard_name_template)) else: raise ValueError('FormatAndSave: unknown format %s', self._output_format) # Write the errors to a text file. (errors | 'Write Errors' >> beam.io.textio.WriteToText(os.path.join(self._output_dir, 'errors'), file_name_suffix='.txt', shard_name_template=self._shard_name_template)) def make_prediction_pipeline(pipeline, args): """Builds the prediction pipeline. Reads the csv files, prepends a ',' if the target column is missing, run prediction, and then prints the formated results to a file. Args: pipeline: the pipeline args: command line args """ # DF bug: DF does not work with unicode strings predicted_values, errors = ( pipeline | 'Read CSV Files' >> beam.io.ReadFromText(str(args.predict_data), strip_trailing_newlines=True) | 'Batch Input' >> beam.ParDo(EmitAsBatchDoFn(args.batch_size)) | 'Run TF Graph on Batches' >> beam.ParDo(RunGraphDoFn(args.trained_model_dir)).with_outputs('errors', main='main')) ((predicted_values, errors) | 'Format and Save' >> FormatAndSave(args)) def main(argv=None): args = parse_arguments(sys.argv if argv is None else argv) if args.cloud: tmpdir = tempfile.mkdtemp() try: local_packages = [os.path.join(tmpdir, os.path.basename(p)) for p in args.extra_package] for source, dest in zip(args.extra_package, local_packages): file_io.copy(source, dest, overwrite=True) options = { 'staging_location': os.path.join(args.output_dir, 'tmp', 'staging'), 'temp_location': os.path.join(args.output_dir, 'tmp', 'staging'), 'job_name': args.job_name, 'project': args.project_id, 'no_save_main_session': True, 'extra_packages': local_packages, 'teardown_policy': 'TEARDOWN_ALWAYS', } opts = beam.pipeline.PipelineOptions(flags=[], **options) # Or use BlockingDataflowPipelineRunner p = beam.Pipeline('DataflowRunner', options=opts) make_prediction_pipeline(p, args) print(('Dataflow Job submitted, see Job %s at ' 'https://console.developers.google.com/dataflow?project=%s') % (options['job_name'], args.project_id)) sys.stdout.flush() runner_results = p.run() finally: shutil.rmtree(tmpdir) else: p = beam.Pipeline('DirectRunner') make_prediction_pipeline(p, args) runner_results = p.run() return runner_results if __name__ == '__main__': runner_results = main() runner_results.wait_until_finish() ================================================ FILE: solutionbox/structured_data/mltoolbox/_structured_data/preprocess/__init__.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import from . import cloud_preprocess from . import local_preprocess __all__ = ['cloud_preprocess', 'local_preprocess'] ================================================ FILE: solutionbox/structured_data/mltoolbox/_structured_data/preprocess/cloud_preprocess.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import json import os import six import sys from tensorflow.python.lib.io import file_io SCHEMA_FILE = 'schema.json' NUMERICAL_ANALYSIS_FILE = 'stats.json' CATEGORICAL_ANALYSIS_FILE = 'vocab_%s.csv' def parse_arguments(argv): """Parse command line arguments. Args: argv: list of command line arguments, includeing programe name. Returns: An argparse Namespace object. Raises: ValueError: for bad parameters """ parser = argparse.ArgumentParser( description='Runs Preprocessing on structured data.') parser.add_argument('--output-dir', type=str, required=True, help='Google Cloud Storage which to place outputs.') parser.add_argument('--schema-file', type=str, required=False, help=('BigQuery json schema file')) parser.add_argument('--input-file-pattern', type=str, required=False, help='Input CSV file names. May contain a file pattern') # If using bigquery table # TODO(brandondutra): maybe also support an sql input, so the table can be # ad-hoc. parser.add_argument('--bigquery-table', type=str, required=False, help=('project:dataset.table_name')) args = parser.parse_args(args=argv[1:]) if not args.output_dir.startswith('gs://'): raise ValueError('--output-dir must point to a location on GCS') if args.bigquery_table: if args.schema_file or args.input_file_pattern: raise ValueError('If using --bigquery-table, then --schema-file and ' '--input-file-pattern, ' 'are not needed.') else: if not args.schema_file or not args.input_file_pattern: raise ValueError('If not using --bigquery-table, then --schema-file and ' '--input-file-pattern ' 'are required.') if not args.input_file_pattern.startswith('gs://'): raise ValueError('--input-file-pattern must point to files on GCS') return args def parse_table_name(bigquery_table): """Giving a string a:b.c, returns b.c. Args: bigquery_table: full table name project_id:dataset:table Returns: dataset:table Raises: ValueError: if a, b, or c contain the character ':'. """ id_name = bigquery_table.split(':') if len(id_name) != 2: raise ValueError('Bigquery table name should be in the form ' 'project_id:dataset.table_name. Got %s' % bigquery_table) return id_name[1] def run_numerical_analysis(table, schema_list, args): """Find min/max values for the numerical columns and writes a json file. Args: table: Reference to FederatedTable (if bigquery_table is false) or a regular Table (otherwise) schema_list: Bigquery schema json object args: the command line args """ import google.datalab.bigquery as bq # Get list of numerical columns. numerical_columns = [] for col_schema in schema_list: col_type = col_schema['type'].lower() if col_type == 'integer' or col_type == 'float': numerical_columns.append(col_schema['name']) # Run the numerical analysis if numerical_columns: sys.stdout.write('Running numerical analysis...') max_min = [ ('max({name}) as max_{name}, ' 'min({name}) as min_{name}, ' 'avg({name}) as avg_{name} ').format(name=name) for name in numerical_columns] if args.bigquery_table: sql = 'SELECT %s from `%s`' % (', '.join(max_min), parse_table_name(args.bigquery_table)) numerical_results = bq.Query(sql).execute().result().to_dataframe() else: sql = 'SELECT %s from csv_table' % ', '.join(max_min) query = bq.Query(sql, data_sources={'csv_table': table}) numerical_results = query.execute().result().to_dataframe() # Convert the numerical results to a json file. results_dict = {} for name in numerical_columns: results_dict[name] = {'max': numerical_results.iloc[0]['max_%s' % name], 'min': numerical_results.iloc[0]['min_%s' % name], 'mean': numerical_results.iloc[0]['avg_%s' % name]} file_io.write_string_to_file( os.path.join(args.output_dir, NUMERICAL_ANALYSIS_FILE), json.dumps(results_dict, indent=2, separators=(',', ': '))) sys.stdout.write('done.\n') def run_categorical_analysis(table, schema_list, args): """Find vocab values for the categorical columns and writes a csv file. The vocab files are in the from label1 label2 label3 ... Args: table: Reference to FederatedTable (if bigquery_table is false) or a regular Table (otherwise) schema_list: Bigquery schema json object args: the command line args """ import google.datalab.bigquery as bq # Get list of categorical columns. categorical_columns = [] for col_schema in schema_list: col_type = col_schema['type'].lower() if col_type == 'string': categorical_columns.append(col_schema['name']) if categorical_columns: sys.stdout.write('Running categorical analysis...') for name in categorical_columns: if args.bigquery_table: table_name = parse_table_name(args.bigquery_table) else: table_name = 'table_name' sql = """ SELECT {name} FROM {table} WHERE {name} IS NOT NULL GROUP BY {name} ORDER BY {name} """.format(name=name, table=table_name) out_file = os.path.join(args.output_dir, CATEGORICAL_ANALYSIS_FILE % name) # extract_async seems to have a bug and sometimes hangs. So get the # results direclty. if args.bigquery_table: df = bq.Query(sql).execute().result().to_dataframe() else: query = bq.Query(sql, data_sources={'table_name': table}) df = query.execute().result().to_dataframe() # Write the results to a file. string_buff = six.StringIO() df.to_csv(string_buff, index=False, header=False) file_io.write_string_to_file(out_file, string_buff.getvalue()) sys.stdout.write('done.\n') def run_analysis(args): """Builds an analysis file for training. Uses BiqQuery tables to do the analysis. Args: args: command line args Raises: ValueError if schema contains unknown types. """ import google.datalab.bigquery as bq if args.bigquery_table: table = bq.Table(args.bigquery_table) schema_list = table.schema._bq_schema else: schema_list = json.loads( file_io.read_file_to_string(args.schema_file).decode()) table = bq.ExternalDataSource( source=args.input_file_pattern, schema=bq.Schema(schema_list)) # Check the schema is supported. for col_schema in schema_list: col_type = col_schema['type'].lower() if col_type != 'string' and col_type != 'integer' and col_type != 'float': raise ValueError('Schema contains an unsupported type %s.' % col_type) run_numerical_analysis(table, schema_list, args) run_categorical_analysis(table, schema_list, args) # Save a copy of the schema to the output location. file_io.write_string_to_file( os.path.join(args.output_dir, SCHEMA_FILE), json.dumps(schema_list, indent=2, separators=(',', ': '))) def main(argv=None): args = parse_arguments(sys.argv if argv is None else argv) run_analysis(args) if __name__ == '__main__': main() ================================================ FILE: solutionbox/structured_data/mltoolbox/_structured_data/preprocess/local_preprocess.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import argparse import collections import json import os import six import sys from tensorflow.python.lib.io import file_io SCHEMA_FILE = 'schema.json' NUMERICAL_ANALYSIS_FILE = 'stats.json' CATEGORICAL_ANALYSIS_FILE = 'vocab_%s.csv' def parse_arguments(argv): """Parse command line arguments. Args: argv: list of command line arguments, includeing programe name. Returns: An argparse Namespace object. """ parser = argparse.ArgumentParser( description='Runs Preprocessing on structured CSV data.') parser.add_argument('--input-file-pattern', type=str, required=True, help='Input CSV file names. May contain a file pattern') parser.add_argument('--output-dir', type=str, required=True, help='Google Cloud Storage which to place outputs.') parser.add_argument('--schema-file', type=str, required=True, help=('BigQuery json schema file')) args = parser.parse_args(args=argv[1:]) # Make sure the output folder exists if local folder. file_io.recursive_create_dir(args.output_dir) return args def run_numerical_categorical_analysis(args, schema_list): """Makes the numerical and categorical analysis files. Args: args: the command line args schema_list: python object of the schema json file. Raises: ValueError: if schema contains unknown column types. """ header = [column['name'] for column in schema_list] input_files = file_io.get_matching_files(args.input_file_pattern) # Check the schema is valid for col_schema in schema_list: col_type = col_schema['type'].lower() if col_type != 'string' and col_type != 'integer' and col_type != 'float': raise ValueError('Schema contains an unsupported type %s.' % col_type) # initialize the results def _init_numerical_results(): return {'min': float('inf'), 'max': float('-inf'), 'count': 0, 'sum': 0.0} numerical_results = collections.defaultdict(_init_numerical_results) categorical_results = collections.defaultdict(set) # for each file, update the numerical stats from that file, and update the set # of unique labels. for input_file in input_files: with file_io.FileIO(input_file, 'r') as f: for line in f: parsed_line = dict(zip(header, line.strip().split(','))) for col_schema in schema_list: col_name = col_schema['name'] col_type = col_schema['type'] if col_type.lower() == 'string': categorical_results[col_name].update([parsed_line[col_name]]) else: # numerical column. # if empty, skip if not parsed_line[col_name].strip(): continue numerical_results[col_name]['min'] = ( min(numerical_results[col_name]['min'], float(parsed_line[col_name]))) numerical_results[col_name]['max'] = ( max(numerical_results[col_name]['max'], float(parsed_line[col_name]))) numerical_results[col_name]['count'] += 1 numerical_results[col_name]['sum'] += float(parsed_line[col_name]) # Update numerical_results to just have min/min/mean for col_schema in schema_list: if col_schema['type'].lower() != 'string': col_name = col_schema['name'] mean = numerical_results[col_name]['sum'] / numerical_results[col_name]['count'] del numerical_results[col_name]['sum'] del numerical_results[col_name]['count'] numerical_results[col_name]['mean'] = mean # Write the numerical_results to a json file. file_io.write_string_to_file( os.path.join(args.output_dir, NUMERICAL_ANALYSIS_FILE), json.dumps(numerical_results, indent=2, separators=(',', ': '))) # Write the vocab files. Each label is on its own line. for name, unique_labels in six.iteritems(categorical_results): labels = '\n'.join(list(unique_labels)) file_io.write_string_to_file( os.path.join(args.output_dir, CATEGORICAL_ANALYSIS_FILE % name), labels) def run_analysis(args): """Builds an analysis files for training.""" # Read the schema and input feature types schema_list = json.loads( file_io.read_file_to_string(args.schema_file)) run_numerical_categorical_analysis(args, schema_list) # Also save a copy of the schema in the output folder. file_io.copy(args.schema_file, os.path.join(args.output_dir, SCHEMA_FILE), overwrite=True) def main(argv=None): args = parse_arguments(sys.argv if argv is None else argv) run_analysis(args) if __name__ == '__main__': main() ================================================ FILE: solutionbox/structured_data/mltoolbox/_structured_data/trainer/__init__.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import from . import task __all__ = ['task'] ================================================ FILE: solutionbox/structured_data/mltoolbox/_structured_data/trainer/task.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import os import re import sys from . import util import tensorflow as tf from tensorflow.contrib.learn.python.learn import learn_runner from tensorflow.python.lib.io import file_io def get_reader_input_fn(train_config, preprocess_output_dir, model_type, data_paths, batch_size, shuffle, num_epochs=None): """Builds input layer for training.""" def get_input_features(): """Read the input features from the given data paths.""" _, examples = util.read_examples( input_files=data_paths, batch_size=batch_size, shuffle=shuffle, num_epochs=num_epochs) features = util.parse_example_tensor(examples=examples, train_config=train_config, keep_target=True) target_name = train_config['target_column'] target = features.pop(target_name) features, target = util.preprocess_input( features=features, target=target, train_config=train_config, preprocess_output_dir=preprocess_output_dir, model_type=model_type) return features, target # Return a function to input the feaures into the model from a data path. return get_input_features def get_experiment_fn(args): """Builds the experiment function for learn_runner.run. Args: args: the command line args Returns: A function that returns a tf.learn experiment object. """ def get_experiment(output_dir): # Merge schema, input features, and transforms. train_config = util.merge_metadata(args.preprocess_output_dir, args.transforms_file) # Get the model to train. estimator = util.get_estimator(output_dir, train_config, args) # Save a copy of the scehma and input to the model folder. schema_file = os.path.join(args.preprocess_output_dir, util.SCHEMA_FILE) # Make list of files to save with the trained model. additional_assets = {'features.json': args.transforms_file, util.SCHEMA_FILE: schema_file} if util.is_classification_model(args.model_type): target_name = train_config['target_column'] vocab_file_name = util.CATEGORICAL_ANALYSIS % target_name vocab_file_path = os.path.join( args.preprocess_output_dir, vocab_file_name) assert file_io.file_exists(vocab_file_path) additional_assets[vocab_file_name] = vocab_file_path export_strategy_target = util.make_export_strategy( train_config=train_config, args=args, keep_target=True, assets_extra=additional_assets) export_strategy_notarget = util.make_export_strategy( train_config=train_config, args=args, keep_target=False, assets_extra=additional_assets) input_reader_for_train = get_reader_input_fn( train_config=train_config, preprocess_output_dir=args.preprocess_output_dir, model_type=args.model_type, data_paths=args.train_data_paths, batch_size=args.train_batch_size, shuffle=True, num_epochs=args.num_epochs) input_reader_for_eval = get_reader_input_fn( train_config=train_config, preprocess_output_dir=args.preprocess_output_dir, model_type=args.model_type, data_paths=args.eval_data_paths, batch_size=args.eval_batch_size, shuffle=False, num_epochs=1) return tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=input_reader_for_train, eval_input_fn=input_reader_for_eval, train_steps=args.max_steps, export_strategies=[export_strategy_target, export_strategy_notarget], min_eval_frequency=args.min_eval_frequency, eval_steps=None, ) # Return a function to create an Experiment. return get_experiment def parse_arguments(argv): """Parse the command line arguments.""" parser = argparse.ArgumentParser( description=('Train a regression or classification model. Note that if ' 'using a DNN model, --layer-size1=NUM, --layer-size2=NUM, ' 'should be used. ')) # I/O file parameters parser.add_argument('--train-data-paths', type=str, action='append', required=True) parser.add_argument('--eval-data-paths', type=str, action='append', required=True) parser.add_argument('--job-dir', type=str, required=True) parser.add_argument('--preprocess-output-dir', type=str, required=True, help=('Output folder of preprocessing. Should contain the' ' schema file, and numerical stats and vocab files.' ' Path must be on GCS if running' ' cloud training.')) parser.add_argument('--transforms-file', type=str, required=True, help=('File describing the the transforms to apply on ' 'each column')) # HP parameters parser.add_argument('--learning-rate', type=float, default=0.01, help='tf.train.AdamOptimizer learning rate') parser.add_argument('--epsilon', type=float, default=0.0005, help='tf.train.AdamOptimizer epsilon') # --layer_size See below # Model problems parser.add_argument('--model-type', choices=['linear_classification', 'linear_regression', 'dnn_classification', 'dnn_regression'], required=True) parser.add_argument('--top-n', type=int, default=1, help=('For classification problems, the output graph ' 'will contain the labels and scores for the top ' 'n classes.')) # Training input parameters parser.add_argument('--max-steps', type=int, default=5000, help='Maximum number of training steps to perform.') parser.add_argument('--num-epochs', type=int, help=('Maximum number of training data epochs on which ' 'to train. If both --max-steps and --num-epochs ' 'are specified, the training job will run for ' '--max-steps or --num-epochs, whichever occurs ' 'first. If unspecified will run for --max-steps.')) parser.add_argument('--train-batch-size', type=int, default=1000) parser.add_argument('--eval-batch-size', type=int, default=1000) parser.add_argument('--min-eval-frequency', type=int, default=100, help=('Minimum number of training steps between ' 'evaluations')) # other parameters parser.add_argument('--save-checkpoints-secs', type=int, default=600, help=('How often the model should be checkpointed/saved ' 'in seconds')) args, remaining_args = parser.parse_known_args(args=argv[1:]) # All HP parambeters must be unique, so we need to support an unknown number # of --layer_size1=10 --layer_size2=10 ... # Look at remaining_args for layer_size\d+ to get the layer info. # Get number of layers pattern = re.compile('layer-size(\d+)') num_layers = 0 for other_arg in remaining_args: match = re.search(pattern, other_arg) if match: num_layers = max(num_layers, int(match.group(1))) # Build a new parser so we catch unknown args and missing layer_sizes. parser = argparse.ArgumentParser() for i in range(num_layers): parser.add_argument('--layer-size%s' % str(i + 1), type=int, required=True) layer_args = vars(parser.parse_args(args=remaining_args)) layer_sizes = [] for i in range(num_layers): key = 'layer_size%s' % str(i + 1) layer_sizes.append(layer_args[key]) assert len(layer_sizes) == num_layers args.layer_sizes = layer_sizes return args def main(argv=None): """Run a Tensorflow model on the Iris dataset.""" args = parse_arguments(sys.argv if argv is None else argv) tf.logging.set_verbosity(tf.logging.INFO) learn_runner.run( experiment_fn=get_experiment_fn(args), output_dir=args.job_dir) if __name__ == '__main__': main() ================================================ FILE: solutionbox/structured_data/mltoolbox/_structured_data/trainer/util.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import json import multiprocessing import os import math import six import tensorflow as tf from tensorflow.python.lib.io import file_io from tensorflow.contrib.learn.python.learn.utils import input_fn_utils from tensorflow.contrib.learn.python.learn import export_strategy from tensorflow.contrib.learn.python.learn.utils import ( saved_model_export_utils) from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.contrib.framework.python.ops import variables as contrib_variables from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib from tensorflow.python.training import saver from tensorflow.python.framework import ops from tensorflow.python.client import session as tf_session from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import tag_constants from tensorflow.python.ops import control_flow_ops from tensorflow.python.util import compat from tensorflow.python.platform import gfile from tensorflow.python.saved_model import signature_def_utils SCHEMA_FILE = 'schema.json' NUMERICAL_ANALYSIS = 'stats.json' CATEGORICAL_ANALYSIS = 'vocab_%s.csv' # Constants for the Prediction Graph fetch tensors. PG_TARGET = 'target' # from input PG_REGRESSION_PREDICTED_TARGET = 'predicted' PG_CLASSIFICATION_FIRST_LABEL = 'predicted' PG_CLASSIFICATION_FIRST_SCORE = 'score' PG_CLASSIFICATION_LABEL_TEMPLATE = 'predicted_%s' PG_CLASSIFICATION_SCORE_TEMPLATE = 'score_%s' class NotFittedError(ValueError): pass # ============================================================================== # Functions for saving the exported graphs. # ============================================================================== def _recursive_copy(src_dir, dest_dir): """Copy the contents of src_dir into the folder dest_dir. Args: src_dir: gsc or local path. dest_dir: gcs or local path. When called, dest_dir should exist. """ src_dir = python_portable_string(src_dir) dest_dir = python_portable_string(dest_dir) file_io.recursive_create_dir(dest_dir) for file_name in file_io.list_directory(src_dir): old_path = os.path.join(src_dir, file_name) new_path = os.path.join(dest_dir, file_name) if file_io.is_directory(old_path): _recursive_copy(old_path, new_path) else: file_io.copy(old_path, new_path, overwrite=True) def serving_from_csv_input(train_config, args, keep_target): """Read the input features from a placeholder csv string tensor.""" examples = tf.placeholder( dtype=tf.string, shape=(None,), name='csv_input_string') features = parse_example_tensor(examples=examples, train_config=train_config, keep_target=keep_target) if keep_target: target = features.pop(train_config['target_column']) else: target = None features, target = preprocess_input( features=features, target=target, train_config=train_config, preprocess_output_dir=args.preprocess_output_dir, model_type=args.model_type) return input_fn_utils.InputFnOps(features, target, {'csv_line': examples} ) def make_output_tensors(train_config, args, input_ops, model_fn_ops, keep_target=True): target_name = train_config['target_column'] key_name = train_config['key_column'] outputs = {} outputs[key_name] = tf.squeeze(input_ops.features[key_name]) if is_classification_model(args.model_type): # build maps from ints to the origional categorical strings. string_value = get_vocabulary(args.preprocess_output_dir, target_name) table = tf.contrib.lookup.index_to_string_table_from_tensor( mapping=string_value, default_value='UNKNOWN') # Get the label of the input target. if keep_target: input_target_label = table.lookup(input_ops.labels) outputs[PG_TARGET] = tf.squeeze(input_target_label) # TODO(brandondutra): get the score of the target label too. probabilities = model_fn_ops.predictions['probabilities'] # get top k labels and their scores. (top_k_values, top_k_indices) = tf.nn.top_k(probabilities, k=args.top_n) top_k_labels = table.lookup(tf.to_int64(top_k_indices)) # Write the top_k values using 2*top_k columns. num_digits = int(math.ceil(math.log(args.top_n, 10))) if num_digits == 0: num_digits = 1 for i in range(0, args.top_n): # Pad i based on the size of k. So if k = 100, i = 23 -> i = '023'. This # makes sorting the columns easy. padded_i = str(i + 1).zfill(num_digits) if i == 0: label_alias = PG_CLASSIFICATION_FIRST_LABEL else: label_alias = PG_CLASSIFICATION_LABEL_TEMPLATE % padded_i label_tensor_name = (tf.squeeze( tf.slice(top_k_labels, [0, i], [tf.shape(top_k_labels)[0], 1]))) if i == 0: score_alias = PG_CLASSIFICATION_FIRST_SCORE else: score_alias = PG_CLASSIFICATION_SCORE_TEMPLATE % padded_i score_tensor_name = (tf.squeeze( tf.slice(top_k_values, [0, i], [tf.shape(top_k_values)[0], 1]))) outputs.update({label_alias: label_tensor_name, score_alias: score_tensor_name}) else: if keep_target: outputs[PG_TARGET] = tf.squeeze(input_ops.labels) scores = model_fn_ops.predictions['scores'] outputs[PG_REGRESSION_PREDICTED_TARGET] = tf.squeeze(scores) return outputs def make_export_strategy(train_config, args, keep_target, assets_extra=None): def export_fn(estimator, export_dir_base, checkpoint_path=None, eval_result=None): with ops.Graph().as_default() as g: contrib_variables.create_global_step(g) input_ops = serving_from_csv_input(train_config, args, keep_target) model_fn_ops = estimator._call_model_fn(input_ops.features, None, model_fn_lib.ModeKeys.INFER) output_fetch_tensors = make_output_tensors( train_config=train_config, args=args, input_ops=input_ops, model_fn_ops=model_fn_ops, keep_target=keep_target) signature_def_map = { 'serving_default': signature_def_utils.predict_signature_def(input_ops.default_inputs, output_fetch_tensors) } if not checkpoint_path: # Locate the latest checkpoint checkpoint_path = saver.latest_checkpoint(estimator._model_dir) if not checkpoint_path: raise NotFittedError("Couldn't find trained model at %s." % estimator._model_dir) export_dir = saved_model_export_utils.get_timestamped_export_dir( export_dir_base) if (model_fn_ops.scaffold is not None and model_fn_ops.scaffold.saver is not None): saver_for_restore = model_fn_ops.scaffold.saver else: saver_for_restore = saver.Saver(sharded=True) with tf_session.Session('') as session: saver_for_restore.restore(session, checkpoint_path) init_op = control_flow_ops.group( variables.local_variables_initializer(), resources.initialize_resources(resources.shared_resources()), tf.tables_initializer()) # Perform the export builder = saved_model_builder.SavedModelBuilder(export_dir) builder.add_meta_graph_and_variables( session, [tag_constants.SERVING], signature_def_map=signature_def_map, assets_collection=ops.get_collection( ops.GraphKeys.ASSET_FILEPATHS), legacy_init_op=init_op) builder.save(False) # Add the extra assets if assets_extra: assets_extra_path = os.path.join(compat.as_bytes(export_dir), compat.as_bytes('assets.extra')) for dest_relative, source in assets_extra.items(): dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative)) dest_path = os.path.dirname(dest_absolute) gfile.MakeDirs(dest_path) gfile.Copy(source, dest_absolute) # only keep the last 3 models saved_model_export_utils.garbage_collect_exports( python_portable_string(export_dir_base), exports_to_keep=3) # save the last model to the model folder. # export_dir_base = A/B/intermediate_models/ if keep_target: final_dir = os.path.join(args.job_dir, 'evaluation_model') else: final_dir = os.path.join(args.job_dir, 'model') if file_io.is_directory(final_dir): file_io.delete_recursively(final_dir) file_io.recursive_create_dir(final_dir) _recursive_copy(export_dir, final_dir) return export_dir if keep_target: intermediate_dir = 'intermediate_evaluation_models' else: intermediate_dir = 'intermediate_prediction_models' return export_strategy.ExportStrategy(intermediate_dir, export_fn) # ============================================================================== # Reading the input csv files and parsing its output into tensors. # ============================================================================== def parse_example_tensor(examples, train_config, keep_target): """Read the csv files. Args: examples: string tensor train_config: training config keep_target: if true, the target column is expected to exist and it is returned in the features dict. Returns: Dict of feature_name to tensor. Target feature is in the dict. """ csv_header = [] if keep_target: csv_header = train_config['csv_header'] else: csv_header = [name for name in train_config['csv_header'] if name != train_config['target_column']] # record_defaults are used by tf.decode_csv to insert defaults, and to infer # the datatype. record_defaults = [[train_config['csv_defaults'][name]] for name in csv_header] tensors = tf.decode_csv(examples, record_defaults, name='csv_to_tensors') # I'm not really sure why expand_dims needs to be called. If using regression # models, it errors without it. tensors = [tf.expand_dims(x, axis=1) for x in tensors] tensor_dict = dict(zip(csv_header, tensors)) return tensor_dict def read_examples(input_files, batch_size, shuffle, num_epochs=None): """Creates readers and queues for reading example protos.""" files = [] for e in input_files: for path in e.split(','): files.extend(file_io.get_matching_files(path)) thread_count = multiprocessing.cpu_count() # The minimum number of instances in a queue from which examples are drawn # randomly. The larger this number, the more randomness at the expense of # higher memory requirements. min_after_dequeue = 1000 # When batching data, the queue's capacity will be larger than the batch_size # by some factor. The recommended formula is (num_threads + a small safety # margin). For now, we use a single thread for reading, so this can be small. queue_size_multiplier = thread_count + 3 # Convert num_epochs == 0 -> num_epochs is None, if necessary num_epochs = num_epochs or None # Build a queue of the filenames to be read. filename_queue = tf.train.string_input_producer(files, num_epochs, shuffle) example_id, encoded_example = tf.TextLineReader().read_up_to( filename_queue, batch_size) if shuffle: capacity = min_after_dequeue + queue_size_multiplier * batch_size return tf.train.shuffle_batch( [example_id, encoded_example], batch_size, capacity, min_after_dequeue, enqueue_many=True, num_threads=thread_count) else: capacity = queue_size_multiplier * batch_size return tf.train.batch( [example_id, encoded_example], batch_size, capacity=capacity, enqueue_many=True, num_threads=thread_count) # ============================================================================== # Building the TF learn estimators # ============================================================================== def get_estimator(output_dir, train_config, args): """Returns a tf learn estimator. We only support {DNN, Linear}Regressor and {DNN, Linear}Classifier. This is controlled by the values of model_type in the args. Args: output_dir: Modes are saved into outputdir/train train_config: our training config args: command line parameters Returns: TF lean estimator Raises: ValueError: if config is wrong. """ # Check the requested mode fits the preprocessed data. target_name = train_config['target_column'] if is_classification_model(args.model_type) and target_name not in \ train_config['categorical_columns']: raise ValueError('When using a classification model, the target must be a ' 'categorical variable.') if is_regression_model(args.model_type) and target_name not in \ train_config['numerical_columns']: raise ValueError('When using a regression model, the target must be a ' 'numerical variable.') # Check layers used for dnn models. if is_dnn_model(args.model_type) and not args.layer_sizes: raise ValueError('--layer-size* must be used with DNN models') if is_linear_model(args.model_type) and args.layer_sizes: raise ValueError('--layer-size* cannot be used with linear models') # Build tf.learn features feature_columns = _tflearn_features(train_config, args) # Set how often to run checkpointing in terms of time. config = tf.contrib.learn.RunConfig( save_checkpoints_secs=args.save_checkpoints_secs) train_dir = os.path.join(output_dir, 'train') if args.model_type == 'dnn_regression': estimator = tf.contrib.learn.DNNRegressor( feature_columns=feature_columns, hidden_units=args.layer_sizes, config=config, model_dir=train_dir, optimizer=tf.train.AdamOptimizer( args.learning_rate, epsilon=args.epsilon)) elif args.model_type == 'linear_regression': estimator = tf.contrib.learn.LinearRegressor( feature_columns=feature_columns, config=config, model_dir=train_dir, optimizer=tf.train.AdamOptimizer( args.learning_rate, epsilon=args.epsilon)) elif args.model_type == 'dnn_classification': estimator = tf.contrib.learn.DNNClassifier( feature_columns=feature_columns, hidden_units=args.layer_sizes, n_classes=train_config['vocab_stats'][target_name]['n_classes'], config=config, model_dir=train_dir, optimizer=tf.train.AdamOptimizer( args.learning_rate, epsilon=args.epsilon)) elif args.model_type == 'linear_classification': estimator = tf.contrib.learn.LinearClassifier( feature_columns=feature_columns, n_classes=train_config['vocab_stats'][target_name]['n_classes'], config=config, model_dir=train_dir, optimizer=tf.train.AdamOptimizer( args.learning_rate, epsilon=args.epsilon)) else: raise ValueError('bad --model-type value') return estimator def preprocess_input(features, target, train_config, preprocess_output_dir, model_type): """Perform some transformations after reading in the input tensors. Args: features: dict of feature_name to tensor target: tensor train_config: our training config object preprocess_output_dir: folder should contain the vocab files. model_type: the tf model type. Raises: ValueError: if wrong transforms are used Returns: New features dict and new target tensor. """ target_name = train_config['target_column'] key_name = train_config['key_column'] # Do the numerical transforms. # Numerical transforms supported for regression/classification # 1) num -> do nothing (identity, default) # 2) num -> scale to -1, 1 (scale) # 3) num -> scale to -a, a (scale with value parameter) with tf.name_scope('numerical_feature_preprocess'): if train_config['numerical_columns']: numerical_analysis_file = os.path.join(preprocess_output_dir, NUMERICAL_ANALYSIS) if not file_io.file_exists(numerical_analysis_file): raise ValueError('File %s not found in %s' % (NUMERICAL_ANALYSIS, preprocess_output_dir)) numerical_anlysis = json.loads( python_portable_string( file_io.read_file_to_string(numerical_analysis_file))) for name in train_config['numerical_columns']: if name == target_name or name == key_name: continue transform_config = train_config['transforms'].get(name, {}) transform_name = transform_config.get('transform', None) if transform_name == 'scale': value = float(transform_config.get('value', 1.0)) features[name] = _scale_tensor( features[name], range_min=numerical_anlysis[name]['min'], range_max=numerical_anlysis[name]['max'], scale_min=-value, scale_max=value) elif transform_name == 'identity' or transform_name is None: pass else: raise ValueError(('For numerical variables, only scale ' 'and identity are supported: ' 'Error for %s') % name) # Do target transform if it exists. if target is not None: with tf.name_scope('target_feature_preprocess'): if target_name in train_config['categorical_columns']: labels = train_config['vocab_stats'][target_name]['labels'] table = tf.contrib.lookup.string_to_index_table_from_tensor(labels) target = table.lookup(target) # target = tf.contrib.lookup.string_to_index(target, labels) # Do categorical transforms. Only apply vocab mapping. The real # transforms are done with tf learn column features. with tf.name_scope('categorical_feature_preprocess'): for name in train_config['categorical_columns']: if name == key_name or name == target_name: continue transform_config = train_config['transforms'].get(name, {}) transform_name = transform_config.get('transform', None) if is_dnn_model(model_type): if transform_name == 'embedding' or transform_name == 'one_hot' or transform_name is None: map_vocab = True else: raise ValueError('Unknown transform %s' % transform_name) elif is_linear_model(model_type): if (transform_name == 'one_hot' or transform_name is None): map_vocab = True elif transform_name == 'embedding': map_vocab = False else: raise ValueError('Unknown transform %s' % transform_name) if map_vocab: labels = train_config['vocab_stats'][name]['labels'] table = tf.contrib.lookup.string_to_index_table_from_tensor(labels) features[name] = table.lookup(features[name]) return features, target def _scale_tensor(tensor, range_min, range_max, scale_min, scale_max): """Scale a tensor to scale_min to scale_max. Args: tensor: input tensor. Should be a numerical tensor. range_min: min expected value for this feature/tensor. range_max: max expected Value. scale_min: new expected min value. scale_max: new expected max value. Returns: scaled tensor. """ if range_min == range_max: return tensor float_tensor = tf.to_float(tensor) scaled_tensor = tf.divide((tf.subtract(float_tensor, range_min) * tf.constant(float(scale_max - scale_min))), tf.constant(float(range_max - range_min))) shifted_tensor = scaled_tensor + tf.constant(float(scale_min)) return shifted_tensor def _tflearn_features(train_config, args): """Builds the tf.learn feature list. All numerical features are just given real_valued_column because all the preprocessing transformations are done in preprocess_input. Categoriacl features are processed here depending if the vocab map (from string to int) was applied in preprocess_input. Args: train_config: our train config object args: command line args. Returns: List of TF lean feature columns. Raises: ValueError: if wrong transforms are used for the model type. """ feature_columns = [] target_name = train_config['target_column'] key_name = train_config['key_column'] for name in train_config['numerical_columns']: if name != target_name and name != key_name: feature_columns.append(tf.contrib.layers.real_valued_column( name, dimension=1)) # Supported transforms: # for DNN # 1) string -> make int -> embedding (embedding) # 2) string -> make int -> one_hot (one_hot, default) # for linear # 1) string -> sparse_column_with_hash_bucket (embedding) # 2) string -> make int -> sparse_column_with_integerized_feature (one_hot, default) # It is unfortunate that tf.layers has different feature transforms if the # model is linear or DNN. This pacakge should not expose to the user that # we are using tf.layers. It is crazy that DNN models support more feature # types (like string -> hash sparse column -> embedding) for name in train_config['categorical_columns']: if name != target_name and name != key_name: transform_config = train_config['transforms'].get(name, {}) transform_name = transform_config.get('transform', None) if is_dnn_model(args.model_type): if transform_name == 'embedding': sparse = tf.contrib.layers.sparse_column_with_integerized_feature( name, bucket_size=train_config['vocab_stats'][name]['n_classes']) learn_feature = tf.contrib.layers.embedding_column( sparse, dimension=transform_config['embedding_dim']) elif transform_name == 'one_hot' or transform_name is None: sparse = tf.contrib.layers.sparse_column_with_integerized_feature( name, bucket_size=train_config['vocab_stats'][name]['n_classes']) learn_feature = tf.contrib.layers.one_hot_column(sparse) else: raise ValueError(('Unknown transform name. Only \'embedding\' ' 'and \'one_hot\' transforms are supported. Got %s') % transform_name) elif is_linear_model(args.model_type): if transform_name == 'one_hot' or transform_name is None: learn_feature = tf.contrib.layers.sparse_column_with_integerized_feature( name, bucket_size=train_config['vocab_stats'][name]['n_classes']) elif transform_name == 'embedding': learn_feature = tf.contrib.layers.sparse_column_with_hash_bucket( name, hash_bucket_size=transform_config['embedding_dim']) else: raise ValueError(('Unknown transform name. Only \'embedding\' ' 'and \'one_hot\' transforms are supported. Got %s') % transform_name) # Save the feature feature_columns.append(learn_feature) return feature_columns # ============================================================================== # Functions for dealing with the parameter files. # ============================================================================== def get_vocabulary(preprocess_output_dir, name): """Loads the vocabulary file as a list of strings. Args: preprocess_output_dir: Should contain the file CATEGORICAL_ANALYSIS % name. name: name of the csv column. Returns: List of strings. Raises: ValueError: if file is missing. """ vocab_file = os.path.join(preprocess_output_dir, CATEGORICAL_ANALYSIS % name) if not file_io.file_exists(vocab_file): raise ValueError('File %s not found in %s' % (CATEGORICAL_ANALYSIS % name, preprocess_output_dir)) labels = python_portable_string( file_io.read_file_to_string(vocab_file)).split('\n') label_values = [x for x in labels if x] # remove empty lines return label_values def merge_metadata(preprocess_output_dir, transforms_file): """Merge schema, analysis, and transforms files into one python object. Args: preprocess_output_dir: the output folder of preprocessing. Should contain the schema, and the numerical and categorical analysis files. transforms_file: the training transforms file. Returns: A dict in the form { csv_header: [name1, name2, ...], csv_defaults: {name1: value, name2: value}, key_column: name, target_column: name, categorical_columns: [] numerical_columns: [] transforms: { name1: {transform: scale, value: 2}, name2: {transform: embedding, dim: 50}, ... } vocab_stats: { name3: {n_classes: 23, labels: ['1', '2', ..., '23']}, name4: {n_classes: 102, labels: ['red', 'blue', ...]}} } Raises: ValueError: if one of the input metadata files is wrong. """ numerical_anlysis_file = os.path.join(preprocess_output_dir, NUMERICAL_ANALYSIS) schema_file = os.path.join(preprocess_output_dir, SCHEMA_FILE) numerical_anlysis = json.loads( python_portable_string( file_io.read_file_to_string(numerical_anlysis_file))) schema = json.loads( python_portable_string(file_io.read_file_to_string(schema_file))) transforms = json.loads( python_portable_string(file_io.read_file_to_string(transforms_file))) result_dict = {} result_dict['csv_header'] = [col_schema['name'] for col_schema in schema] result_dict['key_column'] = None result_dict['target_column'] = None result_dict['categorical_columns'] = [] result_dict['numerical_columns'] = [] result_dict['transforms'] = {} result_dict['csv_defaults'] = {} result_dict['vocab_stats'] = {} # get key column. for name, trans_config in six.iteritems(transforms): if trans_config.get('transform', None) == 'key': result_dict['key_column'] = name break if result_dict['key_column'] is None: raise ValueError('Key transform missing form transfroms file.') # get target column. result_dict['target_column'] = schema[0]['name'] for name, trans_config in six.iteritems(transforms): if trans_config.get('transform', None) == 'target': result_dict['target_column'] = name break if result_dict['target_column'] is None: raise ValueError('Target transform missing from transforms file.') # Get the numerical/categorical columns. for col_schema in schema: col_name = col_schema['name'] col_type = col_schema['type'].lower() if col_name == result_dict['key_column']: continue if col_type == 'string': result_dict['categorical_columns'].append(col_name) elif col_type == 'integer' or col_type == 'float': result_dict['numerical_columns'].append(col_name) else: raise ValueError('Unsupported schema type %s' % col_type) # Get the transforms. for name, trans_config in six.iteritems(transforms): if name != result_dict['target_column'] and name != result_dict['key_column']: result_dict['transforms'][name] = trans_config # Get the vocab_stats for name in result_dict['categorical_columns']: if name == result_dict['key_column']: continue label_values = get_vocabulary(preprocess_output_dir, name) if name != result_dict['target_column'] and '' not in label_values: label_values.append('') # append a 'missing' label. n_classes = len(label_values) result_dict['vocab_stats'][name] = {'n_classes': n_classes, 'labels': label_values} # Get the csv_defaults for col_schema in schema: name = col_schema['name'] col_type = col_schema['type'].lower() default = transforms.get(name, {}).get('default', None) if name == result_dict['target_column']: if name in result_dict['numerical_columns']: default = float(default or 0.0) else: default = default or '' elif name == result_dict['key_column']: if col_type == 'string': default = str(default or '') elif col_type == 'float': default = float(default or 0.0) else: default = int(default or 0) else: if col_type == 'string': default = str(default or '') if default not in result_dict['vocab_stats'][name]['labels']: raise ValueError('Default %s is not in the vocab for %s' % (default, name)) else: default = float(default or numerical_anlysis[name]['mean']) result_dict['csv_defaults'][name] = default validate_metadata(result_dict) return result_dict def validate_metadata(train_config): """Perform some checks that the trainig config is correct. Args: train_config: train config as produced by merge_metadata() Raises: ValueError: if columns look wrong. """ # Make sure we have a default for every column if len(train_config['csv_header']) != len(train_config['csv_defaults']): raise ValueError('Unequal number of columns in input features file and ' 'schema file.') # Check there are no missing columns. sorted_colums has two copies of the # target column because the target column is also listed in # categorical_columns or numerical_columns. sorted_columns = sorted(train_config['csv_header'] + [train_config['target_column']]) sorted_columns2 = sorted(train_config['categorical_columns'] + train_config['numerical_columns'] + [train_config['key_column']] + [train_config['target_column']]) if sorted_columns2 != sorted_columns: raise ValueError('Each csv header must be a numerical/categorical type, a ' ' key, or a target.') def is_linear_model(model_type): return model_type.startswith('linear_') def is_dnn_model(model_type): return model_type.startswith('dnn_') def is_regression_model(model_type): return model_type.endswith('_regression') def is_classification_model(model_type): return model_type.endswith('_classification') # Note that this function exists in google.datalab.utils, but that is not # installed on the training workers. def python_portable_string(string, encoding='utf-8'): """Converts bytes into a string type. Valid string types are retuned without modification. So in Python 2, type str and unicode are not converted. In Python 3, type bytes is converted to type str (unicode) """ if isinstance(string, six.string_types): return string if six.PY3: return string.decode(encoding) raise ValueError('Unsupported type %s' % str(type(string))) ================================================ FILE: solutionbox/structured_data/mltoolbox/classification/__init__.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from mltoolbox._structured_data.__version__ import __version__ __all__ = ['__version__'] ================================================ FILE: solutionbox/structured_data/mltoolbox/classification/dnn/__init__.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """This module contains functions for classification problems modeled as a fully connected feedforward deep neural network. Every function can run locally or use Google Cloud Platform. """ from ._classification_dnn import train, train_async from mltoolbox._structured_data import analyze, analyze_async, predict, batch_predict, \ batch_predict_async from mltoolbox._structured_data.__version__ import __version__ __all__ = ['train', 'train_async', 'analyze', 'analyze_async', 'predict', 'batch_predict', 'batch_predict_async', '__version__'] ================================================ FILE: solutionbox/structured_data/mltoolbox/classification/dnn/_classification_dnn.py ================================================ from mltoolbox._structured_data import train_async as core_train def train(train_dataset, eval_dataset, analysis_dir, output_dir, features, layer_sizes, max_steps=5000, num_epochs=None, train_batch_size=100, eval_batch_size=16, min_eval_frequency=100, top_n=None, learning_rate=0.01, epsilon=0.0005, job_name=None, cloud=None, ): """Blocking version of train_async. See documentation for train_async.""" job = train_async( train_dataset=train_dataset, eval_dataset=eval_dataset, analysis_dir=analysis_dir, output_dir=output_dir, features=features, layer_sizes=layer_sizes, max_steps=max_steps, num_epochs=num_epochs, train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, min_eval_frequency=min_eval_frequency, top_n=top_n, learning_rate=learning_rate, epsilon=epsilon, job_name=job_name, cloud=cloud, ) job.wait() print('Training: ' + str(job.state)) def train_async(train_dataset, eval_dataset, analysis_dir, output_dir, features, layer_sizes, max_steps=5000, num_epochs=None, train_batch_size=100, eval_batch_size=16, min_eval_frequency=100, top_n=None, learning_rate=0.01, epsilon=0.0005, job_name=None, cloud=None, ): """Train model locally or in the cloud. Local Training: Args: train_dataset: CsvDataSet eval_dataset: CsvDataSet analysis_dir: The output directory from local_analysis output_dir: Output directory of training. features: file path or features object. Example: { "col_A": {"transform": "scale", "default": 0.0}, "col_B": {"transform": "scale","value": 4}, # Note col_C is missing, so default transform used. "col_D": {"transform": "hash_one_hot", "hash_bucket_size": 4}, "col_target": {"transform": "target"}, "col_key": {"transform": "key"} } The keys correspond to the columns in the input files as defined by the schema file during preprocessing. Some notes 1) The "key" and "target" transforms are required. 2) Default values are optional. These are used if the input data has missing values during training and prediction. If not supplied for a column, the default value for a numerical column is that column's mean vlaue, and for a categorical column the empty string is used. 3) For numerical colums, the following transforms are supported: i) {"transform": "identity"}: does nothing to the number. (default) ii) {"transform": "scale"}: scales the colum values to -1, 1. iii) {"transform": "scale", "value": a}: scales the colum values to -a, a. For categorical colums, the following transforms are supported: i) {"transform": "one_hot"}: A one-hot vector using the full vocabulary is used. (default) ii) {"transform": "embedding", "embedding_dim": d}: Each label is embedded into an d-dimensional space. max_steps: Int. Number of training steps to perform. num_epochs: Maximum number of training data epochs on which to train. The training job will run for max_steps or num_epochs, whichever occurs first. train_batch_size: number of rows to train on in one step. eval_batch_size: number of rows to eval in one step. One pass of the eval dataset is done. If eval_batch_size does not perfectly divide the numer of eval instances, the last fractional batch is not used. min_eval_frequency: Minimum number of training steps between evaluations. top_n: Int. For classification problems, the output graph will contain the labels and scores for the top n classes with a default of n=1. Use None for regression problems. layer_sizes: List. Represents the layers in the connected DNN. If the model type is DNN, this must be set. Example [10, 3, 2], this will create three DNN layers where the first layer will have 10 nodes, the middle layer will have 3 nodes, and the laster layer will have 2 nodes. learning_rate: tf.train.AdamOptimizer's learning rate, epsilon: tf.train.AdamOptimizer's epsilon value. Cloud Training: All local training arguments are valid for cloud training. Cloud training contains two additional args: Args: cloud: A CloudTrainingConfig object. job_name: Training job name. A default will be picked if None. Returns: Datalab job """ return core_train( train_dataset=train_dataset, eval_dataset=eval_dataset, analysis_dir=analysis_dir, output_dir=output_dir, features=features, model_type='dnn_classification', max_steps=max_steps, num_epochs=num_epochs, train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, min_eval_frequency=min_eval_frequency, top_n=top_n, layer_sizes=layer_sizes, learning_rate=learning_rate, epsilon=epsilon, job_name=job_name, job_name_prefix='mltoolbox_classification_dnn', cloud=cloud, ) ================================================ FILE: solutionbox/structured_data/mltoolbox/classification/linear/__init__.py ================================================ """This module contains functions for multinomial logistic regression problems. Every function can run locally or use Google Cloud Platform. """ from ._classification_linear import train, train_async from mltoolbox._structured_data import analyze, analyze_async, predict, batch_predict, \ batch_predict_async from mltoolbox._structured_data.__version__ import __version__ __all__ = ['train', 'train_async', 'analyze', 'analyze_async', 'predict', 'batch_predict', 'batch_predict_async', '__version__'] ================================================ FILE: solutionbox/structured_data/mltoolbox/classification/linear/_classification_linear.py ================================================ from mltoolbox._structured_data import train_async as core_train def train(train_dataset, eval_dataset, analysis_dir, output_dir, features, max_steps=5000, num_epochs=None, train_batch_size=100, eval_batch_size=16, min_eval_frequency=100, top_n=None, learning_rate=0.01, epsilon=0.0005, job_name=None, cloud=None, ): """Blocking version of train_async. See documentation for train_async.""" job = train_async( train_dataset=train_dataset, eval_dataset=eval_dataset, analysis_dir=analysis_dir, output_dir=output_dir, features=features, max_steps=max_steps, num_epochs=num_epochs, train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, min_eval_frequency=min_eval_frequency, top_n=top_n, learning_rate=learning_rate, epsilon=epsilon, job_name=job_name, cloud=cloud, ) job.wait() print('Training: ' + str(job.state)) def train_async(train_dataset, eval_dataset, analysis_dir, output_dir, features, max_steps=5000, num_epochs=None, train_batch_size=100, eval_batch_size=16, min_eval_frequency=100, top_n=None, learning_rate=0.01, epsilon=0.0005, job_name=None, cloud=None, ): """Train model locally or in the cloud. Local Training: Args: train_dataset: CsvDataSet eval_dataset: CsvDataSet analysis_dir: The output directory from local_analysis output_dir: Output directory of training. features: file path or features object. Example: { "col_A": {"transform": "scale", "default": 0.0}, "col_B": {"transform": "scale","value": 4}, # Note col_C is missing, so default transform used. "col_D": {"transform": "hash_one_hot", "hash_bucket_size": 4}, "col_target": {"transform": "target"}, "col_key": {"transform": "key"} } The keys correspond to the columns in the input files as defined by the schema file during preprocessing. Some notes 1) The "key" and "target" transforms are required. 2) Default values are optional. These are used if the input data has missing values during training and prediction. If not supplied for a column, the default value for a numerical column is that column's mean vlaue, and for a categorical column the empty string is used. 3) For numerical colums, the following transforms are supported: i) {"transform": "identity"}: does nothing to the number. (default) ii) {"transform": "scale"}: scales the colum values to -1, 1. iii) {"transform": "scale", "value": a}: scales the colum values to -a, a. For categorical colums, the following transforms are supported: i) {"transform": "one_hot"}: A one-hot vector using the full vocabulary is used. (default) ii) {"transform": "embedding", "embedding_dim": d}: Each label is embedded into an d-dimensional space. max_steps: Int. Number of training steps to perform. num_epochs: Maximum number of training data epochs on which to train. The training job will run for max_steps or num_epochs, whichever occurs first. train_batch_size: number of rows to train on in one step. eval_batch_size: number of rows to eval in one step. One pass of the eval dataset is done. If eval_batch_size does not perfectly divide the numer of eval instances, the last fractional batch is not used. min_eval_frequency: Minimum number of training steps between evaluations. top_n: Int. For classification problems, the output graph will contain the labels and scores for the top n classes with a default of n=1. Use None for regression problems. learning_rate: tf.train.AdamOptimizer's learning rate, epsilon: tf.train.AdamOptimizer's epsilon value. Cloud Training: All local training arguments are valid for cloud training. Cloud training contains two additional args: Args: cloud: A CloudTrainingConfig object. job_name: Training job name. A default will be picked if None. Returns: Datalab job """ return core_train( train_dataset=train_dataset, eval_dataset=eval_dataset, analysis_dir=analysis_dir, output_dir=output_dir, features=features, model_type='linear_classification', max_steps=max_steps, num_epochs=num_epochs, train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, min_eval_frequency=min_eval_frequency, top_n=top_n, layer_sizes=None, learning_rate=learning_rate, epsilon=epsilon, job_name=job_name, job_name_prefix='mltoolbox_classification_linear', cloud=cloud, ) ================================================ FILE: solutionbox/structured_data/mltoolbox/regression/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from mltoolbox._structured_data.__version__ import __version__ __all__ = ['__version__'] ================================================ FILE: solutionbox/structured_data/mltoolbox/regression/dnn/__init__.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """This module contains functions for regression problems modeled as a fully connected feedforward deep neural network. Every function can run locally or use Google Cloud Platform. """ from ._regression_dnn import train, train_async from mltoolbox._structured_data import analyze, analyze_async, predict, batch_predict, \ batch_predict_async from mltoolbox._structured_data.__version__ import __version__ __all__ = ['train', 'train_async', 'analyze', 'analyze_async', 'predict', 'batch_predict', 'batch_predict_async', '__version__'] ================================================ FILE: solutionbox/structured_data/mltoolbox/regression/dnn/_regression_dnn.py ================================================ from mltoolbox._structured_data import train_async as core_train def train(train_dataset, eval_dataset, analysis_dir, output_dir, features, layer_sizes, max_steps=5000, num_epochs=None, train_batch_size=100, eval_batch_size=16, min_eval_frequency=100, learning_rate=0.01, epsilon=0.0005, job_name=None, cloud=None, ): """Blocking version of train_async. See documentation for train_async.""" job = train_async( train_dataset=train_dataset, eval_dataset=eval_dataset, analysis_dir=analysis_dir, output_dir=output_dir, features=features, layer_sizes=layer_sizes, max_steps=max_steps, num_epochs=num_epochs, train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, min_eval_frequency=min_eval_frequency, learning_rate=learning_rate, epsilon=epsilon, job_name=job_name, cloud=cloud, ) job.wait() print('Training: ' + str(job.state)) def train_async(train_dataset, eval_dataset, analysis_dir, output_dir, features, layer_sizes, max_steps=5000, num_epochs=None, train_batch_size=100, eval_batch_size=16, min_eval_frequency=100, learning_rate=0.01, epsilon=0.0005, job_name=None, cloud=None, ): """Train model locally or in the cloud. Local Training: Args: train_dataset: CsvDataSet eval_dataset: CsvDataSet analysis_dir: The output directory from local_analysis output_dir: Output directory of training. features: file path or features object. Example: { "col_A": {"transform": "scale", "default": 0.0}, "col_B": {"transform": "scale","value": 4}, # Note col_C is missing, so default transform used. "col_D": {"transform": "hash_one_hot", "hash_bucket_size": 4}, "col_target": {"transform": "target"}, "col_key": {"transform": "key"} } The keys correspond to the columns in the input files as defined by the schema file during preprocessing. Some notes 1) The "key" and "target" transforms are required. 2) Default values are optional. These are used if the input data has missing values during training and prediction. If not supplied for a column, the default value for a numerical column is that column's mean vlaue, and for a categorical column the empty string is used. 3) For numerical colums, the following transforms are supported: i) {"transform": "identity"}: does nothing to the number. (default) ii) {"transform": "scale"}: scales the colum values to -1, 1. iii) {"transform": "scale", "value": a}: scales the colum values to -a, a. For categorical colums, the following transforms are supported: i) {"transform": "one_hot"}: A one-hot vector using the full vocabulary is used. (default) ii) {"transform": "embedding", "embedding_dim": d}: Each label is embedded into an d-dimensional space. max_steps: Int. Number of training steps to perform. num_epochs: Maximum number of training data epochs on which to train. The training job will run for max_steps or num_epochs, whichever occurs first. train_batch_size: number of rows to train on in one step. eval_batch_size: number of rows to eval in one step. One pass of the eval dataset is done. If eval_batch_size does not perfectly divide the numer of eval instances, the last fractional batch is not used. min_eval_frequency: Minimum number of training steps between evaluations. layer_sizes: List. Represents the layers in the connected DNN. If the model type is DNN, this must be set. Example [10, 3, 2], this will create three DNN layers where the first layer will have 10 nodes, the middle layer will have 3 nodes, and the laster layer will have 2 nodes. learning_rate: tf.train.AdamOptimizer's learning rate, epsilon: tf.train.AdamOptimizer's epsilon value. Cloud Training: All local training arguments are valid for cloud training. Cloud training contains two additional args: Args: cloud: A CloudTrainingConfig object. job_name: Training job name. A default will be picked if None. Returns: Datalab job """ return core_train( train_dataset=train_dataset, eval_dataset=eval_dataset, analysis_dir=analysis_dir, output_dir=output_dir, features=features, model_type='dnn_regression', max_steps=max_steps, num_epochs=num_epochs, train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, min_eval_frequency=min_eval_frequency, top_n=None, layer_sizes=layer_sizes, learning_rate=learning_rate, epsilon=epsilon, job_name=job_name, job_name_prefix='mltoolbox_regression_dnn', cloud=cloud, ) ================================================ FILE: solutionbox/structured_data/mltoolbox/regression/linear/__init__.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """This module contains functions for linear regression problems. Every function can run locally or use Google Cloud Platform. """ from ._regression_linear import train, train_async from mltoolbox._structured_data import analyze, analyze_async, predict, batch_predict, \ batch_predict_async from mltoolbox._structured_data.__version__ import __version__ __all__ = ['train', 'train_async', 'analyze', 'analyze_async', 'predict', 'batch_predict', 'batch_predict_async', '__version__'] ================================================ FILE: solutionbox/structured_data/mltoolbox/regression/linear/_regression_linear.py ================================================ from mltoolbox._structured_data import train_async as core_train def train(train_dataset, eval_dataset, analysis_dir, output_dir, features, max_steps=5000, num_epochs=None, train_batch_size=100, eval_batch_size=16, min_eval_frequency=100, learning_rate=0.01, epsilon=0.0005, job_name=None, cloud=None, ): """Blocking version of train_async. See documentation for train_async.""" job = train_async( train_dataset=train_dataset, eval_dataset=eval_dataset, analysis_dir=analysis_dir, output_dir=output_dir, features=features, max_steps=max_steps, num_epochs=num_epochs, train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, min_eval_frequency=min_eval_frequency, learning_rate=learning_rate, epsilon=epsilon, job_name=job_name, cloud=cloud, ) job.wait() print('Training: ' + str(job.state)) def train_async(train_dataset, eval_dataset, analysis_dir, output_dir, features, max_steps=5000, num_epochs=None, train_batch_size=100, eval_batch_size=16, min_eval_frequency=100, learning_rate=0.01, epsilon=0.0005, job_name=None, cloud=None, ): """Train model locally or in the cloud. Local Training: Args: train_dataset: CsvDataSet eval_dataset: CsvDataSet analysis_dir: The output directory from local_analysis output_dir: Output directory of training. features: file path or features object. Example: { "col_A": {"transform": "scale", "default": 0.0}, "col_B": {"transform": "scale","value": 4}, # Note col_C is missing, so default transform used. "col_D": {"transform": "hash_one_hot", "hash_bucket_size": 4}, "col_target": {"transform": "target"}, "col_key": {"transform": "key"} } The keys correspond to the columns in the input files as defined by the schema file during preprocessing. Some notes 1) The "key" and "target" transforms are required. 2) Default values are optional. These are used if the input data has missing values during training and prediction. If not supplied for a column, the default value for a numerical column is that column's mean vlaue, and for a categorical column the empty string is used. 3) For numerical colums, the following transforms are supported: i) {"transform": "identity"}: does nothing to the number. (default) ii) {"transform": "scale"}: scales the colum values to -1, 1. iii) {"transform": "scale", "value": a}: scales the colum values to -a, a. For categorical colums, the following transforms are supported: i) {"transform": "one_hot"}: A one-hot vector using the full vocabulary is used. (default) ii) {"transform": "embedding", "embedding_dim": d}: Each label is embedded into an d-dimensional space. max_steps: Int. Number of training steps to perform. num_epochs: Maximum number of training data epochs on which to train. The training job will run for max_steps or num_epochs, whichever occurs first. train_batch_size: number of rows to train on in one step. eval_batch_size: number of rows to eval in one step. One pass of the eval dataset is done. If eval_batch_size does not perfectly divide the numer of eval instances, the last fractional batch is not used. min_eval_frequency: Minimum number of training steps between evaluations. learning_rate: tf.train.AdamOptimizer's learning rate, epsilon: tf.train.AdamOptimizer's epsilon value. Cloud Training: All local training arguments are valid for cloud training. Cloud training contains two additional args: Args: cloud: A CloudTrainingConfig object. job_name: Training job name. A default will be picked if None. Returns: Datalab job """ return core_train( train_dataset=train_dataset, eval_dataset=eval_dataset, analysis_dir=analysis_dir, output_dir=output_dir, features=features, model_type='linear_regression', max_steps=max_steps, num_epochs=num_epochs, train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, min_eval_frequency=min_eval_frequency, top_n=None, layer_sizes=None, learning_rate=learning_rate, epsilon=epsilon, job_name=job_name, job_name_prefix='mltoolbox_regression_linear', cloud=cloud, ) ================================================ FILE: solutionbox/structured_data/setup.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. # A copy of this file must be made in datalab_structured_data/setup.py import os import re from setuptools import setup # The version is saved in an __init__ file. def get_version(): VERSIONFILE = 'mltoolbox/_structured_data/__version__.py' if not os.path.isfile(VERSIONFILE): raise ValueError('setup.py: File not found %s' % VERSIONFILE) initfile_lines = open(VERSIONFILE, 'rt').readlines() VSRE = r"^__version__ = ['\"]([^'\"]*)['\"]" for line in initfile_lines: mo = re.search(VSRE, line, re.M) if mo: return mo.group(1) raise RuntimeError('Unable to find version string in %s.' % (VERSIONFILE,)) # Calling setuptools.find_packages does not work with cloud training repackaging # because this script is not ran from this folder. setup( name='mltoolbox_datalab_classification_and_regression', namespace_packages=['mltoolbox'], version=get_version(), packages=[ 'mltoolbox', 'mltoolbox.classification', 'mltoolbox.classification.linear', 'mltoolbox.classification.dnn', 'mltoolbox.regression', 'mltoolbox.regression.linear', 'mltoolbox.regression.dnn', 'mltoolbox._structured_data', 'mltoolbox._structured_data.preprocess', 'mltoolbox._structured_data.prediction', # 'mltoolbox._structured_data.test', 'mltoolbox._structured_data.trainer', ], description='Google Cloud Datalab Structured Data Package', author='Google', author_email='google-cloud-datalab-feedback@googlegroups.com', keywords=[ ], license="Apache Software License", classifiers=[ "Programming Language :: Python", "Programming Language :: Python :: 2", "Development Status :: 4 - Beta", "Environment :: Other Environment", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", "Topic :: Software Development :: Libraries :: Python Modules" ], long_description=""" """, install_requires=[ ], package_data={ }, data_files=[], ) ================================================ FILE: solutionbox/structured_data/test_mltoolbox/__init__.py ================================================ ================================================ FILE: solutionbox/structured_data/test_mltoolbox/e2e_functions.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import import os import random import json import six import subprocess def make_csv_data(filename, num_rows, problem_type, keep_target=True): """Writes csv data for preprocessing and training. Args: filename: writes data to local csv file. num_rows: how many rows of data will be generated. problem_type: 'classification' or 'regression'. Changes the target value. keep_target: if false, the csv file will have an empty column ',,' for the target. """ random.seed(12321) with open(filename, 'w') as f1: for i in range(num_rows): num1 = random.uniform(0, 30) num2 = random.randint(0, 20) num3 = random.uniform(0, 10) str1 = random.choice(['red', 'blue', 'green', 'pink', 'yellow', 'brown', 'black']) str2 = random.choice(['abc', 'def', 'ghi', 'jkl', 'mno', 'pqr']) str3 = random.choice(['car', 'truck', 'van', 'bike', 'train', 'drone']) map1 = {'red': 2, 'blue': 6, 'green': 4, 'pink': -5, 'yellow': -6, 'brown': -1, 'black': 7} map2 = {'abc': 10, 'def': 1, 'ghi': 1, 'jkl': 1, 'mno': 1, 'pqr': 1} map3 = {'car': 5, 'truck': 10, 'van': 15, 'bike': 20, 'train': 25, 'drone': 30} # Build some model. t = 0.5 + 0.5 * num1 - 2.5 * num2 + num3 t += map1[str1] + map2[str2] + map3[str3] if problem_type == 'classification': if t < 0: t = 100 elif t < 20: t = 101 else: t = 102 if keep_target: csv_line = "{id},{target},{num1},{num2},{num3},{str1},{str2},{str3}\n".format( id=i, target=t, num1=num1, num2=num2, num3=num3, str1=str1, str2=str2, str3=str3) else: csv_line = "{id},{num1},{num2},{num3},{str1},{str2},{str3}\n".format( id=i, num1=num1, num2=num2, num3=num3, str1=str1, str2=str2, str3=str3) f1.write(csv_line) def make_preprocess_schema(filename, problem_type): """Makes a schema file compatable with the output of make_csv_data. Writes a json file. Args: filename: local output file path problem_type: regression or classification """ schema = [ { "mode": "NULLABLE", "name": "key", "type": "STRING" }, { "mode": "REQUIRED", "name": "target", "type": ("STRING" if problem_type == 'classification' else "FLOAT") }, { "mode": "NULLABLE", "name": "num1", "type": "FLOAT" }, { "mode": "NULLABLE", "name": "num2", "type": "INTEGER" }, { "mode": "NULLABLE", "name": "num3", "type": "FLOAT" }, { "mode": "NULLABLE", "name": "str1", "type": "STRING" }, { "mode": "NULLABLE", "name": "str2", "type": "STRING" }, { "mode": "NULLABLE", "name": "str3", "type": "STRING" } ] with open(filename, 'w') as f: f.write(json.dumps(schema)) def run_preprocess(output_dir, csv_filename, schema_filename, logger): """Run preprocess via subprocess call to local_preprocess.py Args: output_dir: local or gcs folder to write output to csv_filename: local or gcs file to do analysis on schema_filename: local or gcs file path to schema file logger: python logging object """ preprocess_script = os.path.abspath( os.path.join(os.path.dirname(__file__), '../mltoolbox/_structured_data/preprocess/local_preprocess.py')) cmd = ['python', preprocess_script, '--output-dir', output_dir, '--input-file-pattern', csv_filename, '--schema-file', schema_filename ] logger.debug('Going to run command: %s' % ' '.join(cmd)) subprocess.check_call(cmd) # , stderr=open(os.devnull, 'wb')) def run_training( train_data_paths, eval_data_paths, output_path, preprocess_output_dir, transforms_file, max_steps, model_type, logger, extra_args=[]): """Runs Training via subprocess call to python -m Args: train_data_paths: local or gcs training csv files eval_data_paths: local or gcs eval csv files output_path: local or gcs folder to write output to preprocess_output_dir: local or gcs output location of preprocessing transforms_file: local or gcs path to transforms file max_steps: max training steps model_type: {dnn,linear}_{regression,classification} logger: python logging object extra_args: array of strings, passed to the trainer. Returns: The stderr of training as one string. TF writes to stderr, so basically, the output of training. """ # Gcloud has the fun bug that you have to be in the parent folder of task.py # when you call it. So cd there first. task_parent_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), '../mltoolbox/_structured_data')) cmd = ['cd %s &&' % task_parent_folder, 'python -m trainer.task', '--train-data-paths=%s' % train_data_paths, '--eval-data-paths=%s' % eval_data_paths, '--job-dir=%s' % output_path, '--preprocess-output-dir=%s' % preprocess_output_dir, '--transforms-file=%s' % transforms_file, '--model-type=%s' % model_type, '--train-batch-size=100', '--eval-batch-size=10', '--max-steps=%s' % max_steps] + extra_args logger.debug('Going to run command: %s' % ' '.join(cmd)) sp = subprocess.Popen(' '.join(cmd), shell=True, stderr=subprocess.PIPE) _, err = sp.communicate() if not six.PY2: err = err.decode() return err if __name__ == '__main__': make_csv_data('raw_train_regression.csv', 5000, 'regression', True) make_csv_data('raw_eval_regression.csv', 1000, 'regression', True) make_csv_data('raw_predict_regression.csv', 100, 'regression', False) make_preprocess_schema('schema_regression.json', 'regression') make_csv_data('raw_train_classification.csv', 5000, 'classification', True) make_csv_data('raw_eval_classification.csv', 1000, 'classification', True) make_csv_data('raw_predict_classification.csv', 100, 'classification', False) make_preprocess_schema('schema_classification.json', 'classification') ================================================ FILE: solutionbox/structured_data/test_mltoolbox/test_datalab_e2e.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Test analyze, training, and prediction. """ from __future__ import absolute_import from __future__ import print_function import json import logging import os import pandas as pd import shutil import six import sys import tempfile import unittest from . import e2e_functions from tensorflow.python.lib.io import file_io sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))) import mltoolbox.regression.linear as reglinear # noqa: E402 import google.datalab.ml as dlml # noqa: E402 class TestLinearRegression(unittest.TestCase): """Test linear regression works e2e locally. Note that there should be little need for testing the other scenarios (linear classification, dnn regression, dnn classification) as they should only differ at training time. The training coverage of task.py is already done in test_sd_trainer. """ def __init__(self, *args, **kwargs): super(TestLinearRegression, self).__init__(*args, **kwargs) # Log everything self._logger = logging.getLogger('TestStructuredDataLogger') self._logger.setLevel(logging.DEBUG) if not self._logger.handlers: self._logger.addHandler(logging.StreamHandler(stream=sys.stdout)) def _make_test_files(self): """Builds test files and folders""" # Make the output folders self._test_dir = tempfile.mkdtemp() self._preprocess_output = os.path.join(self._test_dir, 'preprocess') self._train_output = os.path.join(self._test_dir, 'train') self._batch_predict_output = os.path.join(self._test_dir, 'batch_predict') # Don't make train_output folder as it should not exist at training time. os.mkdir(self._preprocess_output) os.mkdir(self._batch_predict_output) # Make csv files self._csv_train_filename = os.path.join(self._test_dir, 'train_csv_data.csv') self._csv_eval_filename = os.path.join(self._test_dir, 'eval_csv_data.csv') self._csv_predict_filename = os.path.join(self._test_dir, 'predict_csv_data.csv') e2e_functions.make_csv_data(self._csv_train_filename, 100, 'regression', True) e2e_functions.make_csv_data(self._csv_eval_filename, 100, 'regression', True) self._predict_num_rows = 10 e2e_functions.make_csv_data(self._csv_predict_filename, self._predict_num_rows, 'regression', False) # Make schema file self._schema_filename = os.path.join(self._test_dir, 'schema.json') e2e_functions.make_preprocess_schema(self._schema_filename, 'regression') # Make feature file self._input_features_filename = os.path.join(self._test_dir, 'input_features_file.json') transforms = { "num1": {"transform": "scale"}, "num2": {"transform": "scale", "value": 4}, "str1": {"transform": "one_hot"}, "str2": {"transform": "embedding", "embedding_dim": 3}, "target": {"transform": "target"}, "key": {"transform": "key"}, } file_io.write_string_to_file( self._input_features_filename, json.dumps(transforms, indent=2)) def _run_analyze(self): reglinear.analyze( output_dir=self._preprocess_output, dataset=dlml.CsvDataSet( file_pattern=self._csv_train_filename, schema_file=self._schema_filename)) self.assertTrue(os.path.isfile( os.path.join(self._preprocess_output, 'stats.json'))) self.assertTrue(os.path.isfile( os.path.join(self._preprocess_output, 'vocab_str1.csv'))) def _run_train(self): reglinear.train( train_dataset=dlml.CsvDataSet( file_pattern=self._csv_train_filename, schema_file=self._schema_filename), eval_dataset=dlml.CsvDataSet( file_pattern=self._csv_eval_filename, schema_file=self._schema_filename), analysis_dir=self._preprocess_output, output_dir=self._train_output, features=self._input_features_filename, max_steps=100, train_batch_size=100) self.assertTrue(os.path.isfile( os.path.join(self._train_output, 'model', 'saved_model.pb'))) self.assertTrue(os.path.isfile( os.path.join(self._train_output, 'evaluation_model', 'saved_model.pb'))) def _run_predict(self): data = pd.read_csv(self._csv_predict_filename, header=None) df = reglinear.predict(data=data, training_dir=self._train_output) self.assertEqual(len(df.index), self._predict_num_rows) self.assertEqual(list(df), ['key', 'predicted']) def _run_batch_prediction(self, output_dir, use_target): reglinear.batch_predict( training_dir=self._train_output, prediction_input_file=(self._csv_eval_filename if use_target else self._csv_predict_filename), output_dir=output_dir, mode='evaluation' if use_target else 'prediction', batch_size=4, output_format='csv') # check errors file is empty errors = file_io.get_matching_files(os.path.join(output_dir, 'errors*')) self.assertEqual(len(errors), 1) if os.path.getsize(errors[0]): with open(errors[0]) as errors_file: self.fail(msg=errors_file.read()) # check predictions files are not empty predictions = file_io.get_matching_files(os.path.join(output_dir, 'predictions*')) self.assertGreater(os.path.getsize(predictions[0]), 0) # check the schema is correct schema_file = os.path.join(output_dir, 'csv_schema.json') self.assertTrue(os.path.isfile(schema_file)) schema = json.loads(file_io.read_file_to_string(schema_file)) self.assertEqual(schema[0]['name'], 'key') self.assertEqual(schema[1]['name'], 'predicted') if use_target: self.assertEqual(schema[2]['name'], 'target') self.assertEqual(len(schema), 3) else: self.assertEqual(len(schema), 2) def _cleanup(self): shutil.rmtree(self._test_dir) def test_e2e(self): try: self._make_test_files() self._run_analyze() self._run_train() if six.PY2: # Dataflow is only supported by python 2. Prediction assumes Dataflow # is installed. self._run_predict() self._run_batch_prediction( os.path.join(self._batch_predict_output, 'with_target'), True) self._run_batch_prediction( os.path.join(self._batch_predict_output, 'without_target'), False) else: print('only tested analyze in TestLinearRegression') finally: self._cleanup() if __name__ == '__main__': unittest.main() ================================================ FILE: solutionbox/structured_data/test_mltoolbox/test_package_functions.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. """Test the datalab interface functions in _package.py """ from __future__ import absolute_import from __future__ import print_function import os import six import sys sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))) import inspect # noqa: E402 import mltoolbox._structured_data._package as core_sd # noqa: E402 import mltoolbox.classification.linear as classlin # noqa: E402 import mltoolbox.classification.dnn as classdnn # noqa: E402 import mltoolbox.regression.linear as reglin # noqa: E402 import mltoolbox.regression.dnn as regdnn # noqa: E402 import google.datalab.ml as dlml # noqa: E402 import unittest # noqa: E402 @unittest.skipIf(not six.PY2, 'Python 2 is required') class TestAnalyze(unittest.TestCase): def test_not_csvdataset(self): """Test csvdataset is used""" # not a CsvDataSet job = core_sd.analyze_async('some_dir', 'some_file.txt').wait() self.assertIn('Only CsvDataSet is supported', job.fatal_error.message) def test_csvdataset_one_file(self): """Test CsvDataSet has only one file/pattern""" # TODO(brandondutra) remove this restriction job = core_sd.analyze_async( 'some_dir', dlml.CsvDataSet( file_pattern=['file1.txt', 'file2.txt'], schema='col1:STRING,col2:INTEGER,col3:FLOAT')).wait() self.assertIn('should be built with a file pattern', job.fatal_error.message) def test_projectid(self): """Test passing project id but cloud is false""" job = core_sd.analyze_async( 'some_dir', dlml.CsvDataSet( file_pattern=['file1.txt'], schema='col1:STRING,col2:INTEGER,col3:FLOAT'), project_id='project_id').wait() self.assertIn('project_id only needed if cloud is True', job.fatal_error.message) def test_cloud_with_local_output_folder(self): job = core_sd.analyze_async( 'some_dir', dlml.CsvDataSet( file_pattern=['gs://file1.txt'], schema='col1:STRING,col2:INTEGER,col3:FLOAT'), project_id='project_id', cloud=True).wait() self.assertIn('File some_dir is not a gcs path', job.fatal_error.message) def test_cloud_but_local_files(self): job = core_sd.analyze_async( 'gs://some_dir', dlml.CsvDataSet( file_pattern=['file1.txt'], schema='col1:STRING,col2:INTEGER,col3:FLOAT'), project_id='project_id', cloud=True).wait() self.assertIn('File file1.txt is not a gcs path', job.fatal_error.message) def test_unsupported_schema(self): """Test supported schema values. Note that not all valid BQ schema values are valid/used in the structured data package """ unsupported_col_types = ['bytes', 'boolean', 'timestamp', 'date', 'time', 'datetime', 'record'] for col_type in unsupported_col_types: schema = 'col_name:%s' % col_type job = core_sd.analyze_async( 'some_dir', dlml.CsvDataSet( file_pattern=['file1.txt'], schema=schema), cloud=False).wait() self.assertIn('Schema contains an unsupported type %s.' % col_type, job.fatal_error.message) job = core_sd.analyze_async( 'gs://some_dir', dlml.CsvDataSet( file_pattern=['gs://file1.txt'], schema=schema), cloud=True, project_id='junk_project_id').wait() self.assertIn('Schema contains an unsupported type %s.' % col_type, job.fatal_error.message) @unittest.skipIf(not six.PY2, 'Python 2 is required') class TestFunctionSignature(unittest.TestCase): def _argspec(self, fn_obj): if six.PY2: return inspect.getargspec(fn_obj) else: return inspect.getfullargspec(fn_obj) def test_same_analysis(self): """Test that there is only one analyze function""" self.assertIs(core_sd.analyze, classlin.analyze) self.assertIs(core_sd.analyze, classdnn.analyze) self.assertIs(core_sd.analyze, reglin.analyze) self.assertIs(core_sd.analyze, regdnn.analyze) def test_same_analysis_async(self): """Test that there is only one analyze_async function""" self.assertIs(core_sd.analyze_async, classlin.analyze_async) self.assertIs(core_sd.analyze_async, classdnn.analyze_async) self.assertIs(core_sd.analyze_async, reglin.analyze_async) self.assertIs(core_sd.analyze_async, regdnn.analyze_async) def test_analysis_argspec(self): """Test all analyze functions have the same parameters""" self.assertEqual(self._argspec(core_sd.analyze), self._argspec(core_sd.analyze_async)) self.assertEqual(self._argspec(core_sd.analyze), self._argspec(core_sd._analyze)) ================================================ FILE: solutionbox/structured_data/test_mltoolbox/test_sd_preprocess.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import import glob import json import logging import os import shutil import sys import filecmp import tempfile import unittest from . import e2e_functions class TestPreprocess(unittest.TestCase): """Tests preprocessing. Runs analysis on a test dataset. Checks that the expected files are made. """ def __init__(self, *args, **kwargs): super(TestPreprocess, self).__init__(*args, **kwargs) # Log everything self._logger = logging.getLogger('TestStructuredDataLogger') self._logger.setLevel(logging.DEBUG) if not self._logger.handlers: self._logger.addHandler(logging.StreamHandler(stream=sys.stdout)) def setUp(self): self._test_dir = tempfile.mkdtemp() self._csv_filename = os.path.join(self._test_dir, 'raw_csv_data.csv') self._schema_filename = os.path.join(self._test_dir, 'schema.json') self._preprocess_output = os.path.join(self._test_dir, 'pout') def tearDown(self): self._logger.debug('TestPreprocess: removing test dir: ' + self._test_dir) shutil.rmtree(self._test_dir) def _make_test_data(self, problem_type): """Makes input files to run preprocessing on. Args: problem_type: 'regression' or 'classification' """ e2e_functions.make_csv_data(self._csv_filename, 100, problem_type, True) e2e_functions.make_preprocess_schema(self._schema_filename, problem_type) def _test_preprocess(self, problem_type): self._make_test_data(problem_type) e2e_functions.run_preprocess( output_dir=self._preprocess_output, csv_filename=self._csv_filename, schema_filename=self._schema_filename, logger=self._logger) schema_file = os.path.join(self._preprocess_output, 'schema.json') numerical_analysis_file = os.path.join(self._preprocess_output, 'stats.json') # test schema file was copied self.assertTrue(filecmp.cmp(schema_file, self._schema_filename)) expected_numerical_keys = ['num1', 'num2', 'num3'] if problem_type == 'regression': expected_numerical_keys.append('target') # Load the numerical analysis file and check it has the right keys with open(numerical_analysis_file, 'r') as f: analysis = json.load(f) self.assertEqual(sorted(expected_numerical_keys), sorted(analysis.keys())) # Check that the vocab files are made expected_vocab_files = ['vocab_str1.csv', 'vocab_str2.csv', 'vocab_str3.csv', 'vocab_key.csv'] if problem_type == 'classification': expected_vocab_files.append('vocab_target.csv') for name in expected_vocab_files: vocab_file = os.path.join(self._preprocess_output, name) self.assertTrue(os.path.exists(vocab_file)) self.assertGreater(os.path.getsize(vocab_file), 0) all_expected_files = (expected_vocab_files + ['stats.json', 'schema.json']) all_file_paths = glob.glob(os.path.join(self._preprocess_output, '*')) all_files = [os.path.basename(path) for path in all_file_paths] self.assertEqual(sorted(all_expected_files), sorted(all_files)) def testRegression(self): self._test_preprocess('regression') def testClassification(self): self._test_preprocess('classification') if __name__ == '__main__': unittest.main() ================================================ FILE: solutionbox/structured_data/test_mltoolbox/test_sd_trainer.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from __future__ import absolute_import import json import logging import os import re import shutil import sys import tempfile import unittest from . import e2e_functions class TestTrainer(unittest.TestCase): """Tests training. Each test builds a csv test dataset. Preprocessing is run on the data to produce analysis. Training is then ran, and the output is collected and the accuracy/loss values are inspected. """ def __init__(self, *args, **kwargs): super(TestTrainer, self).__init__(*args, **kwargs) # Allow this class to be subclassed for quick tests that only care about # training working, not model loss/accuracy. self._max_steps = 2500 self._check_model_fit = True # Log everything self._logger = logging.getLogger('TestStructuredDataLogger') self._logger.setLevel(logging.DEBUG) if not self._logger.handlers: self._logger.addHandler(logging.StreamHandler(stream=sys.stdout)) def setUp(self): self._test_dir = tempfile.mkdtemp() self._preprocess_output = os.path.join(self._test_dir, 'pre') self._train_output = os.path.join(self._test_dir, 'train') os.mkdir(self._preprocess_output) os.mkdir(self._train_output) self._csv_train_filename = os.path.join(self._test_dir, 'train_csv_data.csv') self._csv_eval_filename = os.path.join(self._test_dir, 'eval_csv_data.csv') self._schema_filename = os.path.join(self._test_dir, 'schema.json') self._input_features_filename = os.path.join(self._test_dir, 'input_features_file.json') self._transforms_filename = os.path.join(self._test_dir, 'features.json') def tearDown(self): self._logger.debug('TestTrainer: removing test dir ' + self._test_dir) shutil.rmtree(self._test_dir) def _run_training(self, problem_type, model_type, transforms, extra_args=[]): """Runs training. Output is saved to _training_screen_output. Nothing from training should be printed to the screen. Args: problem_type: 'regression' or 'classification' model_type: 'linear' or 'dnn' transform: JSON object of the transforms file. extra_args: list of strings to pass to the trainer. """ # Run preprocessing. e2e_functions.make_csv_data(self._csv_train_filename, 100, problem_type, True) e2e_functions.make_csv_data(self._csv_eval_filename, 100, problem_type, True) e2e_functions.make_preprocess_schema(self._schema_filename, problem_type) e2e_functions.run_preprocess( output_dir=self._preprocess_output, csv_filename=self._csv_train_filename, schema_filename=self._schema_filename, logger=self._logger) # Write the transforms file. with open(self._transforms_filename, 'w') as f: f.write(json.dumps(transforms, indent=2, separators=(',', ': '))) # Run training and save the output. output = e2e_functions.run_training( train_data_paths=self._csv_train_filename, eval_data_paths=self._csv_eval_filename, output_path=self._train_output, preprocess_output_dir=self._preprocess_output, transforms_file=self._transforms_filename, max_steps=self._max_steps, model_type=model_type + '_' + problem_type, logger=self._logger, extra_args=extra_args) self._training_screen_output = output def _check_training_screen_output(self, accuracy=None, loss=None): """Should be called after _run_training. Inspects self._training_screen_output for correct output. Args: accuracy: float. Eval accuracy should be > than this number. loss: flaot. Eval loss should be < than this number. """ if not self._check_model_fit: self._logger.debug('Skipping model loss/accuracy checks') return # Find the last training loss line in the output lines = self._training_screen_output.splitlines() last_line = None for line in lines: if line.startswith('INFO:tensorflow:Saving dict for global step %d' % self._max_steps): last_line = line break if not last_line: self._logger.debug('Skipping _check_training_screen_output as could not ' 'find last eval line') return self._logger.debug(last_line) # supports positive numbers (int, real) with exponential form support. positive_number_re = re.compile('[+]?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?') # Check it made it to step 2500 saving_num_re = re.compile('global_step = \d+') saving_num = saving_num_re.findall(last_line) # saving_num == ['Saving evaluation summary for step NUM'] self.assertEqual(len(saving_num), 1) step_num = positive_number_re.findall(saving_num[0]) # step_num == ['2500'] self.assertEqual(len(step_num), 1) self.assertEqual(int(step_num[0]), self._max_steps) # Check the accuracy if accuracy is not None: accuracy_eq_num_re = re.compile('accuracy = [+]?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?') accuracy_eq_num = accuracy_eq_num_re.findall(last_line) # accuracy_eq_num == ['accuracy = NUM'] self.assertEqual(len(accuracy_eq_num), 1) accuracy_num = positive_number_re.findall(accuracy_eq_num[0]) # accuracy_num == ['X.XXX'] self.assertEqual(len(accuracy_num), 1) self.assertGreater(float(accuracy_num[0]), accuracy) if loss is not None: loss_eq_num_re = re.compile('loss = [+]?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?') loss_eq_num = loss_eq_num_re.findall(last_line) # loss_eq_num == ['loss = NUM'] self.assertEqual(len(loss_eq_num), 1) loss_num = positive_number_re.findall(loss_eq_num[0]) # loss_num == ['X.XXX'] self.assertEqual(len(loss_num), 1) self.assertLess(float(loss_num[0]), loss) def _check_train_files(self): self._check_savedmodel(os.path.join(self._train_output, 'model')) self._check_savedmodel(os.path.join(self._train_output, 'evaluation_model')) def _check_savedmodel(self, model_folder): self.assertTrue( os.path.isfile(os.path.join(model_folder, 'saved_model.pb'))) self.assertTrue( os.path.isfile(os.path.join(model_folder, 'variables/variables.index'))) self.assertTrue( os.path.isfile(os.path.join(model_folder, 'assets.extra/schema.json'))) self.assertTrue( os.path.isfile(os.path.join(model_folder, 'assets.extra/features.json'))) def testRegressionDnn(self): self._logger.debug('\n\nTesting Regression DNN') transforms = { "num1": {"transform": "scale"}, "num2": {"transform": "scale", "value": 4}, "str1": {"transform": "one_hot"}, "str2": {"transform": "embedding", "embedding_dim": 3}, "target": {"transform": "target"}, "key": {"transform": "key"}, } extra_args = ['--layer-size1=10', '--layer-size2=10', '--layer-size3=5'] self._run_training(problem_type='regression', model_type='dnn', transforms=transforms, extra_args=extra_args) self._check_training_screen_output(loss=20) self._check_train_files() def testRegressionLinear(self): self._logger.debug('\n\nTesting Regression Linear') transforms = { "num1": {"transform": "scale"}, "num2": {"transform": "scale", "value": 4}, "str1": {"transform": "one_hot"}, "str2": {"transform": "embedding", "embedding_dim": 3}, "target": {"transform": "target"}, "key": {"transform": "key"}, } self._run_training(problem_type='regression', model_type='linear', transforms=transforms) self._check_training_screen_output(loss=20) self._check_train_files() def testClassificationDnn(self): self._logger.debug('\n\nTesting classification DNN') transforms = { "num1": {"transform": "scale"}, "num2": {"transform": "scale", "value": 4}, "str1": {"transform": "one_hot"}, "str2": {"transform": "embedding", "embedding_dim": 3}, "target": {"transform": "target"}, "key": {"transform": "key"}, } extra_args = ['--layer-size1=10', '--layer-size2=10', '--layer-size3=5'] self._run_training(problem_type='classification', model_type='dnn', transforms=transforms, extra_args=extra_args) self._check_training_screen_output(accuracy=0.70, loss=0.10) self._check_train_files() def testClassificationLinear(self): self._logger.debug('\n\nTesting classification Linear') transforms = { "num1": {"transform": "scale"}, "num2": {"transform": "scale", "value": 4}, "str1": {"transform": "one_hot"}, "str2": {"transform": "embedding", "embedding_dim": 3}, "target": {"transform": "target"}, "key": {"transform": "key"}, } self._run_training(problem_type='classification', model_type='linear', transforms=transforms) self._check_training_screen_output(accuracy=0.70, loss=0.2) self._check_train_files() if __name__ == '__main__': unittest.main() ================================================ FILE: tests/_util/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: tests/_util/commands_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import argparse import contextlib import six import sys import unittest import yaml from google.datalab.utils.commands import CommandParser class TestCases(unittest.TestCase): @staticmethod @contextlib.contextmanager # Redirects some stderr temporarily; can be used to prevent console output from some tests. def redirect_stderr(target): original = sys.stderr sys.stderr = target yield sys.stderr = original def test_subcommand_line(self): parser = CommandParser( prog='%test_subcommand_line', description='test_subcommand_line description') subcommand1 = parser.subcommand('subcommand1', help='subcommand1 help') subcommand1.add_argument('--string1', help='string1 help.') subcommand1.add_argument('--flag1', action='store_true', default=False, help='flag1 help.') args, cell = parser.parse('subcommand1 --string1 value1 --flag1', None) self.assertEqual(args, {'string1': 'value1', 'flag1': True}) self.assertIsNone(cell) args, cell = parser.parse('subcommand1 --string1 value1', None) self.assertEqual(args, {'string1': 'value1', 'flag1': False}) self.assertIsNone(cell) args, cell = parser.parse('subcommand1', None) self.assertEqual(args, {'string1': None, 'flag1': False}) self.assertIsNone(cell) # Adding same arg twice will cause argparse to raise its own ArgumentError. with self.assertRaises(argparse.ArgumentError): subcommand1.add_argument('--string2', help='string2 help.') subcommand1.add_argument('--string2', help='string2 help.') def test_subcommand_line_cell(self): parser = CommandParser( prog='%test_subcommand_line', description='test_subcommand_line description') subcommand1 = parser.subcommand('subcommand1', help='subcommand1 help') subcommands_of_subcommand1 = subcommand1.add_subparsers(dest='command') subcommand2 = subcommands_of_subcommand1.add_parser('subcommand2', help='subcommand2 help') subcommand2.add_argument('--string1', '-s', required=True, help='string1 help.') subcommand2.add_argument('--string2', '--string2again', dest='string2', help='string2 help.') subcommand2.add_cell_argument('string3', help='string3 help.') subcommand2.add_argument('--flag1', action='store_true', default=False, help='flag1 help.') args, cell = parser.parse('subcommand1 subcommand2 -s value1 --string2 value2', 'flag1: true') self.assertEqual(args, {'string1': 'value1', 'string2': 'value2', 'string3': None, 'command': 'subcommand2', 'flag1': True}) self.assertIsNone(cell) args, cell = parser.parse('subcommand1 subcommand2 --flag1', 'string1: value1\nstring2again: value2') self.assertEqual(args, {'string1': 'value1', 'string2': 'value2', 'string3': None, 'command': 'subcommand2', 'flag1': True}) self.assertIsNone(cell) args, cell = parser.parse('subcommand1 subcommand2', 'string1: value1\nstring2: value2\nstring3: value3\n' + 'string4: value4\nflag1: false') self.assertEqual(args, {'string1': 'value1', 'string2': 'value2', 'string3': 'value3', 'command': 'subcommand2', 'flag1': False}) self.assertEqual(yaml.load(cell), {'string3': 'value3', 'string4': 'value4'}) # Regular arg and cell arg cannot be the same name. with self.assertRaises(ValueError): subcommand2.add_argument('--duparg', help='help.') subcommand2.add_cell_argument('duparg', help='help.') # Do not allow same arg in both line and cell. with self.assertRaises(ValueError): parser.parse('subcommand1 subcommand2 -s value1 --duparg v1', 'duparg: v2') # 'string3' is a cell arg. Argparse will raise Exception after finding an unrecognized param. with self.assertRaisesRegexp(Exception, 'unrecognized arguments: --string3 value3'): with TestCases.redirect_stderr(six.StringIO()): parser.parse('subcommand1 subcommand2 -s value1 --string3 value3', 'a: b') # 'string4' is required but missing. subcommand2.add_cell_argument('string4', required=True, help='string4 help.') with self.assertRaises(ValueError): parser.parse('subcommand1 subcommand2 -s value1', 'a: b') def test_subcommand_var_replacement(self): parser = CommandParser( prog='%test_subcommand_line', description='test_subcommand_line description') subcommand1 = parser.subcommand('subcommand1', help='subcommand1 help') subcommand1.add_argument('--string1', help='string1 help.') subcommand1.add_argument('--flag1', action='store_true', default=False, help='flag1 help.') subcommand1.add_cell_argument('string2', help='string2 help.') subcommand1.add_cell_argument('dict1', help='dict1 help.') namespace = {'var1': 'value1', 'var2': 'value2', 'var3': [1, 2]} args, cell = parser.parse('subcommand1 --string1 $var1', 'a: b\nstring2: $var2', namespace) self.assertEqual(args, {'string1': 'value1', 'string2': 'value2', 'flag1': False, 'dict1': None}) self.assertEqual(yaml.load(cell), {'a': 'b', 'string2': '$var2'}) cell = """ dict1: k1: $var1 k2: $var3 """ args, cell = parser.parse('subcommand1', cell, namespace) self.assertEqual(args['dict1'], {'k1': 'value1', 'k2': [1, 2]}) def test_subcommand_help(self): parser = CommandParser( prog='%test_subcommand_line', description='test_subcommand_line description') subcommand1 = parser.subcommand('subcommand1', help='subcommand1 help') subcommands_of_subcommand1 = subcommand1.add_subparsers(dest='command') subcommand2 = subcommands_of_subcommand1.add_parser('subcommand2', help='subcommand2 help') subcommand2.add_argument('--string1', '-s', required=True, help='string1 help.') subcommand2.add_argument('--string2', '--string2again', dest='string2', help='string2 help.') subcommand2.add_cell_argument('string3', help='string3 help.') subcommand2.add_argument('--flag1', action='store_true', default=False, help='flag1 help.') old_stdout = sys.stdout buf = six.StringIO() sys.stdout = buf with self.assertRaises(Exception): parser.parse('subcommand1 subcommand2 --help', None) sys.stdout = old_stdout help_string = buf.getvalue() self.assertIn('string1 help.', help_string) self.assertIn('string2 help.', help_string) self.assertIn('string3 help.', help_string) self.assertIn('flag1 help.', help_string) ================================================ FILE: tests/_util/feature_statistics_generator_test.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from google.datalab.utils.facets.feature_statistics_generator import FeatureStatisticsGenerator import numpy as np import tensorflow as tf from tensorflow.python.platform import googletest class FeatureStatisticsGeneratorTest(googletest.TestCase): def setUp(self): self.fs = FeatureStatisticsGenerator() def testParseExampleInt(self): # Tests parsing examples of integers examples = [] for i in range(50): example = tf.train.Example() example.features.feature['num'].int64_list.value.append(i) examples.append(example) entries = {} for i, example in enumerate(examples): self.fs._ParseExample(example.features.feature, [], entries, i) self.assertEqual(1, len(entries)) self.assertIn('num', entries) info = entries['num'] self.assertEqual(0, info['missing']) self.assertEqual(self.fs.fs_proto.INT, info['type']) for i in range(len(examples)): self.assertEqual(1, info['counts'][i]) self.assertEqual(i, info['vals'][i]) def testParseExampleMissingValueList(self): # Tests parsing examples of integers examples = [] example = tf.train.Example() # pylint: disable=pointless-statement example.features.feature['str'] # pylint: enable=pointless-statement examples.append(example) example = tf.train.Example() example.features.feature['str'].bytes_list.value.append(b'test') examples.append(example) entries = {} for i, example in enumerate(examples): self.fs._ParseExample(example.features.feature, [], entries, i) self.assertEqual(1, len(entries)) self.assertIn('str', entries) info = entries['str'] self.assertEqual(1, info['missing']) self.assertEqual(self.fs.fs_proto.STRING, info['type']) self.assertEqual(0, info['counts'][0]) self.assertEqual(1, info['counts'][1]) def _check_sequence_example_entries(self, entries, n_examples, n_features, feat_len=None): self.assertIn('num', entries) info = entries['num'] self.assertEqual(0, info['missing']) self.assertEqual(self.fs.fs_proto.INT, info['type']) for i in range(n_examples): self.assertEqual(n_features, info['counts'][i]) if feat_len is not None: self.assertEqual(feat_len, info['feat_lens'][i]) for i in range(n_examples * n_features): self.assertEqual(i, info['vals'][i]) if feat_len is None: self.assertEqual(0, len(info['feat_lens'])) def testParseExampleSequenceContext(self): # Tests parsing examples of integers in context field examples = [] for i in range(50): example = tf.train.SequenceExample() example.context.feature['num'].int64_list.value.append(i) examples.append(example) entries = {} for i, example in enumerate(examples): self.fs._ParseExample(example.context.feature, example.feature_lists.feature_list, entries, i) self._check_sequence_example_entries(entries, 50, 1) self.assertEqual(1, len(entries)) def testParseExampleSequenceFeatureList(self): examples = [] for i in range(50): example = tf.train.SequenceExample() feat = example.feature_lists.feature_list['num'].feature.add() feat.int64_list.value.append(i) examples.append(example) entries = {} for i, example in enumerate(examples): self.fs._ParseExample(example.context.feature, example.feature_lists.feature_list, entries, i) self._check_sequence_example_entries(entries, 50, 1, 1) def testParseExampleSequenceFeatureListMultipleEntriesInner(self): examples = [] for i in range(2): example = tf.train.SequenceExample() feat = example.feature_lists.feature_list['num'].feature.add() for j in range(25): feat.int64_list.value.append(i * 25 + j) examples.append(example) entries = {} for i, example in enumerate(examples): self.fs._ParseExample(example.context.feature, example.feature_lists.feature_list, entries, i) self._check_sequence_example_entries(entries, 2, 25, 1) def testParseExampleSequenceFeatureListMultipleEntriesOuter(self): # Tests parsing examples of integers in context field examples = [] for i in range(2): example = tf.train.SequenceExample() for j in range(25): feat = example.feature_lists.feature_list['num'].feature.add() feat.int64_list.value.append(i * 25 + j) examples.append(example) entries = {} for i, example in enumerate(examples): self.fs._ParseExample(example.context.feature, example.feature_lists.feature_list, entries, i) self._check_sequence_example_entries(entries, 2, 25, 25) def testVaryingCountsAndMissing(self): # Tests parsing examples of when some examples have missing features examples = [] for i in range(5): example = tf.train.Example() example.features.feature['other'].int64_list.value.append(0) for _ in range(i): example.features.feature['num'].int64_list.value.append(i) examples.append(example) example = tf.train.Example() example.features.feature['other'].int64_list.value.append(0) examples.append(example) entries = {} for i, example in enumerate(examples): self.fs._ParseExample(example.features.feature, [], entries, i) info = entries['num'] self.assertEqual(2, info['missing']) self.assertEqual(4, len(info['counts'])) for i in range(4): self.assertEqual(i + 1, info['counts'][i]) self.assertEqual(10, len(info['vals'])) def testParseExampleStringsAndFloats(self): # Tests parsing examples of string and float features examples = [] for i in range(50): example = tf.train.Example() example.features.feature['str'].bytes_list.value.append(b'hi') example.features.feature['float'].float_list.value.append(i) examples.append(example) entries = {} for i, example in enumerate(examples): self.fs._ParseExample(example.features.feature, [], entries, i) self.assertEqual(2, len(entries)) self.assertEqual(self.fs.fs_proto.FLOAT, entries['float']['type']) self.assertEqual(self.fs.fs_proto.STRING, entries['str']['type']) for i in range(len(examples)): self.assertEqual(1, entries['str']['counts'][i]) self.assertEqual(1, entries['float']['counts'][i]) self.assertEqual(i, entries['float']['vals'][i]) self.assertEqual('hi', entries['str']['vals'][i].decode( 'UTF-8', 'strict')) def testParseExamplesTypeMismatch(self): examples = [] example = tf.train.Example() example.features.feature['feat'].int64_list.value.append(0) examples.append(example) example = tf.train.Example() example.features.feature['feat'].bytes_list.value.append(b'str') examples.append(example) entries = {} self.fs._ParseExample(examples[0].features.feature, [], entries, 0) with self.assertRaises(TypeError): self.fs._ParseExample(examples[1].features.feature, [], entries, 1) def testGetDatasetsProtoFromEntriesLists(self): entries = {} entries['testFeature'] = { 'vals': [1, 2, 3], 'counts': [1, 1, 1], 'missing': 0, 'type': self.fs.fs_proto.INT } datasets = [{'entries': entries, 'size': 3, 'name': 'testDataset'}] p = self.fs.GetDatasetsProto(datasets) self.assertEqual(1, len(p.datasets)) test_data = p.datasets[0] self.assertEqual('testDataset', test_data.name) self.assertEqual(3, test_data.num_examples) self.assertEqual(1, len(test_data.features)) numfeat = test_data.features[0] self.assertEqual('testFeature', numfeat.name) self.assertEqual(self.fs.fs_proto.INT, numfeat.type) self.assertEqual(1, numfeat.num_stats.min) self.assertEqual(3, numfeat.num_stats.max) def testGetProtoNums(self): # Tests converting int examples into the feature stats proto examples = [] for i in range(50): example = tf.train.Example() example.features.feature['num'].int64_list.value.append(i) examples.append(example) example = tf.train.Example() example.features.feature['other'].int64_list.value.append(0) examples.append(example) entries = {} for i, example in enumerate(examples): self.fs._ParseExample(example.features.feature, [], entries, i) datasets = [{'entries': entries, 'size': len(examples), 'name': 'test'}] p = self.fs.GetDatasetsProto(datasets) self.assertEqual(1, len(p.datasets)) test_data = p.datasets[0] self.assertEqual('test', test_data.name) self.assertEqual(51, test_data.num_examples) numfeat = test_data.features[0] if ( test_data.features[0].name == 'num') else test_data.features[1] self.assertEqual('num', numfeat.name) self.assertEqual(self.fs.fs_proto.INT, numfeat.type) self.assertEqual(0, numfeat.num_stats.min) self.assertEqual(49, numfeat.num_stats.max) self.assertEqual(24.5, numfeat.num_stats.mean) self.assertEqual(24.5, numfeat.num_stats.median) self.assertEqual(1, numfeat.num_stats.num_zeros) self.assertAlmostEqual(14.430869689, numfeat.num_stats.std_dev, 4) self.assertEqual(1, numfeat.num_stats.common_stats.num_missing) self.assertEqual(50, numfeat.num_stats.common_stats.num_non_missing) self.assertEqual(1, numfeat.num_stats.common_stats.min_num_values) self.assertEqual(1, numfeat.num_stats.common_stats.max_num_values) self.assertAlmostEqual(1, numfeat.num_stats.common_stats.avg_num_values, 4) hist = numfeat.num_stats.common_stats.num_values_histogram buckets = hist.buckets self.assertEqual(self.fs.histogram_proto.QUANTILES, hist.type) self.assertEqual(10, len(buckets)) self.assertEqual(1, buckets[0].low_value) self.assertEqual(1, buckets[0].high_value) self.assertEqual(5, buckets[0].sample_count) self.assertEqual(1, buckets[9].low_value) self.assertEqual(1, buckets[9].high_value) self.assertEqual(5, buckets[9].sample_count) self.assertEqual(2, len(numfeat.num_stats.histograms)) buckets = numfeat.num_stats.histograms[0].buckets self.assertEqual(self.fs.histogram_proto.STANDARD, numfeat.num_stats.histograms[0].type) self.assertEqual(10, len(buckets)) self.assertEqual(0, buckets[0].low_value) self.assertEqual(4.9, buckets[0].high_value) self.assertEqual(5, buckets[0].sample_count) self.assertAlmostEqual(44.1, buckets[9].low_value) self.assertEqual(49, buckets[9].high_value) self.assertEqual(5, buckets[9].sample_count) buckets = numfeat.num_stats.histograms[1].buckets self.assertEqual(self.fs.histogram_proto.QUANTILES, numfeat.num_stats.histograms[1].type) self.assertEqual(10, len(buckets)) self.assertEqual(0, buckets[0].low_value) self.assertEqual(4.9, buckets[0].high_value) self.assertEqual(5, buckets[0].sample_count) self.assertAlmostEqual(44.1, buckets[9].low_value) self.assertEqual(49, buckets[9].high_value) self.assertEqual(5, buckets[9].sample_count) def testQuantiles(self): examples = [] for i in range(50): example = tf.train.Example() example.features.feature['num'].int64_list.value.append(i) examples.append(example) for i in range(50): example = tf.train.Example() example.features.feature['num'].int64_list.value.append(100) examples.append(example) entries = {} for i, example in enumerate(examples): self.fs._ParseExample(example.features.feature, [], entries, i) datasets = [{'entries': entries, 'size': len(examples), 'name': 'test'}] p = self.fs.GetDatasetsProto(datasets) numfeat = p.datasets[0].features[0] self.assertEqual(2, len(numfeat.num_stats.histograms)) self.assertEqual(self.fs.histogram_proto.QUANTILES, numfeat.num_stats.histograms[1].type) buckets = numfeat.num_stats.histograms[1].buckets self.assertEqual(10, len(buckets)) self.assertEqual(0, buckets[0].low_value) self.assertEqual(9.9, buckets[0].high_value) self.assertEqual(10, buckets[0].sample_count) self.assertEqual(100, buckets[9].low_value) self.assertEqual(100, buckets[9].high_value) self.assertEqual(10, buckets[9].sample_count) def testInfinityAndNan(self): examples = [] for i in range(50): example = tf.train.Example() example.features.feature['num'].float_list.value.append(i) examples.append(example) example = tf.train.Example() example.features.feature['num'].float_list.value.append(float('inf')) examples.append(example) example = tf.train.Example() example.features.feature['num'].float_list.value.append(float('-inf')) examples.append(example) example = tf.train.Example() example.features.feature['num'].float_list.value.append(float('nan')) examples.append(example) entries = {} for i, example in enumerate(examples): self.fs._ParseExample(example.features.feature, [], entries, i) datasets = [{'entries': entries, 'size': len(examples), 'name': 'test'}] p = self.fs.GetDatasetsProto(datasets) numfeat = p.datasets[0].features[0] self.assertEqual('num', numfeat.name) self.assertEqual(self.fs.fs_proto.FLOAT, numfeat.type) self.assertTrue(np.isnan(numfeat.num_stats.min)) self.assertTrue(np.isnan(numfeat.num_stats.max)) self.assertTrue(np.isnan(numfeat.num_stats.mean)) self.assertTrue(np.isnan(numfeat.num_stats.median)) self.assertEqual(1, numfeat.num_stats.num_zeros) self.assertTrue(np.isnan(numfeat.num_stats.std_dev)) self.assertEqual(53, numfeat.num_stats.common_stats.num_non_missing) hist = buckets = numfeat.num_stats.histograms[0] buckets = hist.buckets self.assertEqual(self.fs.histogram_proto.STANDARD, hist.type) self.assertEqual(1, hist.num_nan) self.assertEqual(10, len(buckets)) self.assertEqual(float('-inf'), buckets[0].low_value) self.assertEqual(4.9, buckets[0].high_value) self.assertEqual(6, buckets[0].sample_count) self.assertEqual(44.1, buckets[9].low_value) self.assertEqual(float('inf'), buckets[9].high_value) self.assertEqual(6, buckets[9].sample_count) def testInfinitysOnly(self): examples = [] example = tf.train.Example() example.features.feature['num'].float_list.value.append(float('inf')) examples.append(example) example = tf.train.Example() example.features.feature['num'].float_list.value.append(float('-inf')) examples.append(example) entries = {} for i, example in enumerate(examples): self.fs._ParseExample(example.features.feature, [], entries, i) datasets = [{'entries': entries, 'size': len(examples), 'name': 'test'}] p = self.fs.GetDatasetsProto(datasets) numfeat = p.datasets[0].features[0] hist = buckets = numfeat.num_stats.histograms[0] buckets = hist.buckets self.assertEqual(self.fs.histogram_proto.STANDARD, hist.type) self.assertEqual(10, len(buckets)) self.assertEqual(float('-inf'), buckets[0].low_value) self.assertEqual(0.1, buckets[0].high_value) self.assertEqual(1, buckets[0].sample_count) self.assertEqual(0.9, buckets[9].low_value) self.assertEqual(float('inf'), buckets[9].high_value) self.assertEqual(1, buckets[9].sample_count) def testGetProtoStrings(self): # Tests converting string examples into the feature stats proto examples = [] for i in range(2): example = tf.train.Example() example.features.feature['str'].bytes_list.value.append(b'hello') examples.append(example) for i in range(3): example = tf.train.Example() example.features.feature['str'].bytes_list.value.append(b'hi') examples.append(example) example = tf.train.Example() example.features.feature['str'].bytes_list.value.append(b'hey') examples.append(example) entries = {} for i, example in enumerate(examples): self.fs._ParseExample(example.features.feature, [], entries, i) datasets = [{'entries': entries, 'size': len(examples), 'name': 'test'}] p = self.fs.GetDatasetsProto(datasets) self.assertEqual(1, len(p.datasets)) test_data = p.datasets[0] self.assertEqual('test', test_data.name) self.assertEqual(6, test_data.num_examples) strfeat = test_data.features[0] self.assertEqual('str', strfeat.name) self.assertEqual(self.fs.fs_proto.STRING, strfeat.type) self.assertEqual(3, strfeat.string_stats.unique) self.assertAlmostEqual(19 / 6.0, strfeat.string_stats.avg_length, 4) self.assertEqual(0, strfeat.string_stats.common_stats.num_missing) self.assertEqual(6, strfeat.string_stats.common_stats.num_non_missing) self.assertEqual(1, strfeat.string_stats.common_stats.min_num_values) self.assertEqual(1, strfeat.string_stats.common_stats.max_num_values) self.assertEqual(1, strfeat.string_stats.common_stats.avg_num_values) hist = strfeat.string_stats.common_stats.num_values_histogram buckets = hist.buckets self.assertEqual(self.fs.histogram_proto.QUANTILES, hist.type) self.assertEqual(10, len(buckets)) self.assertEqual(1, buckets[0].low_value) self.assertEqual(1, buckets[0].high_value) self.assertEqual(.6, buckets[0].sample_count) self.assertEqual(1, buckets[9].low_value) self.assertEqual(1, buckets[9].high_value) self.assertEqual(.6, buckets[9].sample_count) self.assertEqual(2, len(strfeat.string_stats.top_values)) self.assertEqual(3, strfeat.string_stats.top_values[0].frequency) self.assertEqual('hi', strfeat.string_stats.top_values[0].value) self.assertEqual(2, strfeat.string_stats.top_values[1].frequency) self.assertEqual('hello', strfeat.string_stats.top_values[1].value) buckets = strfeat.string_stats.rank_histogram.buckets self.assertEqual(3, len(buckets)) self.assertEqual(0, buckets[0].low_rank) self.assertEqual(0, buckets[0].high_rank) self.assertEqual(3, buckets[0].sample_count) self.assertEqual('hi', buckets[0].label) self.assertEqual(2, buckets[2].low_rank) self.assertEqual(2, buckets[2].high_rank) self.assertEqual(1, buckets[2].sample_count) self.assertEqual('hey', buckets[2].label) def testGetProtoMultipleDatasets(self): # Tests converting multiple datsets into the feature stats proto # including ensuring feature order is consistent in the protos. examples1 = [] for i in range(2): example = tf.train.Example() example.features.feature['str'].bytes_list.value.append(b'one') example.features.feature['num'].int64_list.value.append(0) examples1.append(example) examples2 = [] example = tf.train.Example() example.features.feature['num'].int64_list.value.append(1) example.features.feature['str'].bytes_list.value.append(b'two') examples2.append(example) entries1 = {} for i, example1 in enumerate(examples1): self.fs._ParseExample(example1.features.feature, [], entries1, i) entries2 = {} for i, example2 in enumerate(examples2): self.fs._ParseExample(example2.features.feature, [], entries2, i) datasets = [{ 'entries': entries1, 'size': len(examples1), 'name': 'test1' }, { 'entries': entries2, 'size': len(examples2), 'name': 'test2' }] p = self.fs.GetDatasetsProto(datasets) self.assertEqual(2, len(p.datasets)) test_data_1 = p.datasets[0] self.assertEqual('test1', test_data_1.name) self.assertEqual(2, test_data_1.num_examples) num_feat_index = 0 if test_data_1.features[0].name == 'num' else 1 self.assertEqual(0, test_data_1.features[num_feat_index].num_stats.max) test_data_2 = p.datasets[1] self.assertEqual('test2', test_data_2.name) self.assertEqual(1, test_data_2.num_examples) self.assertEqual(1, test_data_2.features[num_feat_index].num_stats.max) def testGetEntriesNoFiles(self): features, num_examples = self.fs._GetEntries(['test'], 10, lambda unused_path: []) self.assertEqual(0, num_examples) self.assertEqual({}, features) @staticmethod def get_example_iter(): def ex_iter(unused_filename): examples = [] for i in range(50): example = tf.train.Example() example.features.feature['num'].int64_list.value.append(i) examples.append(example.SerializeToString()) return examples return ex_iter def testGetEntries_one(self): features, num_examples = self.fs._GetEntries(['test'], 1, self.get_example_iter()) self.assertEqual(1, num_examples) self.assertTrue('num' in features) def testGetEntries_oneFile(self): unused_features, num_examples = self.fs._GetEntries(['test'], 1000, self.get_example_iter()) self.assertEqual(50, num_examples) def testGetEntries_twoFiles(self): unused_features, num_examples = self.fs._GetEntries(['test0', 'test1'], 1000, self.get_example_iter()) self.assertEqual(100, num_examples) def testGetEntries_stopInSecondFile(self): unused_features, num_examples = self.fs._GetEntries([ 'test@0', 'test@1', 'test@2', 'test@3', 'test@4', 'test@5', 'test@6', 'test@7', 'test@8', 'test@9' ], 75, self.get_example_iter()) self.assertEqual(75, num_examples) if __name__ == '__main__': googletest.main() ================================================ FILE: tests/_util/generic_feature_statistics_generator_test.py ================================================ # Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from google.datalab.utils.facets.generic_feature_statistics_generator \ import GenericFeatureStatisticsGenerator import numpy as np import pandas as pd from tensorflow.python.platform import googletest class GenericFeatureStatisticsGeneratorTest(googletest.TestCase): def setUp(self): self.gfsg = GenericFeatureStatisticsGenerator() def testProtoFromDataFrames(self): data = [[1, 'hi'], [2, 'hello'], [3, 'hi']] df = pd.DataFrame(data, columns=['testFeatureInt', 'testFeatureString']) dataframes = [{'table': df, 'name': 'testDataset'}] p = self.gfsg.ProtoFromDataFrames(dataframes) self.assertEqual(1, len(p.datasets)) test_data = p.datasets[0] self.assertEqual('testDataset', test_data.name) self.assertEqual(3, test_data.num_examples) self.assertEqual(2, len(test_data.features)) if test_data.features[0].name == 'testFeatureInt': numfeat = test_data.features[0] stringfeat = test_data.features[1] else: numfeat = test_data.features[1] stringfeat = test_data.features[0] self.assertEqual('testFeatureInt', numfeat.name) self.assertEqual(self.gfsg.fs_proto.INT, numfeat.type) self.assertEqual(1, numfeat.num_stats.min) self.assertEqual(3, numfeat.num_stats.max) self.assertEqual('testFeatureString', stringfeat.name) self.assertEqual(self.gfsg.fs_proto.STRING, stringfeat.type) self.assertEqual(2, stringfeat.string_stats.unique) def testNdarrayToEntry(self): arr = np.array([1.0, 2.0, None, float('nan'), 3.0], dtype=float) entry = self.gfsg.NdarrayToEntry(arr) self.assertEqual(2, entry['missing']) arr = np.array(['a', 'b', float('nan'), 'c'], dtype=str) entry = self.gfsg.NdarrayToEntry(arr) self.assertEqual(1, entry['missing']) def testNdarrayToEntryTimeTypes(self): arr = np.array( [np.datetime64('2005-02-25'), np.datetime64('2006-02-25')], dtype=np.datetime64) entry = self.gfsg.NdarrayToEntry(arr) self.assertEqual([1109289600000000000, 1140825600000000000], entry['vals']) arr = np.array( [np.datetime64('2009-01-01') - np.datetime64('2008-01-01')], dtype=np.timedelta64) entry = self.gfsg.NdarrayToEntry(arr) self.assertEqual([31622400000000000], entry['vals']) def testDTypeToType(self): self.assertEqual(self.gfsg.fs_proto.INT, self.gfsg.DtypeToType(np.dtype(np.int32))) # Boolean and time types treated as int self.assertEqual(self.gfsg.fs_proto.INT, self.gfsg.DtypeToType(np.dtype(np.bool))) self.assertEqual(self.gfsg.fs_proto.INT, self.gfsg.DtypeToType(np.dtype(np.datetime64))) self.assertEqual(self.gfsg.fs_proto.INT, self.gfsg.DtypeToType(np.dtype(np.timedelta64))) self.assertEqual(self.gfsg.fs_proto.FLOAT, self.gfsg.DtypeToType(np.dtype(np.float32))) self.assertEqual(self.gfsg.fs_proto.STRING, self.gfsg.DtypeToType(np.dtype(np.str))) # Unsupported types treated as string for now self.assertEqual(self.gfsg.fs_proto.STRING, self.gfsg.DtypeToType(np.dtype(np.void))) def testGetDatasetsProtoFromEntriesLists(self): entries = {} entries['testFeature'] = { 'vals': [1, 2, 3], 'counts': [1, 1, 1], 'missing': 0, 'type': self.gfsg.fs_proto.INT } datasets = [{'entries': entries, 'size': 3, 'name': 'testDataset'}] p = self.gfsg.GetDatasetsProto(datasets) self.assertEqual(1, len(p.datasets)) test_data = p.datasets[0] self.assertEqual('testDataset', test_data.name) self.assertEqual(3, test_data.num_examples) self.assertEqual(1, len(test_data.features)) numfeat = test_data.features[0] self.assertEqual('testFeature', numfeat.name) self.assertEqual(self.gfsg.fs_proto.INT, numfeat.type) self.assertEqual(1, numfeat.num_stats.min) self.assertEqual(3, numfeat.num_stats.max) hist = numfeat.num_stats.common_stats.num_values_histogram buckets = hist.buckets self.assertEqual(self.gfsg.histogram_proto.QUANTILES, hist.type) self.assertEqual(10, len(buckets)) self.assertEqual(1, buckets[0].low_value) self.assertEqual(1, buckets[0].high_value) self.assertEqual(.3, buckets[0].sample_count) self.assertEqual(1, buckets[9].low_value) self.assertEqual(1, buckets[9].high_value) self.assertEqual(.3, buckets[9].sample_count) def testGetDatasetsProtoSequenceExampleHistogram(self): entries = {} entries['testFeature'] = { 'vals': [1, 2, 2, 3], 'counts': [1, 2, 1], 'feat_lens': [1, 2, 1], 'missing': 0, 'type': self.gfsg.fs_proto.INT } datasets = [{'entries': entries, 'size': 3, 'name': 'testDataset'}] p = self.gfsg.GetDatasetsProto(datasets) hist = p.datasets[0].features[ 0].num_stats.common_stats.feature_list_length_histogram buckets = hist.buckets self.assertEqual(self.gfsg.histogram_proto.QUANTILES, hist.type) self.assertEqual(10, len(buckets)) self.assertEqual(1, buckets[0].low_value) self.assertEqual(1, buckets[0].high_value) self.assertEqual(.3, buckets[0].sample_count) self.assertEqual(1.8, buckets[9].low_value) self.assertEqual(2, buckets[9].high_value) self.assertEqual(.3, buckets[9].sample_count) def testGetDatasetsProtoWithAllowlist(self): entries = {} entries['testFeature'] = { 'vals': [1, 2, 3], 'counts': [1, 1, 1], 'missing': 0, 'type': self.gfsg.fs_proto.INT } entries['ignoreFeature'] = { 'vals': [5, 6], 'counts': [1, 1], 'missing': 1, 'type': self.gfsg.fs_proto.INT } datasets = [{'entries': entries, 'size': 3, 'name': 'testDataset'}] p = self.gfsg.GetDatasetsProto(datasets, features=['testFeature']) self.assertEqual(1, len(p.datasets)) test_data = p.datasets[0] self.assertEqual('testDataset', test_data.name) self.assertEqual(3, test_data.num_examples) self.assertEqual(1, len(test_data.features)) numfeat = test_data.features[0] self.assertEqual('testFeature', numfeat.name) self.assertEqual(1, numfeat.num_stats.min) if __name__ == '__main__': googletest.main() ================================================ FILE: tests/_util/http_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import mock import unittest # The httplib2 import is implicitly used when mocking its functionality. # pylint: disable=unused-import from google.datalab.utils._http import Http class TestCases(unittest.TestCase): @mock.patch('httplib2.Response') @mock.patch('google.datalab.utils._http.Http.http.request') def test_get_request_is_invoked(self, mock_request, mock_response): TestCases._setup_mocks(mock_request, mock_response, '{}') Http.request('http://www.example.org') self.assertEqual(mock_request.call_count, 1) self.assertEqual(mock_request.call_args[1]['method'], 'GET') @mock.patch('httplib2.Response') @mock.patch('google.datalab.utils._http.Http.http.request') def test_post_request_is_invoked(self, mock_request, mock_response): TestCases._setup_mocks(mock_request, mock_response, '{}') Http.request('http://www.example.org', data={}) self.assertEqual(mock_request.call_args[1]['method'], 'POST') @mock.patch('httplib2.Response') @mock.patch('google.datalab.utils._http.Http.http.request') def test_explicit_post_request_is_invoked(self, mock_request, mock_response): TestCases._setup_mocks(mock_request, mock_response, '{}') Http.request('http://www.example.org', method='POST') self.assertEqual(mock_request.call_args[1]['method'], 'POST') @mock.patch('httplib2.Response') @mock.patch('google.datalab.utils._http.Http.http.request') def test_query_string_format(self, mock_request, mock_response): TestCases._setup_mocks(mock_request, mock_response, '{}') Http.request('http://www.example.org', args={'a': 1, 'b': 'a b c'}) parts = mock_request.call_args[0][0].replace('?', '&').split('&') self.assertEqual(parts[0], 'http://www.example.org') self.assertTrue('a=1' in parts[1:]) self.assertTrue('b=a+b+c' in parts[1:]) @mock.patch('httplib2.Response') @mock.patch('google.datalab.utils._http.Http.http.request') def test_formats_json_request(self, mock_request, mock_response): TestCases._setup_mocks(mock_request, mock_response, '{}') data = {'abc': 123} Http.request('http://www.example.org', data=data) self.assertEqual(mock_request.call_args[1]['body'], '{"abc": 123}') self.assertEqual(mock_request.call_args[1]['headers']['Content-Type'], 'application/json') @mock.patch('httplib2.Response') @mock.patch('google.datalab.utils._http.Http.http.request') def test_supports_custom_content(self, mock_request, mock_response): TestCases._setup_mocks(mock_request, mock_response, '{}') headers = {'Content-Type': 'text/plain'} data = 'custom text' Http.request('http://www.example.org', data=data, headers=headers) self.assertEqual(mock_request.call_args[1]['body'], 'custom text') self.assertEqual(mock_request.call_args[1]['headers']['Content-Type'], 'text/plain') @mock.patch('httplib2.Response') @mock.patch('google.datalab.utils._http.Http.http.request') def test_parses_json_response(self, mock_request, mock_response): TestCases._setup_mocks(mock_request, mock_response, '{"abc":123}') data = Http.request('http://www.example.org') self.assertEqual(data['abc'], 123) @mock.patch('httplib2.Response') @mock.patch('google.datalab.utils._http.Http.http.request') def test_raises_http_error_json(self, mock_request, mock_response): TestCases._setup_mocks( mock_request, mock_response, b'{"error": {"errors": [{"message": "Not Found"}]}}', 404) with self.assertRaises(Exception) as error: Http.request('http://www.example.org') e = error.exception self.assertEqual(e.status, 404) self.assertEqual(e.message, 'HTTP request failed: Not Found') @mock.patch('httplib2.Response') @mock.patch('google.datalab.utils._http.Http.http.request') def test_raises_http_error_str(self, mock_request, mock_response): TestCases._setup_mocks(mock_request, mock_response, 'Not Found', 404) with self.assertRaises(Exception) as error: Http.request('http://www.example.org') e = error.exception self.assertEqual(e.status, 404) self.assertEqual(e.content, 'Not Found') @mock.patch('httplib2.Response') @mock.patch('google.datalab.utils._http.Http.http.request') def test_raises_http_error_bytes(self, mock_request, mock_response): TestCases._setup_mocks(mock_request, mock_response, b'Not Found', 404) with self.assertRaises(Exception) as error: Http.request('http://www.example.org') e = error.exception self.assertEqual(e.status, 404) self.assertEqual(e.content, b'Not Found') @staticmethod def _setup_mocks(mock_request, mock_response, content, status=200): response = mock_response() response.status = status mock_request.return_value = (response, content) ================================================ FILE: tests/_util/lru_cache_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import unittest from google.datalab.utils._lru_cache import LRUCache class TestCases(unittest.TestCase): def test_cache_no_entry(self): cache = LRUCache(3) with self.assertRaises(KeyError): cache['a'] def test_cache_lookup(self): cache = LRUCache(4) for x in ['a', 'b', 'c', 'd']: cache[x] = x for x in ['a', 'b', 'c', 'd']: self.assertEqual(x, cache[x]) def test_cache_overflow(self): cache = LRUCache(3) for x in ['a', 'b', 'c', 'd']: cache[x] = x for x in ['b', 'c', 'd']: self.assertEqual(x, cache[x]) with self.assertRaises(KeyError): cache['a'] cache['b'] cache['d'] # 'c' should be LRU now cache['e'] = 'e' with self.assertRaises(KeyError): cache['c'] for x in ['b', 'd', 'e']: self.assertEqual(x, cache[x]) ================================================ FILE: tests/_util/util_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import imp import unittest import pytz import mock import os import google.datalab.utils._utils as _utils import google.datalab.utils._iterator as _iterator from datetime import datetime import google.auth import google.auth.exceptions class TestCases(unittest.TestCase): @staticmethod def _get_data(): m = imp.new_module('baz') exec('x = 99', m.__dict__) data = { 'foo': { 'bar': { 'xyz': 0 }, 'm': m } } return data def test_no_entry(self): data = TestCases._get_data() self.assertIsNone(_utils.get_item(data, '')) self.assertIsNone(_utils.get_item(data, 'x')) self.assertIsNone(_utils.get_item(data, 'bar.x')) self.assertIsNone(_utils.get_item(data, 'foo.bar.x')) self.assertIsNone(_utils.get_item(globals(), 'datetime.bar.x')) def test_entry(self): data = TestCases._get_data() self.assertEquals(data['foo']['bar']['xyz'], _utils.get_item(data, 'foo.bar.xyz')) self.assertEquals(data['foo']['bar'], _utils.get_item(data, 'foo.bar')) self.assertEquals(data['foo'], _utils.get_item(data, 'foo')) self.assertEquals(data['foo']['m'], _utils.get_item(data, 'foo.m')) self.assertEquals(99, _utils.get_item(data, 'foo.m.x')) def test_compare_datetimes(self): t1, t2 = datetime(2017, 2, 2, 12, 0, 0), datetime(2017, 2, 2, 12, 0, 0) self.assertEquals(_utils.compare_datetimes(t1, t2), 0) t2 = t2.replace(hour=11) self.assertEquals(_utils.compare_datetimes(t1, t2), 1) def test_compare_datetimes_tz(self): t1 = datetime(2017, 2, 2, 12, 0, 0) t2 = datetime(2017, 2, 2, 12, 0, 0, tzinfo=pytz.timezone('US/Eastern')) self.assertEquals(_utils.compare_datetimes(t1, t2), -1) t1 = t1.replace(tzinfo=pytz.timezone('US/Pacific')) self.assertEquals(_utils.compare_datetimes(t1, t2), 1) @mock.patch('os.path.expanduser') def test_get_config_dir(self, mock_expand_user): mock_expand_user.return_value = 'user/relative/path' with mock.patch.dict(os.environ, {'CLOUDSDK_CONFIG': 'test/path'}): self.assertEquals(_utils.get_config_dir(), 'test/path') self.assertEquals(_utils.get_config_dir(), 'user/relative/path/.config/gcloud') @mock.patch('os.name', 'nt') @mock.patch('os.path.join') def test_get_config_dir_win(self, mock_path_join): mock_path_join.side_effect = lambda x, y: x + y self.assertEquals(_utils.get_config_dir(), 'C:\\gcloud') mock_path_join.side_effect = lambda x, y: x + '\\' + y with mock.patch.dict(os.environ, {'APPDATA': 'test\\path'}): self.assertEquals(_utils.get_config_dir(), 'test\\path\\gcloud') @mock.patch('google.datalab.utils._utils._in_datalab_docker') @mock.patch('google.auth.credentials.with_scopes_if_required') @mock.patch('google.auth.default') @mock.patch('os.path.exists') def test_get_credentials_from_file(self, mock_path_exists, mock_google_auth_default, mock_with_scopes_if_required, mock_in_datalab): # If application default credentials exist, use them creds = mock.Mock(spec=google.auth.credentials.Credentials) mock_google_auth_default.return_value = [creds, ''] _utils.get_credentials() mock_google_auth_default.assert_called_once() mock_with_scopes_if_required.assert_called_once() # If application default credentials are not defined, should load from file test_creds = ''' { "data": [{ "key": { "type": "google-cloud-sdk" }, "credential": { "access_token": "test-access-token", "client_id": "test-id", "client_secret": "test-secret", "refresh_token": "test-token", "token_expiry": "test-expiry", "token_uri": "test-url", "user_agent": "test-agent", "invalid": "false" } }] } ''' with mock.patch('google.datalab.utils._utils.open', mock.mock_open(read_data=test_creds)): mock_google_auth_default.side_effect = Exception cred = _utils.get_credentials() self.assertEquals(cred.token, 'test-access-token') mock_path_exists.return_value = False with self.assertRaises(Exception): cred = _utils.get_credentials() # If default creds are not defined, and no file exists with credentials, throw # something more meaningful. mock_google_auth_default.side_effect = google.auth.exceptions.DefaultCredentialsError with self.assertRaisesRegexp(Exception, 'No application credentials found. Perhaps you should sign in'): cred = _utils.get_credentials() @mock.patch('subprocess.call') @mock.patch('os.path.exists') def test_save_project_id(self, mock_path_exists, mock_subprocess_call): _utils.save_project_id('test-project') mock_subprocess_call.assert_called_with([ 'gcloud', 'config', 'set', 'project', 'test-project' ]) mock_subprocess_call.side_effect = Exception test_config = ''' { "project_id": "" } ''' opener = mock.mock_open(read_data=test_config) with mock.patch('google.datalab.utils._utils.open', opener): _utils.save_project_id('test-project') opener.assert_has_calls([mock.call().write('{"project_id": "test-project"}')]) @mock.patch('subprocess.Popen') @mock.patch('os.path.exists') def test_get_default_project_id(self, mock_path_exists, mock_subprocess_call): mock_subprocess_call.return_value.communicate.return_value = ('test-project', '') mock_subprocess_call.return_value.poll.return_value = 0 self.assertEquals(_utils.get_default_project_id(), 'test-project') mock_subprocess_call.assert_called_with( ['gcloud', 'config', 'list', '--format', 'value(core.project)'], stdout=-1) mock_subprocess_call.side_effect = Exception test_config = ''' { "project_id": "test-project2" } ''' opener = mock.mock_open(read_data=test_config) with mock.patch('google.datalab.utils._utils.open', opener): self.assertEquals(_utils.get_default_project_id(), 'test-project2') mock_path_exists.return_value = False self.assertIsNone(_utils.get_default_project_id()) with mock.patch.dict(os.environ, {'PROJECT_ID': 'test-project3'}): self.assertEquals(_utils.get_default_project_id(), 'test-project3') def test_iterator(self): max_count = 100 page_size = 10 def limited_retriever(next_item, running_count): next_item = next_item or 1 result_count = min(page_size, max_count - running_count) if result_count <= 0: return [], None return range(next_item, next_item + result_count), next_item + result_count read_count = 0 for item in _iterator.Iterator(limited_retriever): read_count += 1 self.assertLessEqual(read_count, max_count) ================================================ FILE: tests/bigquery/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: tests/bigquery/api_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import unittest import mock import google.auth import google.datalab import google.datalab.bigquery import google.datalab.utils from google.datalab.bigquery._api import Api class TestCases(unittest.TestCase): def validate(self, mock_http_request, expected_url, expected_args=None, expected_data=None, expected_headers=None, expected_method=None): url = mock_http_request.call_args[0][0] kwargs = mock_http_request.call_args[1] self.assertEquals(expected_url, url) if expected_args is not None: self.assertEquals(expected_args, kwargs['args']) else: self.assertNotIn('args', kwargs) if expected_data is not None: self.assertEquals(expected_data, kwargs['data']) else: self.assertNotIn('data', kwargs) if expected_headers is not None: self.assertEquals(expected_headers, kwargs['headers']) else: self.assertNotIn('headers', kwargs) if expected_method is not None: self.assertEquals(expected_method, kwargs['method']) else: self.assertNotIn('method', kwargs) @mock.patch('google.datalab.utils.Http.request') def test_jobs_insert_load(self, mock_http_request): api = TestCases._create_api() api.jobs_insert_load('SOURCE', google.datalab.bigquery._utils.TableName('p', 'd', 't', '')) self.maxDiff = None expected_data = { 'kind': 'bigquery#job', 'configuration': { 'load': { 'sourceUris': ['SOURCE'], 'destinationTable': { 'projectId': 'p', 'datasetId': 'd', 'tableId': 't' }, 'createDisposition': 'CREATE_NEVER', 'writeDisposition': 'WRITE_EMPTY', 'sourceFormat': 'CSV', 'fieldDelimiter': ',', 'allowJaggedRows': False, 'allowQuotedNewlines': False, 'encoding': 'UTF-8', 'ignoreUnknownValues': False, 'maxBadRecords': 0, 'quote': '"', 'skipLeadingRows': 0 } } } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/jobs/', expected_data=expected_data) api.jobs_insert_load('SOURCE2', google.datalab.bigquery._utils.TableName('p2', 'd2', 't2', ''), append=True, create=True, allow_jagged_rows=True, allow_quoted_newlines=True, ignore_unknown_values=True, source_format='JSON', max_bad_records=1) expected_data = { 'kind': 'bigquery#job', 'configuration': { 'load': { 'sourceUris': ['SOURCE2'], 'destinationTable': { 'projectId': 'p2', 'datasetId': 'd2', 'tableId': 't2' }, 'createDisposition': 'CREATE_IF_NEEDED', 'writeDisposition': 'WRITE_APPEND', 'sourceFormat': 'JSON', 'ignoreUnknownValues': True, 'maxBadRecords': 1 } } } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p2/jobs/', expected_data=expected_data) @mock.patch('google.datalab.utils.Http.request') def test_jobs_insert_query(self, mock_http_request): context = TestCases._create_context() context.config['bigquery_billing_tier'] = None api = TestCases._create_api(context) api.jobs_insert_query('SQL') expected_data = { 'kind': 'bigquery#job', 'configuration': { 'query': { 'query': 'SQL', 'useQueryCache': True, 'allowLargeResults': False, 'useLegacySql': False, }, 'dryRun': False, 'priority': 'BATCH', }, } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/test/jobs/', expected_data=expected_data) context.config['bigquery_billing_tier'] = 1 api.jobs_insert_query('SQL2', table_name=google.datalab.bigquery._utils.TableName('p', 'd', 't', ''), append=True, dry_run=True, use_cache=False, batch=False, allow_large_results=True) expected_data = { 'kind': 'bigquery#job', 'configuration': { 'query': { 'query': 'SQL2', 'useQueryCache': False, 'allowLargeResults': True, 'useLegacySql': False, 'maximumBillingTier': 1, 'destinationTable': { 'projectId': 'p', 'datasetId': 'd', 'tableId': 't' }, 'writeDisposition': 'WRITE_APPEND', }, 'dryRun': True, 'priority': 'INTERACTIVE', }, } self.maxDiff = None self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/test/jobs/', expected_data=expected_data) @mock.patch('google.datalab.utils.Http.request') def test_jobs_query_results(self, mock_http_request): api = TestCases._create_api() api.jobs_query_results('JOB', 'PROJECT', 10, 20, 30) self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/PROJECT/queries/JOB', expected_args={'maxResults': 10, 'timeoutMs': 20, 'startIndex': 30}) @mock.patch('google.datalab.utils.Http.request') def test_jobs_get(self, mock_http_request): api = TestCases._create_api() api.jobs_get('JOB', 'PROJECT') self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/PROJECT/jobs/JOB') @mock.patch('google.datalab.utils.Http.request') def test_datasets_insert(self, mock_http_request): api = TestCases._create_api() api.datasets_insert(google.datalab.bigquery._utils.DatasetName('p', 'd')) expected_data = { 'kind': 'bigquery#dataset', 'datasetReference': { 'projectId': 'p', 'datasetId': 'd', } } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/', expected_data=expected_data) api.datasets_insert(google.datalab.bigquery._utils.DatasetName('p', 'd'), 'FRIENDLY', 'DESCRIPTION') expected_data = { 'kind': 'bigquery#dataset', 'datasetReference': { 'projectId': 'p', 'datasetId': 'd' }, 'friendlyName': 'FRIENDLY', 'description': 'DESCRIPTION' } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/', expected_data=expected_data) @mock.patch('google.datalab.utils.Http.request') def test_datasets_delete(self, mock_http_request): api = TestCases._create_api() api.datasets_delete(google.datalab.bigquery._utils.DatasetName('p', 'd'), False) self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d', expected_args={}, expected_method='DELETE') api.datasets_delete(google.datalab.bigquery._utils.DatasetName('p', 'd'), True) self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d', expected_args={'deleteContents': True}, expected_method='DELETE') @mock.patch('google.datalab.utils.Http.request') def test_datasets_update(self, mock_http_request): api = TestCases._create_api() api.datasets_update(google.datalab.bigquery._utils.DatasetName('p', 'd'), 'INFO') self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d', expected_method='PUT', expected_data='INFO') @mock.patch('google.datalab.utils.Http.request') def test_datasets_get(self, mock_http_request): api = TestCases._create_api() api.datasets_get(google.datalab.bigquery._utils.DatasetName('p', 'd')) self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d') @mock.patch('google.datalab.utils.Http.request') def test_datasets_list(self, mock_http_request): api = TestCases._create_api() api.datasets_list() self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/test/datasets/', expected_args={}) api.datasets_list('PROJECT', 10, 'TOKEN') self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/PROJECT/datasets/', expected_args={'maxResults': 10, 'pageToken': 'TOKEN'}) @mock.patch('google.datalab.utils.Http.request') def test_tables_get(self, mock_http_request): api = TestCases._create_api() api.tables_get(google.datalab.bigquery._utils.TableName('p', 'd', 't', '')) self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/t') @mock.patch('google.datalab.utils.Http.request') def test_tables_list(self, mock_http_request): api = TestCases._create_api() api.tables_list(google.datalab.bigquery._utils.DatasetName('p', 'd')) self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/', expected_args={}) api.tables_list(google.datalab.bigquery._utils.DatasetName('p', 'd'), 10, 'TOKEN') self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/', expected_args={'maxResults': 10, 'pageToken': 'TOKEN'}) @mock.patch('google.datalab.utils.Http.request') def test_tables_insert(self, mock_http_request): api = TestCases._create_api() api.tables_insert(google.datalab.bigquery._utils.TableName('p', 'd', 't', '')) expected_data = { 'kind': 'bigquery#table', 'tableReference': { 'projectId': 'p', 'datasetId': 'd', 'tableId': 't' } } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/', expected_data=expected_data) api.tables_insert(google.datalab.bigquery._utils.TableName('p', 'd', 't', ''), 'SCHEMA', 'QUERY', 'FRIENDLY', 'DESCRIPTION') expected_data = { 'kind': 'bigquery#table', 'tableReference': { 'projectId': 'p', 'datasetId': 'd', 'tableId': 't' }, 'schema': { 'fields': 'SCHEMA' }, 'view': {'query': 'QUERY'}, 'friendlyName': 'FRIENDLY', 'description': 'DESCRIPTION' } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/', expected_data=expected_data) @mock.patch('google.datalab.utils.Http.request') def test_tabledata_insertAll(self, mock_http_request): api = TestCases._create_api() api.tabledata_insert_all(google.datalab.bigquery._utils.TableName('p', 'd', 't', ''), 'ROWS') expected_data = { 'kind': 'bigquery#tableDataInsertAllRequest', 'rows': 'ROWS' } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/t/insertAll', expected_data=expected_data) @mock.patch('google.datalab.utils.Http.request') def test_tabledata_list(self, mock_http_request): api = TestCases._create_api() api.tabledata_list(google.datalab.bigquery._utils.TableName('p', 'd', 't', '')) self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/t/data', expected_args={}) api.tabledata_list(google.datalab.bigquery._utils.TableName('p', 'd', 't', ''), 10, 20, 'TOKEN') self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/t/data', expected_args={ 'startIndex': 10, 'maxResults': 20, 'pageToken': 'TOKEN' }) @mock.patch('google.datalab.utils.Http.request') def test_table_delete(self, mock_http_request): api = TestCases._create_api() api.table_delete(google.datalab.bigquery._utils.TableName('p', 'd', 't', '')) self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/t', expected_method='DELETE') @mock.patch('google.datalab.utils.Http.request') def test_table_extract(self, mock_http_request): api = TestCases._create_api() api.table_extract(google.datalab.bigquery._utils.TableName('p', 'd', 't', ''), 'DEST') expected_data = { 'kind': 'bigquery#job', 'configuration': { 'extract': { 'sourceTable': { 'projectId': 'p', 'datasetId': 'd', 'tableId': 't' }, 'compression': 'GZIP', 'fieldDelimiter': ',', 'printHeader': True, 'destinationUris': ['DEST'], 'destinationFormat': 'CSV', } } } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/jobs/', expected_data=expected_data) api.table_extract(google.datalab.bigquery._utils.TableName('p', 'd', 't', ''), ['DEST'], format='JSON', compress=False, field_delimiter=':', print_header=False) expected_data = { 'kind': 'bigquery#job', 'configuration': { 'extract': { 'sourceTable': { 'projectId': 'p', 'datasetId': 'd', 'tableId': 't' }, 'compression': 'NONE', 'fieldDelimiter': ':', 'printHeader': False, 'destinationUris': ['DEST'], 'destinationFormat': 'JSON', } } } self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/jobs/', expected_data=expected_data) @mock.patch('google.datalab.utils.Http.request') def test_table_update(self, mock_http_request): api = TestCases._create_api() api.table_update(google.datalab.bigquery._utils.TableName('p', 'd', 't', ''), 'INFO') self.validate(mock_http_request, 'https://www.googleapis.com/bigquery/v2/projects/p/datasets/d/tables/t', expected_method='PUT', expected_data='INFO') @staticmethod def _create_api(context=None): if not context: context = TestCases._create_context() return Api(context) @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return google.datalab.Context(project_id, creds) ================================================ FILE: tests/bigquery/dataset_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals from builtins import str import mock import unittest import google.auth import google.datalab import google.datalab.bigquery import google.datalab.utils class TestCases(unittest.TestCase): def _check_name_parts(self, dataset): parsed_name = dataset._name_parts self.assertEqual('test', parsed_name[0]) self.assertEqual('requestlogs', parsed_name[1]) self.assertEqual('test.requestlogs', dataset._full_name) self.assertEqual('test.requestlogs', str(dataset)) def test_parse_full_name(self): dataset = TestCases._create_dataset('test.requestlogs') self._check_name_parts(dataset) def test_parse_local_name(self): dataset = TestCases._create_dataset('requestlogs') self._check_name_parts(dataset) def test_parse_dict_full_name(self): dataset = TestCases._create_dataset({'project_id': 'test', 'dataset_id': 'requestlogs'}) self._check_name_parts(dataset) def test_parse_dict_local_name(self): dataset = TestCases._create_dataset({'dataset_id': 'requestlogs'}) self._check_name_parts(dataset) def test_parse_named_tuple_name(self): dataset = TestCases._create_dataset(google.datalab.bigquery._utils.DatasetName('test', 'requestlogs')) self._check_name_parts(dataset) def test_parse_tuple_full_name(self): dataset = TestCases._create_dataset(('test', 'requestlogs')) self._check_name_parts(dataset) def test_parse_tuple_local(self): dataset = TestCases._create_dataset(('requestlogs')) self._check_name_parts(dataset) def test_parse_array_full_name(self): dataset = TestCases._create_dataset(['test', 'requestlogs']) self._check_name_parts(dataset) def test_parse_array_local(self): dataset = TestCases._create_dataset(['requestlogs']) self._check_name_parts(dataset) def test_parse_invalid_name(self): with self.assertRaises(Exception): TestCases._create_dataset('today@') @mock.patch('google.datalab.bigquery._api.Api.datasets_get') def test_dataset_exists(self, mock_api_datasets_get): mock_api_datasets_get.return_value = '' dataset = TestCases._create_dataset('test.requestlogs') self.assertTrue(dataset.exists()) mock_api_datasets_get.side_effect = google.datalab.utils.RequestException(404, None) dataset._info = None self.assertFalse(dataset.exists()) @mock.patch('google.datalab.bigquery._api.Api.datasets_insert') @mock.patch('google.datalab.bigquery._api.Api.datasets_get') def test_datasets_create_fails(self, mock_api_datasets_get, mock_api_datasets_insert): mock_api_datasets_get.side_effect = google.datalab.utils.RequestException(None, 404) mock_api_datasets_insert.return_value = {} ds = TestCases._create_dataset('requestlogs') with self.assertRaises(Exception): ds.create() @mock.patch('google.datalab.bigquery._api.Api.datasets_insert') @mock.patch('google.datalab.bigquery._api.Api.datasets_get') def test_datasets_create_succeeds(self, mock_api_datasets_get, mock_api_datasets_insert): mock_api_datasets_get.side_effect = google.datalab.utils.RequestException(404, None) mock_api_datasets_insert.return_value = {'selfLink': None} ds = TestCases._create_dataset('requestlogs') self.assertEqual(ds, ds.create()) @mock.patch('google.datalab.bigquery._api.Api.datasets_insert') @mock.patch('google.datalab.bigquery._api.Api.datasets_get') def test_datasets_create_redundant(self, mock_api_datasets_get, mock_api_datasets_insert): ds = TestCases._create_dataset('requestlogs', {}) mock_api_datasets_get.return_value = None mock_api_datasets_insert.return_value = {} self.assertEqual(ds, ds.create()) @mock.patch('google.datalab.bigquery._api.Api.datasets_get') @mock.patch('google.datalab.bigquery._api.Api.datasets_delete') def test_datasets_delete_succeeds(self, mock_api_datasets_delete, mock_api_datasets_get): mock_api_datasets_get.return_value = '' mock_api_datasets_delete.return_value = None ds = TestCases._create_dataset('requestlogs') self.assertIsNone(ds.delete()) @mock.patch('google.datalab.bigquery._api.Api.datasets_get') @mock.patch('google.datalab.bigquery._api.Api.datasets_delete') def test_datasets_delete_fails(self, mock_api_datasets_delete, mock_api_datasets_get): mock_api_datasets_delete.return_value = None mock_api_datasets_get.side_effect = google.datalab.utils.RequestException(404, None) ds = TestCases._create_dataset('requestlogs') with self.assertRaises(Exception): ds.delete() @mock.patch('google.datalab.bigquery._api.Api.tables_list') def test_tables_list(self, mock_api_tables_list): mock_api_tables_list.return_value = { 'tables': [ { 'type': 'TABLE', 'tableReference': {'projectId': 'p', 'datasetId': 'd', 'tableId': 't1'} }, { 'type': 'TABLE', 'tableReference': {'projectId': 'p', 'datasetId': 'd', 'tableId': 't2'} }, ] } ds = TestCases._create_dataset('requestlogs') tables = [table for table in ds] self.assertEqual(2, len(tables)) self.assertEqual('`p.d.t1`', tables[0]._repr_sql_()) self.assertEqual('`p.d.t2`', tables[1]._repr_sql_()) @mock.patch('google.datalab.bigquery.Dataset._get_info') @mock.patch('google.datalab.bigquery._api.Api.datasets_list') def test_datasets_list(self, mock_api_datasets_list, mock_dataset_get_info): mock_api_datasets_list.return_value = { 'datasets': [ {'datasetReference': {'projectId': 'p', 'datasetId': 'd1'}}, {'datasetReference': {'projectId': 'p', 'datasetId': 'd2'}}, ] } mock_dataset_get_info.return_value = {} datasets = [dataset for dataset in google.datalab.bigquery.Datasets(TestCases._create_context())] self.assertEqual(2, len(datasets)) self.assertEqual('p.d1', str(datasets[0])) self.assertEqual('p.d2', str(datasets[1])) @mock.patch('google.datalab.bigquery._api.Api.tables_list') @mock.patch('google.datalab.bigquery._api.Api.datasets_get') @mock.patch('google.datalab.bigquery._api.Api.datasets_update') def test_datasets_update(self, mock_api_datasets_update, mock_api_datasets_get, mock_api_tables_list): mock_api_tables_list.return_value = { 'tables': [ {'type': 'TABLE', 'tableReference': {'projectId': 'p', 'datasetId': 'd', 'tableId': 't1'}}, {'type': 'TABLE', 'tableReference': {'projectId': 'p', 'datasetId': 'd', 'tableId': 't2'}}, ] } info = {'friendlyName': 'casper', 'description': 'ghostly logs'} mock_api_datasets_get.return_value = info ds = TestCases._create_dataset('requestlogs') new_friendly_name = 'aziraphale' new_description = 'demon duties' ds.update(new_friendly_name, new_description) name, info = mock_api_datasets_update.call_args[0] self.assertEqual(ds.name, name) self.assertEqual(new_friendly_name, ds.friendly_name) self.assertEqual(new_description, ds.description) @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return google.datalab.Context(project_id, creds) @staticmethod def _create_dataset(name, metadata=None): # Patch get_info so we don't have to mock it everywhere else. orig = google.datalab.bigquery.Dataset._get_info google.datalab.bigquery.Dataset._get_info = mock.Mock(return_value=metadata) ds = google.datalab.bigquery.Dataset(name, context=TestCases._create_context()) google.datalab.bigquery.Dataset._get_info = orig return ds ================================================ FILE: tests/bigquery/external_data_source_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import collections import mock import unittest import google.auth import google.datalab import google.datalab.bigquery import google.datalab.utils class TestCases(unittest.TestCase): # The main thing we need to test is a query that references an external table and how # that translates into a REST call. @staticmethod def _request_result(): return { 'jobReference': { 'jobId': 'job1234' }, 'configuration': { 'query': { 'destinationTable': { 'projectId': 'test', 'datasetId': 'dataset', 'tableId': 'table' } } }, 'jobComplete': True } @staticmethod def _get_data(): data = [] day = 1 for weight in [220, 221, 220, 219, 218]: d = collections.OrderedDict() data.append(d) d['day'] = day day += 1 d['weight'] = weight return data @staticmethod def _get_table_definition(uris, skip_rows=0): if not isinstance(uris, list): uris = [uris] return { 'compression': 'NONE', 'csvOptions': { 'allowJaggedRows': False, 'quote': '"', 'encoding': 'UTF-8', 'skipLeadingRows': skip_rows, 'fieldDelimiter': ',', 'allowQuotedNewlines': False }, 'sourceFormat': 'CSV', 'maxBadRecords': 0, 'ignoreUnknownValues': False, 'sourceUris': uris, 'schema': { 'fields': [ {'type': 'INTEGER', 'name': 'day'}, {'type': 'INTEGER', 'name': 'weight'} ] } } @staticmethod def _get_expected_request_data(sql, table_definitions): return { 'kind': 'bigquery#job', 'configuration': { 'priority': 'INTERACTIVE', 'query': { 'query': sql, 'useLegacySql': True, 'allowLargeResults': False, 'tableDefinitions': table_definitions, 'useQueryCache': True, 'userDefinedFunctionResources': [] }, 'dryRun': False } } @mock.patch('google.datalab.utils.Http.request') def test_external_table_query(self, mock_http_request): mock_http_request.return_value = self._request_result() data = self._get_data() schema = google.datalab.bigquery.Schema.from_data(data) table_uri = 'gs://google.datalab/weight.csv' options = google.datalab.bigquery.CSVOptions(skip_leading_rows=1) sql = 'SELECT * FROM weight' weight = google.datalab.bigquery.ExternalDataSource(table_uri, schema=schema, csv_options=options) q = google.datalab.bigquery.Query(sql, data_sources={'weight': weight}) q.execute_async() table_definition = self._get_table_definition(table_uri, skip_rows=1) expected_data = self._get_expected_request_data(sql, {'weight': table_definition}) request_url = 'https://www.googleapis.com/bigquery/v2/projects/test/jobs/' mock_http_request.assert_called_with(request_url, credentials=mock.ANY, data=expected_data) # Test with multiple URLs and no non-default options @mock.patch('google.datalab.utils.Http.request') def test_external_table_query2(self, mock_http_request): mock_http_request.return_value = self._request_result() data = self._get_data() schema = google.datalab.bigquery.Schema.from_data(data) table_uris = ['gs://google.datalab/weight1.csv', 'gs://google.datalab/weight2.csv'] sql = 'SELECT * FROM weight' weight = google.datalab.bigquery.ExternalDataSource(table_uris, schema=schema) q = google.datalab.bigquery.Query(sql, data_sources={'weight': weight}) q.execute_async() table_definition = self._get_table_definition(table_uris) expected_data = self._get_expected_request_data(sql, {'weight': table_definition}) request_url = 'https://www.googleapis.com/bigquery/v2/projects/test/jobs/' mock_http_request.assert_called_with(request_url, credentials=mock.ANY, data=expected_data) # Test with multiple tables and using keyword args @mock.patch('google.datalab.utils.Http.request') def test_external_tables_query(self, mock_http_request): mock_http_request.return_value = self._request_result() data = self._get_data() schema = google.datalab.bigquery.Schema.from_data(data) table_uri1 = 'gs://google.datalab/weight1.csv' table_uri2 = 'gs://google.datalab/weight2.csv' sql = 'SELECT * FROM weight1 JOIN weight2 ON day' options = google.datalab.bigquery.CSVOptions(skip_leading_rows=1) weight1 = google.datalab.bigquery.ExternalDataSource(table_uri1, schema=schema, csv_options=options) weight2 = google.datalab.bigquery.ExternalDataSource(table_uri2, schema=schema) q = google.datalab.bigquery.Query(sql, env={'weight1': weight1, 'weight2': weight2}) q.execute_async() table_definition1 = self._get_table_definition(table_uri1, skip_rows=1) table_definition2 = self._get_table_definition(table_uri2) table_definitions = {'weight1': table_definition1, 'weight2': table_definition2} expected_data = self._get_expected_request_data(sql, table_definitions) request_url = 'https://www.googleapis.com/bigquery/v2/projects/test/jobs/' mock_http_request.assert_called_with(request_url, credentials=mock.ANY, data=expected_data) @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return google.datalab.Context(project_id, creds) ================================================ FILE: tests/bigquery/jobs_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import mock import unittest import google.auth import google.datalab import google.datalab.bigquery from google.datalab.bigquery._job import Job class TestCases(unittest.TestCase): @staticmethod def _make_job(id): return Job(id, TestCases._create_context()) @mock.patch('google.datalab.bigquery._api.Api.jobs_get') def test_job_complete(self, mock_api_jobs_get): mock_api_jobs_get.return_value = {} j = TestCases._make_job('foo') self.assertFalse(j.is_complete) self.assertFalse(j.failed) mock_api_jobs_get.return_value = {'status': {'state': 'DONE'}} self.assertTrue(j.is_complete) self.assertFalse(j.failed) @mock.patch('google.datalab.bigquery._api.Api.jobs_get') def test_job_fatal_error(self, mock_api_jobs_get): mock_api_jobs_get.return_value = { 'status': { 'state': 'DONE', 'errorResult': { 'location': 'A', 'message': 'B', 'reason': 'C' } } } j = TestCases._make_job('foo') self.assertTrue(j.is_complete) self.assertTrue(j.failed) e = j.fatal_error self.assertIsNotNone(e) self.assertEqual('A', e.location) self.assertEqual('B', e.message) self.assertEqual('C', e.reason) @mock.patch('google.datalab.bigquery._api.Api.jobs_get') def test_job_errors(self, mock_api_jobs_get): mock_api_jobs_get.return_value = { 'status': { 'state': 'DONE', 'errors': [ { 'location': 'A', 'message': 'B', 'reason': 'C' }, { 'location': 'D', 'message': 'E', 'reason': 'F' } ] } } j = TestCases._make_job('foo') self.assertTrue(j.is_complete) self.assertFalse(j.failed) self.assertEqual(2, len(j.errors)) self.assertEqual('A', j.errors[0].location) self.assertEqual('B', j.errors[0].message) self.assertEqual('C', j.errors[0].reason) self.assertEqual('D', j.errors[1].location) self.assertEqual('E', j.errors[1].message) self.assertEqual('F', j.errors[1].reason) @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return google.datalab.Context(project_id, creds) @staticmethod def _create_api(): return google.datalab.bigquery._api.Api(TestCases._create_context()) ================================================ FILE: tests/bigquery/operator_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. import google.datalab.contrib.pipeline._pipeline as pipeline import mock import unittest # import Python so we can mock the parts we need to here. import IPython.core.magic def noop_decorator(func): return func IPython.core.magic.register_line_cell_magic = noop_decorator IPython.core.magic.register_line_magic = noop_decorator IPython.core.magic.register_cell_magic = noop_decorator IPython.get_ipython = mock.Mock() import google.datalab # noqa import google.datalab.bigquery # noqa import google.datalab.bigquery.commands # noqa import google.datalab.utils.commands # noqa from google.datalab.contrib.bigquery.operators._bq_extract_operator import ExtractOperator # noqa from google.datalab.contrib.bigquery.operators._bq_execute_operator import ExecuteOperator # noqa from google.datalab.contrib.bigquery.operators._bq_load_operator import LoadOperator # noqa class TestCases(unittest.TestCase): test_project_id = 'test_project' test_table_name = 'project.test.table' test_schema = [ {"type": "INTEGER", "name": "key"}, {"type": "FLOAT", "name": "var1"}, {"type": "FLOAT", "name": "var2"} ] @staticmethod def _create_context(): project_id = TestCases.test_project_id creds = mock.Mock(spec=google.auth.credentials.Credentials) return google.datalab.Context(project_id, creds) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Table.extract') def test_extract_operator(self, mock_table_extract, mock_context_default): mock_context_default.return_value = TestCases._create_context() extract_operator = ExtractOperator(table=TestCases.test_project_id + '.test_table', path='test_path', format=None, task_id='test_extract_operator', csv_options={'delimiter': '$'}) # Happy path mock_table_extract.return_value.result = lambda: 'test-results' mock_table_extract.return_value.failed = False mock_table_extract.return_value.errors = None self.assertDictEqual(extract_operator.execute(context=None), {'result': 'test-results'}) mock_table_extract.assert_called_with('test_path', format='NEWLINE_DELIMITED_JSON', csv_delimiter='$') # Extract failed mock_table_extract.return_value.result = lambda: 'test-results' mock_table_extract.return_value.failed = True mock_table_extract.return_value.errors = None with self.assertRaisesRegexp(Exception, "Extract failed:"): extract_operator.execute(context=None) # Extract completed with errors mock_table_extract.return_value.result = lambda: 'test-results' mock_table_extract.return_value.failed = False mock_table_extract.return_value.errors = 'foo_error' with self.assertRaisesRegexp(Exception, 'Extract completed with errors: foo_error'): extract_operator.execute(context=None) @mock.patch('google.datalab.bigquery.Query.execute') @mock.patch('google.datalab.utils.commands.get_notebook_item') def test_execute_operator_definition(self, mock_get_notebook_item, mock_query_execute): mock_get_notebook_item.return_value = google.datalab.bigquery.Query('test_sql') task_id = 'foo' task_details = {} task_details['type'] = 'pydatalab.bq.execute' task_details['sql'] = 'test_sql' task_details['mode'] = 'create' actual = pipeline.PipelineGenerator._get_operator_definition(task_id, task_details, None) expected = """foo = ExecuteOperator(task_id='foo_id', mode=\"\"\"create\"\"\", sql=\"\"\"test_sql\"\"\", dag=dag) """ # noqa self.assertEqual(actual, expected) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Query.execute') @mock.patch('google.datalab.bigquery.QueryOutput.table') @mock.patch('google.datalab.bigquery._query_job.QueryJob') def test_execute_operator(self, mock_query_job, mock_query_output_table, mock_query_execute, mock_context_default): mock_context_default.return_value = self._create_context() execute_operator = ExecuteOperator(task_id='test_execute_operator', sql='test_sql') mock_query_execute.return_value = mock_query_job mock_query_job.result.return_value = google.datalab.bigquery.Table(TestCases.test_table_name) self.assertDictEqual(execute_operator.execute(context=None), {'table': TestCases.test_table_name}) mock_query_output_table.assert_called_with(name=None, use_cache=False, allow_large_results=False) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.ExternalDataSource') @mock.patch('google.datalab.bigquery.Query') @mock.patch('google.datalab.bigquery.QueryOutput.table') @mock.patch('google.datalab.bigquery._query_job.QueryJob') def test_execute_operator_with_data_source(self, mock_query_job, mock_query_output_table, mock_query_class, mock_external_data_source, mock_context_default): mock_context_default.return_value = self._create_context() csv_options = {'delimiter': 'f', 'skip': 9, 'strict': True, 'quote': 'l'} execute_operator = ExecuteOperator(task_id='test_execute_operator', sql='test_sql', data_source='foo_data_source', path='foo_path', max_bad_records=20, schema=TestCases.test_schema, csv_options=csv_options, format='csv') mock_query_instance = mock_query_class.return_value mock_query_instance.execute.return_value = mock_query_job mock_query_job.result.return_value = google.datalab.bigquery.Table(TestCases.test_table_name) self.assertDictEqual(execute_operator.execute(context=None), {'table': TestCases.test_table_name}) mock_query_output_table.assert_called_with(name=None, use_cache=False, allow_large_results=False) mock_query_class.assert_called_with( sql='test_sql', data_sources={'foo_data_source': mock_external_data_source.return_value}) mock_external_data_source.assert_called_with( source='foo_path', max_bad_records=20, csv_options=mock.ANY, source_format='csv', schema=google.datalab.bigquery.Schema(TestCases.test_schema)) execute_operator = ExecuteOperator(task_id='test_execute_operator', sql='test_sql', data_source='foo_data_source', path='foo_path', schema=TestCases.test_schema) mock_query_instance = mock_query_class.return_value mock_query_instance.execute.return_value = mock_query_job execute_operator.execute(None) mock_query_output_table.assert_called_with(name=None, use_cache=False, allow_large_results=False) mock_query_class.assert_called_with(sql='test_sql', data_sources={'foo_data_source': mock_external_data_source.return_value}) mock_external_data_source.assert_called_with(source='foo_path', schema=google.datalab.bigquery.Schema( TestCases.test_schema)) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery._api.Api.tables_insert') @mock.patch('google.datalab.bigquery.Table.create') @mock.patch('google.datalab.bigquery.Table.exists') @mock.patch('google.datalab.bigquery.Table.load') def test_load_operator(self, mock_table_load, mock_table_exists, mock_table_create, mock_api_tables_insert, mock_context_default): mock_context_default.return_value = self._create_context() # Table exists mock_table_exists.return_value = True load_operator = LoadOperator(table=TestCases.test_table_name, path='test/path', mode='append', format=None, csv_options=None, schema=None, task_id='test_operator_id') mock_job = mock.Mock() mock_job.result.return_value = 'test-result' mock_job.failed = False mock_job.errors = False mock_table_load.return_value = mock_job load_operator.execute(context=None) mock_table_load.assert_called_with('test/path', mode='append', source_format='NEWLINE_DELIMITED_JSON', csv_options=mock.ANY, ignore_unknown_values=True) # Table does not exist mock_table_exists.return_value = False csv_options = {'delimiter': 'f', 'skip': 9, 'strict': True, 'quote': '"'} schema = [ {"type": "INTEGER", "name": "key"}, {"type": "FLOAT", "name": "var1"}, {"type": "FLOAT", "name": "var2"} ] load_operator = LoadOperator(table=TestCases.test_table_name, path='test/path', mode='append', format=None, csv_options=csv_options, schema=schema, task_id='test_operator_id') mock_job = mock.Mock() mock_job.result.return_value = 'test-result' mock_job.failed = False mock_job.errors = False mock_table_load.return_value = mock_job load_operator.execute(context=None) mock_table_load.assert_called_with('test/path', mode='append', source_format='NEWLINE_DELIMITED_JSON', csv_options=mock.ANY, ignore_unknown_values=False) mock_table_create.assert_called_with(schema=schema) # Table load fails load_operator = LoadOperator(table=TestCases.test_table_name, path='test/path', mode='append', format=None, csv_options=None, schema=schema, task_id='test_operator_id') mock_job = mock.Mock() mock_job.failed = True mock_job.fatal_error = 'fatal error' mock_table_load.return_value = mock_job with self.assertRaisesRegexp(Exception, 'Load failed: fatal error'): load_operator.execute(context=None) # Table load completes with errors load_operator = LoadOperator(table=TestCases.test_table_name, path='test/path', mode='append', format=None, csv_options=None, schema=TestCases.test_schema, task_id='test_operator_id') mock_job = mock.Mock() mock_job.failed = False mock_job.errors = 'error' mock_table_load.return_value = mock_job with self.assertRaisesRegexp(Exception, 'Load completed with errors: error'): load_operator.execute(context=None) def test_execute_operator_defaults(self): execute_operator = ExecuteOperator(task_id='foo_task_id', sql='foo_sql') self.assertIsNone(execute_operator.parameters) self.assertIsNone(execute_operator.table) self.assertIsNone(execute_operator.mode) self.assertEqual(execute_operator.template_fields, ('table', 'parameters', 'path', 'sql')) def test_extract_operator_defaults(self): extract_operator = ExtractOperator(task_id='foo_task_id', path='foo_path', table='foo_table') self.assertEquals(extract_operator.format, 'csv') self.assertDictEqual(extract_operator.csv_options, {}) self.assertEqual(extract_operator.template_fields, ('table', 'path')) def test_load_operator_defaults(self): load_operator = LoadOperator(task_id='foo_task_id', path='foo_path', table='foo_table') self.assertEquals(load_operator.format, 'csv') self.assertEquals(load_operator.mode, 'append') self.assertIsNone(load_operator.schema) self.assertDictEqual(load_operator.csv_options, {}) self.assertEqual(load_operator.template_fields, ('table', 'path')) ================================================ FILE: tests/bigquery/parser_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import unittest import google.datalab.bigquery as bq class TestCases(unittest.TestCase): def test_repeating_data(self): schema = [{'name': 'counts', 'type': 'INTEGER', 'mode': 'REPEATED'}] data = {'f': [{'v': [{'v': 0}, {'v': 1}, {'v': 2}]}]} parsed = {'counts': [0, 1, 2]} result = bq._parser.Parser.parse_row(schema, data) self.assertEqual(parsed, result) def test_non_nested_data(self): data = {u'f': [{u'v': u'1969'}, {u'v': u'1969'}, {u'v': u'1'}, {u'v': u'20'}, {u'v': None}, {u'v': u'AL'}, {u'v': u'true'}, {u'v': u'1'}, {u'v': u'7.81318256528'}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': u'AL'}, {u'v': u'1'}, {u'v': u'20'}, {u'v': None}, {u'v': u'88881998'}, {u'v': u'true'}, {u'v': u''}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': u'1'}, {u'v': u'0'}, {u'v': u'0'}, {u'v': u'2'}, {u'v': u'1'}, {u'v': u'19'}, {u'v': u'2'}]} natality_schema = [{u'description': u'Four-digit year of the birth. Example: 1975.', u'mode': u'REQUIRED', u'name': u'source_year', u'type': u'INTEGER'}, {u'description': u'Four-digit year of the birth. Example: 1975.', u'mode': u'NULLABLE', u'name': u'year', u'type': u'INTEGER'}, {u'description': u'Month index of the date of birth, where 1=January.', u'mode': u'NULLABLE', u'name': u'month', u'type': u'INTEGER'}, {u'description': u'Day of birth, starting from 1.', u'mode': u'NULLABLE', u'name': u'day', u'type': u'INTEGER'}, {u'description': u'Day of the week, where 1 is Sunday and 7 is Saturday.', u'mode': u'NULLABLE', u'name': u'wday', u'type': u'INTEGER'}, {u'description': u'The two character postal code for the state. ' u'Entries after 2004 do not include this value.', u'mode': u'NULLABLE', u'name': u'state', u'type': u'STRING'}, {u'description': u'TRUE if the child is male, FALSE if female.', u'mode': u'REQUIRED', u'name': u'is_male', u'type': u'BOOLEAN'}, {u'description': u'The race of the child. One of the following numbers:\n\n' u'1 - White\n2 - Black\n3 - American Indian\n4 - Chinese\n' u'5 - Japanese\n6 - Hawaiian\n7 - Filipino\n' u'9 - Unknown/Other\n18 - Asian Indian\n28 - Korean\n' u'39 - Samoan\n48 - Vietnamese', u'mode': u'NULLABLE', u'name': u'child_race', u'type': u'INTEGER'}, {u'description': u'Weight of the child, in pounds.', u'mode': u'NULLABLE', u'name': u'weight_pounds', u'type': u'FLOAT'}, {u'description': u'How many children were born as a result of this ' u'pregnancy. twins=2, triplets=3, and so on.', u'mode': u'NULLABLE', u'name': u'plurality', u'type': u'INTEGER'}, {u'description': u'Apgar scores measure the health of a newborn child on a ' u'scale from 0-10. Value after 1 minute. Available from ' u'1978-2002.', u'mode': u'NULLABLE', u'name': u'apgar_1min', u'type': u'INTEGER'}, {u'description': u'Apgar scores measure the health of a newborn child on a ' u'scale from 0-10. Value after 5 minutes. Available from ' u'1978-2002.', u'mode': u'NULLABLE', u'name': u'apgar_5min', u'type': u'INTEGER'}, {u'description': u"The two-letter postal code of the mother's state of " u"residence when the child was born.", u'mode': u'NULLABLE', u'name': u'mother_residence_state', u'type': u'STRING'}, {u'description': u'Race of the mother. Same values as child_race.', u'mode': u'NULLABLE', u'name': u'mother_race', u'type': u'INTEGER'}, {u'description': u'Reported age of the mother when giving birth.', u'mode': u'NULLABLE', u'name': u'mother_age', u'type': u'INTEGER'}, {u'description': u'The number of weeks of the pregnancy.', u'mode': u'NULLABLE', u'name': u'gestation_weeks', u'type': u'INTEGER'}, {u'description': u'Date of the last menstrual period in the format ' u'MMDDYYYY. Unknown values are recorded as "99" or "9999".', u'mode': u'NULLABLE', u'name': u'lmp', u'type': u'STRING'}, {u'description': u'True if the mother was married when she gave birth.', u'mode': u'NULLABLE', u'name': u'mother_married', u'type': u'BOOLEAN'}, {u'description': u"The two-letter postal code of the mother's birth state.", u'mode': u'NULLABLE', u'name': u'mother_birth_state', u'type': u'STRING'}, {u'description': u'True if the mother smoked cigarettes. Available starting ' u'2003.', u'mode': u'NULLABLE', u'name': u'cigarette_use', u'type': u'BOOLEAN'}, {u'description': u'Number of cigarettes smoked by the mother per day. ' u'Available starting 2003.', u'mode': u'NULLABLE', u'name': u'cigarettes_per_day', u'type': u'INTEGER'}, {u'description': u'True if the mother used alcohol. Available starting ' u'1989.', u'mode': u'NULLABLE', u'name': u'alcohol_use', u'type': u'BOOLEAN'}, {u'description': u'Number of drinks per week consumed by the mother. ' u'Available starting 1989.', u'mode': u'NULLABLE', u'name': u'drinks_per_week', u'type': u'INTEGER'}, {u'description': u'Number of pounds gained by the mother during pregnancy.', u'mode': u'NULLABLE', u'name': u'weight_gain_pounds', u'type': u'INTEGER'}, {u'description': u'Number of children previously born to the mother who are ' u'now living.', u'mode': u'NULLABLE', u'name': u'born_alive_alive', u'type': u'INTEGER'}, {u'description': u'Number of children previously born to the mother who are ' u'now dead.', u'mode': u'NULLABLE', u'name': u'born_alive_dead', u'type': u'INTEGER'}, {u'description': u'Number of children who were born dead ' u'(i.e. miscarriages)', u'mode': u'NULLABLE', u'name': u'born_dead', u'type': u'INTEGER'}, {u'description': u'Total number of children to whom the woman has ever ' u'given birth (includes the current birth).', u'mode': u'NULLABLE', u'name': u'ever_born', u'type': u'INTEGER'}, {u'description': u'Race of the father. Same values as child_race.', u'mode': u'NULLABLE', u'name': u'father_race', u'type': u'INTEGER'}, {u'description': u'Age of the father when the child was born.', u'mode': u'NULLABLE', u'name': u'father_age', u'type': u'INTEGER'}, {u'description': u'1 or 2, where 1 is a row from a full-reporting area, and ' u'2 is a row from a 50% sample area.', u'mode': u'NULLABLE', u'name': u'record_weight', u'type': u'INTEGER'}] parsed = {u'alcohol_use': None, u'apgar_1min': None, u'apgar_5min': None, u'born_alive_alive': 1, u'born_alive_dead': 0, u'born_dead': 0, u'child_race': 1, u'cigarette_use': None, u'cigarettes_per_day': None, u'day': 20, u'drinks_per_week': None, u'ever_born': 2, u'father_age': 19, u'father_race': 1, u'gestation_weeks': None, u'is_male': True, u'lmp': u'88881998', u'month': 1, u'mother_age': 20, u'mother_birth_state': u'', u'mother_married': True, u'mother_race': 1, u'mother_residence_state': u'AL', u'plurality': None, u'record_weight': 2, u'source_year': 1969, u'state': u'AL', u'wday': None, u'weight_gain_pounds': None, u'weight_pounds': 7.81318256528, u'year': 1969} self.assertEqual(parsed, bq._parser.Parser.parse_row(natality_schema, data)) def test_parse_nested_data(self): self.maxDiff = None # Show full diff on failure data = {u'f': [{u'v': {u'f': [{u'v': u'https://github.com/foo'}, {u'v': u'true'}, {u'v': u'2011/04/12 20:04:19 -0700'}, {u'v': u'true'}, {u'v': u'A website.'}, {u'v': u'17'}, {u'v': u'false'}, {u'v': u'true'}, {u'v': u'http://foo.com/'}, {u'v': None}, {u'v': None}, {u'v': u'424'}, {u'v': u'false'}, {u'v': u'foo'}, {u'v': None}, {u'v': u'foo'}, {u'v': u'0'}, {u'v': u'95'}, {u'v': u'2012/03/15 00:00:00 -0700'}, {u'v': u'Ruby'}]}}, {u'v': {u'f': [{u'v': u'http://foo.com/'}, {u'v': u'Flickr'}, {u'v': u'd+github@foo.com'}, {u'v': u'94c21234567890abcdef25e704b88407'}, {u'v': u'San Francisco, California'}, {u'v': u'foo'}, {u'v': u'Foo Bar'}, {u'v': u'User'}]}}, {u'v': u'2012/03/15 00:00:01 -0700'}, {u'v': u'true'}, {u'v': u'foo'}, {u'v': {u'f': [ {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': u'2de950123456789abcdef01234451feaf8ce6ae'}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': None}, {u'v': []}, {u'v': None}, {u'v': u'refs/heads/master'}, {u'v': None}, {u'v': u'1'}, {u'v': [ {u'v': {u'f': [{u'v': u'2de958ab480eabe2501b343425b451feaf8ce6ae'}, {u'v': u'd+github@foo.com'}, {u'v': u'Foo tastes good.'}, {u'v': u'Foo Bar'}]}}]}, {u'v': None}, {u'v': None}]}}, {u'v': u'https://github.com/compare/d3e91cb736...2de958ab48'}, {u'v': u'PushEvent'}]} github_nested_schema = [{u'fields': [{u'name': u'url', u'type': u'STRING'}, {u'name': u'has_downloads', u'type': u'BOOLEAN'}, {u'name': u'created_at', u'type': u'STRING'}, {u'name': u'has_issues', u'type': u'BOOLEAN'}, {u'name': u'description', u'type': u'STRING'}, {u'name': u'forks', u'type': u'INTEGER'}, {u'name': u'fork', u'type': u'BOOLEAN'}, {u'name': u'has_wiki', u'type': u'BOOLEAN'}, {u'name': u'homepage', u'type': u'STRING'}, {u'name': u'integrate_branch', u'type': u'STRING'}, {u'name': u'master_branch', u'type': u'STRING'}, {u'name': u'size', u'type': u'INTEGER'}, {u'name': u'private', u'type': u'BOOLEAN'}, {u'name': u'name', u'type': u'STRING'}, {u'name': u'organization', u'type': u'STRING'}, {u'name': u'owner', u'type': u'STRING'}, {u'name': u'open_issues', u'type': u'INTEGER'}, {u'name': u'watchers', u'type': u'INTEGER'}, {u'name': u'pushed_at', u'type': u'STRING'}, {u'name': u'language', u'type': u'STRING'}], u'name': u'repository', u'type': u'RECORD'}, {u'fields': [{u'name': u'blog', u'type': u'STRING'}, {u'name': u'company', u'type': u'STRING'}, {u'name': u'email', u'type': u'STRING'}, {u'name': u'gravatar_id', u'type': u'STRING'}, {u'name': u'location', u'type': u'STRING'}, {u'name': u'login', u'type': u'STRING'}, {u'name': u'name', u'type': u'STRING'}, {u'name': u'type', u'type': u'STRING'}], u'name': u'actor_attributes', u'type': u'RECORD'}, {u'name': u'created_at', u'type': u'STRING'}, {u'name': u'public', u'type': u'BOOLEAN'}, {u'name': u'actor', u'type': u'STRING'}, {u'fields': [ {u'name': u'action', u'type': u'STRING'}, {u'name': u'after', u'type': u'STRING'}, {u'name': u'before', u'type': u'STRING'}, {u'name': u'commit', u'type': u'STRING'}, {u'fields': [ {u'name': u'commit_id', u'type': u'STRING'}, {u'name': u'body', u'type': u'STRING'}, {u'name': u'created_at', u'type': u'STRING'}, {u'name': u'id', u'type': u'INTEGER'}, {u'name': u'original_commit_id', u'type': u'STRING'}, {u'name': u'original_position', u'type': u'INTEGER'}, {u'name': u'path', u'type': u'STRING'}, {u'name': u'position', u'type': u'INTEGER'}, {u'name': u'updated_at', u'type': u'STRING'}, {u'fields': [{u'name': u'id', u'type': u'INTEGER'}, {u'name': u'gravatar_id', u'type': u'STRING'}, {u'name': u'avatar_url', u'type': u'STRING'}, {u'name': u'login', u'type': u'STRING'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'user', u'type': u'RECORD'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'comment', u'type': u'RECORD'}, {u'name': u'comment_id', u'type': u'INTEGER'}, {u'name': u'desc', u'type': u'STRING'}, {u'name': u'description', u'type': u'STRING'}, {u'name': u'head', u'type': u'STRING'}, {u'name': u'id', u'type': u'INTEGER'}, {u'name': u'issue', u'type': u'INTEGER'}, {u'name': u'issue_id', u'type': u'INTEGER'}, {u'name': u'master_branch', u'type': u'STRING'}, {u'name': u'master', u'type': u'STRING'}, {u'fields': [{u'name': u'id', u'type': u'INTEGER'}, {u'name': u'gravatar_id', u'type': u'STRING'}, {u'name': u'avatar_url', u'type': u'STRING'}, {u'name': u'login', u'type': u'STRING'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'member', u'type': u'RECORD'}, {u'name': u'name', u'type': u'STRING'}, {u'name': u'number', u'type': u'INTEGER'}, {u'fields': [{u'name': u'action', u'type': u'STRING'}, {u'name': u'html_url', u'type': u'STRING'}, {u'name': u'page_name', u'type': u'STRING'}, {u'name': u'sha', u'type': u'STRING'}, {u'name': u'summary', u'type': u'STRING'}, {u'name': u'title', u'type': u'STRING'}], u'mode': u'REPEATED', u'name': u'pages', u'type': u'RECORD'}, {u'fields': [ {u'name': u'additions', u'type': u'INTEGER'}, {u'fields': [ {u'fields': [ {u'name': u'clone_url', u'type': u'STRING'}, {u'name': u'created_at', u'type': u'STRING'}, {u'name': u'description', u'type': u'STRING'}, {u'name': u'fork', u'type': u'BOOLEAN'}, {u'name': u'forks', u'type': u'INTEGER'}, {u'name': u'git_url', u'type': u'STRING'}, {u'name': u'has_downloads', u'type': u'BOOLEAN'}, {u'name': u'has_issues', u'type': u'BOOLEAN'}, {u'name': u'has_wiki', u'type': u'BOOLEAN'}, {u'name': u'homepage', u'type': u'STRING'}, {u'name': u'html_url', u'type': u'STRING'}, {u'name': u'id', u'type': u'INTEGER'}, {u'name': u'language', u'type': u'STRING'}, {u'name': u'master_branch', u'type': u'STRING'}, {u'name': u'name', u'type': u'STRING'}, {u'name': u'open_issues', u'type': u'INTEGER'}, {u'fields': [ {u'name': u'id', u'type': u'INTEGER'}, {u'name': u'gravatar_id', u'type': u'STRING'}, {u'name': u'avatar_url', u'type': u'STRING'}, {u'name': u'login', u'type': u'STRING'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'owner', u'type': u'RECORD'}, {u'name': u'private', u'type': u'BOOLEAN'}, {u'name': u'pushed_at', u'type': u'STRING'}, {u'name': u'size', u'type': u'INTEGER'}, {u'name': u'ssh_url', u'type': u'STRING'}, {u'name': u'svn_url', u'type': u'STRING'}, {u'name': u'updated_at', u'type': u'STRING'}, {u'name': u'watchers', u'type': u'INTEGER'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'repo', u'type': u'RECORD'}, {u'name': u'sha', u'type': u'STRING'}, {u'name': u'ref', u'type': u'STRING'}, {u'fields': [{u'name': u'id', u'type': u'INTEGER'}, {u'name': u'gravatar_id', u'type': u'STRING'}, {u'name': u'avatar_url', u'type': u'STRING'}, {u'name': u'login', u'type': u'STRING'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'user', u'type': u'RECORD'}, {u'name': u'label', u'type': u'STRING'}], u'name': u'base', u'type': u'RECORD'}, {u'name': u'body', u'type': u'STRING'}, {u'name': u'changed_files', u'type': u'INTEGER'}, {u'name': u'closed_at', u'type': u'STRING'}, {u'name': u'comments', u'type': u'INTEGER'}, {u'name': u'commits', u'type': u'INTEGER'}, {u'name': u'created_at', u'type': u'STRING'}, {u'name': u'deletions', u'type': u'INTEGER'}, {u'name': u'diff_url', u'type': u'STRING'}, {u'fields': [ {u'fields': [ {u'name': u'clone_url', u'type': u'STRING'}, {u'name': u'created_at', u'type': u'STRING'}, {u'name': u'description', u'type': u'STRING'}, {u'name': u'fork', u'type': u'BOOLEAN'}, {u'name': u'forks', u'type': u'INTEGER'}, {u'name': u'git_url', u'type': u'STRING'}, {u'name': u'has_downloads', u'type': u'BOOLEAN'}, {u'name': u'has_issues', u'type': u'BOOLEAN'}, {u'name': u'has_wiki', u'type': u'BOOLEAN'}, {u'name': u'homepage', u'type': u'STRING'}, {u'name': u'html_url', u'type': u'STRING'}, {u'name': u'id', u'type': u'INTEGER'}, {u'name': u'language', u'type': u'STRING'}, {u'name': u'master_branch', u'type': u'STRING'}, {u'name': u'name', u'type': u'STRING'}, {u'name': u'open_issues', u'type': u'INTEGER'}, {u'fields': [ {u'name': u'id', u'type': u'INTEGER'}, {u'name': u'gravatar_id', u'type': u'STRING'}, {u'name': u'avatar_url', u'type': u'STRING'}, {u'name': u'login', u'type': u'STRING'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'owner', u'type': u'RECORD'}, {u'name': u'private', u'type': u'BOOLEAN'}, {u'name': u'pushed_at', u'type': u'STRING'}, {u'name': u'size', u'type': u'INTEGER'}, {u'name': u'ssh_url', u'type': u'STRING'}, {u'name': u'svn_url', u'type': u'STRING'}, {u'name': u'updated_at', u'type': u'STRING'}, {u'name': u'watchers', u'type': u'INTEGER'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'repo', u'type': u'RECORD'}, {u'name': u'sha', u'type': u'STRING'}, {u'name': u'ref', u'type': u'STRING'}, {u'fields': [{u'name': u'id', u'type': u'INTEGER'}, {u'name': u'gravatar_id', u'type': u'STRING'}, {u'name': u'avatar_url', u'type': u'STRING'}, {u'name': u'login', u'type': u'STRING'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'user', u'type': u'RECORD'}, {u'name': u'label', u'type': u'STRING'}], u'name': u'head', u'type': u'RECORD'}, {u'name': u'html_url', u'type': u'STRING'}, {u'name': u'issue_url', u'type': u'STRING'}, {u'name': u'id', u'type': u'INTEGER'}, {u'fields': [ {u'fields': [ {u'name': u'href', u'type': u'STRING'}], u'name': u'self', u'type': u'RECORD'}, {u'fields': [{u'name': u'href', u'type': u'STRING'}], u'name': u'html', u'type': u'RECORD'}, {u'fields': [{u'name': u'href', u'type': u'STRING'}], u'name': u'review_comments', u'type': u'RECORD'}, {u'fields': [{u'name': u'href', u'type': u'STRING'}], u'name': u'comments', u'type': u'RECORD'}], u'name': u'_links', u'type': u'RECORD'}, {u'name': u'mergeable', u'type': u'BOOLEAN'}, {u'name': u'merged', u'type': u'BOOLEAN'}, {u'name': u'merged_at', u'type': u'STRING'}, {u'fields': [{u'name': u'id', u'type': u'INTEGER'}, {u'name': u'gravatar_id', u'type': u'STRING'}, {u'name': u'avatar_url', u'type': u'STRING'}, {u'name': u'login', u'type': u'STRING'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'merged_by', u'type': u'RECORD'}, {u'name': u'number', u'type': u'INTEGER'}, {u'name': u'patch_url', u'type': u'STRING'}, {u'name': u'review_comments', u'type': u'INTEGER'}, {u'name': u'state', u'type': u'STRING'}, {u'name': u'title', u'type': u'STRING'}, {u'name': u'updated_at', u'type': u'STRING'}, {u'name': u'url', u'type': u'STRING'}, {u'fields': [{u'name': u'id', u'type': u'INTEGER'}, {u'name': u'gravatar_id', u'type': u'STRING'}, {u'name': u'avatar_url', u'type': u'STRING'}, {u'name': u'login', u'type': u'STRING'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'user', u'type': u'RECORD'}], u'name': u'pull_request', u'type': u'RECORD'}, {u'name': u'ref', u'type': u'STRING'}, {u'name': u'ref_type', u'type': u'STRING'}, {u'name': u'size', u'type': u'INTEGER'}, {u'fields': [ {u'name': u'encoded', u'type': u'STRING'}, {u'name': u'actor_email', u'type': u'STRING'}, {u'name': u'message', u'type': u'STRING'}, {u'name': u'actor_login', u'type': u'STRING'}], u'mode': u'REPEATED', u'name': u'shas', u'type': u'RECORD'}, {u'fields': [{u'name': u'login', u'type': u'STRING'}, {u'name': u'repos', u'type': u'INTEGER'}, {u'name': u'followers', u'type': u'INTEGER'}, {u'name': u'id', u'type': u'INTEGER'}, {u'name': u'gravatar_id', u'type': u'STRING'}], u'name': u'target', u'type': u'RECORD'}, {u'name': u'url', u'type': u'STRING'}], u'name': u'payload', u'type': u'RECORD'}, {u'name': u'url', u'type': u'STRING'}, {u'name': u'type', u'type': u'STRING'}] parsed = {u'actor': u'foo', u'actor_attributes': {u'blog': u'http://foo.com/', u'company': u'Flickr', u'email': u'd+github@foo.com', u'gravatar_id': u'94c21234567890abcdef25e704b88407', u'location': u'San Francisco, California', u'login': u'foo', u'name': u'Foo Bar', u'type': u'User'}, u'created_at': u'2012/03/15 00:00:01 -0700', u'payload': {u'action': None, u'after': None, u'before': None, u'comment': {}, u'comment_id': None, u'commit': None, u'desc': None, u'description': None, u'head': u'2de950123456789abcdef01234451feaf8ce6ae', u'id': None, u'issue': None, u'issue_id': None, u'master': None, u'master_branch': None, u'member': {}, u'name': None, u'number': None, u'pages': [], u'pull_request': {}, u'ref': u'refs/heads/master', u'ref_type': None, u'shas': [{u'actor_email': u'd+github@foo.com', u'actor_login': u'Foo Bar', u'encoded': u'2de958ab480eabe2501b343425b451feaf8ce6ae', u'message': u'Foo tastes good.'}], u'size': 1, u'target': {}, u'url': None}, u'public': True, u'repository': {u'created_at': u'2011/04/12 20:04:19 -0700', u'description': u'A website.', u'fork': False, u'forks': 17, u'has_downloads': True, u'has_issues': True, u'has_wiki': True, u'homepage': u'http://foo.com/', u'integrate_branch': None, u'language': u'Ruby', u'master_branch': None, u'name': u'foo', u'open_issues': 0, u'organization': None, u'owner': u'foo', u'private': False, u'pushed_at': u'2012/03/15 00:00:00 -0700', u'size': 424, u'url': u'https://github.com/foo', u'watchers': 95}, u'type': u'PushEvent', u'url': u'https://github.com/compare/d3e91cb736...2de958ab48'} self.assertEqual(parsed, bq._parser.Parser.parse_row(github_nested_schema, data)) ================================================ FILE: tests/bigquery/pipeline_tests.py ================================================ #!/usr/bin/python # -*- coding: utf-8 -*- # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. import google import google.auth import google.datalab.contrib.bigquery.commands._bigquery as bq import mock import re import unittest class TestCases(unittest.TestCase): test_input_config = { 'path': 'test_path_%(_ts_month)s', 'table': 'test_table', 'schema': 'test_schema', 'mode': 'append', 'format': 'csv', 'csv': { 'delimiter': ';', 'skip': 9, 'strict': False, 'quote': '"' }, } test_pipeline_config = { 'input': test_input_config, } @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return google.datalab.Context(project_id, creds) def assertPipelineConfigEquals(self, actual, expected, expected_pipeline_parameters): expected_copy = expected.copy() actual_execute_config = actual['tasks'].get('bq_pipeline_execute_task', None) expected_execute_config = expected_copy['tasks'].get('bq_pipeline_execute_task', None) if actual_execute_config and expected_execute_config: self.assertExecuteConfigEquals(actual_execute_config, expected_execute_config, expected_pipeline_parameters) del actual['tasks']['bq_pipeline_execute_task'] del expected_copy['tasks']['bq_pipeline_execute_task'] actual_params = actual['parameters'] or [] actual_paramaters_dict = {item['name']: (item['value'], item['type']) for item in actual_params} expected_pipeline_parameters = expected_pipeline_parameters or [] expected_paramaters_dict = {item['name']: (item['value'], item['type']) for item in expected_pipeline_parameters} self.assertDictEqual(actual_paramaters_dict, expected_paramaters_dict) del actual['parameters'] self.assertDictEqual(actual, expected_copy) @mock.patch('google.datalab.utils.commands.get_notebook_item') def test_get_pipeline_spec_from_config(self, mock_notebook_item): mock_notebook_item.return_value = google.datalab.bigquery.Query('foo_query_sql_string') # empty pipeline_spec with self.assertRaisesRegexp(Exception, 'Pipeline has no tasks to execute.'): bq._get_pipeline_spec_from_config({}) # empty input , transformation, output as path pipeline_config = { 'transformation': { 'query': 'foo_query' }, 'output': { 'path': 'foo_table' } } expected = { 'tasks': { 'bq_pipeline_execute_task': { 'sql': u'foo_query_sql_string', 'type': 'pydatalab.bq.execute', }, 'bq_pipeline_extract_task': { 'table': """{{ ti.xcom_pull(task_ids='bq_pipeline_execute_task_id').get('table') }}""", 'path': 'foo_table', 'type': 'pydatalab.bq.extract', 'up_stream': ['bq_pipeline_execute_task'] } } } actual = bq._get_pipeline_spec_from_config(pipeline_config) self.assertPipelineConfigEquals(actual, expected, None) # input as path, transformation, output as path pipeline_config = { 'input': { 'path': 'foo_path', 'data_source': 'foo_data_source', }, 'transformation': { 'query': 'foo_query' }, 'output': { 'path': 'foo_table' } } expected = { 'tasks': { 'bq_pipeline_execute_task': { 'sql': u'foo_query_sql_string', 'data_source': 'foo_data_source', 'path': 'foo_path', 'type': 'pydatalab.bq.execute', }, 'bq_pipeline_extract_task': { 'table': """{{ ti.xcom_pull(task_ids='bq_pipeline_execute_task_id').get('table') }}""", 'path': 'foo_table', 'type': 'pydatalab.bq.extract', 'up_stream': ['bq_pipeline_execute_task'] } } } actual = bq._get_pipeline_spec_from_config(pipeline_config) self.assertPipelineConfigEquals(actual, expected, None) # input as path->table, transformation, output as path pipeline_config = { 'input': { 'path': 'foo_path', 'table': 'foo_table_1' }, 'transformation': { 'query': 'foo_query' }, 'output': { 'path': 'foo_path_2' } } expected = { 'tasks': { 'bq_pipeline_load_task': { 'type': 'pydatalab.bq.load', 'path': 'foo_path', 'table': 'foo_table_1', }, 'bq_pipeline_execute_task': { 'sql': u'WITH input AS (\n SELECT * FROM `foo_table_1`\n)\n\nfoo_query_sql_string', 'type': 'pydatalab.bq.execute', 'up_stream': ['bq_pipeline_load_task'], }, 'bq_pipeline_extract_task': { 'table': """{{ ti.xcom_pull(task_ids='bq_pipeline_execute_task_id').get('table') }}""", 'path': 'foo_path_2', 'type': 'pydatalab.bq.extract', 'up_stream': ['bq_pipeline_execute_task'] } } } actual = bq._get_pipeline_spec_from_config(pipeline_config) self.assertPipelineConfigEquals(actual, expected, None) # input as table, transformation, output as path pipeline_config = { 'input': { 'table': 'foo_table_1' }, 'transformation': { 'query': 'foo_query' }, 'output': { 'path': 'foo_path_2' } } expected = { 'tasks': { 'bq_pipeline_execute_task': { 'sql': u'WITH input AS (\n SELECT * FROM `foo_table_1`\n)\n\nfoo_query_sql_string', 'type': 'pydatalab.bq.execute', }, 'bq_pipeline_extract_task': { 'table': """{{ ti.xcom_pull(task_ids='bq_pipeline_execute_task_id').get('table') }}""", 'path': 'foo_path_2', 'type': 'pydatalab.bq.extract', 'up_stream': ['bq_pipeline_execute_task'] } } } actual = bq._get_pipeline_spec_from_config(pipeline_config) self.assertPipelineConfigEquals(actual, expected, None) # input as table, transformation, output as table pipeline_config = { 'input': { 'table': 'foo_table_1' }, 'transformation': { 'query': 'foo_query' }, 'output': { 'table': 'foo_table_1' } } expected = { 'tasks': { 'bq_pipeline_execute_task': { 'sql': u'WITH input AS (\n SELECT * FROM `foo_table_1`\n)\n\nfoo_query_sql_string', 'type': 'pydatalab.bq.execute', 'table': 'foo_table_1', }, } } actual = bq._get_pipeline_spec_from_config(pipeline_config) self.assertPipelineConfigEquals(actual, expected, None) # input as table, no transformation, output as path pipeline_config = { 'input': { 'table': 'foo_table' }, 'output': { 'path': 'foo_path' } } expected = { 'tasks': { 'bq_pipeline_extract_task': { 'type': 'pydatalab.bq.extract', 'path': 'foo_path', 'table': 'foo_table' }, } } actual = bq._get_pipeline_spec_from_config(pipeline_config) self.assertPipelineConfigEquals(actual, expected, None) # output only; this should be identical to the above pipeline_config = { 'output': { 'table': 'foo_table', 'path': 'foo_path' } } actual = bq._get_pipeline_spec_from_config(pipeline_config) self.assertPipelineConfigEquals(actual, expected, None) # output can also be called extract, and it should be identical to the above pipeline_config = { 'extract': { 'table': 'foo_table', 'path': 'foo_path' } } actual = bq._get_pipeline_spec_from_config(pipeline_config) self.assertPipelineConfigEquals(actual, expected, None) # input as path, no transformation, output as table pipeline_config = { 'input': { 'path': 'foo_path' }, 'output': { 'table': 'foo_table' } } expected = { 'tasks': { 'bq_pipeline_load_task': { 'type': 'pydatalab.bq.load', 'path': 'foo_path', 'table': 'foo_table' }, } } actual = bq._get_pipeline_spec_from_config(pipeline_config) self.assertPipelineConfigEquals(actual, expected, None) # input only; this should be identical to the above pipeline_config = { 'input': { 'path': 'foo_path', 'table': 'foo_table' }, } actual = bq._get_pipeline_spec_from_config(pipeline_config) self.assertPipelineConfigEquals(actual, expected, None) # input can also be called load, and it should be identical to the above pipeline_config = { 'load': { 'path': 'foo_path', 'table': 'foo_table' }, } actual = bq._get_pipeline_spec_from_config(pipeline_config) self.assertPipelineConfigEquals(actual, expected, None) # only transformation pipeline_config = { 'transformation': { 'query': 'foo_query' }, } expected = { 'tasks': { 'bq_pipeline_execute_task': { 'sql': u'foo_query_sql_string', 'type': 'pydatalab.bq.execute', }, } } actual = bq._get_pipeline_spec_from_config(pipeline_config) self.assertPipelineConfigEquals(actual, expected, None) user_parameters = [ {'name': 'foo1', 'value': 'foo1', 'type': 'STRING'}, {'name': 'foo2', 'value': 'foo2', 'type': 'INTEGER'}, ] # only transformation with parameters pipeline_config = { 'transformation': { 'query': 'foo_query' }, 'parameters': user_parameters } expected = { 'tasks': { 'bq_pipeline_execute_task': { 'sql': u'foo_query_sql_string', 'type': 'pydatalab.bq.execute', }, } } actual = bq._get_pipeline_spec_from_config(pipeline_config) self.assertPipelineConfigEquals(actual, expected, user_parameters) def test_get_load_parameters(self): actual_load_config = bq._get_load_parameters(TestCases.test_input_config, None, None) expected_load_config = { 'type': 'pydatalab.bq.load', 'path': 'test_path_%(_ts_month)s', 'table': 'test_table', 'schema': 'test_schema', 'mode': 'append', 'format': 'csv', 'csv_options': {'delimiter': ';', 'quote': '"', 'skip': 9, 'strict': False}, } self.assertDictEqual(actual_load_config, expected_load_config) # Table is present in output config input_config = { 'path': 'test_path_%(_ts_month)s', 'format': 'csv', 'csv': {'delimiter': ';', 'quote': '"', 'skip': 9, 'strict': False}, } output_config = { 'table': 'test_table', 'schema': 'test_schema', 'mode': 'append', } actual_load_config = bq._get_load_parameters(input_config, None, output_config) self.assertDictEqual(actual_load_config, expected_load_config) # Path is absent input_config = { 'table': 'test_table', 'schema': 'test_schema' } actual_load_config = bq._get_load_parameters(input_config, None, None) self.assertIsNone(actual_load_config) # Path and table are absent input_config = { 'schema': 'test_schema' } actual_load_config = bq._get_load_parameters(input_config, None, None) self.assertIsNone(actual_load_config) # Table is absent input_config = { 'path': 'test_path', 'schema': 'test_schema' } actual_load_config = bq._get_load_parameters(input_config, None, None) self.assertIsNone(actual_load_config) def test_get_extract_parameters(self): output_config = { 'path': 'test_path_%(_ts_month)s', 'table': 'test_table_%(_ts_month)s', } actual_extract_config = bq._get_extract_parameters('foo_execute_task', None, None, output_config) expected_extract_config = { 'type': 'pydatalab.bq.extract', 'up_stream': ['foo_execute_task'], 'path': 'test_path_%(_ts_month)s', 'table': 'test_table_%(_ts_month)s', } self.assertDictEqual(actual_extract_config, expected_extract_config) input_config = { 'table': 'test_table_%(_ts_month)s', } output_config = { 'path': 'test_path_%(_ts_month)s', } actual_extract_config = bq._get_extract_parameters('foo_execute_task', input_config, None, output_config) self.assertDictEqual(actual_extract_config, expected_extract_config) @mock.patch('google.datalab.utils.commands.get_notebook_item') def test_get_execute_parameters(self, mock_notebook_item): mock_notebook_item.return_value = google.datalab.bigquery.Query("""SELECT @column FROM publicdata.samples.wikipedia WHERE endpoint=@endpoint""") transformation_config = { 'query': 'foo_query' } output_config = { 'table': 'foo_table_%(_ts_month)s', 'mode': 'foo_mode' } parameters_config = [ { 'type': 'STRING', 'name': 'endpoint', 'value': 'Interact2' }, { 'type': 'INTEGER', 'name': 'column', 'value': '1234' } ] # Empty input config actual_execute_config = bq._get_execute_parameters('foo_load_task', {}, transformation_config, output_config, parameters_config) expected_execute_config = { 'type': 'pydatalab.bq.execute', 'up_stream': ['foo_load_task'], 'sql': 'SELECT @column\nFROM publicdata.samples.wikipedia\nWHERE endpoint=@endpoint', 'table': 'foo_table_%(_ts_month)s', 'mode': 'foo_mode', } self.assertExecuteConfigEquals(actual_execute_config, expected_execute_config, parameters_config) # Empty input and parameters config actual_execute_config = bq._get_execute_parameters('foo_load_task', {}, transformation_config, output_config, None) expected_execute_config = { 'type': 'pydatalab.bq.execute', 'up_stream': ['foo_load_task'], 'sql': 'SELECT @column\nFROM publicdata.samples.wikipedia\nWHERE endpoint=@endpoint', 'table': 'foo_table_%(_ts_month)s', 'mode': 'foo_mode', } self.assertExecuteConfigEquals(actual_execute_config, expected_execute_config, None) # Empty input and empty output configs actual_execute_config = bq._get_execute_parameters('foo_load_task', {}, transformation_config, {}, parameters_config) expected_execute_config = { 'type': 'pydatalab.bq.execute', 'up_stream': ['foo_load_task'], 'sql': 'SELECT @column\nFROM publicdata.samples.wikipedia\nWHERE endpoint=@endpoint', } self.assertExecuteConfigEquals(actual_execute_config, expected_execute_config, parameters_config) # Empty output config. Expected config is same as output with empty input and empty output. actual_execute_config = bq._get_execute_parameters('foo_load_task', TestCases.test_input_config, transformation_config, {}, parameters_config) expected_execute_config = { 'type': 'pydatalab.bq.execute', 'up_stream': ['foo_load_task'], 'sql': """WITH input AS ( SELECT * FROM `test_table` ) SELECT @column FROM publicdata.samples.wikipedia WHERE endpoint=@endpoint""", } self.assertExecuteConfigEquals(actual_execute_config, expected_execute_config, parameters_config) # With no table, and implicit data_source input_config = TestCases.test_input_config.copy() del input_config['table'] actual_execute_config = bq._get_execute_parameters('foo_load_task', input_config, transformation_config, {}, parameters_config) expected_execute_config = { 'type': 'pydatalab.bq.execute', 'up_stream': ['foo_load_task'], 'sql': 'SELECT @column\nFROM publicdata.samples.wikipedia\nWHERE endpoint=@endpoint', 'data_source': 'input', 'path': 'test_path_%(_ts_month)s', 'schema': 'test_schema', 'source_format': 'csv', 'csv_options': {'delimiter': ';', 'quote': '"', 'skip': 9, 'strict': False}, } self.assertExecuteConfigEquals(actual_execute_config, expected_execute_config, parameters_config) # With no table, and explicit data_source input_config['data_source'] = 'foo_data_source' actual_execute_config = bq._get_execute_parameters('foo_load_task', input_config, transformation_config, {}, parameters_config) expected_execute_config = { 'type': 'pydatalab.bq.execute', 'up_stream': ['foo_load_task'], 'sql': 'SELECT @column\nFROM publicdata.samples.wikipedia\nWHERE endpoint=@endpoint', 'data_source': 'foo_data_source', 'path': 'test_path_%(_ts_month)s', 'schema': 'test_schema', 'source_format': 'csv', 'csv_options': {'delimiter': ';', 'quote': '"', 'skip': 9, 'strict': False}, } self.assertExecuteConfigEquals(actual_execute_config, expected_execute_config, parameters_config) # With table and implicit sub-query mock_notebook_item.return_value = google.datalab.bigquery.Query("""SELECT @column FROM input WHERE endpoint=@endpoint""") input_config = { 'path': 'test_path_%(_ds)s', 'table': 'test_table_%(_ds)s', } actual_execute_config = bq._get_execute_parameters(None, input_config, transformation_config, {}, parameters_config) expected_execute_config = { 'type': 'pydatalab.bq.execute', 'sql': """WITH input AS ( SELECT * FROM `test_table_{{ ds }}` ) SELECT @column FROM input WHERE endpoint=@endpoint""" } self.assertExecuteConfigEquals(actual_execute_config, expected_execute_config, parameters_config) def assertExecuteConfigEquals(self, actual_execute_config, expected_execute_config, parameters_config): actual_parameters = actual_execute_config['parameters'] if actual_execute_config else [] self.compare_parameters(actual_parameters, parameters_config) if actual_execute_config: del actual_execute_config['parameters'] self.assertDictEqual(actual_execute_config, expected_execute_config) def compare_parameters(self, actual_parameters, user_parameters): actual_paramaters_dict = user_parameters_dict = {} if actual_parameters: actual_paramaters_dict = {item['name']: (item['value'], item['type']) for item in actual_parameters} if user_parameters: user_parameters_dict = {item['name']: (item['value'], item['type']) for item in user_parameters} self.assertDictEqual(actual_paramaters_dict, user_parameters_dict) @mock.patch('google.datalab.contrib.pipeline.composer._api.Api.get_environment_details') @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.utils.commands.notebook_environment') @mock.patch('google.datalab.utils.commands.get_notebook_item') @mock.patch('google.datalab.bigquery.Table.exists') @mock.patch('google.datalab.bigquery.commands._bigquery._get_table') @mock.patch('google.datalab.storage.Bucket') def test_pipeline_cell_golden(self, mock_bucket_class, mock_get_table, mock_table_exists, mock_notebook_item, mock_environment, mock_default_context, mock_composer_env): import google.datalab.contrib.pipeline.airflow table = google.datalab.bigquery.Table('project.test.table') mock_get_table.return_value = table mock_table_exists.return_value = True context = TestCases._create_context() mock_default_context.return_value = context mock_composer_env.return_value = { 'config': {'gcsDagLocation': 'gs://foo_bucket/dags'} } env = { 'endpoint': 'Interact2', 'job_id': '1234', 'input_table_format': 'cloud-datalab-samples.httplogs.logs_%(_ds_nodash)s', 'output_table_format': 'cloud-datalab-samples.endpoints.logs_%(_ds_nodash)s' } mock_notebook_item.return_value = google.datalab.bigquery.Query( 'SELECT @column FROM input where endpoint=@endpoint') mock_environment.return_value = env name = 'bq_pipeline_test' args = {'name': name, 'environment': 'foo_environment', 'location': 'foo_location', 'gcs_dag_bucket': 'foo_bucket', 'gcs_dag_file_path': 'foo_file_path', 'debug': True} cell_body = """ emails: foo1@test.com,foo2@test.com schedule: start: 2009-05-05T22:28:15Z end: 2009-05-06T22:28:15Z interval: '@hourly' input: path: gs://bucket/cloud-datalab-samples-httplogs_%(_ds_nodash)s table: $input_table_format csv: header: True strict: False quote: '"' skip: 5 delimiter: ',' schema: - name: col1 type: int64 mode: NULLABLE description: description1 - name: col2 type: STRING mode: required description: description1 transformation: query: foo_query output: path: gs://bucket/cloud-datalab-samples-endpoints_%(_ds_nodash)s.csv table: $output_table_format parameters: - name: endpoint type: STRING value: $endpoint - name: column type: INTEGER value: $job_id """ output = google.datalab.bigquery.commands._bigquery._pipeline_cell(args, cell_body) error_message = ("Airflow pipeline successfully deployed! View dashboard for more details.\n" "Composer pipeline successfully deployed! View dashboard for more details.\n") airflow_spec_pattern = """ import datetime from airflow import DAG from airflow.operators.bash_operator import BashOperator from airflow.contrib.operators.bigquery_operator import BigQueryOperator from airflow.contrib.operators.bigquery_table_delete_operator import BigQueryTableDeleteOperator from airflow.contrib.operators.bigquery_to_bigquery import BigQueryToBigQueryOperator from airflow.contrib.operators.bigquery_to_gcs import BigQueryToCloudStorageOperator from airflow.contrib.operators.gcs_to_bq import GoogleCloudStorageToBigQueryOperator from google.datalab.contrib.bigquery.operators._bq_load_operator import LoadOperator from google.datalab.contrib.bigquery.operators._bq_execute_operator import ExecuteOperator from google.datalab.contrib.bigquery.operators._bq_extract_operator import ExtractOperator from datetime import timedelta default_args = { 'owner': 'Google Cloud Datalab', 'email': \['foo1@test.com', 'foo2@test.com'\], 'start_date': datetime.datetime.strptime\('2009-05-05T22:28:15', '%Y-%m-%dT%H:%M:%S'\), 'end_date': datetime.datetime.strptime\('2009-05-06T22:28:15', '%Y-%m-%dT%H:%M:%S'\), } dag = DAG\(dag_id='bq_pipeline_test', schedule_interval='@hourly', catchup=False, default_args=default_args\) bq_pipeline_execute_task = ExecuteOperator\(task_id='bq_pipeline_execute_task_id', parameters=(.*), sql=\"\"\"WITH input AS \( SELECT \* FROM `cloud-datalab-samples\.httplogs\.logs_{{ ds_nodash }}` \) SELECT @column FROM input where endpoint=@endpoint\"\"\", table=\"\"\"cloud-datalab-samples\.endpoints\.logs_{{ ds_nodash }}\"\"\", dag=dag\) bq_pipeline_extract_task = ExtractOperator\(task_id='bq_pipeline_extract_task_id', path=\"\"\"gs://bucket/cloud-datalab-samples-endpoints_{{ ds_nodash }}\.csv\"\"\", table=\"\"\"cloud-datalab-samples\.endpoints\.logs_{{ ds_nodash }}\"\"\", dag=dag\).* bq_pipeline_load_task = LoadOperator\(task_id='bq_pipeline_load_task_id', csv_options=(.*), path=\"\"\"gs://bucket/cloud-datalab-samples-httplogs_{{ ds_nodash }}\"\"\", schema=(.*), table=\"\"\"cloud-datalab-samples\.httplogs\.logs_{{ ds_nodash }}\"\"\", dag=dag\).* bq_pipeline_execute_task.set_upstream\(bq_pipeline_load_task\) bq_pipeline_extract_task.set_upstream\(bq_pipeline_execute_task\) """ # noqa pattern = re.compile(error_message + '\n\n' + airflow_spec_pattern) self.assertIsNotNone(pattern.match(output)) # String that follows the "parameters=", for the execute operator. actual_parameter_dict_str = pattern.match(output).group(1) self.assertIn("'type': 'STRING'", actual_parameter_dict_str) self.assertIn("'name': 'endpoint'", actual_parameter_dict_str) self.assertIn("'value': 'Interact2'", actual_parameter_dict_str) self.assertIn("'type': 'INTEGER'", actual_parameter_dict_str) self.assertIn("'name': 'column'", actual_parameter_dict_str) self.assertIn("'value': '1234'", actual_parameter_dict_str) # String that follows the "csv_options=", for the load operator. actual_csv_options_dict_str = pattern.match(output).group(2) self.assertIn("'header': True", actual_csv_options_dict_str) self.assertIn("'delimiter': ','", actual_csv_options_dict_str) self.assertIn("'skip': 5", actual_csv_options_dict_str) self.assertIn("'strict': False", actual_csv_options_dict_str) self.assertIn("'quote': '\"'", actual_csv_options_dict_str) # String that follows the "schema=", i.e. the list of dicts. actual_schema_str = pattern.match(output).group(3) self.assertIn("'type': 'int64'", actual_schema_str) self.assertIn("'mode': 'NULLABLE'", actual_schema_str) self.assertIn("'name': 'col1'", actual_schema_str) self.assertIn("'description': 'description1'", actual_schema_str) self.assertIn("'type': 'STRING'", actual_schema_str) self.assertIn("'mode': 'required'", actual_schema_str) self.assertIn("'name': 'col2'", actual_schema_str) self.assertIn("'description': 'description1'", actual_schema_str) import google.datalab.utils as utils cell_body_dict = utils.commands.parse_config(cell_body, utils.commands.notebook_environment()) expected_airflow_spec = \ google.datalab.contrib.bigquery.commands._bigquery.get_airflow_spec_from_config( name, cell_body_dict) mock_bucket_class.assert_called_with('foo_bucket') mock_bucket_class.return_value.object.assert_called_with('dags/bq_pipeline_test.py') mock_bucket_class.return_value.object.return_value.write_stream.assert_called_with( expected_airflow_spec, 'text/plain') ================================================ FILE: tests/bigquery/query_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals from builtins import str import datetime import mock import unittest import google.auth import google.datalab import google.datalab.bigquery class TestCases(unittest.TestCase): def test_parameter_validation(self): sql = 'SELECT * FROM table' with self.assertRaises(Exception): TestCases._create_query(sql, subqueries=['subquery']) sq = TestCases._create_query() env = {'subquery': sq} q = TestCases._create_query(sql, env=env, subqueries=['subquery']) self.assertIsNotNone(q) self.assertEqual(q.subqueries, {'subquery': sq}) self.assertEqual(q._sql, sql) with self.assertRaises(Exception): TestCases._create_query(sql, udfs=['udf']) udf = TestCases._create_udf('test_udf', 'code', 'TYPE') env = {'testudf': udf} q = TestCases._create_query(sql, env=env, udfs=['testudf']) self.assertIsNotNone(q) self.assertEqual(q.udfs, {'testudf': udf}) self.assertEqual(q._sql, sql) with self.assertRaises(Exception): TestCases._create_query(sql, data_sources=['test_datasource']) test_datasource = TestCases._create_data_source('gs://test/path') env = {'test_datasource': test_datasource} q = TestCases._create_query(sql, env=env, data_sources=['test_datasource']) self.assertIsNotNone(q) self.assertEqual(q.data_sources, {'test_datasource': test_datasource}) self.assertEqual(q._sql, sql) def test_query_with_udf_object(self): udf = TestCases._create_udf('test_udf', 'udf body', 'TYPE') q = TestCases._create_query('SELECT * FROM table', udfs={'test_udf': udf}) self.assertIn('udf body', q.sql) @mock.patch('google.datalab.bigquery._api.Api.tabledata_list') @mock.patch('google.datalab.bigquery._api.Api.jobs_insert_query') @mock.patch('google.datalab.bigquery._api.Api.jobs_query_results') @mock.patch('google.datalab.bigquery._api.Api.jobs_get') @mock.patch('google.datalab.bigquery._api.Api.tables_get') def test_single_result_query(self, mock_api_tables_get, mock_api_jobs_get, mock_api_jobs_query_results, mock_api_insert_query, mock_api_tabledata_list): mock_api_tables_get.return_value = TestCases._create_tables_get_result() mock_api_jobs_get.return_value = {'status': {'state': 'DONE'}} mock_api_jobs_query_results.return_value = {'jobComplete': True} mock_api_insert_query.return_value = TestCases._create_insert_done_result() mock_api_tabledata_list.return_value = TestCases._create_single_row_result() sql = 'SELECT field1 FROM [table] LIMIT 1' q = TestCases._create_query(sql) context = TestCases._create_context() results = q.execute(context=context).result() self.assertEqual(sql, results.sql) self.assertEqual('(%s)' % sql, q._repr_sql_()) self.assertEqual(1, results.length) first_result = results[0] self.assertEqual('value1', first_result['field1']) @mock.patch('google.datalab.bigquery._api.Api.jobs_insert_query') @mock.patch('google.datalab.bigquery._api.Api.jobs_query_results') @mock.patch('google.datalab.bigquery._api.Api.jobs_get') @mock.patch('google.datalab.bigquery._api.Api.tables_get') def test_empty_result_query(self, mock_api_tables_get, mock_api_jobs_get, mock_api_jobs_query_results, mock_api_insert_query): mock_api_tables_get.return_value = TestCases._create_tables_get_result(0) mock_api_jobs_get.return_value = {'status': {'state': 'DONE'}} mock_api_jobs_query_results.return_value = {'jobComplete': True} mock_api_insert_query.return_value = TestCases._create_insert_done_result() q = TestCases._create_query() context = TestCases._create_context() results = q.execute(context=context).result() self.assertEqual(0, results.length) @mock.patch('google.datalab.bigquery._api.Api.jobs_insert_query') @mock.patch('google.datalab.bigquery._api.Api.jobs_query_results') @mock.patch('google.datalab.bigquery._api.Api.jobs_get') @mock.patch('google.datalab.bigquery._api.Api.tables_get') def test_incomplete_result_query(self, mock_api_tables_get, mock_api_jobs_get, mock_api_jobs_query_results, mock_api_insert_query): mock_api_tables_get.return_value = TestCases._create_tables_get_result() mock_api_jobs_get.return_value = {'status': {'state': 'DONE'}} mock_api_jobs_query_results.return_value = {'jobComplete': True} mock_api_insert_query.return_value = TestCases._create_incomplete_result() q = TestCases._create_query() context = TestCases._create_context() results = q.execute(context=context).result() self.assertEqual(1, results.length) self.assertEqual('test_job', results.job_id) @mock.patch('google.datalab.bigquery._api.Api.jobs_insert_query') def test_malformed_response_raises_exception(self, mock_api_insert_query): mock_api_insert_query.return_value = {} q = TestCases._create_query() with self.assertRaises(Exception) as error: context = TestCases._create_context() q.execute(context=context).result() self.assertEqual('Unexpected response from server', str(error.exception)) def test_nested_subquery_expansion(self): # test expanding subquery and udf validation with self.assertRaises(Exception): TestCases._create_query('SELECT * FROM subquery', subqueries=['subquery']) with self.assertRaises(Exception): TestCases._create_query('SELECT test_udf(field1) FROM test_table', udfs=['test_udf']) env = {} # test direct subquery expansion q1 = TestCases._create_query('SELECT * FROM test_table', name='q1', env=env) q2 = TestCases._create_query('SELECT * FROM q1', name='q2', subqueries=['q1'], env=env) self.assertEqual('''\ WITH q1 AS ( SELECT * FROM test_table ) SELECT * FROM q1''', q2.sql) # test recursive, second level subquery expansion q3 = TestCases._create_query('SELECT * FROM q2', name='q3', subqueries=['q2'], env=env) # subquery listing order is random, try both possibilities expected_sql1 = '''\ WITH q1 AS ( %s ), q2 AS ( %s ) %s''' % (q1._sql, q2._sql, q3._sql) expected_sql2 = '''\ WITH q2 AS ( %s ), q1 AS ( %s ) %s''' % (q2._sql, q1._sql, q3._sql) self.assertTrue((expected_sql1 == q3.sql) or (expected_sql2 == q3.sql)) # @mock.patch('google.datalab.bigquery._api.Api.jobs_insert_query') def test_subquery_expansion_order(self): env = {} TestCases._create_query('SELECT * FROM test_table', name='snps', env=env) TestCases._create_query('SELECT * FROM snps', subqueries=['snps'], name='windows', env=env) titv = TestCases._create_query('SELECT * FROM snps, windows', subqueries=['snps', 'windows'], env=env) # make sure snps appears before windows in the expanded sql of titv snps_pos, windows_pos = titv.sql.find('snps AS'), titv.sql.find('windows AS') self.assertNotEqual(snps_pos, -1, 'Could not find snps definition in expanded sql') self.assertNotEqual(windows_pos, -1, 'Could not find windows definition in expanded sql') self.assertLess(snps_pos, windows_pos) # reverse the order they're referenced in titv, and make sure snps still appears before windows titv = TestCases._create_query('SELECT * FROM snps, windows', subqueries=['windows', 'snps'], env=env) snps_pos, windows_pos = titv.sql.find('snps AS'), titv.sql.find('windows AS') self.assertNotEqual(snps_pos, -1, 'Could not find snps definition in expanded sql') self.assertNotEqual(windows_pos, -1, 'Could not find windows definition in expanded sql') self.assertLess(snps_pos, windows_pos) @staticmethod def _create_query(sql='SELECT * ...', name=None, env=None, udfs=None, data_sources=None, subqueries=None): if env is None: env = {} q = google.datalab.bigquery.Query(sql, env=env, udfs=udfs, data_sources=data_sources, subqueries=subqueries) if name: env[name] = q return q @staticmethod def _create_udf(name, code, return_type): return google.datalab.bigquery.UDF(name, code, return_type) @staticmethod def _create_data_source(source): return google.datalab.bigquery.ExternalDataSource(source=source) @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return google.datalab.Context(project_id, creds) @staticmethod def _create_insert_done_result(): # pylint: disable=g-continuation-in-parens-misaligned return { 'jobReference': { 'jobId': 'test_job' }, 'configuration': { 'query': { 'destinationTable': { 'projectId': 'project', 'datasetId': 'dataset', 'tableId': 'table' } } }, 'jobComplete': True, } @staticmethod def _create_single_row_result(): # pylint: disable=g-continuation-in-parens-misaligned return { 'totalRows': 1, 'rows': [ {'f': [{'v': 'value1'}]} ] } @staticmethod def _create_empty_result(): # pylint: disable=g-continuation-in-parens-misaligned return { 'totalRows': 0 } @staticmethod def _create_incomplete_result(): # pylint: disable=g-continuation-in-parens-misaligned return { 'jobReference': { 'jobId': 'test_job' }, 'configuration': { 'query': { 'destinationTable': { 'projectId': 'project', 'datasetId': 'dataset', 'tableId': 'table' } } }, 'jobComplete': False } @staticmethod def _create_page_result(page_token=None): # pylint: disable=g-continuation-in-parens-misaligned return { 'totalRows': 2, 'rows': [ {'f': [{'v': 'value1'}]} ], 'pageToken': page_token } @staticmethod def _create_tables_get_result(num_rows=1, schema=None): if schema is None: schema = [{'name': 'field1', 'type': 'string'}] return { 'numRows': num_rows, 'schema': { 'fields': schema }, } def test_merged_parameters(self): parameters = [ {'type': 'foo1', 'name': 'foo1', 'value': 'foo1'}, {'type': 'foo2', 'name': 'foo2', 'value': 'foo2'}, {'type': 'foo3', 'name': '_ds', 'value': 'foo3'}, ] merged_parameters = google.datalab.bigquery.Query.merge_parameters( parameters, date_time=datetime.datetime.now(), macros=True, types_and_values=False) expected = { 'foo1': 'foo1', 'foo2': 'foo2', '_ds': 'foo3', '_ts': '{{ ts }}', '_ds_nodash': '{{ ds_nodash }}', '_ts_nodash': '{{ ts_nodash }}', '_ts_year': """{{ '{:04d}'.format(execution_date.year) }}""", '_ts_month': """{{ '{:02d}'.format(execution_date.month) }}""", '_ts_day': """{{ '{:02d}'.format(execution_date.day) }}""", '_ts_hour': """{{ '{:02d}'.format(execution_date.hour) }}""", '_ts_minute': """{{ '{:02d}'.format(execution_date.minute) }}""", '_ts_second': """{{ '{:02d}'.format(execution_date.second) }}""", } self.assertDictEqual(merged_parameters, expected) date_time = datetime.datetime.now() day = date_time.date() merged_parameters = google.datalab.bigquery.Query.merge_parameters( parameters, date_time=date_time, macros=False, types_and_values=False) expected = { u'_ts_nodash': date_time.strftime('%Y%m%d%H%M%S%f'), u'_ts_second': date_time.strftime('%S'), u'_ts_day': day.strftime('%d'), u'_ts_minute': date_time.strftime('%M'), u'_ts': date_time.isoformat(), u'_ts_hour': date_time.strftime('%H'), u'_ts_month': day.strftime('%m'), u'_ds_nodash': day.strftime('%Y%m%d'), u'_ds': 'foo3', u'_ts_year': day.strftime('%Y'), u'foo1': 'foo1', u'foo2': 'foo2' } self.assertDictEqual(merged_parameters, expected) merged_parameters = google.datalab.bigquery.Query.merge_parameters( parameters, date_time=date_time, macros=False, types_and_values=True) expected = { u'_ts_nodash': {u'type': u'STRING', u'value': date_time.strftime('%Y%m%d%H%M%S%f')}, u'_ts_second': {u'type': u'STRING', u'value': date_time.strftime('%S')}, u'_ts_day': {u'type': u'STRING', u'value': day.strftime('%d')}, u'_ts_minute': {u'type': u'STRING', u'value': date_time.strftime('%M')}, u'_ts': {u'type': u'STRING', u'value': date_time.isoformat()}, u'_ts_hour': {u'type': u'STRING', u'value': date_time.strftime('%H')}, u'_ts_month': {u'type': u'STRING', u'value': day.strftime('%m')}, u'_ds_nodash': {u'type': u'STRING', u'value': day.strftime('%Y%m%d')}, u'_ds': {u'type': u'foo3', u'value': 'foo3'}, u'_ts_year': {u'type': u'STRING', u'value': day.strftime('%Y')}, u'foo1': {u'type': u'foo1', u'value': u'foo1'}, u'foo2': {u'type': u'foo2', u'value': u'foo2'} } self.assertDictEqual(merged_parameters, expected) def test_resolve_parameters(self): date_time = datetime.datetime.now() day = date_time.date() day_string = day.isoformat() self.assertEqual(google.datalab.bigquery.Query.resolve_parameters('foo%(_ds)s', []), 'foo{0}'.format(day_string)) self.assertListEqual(google.datalab.bigquery.Query.resolve_parameters( ['foo%(_ds)s', 'bar%(_ds)s'], []), ['foo{0}'.format(day_string), 'bar{0}'.format(day_string)]) self.assertDictEqual(google.datalab.bigquery.Query.resolve_parameters( {'key%(_ds)s': 'value%(_ds)s'}, []), {'key{0}'.format(day_string): 'value{0}'.format(day_string)}) self.assertDictEqual(google.datalab.bigquery.Query.resolve_parameters( {'key%(_ds)s': {'key': 'value%(_ds)s'}}, []), {'key{0}'.format(day_string): {'key': 'value{0}'.format(day_string)}}) params = [{'name': 'custom_key', 'value': 'custom_value'}] self.assertDictEqual(google.datalab.bigquery.Query.resolve_parameters( {'key%(custom_key)s': 'value%(custom_key)s'}, params), {'keycustom_value': 'valuecustom_value'}) params = [{'name': '_ds', 'value': 'custom_value'}] self.assertDictEqual(google.datalab.bigquery.Query.resolve_parameters( {'key%(_ds)s': 'value%(_ds)s'}, params), {'keycustom_value': 'valuecustom_value'}) ================================================ FILE: tests/bigquery/sampling_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import unittest from google.datalab.bigquery import Sampling class TestCases(unittest.TestCase): BASE_SQL = '[]' def test_default(self): expected_sql = 'SELECT * FROM (%s) LIMIT 5' % TestCases.BASE_SQL self._apply_sampling(Sampling.default(), expected_sql) def test_default_custom_count(self): expected_sql = 'SELECT * FROM (%s) LIMIT 20' % TestCases.BASE_SQL self._apply_sampling(Sampling.default(count=20), expected_sql) def test_default_custom_fields(self): expected_sql = 'SELECT f1,f2 FROM (%s) LIMIT 5' % TestCases.BASE_SQL self._apply_sampling(Sampling.default(fields=['f1', 'f2']), expected_sql) def test_default_all_fields(self): expected_sql = 'SELECT * FROM (%s) LIMIT 5' % TestCases.BASE_SQL self._apply_sampling(Sampling.default(fields=[]), expected_sql) def test_hashed(self): expected_sql = 'SELECT * FROM (%s) ' \ 'WHERE MOD(ABS(FARM_FINGERPRINT(CAST(f1 AS STRING))), 100) < 5' \ % TestCases.BASE_SQL self._apply_sampling(Sampling.hashed('f1', 5), expected_sql) def test_hashed_and_limited(self): expected_sql = 'SELECT * FROM (%s) ' \ 'WHERE MOD(ABS(FARM_FINGERPRINT(CAST(f1 AS STRING))), 100) < 5 LIMIT 100' \ % TestCases.BASE_SQL self._apply_sampling(Sampling.hashed('f1', 5, count=100), expected_sql) def test_hashed_with_fields(self): expected_sql = 'SELECT f1 FROM (%s) ' \ 'WHERE MOD(ABS(FARM_FINGERPRINT(CAST(f1 AS STRING))), 100) < 5' \ % TestCases.BASE_SQL self._apply_sampling(Sampling.hashed('f1', 5, fields=['f1']), expected_sql) def test_sorted_ascending(self): expected_sql = 'SELECT * FROM (%s) ORDER BY f1 LIMIT 5' % TestCases.BASE_SQL self._apply_sampling(Sampling.sorted('f1'), expected_sql) def test_sorted_descending(self): expected_sql = 'SELECT * FROM (%s) ORDER BY f1 DESC LIMIT 5' % TestCases.BASE_SQL self._apply_sampling(Sampling.sorted('f1', ascending=False), expected_sql) def test_sorted_with_fields(self): expected_sql = 'SELECT f1,f2 FROM (%s) ORDER BY f1 LIMIT 5' % TestCases.BASE_SQL self._apply_sampling(Sampling.sorted('f1', fields=['f1', 'f2']), expected_sql) def _apply_sampling(self, sampling, expected_query): sampled_query = sampling(TestCases.BASE_SQL) self.assertEqual(sampled_query, expected_query) ================================================ FILE: tests/bigquery/schema_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import collections import pandas import sys import unittest import google.datalab.bigquery import google.datalab.utils class TestCases(unittest.TestCase): def test_schema_from_dataframe(self): df = TestCases._create_data_frame() result = google.datalab.bigquery.Schema.from_data(df) self.assertEqual(google.datalab.bigquery.Schema.from_data(TestCases._create_inferred_schema()), result) def test_schema_from_data(self): variant1 = [ 3, 2.0, True, ['cow', 'horse', [0, []]] ] variant2 = collections.OrderedDict() variant2['Column1'] = 3 variant2['Column2'] = 2.0 variant2['Column3'] = True variant2['Column4'] = collections.OrderedDict() variant2['Column4']['Column1'] = 'cow' variant2['Column4']['Column2'] = 'horse' variant2['Column4']['Column3'] = collections.OrderedDict() variant2['Column4']['Column3']['Column1'] = 0 variant2['Column4']['Column3']['Column2'] = collections.OrderedDict() master = [ {'name': 'Column1', 'type': 'INTEGER'}, {'name': 'Column2', 'type': 'FLOAT'}, {'name': 'Column3', 'type': 'BOOLEAN'}, {'name': 'Column4', 'type': 'RECORD', 'fields': [ {'name': 'Column1', 'type': 'STRING'}, {'name': 'Column2', 'type': 'STRING'}, {'name': 'Column3', 'type': 'RECORD', 'fields': [ {'name': 'Column1', 'type': 'INTEGER'}, {'name': 'Column2', 'type': 'RECORD', 'fields': []} ]} ]} ] schema_master = google.datalab.bigquery.Schema(master) with self.assertRaises(Exception) as error1: google.datalab.bigquery.Schema.from_data(variant1) if sys.version_info[0] == 3: self.assertEquals('Cannot create a schema from heterogeneous list [3, 2.0, True, ' + '[\'cow\', \'horse\', [0, []]]]; perhaps you meant to use ' + 'Schema.from_record?', str(error1.exception)) else: self.assertEquals('Cannot create a schema from heterogeneous list [3, 2.0, True, ' + '[u\'cow\', u\'horse\', [0, []]]]; perhaps you meant to use ' + 'Schema.from_record?', str(error1.exception)) schema3 = google.datalab.bigquery.Schema.from_data([variant1]) schema4 = google.datalab.bigquery.Schema.from_data([variant2]) schema5 = google.datalab.bigquery.Schema.from_data(master) schema6 = google.datalab.bigquery.Schema.from_record(variant1) schema7 = google.datalab.bigquery.Schema.from_record(variant2) self.assertEquals(schema_master, schema3, 'schema inferred from list of lists with from_data') self.assertEquals(schema_master, schema4, 'schema inferred from list of dicts with from_data') self.assertEquals(schema_master, schema5, 'schema inferred from BQ schema list with from_data') self.assertEquals(schema_master, schema6, 'schema inferred from list with from_record') self.assertEquals(schema_master, schema7, 'schema inferred from dict with from_record') @staticmethod def _create_data_frame(): data = { 'some': [ 0, 1, 2, 3 ], 'column': [ 'r0', 'r1', 'r2', 'r3' ], 'headers': [ 10.0, 10.0, 10.0, 10.0 ] } return pandas.DataFrame(data) @staticmethod def _create_inferred_schema(extra_field=None): schema = [ {'name': 'some', 'type': 'INTEGER'}, {'name': 'column', 'type': 'STRING'}, {'name': 'headers', 'type': 'FLOAT'}, ] if extra_field: schema.append({'name': extra_field, 'type': 'INTEGER'}) return schema ================================================ FILE: tests/bigquery/table_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals from builtins import str from builtins import object import calendar import datetime as dt import mock import pandas import unittest import google.auth import google.datalab import google.datalab.bigquery import google.datalab.utils class TestCases(unittest.TestCase): def _check_name_parts(self, table): parsed_name = table._name_parts self.assertEqual('test', parsed_name[0]) self.assertEqual('requestlogs', parsed_name[1]) self.assertEqual('today', parsed_name[2]) self.assertEqual('', parsed_name[3]) self.assertEqual('`test.requestlogs.today`', table._repr_sql_()) def test_api_paths(self): name = google.datalab.bigquery._utils.TableName('a', 'b', 'c', 'd') self.assertEqual('/projects/a/datasets/b/tables/cd', google.datalab.bigquery._api.Api._TABLES_PATH % name) self.assertEqual('/projects/a/datasets/b/tables/cd/data', google.datalab.bigquery._api.Api._TABLEDATA_PATH % name) name = google.datalab.bigquery._utils.DatasetName('a', 'b') self.assertEqual('/projects/a/datasets/b', google.datalab.bigquery._api.Api._DATASETS_PATH % name) def test_parse_full_name(self): table = TestCases._create_table('test.requestlogs.today') self._check_name_parts(table) def test_parse_local_name(self): table = TestCases._create_table('requestlogs.today') self._check_name_parts(table) def test_parse_dict_full_name(self): table = TestCases._create_table({'project_id': 'test', 'dataset_id': 'requestlogs', 'table_id': 'today'}) self._check_name_parts(table) def test_parse_dict_local_name(self): table = TestCases._create_table({'dataset_id': 'requestlogs', 'table_id': 'today'}) self._check_name_parts(table) def test_parse_named_tuple_name(self): table = TestCases._create_table(google.datalab.bigquery._utils.TableName('test', 'requestlogs', 'today', '')) self._check_name_parts(table) def test_parse_tuple_full_name(self): table = TestCases._create_table(('test', 'requestlogs', 'today')) self._check_name_parts(table) def test_parse_tuple_local(self): table = TestCases._create_table(('requestlogs', 'today')) self._check_name_parts(table) def test_parse_array_full_name(self): table = TestCases._create_table(['test', 'requestlogs', 'today']) self._check_name_parts(table) def test_parse_array_local(self): table = TestCases._create_table(['requestlogs', 'today']) self._check_name_parts(table) def test_parse_invalid_name(self): with self.assertRaises(Exception): TestCases._create_table('today@') @mock.patch('google.datalab.bigquery._api.Api.tables_get') def test_table_metadata(self, mock_api_tables_get): name = 'test.requestlogs.today' ts = dt.datetime.utcnow() mock_api_tables_get.return_value = TestCases._create_table_info_result(ts=ts) t = TestCases._create_table(name) metadata = t.metadata self.assertEqual('Logs', metadata.friendly_name) self.assertEqual(2, metadata.rows) self.assertEqual(2, metadata.rows) self.assertTrue(abs((metadata.created_on - ts).total_seconds()) <= 1) self.assertEqual(None, metadata.expires_on) @mock.patch('google.datalab.bigquery._api.Api.tables_get') def test_table_schema(self, mock_api_tables): mock_api_tables.return_value = TestCases._create_table_info_result() t = TestCases._create_table('test.requestlogs.today') schema = t.schema self.assertEqual(2, len(schema)) self.assertEqual('name', schema[0].name) @mock.patch('google.datalab.bigquery._api.Api.tables_get') def test_table_schema_nested(self, mock_api_tables): mock_api_tables.return_value = TestCases._create_table_info_nested_schema_result() t = TestCases._create_table('test.requestlogs.today') schema = t.schema self.assertEqual(4, len(schema)) self.assertEqual('name', schema[0].name) self.assertEqual('val', schema[1].name) self.assertEqual('more', schema[2].name) self.assertEqual('more.xyz', schema[3].name) self.assertIsNone(schema['value']) self.assertIsNotNone(schema['val']) @mock.patch('google.datalab.bigquery._api.Api.tables_get') def test_malformed_response_raises_exception(self, mock_api_tables_get): mock_api_tables_get.return_value = {} t = TestCases._create_table('test.requestlogs.today') with self.assertRaises(Exception) as error: t.schema self.assertEqual('Unexpected table response: missing schema', str(error.exception)) @mock.patch('google.datalab.bigquery._api.Api.tables_list') @mock.patch('google.datalab.bigquery._api.Api.datasets_get') def test_dataset_list(self, mock_api_datasets_get, mock_api_tables_list): mock_api_datasets_get.return_value = None mock_api_tables_list.return_value = TestCases._create_table_list_result() ds = google.datalab.bigquery.Dataset('testds', context=TestCases._create_context()) tables = [] for table in ds: tables.append(table) self.assertEqual(2, len(tables)) self.assertEqual('`test.testds.testTable1`', tables[0]._repr_sql_()) self.assertEqual('`test.testds.testTable2`', tables[1]._repr_sql_()) @mock.patch('google.datalab.bigquery._api.Api.tables_list') @mock.patch('google.datalab.bigquery._api.Api.datasets_get') def test_table_list(self, mock_api_datasets_get, mock_api_tables_list): mock_api_datasets_get.return_value = None mock_api_tables_list.return_value = TestCases._create_table_list_result() ds = google.datalab.bigquery.Dataset('testds', context=TestCases._create_context()) tables = [] for table in ds.tables(): tables.append(table) self.assertEqual(2, len(tables)) self.assertEqual('`test.testds.testTable1`', tables[0]._repr_sql_()) self.assertEqual('`test.testds.testTable2`', tables[1]._repr_sql_()) @mock.patch('google.datalab.bigquery._api.Api.tables_list') @mock.patch('google.datalab.bigquery._api.Api.datasets_get') def test_view_list(self, mock_api_datasets_get, mock_api_tables_list): mock_api_datasets_get.return_value = None mock_api_tables_list.return_value = TestCases._create_table_list_result() ds = google.datalab.bigquery.Dataset('testds', context=TestCases._create_context()) views = [] for view in ds.views(): views.append(view) self.assertEqual(1, len(views)) self.assertEqual('`test.testds.testView1`', views[0]._repr_sql_()) @mock.patch('google.datalab.bigquery._api.Api.tables_list') @mock.patch('google.datalab.bigquery._api.Api.datasets_get') def test_table_list_empty(self, mock_api_datasets_get, mock_api_tables_list): mock_api_datasets_get.return_value = None mock_api_tables_list.return_value = TestCases._create_table_list_empty_result() ds = google.datalab.bigquery.Dataset('testds', context=TestCases._create_context()) tables = [] for table in ds: tables.append(table) self.assertEqual(0, len(tables)) @mock.patch('google.datalab.bigquery._api.Api.tables_get') def test_table_exists(self, mock_api_tables_get): mock_api_tables_get.return_value = None tbl = google.datalab.bigquery.Table('testds.testTable0', context=TestCases._create_context()) self.assertTrue(tbl.exists()) mock_api_tables_get.side_effect = google.datalab.utils.RequestException(404, 'failed') self.assertFalse(tbl.exists()) @mock.patch('google.datalab.bigquery._api.Api.tables_insert') @mock.patch('google.datalab.bigquery._api.Api.tables_list') @mock.patch('google.datalab.bigquery._api.Api.datasets_get') def test_tables_create(self, mock_api_datasets_get, mock_api_tables_list, mock_api_tables_insert): mock_api_datasets_get.return_value = None mock_api_tables_list.return_value = [] schema = TestCases._create_inferred_schema() mock_api_tables_insert.return_value = {} with self.assertRaises(Exception) as error: TestCases._create_table_with_schema(schema) self.assertEqual('Table test.testds.testTable0 could not be created as it already exists', str(error.exception)) mock_api_tables_insert.return_value = {'selfLink': 'http://foo'} self.assertIsNotNone(TestCases._create_table_with_schema(schema), 'Expected a table') @mock.patch('uuid.uuid4') @mock.patch('time.sleep') @mock.patch('google.datalab.bigquery._api.Api.tables_list') @mock.patch('google.datalab.bigquery._api.Api.tables_insert') @mock.patch('google.datalab.bigquery._api.Api.tables_get') @mock.patch('google.datalab.bigquery._api.Api.tabledata_insert_all') @mock.patch('google.datalab.bigquery._api.Api.datasets_get') def test_insert_no_table(self, mock_api_datasets_get, mock_api_tabledata_insert_all, mock_api_tables_get, mock_api_tables_insert, mock_api_tables_list, mock_time_sleep, mock_uuid): mock_uuid.return_value = TestCases._create_uuid() mock_time_sleep.return_value = None mock_api_tables_list.return_value = [] mock_api_tables_insert.return_value = {'selfLink': 'http://foo'} mock_api_tables_get.side_effect = google.datalab.utils.RequestException(404, 'failed') mock_api_tabledata_insert_all.return_value = {} mock_api_datasets_get.return_value = None table = TestCases._create_table_with_schema(TestCases._create_inferred_schema()) df = TestCases._create_data_frame() with self.assertRaises(Exception) as error: table.insert(df) self.assertEqual('Table %s does not exist.' % table._full_name, str(error.exception)) @mock.patch('uuid.uuid4') @mock.patch('time.sleep') @mock.patch('google.datalab.bigquery._api.Api.datasets_get') @mock.patch('google.datalab.bigquery._api.Api.tables_list') @mock.patch('google.datalab.bigquery._api.Api.tables_insert') @mock.patch('google.datalab.bigquery._api.Api.tables_get') @mock.patch('google.datalab.bigquery._api.Api.tabledata_insert_all') def test_insert_missing_field(self, mock_api_tabledata_insert_all, mock_api_tables_get, mock_api_tables_insert, mock_api_tables_list, mock_api_datasets_get, mock_time_sleep, mock_uuid,): # Truncate the schema used when creating the table so we have an unmatched column in insert. schema = TestCases._create_inferred_schema()[:2] mock_uuid.return_value = TestCases._create_uuid() mock_time_sleep.return_value = None mock_api_datasets_get.return_value = None mock_api_tables_insert.return_value = {'selfLink': 'http://foo'} mock_api_tables_list.return_value = [] mock_api_tables_get.return_value = {'schema': {'fields': schema}} mock_api_tabledata_insert_all.return_value = {} table = TestCases._create_table_with_schema(schema) df = TestCases._create_data_frame() with self.assertRaises(Exception) as error: table.insert(df) self.assertEqual('Table does not contain field headers', str(error.exception)) @mock.patch('uuid.uuid4') @mock.patch('time.sleep') @mock.patch('google.datalab.bigquery._api.Api.tables_list') @mock.patch('google.datalab.bigquery._api.Api.tables_insert') @mock.patch('google.datalab.bigquery._api.Api.tables_get') @mock.patch('google.datalab.bigquery._api.Api.tabledata_insert_all') @mock.patch('google.datalab.bigquery._api.Api.datasets_get') def test_insert_mismatched_schema(self, mock_api_datasets_get, mock_api_tabledata_insert_all, mock_api_tables_get, mock_api_tables_insert, mock_api_tables_list, mock_time_sleep, mock_uuid): # Change the schema used when creating the table so we get a mismatch when inserting. schema = TestCases._create_inferred_schema() schema[2]['type'] = 'STRING' mock_uuid.return_value = TestCases._create_uuid() mock_time_sleep.return_value = None mock_api_tables_list.return_value = [] mock_api_tables_insert.return_value = {'selfLink': 'http://foo'} mock_api_tables_get.return_value = {'schema': {'fields': schema}} mock_api_tabledata_insert_all.return_value = {} mock_api_datasets_get.return_value = None table = TestCases._create_table_with_schema(schema) df = TestCases._create_data_frame() with self.assertRaises(Exception) as error: table.insert(df) self.assertEqual('Field headers in data has type FLOAT but in table has type STRING', str(error.exception)) @mock.patch('uuid.uuid4') @mock.patch('time.sleep') @mock.patch('google.datalab.bigquery._api.Api.datasets_get') @mock.patch('google.datalab.bigquery._api.Api.tables_list') @mock.patch('google.datalab.bigquery._api.Api.tables_insert') @mock.patch('google.datalab.bigquery._api.Api.tables_get') @mock.patch('google.datalab.bigquery._api.Api.tabledata_insert_all') def test_insert_dataframe(self, mock_api_tabledata_insert_all, mock_api_tables_get, mock_api_tables_insert, mock_api_tables_list, mock_api_datasets_get, mock_time_sleep, mock_uuid): schema = TestCases._create_inferred_schema() mock_uuid.return_value = TestCases._create_uuid() mock_time_sleep.return_value = None mock_api_datasets_get.return_value = True mock_api_tables_list.return_value = [] mock_api_tables_insert.return_value = {'selfLink': 'http://foo'} mock_api_tables_get.return_value = {'schema': {'fields': schema}} mock_api_tabledata_insert_all.return_value = {} table = TestCases._create_table_with_schema(schema) df = TestCases._create_data_frame() result = table.insert(df) self.assertIsNotNone(result, "insert_all should return the table object") mock_api_tabledata_insert_all.assert_called_with(('test', 'testds', 'testTable0', ''), [ {'insertId': '#0', 'json': {u'column': 'r0', u'headers': 10.0, u'some': 0}}, {'insertId': '#1', 'json': {u'column': 'r1', u'headers': 10.0, u'some': 1}}, {'insertId': '#2', 'json': {u'column': 'r2', u'headers': 10.0, u'some': 2}}, {'insertId': '#3', 'json': {u'column': 'r3', u'headers': 10.0, u'some': 3}} ]) @mock.patch('uuid.uuid4') @mock.patch('time.sleep') @mock.patch('google.datalab.bigquery._api.Api.datasets_get') @mock.patch('google.datalab.bigquery._api.Api.tables_list') @mock.patch('google.datalab.bigquery._api.Api.tables_insert') @mock.patch('google.datalab.bigquery._api.Api.tables_get') @mock.patch('google.datalab.bigquery._api.Api.tabledata_insert_all') def test_insert_dictlist(self, mock_api_tabledata_insert_all, mock_api_tables_get, mock_api_tables_insert, mock_api_tables_list, mock_api_datasets_get, mock_time_sleep, mock_uuid): schema = TestCases._create_inferred_schema() mock_uuid.return_value = TestCases._create_uuid() mock_time_sleep.return_value = None mock_api_datasets_get.return_value = True mock_api_tables_list.return_value = [] mock_api_tables_insert.return_value = {'selfLink': 'http://foo'} mock_api_tables_get.return_value = {'schema': {'fields': schema}} mock_api_tabledata_insert_all.return_value = {} table = TestCases._create_table_with_schema(schema) result = table.insert([ {u'column': 'r0', u'headers': 10.0, u'some': 0}, {u'column': 'r1', u'headers': 10.0, u'some': 1}, {u'column': 'r2', u'headers': 10.0, u'some': 2}, {u'column': 'r3', u'headers': 10.0, u'some': 3} ]) self.assertIsNotNone(result, "insert_all should return the table object") mock_api_tabledata_insert_all.assert_called_with(('test', 'testds', 'testTable0', ''), [ {'insertId': '#0', 'json': {u'column': 'r0', u'headers': 10.0, u'some': 0}}, {'insertId': '#1', 'json': {u'column': 'r1', u'headers': 10.0, u'some': 1}}, {'insertId': '#2', 'json': {u'column': 'r2', u'headers': 10.0, u'some': 2}}, {'insertId': '#3', 'json': {u'column': 'r3', u'headers': 10.0, u'some': 3}} ]) @mock.patch('uuid.uuid4') @mock.patch('time.sleep') @mock.patch('google.datalab.bigquery._api.Api.datasets_get') @mock.patch('google.datalab.bigquery._api.Api.tables_list') @mock.patch('google.datalab.bigquery._api.Api.tables_insert') @mock.patch('google.datalab.bigquery._api.Api.tables_get') @mock.patch('google.datalab.bigquery._api.Api.tabledata_insert_all') def test_insert_dictlist_index(self, mock_api_tabledata_insert_all, mock_api_tables_get, mock_api_tables_insert, mock_api_tables_list, mock_api_datasets_get, mock_time_sleep, mock_uuid): schema = TestCases._create_inferred_schema('Index') mock_uuid.return_value = TestCases._create_uuid() mock_time_sleep.return_value = None mock_api_datasets_get.return_value = True mock_api_tables_list.return_value = [] mock_api_tables_insert.return_value = {'selfLink': 'http://foo'} mock_api_tables_get.return_value = {'schema': {'fields': schema}} mock_api_tabledata_insert_all.return_value = {} table = TestCases._create_table_with_schema(schema) result = table.insert([ {u'column': 'r0', u'headers': 10.0, u'some': 0}, {u'column': 'r1', u'headers': 10.0, u'some': 1}, {u'column': 'r2', u'headers': 10.0, u'some': 2}, {u'column': 'r3', u'headers': 10.0, u'some': 3} ], include_index=True) self.assertIsNotNone(result, "insert_all should return the table object") mock_api_tabledata_insert_all.assert_called_with(('test', 'testds', 'testTable0', ''), [ {'insertId': '#0', 'json': {u'column': 'r0', u'headers': 10.0, u'some': 0, 'Index': 0}}, {'insertId': '#1', 'json': {u'column': 'r1', u'headers': 10.0, u'some': 1, 'Index': 1}}, {'insertId': '#2', 'json': {u'column': 'r2', u'headers': 10.0, u'some': 2, 'Index': 2}}, {'insertId': '#3', 'json': {u'column': 'r3', u'headers': 10.0, u'some': 3, 'Index': 3}} ]) @mock.patch('uuid.uuid4') @mock.patch('time.sleep') @mock.patch('google.datalab.bigquery._api.Api.datasets_get') @mock.patch('google.datalab.bigquery._api.Api.tables_list') @mock.patch('google.datalab.bigquery._api.Api.tables_insert') @mock.patch('google.datalab.bigquery._api.Api.tables_get') @mock.patch('google.datalab.bigquery._api.Api.tabledata_insert_all') def test_insert_dictlist_named_index(self, mock_api_tabledata_insert_all, mock_api_tables_get, mock_api_tables_insert, mock_api_tables_list, mock_api_datasets_get, mock_time_sleep, mock_uuid): schema = TestCases._create_inferred_schema('Row') mock_uuid.return_value = TestCases._create_uuid() mock_time_sleep.return_value = None mock_api_datasets_get.return_value = True mock_api_tables_list.return_value = [] mock_api_tables_insert.return_value = {'selfLink': 'http://foo'} mock_api_tables_get.return_value = {'schema': {'fields': schema}} mock_api_tabledata_insert_all.return_value = {} table = TestCases._create_table_with_schema(schema) result = table.insert([ {u'column': 'r0', u'headers': 10.0, u'some': 0}, {u'column': 'r1', u'headers': 10.0, u'some': 1}, {u'column': 'r2', u'headers': 10.0, u'some': 2}, {u'column': 'r3', u'headers': 10.0, u'some': 3} ], include_index=True, index_name='Row') self.assertIsNotNone(result, "insert_all should return the table object") mock_api_tabledata_insert_all.assert_called_with(('test', 'testds', 'testTable0', ''), [ {'insertId': '#0', 'json': {u'column': 'r0', u'headers': 10.0, u'some': 0, 'Row': 0}}, {'insertId': '#1', 'json': {u'column': 'r1', u'headers': 10.0, u'some': 1, 'Row': 1}}, {'insertId': '#2', 'json': {u'column': 'r2', u'headers': 10.0, u'some': 2, 'Row': 2}}, {'insertId': '#3', 'json': {u'column': 'r3', u'headers': 10.0, u'some': 3, 'Row': 3}} ]) @mock.patch('google.datalab.bigquery._api.Api.tables_get') @mock.patch('google.datalab.bigquery._api.Api.jobs_insert_load') @mock.patch('google.datalab.bigquery._api.Api.jobs_get') def test_table_load(self, mock_api_jobs_get, mock_api_jobs_insert_load, mock_api_tables_get): schema = TestCases._create_inferred_schema('Row') mock_api_jobs_get.return_value = {'status': {'state': 'DONE'}} mock_api_jobs_insert_load.return_value = None mock_api_tables_get.return_value = {'schema': {'fields': schema}} tbl = google.datalab.bigquery.Table('testds.testTable0', context=TestCases._create_context()) job = tbl.load('gs://foo') self.assertIsNone(job) mock_api_jobs_insert_load.return_value = {'jobReference': {'jobId': 'bar'}} job = tbl.load('gs://foo') self.assertEquals('bar', job.id) @mock.patch('google.datalab.bigquery._api.Api.table_extract') @mock.patch('google.datalab.bigquery._api.Api.jobs_get') @mock.patch('google.datalab.bigquery._api.Api.tables_get') def test_table_extract(self, mock_api_tables_get, mock_api_jobs_get, mock_api_table_extract): mock_api_tables_get.return_value = {} mock_api_jobs_get.return_value = {'status': {'state': 'DONE'}} mock_api_table_extract.return_value = None tbl = google.datalab.bigquery.Table('testds.testTable0', context=self._create_context()) tbl._api.table_extract = mock.Mock(return_value={'jobReference': {'jobId': 'bar'}}) job = tbl.extract('gs://foo') tbl._api.table_extract.assert_called_with(tbl.name, 'gs://foo', 'CSV', False, ',', True) self.assertEquals('bar', job.id) tbl.extract('gs://foo', format='json') tbl._api.table_extract.assert_called_with(tbl.name, 'gs://foo', 'NEWLINE_DELIMITED_JSON', False, None, True) tbl.extract('gs://foo', format='avro') tbl._api.table_extract.assert_called_with(tbl.name, 'gs://foo', 'AVRO', False, None, True) @mock.patch('google.datalab.bigquery._api.Api.tabledata_list') @mock.patch('google.datalab.bigquery._api.Api.tables_get') def test_table_to_dataframe(self, mock_api_tables_get, mock_api_tabledata_list): schema = self._create_inferred_schema() mock_api_tables_get.return_value = {'schema': {'fields': schema}} mock_api_tabledata_list.return_value = { 'rows': [ {'f': [{'v': 1}, {'v': 'foo'}, {'v': 3.1415}]}, {'f': [{'v': 2}, {'v': 'bar'}, {'v': 0.5}]}, ] } tbl = google.datalab.bigquery.Table('testds.testTable0', context=TestCases._create_context()) df = tbl.to_dataframe() self.assertEquals(2, len(df)) self.assertEquals(1, df['some'][0]) self.assertEquals(2, df['some'][1]) self.assertEquals('foo', df['column'][0]) self.assertEquals('bar', df['column'][1]) self.assertEquals(3.1415, df['headers'][0]) self.assertEquals(0.5, df['headers'][1]) def test_encode_dict_as_row_datetime(self): when = dt.datetime(2001, 2, 3, 4, 5, 6, 7) row = google.datalab.bigquery.Table._encode_dict_as_row({'fo@o': 'b@r', 'b+ar': when}, {}) self.assertEqual({'foo': 'b@r', 'bar': '2001-02-03T04:05:06.000007'}, row) def test_encode_dict_as_row_date(self): when = dt.date(2001, 2, 3) row = google.datalab.bigquery.Table._encode_dict_as_row({'fo@o': 'b@r', 'b+ar': when}, {}) self.assertEqual({'foo': 'b@r', 'bar': '2001-02-03'}, row) def test_encode_dict_as_row_time(self): when = dt.time(1, 2, 3, 4) row = google.datalab.bigquery.Table._encode_dict_as_row({'fo@o': 'b@r', 'b+ar': when}, {}) self.assertEqual({'foo': 'b@r', 'bar': '01:02:03.000004'}, row) def test_decorators(self): tbl = google.datalab.bigquery.Table('testds.testTable0', context=TestCases._create_context()) tbl2 = tbl.snapshot(dt.timedelta(hours=-1)) self.assertEquals('`test.testds.testTable0@-3600000`', tbl2._repr_sql_()) with self.assertRaises(Exception) as error: tbl2 = tbl2.snapshot(dt.timedelta(hours=-2)) self.assertEqual('Cannot use snapshot() on an already decorated table', str(error.exception)) with self.assertRaises(Exception) as error: tbl2.window(dt.timedelta(hours=-2), 0) self.assertEqual('Cannot use window() on an already decorated table', str(error.exception)) with self.assertRaises(Exception) as error: tbl.snapshot(dt.timedelta(days=-8)) self.assertEqual( 'Invalid snapshot relative when argument: must be within 7 days: -8 days, 0:00:00', str(error.exception)) with self.assertRaises(Exception) as error: tbl.snapshot(dt.timedelta(days=-8)) self.assertEqual( 'Invalid snapshot relative when argument: must be within 7 days: -8 days, 0:00:00', str(error.exception)) tbl2 = tbl.snapshot(dt.timedelta(days=-1)) self.assertEquals('`test.testds.testTable0@-86400000`', tbl2._repr_sql_()) with self.assertRaises(Exception) as error: tbl.snapshot(dt.timedelta(days=1)) self.assertEqual('Invalid snapshot relative when argument: 1 day, 0:00:00', str(error.exception)) with self.assertRaises(Exception) as error: tbl2 = tbl.snapshot(1000) self.assertEqual('Invalid snapshot when argument type: 1000', str(error.exception)) self.assertEquals('`test.testds.testTable0@-86400000`', tbl2._repr_sql_()) when = dt.datetime.utcnow() + dt.timedelta(1) with self.assertRaises(Exception) as error: tbl.snapshot(when) self.assertEqual('Invalid snapshot absolute when argument: %s' % when, str(error.exception)) when = dt.datetime.utcnow() - dt.timedelta(8) with self.assertRaises(Exception) as error: tbl.snapshot(when) self.assertEqual('Invalid snapshot absolute when argument: %s' % when, str(error.exception)) def test_window_decorators(self): # The at test above already tests many of the conversion cases. The extra things we # have to test are that we can use two values, we get a meaningful default for the second # if we pass None, and that the first time comes before the second. tbl = google.datalab.bigquery.Table('testds.testTable0', context=TestCases._create_context()) tbl2 = tbl.window(dt.timedelta(hours=-1)) self.assertEquals('`test.testds.testTable0@-3600000-0`', tbl2._repr_sql_()) with self.assertRaises(Exception) as error: tbl2 = tbl2.window(-400000, 0) self.assertEqual('Cannot use window() on an already decorated table', str(error.exception)) with self.assertRaises(Exception) as error: tbl2.snapshot(-400000) self.assertEqual('Cannot use snapshot() on an already decorated table', str(error.exception)) with self.assertRaises(Exception) as error: tbl.window(dt.timedelta(0), dt.timedelta(hours=-1)) self.assertEqual( 'window: Between arguments: begin must be before end: 0:00:00, -1 day, 23:00:00', str(error.exception)) @mock.patch('google.datalab.bigquery._api.Api.tables_get') @mock.patch('google.datalab.bigquery._api.Api.table_update') def test_table_update(self, mock_api_table_update, mock_api_tables_get): schema = self._create_inferred_schema() info = {'schema': {'fields': schema}, 'friendlyName': 'casper', 'description': 'ghostly logs', 'expirationTime': calendar.timegm(dt.datetime(2020, 1, 1).utctimetuple()) * 1000} mock_api_tables_get.return_value = info tbl = google.datalab.bigquery.Table('testds.testTable0', context=TestCases._create_context()) new_name = 'aziraphale' new_description = 'demon duties' new_schema = [{'name': 'injected', 'type': 'FLOAT'}] new_schema.extend(schema) new_expiry = dt.datetime(2030, 1, 1) tbl.update(new_name, new_description, new_expiry, new_schema) name, info = mock_api_table_update.call_args[0] self.assertEqual(tbl.name, name) self.assertEqual(new_name, tbl.metadata.friendly_name) self.assertEqual(new_description, tbl.metadata.description) self.assertEqual(new_expiry, tbl.metadata.expires_on) self.assertEqual(len(new_schema), len(tbl.schema)) def test_table_to_query(self): tbl = google.datalab.bigquery.Table('testds.testTable0', context=TestCases._create_context()) q = google.datalab.bigquery.Query.from_table(tbl) self.assertEqual('SELECT * FROM `test.testds.testTable0`', q.sql) q = google.datalab.bigquery.Query.from_table(tbl, fields='foo, bar') self.assertEqual('SELECT foo, bar FROM `test.testds.testTable0`', q.sql) q = google.datalab.bigquery.Query.from_table(tbl, fields=['bar', 'foo']) self.assertEqual('SELECT bar,foo FROM `test.testds.testTable0`', q.sql) @mock.patch('google.datalab.bigquery._api.Api.tables_get') def test_row_fetcher(self, mock_api_tables_get): schema = self._create_inferred_schema() mock_api_tables_get.return_value = {'schema': {'fields': schema}} dummy_row = {'f': [{'v': 1}, {'v': 'foo'}, {'v': 3.1415}]} results = { 'rows': [dummy_row] * 10, 'pageToken': None } tbl = TestCases._create_table('test.table') tbl._api.tabledata_list = mock.Mock(return_value=results) # using Table.to_dataframe should use large pages to reduce traffic tbl.to_dataframe() tbl._api.tabledata_list.assert_called_with(tbl.name, max_results=100000, start_index=0) # using Table.range or iterator should use smaller pages to reduce latency list(tbl.range(start_row=0)) tbl._api.tabledata_list.assert_called_with(tbl.name, max_results=1024, start_index=0) tbl[0] tbl._api.tabledata_list.assert_called_with(tbl.name, max_results=1024, start_index=0) @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return google.datalab.Context(project_id, creds) @staticmethod def _create_table(name): return google.datalab.bigquery.Table(name, TestCases._create_context()) @staticmethod def _create_table_info_result(ts=None): if ts is None: ts = dt.datetime.utcnow() epoch = dt.datetime.utcfromtimestamp(0) timestamp = (ts - epoch).total_seconds() * 1000 return { 'description': 'Daily Logs Table', 'friendlyName': 'Logs', 'numBytes': 1000, 'numRows': 2, 'creationTime': timestamp, 'lastModifiedTime': timestamp, 'schema': { 'fields': [ {'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE'}, {'name': 'val', 'type': 'INTEGER', 'mode': 'NULLABLE'} ] } } @staticmethod def _create_table_info_nested_schema_result(ts=None): if ts is None: ts = dt.datetime.utcnow() epoch = dt.datetime.utcfromtimestamp(0) timestamp = (ts - epoch).total_seconds() * 1000 return { 'description': 'Daily Logs Table', 'friendlyName': 'Logs', 'numBytes': 1000, 'numRows': 2, 'creationTime': timestamp, 'lastModifiedTime': timestamp, 'schema': { 'fields': [ {'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE'}, {'name': 'val', 'type': 'INTEGER', 'mode': 'NULLABLE'}, {'name': 'more', 'type': 'RECORD', 'mode': 'REPEATED', 'fields': [ {'name': 'xyz', 'type': 'INTEGER', 'mode': 'NULLABLE'} ] } ] } } @staticmethod def _create_dataset(dataset_id): return google.datalab.bigquery.Dataset(dataset_id, context=TestCases._create_context()) @staticmethod def _create_table_list_result(): return { 'tables': [ { 'type': 'TABLE', 'tableReference': {'projectId': 'test', 'datasetId': 'testds', 'tableId': 'testTable1'} }, { 'type': 'VIEW', 'tableReference': {'projectId': 'test', 'datasetId': 'testds', 'tableId': 'testView1'} }, { 'type': 'TABLE', 'tableReference': {'projectId': 'test', 'datasetId': 'testds', 'tableId': 'testTable2'} } ] } @staticmethod def _create_table_list_empty_result(): return { 'tables': [] } @staticmethod def _create_data_frame(): data = { 'some': [ 0, 1, 2, 3 ], 'column': [ 'r0', 'r1', 'r2', 'r3' ], 'headers': [ 10.0, 10.0, 10.0, 10.0 ] } return pandas.DataFrame(data) @staticmethod def _create_inferred_schema(extra_field=None): schema = [ {'name': 'some', 'type': 'INTEGER'}, {'name': 'column', 'type': 'STRING'}, {'name': 'headers', 'type': 'FLOAT'}, ] if extra_field: schema.append({'name': extra_field, 'type': 'INTEGER'}) return schema @staticmethod def _create_table_with_schema(schema, name='test.testds.testTable0'): return google.datalab.bigquery.Table(name, TestCases._create_context()).create(schema) class _uuid(object): @property def hex(self): return '#' @staticmethod def _create_uuid(): return TestCases._uuid() ================================================ FILE: tests/bigquery/udf_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import unittest import google.datalab import google.datalab.bigquery class TestCases(unittest.TestCase): def _create_udf(self, name='test_udf', code='console.log("test");', return_type='integer', params=None, language='js', imports=None): if params is None: params = [('test_param', 'integer')] if code is None: code = 'test code;' if imports is None: imports = ['gcs://mylib'] return google.datalab.bigquery.UDF(name, code, return_type, params, language, imports) def test_building_udf(self): code = 'console.log("test");' imports = ['gcs://test_lib'] udf = self._create_udf(code=code, imports=imports) expected_sql = '''\ CREATE TEMPORARY FUNCTION test_udf (test_param integer) RETURNS integer LANGUAGE js AS """ console.log("test"); """ OPTIONS ( library="gcs://test_lib" );\ ''' self.assertEqual(udf.name, 'test_udf') self.assertEqual(udf.code, code) self.assertEqual(udf.imports, imports) self.assertEqual(udf._language, 'js') self.assertEqual(udf._repr_sql_(), udf._expanded_sql()) self.assertEqual(udf.__repr__(), 'BigQuery UDF - code:\n%s' % udf._code) self.assertEqual(udf._expanded_sql(), expected_sql) def test_udf_bad_return_type(self): with self.assertRaisesRegexp(TypeError, 'Argument return_type should be a string'): self._create_udf(return_type=['integer']) def test_udf_bad_params(self): with self.assertRaisesRegexp(TypeError, 'Argument params should be a list'): self._create_udf(params={'param1': 'param2'}) def test_udf_params_order(self): udf = self._create_udf(params=[('param1', 'int'), ('param2', 'string'), ('param3', 'array')]) self.assertIn('param1 int,param2 string,param3 array', udf._expanded_sql()) def test_udf_bad_imports(self): with self.assertRaisesRegexp(TypeError, 'Argument imports should be a list'): self._create_udf(imports='gcs://mylib') def test_udf_imports_non_js(self): with self.assertRaisesRegexp(Exception, 'Imports are available for Javascript'): self._create_udf(language='sql') def test_query_with_udf(self): code = 'console.log("test");' return_type = 'integer' params = [('test_param', 'integer')] language = 'js' imports = '' udf = google.datalab.bigquery.UDF('test_udf', code, return_type, params, language, imports) sql = 'SELECT test_udf(col) FROM mytable' expected_sql = '''\ CREATE TEMPORARY FUNCTION test_udf (test_param integer) RETURNS integer LANGUAGE js AS """ console.log("test"); """ OPTIONS ( ); SELECT test_udf(col) FROM mytable\ ''' query = google.datalab.bigquery.Query(sql, udfs={'udf': udf}) self.assertEquals(query.sql, expected_sql) # Alternate form of passing the udf using notebook environment query = google.datalab.bigquery.Query(sql, udfs=['udf'], env={'udf': udf}) self.assertEquals(query.sql, expected_sql) def test_query_with_sql_udf(self): code = 'test_param + 1' return_type = 'INT64' params = [('test_param', 'INT64')] language = 'sql' imports = '' udf = google.datalab.bigquery.UDF('test_udf', code, return_type, params, language, imports) sql = 'SELECT test_udf(col) FROM mytable' expected_sql = '''\ CREATE TEMPORARY FUNCTION test_udf (test_param INT64) RETURNS INT64 AS ( test_param + 1 ); SELECT test_udf(col) FROM mytable\ ''' query = google.datalab.bigquery.Query(sql, udfs={'udf': udf}) self.assertEquals(query.sql, expected_sql) # Alternate form of passing the udf using notebook environment query = google.datalab.bigquery.Query(sql, udfs=['udf'], env={'udf': udf}) self.assertEquals(query.sql, expected_sql) ================================================ FILE: tests/bigquery/view_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import mock import unittest import google.auth import google.datalab import google.datalab.bigquery class TestCases(unittest.TestCase): def test_view_repr_sql(self): name = 'test.testds.testView0' view = google.datalab.bigquery.View(name, TestCases._create_context()) self.assertEqual('`%s`' % name, view._repr_sql_()) @mock.patch('google.datalab.bigquery._api.Api.tables_insert') @mock.patch('google.datalab.bigquery._api.Api.tables_get') @mock.patch('google.datalab.bigquery._api.Api.tables_list') @mock.patch('google.datalab.bigquery._api.Api.datasets_get') def test_view_create(self, mock_api_datasets_get, mock_api_tables_list, mock_api_tables_get, mock_api_tables_insert): mock_api_datasets_get.return_value = None mock_api_tables_list.return_value = [] mock_api_tables_get.return_value = None mock_api_tables_insert.return_value = TestCases._create_tables_insert_success_result() name = 'test.testds.testView0' sql = 'select * from test.testds.testTable0' view = google.datalab.bigquery.View(name, TestCases._create_context()) result = view.create(sql) self.assertTrue(view.exists()) self.assertEqual('`%s`' % name, view._repr_sql_()) self.assertIsNotNone(result, 'Expected a view') @mock.patch('google.datalab.bigquery._api.Api.tables_insert') @mock.patch('google.datalab.bigquery._api.Api.tabledata_list') @mock.patch('google.datalab.bigquery._api.Api.jobs_insert_query') @mock.patch('google.datalab.bigquery._api.Api.jobs_query_results') @mock.patch('google.datalab.bigquery._api.Api.jobs_get') @mock.patch('google.datalab.bigquery._api.Api.tables_get') def test_view_result(self, mock_api_tables_get, mock_api_jobs_get, mock_api_jobs_query_results, mock_api_insert_query, mock_api_tabledata_list, mock_api_tables_insert): mock_api_insert_query.return_value = TestCases._create_insert_done_result() mock_api_tables_insert.return_value = TestCases._create_tables_insert_success_result() mock_api_jobs_query_results.return_value = {'jobComplete': True} mock_api_tables_get.return_value = TestCases._create_tables_get_result() mock_api_jobs_get.return_value = {'status': {'state': 'DONE'}} mock_api_tabledata_list.return_value = TestCases._create_single_row_result() name = 'test.testds.testView0' sql = 'select * from test.testds.testTable0' context = TestCases._create_context() view = google.datalab.bigquery.View(name, context) view.create(sql) q = google.datalab.bigquery.Query.from_view(view) results = q.execute(context=context).result() self.assertEqual(1, results.length) first_result = results[0] self.assertEqual('value1', first_result['field1']) @mock.patch('google.datalab.bigquery._api.Api.tables_insert') @mock.patch('google.datalab.bigquery._api.Api.tables_get') @mock.patch('google.datalab.bigquery._api.Api.table_update') @mock.patch('google.datalab.Context.default') def test_view_update(self, mock_context_default, mock_api_table_update, mock_api_tables_get, mock_api_tables_insert): mock_api_tables_insert.return_value = TestCases._create_tables_insert_success_result() mock_context_default.return_value = TestCases._create_context() mock_api_table_update.return_value = None friendly_name = 'casper' description = 'ghostly logs' sql = 'select * from `test.testds.testTable0`' info = {'friendlyName': friendly_name, 'description': description, 'view': {'query': sql}} mock_api_tables_get.return_value = info name = 'test.testds.testView0' view = google.datalab.bigquery.View(name, TestCases._create_context()) view.create(sql) self.assertEqual(friendly_name, view.friendly_name) self.assertEqual(description, view.description) self.assertEqual(sql, view.query.sql) new_friendly_name = 'aziraphale' new_description = 'demon duties' new_query = 'SELECT 3 AS x' view.update(new_friendly_name, new_description, new_query) self.assertEqual(new_friendly_name, view.friendly_name) self.assertEqual(new_description, view.description) self.assertEqual(new_query, view.query.sql) @staticmethod def _create_tables_insert_success_result(): return {'selfLink': 'http://foo'} @staticmethod def _create_insert_done_result(): # pylint: disable=g-continuation-in-parens-misaligned return { 'jobReference': { 'jobId': 'test_job' }, 'configuration': { 'query': { 'destinationTable': { 'projectId': 'project', 'datasetId': 'dataset', 'tableId': 'table' } } }, 'jobComplete': True, } @staticmethod def _create_tables_get_result(num_rows=1, schema=None): if not schema: schema = [{'name': 'field1', 'type': 'string'}] return { 'numRows': num_rows, 'schema': { 'fields': schema }, } @staticmethod def _create_single_row_result(): # pylint: disable=g-continuation-in-parens-misaligned return { 'totalRows': 1, 'rows': [ {'f': [{'v': 'value1'}]} ] } @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return google.datalab.Context(project_id, creds) ================================================ FILE: tests/context_tests.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import unittest import mock from google.datalab import Context from google.datalab.utils import _utils as du class TestCases(unittest.TestCase): def test_credentials(self): dummy_creds = {} c = Context('test_project', credentials=dummy_creds) self.assertEqual(c.credentials, dummy_creds) dummy_creds = {'test': 'test'} c.set_credentials(dummy_creds) self.assertEqual(c.credentials, dummy_creds) def test_config(self): dummy_config = {} c = Context('test_project', credentials=None, config=dummy_config) self.assertEqual(c.config, dummy_config) dummy_config = {'test': 'test'} c.set_config(dummy_config) self.assertEqual(c.config, dummy_config) c = Context('test_project', None, None) self.assertEqual(c.config, Context._get_default_config()) def test_project(self): dummy_project = 'test_project' c = Context(dummy_project, credentials=None, config=None) self.assertEqual(c.project_id, dummy_project) dummy_project = 'test_project2' c.set_project_id(dummy_project) self.assertEqual(c.project_id, dummy_project) c = Context(None, None, None) with self.assertRaises(Exception): print(c.project_id) @mock.patch('google.datalab.utils._utils.get_credentials') @mock.patch('google.datalab.utils._utils.get_default_project_id') @mock.patch('google.datalab.utils._utils.save_project_id') def test_default_project(self, mock_save_project_id, mock_get_default_project_id, mock_get_credentials): # verify setting the project on a default Context object sets # the global default project global default_project default_project = '' def save_project(project=None): global default_project default_project = project def get_project(): global default_project return default_project mock_save_project_id.side_effect = save_project mock_get_default_project_id.side_effect = get_project mock_get_credentials.return_value = '' c = Context.default() dummy_project = 'test_project3' c.set_project_id(dummy_project) self.assertEqual(du.get_default_project_id(), dummy_project) @mock.patch('google.datalab.utils._utils.get_credentials') def test_is_signed_in(self, mock_get_credentials): mock_get_credentials.side_effect = Exception('No creds!') self.assertFalse(Context._is_signed_in()) mock_get_credentials.return_value = {} mock_get_credentials.side_effect = None self.assertTrue(Context._is_signed_in()) @mock.patch('google.datalab.utils._utils.get_credentials') @mock.patch('google.datalab.utils._utils.get_default_project_id') @mock.patch('google.datalab.utils._utils.save_project_id') def test_default_context(self, mock_save_project_id, mock_get_default_project_id, mock_get_credentials): mock_get_default_project_id.return_value = 'default_project' mock_get_credentials.return_value = '' c = Context.default() default_project = c.project_id self.assertEqual(default_project, 'default_project') # deliberately change the default project and make sure it's reset c.set_project_id('test_project4') self.assertEqual(Context.default().project_id, 'default_project') ================================================ FILE: tests/integration/storage_test.py ================================================ """Integration tests for google.datalab.storage.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import logging import random import string import unittest import google.datalab from google.datalab import storage class StorageTest(unittest.TestCase): def setUp(self): self._context = google.datalab.Context.default() logging.info('Using project: %s', self._context.project_id) suffix = ''.join(random.choice(string.lowercase) for _ in range(8)) self._test_bucket_name = '{}-{}'.format(self._context.project_id, suffix) logging.info('test bucket: %s', self._test_bucket_name) def test_object_deletion_consistency(self): b = storage.Bucket(self._test_bucket_name, context=self._context) b.create() o = b.object('sample') o.write_stream('contents', 'text/plain') o.delete() b.delete() if __name__ == '__main__': logging.basicConfig(level=logging.INFO) unittest.main() ================================================ FILE: tests/kernel/__init__.py ================================================ # Copyright 2016 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. ================================================ FILE: tests/kernel/bigquery_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import json import mock import pandas import six import unittest from datetime import datetime import google import google.auth import google.datalab # noqa import google.datalab.bigquery as bq import google.datalab.bigquery.commands import google.datalab.storage import google.datalab.utils.commands # noqa # import Python so we can mock the parts we need to here. import IPython import IPython.core.magic def noop_decorator(func): return func IPython.core.magic.register_line_cell_magic = noop_decorator IPython.core.magic.register_line_magic = noop_decorator IPython.core.magic.register_cell_magic = noop_decorator IPython.get_ipython = mock.Mock() class TestCases(unittest.TestCase): def test_table_schema(self): import jsonschema good_schemas = [ { 'schema': [ {'name': 'col1', 'type': 'int64', 'mode': 'NULLABLE', 'description': 'description1'}, {'name': 'col1', 'type': 'STRING', 'mode': 'required', 'description': 'description1'} ] }, { 'schema': [ {'name': 'col1', 'type': 'record', 'mode': 'repeated', 'description': 'description1', 'fields': [ {'name': 'field1', 'type': 'int64'}, {'name': 'field2', 'type': 'int64'}, {'name': 'field3', 'type': 'record', 'fields': [ {'name': 'nestedField1', 'type': 'STRING'}, {'name': 'nestedField2', 'type': 'STRING'} ]} ]} ] } ] bad_schemas = [ { # Bad type. 'schema': [ {'name': 'col1', 'type': 'badtype'} ] }, { # Bad type. Strictly upper and lower case are supported. 'schema': [ {'name': 'col1', 'type': 'stRING'} ] }, { # Missing name. 'schema': [ {'type': 'string'} ] }, { # Missing type. 'schema': [ {'name': 'col1'} ] }, { # Fields should be an array. 'schema': [ {'name': 'col1', 'type': 'string', 'fields': 'badfields'} ] } ] for s in good_schemas: record = google.datalab.utils.commands.parse_config(json.dumps(s), {}) jsonschema.validate(record, bq.commands._bigquery.BigQuerySchema.TABLE_SCHEMA_SCHEMA) for s in bad_schemas: record = google.datalab.utils.commands.parse_config(json.dumps(s), {}) with self.assertRaises(Exception): jsonschema.validate(record, bq.commands._bigquery.table_schema_schema) @mock.patch('google.datalab.utils.commands.notebook_environment') @mock.patch('google.datalab.Context.default') def test_udf_cell(self, mock_default_context, mock_notebook_environment): env = {} mock_default_context.return_value = TestCases._create_context() mock_notebook_environment.return_value = env # no cell body with self.assertRaisesRegexp(Exception, 'UDF return type must be defined'): bq.commands._bigquery._udf_cell({'name': 'test_udf', 'language': 'js'}, '') # no name with self.assertRaisesRegexp(Exception, 'Declaration must be of the form %%bq udf --name'): bq.commands._bigquery._udf_cell({'name': None, 'language': 'js'}, 'test_cell_body') # no return type cell_body = """ // @param word STRING // @param corpus STRING re = new RegExp(word, 'g'); return corpus.match(re || []).length; """ with self.assertRaisesRegexp(Exception, 'UDF return type must be defined'): bq.commands._bigquery._udf_cell({'name': 'count_occurrences', 'language': 'js'}, cell_body) # too many return statements cell_body = """ // @param word STRING // @param corpus STRING // @returns INTEGER // @returns STRING re = new RegExp(word, 'g'); return corpus.match(re || []).length; """ with self.assertRaisesRegexp(Exception, 'Found more than one return'): bq.commands._bigquery._udf_cell({'name': 'count_occurrences', 'language': 'js'}, cell_body) cell_body = """ // @param word STRING // @param corpus STRING // @returns INTEGER re = new RegExp(word, 'g'); return corpus.match(re || []).length; """ bq.commands._bigquery._udf_cell({'name': 'count_occurrences', 'language': 'js'}, cell_body) udf = env['count_occurrences'] self.assertIsNotNone(udf) self.assertEquals('count_occurrences', udf._name) self.assertEquals('js', udf._language) self.assertEquals('INTEGER', udf._return_type) self.assertEquals([('word', 'STRING'), ('corpus', 'STRING')], udf._params) self.assertEquals([], udf._imports) # param types with spaces (regression for pull request 373) cell_body = """ // @param test_param ARRAY> // @returns INTEGER """ bq.commands._bigquery._udf_cell({'name': 'count_occurrences', 'language': 'js'}, cell_body) udf = env['count_occurrences'] self.assertIsNotNone(udf) self.assertEquals([('test_param', 'ARRAY>')], udf._params) @mock.patch('google.datalab.utils.commands.notebook_environment') def test_datasource_cell(self, mock_notebook_env): env = {} mock_notebook_env.return_value = env args = {'name': 'test_ds', 'paths': 'test_path', 'format': None, 'compressed': None} cell_body = { 'schema': [ {'name': 'col1', 'type': 'int64', 'mode': 'NULLABLE', 'description': 'description1'}, {'name': 'col1', 'type': 'STRING', 'mode': 'required', 'description': 'description1'} ] } bq.commands._bigquery._datasource_cell(args, json.dumps(cell_body)) self.assertIsInstance(env['test_ds'], bq.ExternalDataSource) self.assertEqual(env['test_ds']._source, ['test_path']) self.assertEqual(env['test_ds']._source_format, 'csv') @mock.patch('google.datalab.bigquery.Query.execute') @mock.patch('google.datalab.utils.commands.notebook_environment') @mock.patch('google.datalab.Context.default') def test_query_cell(self, mock_default_context, mock_notebook_environment, mock_query_execute): env = {} mock_default_context.return_value = TestCases._create_context() mock_notebook_environment.return_value = env IPython.get_ipython().user_ns = env # test query creation q1_body = 'SELECT * FROM test_table' # no query name specified. should execute bq.commands._bigquery._query_cell( {'name': None, 'udfs': None, 'datasources': None, 'subqueries': None}, q1_body) mock_query_execute.assert_called_with() # test query creation bq.commands._bigquery._query_cell( {'name': 'q1', 'udfs': None, 'datasources': None, 'subqueries': None}, q1_body) mock_query_execute.assert_called_with() q1 = env['q1'] self.assertIsNotNone(q1) self.assertEqual(q1.udfs, {}) self.assertEqual(q1.subqueries, {}) self.assertEqual(q1_body, q1._sql) self.assertEqual(q1_body, q1.sql) # test subquery reference and expansion q2_body = 'SELECT * FROM q1' bq.commands._bigquery._query_cell( {'name': 'q2', 'udfs': None, 'datasources': None, 'subqueries': ['q1']}, q2_body) q2 = env['q2'] self.assertIsNotNone(q2) self.assertEqual(q2.udfs, {}) self.assertEqual({'q1': q1}, q2.subqueries) expected_sql = '''\ WITH q1 AS ( %s ) %s''' % (q1_body, q2_body) self.assertEqual(q2_body, q2._sql) self.assertEqual(expected_sql, q2.sql) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Query.execute') @mock.patch('google.datalab.utils.commands.get_notebook_item') def test_execute_cell(self, mock_get_notebook_item, mock_query_execute, mock_default_context): args = {'query': 'test_query', 'verbose': None, 'to_dataframe': None, 'table': None, 'dataframe_start_row': None, 'dataframe_max_rows': None, 'nocache': None, 'mode': None, 'large': None} cell_body = '' mock_get_notebook_item.return_value = bq.Query('test_sql') bq.commands._bigquery._execute_cell(args, cell_body) args['to_dataframe'] = True bq.commands._bigquery._execute_cell(args, cell_body) # test --verbose args['verbose'] = True with mock.patch('sys.stdout', new=six.StringIO()) as mocked_stdout: bq.commands._bigquery._execute_cell(args, cell_body) self.assertEqual(mocked_stdout.getvalue(), 'test_sql\n') args['verbose'] = False @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return google.datalab.Context(project_id, creds) @mock.patch('google.datalab.bigquery.commands._bigquery._get_table') @mock.patch('google.datalab.bigquery.Query.execute') @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.utils.commands.get_notebook_item') def test_sample_cell(self, mock_get_notebook_item, mock_context_default, mock_query_execute, mock_get_table): args = {'query': None, 'table': None, 'view': None, 'fields': None, 'count': 5, 'percent': 1, 'key_field': None, 'order': None, 'profile': None, 'verbose': None, 'method': 'limit'} cell_body = '' with self.assertRaises(Exception): bq.commands._bigquery._sample_cell(args, cell_body) args['query'] = 'test_query' mock_get_notebook_item.return_value = None with self.assertRaises(Exception): bq.commands._bigquery._sample_cell(args, cell_body) # query passed, no other parameters mock_get_notebook_item.return_value = bq.Query('test_sql') bq.commands._bigquery._sample_cell(args, cell_body) call_args = mock_query_execute.call_args[0] call_kwargs = mock_query_execute.call_args[1] self.assertEqual(call_args[0]._output_type, 'table') self.assertEqual(call_kwargs['sampling']('test_sql'), bq._sampling.Sampling.default()('test_sql')) # test --profile args['profile'] = True mock_query_execute.return_value.result = lambda: pandas.DataFrame({'c': 'v'}, index=['c']) bq.commands._bigquery._sample_cell(args, cell_body) call_args = mock_query_execute.call_args[0] self.assertEqual(call_args[0]._output_type, 'dataframe') # test --verbose args['verbose'] = True with mock.patch('sys.stdout', new=six.StringIO()) as mocked_stdout: bq.commands._bigquery._sample_cell(args, cell_body) self.assertEqual(mocked_stdout.getvalue(), 'test_sql\n') args['verbose'] = False # bad query mock_get_notebook_item.return_value = None with self.assertRaises(Exception): bq.commands._bigquery._sample_cell(args, cell_body) # table passed args['query'] = None args['table'] = 'test.table' mock_get_notebook_item.return_value = bq.Table('test.table') bq.commands._bigquery._sample_cell(args, cell_body) # bad table mock_get_table.return_value = None with self.assertRaises(Exception): bq.commands._bigquery._sample_cell(args, cell_body) # view passed args['table'] = None args['view'] = 'test_view' mock_get_notebook_item.return_value = bq.View('test.view') bq.commands._bigquery._sample_cell(args, cell_body) # bad view mock_get_notebook_item.return_value = None with self.assertRaises(Exception): bq.commands._bigquery._sample_cell(args, cell_body) @mock.patch('google.datalab.bigquery.Query.dry_run') @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.commands._bigquery._get_query_argument') def test_dry_run_cell(self, mock_get_query_argument, mock_context_default, mock_dry_run): args = {'query': 'test_query'} cell_body = '' mock_get_query_argument.return_value = bq.Query('test_sql') # test --verbose args['verbose'] = True with mock.patch('sys.stdout', new=six.StringIO()) as mocked_stdout: bq.commands._bigquery._dryrun_cell(args, cell_body) self.assertEqual(mocked_stdout.getvalue(), 'test_sql\n') args['verbose'] = False @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.utils._utils.get_credentials') @mock.patch('google.datalab.bigquery.Table.exists') @mock.patch('google.datalab.utils.commands.get_notebook_item') def test_get_table(self, mock_get_notebook_item, mock_table_exists, mock_get_credentials, mock_default_context): # test bad name mock_get_notebook_item.return_value = None mock_table_exists.return_value = False t = bq.commands._bigquery._get_table('bad.name') self.assertIsNone(t) # test good table name test_table_name = 'testproject.test.table' mock_get_notebook_item.return_value = bq.Table(test_table_name) t = bq.commands._bigquery._get_table(test_table_name) self.assertEqual(t.full_name, test_table_name) # test table name reference mock_get_notebook_item.return_value = test_table_name mock_table_exists.return_value = True t = bq.commands._bigquery._get_table(test_table_name) self.assertEqual(t.full_name, test_table_name) self.assertIn(test_table_name, bq.commands._bigquery._existing_table_cache) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Datasets') @mock.patch('google.datalab.bigquery.commands._bigquery._render_list') def test_dataset_line_list(self, mock_render_list, mock_datasets, mock_default_context): args = {'command': 'list', 'filter': None, 'project': None} datasets = ['ds1', 'ds2', 'ds11'] mock_datasets.return_value = iter(datasets) bq.commands._bigquery._dataset_line(args) mock_render_list.assert_called_with(datasets) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Datasets') @mock.patch('google.datalab.bigquery.commands._bigquery._render_list') def test_dataset_line_list_asterisk(self, mock_render_list, mock_datasets, mock_default_context): args = {'command': 'list', 'filter': '*', 'project': None} datasets = ['ds1', 'ds2', 'ds11'] mock_datasets.return_value = iter(datasets) bq.commands._bigquery._dataset_line(args) mock_render_list.assert_called_with(datasets) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Datasets') @mock.patch('google.datalab.bigquery.commands._bigquery._render_list') def test_dataset_line_list_substr_filter(self, mock_render_list, mock_datasets, mock_default_context): args = {'command': 'list', 'filter': 'ds1*', 'project': None} datasets = ['ds1', 'ds2', 'ds11'] mock_datasets.return_value = iter(datasets) bq.commands._bigquery._dataset_line(args) mock_render_list.assert_called_with(['ds1', 'ds11']) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Datasets') @mock.patch('google.datalab.bigquery.commands._bigquery._render_list') def test_dataset_line_list_exact_filter(self, mock_render_list, mock_datasets, mock_default_context): args = {'command': 'list', 'filter': 'ds1', 'project': None} datasets = ['ds1', 'ds2', 'ds11'] mock_datasets.return_value = iter(datasets) bq.commands._bigquery._dataset_line(args) mock_render_list.assert_called_with(['ds1']) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Datasets') @mock.patch('google.datalab.bigquery.commands._bigquery._render_list') def test_dataset_line_list_project(self, mock_render_list, mock_datasets, mock_default_context): args = {'command': 'list', 'filter': None, 'project': 'testproject'} mock_default_context.return_value = self._create_context() bq.commands._bigquery._dataset_line(args) self.assertEqual(mock_datasets.call_args[0][0].project_id, 'testproject') @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Dataset') def test_dataset_line_create(self, mock_dataset, mock_default_context): args = {'command': 'create', 'name': 'dataset-name', 'friendly': 'test-name'} bq.commands._bigquery._dataset_line(args) mock_dataset.assert_called_with('dataset-name') mock_dataset.return_value.create.assert_called_with(friendly_name='test-name') mock_dataset.side_effect = Exception('error') with mock.patch('sys.stdout', new=six.StringIO()) as mocked_stdout: bq.commands._bigquery._dataset_line(args) self.assertIn('Failed to create dataset dataset-name', mocked_stdout.getvalue()) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Dataset') def test_dataset_line_delete(self, mock_dataset, mock_default_context): args = {'command': 'delete', 'name': 'dataset-name'} bq.commands._bigquery._dataset_line(args) mock_dataset.assert_called_with('dataset-name') mock_dataset.return_value.delete.assert_called_with() mock_dataset.side_effect = Exception('error') with mock.patch('sys.stdout', new=six.StringIO()) as mocked_stdout: bq.commands._bigquery._dataset_line(args) self.assertIn('Failed to delete dataset dataset-name', mocked_stdout.getvalue()) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Datasets') def test_table_cell_list(self, mock_datasets, mock_default_context): args = {'command': 'list', 'filter': None, 'dataset': None, 'project': None} tables = [bq.Table('project.test.' + name) for name in ['t1', 't2', 't3']] ds1 = mock.MagicMock() ds1.__iter__.return_value = iter([tables[0], tables[1]]) ds2 = mock.MagicMock() ds2.__iter__.return_value = iter([tables[2]]) mock_datasets.return_value = iter([ds1, ds2]) self.assertEqual( bq.commands._bigquery._table_cell(args, None), '
  • project.test.t1
  • project.test.t2
  • project.test.t3
') @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Datasets') def test_table_cell_list_asterisk(self, mock_datasets, mock_default_context): args = {'command': 'list', 'filter': '*', 'dataset': None, 'project': None} tables = [bq.Table('project.test.' + name) for name in ['t1', 't2', 't3']] ds1 = mock.MagicMock() ds1.__iter__.return_value = iter([tables[0], tables[1]]) ds2 = mock.MagicMock() ds2.__iter__.return_value = iter([tables[2]]) mock_datasets.return_value = iter([ds1, ds2]) self.assertEqual( bq.commands._bigquery._table_cell(args, None), '
  • project.test.t1
  • project.test.t2
  • project.test.t3
') @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Datasets') def test_table_cell_list_substr_filter(self, mock_datasets, mock_default_context): args = {'command': 'list', 'filter': '*t1*', 'dataset': None, 'project': None} tables = [bq.Table('project.test.' + name) for name in ['t1', 't2', 't11']] ds1 = mock.MagicMock() ds1.__iter__.return_value = iter([tables[0], tables[1]]) ds2 = mock.MagicMock() ds2.__iter__.return_value = iter([tables[2]]) mock_datasets.return_value = iter([ds1, ds2]) self.assertEqual( bq.commands._bigquery._table_cell(args, None), '
  • project.test.t1
  • project.test.t11
') @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Datasets') def test_table_cell_list_bad_filter(self, mock_datasets, mock_default_context): args = {'command': 'list', 'filter': 't7', 'dataset': None, 'project': None} tables = [bq.Table('project.test.' + name) for name in ['t1', 't2', 't11']] ds1 = mock.MagicMock() ds1.__iter__.return_value = iter([tables[0], tables[1]]) ds2 = mock.MagicMock() ds2.__iter__.return_value = iter([tables[2]]) mock_datasets.return_value = iter([ds1, ds2]) self.assertEqual( bq.commands._bigquery._table_cell(args, None), '
<empty>
') @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Dataset') def test_table_cell_list_dataset(self, mock_dataset, mock_default_context): args = {'command': 'list', 'filter': '', 'dataset': 'test-dataset', 'project': None} tables = [bq.Table('project.test.' + name) for name in ['t1', 't2']] mock_dataset.return_value = iter(tables) self.assertEqual( bq.commands._bigquery._table_cell(args, None), '
  • project.test.t1
  • project.test.t2
') @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Datasets') def test_table_cell_list_project(self, mock_datasets, mock_default_context): args = {'command': 'list', 'filter': '', 'dataset': None, 'project': 'test-project'} tables = [bq.Table('project.test.' + name) for name in ['t1', 't2', 't3']] ds1 = mock.MagicMock() ds1.__iter__.return_value = iter([tables[0], tables[1]]) ds2 = mock.MagicMock() ds2.__iter__.return_value = iter([tables[2]]) mock_datasets.return_value = iter([ds1, ds2]) self.assertEqual( bq.commands._bigquery._table_cell(args, None), '
  • project.test.t1
  • project.test.t2
  • project.test.t3
') @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Dataset') def test_table_cell_list_dataset_project(self, mock_dataset, mock_default_context): args = {'command': 'list', 'filter': '', 'dataset': 'test-dataset', 'project': 'test-project'} tables = [bq.Table('project.test.' + name) for name in ['t1', 't2']] mock_dataset.return_value = iter(tables) self.assertEqual( bq.commands._bigquery._table_cell(args, None), '
  • project.test.t1
  • project.test.t2
') call_args = mock_dataset.call_args[0] self.assertEqual(call_args[0], 'test-dataset') self.assertEqual(call_args[1].project_id, 'test-project') @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Table') def test_table_cell_create_bad_params(self, mock_table, mock_default_context): args = {'command': 'create', 'name': 'test-table', 'overwrite': None} with mock.patch('sys.stdout', new=six.StringIO()) as mocked_stdout: bq.commands._bigquery._table_cell(args, None) self.assertIn('Failed to create test-table: no schema', mocked_stdout.getvalue()) mock_table.side_effect = Exception with mock.patch('sys.stdout', new=six.StringIO()) as mocked_stdout: bq.commands._bigquery._table_cell(args, json.dumps({})) self.assertIn('\'schema\' is a required property', mocked_stdout.getvalue()) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Table') def test_table_cell_create(self, mock_table, mock_default_context): args = {'command': 'create', 'name': 'test-table', 'overwrite': None} cell_body = { 'schema': [ {'name': 'col1', 'type': 'int64', 'mode': 'NULLABLE', 'description': 'description1'}, {'name': 'col1', 'type': 'STRING', 'mode': 'required', 'description': 'description1'} ] } bq.commands._bigquery._table_cell(args, json.dumps(cell_body)) call_kwargs = mock_table.return_value.create.call_args[1] self.assertEqual(None, call_kwargs['overwrite']) self.assertEqual(bq.Schema(cell_body['schema']), call_kwargs['schema']) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.commands._bigquery._get_table') def test_table_cell_describe(self, mock_get_table, mock_default_context): args = {'command': 'describe', 'name': 'test-table', 'overwrite': None} mock_get_table.return_value = None with self.assertRaisesRegexp(Exception, 'Could not find table'): bq.commands._bigquery._table_cell(args, None) mock_get_table.return_value = bq.Table('project.test.table') schema = bq.Schema([{ 'name': 'col1', 'type': 'string' }]) mock_get_table.return_value._schema = schema rendered = bq.commands._bigquery._table_cell(args, None) expected_html1 = 'bq.renderSchema(dom, [{"type": "string", "name": "col1"}]);' expected_html2 = 'bq.renderSchema(dom, [{"name": "col1", "type": "string"}]);' self.assertTrue(expected_html1 in rendered or expected_html2 in rendered) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Table') def test_table_cell_delete(self, mock_table, mock_default_context): args = {'command': 'delete', 'name': 'test-table'} mock_table.return_value.delete.side_effect = Exception with mock.patch('sys.stdout', new=six.StringIO()) as mocked_stdout: bq.commands._bigquery._table_cell(args, None) self.assertIn('Failed to delete table test-table', mocked_stdout.getvalue()) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.commands._bigquery._get_table') def test_table_cell_view(self, mock_get_table, mock_default_context): args = {'command': 'view', 'name': 'test-table'} table = bq.Table('project.test.table') mock_get_table.return_value = None with self.assertRaisesRegexp(Exception, 'Could not find table test-table'): bq.commands._bigquery._table_cell(args, None) mock_get_table.return_value = table self.assertEqual(table, bq.commands._bigquery._table_cell(args, None)) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Query.execute') @mock.patch('google.datalab.utils.commands.get_notebook_item') def test_extract_cell_query(self, mock_get_notebook_item, mock_query_execute, mock_default_context): args = {'table': None, 'view': None, 'query': None, 'path': None, 'format': None, 'delimiter': None, 'header': None, 'compress': None, 'nocache': None} with self.assertRaisesRegexp(Exception, 'A query, table, or view is needed'): bq.commands._bigquery._extract_cell(args, None) args['query'] = 'test-query' mock_get_notebook_item.return_value = None with self.assertRaisesRegexp(Exception, 'Could not find query test-query'): bq.commands._bigquery._extract_cell(args, None) mock_get_notebook_item.return_value = bq.Query('sql') mock_query_execute.return_value.failed = True mock_query_execute.return_value.fatal_error = 'test-error' with self.assertRaisesRegexp(Exception, 'Extract failed: test-error'): bq.commands._bigquery._extract_cell(args, None) mock_query_execute.return_value.failed = False mock_query_execute.return_value.errors = 'test-errors' with self.assertRaisesRegexp(Exception, 'Extract completed with errors: test-errors'): bq.commands._bigquery._extract_cell(args, None) mock_query_execute.return_value.errors = None mock_query_execute.return_value.result = lambda: 'results' self.assertEqual(bq.commands._bigquery._extract_cell(args, None), 'results') cell_body = { 'parameters': [ {'name': 'arg1', 'type': 'INT64', 'value': 5} ] } bq.commands._bigquery._extract_cell(args, json.dumps(cell_body)) mock_get_notebook_item.assert_called_with('test-query') call_args = mock_query_execute.call_args[1] found_item = False for item in call_args['query_params']: if item['name'] == 'arg1': found_item = True self.assertDictEqual(item, { 'parameterValue': {'value': 5}, 'name': 'arg1', 'parameterType': {'type': 'INT64'} }) self.assertTrue(found_item) @mock.patch('google.datalab.bigquery.Table.extract') @mock.patch('google.datalab.bigquery.commands._bigquery._get_table') @mock.patch('google.datalab.utils.commands.get_notebook_item') def test_extract_cell_table(self, mock_get_notebook_item, mock_get_table, mock_table_extract): args = {'table': 'test-table', 'path': 'test-path', 'format': 'json', 'delimiter': None, 'header': None, 'compress': None, 'nocache': None} mock_get_table.return_value = None with self.assertRaisesRegexp(Exception, 'Could not find table test-table'): bq.commands._bigquery._extract_cell(args, None) mock_get_table.return_value = bq.Table('project.test.table', self._create_context()) mock_table_extract.return_value.result = lambda: 'test-results' mock_table_extract.return_value.failed = False mock_table_extract.return_value.errors = None self.assertEqual(bq.commands._bigquery._extract_cell(args, None), 'test-results') mock_table_extract.assert_called_with('test-path', format='json', csv_delimiter=None, csv_header=None, compress=None) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Query.execute') @mock.patch('google.datalab.utils.commands.get_notebook_item') def test_extract_cell_view(self, mock_get_notebook_item, mock_query_execute, mock_default_context): args = {'view': 'test-view', 'table': None, 'query': None, 'path': 'test-path', 'format': None, 'delimiter': None, 'header': None, 'compress': None, 'nocache': None} mock_get_notebook_item.return_value = None with self.assertRaisesRegexp(Exception, 'Could not find view test-view'): bq.commands._bigquery._extract_cell(args, None) mock_get_notebook_item.return_value = bq.View('project.test.view', self._create_context()) mock_query_execute.return_value.result = lambda: 'test-results' mock_query_execute.return_value.failed = False mock_query_execute.return_value.errors = None self.assertEqual(bq.commands._bigquery._extract_cell(args, None), 'test-results') mock_get_notebook_item.assert_called_with('test-view') @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.bigquery.Table.create') @mock.patch('google.datalab.bigquery.Table.exists') @mock.patch('google.datalab.bigquery.Table.load') @mock.patch('google.datalab.bigquery.commands._bigquery._get_table') def test_load_cell(self, mock_get_table, mock_table_load, mock_table_exists, mock_table_create, mock_default_context): args = {'table': 'project.test.table', 'mode': 'create', 'path': 'test/path_%(_ds)s', 'skip': None, 'csv': None, 'delimiter': None, 'format': 'csv', 'strict': None, 'quote': None} context = self._create_context() mock_get_table.return_value = bq.Table('project.test.table') job = bq._query_job.QueryJob('test_id', 'project.test.table', 'test_sql', context) mock_table_exists.return_value = True with self.assertRaisesRegexp(Exception, 'already exists; use "append" or "overwrite" as mode.'): bq.commands._bigquery._load_cell(args, None) mock_table_exists.return_value = False with self.assertRaisesRegexp(Exception, 'Table does not exist, and no schema specified'): bq.commands._bigquery._load_cell(args, None) cell_body = { 'schema': [ {'name': 'col1', 'type': 'int64', 'mode': 'NULLABLE', 'description': 'description1'}, {'name': 'col1', 'type': 'STRING', 'mode': 'required', 'description': 'description1'} ], 'parameters': [ {'name': 'custom', 'type': 'FLOAT', 'value': 4.23} ] } mock_table_load.return_value = job job._is_complete = True job._fatal_error = 'fatal error' with self.assertRaisesRegexp(Exception, 'Load failed: fatal error'): bq.commands._bigquery._load_cell(args, json.dumps(cell_body)) job._fatal_error = None job._errors = 'error' with self.assertRaisesRegexp(Exception, 'Load completed with errors: error'): bq.commands._bigquery._load_cell(args, json.dumps(cell_body)) job._errors = None bq.commands._bigquery._load_cell(args, json.dumps(cell_body)) today = datetime.now().date().isoformat() mock_table_load.assert_called_with('test/path_{0}'.format(today), mode='create', source_format='csv', csv_options=mock.ANY, ignore_unknown_values=True) mock_get_table.return_value = None mock_table_exists.return_value = True args['mode'] = 'append' args['format'] = 'csv' bq.commands._bigquery._load_cell(args, None) mock_table_load.assert_called_with('test/path_{0}'.format(today), mode='append', source_format='csv', csv_options=mock.ANY, ignore_unknown_values=True) @mock.patch('google.datalab.Context.default') @mock.patch('google.datalab.storage.Bucket') @mock.patch('google.datalab.utils.commands.get_notebook_item') @mock.patch('google.datalab.utils.commands.notebook_environment') def test_pipeline_cell(self, mock_env, mock_get_notebook_item, mock_bucket_class, mock_default_context): context = TestCases._create_context() mock_default_context.return_value = context mock_bucket_class.return_value = mock.Mock() mock_get_notebook_item.return_value = bq.Query( 'SELECT * FROM publicdata.samples.wikipedia LIMIT 5') small_cell_body = """ emails: foo1@test.com schedule: start: 2009-05-05T22:28:15Z end: 2009-05-06T22:28:15Z interval: '@hourly' input: table: project.test.table transformation: query: foo_query output: table: project.test.table """ args = {'name': 'bq_pipeline_test', 'gcs_dag_bucket': 'foo_bucket', 'gcs_dag_folder': 'dags'} actual = bq.commands._bigquery._pipeline_cell(args, small_cell_body) self.assertIn("successfully deployed", actual) self.assertNotIn("'email': ['foo1@test.com']", actual) args['debug'] = True actual = bq.commands._bigquery._pipeline_cell(args, small_cell_body) self.assertIn("successfully deployed", actual) self.assertIn("'email': ['foo1@test.com']", actual) @mock.patch('google.datalab.utils.commands._html.Html.next_id') @mock.patch('google.datalab.utils.commands._html.HtmlBuilder.render_chart_data') @mock.patch('google.datalab.bigquery._api.Api.tables_get') @mock.patch('google.datalab.utils.commands.get_data') @mock.patch('google.datalab.utils.commands.get_field_list') @mock.patch('google.datalab.bigquery.Table.exists') def test_table_viewer(self, mock_table_exists, mock_get_field_list, mock_get_data, mock_tables_get, mock_render_chart_data, mock_next_id): test_table = bq.Table('testproject.test.table', self._create_context()) mock_table_exists.return_value = False with self.assertRaisesRegexp(Exception, 'does not exist'): bq.commands._bigquery._table_viewer(test_table) mock_table_exists.return_value = True mock_get_field_list.return_value = ['col1'] mock_get_data.return_value = ({'cols': ['col1'], 'rows': ['val1']}, 1) mock_render_chart_data.return_value = 'test_chart_data' mock_next_id.return_value = 'test_id' viewer = bq.commands._bigquery._table_viewer(test_table) mock_table_exists.assert_called() mock_get_field_list.assert_called() mock_render_chart_data.assert_called() expected_html_header = '''
test_chart_data

(testproject.test.table)
''' self.assertIn(expected_html_header, viewer) @mock.patch('google.datalab.bigquery._query_stats.QueryStats._size_formatter') @mock.patch('google.datalab.bigquery.Table.job') @mock.patch('google.datalab.utils.commands._html.Html.next_id') @mock.patch('google.datalab.utils.commands._html.HtmlBuilder.render_chart_data') @mock.patch('google.datalab.bigquery._api.Api.tables_get') @mock.patch('google.datalab.utils.commands.get_data') @mock.patch('google.datalab.utils.commands.get_field_list') @mock.patch('google.datalab.bigquery.Table.exists') def test_query_results_table_viewer(self, mock_table_exists, mock_get_field_list, mock_get_data, mock_tables_get, mock_render_chart_data, mock_next_id, mock_table_job, mock_size_formatter): context = self._create_context() table_name = 'testproject.test.table' job = bq._query_job.QueryJob('test_id', table_name, 'test_sql', context) job._start_time, job._end_time = datetime(2017, 1, 1, 1, 1), datetime(2017, 1, 1, 1, 2) test_table = bq.QueryResultsTable(table_name, context, job) mock_table_exists.return_value = True mock_get_field_list.return_value = ['col1'] mock_get_data.return_value = ({'cols': ['col1'], 'rows': ['val1']}, 1) mock_next_id.return_value = 'test_id' mock_size_formatter.return_value = '10MB' mock_render_chart_data.return_value = 'test_chart_data' viewer = bq.commands._bigquery._table_viewer(test_table) mock_table_exists.assert_called() mock_get_field_list.assert_called() mock_render_chart_data.assert_called() expected_html_header = '''
test_chart_data

(time: 60.0s, 10MB processed, job: test_id)
''' self.assertIn(expected_html_header, viewer) job._cache_hit = True viewer = bq.commands._bigquery._table_viewer(test_table) expected_html_header = '''
test_chart_data

(time: 60.0s, cached, job: test_id)
''' self.assertIn(expected_html_header, viewer) mock_get_data.return_value = ({'rows': []}, -1) viewer = bq.commands._bigquery._table_viewer(test_table) expected_html_header = 'pageSize: 25,' self.assertIn(expected_html_header, viewer) mock_get_data.return_value = ({'rows': ['val'] * 5}, -1) viewer = bq.commands._bigquery._table_viewer(test_table, rows_per_page=10) expected_html_header = 'pageSize: 10,' self.assertIn(expected_html_header, viewer) expected_html_footer = ''' {source_index: 0, fields: 'col1'}, 0, 5); ''' self.assertIn(expected_html_footer, viewer) @mock.patch('google.datalab.utils._utils.get_credentials') @mock.patch('google.datalab.utils._utils.get_default_project_id') @mock.patch('google.datalab.utils._utils.save_project_id') def test_args_to_context(self, mock_save_project, mock_get_default_project, mock_get_credentials): mock_get_credentials.return_value = 'test_creds' mock_get_default_project.return_value = 'testproject' args = {'billing': 'billing_value'} default_context = google.datalab.Context.default() c = google.datalab.utils._utils._construct_context_for_args(args) # make sure it's not the same object self.assertNotEqual(c, default_context) self.assertEqual(c.project_id, default_context.project_id) self.assertEqual(c.credentials, default_context.credentials) # make sure the right config object was passed self.assertEqual(c.config, {'bigquery_billing_tier': 'billing_value'}) default_context.config['test_prop'] = 'test_val' c = google.datalab.utils._utils._construct_context_for_args(args) # make sure other properties in default context were copied self.assertEqual(c.config, {'bigquery_billing_tier': 'billing_value', 'test_prop': 'test_val'}) @mock.patch('google.datalab.utils.commands.get_notebook_item') def test_get_query_argument(self, mock_get_notebook_item): args = {} cell = None env = {} # an Exception should be raised if no query is specified by name or body with self.assertRaises(Exception): bq.commands._bigquery._get_query_argument(args, cell, env) # specify query name, no cell body args = {'query': 'test_query'} mock_get_notebook_item.return_value = bq.Query('test_sql') q = bq.commands._bigquery._get_query_argument(args, cell, env) self.assertEqual(q.sql, 'test_sql') # specify query in cell body, no name args = {} cell = 'test_sql2' q = bq.commands._bigquery._get_query_argument(args, cell, env) self.assertEqual(q.sql, 'test_sql2') # specify query by bad name args = {'query': 'test_query'} mock_get_notebook_item.return_value = None with self.assertRaises(Exception): bq.commands._bigquery._get_query_argument(args, cell, env) def test_get_query_parameters(self): args = {'query': None} cell_body = '' now = datetime.now() with self.assertRaises(Exception): bq.commands._bigquery.get_query_parameters(args, json.dumps(cell_body)) args['query'] = 'test_sql' params = bq.commands._bigquery.get_query_parameters(args, json.dumps(cell_body), date_time=now) # We push the params into a dict so that it's easier to compare params_dict = { item['name']: { 'type': item['parameterType']['type'], 'value': item['parameterValue']['value'] } for item in params } today = now.date() default_query_parameters = { # the datetime formatted as YYYY-MM-DD '_ds': {'type': 'STRING', 'value': today.isoformat()}, # the full ISO-formatted timestamp YYYY-MM-DDTHH:MM:SS.mmmmmm '_ts': {'type': 'STRING', 'value': now.isoformat()}, # the datetime formatted as YYYYMMDD (i.e. YYYY-MM-DD with 'no dashes') '_ds_nodash': {'type': 'STRING', 'value': today.strftime('%Y%m%d')}, # the timestamp formatted as YYYYMMDDTHHMMSSmmmmmm (i.e full ISO-formatted timestamp # YYYY-MM-DDTHH:MM:SS.mmmmmm with no dashes or colons). '_ts_nodash': {'type': 'STRING', 'value': now.strftime('%Y%m%d%H%M%S%f')}, '_ts_year': {'type': 'STRING', 'value': today.strftime('%Y')}, '_ts_month': {'type': 'STRING', 'value': today.strftime('%m')}, '_ts_day': {'type': 'STRING', 'value': today.strftime('%d')}, '_ts_hour': {'type': 'STRING', 'value': now.strftime('%H')}, '_ts_minute': {'type': 'STRING', 'value': now.strftime('%M')}, '_ts_second': {'type': 'STRING', 'value': now.strftime('%S')}, } self.assertDictEqual(params_dict, default_query_parameters) cell_body = { 'parameters': [ {'name': 'arg1', 'type': 'INT64', 'value': 5}, {'name': 'arg2', 'type': 'string', 'value': 'val2'}, {'name': 'arg3', 'type': 'date', 'value': 'val3'} ] } params = bq.commands._bigquery.get_query_parameters(args, json.dumps(cell_body), date_time=now) # We push the params into a dict so that it's easier to compare params_dict = { item['name']: { 'type': item['parameterType']['type'], 'value': item['parameterValue']['value'] } for item in params } cell_body_params_dict = { item['name']: { 'type': item['type'], 'value': item['value'] } for item in cell_body['parameters'] } default_query_parameters.update(cell_body_params_dict) self.assertDictEqual(params_dict, default_query_parameters) ================================================ FILE: tests/kernel/chart_data_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import json import mock import unittest # import Python so we can mock the parts we need to here. import IPython.core.display import IPython.core.magic def noop_decorator(func): return func IPython.core.magic.register_line_cell_magic = noop_decorator IPython.core.magic.register_line_magic = noop_decorator IPython.core.magic.register_cell_magic = noop_decorator IPython.core.display.HTML = lambda x: x IPython.core.display.JSON = lambda x: x import google.datalab.utils.commands # noqa class TestCases(unittest.TestCase): @mock.patch('google.datalab.utils.get_item') def test_get_chart_data(self, mock_get_item): IPython.get_ipython().user_ns = {} t = [ {'country': 'US', 'quantity': 100}, {'country': 'ZA', 'quantity': 50}, {'country': 'UK', 'quantity': 75}, {'country': 'AU', 'quantity': 25} ] mock_get_item.return_value = t ds = google.datalab.utils.commands.get_data_source_index('t') data = google.datalab.utils.commands._chart_data._get_chart_data('', json.dumps({ 'source_index': ds, 'fields': 'country', 'first': 1, 'count': 1 })) self.assertEquals({"data": {"rows": [{"c": [{"v": "ZA"}]}], "cols": [{"type": "string", "id": "country", "label": "country"}]}, "refresh_interval": 0, "options": {}}, data) data = google.datalab.utils.commands._chart_data._get_chart_data('', json.dumps({ 'source_index': ds, 'fields': 'country', 'first': 6, 'count': 1 })) self.assertEquals({"data": {"rows": [], "cols": [{"type": "string", "id": "country", "label": "country"}]}, "refresh_interval": 0, "options": {}}, data) data = google.datalab.utils.commands._chart_data._get_chart_data('', json.dumps({ 'source_index': ds, 'fields': 'country', 'first': 2, 'count': 0 })) self.assertEquals({"data": {"rows": [], "cols": [{"type": "string", "id": "country", "label": "country"}]}, "refresh_interval": 0, "options": {}}, data) ================================================ FILE: tests/kernel/chart_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import unittest # import Python so we can mock the parts we need to here. import IPython.core.display import IPython.core.magic def noop_decorator(func): return func IPython.core.magic.register_line_cell_magic = noop_decorator IPython.core.magic.register_line_magic = noop_decorator IPython.core.magic.register_cell_magic = noop_decorator IPython.core.display.HTML = lambda x: x IPython.core.display.JSON = lambda x: x import google.datalab.utils.commands # noqa class TestCases(unittest.TestCase): def test_chart_cell(self): t = [{'country': 'US', 'quantity': 100}, {'country': 'ZA', 'quantity': 50}] IPython.get_ipython().user_ns = {} chart = google.datalab.utils.commands._chart._chart_cell({'chart': 'geo', 'data': t, 'fields': None}, '') self.assertTrue(chart.find('charts.render(') > 0) self.assertTrue(chart.find('\'geo\'') > 0) self.assertTrue(chart.find('"fields": "*"') > 0) self.assertTrue(chart.find('{"c": [{"v": "US"}, {"v": 100}]}') > 0 or chart.find('{"c": [{"v": 100}, {"v": "US"}]}') > 0) self.assertTrue(chart.find('{"c": [{"v": "ZA"}, {"v": 50}]}') > 0 or chart.find('{"c": [{"v": 50}, {"v": "ZA"}]}') > 0) def test_chart_magic(self): # TODO(gram): complete this test pass ================================================ FILE: tests/kernel/html_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import mock import unittest # import Python so we can mock the parts we need to here. import IPython import IPython.core.magic import google.datalab def noop_decorator(func): return func IPython.core.magic.register_line_cell_magic = noop_decorator IPython.core.magic.register_line_magic = noop_decorator IPython.core.magic.register_cell_magic = noop_decorator IPython.get_ipython = mock.Mock() class TestCases(unittest.TestCase): def test_render_table(self): builder = google.datalab.utils.commands.HtmlBuilder() builder._render_objects({ 'cols': [ {'label': 'col1'}, {'label': 'col2'}, ], 'rows': [ {'c': [ {'v': 'val1'}, {'v': 'val2'} ]}, {'c': [ {'v': 'val3'}, {'v': 'val4'} ]} ] }, ['col1', 'col2'], 'chartdata') expected_html = ''.join('''
col1 col2
val1 val2
val3 val4
'''.split()) self.assertEqual(builder._to_html(), expected_html) def test_render_text(self): # TODO(gram): complete this test pass ================================================ FILE: tests/kernel/pipeline_tests.py ================================================ # Copyright 2017 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import unittest # import Python so we can mock the parts we need to here. import IPython import IPython.core.magic import mock import google.auth import google.datalab.contrib.pipeline.commands._pipeline def noop_decorator(func): return func IPython.core.magic.register_line_cell_magic = noop_decorator IPython.core.magic.register_line_magic = noop_decorator IPython.core.magic.register_cell_magic = noop_decorator IPython.get_ipython = mock.Mock() class TestCases(unittest.TestCase): # test pipeline creation sample_cell_body = """ schedule: start: 2009-05-05T22:28:15Z end: 2009-05-06T22:28:15Z interval: '@hourly' tasks: print_pdt_date: type: bash bash_command: date print_utc_date: type: bash bash_command: date -u up_stream: - print_pdt_date """ @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return google.datalab.Context(project_id, creds) @mock.patch('google.datalab.utils.commands.notebook_environment') @mock.patch('google.datalab.Context.default') def test_create_cell_no_name(self, mock_default_context, mock_notebook_environment): env = {} mock_default_context.return_value = TestCases._create_context() mock_notebook_environment.return_value = env IPython.get_ipython().user_ns = env # test pipeline creation p_body = 'foo' # no pipeline name specified. should execute with self.assertRaises(Exception): google.datalab.contrib.pipeline.commands._pipeline._create_cell({'name': None}, p_body) @mock.patch('google.datalab.utils.commands.notebook_environment') @mock.patch('google.datalab.Context.default') def test_create_cell_debug(self, mock_default_context, mock_notebook_environment): env = {} mock_default_context.return_value = TestCases._create_context() mock_notebook_environment.return_value = env IPython.get_ipython().user_ns = env # cell output is empty when debug is True output = google.datalab.contrib.pipeline.commands._pipeline._create_cell( {'name': 'foo_pipeline', 'debug': True}, self.sample_cell_body) self.assertTrue(len(output) > 0) output = google.datalab.contrib.pipeline.commands._pipeline._create_cell( {'name': 'foo_pipeline', 'debug': False}, self.sample_cell_body) self.assertTrue(output is None) output = google.datalab.contrib.pipeline.commands._pipeline._create_cell( {'name': 'foo_pipeline'}, self.sample_cell_body) self.assertTrue(output is None) @mock.patch('google.datalab.utils.commands.notebook_environment') @mock.patch('google.datalab.Context.default') def test_create_cell_golden(self, mock_default_context, mock_notebook_environment): # This import is required by the test to run successfully because we dynamically check the # imports to instantiate the class-type. from airflow.operators.bash_operator import BashOperator # noqa mock_default_context.return_value = TestCases._create_context() env = {} env['foo_query'] = google.datalab.bigquery.Query( 'SELECT * FROM publicdata.samples.wikipedia LIMIT 5') mock_notebook_environment.return_value = env # TODO(rajivpb): Possibly not necessary IPython.get_ipython().user_ns = env # test pipeline creation p_body = """ emails: foo1@test.com,foo2@test.com schedule: start: 2009-05-05T22:28:15Z end: 2009-05-06T22:28:15Z interval: '@hourly' tasks: foo_task_1: type: BigQuery query: $foo_query foo_task_2: type: Bash bash_command: date foo_task_3: type: Bash bash_command: date -u up_stream: - print_pdt_date """ spec = google.datalab.contrib.pipeline.commands._pipeline._create_cell( {'name': 'p1', 'debug': True}, p_body) self.assertEqual(spec, """ import datetime from airflow import DAG from airflow.operators.bash_operator import BashOperator from airflow.contrib.operators.bigquery_operator import BigQueryOperator from airflow.contrib.operators.bigquery_table_delete_operator import BigQueryTableDeleteOperator from airflow.contrib.operators.bigquery_to_bigquery import BigQueryToBigQueryOperator from airflow.contrib.operators.bigquery_to_gcs import BigQueryToCloudStorageOperator from airflow.contrib.operators.gcs_to_bq import GoogleCloudStorageToBigQueryOperator from google.datalab.contrib.bigquery.operators._bq_load_operator import LoadOperator from google.datalab.contrib.bigquery.operators._bq_execute_operator import ExecuteOperator from google.datalab.contrib.bigquery.operators._bq_extract_operator import ExtractOperator from datetime import timedelta default_args = { 'owner': 'Google Cloud Datalab', 'email': ['foo1@test.com', 'foo2@test.com'], 'start_date': datetime.datetime.strptime('2009-05-05T22:28:15', '%Y-%m-%dT%H:%M:%S'), 'end_date': datetime.datetime.strptime('2009-05-06T22:28:15', '%Y-%m-%dT%H:%M:%S'), } dag = DAG(dag_id='p1', schedule_interval='@hourly', catchup=False, default_args=default_args) foo_task_1 = BigQueryOperator(task_id='foo_task_1_id', bql=\"\"\"SELECT * FROM publicdata.samples.wikipedia LIMIT 5\"\"\", use_legacy_sql=False, dag=dag) foo_task_2 = BashOperator(task_id='foo_task_2_id', bash_command=\"\"\"date\"\"\", dag=dag) foo_task_3 = BashOperator(task_id='foo_task_3_id', bash_command=\"\"\"date -u\"\"\", dag=dag) foo_task_3.set_upstream(print_pdt_date) """ # noqa ) ================================================ FILE: tests/kernel/storage_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals import mock import unittest # import Python so we can mock the parts we need to here. import IPython import IPython.core.magic import google.auth def noop_decorator(func): return func IPython.core.magic.register_line_cell_magic = noop_decorator IPython.core.magic.register_line_magic = noop_decorator IPython.core.magic.register_cell_magic = noop_decorator IPython.get_ipython = mock.Mock() import google.datalab # noqa import google.datalab.storage # noqa import google.datalab.storage.commands # noqa class TestCases(unittest.TestCase): @mock.patch('google.datalab.storage._object.Object.exists', autospec=True) @mock.patch('google.datalab.storage._bucket.Bucket.objects', autospec=True) @mock.patch('google.datalab.storage._api.Api.objects_get', autospec=True) @mock.patch('google.datalab.Context.default') def test_expand_list(self, mock_context_default, mock_api_objects_get, mock_bucket_objects, mock_object_exists): context = TestCases._create_context() mock_context_default.return_value = context # Mock API for testing for object existence. Fail if called with name that includes wild char. def object_exists_side_effect(*args, **kwargs): return args[0].key.find('*') < 0 mock_object_exists.side_effect = object_exists_side_effect # Mock API for getting objects in a bucket. mock_bucket_objects.side_effect = TestCases._mock_bucket_objects_return(context) # Mock API for getting object metadata. mock_api_objects_get.side_effect = TestCases._mock_api_objects_get() objects = google.datalab.storage.commands._storage._expand_list(None) self.assertEqual([], objects) objects = google.datalab.storage.commands._storage._expand_list([]) self.assertEqual([], objects) objects = google.datalab.storage.commands._storage._expand_list('gs://bar/o*') self.assertEqual(['gs://bar/object1', 'gs://bar/object3'], objects) objects = google.datalab.storage.commands._storage._expand_list(['gs://foo', 'gs://bar']) self.assertEqual(['gs://foo', 'gs://bar'], objects) objects = google.datalab.storage.commands._storage._expand_list(['gs://foo/*', 'gs://bar']) self.assertEqual(['gs://foo/object1', 'gs://foo/object2', 'gs://foo/object3', 'gs://bar'], objects) objects = google.datalab.storage.commands._storage._expand_list(['gs://bar/o*']) self.assertEqual(['gs://bar/object1', 'gs://bar/object3'], objects) objects = google.datalab.storage.commands._storage._expand_list(['gs://bar/i*']) # Note - if no match we return the pattern. self.assertEqual(['gs://bar/i*'], objects) objects = google.datalab.storage.commands._storage._expand_list(['gs://baz']) self.assertEqual(['gs://baz'], objects) objects = google.datalab.storage.commands._storage._expand_list(['gs://baz/*']) self.assertEqual(['gs://baz/*'], objects) objects = google.datalab.storage.commands._storage._expand_list(['gs://foo/o*3']) self.assertEqual(['gs://foo/object3'], objects) @mock.patch('google.datalab.storage._object.Object.copy_to', autospec=True) @mock.patch('google.datalab.storage._bucket.Bucket.objects', autospec=True) @mock.patch('google.datalab.storage._api.Api.objects_get', autospec=True) @mock.patch('google.datalab.Context.default') def test_gcs_copy(self, mock_context_default, mock_api_objects_get, mock_bucket_objects, mock_gcs_object_copy_to): context = TestCases._create_context() mock_context_default.return_value = context # Mock API for getting objects in a bucket. mock_bucket_objects.side_effect = TestCases._mock_bucket_objects_return(context) # Mock API for getting object metadata. mock_api_objects_get.side_effect = TestCases._mock_api_objects_get() google.datalab.storage.commands._storage._gcs_copy({ 'source': ['gs://foo/object1'], 'destination': 'gs://foo/bar1' }, None) mock_gcs_object_copy_to.assert_called_with(mock.ANY, 'bar1', bucket='foo') self.assertEquals('object1', mock_gcs_object_copy_to.call_args[0][0].key) self.assertEquals('foo', mock_gcs_object_copy_to.call_args[0][0]._bucket) with self.assertRaises(Exception) as error: google.datalab.storage.commands._storage._gcs_copy({ 'source': ['gs://foo/object*'], 'destination': 'gs://foo/bar1' }, None) self.assertEqual('More than one source but target gs://foo/bar1 is not a bucket', str(error.exception)) @mock.patch('google.datalab.storage.commands._storage._gcs_copy', autospec=True) def test_gcs_copy_magic(self, mock_gcs_copy): google.datalab.storage.commands._storage.gcs('copy --source gs://foo/object1 ' '--destination gs://foo/bar1') mock_gcs_copy.assert_called_with({ 'source': ['gs://foo/object1'], 'destination': 'gs://foo/bar1', 'func': google.datalab.storage.commands._storage._gcs_copy }, None) @mock.patch('google.datalab.storage._api.Api.buckets_insert', autospec=True) @mock.patch('google.datalab.Context.default') def test_gcs_create(self, mock_context_default, mock_api_buckets_insert): context = TestCases._create_context() mock_context_default.return_value = context errs = google.datalab.storage.commands._storage._gcs_create({ 'project': 'test', 'bucket': [ 'gs://baz' ] }, None) self.assertEqual(None, errs) mock_api_buckets_insert.assert_called_with(mock.ANY, 'baz', project_id='test') with self.assertRaises(Exception) as error: google.datalab.storage.commands._storage._gcs_create({ 'project': 'test', 'bucket': [ 'gs://foo/bar' ] }, None) self.assertEqual("Couldn't create gs://foo/bar: Invalid bucket name gs://foo/bar", str(error.exception)) @mock.patch('google.datalab.storage._api.Api.objects_list', autospec=True) @mock.patch('google.datalab.storage._api.Api.buckets_get', autospec=True) @mock.patch('google.datalab.storage._api.Api.objects_get', autospec=True) @mock.patch('google.datalab.storage._bucket.Bucket.objects', autospec=True) @mock.patch('google.datalab.storage._api.Api.objects_delete', autospec=True) @mock.patch('google.datalab.storage._api.Api.buckets_delete', autospec=True) @mock.patch('google.datalab.Context.default') def test_gcs_delete(self, mock_context_default, mock_api_bucket_delete, mock_api_objects_delete, mock_bucket_objects, mock_api_objects_get, mock_api_buckets_get, mock_api_objects_list): context = TestCases._create_context() mock_context_default.return_value = context # Mock API for getting objects in a bucket. mock_bucket_objects.side_effect = TestCases._mock_bucket_objects_return(context) # Mock API for getting object metadata. mock_api_objects_get.side_effect = TestCases._mock_api_objects_get() mock_api_buckets_get.side_effect = TestCases._mock_api_buckets_get() # Mock API for listing objects in a bucket. mock_api_objects_list.side_effect = {} with self.assertRaises(Exception) as error: google.datalab.storage.commands._storage._gcs_delete({ 'bucket': [ 'gs://bar', 'gs://baz' ], 'object': [ 'gs://foo/object1', 'gs://baz/object1', ] }, None) self.assertEqual('gs://baz does not exist\ngs://baz/object1 does not exist', str(error.exception)) mock_api_bucket_delete.assert_called_with(mock.ANY, 'bar') mock_api_objects_delete.assert_called_with(mock.ANY, 'foo', 'object1') @mock.patch('google.datalab.Context.default') def test_gcs_view(self, mock_context_default): context = TestCases._create_context() mock_context_default.return_value = context # TODO(gram): complete this test @mock.patch('google.datalab.Context.default') def test_gcs_write(self, mock_context_default): context = TestCases._create_context() mock_context_default.return_value = context # TODO(gram): complete this test @staticmethod def _create_context(): project_id = 'test' creds = mock.Mock(spec=google.auth.credentials.Credentials) return google.datalab.Context(project_id, creds) @staticmethod def _mock_bucket_objects_return(context): # Mock API for getting objects in a bucket. def bucket_objects_side_effect(*args, **kwargs): bucket = args[0].name # self if bucket == 'foo': return [ google.datalab.storage._object.Object(bucket, 'object1', context=context), google.datalab.storage._object.Object(bucket, 'object2', context=context), google.datalab.storage._object.Object(bucket, 'object3', context=context), ] elif bucket == 'bar': return [ google.datalab.storage._object.Object(bucket, 'object1', context=context), google.datalab.storage._object.Object(bucket, 'object3', context=context), ] else: return [] return bucket_objects_side_effect @staticmethod def _mock_api_objects_get(): # Mock API for getting object metadata. def api_objects_get_side_effect(*args, **kwargs): if args[1].find('baz') >= 0: return None key = args[2] if key.find('*') >= 0: return None return {'name': key} return api_objects_get_side_effect @staticmethod def _mock_api_buckets_get(): # Mock API for getting bucket metadata. def api_buckets_get_side_effect(*args, **kwargs): key = args[1] if key.find('*') >= 0 or key.find('baz') >= 0: return None return {'name': key} return api_buckets_get_side_effect ================================================ FILE: tests/kernel/utils_tests.py ================================================ # Copyright 2015 Google Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software distributed under the License # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. from __future__ import absolute_import from __future__ import unicode_literals from builtins import range import datetime as dt import collections import mock import pandas import unittest # import Python so we can mock the parts we need to here. import IPython import IPython.core.magic import google.auth IPython.core.magic.register_line_cell_magic = mock.Mock() IPython.core.magic.register_line_magic = mock.Mock() IPython.core.magic.register_cell_magic = mock.Mock() IPython.get_ipython = mock.Mock() import google.datalab # noqa import google.datalab.bigquery # noqa import google.datalab.utils.commands._utils as _utils # noqa class TestCases(unittest.TestCase): @staticmethod def _get_expected_cols(): cols = [ {'type': 'number', 'id': 'Column1', 'label': 'Column1'}, {'type': 'number', 'id': 'Column2', 'label': 'Column2'}, {'type': 'string', 'id': 'Column3', 'label': 'Column3'}, {'type': 'boolean', 'id': 'Column4', 'label': 'Column4'}, {'type': 'number', 'id': 'Column5', 'label': 'Column5'}, {'type': 'timestamp', 'id': 'Column6', 'label': 'Column6'} ] return cols @staticmethod def _timestamp(d): return (d - dt.datetime(1970, 1, 1)).total_seconds() @staticmethod def _get_raw_rows(): rows = [ {'f': [ {'v': 1}, {'v': 2}, {'v': '3'}, {'v': 'true'}, {'v': 0.0}, {'v': TestCases._timestamp(dt.datetime(2000, 1, 1))} ]}, {'f': [ {'v': 11}, {'v': 12}, {'v': '13'}, {'v': 'false'}, {'v': 0.2}, {'v': TestCases._timestamp(dt.datetime(2000, 1, 2))} ]}, {'f': [ {'v': 21}, {'v': 22}, {'v': '23'}, {'v': 'true'}, {'v': 0.3}, {'v': TestCases._timestamp(dt.datetime(2000, 1, 3))} ]}, {'f': [ {'v': 31}, {'v': 32}, {'v': '33'}, {'v': 'false'}, {'v': 0.4}, {'v': TestCases._timestamp(dt.datetime(2000, 1, 4))} ]}, {'f': [ {'v': 41}, {'v': 42}, {'v': '43'}, {'v': 'true'}, {'v': 0.5}, {'v': TestCases._timestamp(dt.datetime(2000, 1, 5))} ]}, {'f': [ {'v': 51}, {'v': 52}, {'v': '53'}, {'v': 'true'}, {'v': 0.6}, {'v': TestCases._timestamp(dt.datetime(2000, 1, 6))} ]} ] return rows @staticmethod def _get_expected_rows(): rows = [ {'c': [ {'v': 1}, {'v': 2}, {'v': '3'}, {'v': True}, {'v': 0.0}, {'v': dt.datetime(2000, 1, 1)} ]}, {'c': [ {'v': 11}, {'v': 12}, {'v': '13'}, {'v': False}, {'v': 0.2}, {'v': dt.datetime(2000, 1, 2)} ]}, {'c': [ {'v': 21}, {'v': 22}, {'v': '23'}, {'v': True}, {'v': 0.3}, {'v': dt.datetime(2000, 1, 3)} ]}, {'c': [ {'v': 31}, {'v': 32}, {'v': '33'}, {'v': False}, {'v': 0.4}, {'v': dt.datetime(2000, 1, 4)} ]}, {'c': [ {'v': 41}, {'v': 42}, {'v': '43'}, {'v': True}, {'v': 0.5}, {'v': dt.datetime(2000, 1, 5)} ]}, {'c': [ {'v': 51}, {'v': 52}, {'v': '53'}, {'v': True}, {'v': 0.6}, {'v': dt.datetime(2000, 1, 6)} ]} ] return rows @staticmethod def _get_test_data_as_list_of_dicts(): test_data = [ {'Column1': 1, 'Column2': 2, 'Column3': '3', 'Column4': True, 'Column5': 0.0, 'Column6': dt.datetime(2000, 1, 1)}, {'Column1': 11, 'Column2': 12, 'Column3': '13', 'Column4': False, 'Column5': 0.2, 'Column6': dt.datetime(2000, 1, 2)}, {'Column1': 21, 'Column2': 22, 'Column3': '23', 'Column4': True, 'Column5': 0.3, 'Column6': dt.datetime(2000, 1, 3)}, {'Column1': 31, 'Column2': 32, 'Column3': '33', 'Column4': False, 'Column5': 0.4, 'Column6': dt.datetime(2000, 1, 4)}, {'Column1': 41, 'Column2': 42, 'Column3': '43', 'Column4': True, 'Column5': 0.5, 'Column6': dt.datetime(2000, 1, 5)}, {'Column1': 51, 'Column2': 52, 'Column3': '53', 'Column4': True, 'Column5': 0.6, 'Column6': dt.datetime(2000, 1, 6)} ] # Use OrderedDicts to make testing the result easier. for i in range(0, len(test_data)): test_data[i] = collections.OrderedDict(sorted(list(test_data[i].items()), key=lambda t: t[0])) return test_data def test_get_data_from_list_of_dicts(self): self._test_get_data(TestCases._get_test_data_as_list_of_dicts(), TestCases._get_expected_cols(), TestCases._get_expected_rows(), 6, _utils._get_data_from_list_of_dicts) self._test_get_data(TestCases._get_test_data_as_list_of_dicts(), TestCases._get_expected_cols(), TestCases._get_expected_rows(), 6, _utils.get_data) def test_get_data_from_list_of_lists(self): test_data = [ [1, 2, '3', True, 0.0, dt.datetime(2000, 1, 1)], [11, 12, '13', False, 0.2, dt.datetime(2000, 1, 2)], [21, 22, '23', True, 0.3, dt.datetime(2000, 1, 3)], [31, 32, '33', False, 0.4, dt.datetime(2000, 1, 4)], [41, 42, '43', True, 0.5, dt.datetime(2000, 1, 5)], [51, 52, '53', True, 0.6, dt.datetime(2000, 1, 6)], ] self._test_get_data(test_data, TestCases._get_expected_cols(), TestCases._get_expected_rows(), 6, _utils._get_data_from_list_of_lists) self._test_get_data(test_data, TestCases._get_expected_cols(), TestCases._get_expected_rows(), 6, _utils.get_data) def test_get_data_from_dataframe(self): df = pandas.DataFrame(self._get_test_data_as_list_of_dicts()) self._test_get_data(df, TestCases._get_expected_cols(), TestCases._get_expected_rows(), 6, _utils._get_data_from_dataframe) self._test_get_data(df, TestCases._get_expected_cols(), TestCases._get_expected_rows(), 6, _utils.get_data) @mock.patch('google.datalab.bigquery._api.Api.tabledata_list') @mock.patch('google.datalab.bigquery._table.Table.exists') @mock.patch('google.datalab.bigquery._api.Api.tables_get') @mock.patch('google.datalab.Context.default') def test_get_data_from_table(self, mock_context_default, mock_api_tables_get, mock_table_exists, mock_api_tabledata_list): data = TestCases._get_expected_rows() mock_context_default.return_value = TestCases._create_context() mock_api_tables_get.return_value = { 'numRows': len(data), 'schema': { 'fields': [ {'name': 'Column1', 'type': 'INTEGER'}, {'name': 'Column2', 'type': 'INTEGER'}, {'name': 'Column3', 'type': 'STRING'}, {'name': 'Column4', 'type': 'BOOLEAN'}, {'name': 'Column5', 'type': 'FLOAT'}, {'name': 'Column6', 'type': 'TIMESTAMP'} ] } } mock_table_exists.return_value = True raw_data = self._get_raw_rows() def tabledata_list(*args, **kwargs): start_index = kwargs['start_index'] max_results = kwargs['max_results'] if max_results < 0: max_results = len(data) return {'rows': raw_data[start_index:start_index + max_results]} mock_api_tabledata_list.side_effect = tabledata_list t = google.datalab.bigquery.Table('foo.bar') self._test_get_data(t, TestCases._get_expected_cols(), TestCases._get_expected_rows(), 6, _utils._get_data_from_table) self._test_get_data(t, TestCases._get_expected_cols(), TestCases._get_expected_rows(), 6, _utils.get_data) def test_get_data_from_empty_list(self): self._test_get_data([], [], [], 0, _utils.get_data) def test_get_data_from_malformed_list(self): with self.assertRaises(Exception) as error: self._test_get_data(['foo', 'bar'], [], [], 0, _utils.get_data) self.assertEquals('To get tabular data from a list it must contain dictionaries or lists.', str(error.exception)) def _test_get_data(self, test_data, cols, rows, expected_count, fn): self.maxDiff = None data, count = fn(test_data) self.assertEquals(expected_count, count) self.assertEquals({'cols': cols, 'rows': rows}, data) # Test first_row. Note that count must be set in this case so we use a value greater than the # data set size. for first in range(0, 6): data, count = fn(test_data, first_row=first, count=10) self.assertEquals(expected_count, count) self.assertEquals({'cols': cols, 'rows': rows[first:]}, data) # Test first_row + count for first in range(0, 6): data, count = fn(test_data, first_row=first, count=2) self.assertEquals(expected_count, count) self.assertEquals({'cols': cols, 'rows': rows[first:first + 2]}, data) # Test subsets of columns # No columns data, count = fn(test_data, fields=[]) self.assertEquals({'cols': [], 'rows': [{'c': []}] * expected_count}, data) # Single column data, count = fn(test_data, fields=['Column3']) if expected_count == 0: return self.assertEquals({'cols': [cols[2]], 'rows': [{'c': [row['c'][2]]} for row in rows]}, data) # Multi-columns data, count = fn(test_data, fields=['Column1', 'Column3', 'Column6']) self.assertEquals({'cols': [cols[0], cols[2], cols[5]], 'rows': [{'c': [row['c'][0], row['c'][2], row['c'][5]]} for row in rows]}, data) # Switch order data, count = fn(test_data, fields=['Column3', 'Column1']) self.assertEquals({'cols': [cols[2], cols[0]], 'rows': [{'c': [row['c'][2], row['c'][0]]} for row in rows]}, data) # Select all data, count = fn(test_data, fields=['Column1', 'Column2', 'Column3', 'Column4', 'Column5', 'Column6']) self.assertEquals({'cols': cols, 'rows': rows}, data) def test_expand_var(self): env = {'var': 'test-var-value'} resolved = _utils.expand_var('$var', env) self.assertEqual(resolved, env['var']) resolved = _utils.expand_var('', env) self.assertEqual(resolved, '') with self.assertRaisesRegexp(Exception, 'Cannot expand variable'): _utils.expand_var('$badname', env) def test_replace_vars(self): config = {'var': '$value'} env = {'value': 5} _utils.replace_vars(config, env) self.assertEqual(config, {'var': 5}) config = ['$value'] _utils.replace_vars(config, env) self.assertEqual(config, [5]) config = ({'var1': '$value1'}, ['$value2']) env = {'value1': 5, 'value2': 'stringvalue'} _utils.replace_vars(config, env) self.assertEqual(config, ({'var1': 5}, ['stringvalue'])) def test_validate_config(self): with self.assertRaisesRegexp(Exception, 'config is not dict type'): _utils.validate_config([], []) config = {'key1': 'value1', 'key2': 'value2', 'key3': 'value3'} _utils.validate_config(config, ['key1', 'key2'], ['key3']) _utils.validate_config(config, [], ['key1', 'key2', 'key3']) _utils.validate_config(config, ['key1', 'key2', 'key3']) with self.assertRaisesRegexp(Exception, 'Invalid config with unexpected keys'): _utils.validate_config(config, ['key1', 'key2']) with self.assertRaisesRegexp(Exception, 'Invalid config with missing keys'): _utils.validate_config(config, ['key1', 'key2', 'key3', 'key4']) def test_validate_config_must_have(self): config = {'key1': 'value1', 'key2': 'value2', 'key3': 'value3'} _utils.validate_config_must_have(config, ['key1', 'key2']) with self.assertRaisesRegexp(Exception, 'Invalid config with missing keys'): _utils.validate_config_must_have(config, ['key1', 'key4']) def test_validate_config_has_one_of(self): config = {'key1': 'value1', 'key2': 'value2', 'key3': 'value3'} _utils.validate_config_has_one_of(config, ['key1']) with self.assertRaisesRegexp(Exception, 'Only one of the values'): _utils.validate_config_has_one_of(config, ['key1', 'key2', 'key3']) with self.assertRaisesRegexp(Exception, 'One of the values in'): _utils.validate_config_has_one_of(config, ['key4', 'key5']) def test_validate_config_value(self): _utils.validate_config_value('val', ['val', 'val1', 'val2']) with self.assertRaisesRegexp(Exception, 'Invalid config value'): _utils.validate_config_value('val', ['val0', 'val1', 'val2']) def test_validate_gcs_path(self): _utils.validate_gcs_path('gs://testbucket/path/to/object', False) with self.assertRaisesRegexp(Exception, 'Invalid GCS path'): _utils.validate_gcs_path('path/to/object', False) _utils.validate_gcs_path('gs://path', False) with self.assertRaisesRegexp(Exception, 'It appears the GCS path "gs://path" is a bucket'): _utils.validate_gcs_path('gs://path', True) def test_parse_control_options_badtype(self): control = {'label': None, 'type': 'badtype'} with self.assertRaisesRegexp(Exception, 'Unknown control type badtype'): _utils.parse_control_options({'test-control': control}) def test_parse_control_options_set(self): control = {'label': None, 'choices': ['v1', 'v2'], 'min': 0, 'max': 10, 'step': 2} defaults = {'test-control': ['value1', 'value2']} control_html = _utils.parse_control_options({'test-control': control}, defaults) self.assertIn('class="gchart-control"', control_html[0]) self.assertIn('", output) def test_dive_plot(self): """Tests diveview.""" data1, _ = self._create_test_data() output = FacetsDiveview().plot(data1) # Output is an html. Ideally we can parse the html and verify nodes, but since the html # is output by a polymer component which is tested separately, we just verify # minumum keywords. self.assertIn("facets-dive", output) self.assertIn("