Repository: awslabs/mxnet-model-server Branch: master Commit: 706aa9c75557 Files: 466 Total size: 1.8 MB Directory structure: gitextract_mh90caup/ ├── .circleci/ │ ├── README.md │ ├── config.yml │ ├── images/ │ │ ├── Dockerfile.python2.7 │ │ └── Dockerfile.python3.6 │ └── scripts/ │ ├── linux_build.sh │ ├── linux_test_api.sh │ ├── linux_test_benchmark.sh │ ├── linux_test_modelarchiver.sh │ ├── linux_test_perf_regression.sh │ └── linux_test_python.sh ├── .coveragerc ├── .github/ │ └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── LICENSE ├── LICENSE.txt ├── MANIFEST.in ├── PyPiDescription.rst ├── README.md ├── _config.yml ├── benchmarks/ │ ├── README.md │ ├── benchmark.py │ ├── install_dependencies.sh │ ├── jmx/ │ │ ├── concurrentLoadPlan.jmx │ │ ├── concurrentScaleCalls.jmx │ │ ├── graphsGenerator.jmx │ │ ├── imageInputModelPlan.jmx │ │ ├── multipleModelsLoadPlan.jmx │ │ ├── pingPlan.jmx │ │ └── textInputModelPlan.jmx │ ├── lstm_ip.json │ ├── mac_install_dependencies.sh │ ├── noop_ip.txt │ └── upload_results_to_s3.sh ├── ci/ │ ├── Dockerfile.python2.7 │ ├── Dockerfile.python3.6 │ ├── README.md │ ├── buildspec.yml │ ├── dockerd-entrypoint.sh │ └── m2-settings.xml ├── docker/ │ ├── Dockerfile.cpu │ ├── Dockerfile.gpu │ ├── Dockerfile.nightly-cpu │ ├── Dockerfile.nightly-gpu │ ├── README.md │ ├── advanced-dockerfiles/ │ │ ├── Dockerfile.base.nvidia_cu92_ubuntu_16_04.py2_7 │ │ ├── Dockerfile.base.nvidia_cu92_ubuntu_16_04.py2_7.nightly │ │ ├── Dockerfile.base.nvidia_cu92_ubuntu_16_04.py3_6 │ │ ├── Dockerfile.base.nvidia_cu92_ubuntu_16_04.py3_6.nightly │ │ ├── Dockerfile.base.ubuntu_16_04.py2_7 │ │ ├── Dockerfile.base.ubuntu_16_04.py2_7.nightly │ │ ├── Dockerfile.base.ubuntu_16_04.py3_6 │ │ ├── Dockerfile.base.ubuntu_16_04.py3_6.nightly │ │ ├── config.properties │ │ └── dockerd-entrypoint.sh │ ├── advanced_settings.md │ ├── config.properties │ └── dockerd-entrypoint.sh ├── docs/ │ ├── README.md │ ├── batch_inference_with_mms.md │ ├── configuration.md │ ├── custom_service.md │ ├── elastic_inference.md │ ├── images/ │ │ └── helpers/ │ │ └── plugins_sdk_class_uml_diagrams.puml │ ├── inference_api.md │ ├── install.md │ ├── logging.md │ ├── management_api.md │ ├── metrics.md │ ├── migration.md │ ├── mms_endpoint_plugins.md │ ├── mms_on_fargate.md │ ├── model_zoo.md │ ├── rest_api.md │ └── server.md ├── examples/ │ ├── README.md │ ├── densenet_pytorch/ │ │ ├── Dockerfile │ │ ├── README.md │ │ ├── densenet_service.py │ │ └── index_to_name.json │ ├── gluon_alexnet/ │ │ ├── README.md │ │ ├── gluon_hybrid_alexnet.py │ │ ├── gluon_imperative_alexnet.py │ │ ├── gluon_pretrained_alexnet.py │ │ ├── signature.json │ │ └── synset.txt │ ├── gluon_character_cnn/ │ │ ├── README.md │ │ ├── gluon_crepe.py │ │ ├── signature.json │ │ └── synset.txt │ ├── lstm_ptb/ │ │ ├── README.md │ │ └── lstm_ptb_service.py │ ├── metrics_cloudwatch/ │ │ ├── __init__.py │ │ └── metric_push_example.py │ ├── model_service_template/ │ │ ├── gluon_base_service.py │ │ ├── model_handler.py │ │ ├── mxnet_model_service.py │ │ ├── mxnet_utils/ │ │ │ ├── __init__.py │ │ │ ├── image.py │ │ │ ├── ndarray.py │ │ │ └── nlp.py │ │ ├── mxnet_vision_batching.py │ │ └── mxnet_vision_service.py │ ├── mxnet_vision/ │ │ └── README.md │ ├── sockeye_translate/ │ │ ├── Dockerfile │ │ ├── README.md │ │ ├── config/ │ │ │ └── config.properties │ │ ├── model_handler.py │ │ ├── preprocessor.py │ │ └── sockeye_service.py │ └── ssd/ │ ├── README.md │ ├── example_outputs.md │ └── ssd_service.py ├── frontend/ │ ├── .gitignore │ ├── README.md │ ├── build.gradle │ ├── cts/ │ │ ├── build.gradle │ │ └── src/ │ │ └── main/ │ │ ├── java/ │ │ │ └── com/ │ │ │ └── amazonaws/ │ │ │ └── ml/ │ │ │ └── mms/ │ │ │ └── cts/ │ │ │ ├── Cts.java │ │ │ ├── HttpClient.java │ │ │ └── ModelInfo.java │ │ └── resources/ │ │ └── log4j2.xml │ ├── gradle/ │ │ └── wrapper/ │ │ ├── gradle-wrapper.jar │ │ └── gradle-wrapper.properties │ ├── gradle.properties │ ├── gradlew │ ├── gradlew.bat │ ├── modelarchive/ │ │ ├── build.gradle │ │ └── src/ │ │ ├── main/ │ │ │ └── java/ │ │ │ └── com/ │ │ │ └── amazonaws/ │ │ │ └── ml/ │ │ │ └── mms/ │ │ │ └── archive/ │ │ │ ├── DownloadModelException.java │ │ │ ├── Hex.java │ │ │ ├── InvalidModelException.java │ │ │ ├── LegacyManifest.java │ │ │ ├── Manifest.java │ │ │ ├── ModelArchive.java │ │ │ ├── ModelException.java │ │ │ ├── ModelNotFoundException.java │ │ │ └── ZipUtils.java │ │ └── test/ │ │ ├── java/ │ │ │ └── com/ │ │ │ └── amazonaws/ │ │ │ └── ml/ │ │ │ └── mms/ │ │ │ ├── archive/ │ │ │ │ ├── CoverageTest.java │ │ │ │ ├── Exporter.java │ │ │ │ └── ModelArchiveTest.java │ │ │ └── test/ │ │ │ └── TestHelper.java │ │ └── resources/ │ │ └── models/ │ │ ├── custom-return-code/ │ │ │ ├── MAR-INF/ │ │ │ │ └── MANIFEST.json │ │ │ └── service.py │ │ ├── error_batch/ │ │ │ ├── MAR-INF/ │ │ │ │ └── MANIFEST.json │ │ │ └── service.py │ │ ├── init-error/ │ │ │ ├── MAR-INF/ │ │ │ │ └── MANIFEST.json │ │ │ └── invalid_service.py │ │ ├── invalid/ │ │ │ ├── MAR-INF/ │ │ │ │ └── MANIFEST.json │ │ │ └── invalid_service.py │ │ ├── loading-memory-error/ │ │ │ ├── MAR-INF/ │ │ │ │ └── MANIFEST.json │ │ │ └── service.py │ │ ├── logging/ │ │ │ ├── MAR-INF/ │ │ │ │ └── MANIFEST.json │ │ │ └── service.py │ │ ├── noop-no-manifest/ │ │ │ └── service.py │ │ ├── noop-v0.1/ │ │ │ ├── MANIFEST.json │ │ │ ├── noop_service.py │ │ │ └── signature.json │ │ ├── noop-v1.0/ │ │ │ ├── MAR-INF/ │ │ │ │ └── MANIFEST.json │ │ │ └── service.py │ │ ├── noop-v1.0-config-tests/ │ │ │ ├── MAR-INF/ │ │ │ │ └── MANIFEST.json │ │ │ └── service.py │ │ ├── prediction-memory-error/ │ │ │ ├── MAR-INF/ │ │ │ │ └── MANIFEST.json │ │ │ └── service.py │ │ └── respheader-test/ │ │ ├── MAR-INF/ │ │ │ └── MANIFEST.json │ │ └── service.py │ ├── server/ │ │ ├── build.gradle │ │ └── src/ │ │ ├── main/ │ │ │ ├── java/ │ │ │ │ └── com/ │ │ │ │ └── amazonaws/ │ │ │ │ └── ml/ │ │ │ │ └── mms/ │ │ │ │ ├── ModelServer.java │ │ │ │ ├── ServerInitializer.java │ │ │ │ ├── http/ │ │ │ │ │ ├── ApiDescriptionRequestHandler.java │ │ │ │ │ ├── BadRequestException.java │ │ │ │ │ ├── ConflictStatusException.java │ │ │ │ │ ├── DescribeModelResponse.java │ │ │ │ │ ├── ErrorResponse.java │ │ │ │ │ ├── HttpRequestHandler.java │ │ │ │ │ ├── HttpRequestHandlerChain.java │ │ │ │ │ ├── InferenceRequestHandler.java │ │ │ │ │ ├── InternalServerException.java │ │ │ │ │ ├── InvalidPluginException.java │ │ │ │ │ ├── InvalidRequestHandler.java │ │ │ │ │ ├── ListModelsResponse.java │ │ │ │ │ ├── ManagementRequestHandler.java │ │ │ │ │ ├── MethodNotAllowedException.java │ │ │ │ │ ├── RequestTimeoutException.java │ │ │ │ │ ├── ResourceNotFoundException.java │ │ │ │ │ ├── ServiceUnavailableException.java │ │ │ │ │ ├── Session.java │ │ │ │ │ ├── StatusResponse.java │ │ │ │ │ └── messages/ │ │ │ │ │ └── RegisterModelRequest.java │ │ │ │ ├── metrics/ │ │ │ │ │ ├── Dimension.java │ │ │ │ │ ├── Metric.java │ │ │ │ │ ├── MetricCollector.java │ │ │ │ │ └── MetricManager.java │ │ │ │ ├── openapi/ │ │ │ │ │ ├── Encoding.java │ │ │ │ │ ├── Info.java │ │ │ │ │ ├── MediaType.java │ │ │ │ │ ├── OpenApi.java │ │ │ │ │ ├── OpenApiUtils.java │ │ │ │ │ ├── Operation.java │ │ │ │ │ ├── Parameter.java │ │ │ │ │ ├── Path.java │ │ │ │ │ ├── PathParameter.java │ │ │ │ │ ├── QueryParameter.java │ │ │ │ │ ├── RequestBody.java │ │ │ │ │ ├── Response.java │ │ │ │ │ └── Schema.java │ │ │ │ ├── servingsdk/ │ │ │ │ │ └── impl/ │ │ │ │ │ ├── ModelServerContext.java │ │ │ │ │ ├── ModelServerModel.java │ │ │ │ │ ├── ModelServerRequest.java │ │ │ │ │ ├── ModelServerResponse.java │ │ │ │ │ ├── ModelWorker.java │ │ │ │ │ └── PluginsManager.java │ │ │ │ ├── util/ │ │ │ │ │ ├── ConfigManager.java │ │ │ │ │ ├── Connector.java │ │ │ │ │ ├── ConnectorType.java │ │ │ │ │ ├── JsonUtils.java │ │ │ │ │ ├── NettyUtils.java │ │ │ │ │ ├── OpenSslKey.java │ │ │ │ │ ├── ServerGroups.java │ │ │ │ │ ├── codec/ │ │ │ │ │ │ ├── CodecUtils.java │ │ │ │ │ │ ├── ModelRequestEncoder.java │ │ │ │ │ │ └── ModelResponseDecoder.java │ │ │ │ │ ├── logging/ │ │ │ │ │ │ └── QLogLayout.java │ │ │ │ │ └── messages/ │ │ │ │ │ ├── BaseModelRequest.java │ │ │ │ │ ├── InputParameter.java │ │ │ │ │ ├── ModelInferenceRequest.java │ │ │ │ │ ├── ModelLoadModelRequest.java │ │ │ │ │ ├── ModelWorkerResponse.java │ │ │ │ │ ├── Predictions.java │ │ │ │ │ ├── RequestInput.java │ │ │ │ │ └── WorkerCommands.java │ │ │ │ └── wlm/ │ │ │ │ ├── BatchAggregator.java │ │ │ │ ├── Job.java │ │ │ │ ├── Model.java │ │ │ │ ├── ModelManager.java │ │ │ │ ├── WorkLoadManager.java │ │ │ │ ├── WorkerInitializationException.java │ │ │ │ ├── WorkerLifeCycle.java │ │ │ │ ├── WorkerState.java │ │ │ │ ├── WorkerStateListener.java │ │ │ │ └── WorkerThread.java │ │ │ └── resources/ │ │ │ └── log4j2.xml │ │ └── test/ │ │ ├── java/ │ │ │ └── com/ │ │ │ └── amazonaws/ │ │ │ └── ml/ │ │ │ └── mms/ │ │ │ ├── CoverageTest.java │ │ │ ├── ModelServerTest.java │ │ │ ├── TestUtils.java │ │ │ ├── test/ │ │ │ │ └── TestHelper.java │ │ │ └── util/ │ │ │ └── ConfigManagerTest.java │ │ └── resources/ │ │ ├── certs.pem │ │ ├── config.properties │ │ ├── config_test_env.properties │ │ ├── describe_api.json │ │ ├── inference_open_api.json │ │ ├── key.pem │ │ ├── keystore.p12 │ │ └── management_open_api.json │ ├── settings.gradle │ └── tools/ │ ├── conf/ │ │ ├── checkstyle.xml │ │ ├── findbugs-exclude.xml │ │ ├── pmd.xml │ │ └── suppressions.xml │ └── gradle/ │ ├── check.gradle │ ├── formatter.gradle │ └── launcher.gradle ├── mms/ │ ├── .gitignore │ ├── __init__.py │ ├── arg_parser.py │ ├── configs/ │ │ └── sagemaker_config.properties │ ├── context.py │ ├── export_model.py │ ├── metrics/ │ │ ├── __init__.py │ │ ├── dimension.py │ │ ├── metric.py │ │ ├── metric_collector.py │ │ ├── metric_encoder.py │ │ ├── metrics_store.py │ │ ├── process_memory_metric.py │ │ ├── system_metrics.py │ │ └── unit.py │ ├── model_loader.py │ ├── model_server.py │ ├── model_service/ │ │ ├── __init__.py │ │ ├── gluon_vision_service.py │ │ ├── model_service.py │ │ ├── mxnet_model_service.py │ │ └── mxnet_vision_service.py │ ├── model_service_worker.py │ ├── protocol/ │ │ ├── __init__.py │ │ └── otf_message_handler.py │ ├── service.py │ ├── tests/ │ │ ├── README.md │ │ ├── pylintrc │ │ └── unit_tests/ │ │ ├── helper/ │ │ │ ├── __init__.py │ │ │ └── pixel2pixel_service.py │ │ ├── model_service/ │ │ │ ├── dummy_model/ │ │ │ │ ├── MANIFEST.json │ │ │ │ └── dummy_model_service.py │ │ │ ├── test_mxnet_image.py │ │ │ ├── test_mxnet_ndarray.py │ │ │ ├── test_mxnet_nlp.py │ │ │ └── test_service.py │ │ ├── test_beckend_metric.py │ │ ├── test_model_loader.py │ │ ├── test_model_service_worker.py │ │ ├── test_otf_codec_protocol.py │ │ ├── test_utils/ │ │ │ ├── MAR-INF/ │ │ │ │ └── MANIFEST.json │ │ │ ├── dummy_class_model_service.py │ │ │ └── dummy_func_model_service.py │ │ ├── test_version.py │ │ └── test_worker_service.py │ ├── utils/ │ │ ├── __init__.py │ │ ├── mxnet/ │ │ │ ├── __init__.py │ │ │ ├── image.py │ │ │ ├── ndarray.py │ │ │ └── nlp.py │ │ └── timeit_decorator.py │ └── version.py ├── model-archiver/ │ ├── .coveragerc │ ├── LICENSE │ ├── MANIFEST.in │ ├── PyPiDescription.rst │ ├── README.md │ ├── docs/ │ │ └── convert_from_onnx.md │ ├── model_archiver/ │ │ ├── __init__.py │ │ ├── arg_parser.py │ │ ├── manifest_components/ │ │ │ ├── __init__.py │ │ │ ├── engine.py │ │ │ ├── manifest.py │ │ │ ├── model.py │ │ │ └── publisher.py │ │ ├── model_archiver_error.py │ │ ├── model_packaging.py │ │ ├── model_packaging_utils.py │ │ ├── tests/ │ │ │ ├── integ_tests/ │ │ │ │ ├── configuration.json │ │ │ │ ├── resources/ │ │ │ │ │ ├── onnx_model/ │ │ │ │ │ │ ├── model.onnx │ │ │ │ │ │ └── service.py │ │ │ │ │ └── regular_model/ │ │ │ │ │ ├── dir/ │ │ │ │ │ │ └── 1.py │ │ │ │ │ ├── dummy-artifacts.txt │ │ │ │ │ └── service.py │ │ │ │ └── test_integration_model_archiver.py │ │ │ ├── pylintrc │ │ │ └── unit_tests/ │ │ │ ├── test_model_packaging.py │ │ │ ├── test_model_packaging_utils.py │ │ │ └── test_version.py │ │ └── version.py │ └── setup.py ├── performance_regression/ │ └── imageInputModelPlan.jmx.yaml ├── plugins/ │ ├── build.gradle │ ├── endpoints/ │ │ ├── build.gradle │ │ └── src/ │ │ └── main/ │ │ ├── java/ │ │ │ └── software/ │ │ │ └── amazon/ │ │ │ └── ai/ │ │ │ └── mms/ │ │ │ └── plugins/ │ │ │ └── endpoint/ │ │ │ ├── ExecutionParameters.java │ │ │ └── Ping.java │ │ └── resources/ │ │ └── META-INF/ │ │ └── services/ │ │ └── software.amazon.ai.mms.servingsdk.ModelServerEndpoint │ ├── gradle/ │ │ └── wrapper/ │ │ ├── gradle-wrapper.jar │ │ └── gradle-wrapper.properties │ ├── gradle.properties │ ├── gradlew │ ├── gradlew.bat │ ├── settings.gradle │ └── tools/ │ ├── conf/ │ │ ├── checkstyle.xml │ │ ├── findbugs-exclude.xml │ │ ├── pmd.xml │ │ └── suppressions.xml │ └── gradle/ │ ├── check.gradle │ ├── formatter.gradle │ └── launcher.gradle ├── run_ci_tests.sh ├── run_circleci_tests.py ├── serving-sdk/ │ ├── checkstyle.xml │ ├── pom.xml │ └── src/ │ ├── main/ │ │ └── java/ │ │ └── software/ │ │ └── amazon/ │ │ └── ai/ │ │ └── mms/ │ │ └── servingsdk/ │ │ ├── Context.java │ │ ├── Model.java │ │ ├── ModelServerEndpoint.java │ │ ├── ModelServerEndpointException.java │ │ ├── Worker.java │ │ ├── annotations/ │ │ │ ├── Endpoint.java │ │ │ └── helpers/ │ │ │ └── EndpointTypes.java │ │ └── http/ │ │ ├── Request.java │ │ └── Response.java │ └── test/ │ └── java/ │ └── software/ │ └── amazon/ │ └── ai/ │ └── mms/ │ └── servingsdk/ │ └── ModelServerEndpointTest.java ├── setup.py ├── test/ │ ├── README.md │ ├── postman/ │ │ ├── environment.json │ │ ├── https_test_collection.json │ │ ├── inference_api_test_collection.json │ │ ├── inference_data.json │ │ └── management_api_test_collection.json │ ├── regression_tests.sh │ └── resources/ │ ├── certs.pem │ ├── config.properties │ └── key.pem └── tests/ └── performance/ ├── README.md ├── TESTS.md ├── agents/ │ ├── __init__.py │ ├── config.ini │ ├── configuration.py │ ├── metrics/ │ │ └── __init__.py │ ├── metrics_collector.py │ ├── metrics_monitoring_inproc.py │ ├── metrics_monitoring_server.py │ └── utils/ │ ├── __init__.py │ └── process.py ├── pylintrc ├── requirements.txt ├── run_performance_suite.py ├── runs/ │ ├── __init__.py │ ├── compare.py │ ├── context.py │ ├── junit.py │ ├── storage.py │ └── taurus/ │ ├── __init__.py │ ├── reader.py │ └── x2junit.py ├── tests/ │ ├── api_description/ │ │ ├── api_description.jmx │ │ ├── api_description.yaml │ │ └── environments/ │ │ └── xlarge.yaml │ ├── batch_and_single_inference/ │ │ ├── batch_and_single_inference.jmx │ │ ├── batch_and_single_inference.yaml │ │ └── environments/ │ │ └── xlarge.yaml │ ├── batch_inference/ │ │ ├── batch_inference.jmx │ │ ├── batch_inference.yaml │ │ └── environments/ │ │ └── xlarge.yaml │ ├── examples_local_criteria/ │ │ ├── environments/ │ │ │ └── xlarge.yaml │ │ ├── examples_local_criteria.jmx │ │ └── examples_local_criteria.yaml │ ├── examples_local_monitoring/ │ │ ├── examples_local_monitoring.jmx │ │ └── examples_local_monitoring.yaml │ ├── examples_remote_criteria/ │ │ ├── environments/ │ │ │ └── xlarge.yaml │ │ ├── examples_remote_criteria.jmx │ │ └── examples_remote_criteria.yaml │ ├── examples_remote_monitoring/ │ │ ├── examples_remote_monitoring.jmx │ │ └── examples_remote_monitoring.yaml │ ├── examples_starter/ │ │ ├── examples_starter.jmx │ │ └── examples_starter.yaml │ ├── global_config.yaml │ ├── health_check/ │ │ ├── environments/ │ │ │ └── xlarge.yaml │ │ ├── health_check.jmx │ │ └── health_check.yaml │ ├── inference_multiple_models/ │ │ ├── environments/ │ │ │ └── xlarge.yaml │ │ ├── inference_multiple_models.jmx │ │ └── inference_multiple_models.yaml │ ├── inference_multiple_worker/ │ │ ├── environments/ │ │ │ └── xlarge.yaml │ │ ├── inference_multiple_worker.jmx │ │ └── inference_multiple_worker.yaml │ ├── inference_single_worker/ │ │ ├── environments/ │ │ │ └── xlarge.yaml │ │ ├── inference_single_worker.jmx │ │ └── inference_single_worker.yaml │ ├── list_models/ │ │ ├── environments/ │ │ │ └── xlarge.yaml │ │ ├── list_models.jmx │ │ └── list_models.yaml │ ├── model_description/ │ │ ├── environments/ │ │ │ └── xlarge.yaml │ │ ├── model_description.jmx │ │ └── model_description.yaml │ ├── multiple_inference_and_scaling/ │ │ ├── environments/ │ │ │ └── xlarge.yaml │ │ ├── multiple_inference_and_scaling.jmx │ │ └── multiple_inference_and_scaling.yaml │ ├── register_unregister/ │ │ ├── environments/ │ │ │ └── xlarge.yaml │ │ ├── register_unregister.jmx │ │ └── register_unregister.yaml │ ├── register_unregister_multiple/ │ │ ├── environments/ │ │ │ └── xlarge.yaml │ │ ├── register_unregister_multiple.jmx │ │ └── register_unregister_multiple.yaml │ ├── scale_down_workers/ │ │ ├── environments/ │ │ │ └── xlarge.yaml │ │ ├── scale_down_workers.jmx │ │ └── scale_down_workers.yaml │ └── scale_up_workers/ │ ├── environments/ │ │ └── xlarge.yaml │ ├── scale_up_workers.jmx │ └── scale_up_workers.yaml └── utils/ ├── __init__.py ├── fs.py ├── pyshell.py └── timer.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .circleci/README.md ================================================ # Multi Model Server CircleCI build Model Server uses CircleCI for builds. This folder contains the config and scripts that are needed for CircleCI. ## config.yml _config.yml_ contains MMS build logic which will be used by CircleCI. ## Workflows and Jobs Currently, following _workflows_ are available - 1. smoke 2. nightly 3. weekly Following _jobs_ are executed under each workflow - 1. **build** : Builds _frontend/model-server.jar_ and executes tests from gradle 2. **modelarchiver** : Builds and tests modelarchiver module 3. **python-tests** : Executes pytests from _mms/tests/unit_tests/_ 4. **benchmark** : Executes latency benchmark using resnet-18 model 5. (NEW!) **api-tests** : Executes newman test suite for API testing Following _executors_ are available for job execution - 1. py27 2. py36 > Please check the _workflows_, _jobs_ and _executors_ section in _config.yml_ for an up to date list ## scripts Instead of using inline commands inside _config.yml_, job steps are configured as shell scripts. This is easier for maintenance and reduces chances of error in config.yml ## images MMS uses customized docker image for its CircleCI build. To make sure MMS is compatible with both Python2 and Python3, we use two build projects. We have published two docker images on docker hub for code build * prashantsail/mms-build:python2.7 * prashantsail/mms-build:python3.6 Following files in the _images_ folder are used to create the docker images * Dockerfile.python2.7 - Dockerfile for prashantsail/mms-build:python2.7 * Dockerfile.python3.6 - Dockerfile for prashantsail/mms-build:python3.6 ## Local CircleCI cli To make it easy for developers to debug build issues locally, MMS supports CircleCI cli for running a job in a container on your machine. #### Dependencies 1. CircleCI cli ([Quick Install](https://circleci.com/docs/2.0/local-cli/#quick-installation)) 2. PyYAML (pip install PyYaml) 3. docker (installed and running) #### Command Developers can use the following command to build MMS locally: **./run_circleci_tests.py -j -e ** - _workflow_name_ This is a madatory parameter - _-j, --job job_name_ If specified, executes only the specified job name (along with the required parents). If not specified, all jobs in the workflow are executed sequentially. - _-e, --executor executor_name_ If specified, job is executed only on the specified executor(docker image). If not specified, job is executed on all the available executors. ```bash $ cd multi-model-server $ ./run_circleci_tests.py smoke $ ./run_circleci_tests.py smoke -j modelarchiver $ ./run_circleci_tests.py smoke -e py36 $ ./run_circleci_tests.py smoke -j modelarchiver -e py36 ``` ###### Checklist > 1. Make sure docker is running before you start local execution. > 2. Docker containers to have **at least 4GB RAM, 2 CPU**. > 3. If you are on a network with low bandwidth, we advise you to explicitly pull the docker images - > docker pull prashantsail/mms-build:python2.7 > docker pull prashantsail/mms-build:python3.6 `To avoid Pull Request build failures on github, developers should always make sure that their local builds pass.` ================================================ FILE: .circleci/config.yml ================================================ version: 2.1 executors: py36: docker: - image: prashantsail/mms-build:python3.6 environment: _JAVA_OPTIONS: "-Xmx2048m" py27: docker: - image: prashantsail/mms-build:python2.7 environment: _JAVA_OPTIONS: "-Xmx2048m" commands: attach-mms-workspace: description: "Attach the MMS directory which was saved into workspace" steps: - attach_workspace: at: . install-mms-server: description: "Install MMS server from a wheel" steps: - run: name: Install MMS command: pip install dist/*.whl exeucute-api-tests: description: "Execute API tests from a collection" parameters: collection: type: enum enum: [management, inference, https] default: management steps: - run: name: Start MMS, Execute << parameters.collection >> API Tests, Stop MMS command: .circleci/scripts/linux_test_api.sh << parameters.collection >> - store_artifacts: name: Store server logs from << parameters.collection >> API tests path: mms_<< parameters.collection >>.log - store_artifacts: name: Store << parameters.collection >> API test results path: test/<< parameters.collection >>-api-report.html jobs: build: parameters: executor: type: executor executor: << parameters.executor >> steps: - checkout - run: name: Build frontend command: .circleci/scripts/linux_build.sh - store_artifacts: name: Store gradle testng results path: frontend/server/build/reports/tests/test - persist_to_workspace: root: . paths: - . python-tests: parameters: executor: type: executor executor: << parameters.executor >> steps: - attach-mms-workspace - run: name: Execute python unit tests command: .circleci/scripts/linux_test_python.sh - store_artifacts: name: Store python Test results path: htmlcov api-tests: parameters: executor: type: executor executor: << parameters.executor >> steps: - attach-mms-workspace - install-mms-server - exeucute-api-tests: collection: management - exeucute-api-tests: collection: inference - exeucute-api-tests: collection: https benchmark: parameters: executor: type: executor executor: << parameters.executor >> steps: - attach-mms-workspace - install-mms-server - run: name: Start MMS, Execute benchmark tests, Stop MMS command: .circleci/scripts/linux_test_benchmark.sh - store_artifacts: name: Store server logs from benchmark tests path: mms.log - store_artifacts: name: Store Benchmark Latency resnet-18 results path: /tmp/MMSBenchmark/out/latency/resnet-18/report/ destination: benchmark-latency-resnet-18 modelarchiver: parameters: executor: type: executor executor: << parameters.executor >> steps: - checkout - run: name: Execute lint, unit and integration tests command: .circleci/scripts/linux_test_modelarchiver.sh - store_artifacts: name: Store unit tests results from model archiver tests path: model-archiver/results_units destination: units workflows: version: 2 smoke: jobs: - &build build: name: build-<< matrix.executor >> matrix: &matrix parameters: executor: ["py27", "py36"] - &modelarchiver modelarchiver: name: modelarchiver-<< matrix.executor >> matrix: *matrix - &python-tests python-tests: name: python-tests-<< matrix.executor >> requires: - build-<< matrix.executor >> matrix: *matrix nightly: triggers: - schedule: cron: "0 0 * * *" filters: branches: only: - master jobs: - *build - *modelarchiver - *python-tests - &api-tests api-tests: name: api-tests-<< matrix.executor >> requires: - build-<< matrix.executor >> matrix: *matrix weekly: triggers: - schedule: cron: "0 0 * * 0" filters: branches: only: - master jobs: - *build - benchmark: name: benchmark-<< matrix.executor >> requires: - build-<< matrix.executor >> matrix: *matrix ================================================ FILE: .circleci/images/Dockerfile.python2.7 ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. # FROM awsdeeplearningteam/mms-build:python2.7@sha256:2b743d6724dead806873cce1330f7b8a0197399a35af47dfd7035251fdade122 # 2020 - Updated Build and Test dependencies # Python packages for MMS Server RUN pip install psutil \ && pip install future \ && pip install Pillow \ && pip install wheel \ && pip install twine \ && pip install requests \ && pip install mock \ && pip install numpy \ && pip install Image \ && pip install mxnet==1.5.0 \ && pip install enum34 # Python packages for pytests RUN pip install pytest==4.0.0 \ && pip install pytest-cov \ && pip install pytest-mock # Python packages for benchmark RUN pip install pandas # Install NodeJS and packages for API tests RUN curl -sL https://deb.nodesource.com/setup_14.x | sudo -E bash - \ && sudo apt-get install -y nodejs \ && sudo npm install -g newman newman-reporter-html # Install jmeter for benchmark # ToDo: Remove --no-check-certificate; temporarily added to bypass jmeter-plugins.org's expired certificate RUN cd /opt \ && wget https://archive.apache.org/dist/jmeter/binaries/apache-jmeter-5.3.tgz \ && tar -xzf apache-jmeter-5.3.tgz \ && cd apache-jmeter-5.3 \ && ln -s /opt/apache-jmeter-5.3/bin/jmeter /usr/local/bin/jmeter \ && wget --no-check-certificate https://jmeter-plugins.org/get/ -O lib/ext/jmeter-plugins-manager-1.4.jar \ && wget http://search.maven.org/remotecontent?filepath=kg/apc/cmdrunner/2.2/cmdrunner-2.2.jar -O lib/cmdrunner-2.2.jar \ && java -cp lib/ext/jmeter-plugins-manager-1.4.jar org.jmeterplugins.repository.PluginManagerCMDInstaller \ && bin/PluginsManagerCMD.sh install jpgc-synthesis=2.1,jpgc-filterresults=2.1,jpgc-mergeresults=2.1,jpgc-cmd=2.1,jpgc-perfmon=2.1 # bzt is used for performance regression test suite # bzt requires python 3.6 runtime. # Download pyenv, use pyenv to download python 3.6.5. # The downloaded python 3.6.5 is isolated and doesn't interfere with default python(2.7) # Only before starting the performance regression suite, py 3.6.5 is local installed(pyenv local 3.6.5) in test dir # !! MMS server will continue using Python 2.7 !! RUN curl https://pyenv.run | bash \ && $HOME/.pyenv/bin/pyenv install 3.6.5 ================================================ FILE: .circleci/images/Dockerfile.python3.6 ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. # FROM awsdeeplearningteam/mms-build:python3.6@sha256:2c1afa8834907ceec641d254dffbf4bcc659ca2d00fd6f2872d7521f32c9fa2e # 2020 - Updated Build and Test dependencies # Python packages for MMS Server RUN pip install psutil \ && pip install future \ && pip install Pillow \ && pip install wheel \ && pip install twine \ && pip install requests \ && pip install mock \ && pip install numpy \ && pip install Image \ && pip install mxnet==1.5.0 # Python packages for pytests RUN pip install pytest==4.0.0 \ && pip install pytest-cov \ && pip install pytest-mock # Python packages for benchmark RUN pip install pandas # Install NodeJS and packages for API tests RUN curl -sL https://deb.nodesource.com/setup_14.x | sudo -E bash - \ && sudo apt-get install -y nodejs \ && sudo npm install -g newman newman-reporter-html # Install jmeter for benchmark # ToDo: Remove --no-check-certificate; temporarily added to bypass jmeter-plugins.org's expired certificate RUN cd /opt \ && wget https://archive.apache.org/dist/jmeter/binaries/apache-jmeter-5.3.tgz \ && tar -xzf apache-jmeter-5.3.tgz \ && cd apache-jmeter-5.3 \ && ln -s /opt/apache-jmeter-5.3/bin/jmeter /usr/local/bin/jmeter \ && wget --no-check-certificate https://jmeter-plugins.org/get/ -O lib/ext/jmeter-plugins-manager-1.4.jar \ && wget http://search.maven.org/remotecontent?filepath=kg/apc/cmdrunner/2.2/cmdrunner-2.2.jar -O lib/cmdrunner-2.2.jar \ && java -cp lib/ext/jmeter-plugins-manager-1.4.jar org.jmeterplugins.repository.PluginManagerCMDInstaller \ && bin/PluginsManagerCMD.sh install jpgc-synthesis=2.1,jpgc-filterresults=2.1,jpgc-mergeresults=2.1,jpgc-cmd=2.1,jpgc-perfmon=2.1 ================================================ FILE: .circleci/scripts/linux_build.sh ================================================ #!/bin/bash python setup.py bdist_wheel --universal ================================================ FILE: .circleci/scripts/linux_test_api.sh ================================================ #!/bin/bash MODEL_STORE_DIR='test/model_store' MMS_LOG_FILE_MANAGEMENT='mms_management.log' MMS_LOG_FILE_INFERENCE='mms_inference.log' MMS_LOG_FILE_HTTPS='mms_https.log' MMS_CONFIG_FILE_HTTPS='test/resources/config.properties' POSTMAN_ENV_FILE='test/postman/environment.json' POSTMAN_COLLECTION_MANAGEMENT='test/postman/management_api_test_collection.json' POSTMAN_COLLECTION_INFERENCE='test/postman/inference_api_test_collection.json' POSTMAN_COLLECTION_HTTPS='test/postman/https_test_collection.json' POSTMAN_DATA_FILE_INFERENCE='test/postman/inference_data.json' REPORT_FILE_MANAGEMENT='test/management-api-report.html' REPORT_FILE_INFERENCE='test/inference-api-report.html' REPORT_FILE_HTTPS='test/https-api-report.html' start_mms_server() { multi-model-server --start --model-store $1 >> $2 2>&1 sleep 10 } start_mms_secure_server() { multi-model-server --start --mms-config $MMS_CONFIG_FILE_HTTPS --model-store $1 >> $2 2>&1 sleep 10 } stop_mms_server() { multi-model-server --stop } trigger_management_tests(){ start_mms_server $MODEL_STORE_DIR $MMS_LOG_FILE_MANAGEMENT newman run -e $POSTMAN_ENV_FILE $POSTMAN_COLLECTION_MANAGEMENT \ -r cli,html --reporter-html-export $REPORT_FILE_MANAGEMENT --verbose stop_mms_server } trigger_inference_tests(){ start_mms_server $MODEL_STORE_DIR $MMS_LOG_FILE_INFERENCE newman run -e $POSTMAN_ENV_FILE $POSTMAN_COLLECTION_INFERENCE -d $POSTMAN_DATA_FILE_INFERENCE \ -r cli,html --reporter-html-export $REPORT_FILE_INFERENCE --verbose stop_mms_server } trigger_https_tests(){ start_mms_secure_server $MODEL_STORE_DIR $MMS_LOG_FILE_HTTPS newman run --insecure -e $POSTMAN_ENV_FILE $POSTMAN_COLLECTION_HTTPS \ -r cli,html --reporter-html-export $REPORT_FILE_HTTPS --verbose stop_mms_server } mkdir -p $MODEL_STORE_DIR case $1 in 'management') trigger_management_tests ;; 'inference') trigger_inference_tests ;; 'https') trigger_https_tests ;; 'ALL') trigger_management_tests trigger_inference_tests trigger_https_tests ;; *) echo $1 'Invalid' echo 'Please specify any one of - management | inference | https | ALL' exit 1 ;; esac ================================================ FILE: .circleci/scripts/linux_test_benchmark.sh ================================================ #!/bin/bash # Hack needed to make it work with existing benchmark.py # benchmark.py expects jmeter to be present at a very specific location mkdir -p /home/ubuntu/.linuxbrew/Cellar/jmeter/5.3/libexec/bin/ ln -s /opt/apache-jmeter-5.3/bin/jmeter /home/ubuntu/.linuxbrew/Cellar/jmeter/5.3/libexec/bin/jmeter multi-model-server --start >> mms.log 2>&1 sleep 30 cd benchmarks python benchmark.py latency multi-model-server --stop ================================================ FILE: .circleci/scripts/linux_test_modelarchiver.sh ================================================ #!/bin/bash cd model-archiver/ # Lint test pylint -rn --rcfile=./model_archiver/tests/pylintrc model_archiver/. # Execute python unit tests python -m pytest --cov-report html:results_units --cov=./ model_archiver/tests/unit_tests # Install model archiver module pip install . # Execute integration tests python -m pytest model_archiver/tests/integ_tests # ToDo - Report for Integration tests ? ================================================ FILE: .circleci/scripts/linux_test_perf_regression.sh ================================================ #!/bin/bash multi-model-server --start \ --models squeezenet=https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar \ >> mms.log 2>&1 sleep 90 cd performance_regression # Only on a python 2 environment - PY_MAJOR_VER=$(python -c 'import sys; major = sys.version_info.major; print(major);') if [ $PY_MAJOR_VER -eq 2 ]; then # Hack to use python 3.6.5 for bzt installation and execution export PATH="/root/.pyenv/bin:/root/.pyenv/shims:$PATH" pyenv local 3.6.5 fi # Install dependencies pip install bzt curl -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg bzt -o modules.jmeter.path=/opt/apache-jmeter-5.3/bin/jmeter \ -o settings.artifacts-dir=/tmp/mms-performance-regression/ \ -o modules.console.disable=true \ imageInputModelPlan.jmx.yaml \ -report multi-model-server --stop ================================================ FILE: .circleci/scripts/linux_test_python.sh ================================================ #!/bin/bash # Lint Test pylint -rn --rcfile=./mms/tests/pylintrc mms/. # Execute python tests python -m pytest --cov-report html:htmlcov --cov=mms/ mms/tests/unit_tests/ ================================================ FILE: .coveragerc ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. [report] exclude_lines = pragma: no cover if __name__ == .__main__.: if __name__ == "__main__" : [run] branch = True omit = */__init__.py mms/tests/* mms/utils/model_server_error_codes.py mms/utils/timeit_decorator.py mms/storage.py mms/metrics/system_metrics.py mms/utils/mxnet/* mms/examples/metric_push_example.py mms/model_service/* ================================================ FILE: .github/PULL_REQUEST_TEMPLATE.md ================================================ Before or while filing an issue please feel free to join our [ slack channel](https://join.slack.com/t/mms-awslabs/shared_invite/enQtNDk4MTgzNDc5NzE4LTBkYTAwMjBjMTVmZTdkODRmYTZkNjdjZGYxZDI0ODhiZDdlM2Y0ZGJiZTczMGY3Njc4MmM3OTQ0OWI2ZDMyNGQ) to get in touch with development team, ask questions, find out what's cooking and more! ## Issue #, if available: ## Description of changes: ## Testing done: **To run CI tests on your changes refer [README.md](https://github.com/awslabs/multi-model-server/blob/master/ci/README.md)** By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. ================================================ FILE: .gitignore ================================================ .gradle # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python env/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # dotenv .env # virtualenv .venv venv/ ENV/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ # mac .DS_Store # PyCharm .idea/ # Log *.log.* # Model *.model # Pictures *.jpg # Prop file in benchmark benchmarks/*.properties # intellij files *.iml # MMS files mms/frontend mms/plugins ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright {yyyy} {name of copyright owner} Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: LICENSE.txt ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright {yyyy} {name of copyright owner} Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: MANIFEST.in ================================================ include mms/frontend/model-server.jar include PyPiDescription.rst include mms/configs/* ================================================ FILE: PyPiDescription.rst ================================================ Project Description =================== Multi Model Server (MMS) is a flexible and easy to use tool for serving deep learning models exported from `MXNet `__ or the Open Neural Network Exchange (`ONNX `__). Use the MMS Server CLI, or the pre-configured Docker images, to start a service that sets up HTTP endpoints to handle model inference requests. Detailed documentation and examples are provided in the `docs folder `__. Prerequisites ------------- * **java 8**: Required. MMS use java to serve HTTP requests. You must install java 8 (or later) and make sure java is on available in $PATH environment variable *before* installing MMS. If you have multiple java installed, you can use $JAVA_HOME environment vairable to control which java to use. * **mxnet**: `mxnet` will not be installed by default with MMS 1.0 any more. You have to install it manually if you use MxNet. For ubuntu: :: sudo apt-get install openjdk-8-jre-headless For centos :: sudo yum install java-1.8.0-openjdk For Mac: :: brew tap caskroom/versions brew update brew cask install java8 Install MxNet: :: pip install mxnet MXNet offers MKL pip packages that will be much faster when running on Intel hardware. To install mkl package for CPU: :: pip install mxnet-mkl or for GPU instance: :: pip install mxnet-cu92mkl Installation ------------ :: pip install multi-model-server Development ----------- We welcome new contributors of all experience levels. For information on how to install MMS for development, refer to the `MMS docs `__. Important links --------------- - `Official source code repo `__ - `Download releases `__ - `Issue tracker `__ Source code ----------- You can check the latest source code as follows: :: git clone https://github.com/awslabs/multi-model-server.git Testing ------- After installation, try out the MMS Quickstart for - `Serving a Model `__ - `Create a Model Archive `__. Help and Support ---------------- - `Documentation `__ - `Forum `__ Citation -------- If you use MMS in a publication or project, please cite MMS: https://github.com/awslabs/multi-model-server ================================================ FILE: README.md ================================================ Multi Model Server ======= | ubuntu/python-2.7 | ubuntu/python-3.6 | |---------|---------| | ![Python3 Build Status](https://codebuild.us-east-1.amazonaws.com/badges?uuid=eyJlbmNyeXB0ZWREYXRhIjoicGZ6dXFmMU54UGxDaGsxUDhXclJLcFpHTnFMNld6cW5POVpNclc4Vm9BUWJNamZKMGdzbk1lOU92Z0VWQVZJTThsRUttOW8rUzgxZ2F0Ull1U1VkSHo0PSIsIml2UGFyYW1ldGVyU3BlYyI6IkJJaFc1QTEwRGhwUXY1dDgiLCJtYXRlcmlhbFNldFNlcmlhbCI6MX0%3D&branch=master) | ![Python2 Build Status](https://codebuild.us-east-1.amazonaws.com/badges?uuid=eyJlbmNyeXB0ZWREYXRhIjoiYVdIajEwVW9uZ3cvWkZqaHlaRGNUU2M0clE2aUVjelJranJoYTI3S1lHT3R5THJXdklzejU2UVM5NWlUTWdwaVVJalRwYi9GTnJ1aUxiRXIvTGhuQ2g0PSIsIml2UGFyYW1ldGVyU3BlYyI6IjArcHVCaFgvR1pTN1JoSG4iLCJtYXRlcmlhbFNldFNlcmlhbCI6MX0%3D&branch=master) | Multi Model Server (MMS) is a flexible and easy to use tool for serving deep learning models trained using any ML/DL framework. Use the MMS Server CLI, or the pre-configured Docker images, to start a service that sets up HTTP endpoints to handle model inference requests. A quick overview and examples for both serving and packaging are provided below. Detailed documentation and examples are provided in the [docs folder](docs/README.md). Join our [ slack channel](https://join.slack.com/t/mms-awslabs/shared_invite/zt-6cv1kx46-MBTOPLNDwmyBynEvFBsNkQ) to get in touch with development team, ask questions, find out what's cooking and more! ## Contents of this Document * [Quick Start](#quick-start) * [Serve a Model](#serve-a-model) * [Other Features](#other-features) * [External demos powered by MMS](#external-demos-powered-by-mms) * [Contributing](#contributing) ## Other Relevant Documents * [Latest Version Docs](docs/README.md) * [v0.4.0 Docs](https://github.com/awslabs/multi-model-server/blob/v0.4.0/docs/README.md) * [Migrating from v0.4.0 to v1.0.0](docs/migration.md) ## Quick Start ### Prerequisites Before proceeding further with this document, make sure you have the following prerequisites. 1. Ubuntu, CentOS, or macOS. Windows support is experimental. The following instructions will focus on Linux and macOS only. 1. Python - Multi Model Server requires python to run the workers. 1. pip - Pip is a python package management system. 1. Java 8 - Multi Model Server requires Java 8 to start. You have the following options for installing Java 8: For Ubuntu: ```bash sudo apt-get install openjdk-8-jre-headless ``` For CentOS: ```bash sudo yum install java-1.8.0-openjdk ``` For macOS: ```bash brew tap homebrew/cask-versions brew update brew cask install adoptopenjdk8 ``` ### Installing Multi Model Server with pip #### Setup **Step 1:** Setup a Virtual Environment We recommend installing and running Multi Model Server in a virtual environment. It's a good practice to run and install all of the Python dependencies in virtual environments. This will provide isolation of the dependencies and ease dependency management. One option is to use Virtualenv. This is used to create virtual Python environments. You may install and activate a virtualenv for Python 2.7 as follows: ```bash pip install virtualenv ``` Then create a virtual environment: ```bash # Assuming we want to run python2.7 in /usr/local/bin/python2.7 virtualenv -p /usr/local/bin/python2.7 /tmp/pyenv2 # Enter this virtual environment as follows source /tmp/pyenv2/bin/activate ``` Refer to the [Virtualenv documentation](https://virtualenv.pypa.io/en/stable/) for further information. **Step 2:** Install MXNet MMS won't install the MXNet engine by default. If it isn't already installed in your virtual environment, you must install one of the MXNet pip packages. For CPU inference, `mxnet-mkl` is recommended. Install it as follows: ```bash # Recommended for running Multi Model Server on CPU hosts pip install mxnet-mkl ``` For GPU inference, `mxnet-cu92mkl` is recommended. Install it as follows: ```bash # Recommended for running Multi Model Server on GPU hosts pip install mxnet-cu92mkl ``` **Step 3:** Install or Upgrade MMS as follows: ```bash # Install latest released version of multi-model-server pip install multi-model-server ``` To upgrade from a previous version of `multi-model-server`, please refer [migration reference](docs/migration.md) document. **Notes:** * A minimal version of `model-archiver` will be installed with MMS as dependency. See [model-archiver](model-archiver/README.md) for more options and details. * See the [advanced installation](docs/install.md) page for more options and troubleshooting. ### Serve a Model Once installed, you can get MMS model server up and running very quickly. Try out `--help` to see all the CLI options available. ```bash multi-model-server --help ``` For this quick start, we'll skip over most of the features, but be sure to take a look at the [full server docs](docs/server.md) when you're ready. Here is an easy example for serving an object classification model: ```bash multi-model-server --start --models squeezenet=https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar ``` With the command above executed, you have MMS running on your host, listening for inference requests. **Please note, that if you specify model(s) during MMS start - it will automatically scale backend workers to the number equal to available vCPUs (if you run on CPU instance) or to the number of available GPUs (if you run on GPU instance). In case of powerful hosts with a lot of compute resoures (vCPUs or GPUs) this start up and autoscaling process might take considerable time. If you would like to minimize MMS start up time you can try to avoid registering and scaling up model during start up time and move that to a later point by using corresponding [Management API](docs/management_api.md#register-a-model) calls (this allows finer grain control to how much resources are allocated for any particular model).** To test it out, you can open a new terminal window next to the one running MMS. Then you can use `curl` to download one of these [cute pictures of a kitten](https://www.google.com/search?q=cute+kitten&tbm=isch&hl=en&cr=&safe=images) and curl's `-o` flag will name it `kitten.jpg` for you. Then you will `curl` a `POST` to the MMS predict endpoint with the kitten's image. ![kitten](docs/images/kitten_small.jpg) In the example below, we provide a shortcut for these steps. ```bash curl -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg curl -X POST http://127.0.0.1:8080/predictions/squeezenet -T kitten.jpg ``` The predict endpoint will return a prediction response in JSON. It will look something like the following result: ```json [ { "probability": 0.8582232594490051, "class": "n02124075 Egyptian cat" }, { "probability": 0.09159987419843674, "class": "n02123045 tabby, tabby cat" }, { "probability": 0.0374876894056797, "class": "n02123159 tiger cat" }, { "probability": 0.006165083032101393, "class": "n02128385 leopard, Panthera pardus" }, { "probability": 0.0031716004014015198, "class": "n02127052 lynx, catamount" } ] ``` You will see this result in the response to your `curl` call to the predict endpoint, and in the server logs in the terminal window running MMS. It's also being [logged locally with metrics](docs/metrics.md). Other models can be downloaded from the [model zoo](docs/model_zoo.md), so try out some of those as well. Now you've seen how easy it can be to serve a deep learning model with MMS! [Would you like to know more?](docs/server.md) ### Stopping the running model server To stop the current running model-server instance, run the following command: ```bash $ multi-model-server --stop ``` You would see output specifying that multi-model-server has stopped. ### Create a Model Archive MMS enables you to package up all of your model artifacts into a single model archive. This makes it easy to share and deploy your models. To package a model, check out [model archiver documentation](model-archiver/README.md) ## Recommended production deployments * MMS doesn't provide authentication. You have to have your own authentication proxy in front of MMS. * MMS doesn't provide throttling, it's vulnerable to DDoS attack. It's recommended to running MMS behind a firewall. * MMS only allows localhost access by default, see [Network configuration](docs/configuration.md#configure-mms-listening-port) for detail. * SSL is not enabled by default, see [Enable SSL](docs/configuration.md#enable-ssl) for detail. * MMS use a config.properties file to configure MMS's behavior, see [Manage MMS](docs/configuration.md) page for detail of how to configure MMS. * For better security, we recommend running MMS inside docker container. This project includes Dockerfiles to build containers recommended for production deployments. These containers demonstrate how to customize your own production MMS deployment. The basic usage can be found on the [Docker readme](docker/README.md). ## Other Features Browse over to the [Docs readme](docs/README.md) for the full index of documentation. This includes more examples, how to customize the API service, API endpoint details, and more. ## External demos powered by MMS Here are some example demos of deep learning applications, powered by MMS: | | | |:------:|:-----------:| | [Product Review Classification](https://thomasdelteil.github.io/TextClassificationCNNs_MXNet/) demo4 |[Visual Search](https://thomasdelteil.github.io/VisualSearch_MXNet/) demo1| | [Facial Emotion Recognition](https://thomasdelteil.github.io/FacialEmotionRecognition_MXNet/) demo2 |[Neural Style Transfer](https://thomasdelteil.github.io/NeuralStyleTransfer_MXNet/) demo3 | ## Contributing We welcome all contributions! To file a bug or request a feature, please file a GitHub issue. Pull requests are welcome. ================================================ FILE: _config.yml ================================================ theme: jekyll-theme-cayman ================================================ FILE: benchmarks/README.md ================================================ # Multi Model Server Benchmarking The benchmarks measure the performance of MMS on various models and benchmarks. It supports either a number of built-in models or a custom model passed in as a path or URL to the .model file. It also runs various benchmarks using these models (see benchmarks section below). The benchmarks are run through a python3 script on the user machine through jmeter. MMS is run on the same machine in a docker instance to avoid network latencies. The benchmark must be run from within the context of the full MMS repo because it executes the local code as the version of MMS (and it is recompiled between runs) for ease of development. ## Installation ### Ubuntu The script is mainly intended to run on a Ubuntu EC2 instance. For this reason, we have provided an `install_dependencies.sh` script to install everything needed to execute the benchmark on this environment. All you need to do is run this file and clone the MMS repo. ### MacOS For mac, you should have python3 and java installed. If you wish to run the default benchmarks featuring a docker-based instance of MMS, you will need to install docker as well. Finally, you will need to install jmeter with plugins which can be accomplished by running `mac_install_dependencies.sh`. ### Other For other environments, manual installation is necessary. The list of dependencies to be installed can be found below or by reading the ubuntu installation script. The benchmarking script requires the following to run: - python3 - A JDK and JRE - jmeter installed through homebrew or linuxbrew with the plugin manager and the following plugins: jpgc-synthesis=2.1,jpgc-filterresults=2.1,jpgc-mergeresults=2.1,jpgc-cmd=2.1,jpgc-perfmon=2.1 - Docker-ce with the current user added to the docker group - Nvidia-docker (for GPU) ## Models The pre-loaded models for the benchmark can be mostly found in the [MMS model zoo](https://github.com/awslabs/multi-model-server/blob/master/docs/model_zoo.md). We currently support the following: - [resnet: ResNet-18 (Default)](https://github.com/awslabs/multi-model-server/blob/master/docs/model_zoo.md#resnet-18) - [squeezenet: SqueezeNet V1.1](https://github.com/awslabs/multi-model-server/blob/master/docs/model_zoo.md#squeezenet_v1.1) - [lstm: lstm-ptb](https://github.com/awslabs/multi-model-server/blob/master/docs/model_zoo.md#lstm-ptb) - [noop: noop-v1.0](https://s3.amazonaws.com/model-server/models/noop/noop-v1.0.model) Simple Noop model which returns "Hello world" to any input specified. - [noop_echo: noop_echo-v1.0](https://s3.amazonaws.com/model-server/models/noop/noop_echo-v1.0.model) Simple Noop model which returns whatever input is given to it. ## Benchmarks We support several basic benchmarks: - throughput: Run inference with enough threads to occupy all workers and ensure full saturation of resources to find the throughput. The number of threads defaults to 100. - latency: Run inference with a single thread to determine the latency - ping: Test the throughput of pinging against the frontend - load: Loads the same model many times in parallel. The number of loads is given by the "count" option and defaults to 16. - repeated_scale_calls: Will scale the model up to "scale_up_workers"=16 then down to "scale_down_workers"=1 then up and down repeatedly. - multiple_models: Loads and scales up three models (1. noop, 2. lstm, and 3. resnet), at the same time, runs inferences on them, and then scales them down. Use the options "urlN", "modelN_name", "dataN" to specify the model url, model name, and the data to pass to the model respectively. data1 and data2 are of the format "'Some garbage data being passed here'" and data3 is the filesystem path to a file to upload. We also support compound benchmarks: - concurrent_inference: Runs the basic benchmark with different numbers of threads ## Examples Run basic latency test on default resnet-18 model\ ```./benchmark.py latency``` Run basic throughput test on default resnet-18 model.\ ```./benchmark.py throughput``` Run all benchmarks\ ```./benchmark.py --all``` Run using the noop-v1.0 model\ ```./benchmark.py latency -m noop_v1.0``` Run on GPU (4 gpus)\ ```./benchmark.py latency -g 4``` Run with a custom image\ ```./benchmark.py latency -i {imageFilePath}``` Run with a custom model (works only for CNN based models, which accept image as an input for now. We will add support for more input types in future to this command. )\ ```./benchmark.py latency -c {modelUrl} -i {imageFilePath}``` Run with custom options\ ```./benchmark.py repeated_scale_calls --options scale_up_workers 100 scale_down_workers 10``` Run against an already running instance of MMS\ ```./benchmark.py latency --mms 127.0.0.1``` (defaults to http, port 80, management port = port + 1)\ ```./benchmark.py latency --mms 127.0.0.1:8080 --management-port 8081```\ ```./benchmark.py latency --mms https://127.0.0.1:8443``` Run verbose with only a single loop\ ```./benchmark.py latency -v -l 1``` ## Benchmark options The full list of options can be found by running with the -h or --help flags. ## Profiling ### Frontend The benchmarks can be used in conjunction with standard profiling tools such as JProfiler to analyze the system performance. JProfiler can be downloaded from their [website](https://www.ej-technologies.com/products/jprofiler/overview.html). Once downloaded, open up JProfiler and follow these steps: 1. Run MMS directly through gradle (do not use docker). This can be done either on your machine or on a remote machine accessible through SSH. 2. In JProfiler, select "Attach" from the ribbon and attach to the ModelServer. The process name in the attach window should be "com.amazonaws.ml.mms.ModelServer". If it is on a remote machine, select "On another computer" in the attach window and enter the SSH details. For the session startup settings, you can leave it with the defaults. At this point, you should see live CPU and Memory Usage data on JProfiler's Telemetries section. 3. Select Start Recordings in JProfiler's ribbon 4. Run the Benchmark script targeting your running MMS instance. It might run something like `./benchmark.py throughput --mms https://127.0.0.1:8443`. It can be run on either your local machine or a remote machine (if you are running remote), but we recommend running the benchmark on the same machine as the model server to avoid confounding network latencies. 5. Once the benchmark script has finished running, select Stop Recordings in JProfiler's ribbon Once you have stopped recording, you should be able to analyze the data. One useful section to examine is CPU views > Call Tree and CPU views > Hot Spots to see where the processor time is going. ### Backend The benchmarks can also be used to analyze the backend performance using cProfile. It does not require any additional packages to run the benchmark, but viewing the logs does require an additional package. Run `pip install snakeviz` to install this. To run the python profiling, follow these steps: 1. In the file `mms/model_service_worker.py`, set the constant BENCHMARK to true at the top to enable benchmarking. 2. Run the benchmark and MMS. They can either be done automatically inside the docker container or separately with the "--mms" flag. 3. Run MMS directly through gradle (do not use docker). This can be done either on your machine or on a remote machine accessible through SSH. 4. Run the Benchmark script targeting your running MMS instance. It might run something like `./benchmark.py throughput --mms https://127.0.0.1:8443`. It can be run on either your local machine or a remote machine (if you are running remote), but we recommend running the benchmark on the same machine as the model server to avoid confounding network latencies. 5. Run `snakeviz /tmp/mmsPythonProfile.prof` to view the profiling data. It should start up a web server on your machine and automatically open the page. 6. Don't forget to set BENCHMARK = False in the model_service_worker.py file after you are finished. ================================================ FILE: benchmarks/benchmark.py ================================================ #!/usr/bin/env python3 # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Execute the MMS Benchmark. For instructions, run with the --help flag """ # pylint: disable=redefined-builtin import argparse import itertools import multiprocessing import os import pprint import shutil import subprocess import sys import time import traceback from functools import reduce from urllib.request import urlretrieve import pandas as pd BENCHMARK_DIR = "/tmp/MMSBenchmark/" OUT_DIR = os.path.join(BENCHMARK_DIR, 'out/') RESOURCE_DIR = os.path.join(BENCHMARK_DIR, 'resource/') RESOURCE_MAP = { 'kitten.jpg': 'https://s3.amazonaws.com/model-server/inputs/kitten.jpg' } # Listing out all the JMX files JMX_IMAGE_INPUT_MODEL_PLAN = 'imageInputModelPlan.jmx' JMX_TEXT_INPUT_MODEL_PLAN = 'textInputModelPlan.jmx' JMX_PING_PLAN = 'pingPlan.jmx' JMX_CONCURRENT_LOAD_PLAN = 'concurrentLoadPlan.jmx' JMX_CONCURRENT_SCALE_CALLS = 'concurrentScaleCalls.jmx' JMX_MULTIPLE_MODELS_LOAD_PLAN = 'multipleModelsLoadPlan.jmx' JMX_GRAPHS_GENERATOR_PLAN = 'graphsGenerator.jmx' # Listing out the models tested MODEL_RESNET_18 = 'resnet-18' MODEL_SQUEEZE_NET = 'squeezenet' MODEL_LSTM_PTB = 'lstm_ptb' MODEL_NOOP = 'noop-v1.0' MODEL_MAP = { MODEL_SQUEEZE_NET: (JMX_IMAGE_INPUT_MODEL_PLAN, {'url': 'https://s3.amazonaws.com/model-server/models/squeezenet_v1.1/squeezenet_v1.1.model', 'model_name': MODEL_SQUEEZE_NET, 'input_filepath': 'kitten.jpg'}), MODEL_RESNET_18: (JMX_IMAGE_INPUT_MODEL_PLAN, {'url': 'https://s3.amazonaws.com/model-server/models/resnet-18/resnet-18.model', 'model_name': MODEL_RESNET_18, 'input_filepath': 'kitten.jpg'}), MODEL_LSTM_PTB: (JMX_TEXT_INPUT_MODEL_PLAN, {'url': 'https://s3.amazonaws.com/model-server/models/lstm_ptb/lstm_ptb.model', 'model_name': MODEL_LSTM_PTB, 'data': 'lstm_ip.json'}), MODEL_NOOP: (JMX_TEXT_INPUT_MODEL_PLAN, {'url': 'https://s3.amazonaws.com/model-server/models/noop/noop-v1.0.mar', 'model_name': MODEL_NOOP, 'data': 'noop_ip.txt'}) } # Mapping of which row is relevant for a given JMX Test Plan EXPERIMENT_RESULTS_MAP = { JMX_IMAGE_INPUT_MODEL_PLAN: ['Inference Request'], JMX_TEXT_INPUT_MODEL_PLAN: ['Inference Request'], JMX_PING_PLAN: ['Ping Request'], JMX_CONCURRENT_LOAD_PLAN: ['Load Model Request'], JMX_CONCURRENT_SCALE_CALLS: ['Scale Up Model', 'Scale Down Model'], JMX_MULTIPLE_MODELS_LOAD_PLAN: ['Inference Request'] } JMETER_RESULT_SETTINGS = { 'jmeter.reportgenerator.overall_granularity': 1000, # 'jmeter.reportgenerator.report_title': '"MMS Benchmark Report Dashboard"', 'aggregate_rpt_pct1': 50, 'aggregate_rpt_pct2': 90, 'aggregate_rpt_pct3': 99, } # Dictionary of what's present in the output csv generated v/s what we want to change the column name to for readability AGGREGATE_REPORT_CSV_LABELS_MAP = { 'aggregate_report_rate': 'Throughput', 'average': 'Average', 'aggregate_report_median': 'Median', 'aggregate_report_90%_line': 'aggregate_report_90_line', 'aggregate_report_99%_line': 'aggregate_report_99_line', 'aggregate_report_error%': 'aggregate_report_error' } CELLAR = '/home/ubuntu/.linuxbrew/Cellar/jmeter' if 'linux' in sys.platform else '/usr/local/Cellar/jmeter' JMETER_VERSION = os.listdir(CELLAR)[0] CMDRUNNER = '{}/{}/libexec/lib/ext/CMDRunner.jar'.format(CELLAR, JMETER_VERSION) JMETER = '{}/{}/libexec/bin/jmeter'.format(CELLAR, JMETER_VERSION) MMS_BASE = reduce(lambda val,func: func(val), (os.path.abspath(__file__),) + (os.path.dirname,) * 2) JMX_BASE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'jmx') CONFIG_PROP = os.path.join(MMS_BASE, 'benchmarks', 'config.properties') CONFIG_PROP_TEMPLATE = os.path.join(MMS_BASE, 'benchmarks', 'config_template.properties') DOCKER_MMS_BASE = "/multi-model-server" DOCKER_CONFIG_PROP = os.path.join(DOCKER_MMS_BASE, 'benchmarks', 'config.properties') # Commenting our NOOPs for now since there's a bug on MMS model loading for .mar files ALL_BENCHMARKS = list(itertools.product(('latency', 'throughput'), (MODEL_RESNET_18,MODEL_NOOP, MODEL_LSTM_PTB))) # + [('multiple_models', MODEL_NOOP)] # + list(itertools.product(('load', 'repeated_scale_calls'), (MODEL_RESNET_18,))) \ To Add once # repeated_scale_calls is fixed BENCHMARK_NAMES = ['latency', 'throughput'] class ChDir: def __init__(self, path): self.curPath = os.getcwd() self.path = path def __enter__(self): os.chdir(self.path) def __exit__(self, *args): os.chdir(self.curPath) def basename(path): return os.path.splitext(os.path.basename(path))[0] def get_resource(name): url = RESOURCE_MAP[name] path = os.path.join(RESOURCE_DIR, name) if not os.path.exists(path): directory = os.path.dirname(path) if not os.path.exists(directory): os.makedirs(directory) urlretrieve(url, path) return path def run_process(cmd, wait=True, **kwargs): output = None if pargs.verbose else subprocess.DEVNULL if pargs.verbose: print(' '.join(cmd) if isinstance(cmd, list) else cmd) if not kwargs.get('shell') and isinstance(cmd, str): cmd = cmd.split(' ') if 'stdout' not in kwargs: kwargs['stdout'] = output if 'stderr' not in kwargs: kwargs['stderr'] = output p = subprocess.Popen(cmd, **kwargs) if wait: p.wait() return p def run_single_benchmark(jmx, jmeter_args=dict(), threads=100, out_dir=None): if out_dir is None: out_dir = os.path.join(OUT_DIR, benchmark_name, basename(benchmark_model)) if os.path.exists(out_dir): shutil.rmtree(out_dir) os.makedirs(out_dir) protocol = 'http' hostname = '127.0.0.1' port = 8080 threads = pargs.threads[0] if pargs.threads else threads workers = pargs.workers[0] if pargs.workers else ( pargs.gpus[0] if pargs.gpus else multiprocessing.cpu_count() ) if pargs.mms: url = pargs.mms[0] if '://' in url: protocol, url = url.split('://') if ':' in url: hostname, port = url.split(':') port = int(port) else: hostname = url port = 80 else: # Start MMS docker = 'nvidia-docker' if pargs.gpus else 'docker' container = 'mms_benchmark_gpu' if pargs.gpus else 'mms_benchmark_cpu' docker_path = 'awsdeeplearningteam/multi-model-server:nightly-mxnet-gpu' \ if pargs.gpus else 'awsdeeplearningteam/multi-model-server:nightly-mxnet-cpu' if pargs.docker: container = 'mms_benchmark_{}'.format(pargs.docker[0].split('/')[1]) docker_path = pargs.docker[0] run_process("{} rm -f {}".format(docker, container)) docker_run_call = "{} run --name {} -p 8080:8080 -p 8081:8081 -itd {}".format(docker, container, docker_path) run_process(docker_run_call) management_port = int(pargs.management[0]) if pargs.management else port + 1 time.sleep(300) try: # temp files tmpfile = os.path.join(out_dir, 'output.jtl') logfile = os.path.join(out_dir, 'jmeter.log') outfile = os.path.join(out_dir, 'out.csv') perfmon_file = os.path.join(out_dir, 'perfmon.csv') graphsDir = os.path.join(out_dir, 'graphs') reportDir = os.path.join(out_dir, 'report') # run jmeter run_jmeter_args = { 'hostname': hostname, 'port': port, 'management_port': management_port, 'protocol': protocol, 'min_workers': workers, 'rampup': 5, 'threads': threads, 'loops': int(pargs.loops[0]), 'perfmon_file': perfmon_file } run_jmeter_args.update(JMETER_RESULT_SETTINGS) run_jmeter_args.update(jmeter_args) run_jmeter_args.update(dict(zip(pargs.options[::2], pargs.options[1::2]))) abs_jmx = jmx if os.path.isabs(jmx) else os.path.join(JMX_BASE, jmx) jmeter_args_str = ' '.join(sorted(['-J{}={}'.format(key, val) for key, val in run_jmeter_args.items()])) jmeter_call = '{} -n -t {} {} -l {} -j {} -e -o {}'.format(JMETER, abs_jmx, jmeter_args_str, tmpfile, logfile, reportDir) run_process(jmeter_call) time.sleep(30) # run AggregateReport ag_call = 'java -jar {} --tool Reporter --generate-csv {} --input-jtl {} --plugin-type AggregateReport'.format(CMDRUNNER, outfile, tmpfile) run_process(ag_call) # Generate output graphs gLogfile = os.path.join(out_dir, 'graph_jmeter.log') graphing_args = { 'raw_output': graphsDir, 'jtl_input': tmpfile } graphing_args.update(JMETER_RESULT_SETTINGS) gjmx = os.path.join(JMX_BASE, JMX_GRAPHS_GENERATOR_PLAN) graphing_args_str = ' '.join(['-J{}={}'.format(key, val) for key, val in graphing_args.items()]) graphing_call = '{} -n -t {} {} -j {}'.format(JMETER, gjmx, graphing_args_str, gLogfile) run_process(graphing_call) print("Output available at {}".format(out_dir)) print("Report generated at {}".format(os.path.join(reportDir, 'index.html'))) data_frame = pd.read_csv(outfile, index_col=0) report = list() for val in EXPERIMENT_RESULTS_MAP[jmx]: for full_val in [fv for fv in data_frame.index if val in fv]: report.append(decorate_metrics(data_frame, full_val)) return report except Exception: # pylint: disable=broad-except traceback.print_exc() def run_multi_benchmark(key, xs, *args, **kwargs): out_dir = os.path.join(OUT_DIR, benchmark_name, basename(benchmark_model)) if os.path.exists(out_dir): shutil.rmtree(out_dir) os.makedirs(out_dir) reports = dict() out_dirs = [] for i, x in enumerate(xs): print("Running value {}={} (value {}/{})".format(key, x, i+1, len(xs))) kwargs[key] = x sub_out_dir = os.path.join(out_dir, str(i+1)) out_dirs.append(sub_out_dir) report = run_single_benchmark(*args, out_dir=sub_out_dir, **kwargs) reports[x] = report # files merge_results = os.path.join(out_dir, 'merge-results.properties') joined = os.path.join(out_dir, 'joined.csv') reportDir = os.path.join(out_dir, 'report') # merge runs together inputJtls = [os.path.join(out_dirs[i], 'output.jtl') for i in range(len(xs))] prefixes = ["{} {}: ".format(key, x) for x in xs] baseJtl = inputJtls[0] basePrefix = prefixes[0] for i in range(1, len(xs), 3): # MergeResults only joins up to 4 at a time with open(merge_results, 'w') as f: curInputJtls = [baseJtl] + inputJtls[i:i+3] curPrefixes = [basePrefix] + prefixes[i:i+3] for j, (jtl, p) in enumerate(zip(curInputJtls, curPrefixes)): f.write("inputJtl{}={}\n".format(j+1, jtl)) f.write("prefixLabel{}={}\n".format(j+1, p)) f.write("\n") merge_call = 'java -jar {} --tool Reporter --generate-csv joined.csv --input-jtl {} --plugin-type MergeResults'.format(CMDRUNNER, merge_results) time.sleep(30) run_process(merge_call) shutil.move('joined.csv', joined) # MergeResults ignores path given and puts result into cwd baseJtl = joined basePrefix = "" # build report time.sleep(30) run_process('{} -g {} -o {}'.format(JMETER, joined, reportDir)) print("Merged output available at {}".format(out_dir)) print("Merged report generated at {}".format(os.path.join(reportDir, 'index.html'))) return reports def parseModel(): if benchmark_model in MODEL_MAP: plan, jmeter_args = MODEL_MAP[benchmark_model] for k, v in jmeter_args.items(): if v in RESOURCE_MAP: jmeter_args[k] = get_resource(v) if k == 'data': jmeter_args[k] = os.path.join(MMS_BASE, 'benchmarks', v) if pargs.input: jmeter_args['input_filepath'] = pargs.input[0] else: plan = JMX_IMAGE_INPUT_MODEL_PLAN jmeter_args = { 'url': benchmark_model, 'model_name': basename(benchmark_model), 'input_filepath': pargs.input[0] } return plan, jmeter_args def decorate_metrics(data_frame, row_to_read): temp_dict = data_frame.loc[row_to_read].to_dict() result = dict() row_name = row_to_read.replace(' ', '_') for key, value in temp_dict.items(): if key in AGGREGATE_REPORT_CSV_LABELS_MAP: new_key = '{}_{}_{}_{}'.format(benchmark_name, benchmark_model, row_name, AGGREGATE_REPORT_CSV_LABELS_MAP[key]) result[new_key] = value return result class Benchmarks: """ Contains benchmarks to run """ @staticmethod def throughput(): """ Performs a simple single benchmark that measures the model throughput on inference tasks """ plan, jmeter_args = parseModel() return run_single_benchmark(plan, jmeter_args) @staticmethod def latency(): """ Performs a simple single benchmark that measures the model latency on inference tasks """ plan, jmeter_args = parseModel() return run_single_benchmark(plan, jmeter_args, threads=1) @staticmethod def ping(): """ Performs a simple ping benchmark that measures the throughput for a ping request to the frontend """ return run_single_benchmark(JMX_PING_PLAN, dict(), threads=5000) @staticmethod def load(): """ Benchmarks number of concurrent inference requests """ plan, jmeter_args = parseModel() plan = JMX_CONCURRENT_LOAD_PLAN jmeter_args['count'] = 8 return run_single_benchmark(plan, jmeter_args) @staticmethod def repeated_scale_calls(): """ Benchmarks number of concurrent inference requests """ plan, jmeter_args = parseModel() plan = JMX_CONCURRENT_SCALE_CALLS jmeter_args['scale_up_workers'] = 16 jmeter_args['scale_down_workers'] = 2 return run_single_benchmark(plan, jmeter_args) @staticmethod def multiple_models(): """ Tests with 3 models """ plan = JMX_MULTIPLE_MODELS_LOAD_PLAN jmeter_args = { 'url1': MODEL_MAP[MODEL_NOOP][1]['url'], 'url2': MODEL_MAP[MODEL_LSTM_PTB][1]['url'], 'url3': MODEL_MAP[MODEL_RESNET_18][1]['url'], 'model1_name': MODEL_MAP[MODEL_NOOP][1]['model_name'], 'model2_name': MODEL_MAP[MODEL_LSTM_PTB][1]['model_name'], 'model3_name': MODEL_MAP[MODEL_RESNET_18][1]['model_name'], 'data3': get_resource('kitten.jpg') } return run_single_benchmark(plan, jmeter_args) @staticmethod def concurrent_inference(): """ Benchmarks number of concurrent inference requests """ plan, jmeter_args = parseModel() return run_multi_benchmark('threads', range(1, 3*5+1, 3), plan, jmeter_args) def run_benchmark(): if hasattr(Benchmarks, benchmark_name): print("Running benchmark {} with model {}".format(benchmark_name, benchmark_model)) res = getattr(Benchmarks, benchmark_name)() pprint.pprint(res) print('\n') else: raise Exception("No benchmark benchmark_named {}".format(benchmark_name)) def modify_config_props_for_mms(pargs): shutil.copyfile(CONFIG_PROP_TEMPLATE, CONFIG_PROP) with open(CONFIG_PROP, 'a') as f: f.write('\nnumber_of_netty_threads=32') f.write('\njob_queue_size=1000') if pargs.gpus: f.write('\nnumber_of_gpu={}'.format(pargs.gpus[0])) if __name__ == '__main__': benchmark_name_options = [f for f in dir(Benchmarks) if callable(getattr(Benchmarks, f)) and f[0] != '_'] parser = argparse.ArgumentParser(prog='multi-model-server-benchmarks', description='Benchmark Multi Model Server') target = parser.add_mutually_exclusive_group(required=True) target.add_argument('name', nargs='?', type=str, choices=benchmark_name_options, help='The name of the benchmark to run') target.add_argument('-a', '--all', action='store_true', help='Run all benchmarks') target.add_argument('-s', '--suite', action='store_true', help='Run throughput and latency on a supplied model') model = parser.add_mutually_exclusive_group() model.add_argument('-m', '--model', nargs=1, type=str, dest='model', default=[MODEL_RESNET_18], choices=MODEL_MAP.keys(), help='A preloaded model to run. It defaults to {}'.format(MODEL_RESNET_18)) model.add_argument('-c', '--custom-model', nargs=1, type=str, dest='model', help='The path to a custom model to run. The input argument must also be passed. Currently broken') parser.add_argument('-d', '--docker', nargs=1, type=str, default=None, help='Docker hub path to use') parser.add_argument('-i', '--input', nargs=1, type=str, default=None, help='The input to feed to the test') parser.add_argument('-g', '--gpus', nargs=1, type=int, default=None, help='Number of gpus. Leave empty to run CPU only') parser.add_argument('-l', '--loops', nargs=1, type=int, default=[10], help='Number of loops to run') parser.add_argument('-t', '--threads', nargs=1, type=int, default=None, help='Number of jmeter threads to run') parser.add_argument('-w', '--workers', nargs=1, type=int, default=None, help='Number of MMS backend workers to use') parser.add_argument('--mms', nargs=1, type=str, help='Target an already running instance of MMS instead of spinning up a docker container of MMS. Specify the target with the format address:port (for http) or protocol://address:port') parser.add_argument('--management-port', dest='management', nargs=1, type=str, help='When targeting a running MMS instance, specify the management port') parser.add_argument('-v', '--verbose', action='store_true', help='Display all output') parser.add_argument('--options', nargs='*', default=[], help='Additional jmeter arguments. It should follow the format of --options argname1 argval1 argname2 argval2 ...') pargs = parser.parse_args() if os.path.exists(OUT_DIR): if pargs.all: shutil.rmtree(OUT_DIR) os.makedirs(OUT_DIR) else: os.makedirs(OUT_DIR) modify_config_props_for_mms(pargs) if pargs.suite: benchmark_model = pargs.model[0].lower() for benchmark_name in BENCHMARK_NAMES: run_benchmark() if not os.path.isdir(os.path.join(OUT_DIR, benchmark_name, basename(benchmark_model), 'report')): run_benchmark() elif pargs.all: for benchmark_name, benchmark_model in ALL_BENCHMARKS: run_benchmark() else: benchmark_name = pargs.name.lower() benchmark_model = pargs.model[0].lower() run_benchmark() ================================================ FILE: benchmarks/install_dependencies.sh ================================================ #!/bin/bash # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. # This file contains the installation setup for running benchmarks on EC2 isntance. # To run on a machine with GPU : ./install_dependencies True # To run on a machine with CPU : ./install_dependencies False set -ex sudo apt-get update sudo apt-get -y upgrade echo "Setting up your Ubuntu machine to load test MMS" sudo apt-get install -y \ python \ python-pip \ python3-pip \ python3-tk \ python-psutil \ default-jre \ default-jdk \ linuxbrew-wrapper \ build-essential if [[ $1 = True ]] then echo "Installing pip packages for GPU" sudo apt install -y nvidia-cuda-toolkit pip install future psutil mxnet-cu92 pillow --user else echo "Installing pip packages for CPU" pip install future psutil mxnet pillow --user fi pip3 install pandas echo "Installing JMeter through Brew" # Script would end on errors, but everything works fine { yes '' | brew update } || { true } { brew install jmeter --with-plugins } || { true } wget https://jmeter-plugins.org/get/ -O /home/ubuntu/.linuxbrew/Cellar/jmeter/5.0/libexec/lib/ext/jmeter-plugins-manager-1.3.jar wget http://search.maven.org/remotecontent?filepath=kg/apc/cmdrunner/2.2/cmdrunner-2.2.jar -O /home/ubuntu/.linuxbrew/Cellar/jmeter/5.0/libexec/lib/cmdrunner-2.2.jar java -cp /home/ubuntu/.linuxbrew/Cellar/jmeter/5.0/libexec/lib/ext/jmeter-plugins-manager-1.3.jar org.jmeterplugins.repository.PluginManagerCMDInstaller /home/ubuntu/.linuxbrew/Cellar/jmeter/5.0/libexec/bin/PluginsManagerCMD.sh install jpgc-synthesis=2.1,jpgc-filterresults=2.1,jpgc-mergeresults=2.1,jpgc-cmd=2.1,jpgc-perfmon=2.1 echo "Install docker" sudo apt-get remove docker docker-engine docker.io sudo apt-get install -y \ apt-transport-https \ ca-certificates \ curl \ software-properties-common curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - sudo add-apt-repository \ "deb [arch=amd64] https://download.docker.com/linux/ubuntu \ $(lsb_release -cs) \ stable" sudo apt-get update sudo apt-get install -y docker-ce { sudo groupadd docker || {true} } || { true } { gpasswd -a $USER docker } || { true } if [[ $1 = True ]] then echo "Installing nvidia-docker" # If you have nvidia-docker 1.0 installed: we need to remove it and all existing GPU containers { docker volume ls -q -f driver=nvidia-docker | xargs -r -I{} -n1 docker ps -q -a -f volume={} | xargs -r docker rm -f } || { true } { sudo apt-get purge -y nvidia-docker } || { true } # Add the package repositories curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | \ sudo apt-key add - distribution=$(. /etc/os-release;echo $ID$VERSION_ID) curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | \ sudo tee /etc/apt/sources.list.d/nvidia-docker.list sudo apt-get update # Install nvidia-docker2 and reload the Docker daemon configuration sudo apt-get install -y nvidia-docker2 sudo pkill -SIGHUP dockerd # Test nvidia-smi with the latest official CUDA image docker run --runtime=nvidia --rm nvidia/cuda nvidia-smi fi ================================================ FILE: benchmarks/jmx/concurrentLoadPlan.jmx ================================================ false true false model_url ${__P(url, https://s3.amazonaws.com/model-server/models/resnet-18/resnet-18.model)} = The url from where to fetch the models from model ${__P(model_name,resnet-18)} = count ${__P(count,10)} = ${__P(hostname,127.0.0.1)} ${__P(management_port,8444)} ${__P(protocol,https)} 6 continue false ${__P(loops,1)} ${count} 1 false false ${ctr} = true ctr /models?url=${model}-${__counter(FALSE, )}&model_name=${model}-${__counter(FALSE,)} POST true false true false ${count} 0 ================================================ FILE: benchmarks/jmx/concurrentScaleCalls.jmx ================================================ false true false model_url ${__P(url, https://s3.amazonaws.com/model-server/models/resnet-18/resnet-18.model)} = The url from where to fetch the models from model ${__P(model_name,resnet-18)} = Name of the model to run the tests on count ${__P(count,1)} = scale_up_workers ${__P(scale_up_workers,1)} = The number of workers to scale model to scale_down_workers ${__P(scale_down_workers,1)} Scale down the workers = ${__P(hostname,127.0.0.1)} ${__P(management_port,8444)} ${__P(protocol,https)} 6 continue false 1 1 1 false /models?url=${model_url} POST true false true false continue false ${__P(loops,10)} ${__P(threads,2)} ${__P(rampup,5)} false /models/${model}?min_worker=${scale_up_workers} PUT true false true false /models/${model}?min_worker=${scale_down_workers} PUT true false true false ${count} 0 continue false 1 1 1 false /models/${model}?min_worker=0 DELETE true false true false ================================================ FILE: benchmarks/jmx/graphsGenerator.jmx ================================================ false true false continue false 1 1 1 false false false false 0 ${__P(graph_prefix,g_)} 1000 600 800 Inference Request false 150 ${__P(raw_output)} true true false true ${__P(jtl_input)} ================================================ FILE: benchmarks/jmx/imageInputModelPlan.jmx ================================================ false true false cnn_url ${__P(url, https://s3.amazonaws.com/model-server/models/squeezenet_v1.1/squeezenet_v1.1.model)} = The url from where to fetch noop model from scale_up_workers ${__P(min_workers,1)} = The workers to scale No op model to scale_down_workers 0 Offload the No-Op Model = model ${__P(model_name,squeezenet_v1.1)} = ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false 1 1 1 false ${__P(management_port,8444)} /models?url=${cnn_url} POST true false true false ${__P(management_port,8444)} /models/${model}?min_worker=${scale_up_workers} PUT true false true false continue false ${__P(loops,200)} ${__P(threads,20)} ${__P(rampup,5)} false ${__P(input_filepath)} data image/jpeg /predictions/${model} POST true false true true continue false 1 1 1 false ${__P(management_port,8444)} /models/${model}?min_worker=${scale_down_workers} DELETE true false true false ================================================ FILE: benchmarks/jmx/multipleModelsLoadPlan.jmx ================================================ false true false url1 ${__P(url1,noop.model)} = The url from where to fetch noop model from url2 ${__P(url2,lstm_ptb)} = url3 ${__P(url3,https://s3.amazonaws.com/model-server/models/resnet-18/resnet-18.model)} = model1 ${__P(model1_name,noop)} = model2 ${__P(model2_name,lstm_ptb)} = model3 ${__P(model3_name,resnet-18)} = scale_up_workers ${__P(min_workers,1)} = scale_down_workers ${__P(scale_down_workers, 0)} = ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false 1 1 1 false ${__P(management_port,8444)} /models?url=${url1} POST true false true false true false = ${__P(management_port,8444)} /models/${model1}?min_worker=${scale_up_workers} PUT true false true false ${__P(management_port,8444)} /models?url=${url2} POST true false true false true false = ${__P(management_port,8444)} /models/${model2}?min_worker=${scale_up_workers} PUT true false true false ${__P(management_port,8444)} /models?url=${url3} POST true false true false ${__P(management_port,8444)} /models/${model3}?min_worker=${scale_up_workers} PUT true false true false continue false ${__P(loops,2)} ${__P(threads,2)} ${__P(rampup,5)} false false ${__P(data1,'Some garbage data being passed here')} = true data /predictions/${model1} POST true false true false false ${__P(data2,'Some garbage data being passed here')} = true data /predictions/${model2} POST true false true false ${__P(data3)} data image/jpeg /predictions/${model3} POST true false true true continue false 1 1 1 false ${__P(management_port,8444)} /models/${model1}?min_worker=${scale_down_workers} DELETE true false true false ${__P(management_port,8444)} /models/${model2}?min_worker=${scale_down_workers} DELETE true false true false ${__P(management_port,8444)} /models/${model3}?min_worker=${scale_down_workers} DELETE true false true false ================================================ FILE: benchmarks/jmx/pingPlan.jmx ================================================ false true false ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false ${__P(loops,200)} ${__P(threads,200)} ${__P(rampup,5)} false /ping GET true false true false ================================================ FILE: benchmarks/jmx/textInputModelPlan.jmx ================================================ false true false noop_url ${__P(url,noop.model)} = The url from where to fetch noop model from scale_up_workers ${__P(min_workers,1)} = The workers to scale No op model to scale_down_workers 0 Offload the No-Op Model = model ${__P(model_name,noop)} Model name = ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false 1 1 1 false ${__P(management_port,8444)} /models?url=${noop_url} POST true false true false ${__P(management_port,8444)} /models/${model}?min_worker=${scale_up_workers} PUT true false true false continue false ${__P(loops,200)} ${__P(threads,200)} ${__P(rampup,5)} false ${__P(data,)} data application/json /predictions/${model} POST true false true true continue false 1 1 1 false ${__P(management_port,8444)} /models/${model}?min_worker=${scale_down_workers} DELETE true false true false ================================================ FILE: benchmarks/lstm_ip.json ================================================ [{"input_sentence": "on the exchange floor as soon as ual stopped trading we for a panic said one top floor trader"}] ================================================ FILE: benchmarks/mac_install_dependencies.sh ================================================ #!/bin/bash # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. # This file contains the installation setup for running benchmarks on EC2 isntance. # To run on a machine with GPU : ./install_dependencies True # To run on a machine with CPU : ./install_dependencies False set -ex echo "Installing JMeter through Brew" # Script would end on errors, but everything works fine brew update || { } brew install jmeter --with-plugins || { } wget https://jmeter-plugins.org/get/ -O /usr/local/Cellar/jmeter/4.0/libexec/lib/ext/jmeter-plugins-manager-1.3.jar wget http://search.maven.org/remotecontent?filepath=kg/apc/cmdrunner/2.2/cmdrunner-2.2.jar -O /usr/local/Cellar/jmeter/4.0/libexec/lib/cmdrunner-2.2.jar java -cp /usr/local/Cellar/jmeter/4.0/libexec/lib/ext/jmeter-plugins-manager-1.3.jar org.jmeterplugins.repository.PluginManagerCMDInstaller /usr/local/Cellar/jmeter/4.0/libexec/bin/PluginsManagerCMD.sh install jpgc-synthesis=2.1,jpgc-filterresults=2.1,jpgc-mergeresults=2.1,jpgc-cmd=2.1,jpgc-perfmon=2.1 ================================================ FILE: benchmarks/noop_ip.txt ================================================ "[{\"input_sentence\": \"Hello World\"}]" ================================================ FILE: benchmarks/upload_results_to_s3.sh ================================================ #!/usr/bin/env bash # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. #Author: Piyush Ghai set -ex echo "uploading result files to s3" hw_type=cpu if [ "$1" = "True" ] then hw_type=gpu fi echo `pwd` cd /tmp/MMSBenchmark/out echo `pwd` today=`date +"%m-%d-%y"` echo "Saving on S3 bucket on s3://benchmarkai-metrics-prod/daily/mms/$hw_type/$today" for dir in $(ls `pwd`/) do echo $dir aws s3 cp $dir/ s3://benchmarkai-metrics-prod/daily/mms/$hw_type/$today/$dir/ --recursive done echo "Files uploaded" ================================================ FILE: ci/Dockerfile.python2.7 ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. # FROM ubuntu:14.04.5 ENV LANG="C.UTF-8" ENV DOCKER_BUCKET="download.docker.com" \ DOCKER_VERSION="17.09.0-ce" \ DOCKER_CHANNEL="stable" \ DOCKER_SHA256="a9e90a73c3cdfbf238f148e1ec0eaff5eb181f92f35bdd938fd7dab18e1c4647" \ DIND_COMMIT="3b5fac462d21ca164b3778647420016315289034" \ DOCKER_COMPOSE_VERSION="1.16.1" \ GITVERSION_VERSION="3.6.5" # Install git RUN set -ex \ && apt-get update \ && apt-get install software-properties-common -y --no-install-recommends\ && apt-add-repository ppa:git-core/ppa \ && apt-get update \ && apt-get install git -y --no-install-recommends\ && git version RUN set -ex \ && echo 'Acquire::CompressionTypes::Order:: "gz";' > /etc/apt/apt.conf.d/99use-gzip-compression \ && apt-get update \ && apt-get install -y --no-install-recommends wget=1.15-* fakeroot=1.20-* ca-certificates \ autoconf=2.69-* automake=1:1.14.* less=458-* groff=1.22.* \ bzip2=1.0.* file=1:5.14-* g++=4:4.8.* gcc=4:4.8.* imagemagick=8:6.7.* \ libbz2-dev=1.0.* libc6-dev=2.19-* libcurl4-openssl-dev=7.35.* curl=7.35.* \ libdb-dev=1:5.3.* libevent-dev=2.0.* libffi-dev=3.1~* \ libgeoip-dev=1.6.* libglib2.0-dev=2.40.* libjpeg-dev=8c-* \ libkrb5-dev=1.12+* liblzma-dev=5.1.* libmagickcore-dev=8:6.7.* \ libmagickwand-dev=8:6.7.* libmysqlclient-dev=5.5.* libncurses5-dev=5.9+* \ libpng12-dev=1.2.* libpq-dev=9.3.* libreadline-dev=6.3-* libsqlite3-dev=3.8.* \ libssl-dev=1.0.* libtool=2.4.* libwebp-dev=0.4.* libxml2-dev=2.9.* \ libxslt1-dev=1.1.* libyaml-dev=0.1.* make=3.81-* patch=2.7.* xz-utils=5.1.* \ zlib1g-dev=1:1.2.* tcl=8.6.* tk=8.6.* \ e2fsprogs=1.42.* iptables=1.4.* xfsprogs=3.1.* xz-utils=5.1.* \ mono-mcs=3.2.* libcurl4-openssl-dev=7.35.* liberror-perl=0.17-* unzip=6.0-*\ && rm -rf /var/lib/apt/lists/* \ && apt-get clean # Download and set up GitVersion RUN set -ex \ && wget "https://github.com/GitTools/GitVersion/releases/download/v${GITVERSION_VERSION}/GitVersion_${GITVERSION_VERSION}.zip" -O /tmp/GitVersion_${GITVERSION_VERSION}.zip \ && mkdir -p /usr/local/GitVersion_${GITVERSION_VERSION} \ && unzip /tmp/GitVersion_${GITVERSION_VERSION}.zip -d /usr/local/GitVersion_${GITVERSION_VERSION} \ && rm /tmp/GitVersion_${GITVERSION_VERSION}.zip \ && echo "mono /usr/local/GitVersion_${GITVERSION_VERSION}/GitVersion.exe /output json /showvariable \$1" >> /usr/local/bin/gitversion \ && chmod +x /usr/local/bin/gitversion # Install Docker RUN set -ex \ && curl -fSL "https://${DOCKER_BUCKET}/linux/static/${DOCKER_CHANNEL}/x86_64/docker-${DOCKER_VERSION}.tgz" -o docker.tgz \ && echo "${DOCKER_SHA256} *docker.tgz" | sha256sum -c - \ && tar --extract --file docker.tgz --strip-components 1 --directory /usr/local/bin/ \ && rm docker.tgz \ && docker -v \ # set up subuid/subgid so that "--userns-remap=default" works out-of-the-box && addgroup dockremap \ && useradd -g dockremap dockremap \ && echo 'dockremap:165536:65536' >> /etc/subuid \ && echo 'dockremap:165536:65536' >> /etc/subgid \ && wget "https://raw.githubusercontent.com/docker/docker/${DIND_COMMIT}/hack/dind" -O /usr/local/bin/dind \ && curl -L https://github.com/docker/compose/releases/download/${DOCKER_COMPOSE_VERSION}/docker-compose-Linux-x86_64 > /usr/local/bin/docker-compose \ && chmod +x /usr/local/bin/dind /usr/local/bin/docker-compose \ # Ensure docker-compose works && docker-compose version VOLUME /var/lib/docker COPY dockerd-entrypoint.sh /usr/local/bin/ ENV PATH="/usr/local/bin:$PATH" \ GPG_KEY="C01E1CAD5EA2C4F0B8E3571504C367C218ADD4FF" \ PYTHON_VERSION="2.7.12" \ PYTHON_PIP_VERSION="8.1.2" RUN set -ex \ && apt-get update \ && apt-get install -y --no-install-recommends tcl-dev tk-dev \ && rm -rf /var/lib/apt/lists/* \ \ && wget -O python.tar.xz "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz" \ && wget -O python.tar.xz.asc "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz.asc" \ && export GNUPGHOME="$(mktemp -d)" \ && (gpg --keyserver hkp://p80.pool.sks-keyservers.net:80 --recv-keys "$GPG_KEY" \ || gpg --keyserver pgp.mit.edu --recv-keys "$GPG_KEY" \ || gpg --keyserver keyserver.ubuntu.com --recv-keys "$GPG_KEY") \ && gpg --batch --verify python.tar.xz.asc python.tar.xz \ && rm -r "$GNUPGHOME" python.tar.xz.asc \ && mkdir -p /usr/src/python \ && tar -xJC /usr/src/python --strip-components=1 -f python.tar.xz \ && rm python.tar.xz \ \ && cd /usr/src/python \ && ./configure \ --enable-shared \ --enable-unicode=ucs4 \ && make -j$(nproc) \ && make install \ && ldconfig \ \ && wget -O /tmp/get-pip.py 'https://bootstrap.pypa.io/get-pip.py' \ && python2 /tmp/get-pip.py "pip==$PYTHON_PIP_VERSION" \ && rm /tmp/get-pip.py \ # we use "--force-reinstall" for the case where the version of pip we're trying to install is the same as the version bundled with Python # ("Requirement already up-to-date: pip==8.1.2 in /usr/local/lib/python3.6/site-packages") # https://github.com/docker-library/python/pull/143#issuecomment-241032683 && pip install --no-cache-dir --upgrade --force-reinstall "pip==$PYTHON_PIP_VERSION" \ && pip install awscli==1.* --no-cache-dir \ # then we use "pip list" to ensure we don't have more than one pip version installed # https://github.com/docker-library/python/pull/100 && [ "$(pip list |tac|tac| awk -F '[ ()]+' '$1 == "pip" { print $2; exit }')" = "$PYTHON_PIP_VERSION" ] \ \ && find /usr/local -depth \ \( \ \( -type d -a -name test -o -name tests \) \ -o \ \( -type f -a -name '*.pyc' -o -name '*.pyo' \) \ \) -exec rm -rf '{}' + \ && apt-get purge -y --auto-remove tcl-dev tk-dev \ && rm -rf /usr/src/python ~/.cache ENV JAVA_VERSION=8 \ JAVA_HOME="/usr/lib/jvm/java-8-openjdk-amd64" \ JDK_VERSION="8u171-b11-2~14.04" \ JDK_HOME="/usr/lib/jvm/java-8-openjdk-amd64" \ JRE_HOME="/usr/lib/jvm/java-8-openjdk-amd64/jre" \ ANT_VERSION=1.9.6 \ MAVEN_VERSION=3.3.3 \ MAVEN_HOME="/usr/share/maven" \ MAVEN_CONFIG="/root/.m2" \ GRADLE_VERSION=2.7 \ PROPERTIES_COMMON_VERSIION=0.92.37.8 \ PYTHON_TOOL_VERSION="3.3-*" # Install Java RUN set -ex \ && apt-get update \ && apt-get install -y software-properties-common=$PROPERTIES_COMMON_VERSIION \ && add-apt-repository ppa:openjdk-r/ppa \ && apt-get update \ && apt-get -y install python-setuptools=$PYTHON_TOOL_VERSION \ && apt-get -y install openjdk-$JAVA_VERSION-jdk=$JDK_VERSION \ && apt-get clean \ # Ensure Java cacerts symlink points to valid location && update-ca-certificates -f \ && mkdir -p /usr/src/ant \ && wget "http://archive.apache.org/dist/ant/binaries/apache-ant-$ANT_VERSION-bin.tar.gz" -O /usr/src/ant/apache-ant-$ANT_VERSION-bin.tar.gz \ && tar -xzf /usr/src/ant/apache-ant-$ANT_VERSION-bin.tar.gz -C /usr/local \ && ln -s /usr/local/apache-ant-$ANT_VERSION/bin/ant /usr/bin/ant \ && rm -rf /usr/src/ant \ && mkdir -p /usr/share/maven /usr/share/maven/ref $MAVEN_CONFIG \ && curl -fsSL "https://archive.apache.org/dist/maven/maven-3/$MAVEN_VERSION/binaries/apache-maven-$MAVEN_VERSION-bin.tar.gz" \ | tar -xzC /usr/share/maven --strip-components=1 \ && ln -s /usr/share/maven/bin/mvn /usr/bin/mvn \ && mkdir -p /usr/src/gradle \ && wget "https://services.gradle.org/distributions/gradle-$GRADLE_VERSION-bin.zip" -O /usr/src/gradle/gradle-$GRADLE_VERSION-bin.zip \ && unzip /usr/src/gradle/gradle-$GRADLE_VERSION-bin.zip -d /usr/local \ && ln -s /usr/local/gradle-$GRADLE_VERSION/bin/gradle /usr/bin/gradle \ && rm -rf /usr/src/gradle \ && rm -fr /var/lib/apt/lists/* /tmp/* /var/tmp/* COPY m2-settings.xml $MAVEN_CONFIG/settings.xml # MMS build environment RUN set -ex \ && apt-get update \ && pip install retrying \ && pip install mock \ && pip install pytest -U \ && pip install pylint # Install protobuf RUN wget https://github.com/google/protobuf/archive/v3.4.1.zip \ && unzip v3.4.1.zip && rm v3.4.1.zip \ && cd protobuf-3.4.1 && ./autogen.sh && ./configure --prefix=/usr && make && make install && cd .. \ && rm -r protobuf-3.4.1 ================================================ FILE: ci/Dockerfile.python3.6 ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. # FROM ubuntu:14.04.5 ENV LANG="C.UTF-8" ENV DOCKER_BUCKET="download.docker.com" \ DOCKER_VERSION="17.09.0-ce" \ DOCKER_CHANNEL="stable" \ DOCKER_SHA256="a9e90a73c3cdfbf238f148e1ec0eaff5eb181f92f35bdd938fd7dab18e1c4647" \ DIND_COMMIT="3b5fac462d21ca164b3778647420016315289034" \ DOCKER_COMPOSE_VERSION="1.16.1" \ GITVERSION_VERSION="3.6.5" # Install git RUN set -ex \ && apt-get update \ && apt-get install software-properties-common -y --no-install-recommends\ && apt-add-repository ppa:git-core/ppa \ && apt-get update \ && apt-get install git -y --no-install-recommends\ && git version RUN set -ex \ && echo 'Acquire::CompressionTypes::Order:: "gz";' > /etc/apt/apt.conf.d/99use-gzip-compression \ && apt-get update \ && apt-get install -y --no-install-recommends wget=1.15-* fakeroot=1.20-* ca-certificates \ autoconf=2.69-* automake=1:1.14.* less=458-* groff=1.22.* \ bzip2=1.0.* file=1:5.14-* g++=4:4.8.* gcc=4:4.8.* imagemagick=8:6.7.* \ libbz2-dev=1.0.* libc6-dev=2.19-* libcurl4-openssl-dev=7.35.* curl=7.35.* \ libdb-dev=1:5.3.* libevent-dev=2.0.* libffi-dev=3.1~* \ libgeoip-dev=1.6.* libglib2.0-dev=2.40.* libjpeg-dev=8c-* \ libkrb5-dev=1.12+* liblzma-dev=5.1.* libmagickcore-dev=8:6.7.* \ libmagickwand-dev=8:6.7.* libmysqlclient-dev=5.5.* libncurses5-dev=5.9+* \ libpng12-dev=1.2.* libpq-dev=9.3.* libreadline-dev=6.3-* libsqlite3-dev=3.8.* \ libssl-dev=1.0.* libtool=2.4.* libwebp-dev=0.4.* libxml2-dev=2.9.* \ libxslt1-dev=1.1.* libyaml-dev=0.1.* make=3.81-* patch=2.7.* xz-utils=5.1.* \ zlib1g-dev=1:1.2.* tcl=8.6.* tk=8.6.* \ e2fsprogs=1.42.* iptables=1.4.* xfsprogs=3.1.* xz-utils=5.1.* \ mono-mcs=3.2.* libcurl4-openssl-dev=7.35.* liberror-perl=0.17-* unzip=6.0-*\ && rm -rf /var/lib/apt/lists/* \ && apt-get clean # Download and set up GitVersion RUN set -ex \ && wget "https://github.com/GitTools/GitVersion/releases/download/v${GITVERSION_VERSION}/GitVersion_${GITVERSION_VERSION}.zip" -O /tmp/GitVersion_${GITVERSION_VERSION}.zip \ && mkdir -p /usr/local/GitVersion_${GITVERSION_VERSION} \ && unzip /tmp/GitVersion_${GITVERSION_VERSION}.zip -d /usr/local/GitVersion_${GITVERSION_VERSION} \ && rm /tmp/GitVersion_${GITVERSION_VERSION}.zip \ && echo "mono /usr/local/GitVersion_${GITVERSION_VERSION}/GitVersion.exe /output json /showvariable \$1" >> /usr/local/bin/gitversion \ && chmod +x /usr/local/bin/gitversion # Install Docker RUN set -ex \ && curl -fSL "https://${DOCKER_BUCKET}/linux/static/${DOCKER_CHANNEL}/x86_64/docker-${DOCKER_VERSION}.tgz" -o docker.tgz \ && echo "${DOCKER_SHA256} *docker.tgz" | sha256sum -c - \ && tar --extract --file docker.tgz --strip-components 1 --directory /usr/local/bin/ \ && rm docker.tgz \ && docker -v \ # set up subuid/subgid so that "--userns-remap=default" works out-of-the-box && addgroup dockremap \ && useradd -g dockremap dockremap \ && echo 'dockremap:165536:65536' >> /etc/subuid \ && echo 'dockremap:165536:65536' >> /etc/subgid \ && wget "https://raw.githubusercontent.com/docker/docker/${DIND_COMMIT}/hack/dind" -O /usr/local/bin/dind \ && curl -L https://github.com/docker/compose/releases/download/${DOCKER_COMPOSE_VERSION}/docker-compose-Linux-x86_64 > /usr/local/bin/docker-compose \ && chmod +x /usr/local/bin/dind /usr/local/bin/docker-compose \ # Ensure docker-compose works && docker-compose version VOLUME /var/lib/docker COPY dockerd-entrypoint.sh /usr/local/bin/ ENV PATH="/usr/local/bin:$PATH" \ GPG_KEY="0D96DF4D4110E5C43FBFB17F2D347EA6AA65421D" \ PYTHON_VERSION="3.6.5" \ PYTHON_PIP_VERSION="10.0.0" \ LC_ALL=C.UTF-8 \ LANG=C.UTF-8 RUN apt-get update && apt-get install -y --no-install-recommends \ tcl-dev tk-dev \ && rm -rf /var/lib/apt/lists/* \ \ && wget -O python.tar.xz "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz" \ && wget -O python.tar.xz.asc "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz.asc" \ && export GNUPGHOME="$(mktemp -d)" \ && (gpg --keyserver hkp://p80.pool.sks-keyservers.net:80 --recv-keys "$GPG_KEY" \ || gpg --keyserver pgp.mit.edu --recv-keys "$GPG_KEY" \ || gpg --keyserver keyserver.ubuntu.com --recv-keys "$GPG_KEY") \ && gpg --batch --verify python.tar.xz.asc python.tar.xz \ && rm -r "$GNUPGHOME" python.tar.xz.asc \ && mkdir -p /usr/src/python \ && tar -xJC /usr/src/python --strip-components=1 -f python.tar.xz \ && rm python.tar.xz \ \ && cd /usr/src/python \ && ./configure \ --enable-loadable-sqlite-extensions \ --enable-shared \ && make -j$(nproc) \ && make install \ && ldconfig \ \ # explicit path to "pip3" to ensure distribution-provided "pip3" cannot interfere && if [ ! -e /usr/local/bin/pip3 ]; then : \ && wget -O /tmp/get-pip.py 'https://bootstrap.pypa.io/get-pip.py' \ && python3 /tmp/get-pip.py "pip==$PYTHON_PIP_VERSION" \ && rm /tmp/get-pip.py \ ; fi \ # we use "--force-reinstall" for the case where the version of pip we're trying to install is the same as the version bundled with Python # ("Requirement already up-to-date: pip==8.1.2 in /usr/local/lib/python3.6/site-packages") # https://github.com/docker-library/python/pull/143#issuecomment-241032683 && pip3 install --no-cache-dir --upgrade --force-reinstall "pip==$PYTHON_PIP_VERSION" \ && pip install awscli==1.* boto3 pipenv virtualenv --no-cache-dir \ # then we use "pip list" to ensure we don't have more than one pip version installed # https://github.com/docker-library/python/pull/100 && [ "$(pip list |tac|tac| awk -F '[ ()]+' '$1 == "pip" { print $2; exit }')" = "$PYTHON_PIP_VERSION" ] \ \ && find /usr/local -depth \ \( \ \( -type d -a -name test -o -name tests \) \ -o \ \( -type f -a -name '*.pyc' -o -name '*.pyo' \) \ \) -exec rm -rf '{}' + \ && apt-get purge -y --auto-remove tcl-dev tk-dev \ && rm -rf /usr/src/python ~/.cache \ && cd /usr/local/bin \ && { [ -e easy_install ] || ln -s easy_install-* easy_install; } \ && ln -s idle3 idle \ && ln -s pydoc3 pydoc \ && ln -s python3 python \ && ln -s python3-config python-config \ && rm -fr /var/lib/apt/lists/* /tmp/* /var/tmp/* ENV JAVA_VERSION=8 \ JAVA_HOME="/usr/lib/jvm/java-8-openjdk-amd64" \ JDK_VERSION="8u171-b11-2~14.04" \ JDK_HOME="/usr/lib/jvm/java-8-openjdk-amd64" \ JRE_HOME="/usr/lib/jvm/java-8-openjdk-amd64/jre" \ ANT_VERSION=1.9.6 \ MAVEN_VERSION=3.3.3 \ MAVEN_HOME="/usr/share/maven" \ MAVEN_CONFIG="/root/.m2" \ GRADLE_VERSION=2.7 \ PROPERTIES_COMMON_VERSIION=0.92.37.8 \ PYTHON_TOOL_VERSION="3.3-*" # Install Java RUN set -ex \ && apt-get update \ && apt-get install -y software-properties-common=$PROPERTIES_COMMON_VERSIION \ && add-apt-repository ppa:openjdk-r/ppa \ && apt-get update \ && apt-get -y install python-setuptools=$PYTHON_TOOL_VERSION \ && apt-get -y install openjdk-$JAVA_VERSION-jdk=$JDK_VERSION \ && apt-get clean \ # Ensure Java cacerts symlink points to valid location && update-ca-certificates -f \ && mkdir -p /usr/src/ant \ && wget "http://archive.apache.org/dist/ant/binaries/apache-ant-$ANT_VERSION-bin.tar.gz" -O /usr/src/ant/apache-ant-$ANT_VERSION-bin.tar.gz \ && tar -xzf /usr/src/ant/apache-ant-$ANT_VERSION-bin.tar.gz -C /usr/local \ && ln -s /usr/local/apache-ant-$ANT_VERSION/bin/ant /usr/bin/ant \ && rm -rf /usr/src/ant \ && mkdir -p /usr/share/maven /usr/share/maven/ref $MAVEN_CONFIG \ && curl -fsSL "https://archive.apache.org/dist/maven/maven-3/$MAVEN_VERSION/binaries/apache-maven-$MAVEN_VERSION-bin.tar.gz" \ | tar -xzC /usr/share/maven --strip-components=1 \ && ln -s /usr/share/maven/bin/mvn /usr/bin/mvn \ && mkdir -p /usr/src/gradle \ && wget "https://services.gradle.org/distributions/gradle-$GRADLE_VERSION-bin.zip" -O /usr/src/gradle/gradle-$GRADLE_VERSION-bin.zip \ && unzip /usr/src/gradle/gradle-$GRADLE_VERSION-bin.zip -d /usr/local \ && ln -s /usr/local/gradle-$GRADLE_VERSION/bin/gradle /usr/bin/gradle \ && rm -rf /usr/src/gradle \ && rm -fr /var/lib/apt/lists/* /tmp/* /var/tmp/* COPY m2-settings.xml $MAVEN_CONFIG/settings.xml # MMS build environment RUN set -ex \ && apt-get update \ && pip install retrying \ && pip install mock \ && pip install pytest -U \ && pip install pylint # Install protobuf RUN wget https://github.com/google/protobuf/archive/v3.4.1.zip \ && unzip v3.4.1.zip && rm v3.4.1.zip \ && cd protobuf-3.4.1 && ./autogen.sh && ./configure --prefix=/usr && make && make install && cd .. \ && rm -r protobuf-3.4.1 ================================================ FILE: ci/README.md ================================================ # Model Server CI build Model Server us AWS codebuild for its CI build. This folder contains scripts that needed for AWS codebuild. ## buildspec.yml buildspec.yml contains MMS build logic which will be used by AWS codebuild. ## Docker images MMS use customized docker image for its AWS codebuild. To make sure MMS is compatible with both Python2 and Python3, we use two build projects. We published two codebuild docker images on docker hub: * awsdeeplearningteam/mms-build:python2.7 * awsdeeplearningteam/mms-build:python3.6 Following files in this folder is used to create the docker images * Dockerfile.python2.7 - Dockerfile for awsdeeplearningteam/mms-build:python2.7 * Dockerfile.python3.6 - Dockerfile for awsdeeplearningteam/mms-build:python3.6 * dockerd-entrypoint.sh - AWS codebuild entrypoint script, required by AWS codebuild * m2-settings.xml - Limit with repository can be used by maven/gradle in docker container, provided by AWS codebuild. ## AWS codebuild local To make it easy for developer debug build issue locally, MMS support AWS codebuild local. Developer can use following command to build MMS locally: ```bash $ cd multi-model-server $ ./run_ci_tests.sh ``` To avoid Pull Request build failure on github, developer should always make sure local build can pass. ================================================ FILE: ci/buildspec.yml ================================================ # Build Spec for AWS CodeBuild CI version: 0.2 phases: install: commands: - apt-get update - apt-get install -y curl - pip install pip -U - pip install future - pip install Pillow - pip install pytest==4.0.0 - pip install wheel - pip install twine - pip install pytest-mock -U - pip install requests - pip install -U -e . - pip install mxnet==1.5.0 - cd model-archiver/ && pip install -U -e . && cd ../ build: commands: - frontend/gradlew -p frontend build - python -m pytest mms/tests/unit_tests - cd model-archiver/ && python -m pytest model_archiver/tests/unit_tests && cd ../ - cd model-archiver/ && python -m pytest model_archiver/tests/integ_tests && cd ../ - cd serving-sdk/ && mvn clean deploy && cd ../ # integration test is broken: https://github.com/awslabs/multi-model-server/issues/437 #- python -m pytest mms/tests/integration_tests - pylint -rn --rcfile=./mms/tests/pylintrc mms/. - cd model-archiver/ && pylint -rn --rcfile=./model_archiver/tests/pylintrc model_archiver/. && cd ../ - $NIGHTLYBUILD - eval $NIGHTLYUPLOAD artifacts: files: - dist/*.whl - model_archiver/dist/*.whl - frontend/server/build/reports/**/* - frontend/modelarchive/build/reports/**/* - frontend/cts/build/reports/**/* ================================================ FILE: ci/dockerd-entrypoint.sh ================================================ #!/bin/sh set -e /usr/local/bin/dockerd \ --host=unix:///var/run/docker.sock \ --host=tcp://127.0.0.1:2375 \ --storage-driver=overlay &>/var/log/docker.log & tries=0 d_timeout=60 until docker info >/dev/null 2>&1 do if [ "$tries" -gt "$d_timeout" ]; then cat /var/log/docker.log echo 'Timed out trying to connect to internal docker host.' >&2 exit 1 fi tries=$(( $tries + 1 )) sleep 1 done eval "$@" ================================================ FILE: ci/m2-settings.xml ================================================ securecentral true central https://repo1.maven.org/maven2 true central https://repo1.maven.org/maven2 true ================================================ FILE: docker/Dockerfile.cpu ================================================ FROM ubuntu:18.04 ENV PYTHONUNBUFFERED TRUE RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ fakeroot \ ca-certificates \ dpkg-dev \ g++ \ python3-dev \ openjdk-8-jdk-headless \ curl \ vim \ && rm -rf /var/lib/apt/lists/* \ && cd /tmp \ && curl -O https://bootstrap.pypa.io/get-pip.py \ && python3 get-pip.py RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1 RUN update-alternatives --install /usr/local/bin/pip pip /usr/local/bin/pip3 1 RUN pip install --no-cache-dir multi-model-server \ && pip install --no-cache-dir mxnet-mkl==1.4.0 RUN useradd -m model-server \ && mkdir -p /home/model-server/tmp COPY dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh COPY config.properties /home/model-server RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh \ && chown -R model-server /home/model-server EXPOSE 8080 8081 USER model-server WORKDIR /home/model-server ENV TEMP=/home/model-server/tmp ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] CMD ["serve"] LABEL maintainer="dantu@amazon.com, rakvas@amazon.com, lufen@amazon.com, dden@amazon.com" ================================================ FILE: docker/Dockerfile.gpu ================================================ FROM nvidia/cuda:9.2-cudnn7-runtime-ubuntu18.04 ENV PYTHONUNBUFFERED TRUE RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ fakeroot \ ca-certificates \ dpkg-dev \ g++ \ python3-dev \ openjdk-8-jdk-headless \ curl \ vim \ && rm -rf /var/lib/apt/lists/* \ && cd /tmp \ && curl -O https://bootstrap.pypa.io/get-pip.py \ && python3 get-pip.py RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1 RUN update-alternatives --install /usr/local/bin/pip pip /usr/local/bin/pip3 1 RUN pip install --no-cache-dir multi-model-server \ && pip install --no-cache-dir mxnet-cu92mkl==1.4.0 RUN useradd -m model-server \ && mkdir -p /home/model-server/tmp COPY dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh COPY config.properties /home/model-server RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh \ && chown -R model-server /home/model-server EXPOSE 8080 8081 USER model-server WORKDIR /home/model-server ENV TEMP=/home/model-server/tmp ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] CMD ["serve"] LABEL maintainer="dantu@amazon.com, rakvas@amazon.com, lufen@amazon.com, dden@amazon.com" ================================================ FILE: docker/Dockerfile.nightly-cpu ================================================ FROM ubuntu:18.04 ENV PYTHONUNBUFFERED TRUE RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ fakeroot \ ca-certificates \ dpkg-dev \ g++ \ python3-dev \ openjdk-8-jdk-headless \ curl \ vim \ && rm -rf /var/lib/apt/lists/* \ && cd /tmp \ && curl -O https://bootstrap.pypa.io/get-pip.py \ && python3 get-pip.py RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1 RUN update-alternatives --install /usr/local/bin/pip pip /usr/local/bin/pip3 1 RUN pip install --no-cache-dir --pre multi-model-server \ && pip install --no-cache-dir mxnet-mkl RUN useradd -m model-server \ && mkdir -p /home/model-server/tmp COPY dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh COPY config.properties /home/model-server RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh \ && chown -R model-server /home/model-server EXPOSE 8080 8081 USER model-server WORKDIR /home/model-server ENV TEMP=/home/model-server/tmp ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] CMD ["serve"] LABEL maintainer="dantu@amazon.com, rakvas@amazon.com, lufen@amazon.com, dden@amazon.com" ================================================ FILE: docker/Dockerfile.nightly-gpu ================================================ FROM nvidia/cuda:9.2-cudnn7-runtime-ubuntu18.04 ENV PYTHONUNBUFFERED TRUE RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ fakeroot \ ca-certificates \ dpkg-dev \ g++ \ python3-dev \ openjdk-8-jdk-headless \ curl \ vim \ && rm -rf /var/lib/apt/lists/* \ && cd /tmp \ && curl -O https://bootstrap.pypa.io/get-pip.py \ && python3 get-pip.py RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1 RUN update-alternatives --install /usr/local/bin/pip pip /usr/local/bin/pip3 1 RUN pip install --no-cache-dir --pre multi-model-server \ && pip install --no-cache-dir mxnet-cu92mkl RUN useradd -m model-server \ && mkdir -p /home/model-server/tmp COPY dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh COPY config.properties /home/model-server RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh \ && chown -R model-server /home/model-server EXPOSE 8080 8081 USER model-server WORKDIR /home/model-server ENV TEMP=/home/model-server/tmp ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] CMD ["serve"] LABEL maintainer="dantu@amazon.com, rakvas@amazon.com, lufen@amazon.com, dden@amazon.com" ================================================ FILE: docker/README.md ================================================ [//]: # "All the references in this file should be actual links because this file would be used by docker hub. DO NOT use relative links or section tagging." # Using Containers with Multi Model Server Multi Model Server (MMS) can be used with any container service. In this guide, you will learn how to run MMS with Docker. ## Contents of this Document * [Quickstart](https://github.com/awslabs/multi-model-server/blob/master/docker/README.md#quickstart) * [Available pre-built containers](https://github.com/awslabs/multi-model-server/blob/master/docker/README.md#available-pre-built-containers) * [Configuring MMS with Docker](https://github.com/awslabs/multi-model-server/blob/master/docker/README.md#configuring-mms-with-docker) ## Other Relevant Documents * [Advanced Settings](https://github.com/awslabs/multi-model-server/blob/master/docker/advanced_settings.md) * [GPU Inference](https://github.com/awslabs/multi-model-server/blob/master/docker/advanced_settings.md#gpu-inference) * [Reference Commands](https://github.com/awslabs/multi-model-server/blob/master/docker/advanced_settings.md#reference-commands) * [Docker Details](https://github.com/awslabs/multi-model-server/blob/master/docker/advanced_settings.md#docker-details) * [Description of Config File Settings](https://github.com/awslabs/multi-model-server/blob/master/docker/advanced_settings.md#description-of-config-file-settings) * [Configuring SSL](https://github.com/awslabs/multi-model-server/blob/master/docker/advanced_settings.md#configuring-ssl) * [Launch MMS as a managed inference service on AWS Fargate](https://github.com/awslabs/multi-model-server/blob/master/docs/mms_on_fargate.md) * [Introduction to published containers](https://github.com/awslabs/multi-model-server/blob/master/docs/mms_on_fargate.md#familiarize-yourself-with-our-containers) * [Creating a AWS Fargate task to server SqueezeNet V1.1](https://github.com/awslabs/multi-model-server/blob/master/docs/mms_on_fargate.md#create-a-aws-faragte-task-to-serve-squeezenet-model) * [Creating an Load Balancer](https://github.com/awslabs/multi-model-server/blob/master/docs/mms_on_fargate.md#create-a-load-balancer) * [Creating an AWS ECS Service](https://github.com/awslabs/multi-model-server/blob/master/docs/mms_on_fargate.md#creating-an-ecs-service-to-launch-our-aws-fargate-task) * [Testing your service](https://github.com/awslabs/multi-model-server/blob/master/docs/mms_on_fargate.md#test-your-service) * [Build custom MMS containers images to serve your Deep learning models](https://github.com/awslabs/multi-model-server/blob/master/docs/mms_on_fargate.md#customize-the-containers-to-server-your-custom-deep-learning-models) ## Quickstart Running Multi Model Server with Docker in two steps: **Step 1: Run the Docker image.** This will download the MMS Docker image and run its default configuration, serving a SqueezeNet model. ```bash docker run -itd --name mms -p 80:8080 -p 8081:8081 awsdeeplearningteam/multi-model-server multi-model-server --start --models squeezenet=https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar ``` With the `-p` flag, we're setting it up so you can run the Predict API on your host computer's port `80`. This maps to the Docker image's port `8080`. It will run the Management API on your host computer's port `8081`. This maps to the Docker image's port `8081`. **Step 2: Test inference.** ```bash curl -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg curl -X POST http://127.0.0.1/predictions/squeezenet -T kitten.jpg ``` After fetching this image of a kitten and posting it to the `predict` endpoint, you should see a response similar to the following: ``` { "prediction": [ [ { "class": "n02124075 Egyptian cat", "probability": 0.9408261179924011 ... ``` ### Cleaning Up Now that you have tested it out, you may stop the Docker container. The following command will stop the server and delete the container. It will retain the Docker image for trying out other models and configurations later. ```bash docker rm -f mms ``` ## Available pre-built containers We have following container tags available on [Docker Hub](https://hub.docker.com/r/awsdeeplearningteam/multi-model-server/). 1. *latest*: This is the latest officially released MMS CPU container. This is based on the latest [Dockerfile.cpu](https://github.com/awslabs/multi-model-server/blob/master/docker/Dockerfile.cpu). 2. *latest-gpu*: This is the latest officially released MMS GPU container. This is based on the latest [Dockerfile.gpu](https://github.com/awslabs/multi-model-server/blob/master/docker/Dockerfile.gpu). 3. *(MMS Release Tag)-mxnet-cpu*: Each released version since MMS 1.0.0 has an individual release tagged CPU MXNet container. These containers are based on [Dockerfile.cpu](https://github.com/awslabs/multi-model-server/blob/master/docker/Dockerfile.cpu), in that MMS release. 4. *(MMS Release Tag)-mxnet-cpu*: Each released version since MMS 1.0.0 has an individual release tagged GPU MXNet container. These containers are based on [Dockerfile.gpu](https://github.com/awslabs/multi-model-server/blob/master/docker/Dockerfile.gpu), in that MMS release. 5. *nightly-mxnet-cpu*: This is the official CPU container which is built based on the nightly release of MMS pip package. This is built from [Dockerfile.nightly-cpu](https://github.com/awslabs/multi-model-server/blob/master/docker/Dockerfile.nightly-cpu). 6. *nightly-mxnet-gpu*: This is the official GPU container which is built based on the nightly release of MMS pip package. This is built from [Dockerfile.nightly-gpu](https://github.com/awslabs/multi-model-server/blob/master/docker/Dockerfile.nightly-gpu). 7. *base-cpu-py2.7*: This is the official Base Python 2.7 CPU container which contains only MMS and python. This is built from [Dockerfile.nightly-gpu](https://github.com/awslabs/multi-model-server/blob/master/docker/advanced-dockerfiles/Dockerfile.base.ubuntu_16_04.py2_7). Please note, this container doesn't have any DL/ML engine installed by definition, it is meant to be used for cases where you would like to bring your own engine/framework into container. **WARNING: Python 2.x will be deprecated from Jan 1 2020.** 8. *base-cpu-py3.6*: This is the official Base Python 3.6 CPU container which contains only MMS and python. This is built from [Dockerfile.nightly-gpu](https://github.com/awslabs/multi-model-server/blob/master/docker/advanced-dockerfiles/Dockerfile.base.ubuntu_16_04.py3_6). Please note, this container doesn't have any DL/ML engine installed by definition, it is meant to be used for cases where you would like to bring your own engine/framework into container. 9. *base-gpu-py2.7*: This is the official Base Python 2.7 GPU container which contains only MMS and python. This is built from [Dockerfile.nightly-gpu](https://github.com/awslabs/multi-model-server/blob/master/docker/advanced-dockerfiles/Dockerfile.base.nvidia_cu92_ubuntu_16_04.py2_7). Please note, this container doesn't have any DL/ML engine installed by definition, it is meant to be used for cases where you would like to bring your own engine/framework into container. **WARNING: Python 2.x will be deprecated from Jan 1 2020.** 10. *base-gpu-py3.6*: This is the official Base Python 3.6 GPU container which contains only MMS and python. This is built from [Dockerfile.nightly-gpu](https://github.com/awslabs/multi-model-server/blob/master/docker/advanced-dockerfiles/Dockerfile.base.nvidia_cu92_ubuntu_16_04.py3_6). Please note, this container doesn't have any DL/ML engine installed by definition, it is meant to be used for cases where you would like to bring your own engine/framework into container. 11. *nightly-base-cpu-py2.7*: This is the official Nightly Base Python 2.7 CPU container which contains only MMS and python. This is built from [Dockerfile.nightly-gpu](https://github.com/awslabs/multi-model-server/blob/master/docker/advanced-dockerfiles/Dockerfile.base.ubuntu_16_04.py2_7). Please note, this container doesn't have any DL/ML engine installed by definition, it is meant to be used for cases where you would like to bring your own engine/framework into container. **WARNING: Python 2.x will be deprecated from Jan 1 2020.** 12. *nightly-base-cpu-py3.6*: This is the official Nightly Base Python 3.6 CPU container which contains only MMS and python. This is built from [Dockerfile.nightly-gpu](https://github.com/awslabs/multi-model-server/blob/master/docker/advanced-dockerfiles/Dockerfile.base.ubuntu_16_04.py3_6). Please note, this container doesn't have any DL/ML engine installed by definition, it is meant to be used for cases where you would like to bring your own engine/framework into container. 13. *nightly-base-gpu-py2.7*: This is the official Nightly Base Python 2.7 GPU container which contains only MMS and python. This is built from [Dockerfile.nightly-gpu](https://github.com/awslabs/multi-model-server/blob/master/docker/advanced-dockerfiles/Dockerfile.base.nvidia_cu92_ubuntu_16_04.py2_7). Please note, this container doesn't have any DL/ML engine installed by definition, it is meant to be used for cases where you would like to bring your own engine/framework into container. **WARNING: Python 2.x will be deprecated from Jan 1 2020.** 14. *nightly-base-gpu-py3.6*: This is the official Nightly Base Python 3.6 GPU container which contains only MMS and python. This is built from [Dockerfile.nightly-gpu](https://github.com/awslabs/multi-model-server/blob/master/docker/advanced-dockerfiles/Dockerfile.base.nvidia_cu92_ubuntu_16_04.py3_6). Please note, this container doesn't have any DL/ML engine installed by definition, it is meant to be used for cases where you would like to bring your own engine/framework into container. To pull the a particular container, run the following command #### Pulling the latest CPU container: Docker pull by default pulls the latest tag. This tag is associated with latest released MMS CPU container. This tag isn't available until after an official release. ```bash docker pull awsdeeplearningteam/multi-model-server ``` #### Pulling the latest GPU container: To pull a official latest released MMS GPU container run the following command. This tag isn't available until after an official release. ```bash docker pull awsdeeplearningteam/multi-model-server:latest-gpu ``` #### Pulling the `nightly-mxnet-cpu` tag: To pull a latest nigthtly MMS CPU container run the following command. This track the pre-release version of MMS. We do not recommend running this container in production setup. ```bash docker pull awsdeeplearningteam/multi-model-server:nightly-mxnet-cpu ``` #### Pulling the `nightly-mxnet-gpu` tag: To pull a latest nigthtly MMS CPU container run the following command. This track the pre-release version of MMS. We do not recommend running this container in production setup. ```bash docker pull awsdeeplearningteam/multi-model-server:nightly-mxnet-gpu ``` ## Configuring MMS with Docker In the Quickstart section, you launched a Docker image with MMS serving the SqueezeNet model. Now you will learn how to configure MMS with Docker to run other models. You will also learn how to collect MMS logs, and optimize MMS with Docker images. ### Using MMS and Docker with a Shared Volume You may sometimes want to load different models with a different configuration. Setting up a shared volume with the Docker image is the recommended way to handle this. **Step 1: Create a folder to share with the Docker container.** Create a directory for `models`. This will also provide a place for log files to be written. ```bash mkdir /tmp/models ``` **Step 2: Download the configuration template.** Download the template `config.properties` and place it in the `models` folder you just created: * [config.properties](https://github.com/awslabs/multi-model-server/blob/master/docker/config.properties) **Step 3: Modify the configuration template.** Edit the file you downloaded, `config.properties`. ```properties vmargs=-Xmx128m -XX:-UseLargePages -XX:+UseG1GC -XX:MaxMetaspaceSize=32M -XX:MaxDirectMemorySize=10m -XX:+ExitOnOutOfMemoryError model_store=/opt/ml/model load_models=ALL inference_address=http://0.0.0.0:8080 management_address=http://0.0.0.0:8081 # management_address=unix:/tmp/management.sock # number_of_netty_threads=0 # netty_client_threads=0 # default_response_timeout=120 # unregister_model_timeout=120 # default_workers_per_model=0 # job_queue_size=100 # async_logging=false # number_of_gpu=1 # cors_allowed_origin # cors_allowed_methods # cors_allowed_headers # keystore=src/test/resources/keystore.p12 # keystore_pass=changeit # keystore_type=PKCS12 # private_key_file=src/test/resources/key.pem # certificate_file=src/test/resources/certs.pem # blacklist_env_vars= ``` Modify the configuration file to suite your configuration needs before running the model server. Save the file. **Step 4: Run MMS with Docker using a shared volume.** When you run the following command, the `-v` argument and path values of `/tmp/models/:/models` will map the Docker image's `models` folder to your local `/tmp/models` folder. MMS will then be able to use the local model file. ```bash docker run -itd --name mms -p 80:8080 -p 8081:8081 -v /tmp/models/:/models awsdeeplearningteam/multi-model-server multi-model-server --start --models squeezenet=https://s3.amazonaws.com/model-server/model_archive_1.0/resnet-18.mar ``` **NOTE**: If you modify the inference_address or the management_address in the configuration file, you must modify the ports exposed by Docker as well. **Step 5: Test inference.** You will upload the same kitten image as before, but this time you will request the `predictions/resnet` API endpoint. ```bash curl -X POST http://127.0.0.1/predictions/resnet -T @kitten.jpg ``` Given that this is a different model, the same image yields a different inference result which is something similar to the following: ``` { "prediction": [ [ { "class": "n02123159 tiger cat", "probability": 0.3630334138870239 }, ... ``` ## Conclusion You have tried the default Predictions API settings using a SqueezeNet model. You then configured your Predictions API endpoints to also serve a ResNet-18 model. Now you are ready to try some other more **advanced settings** such as: * GPU inference * MMS settings Next Step: [Advanced Settings](https://github.com/awslabs/multi-model-server/blob/master/docker/advanced_settings.md) ================================================ FILE: docker/advanced-dockerfiles/Dockerfile.base.nvidia_cu92_ubuntu_16_04.py2_7 ================================================ FROM nvidia/cuda:9.2-cudnn7-runtime-ubuntu16.04 ENV PYTHONUNBUFFERED TRUE RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ fakeroot \ ca-certificates \ dpkg-dev \ g++ \ python-dev \ openjdk-8-jdk-headless \ curl \ vim \ && rm -rf /var/lib/apt/lists/* \ && cd /tmp \ && curl -O https://bootstrap.pypa.io/get-pip.py \ && python get-pip.py RUN update-alternatives --install /usr/bin/python python /usr/bin/python2.7 1 RUN pip install --no-cache-dir multi-model-server RUN mkdir -p /home/model-server/tmp COPY dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh COPY config.properties /home/model-server RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh EXPOSE 8080 8081 WORKDIR /home/model-server ENV TEMP=/home/model-server/tmp ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] CMD ["serve"] LABEL maintainer="dantu@amazon.com, rakvas@amazon.com, lufen@amazon.com, dden@amazon.com" ================================================ FILE: docker/advanced-dockerfiles/Dockerfile.base.nvidia_cu92_ubuntu_16_04.py2_7.nightly ================================================ FROM nvidia/cuda:9.2-cudnn7-runtime-ubuntu16.04 ENV PYTHONUNBUFFERED TRUE RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ fakeroot \ ca-certificates \ dpkg-dev \ g++ \ python-dev \ openjdk-8-jdk-headless \ curl \ vim \ && rm -rf /var/lib/apt/lists/* \ && cd /tmp \ && curl -O https://bootstrap.pypa.io/get-pip.py \ && python get-pip.py RUN update-alternatives --install /usr/bin/python python /usr/bin/python2.7 1 RUN pip install --no-cache-dir multi-model-server --pre RUN mkdir -p /home/model-server/tmp COPY dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh COPY config.properties /home/model-server RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh EXPOSE 8080 8081 WORKDIR /home/model-server ENV TEMP=/home/model-server/tmp ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] CMD ["serve"] LABEL maintainer="dantu@amazon.com, rakvas@amazon.com, lufen@amazon.com, dden@amazon.com" ================================================ FILE: docker/advanced-dockerfiles/Dockerfile.base.nvidia_cu92_ubuntu_16_04.py3_6 ================================================ FROM nvidia/cuda:9.2-cudnn7-runtime-ubuntu16.04 ENV PYTHONUNBUFFERED TRUE # Install python3.6 and pip3 ENV PATH /usr/local/bin:$PATH ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y --no-install-recommends \ tk-dev RUN set -ex \ && echo 'Acquire::CompressionTypes::Order:: "gz";' > /etc/apt/apt.conf.d/99use-gzip-compression \ && apt-get update \ && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends wget ca-certificates vim\ autoconf automake less groff dpkg-dev \ bzip2=1.0.* file g++ gcc imagemagick \ libbz2-dev libc6-dev curl \ libdb-dev libevent-dev libffi-dev \ libgeoip-dev libglib2.0-dev libjpeg-dev \ libkrb5-dev liblzma-dev libmagickcore-dev \ libmagickwand-dev libmysqlclient-dev libncurses5-dev \ libpng12-dev libpq-dev libreadline-dev libsqlite3-dev \ libssl-dev libtool libwebp-dev libxml2-dev \ libxslt1-dev libyaml-dev make patch xz-utils \ zlib1g-dev tcl tk \ e2fsprogs iptables xfsprogs xz-utils openjdk-8-jdk-headless fakeroot \ mono-mcs libcurl4-openssl-dev liberror-perl unzip # Install python 3.6 ENV GPG_KEY 0D96DF4D4110E5C43FBFB17F2D347EA6AA65421D ENV PYTHON_VERSION 3.6.8 RUN set -ex \ \ && wget -O python.tar.xz "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz" \ && wget -O python.tar.xz.asc "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz.asc" \ && export GNUPGHOME="$(mktemp -d)" \ && gpg --batch --keyserver hkp://p80.pool.sks-keyservers.net:80 --recv-keys "$GPG_KEY" \ && gpg --batch --verify python.tar.xz.asc python.tar.xz \ && { command -v gpgconf > /dev/null && gpgconf --kill all || :; } \ && rm -rf "$GNUPGHOME" python.tar.xz.asc \ && mkdir -p /usr/src/python \ && tar -xJC /usr/src/python --strip-components=1 -f python.tar.xz \ && rm python.tar.xz \ && cd /usr/src/python \ && gnuArch="$(dpkg-architecture --query DEB_BUILD_GNU_TYPE)" \ && ./configure \ --build="$gnuArch" \ --enable-loadable-sqlite-extensions \ --enable-shared \ --with-system-expat \ --with-system-ffi \ --without-ensurepip \ && make -j "$(nproc)" \ && make install \ && ldconfig \ \ && find /usr/local -depth \ \( \ \( -type d -a \( -name test -o -name tests \) \) \ -o \ \( -type f -a \( -name '*.pyc' -o -name '*.pyo' \) \) \ \) -exec rm -rf '{}' + \ && rm -rf /usr/src/python \ \ && python3 --version # make some useful symlinks that are expected to exist RUN cd /usr/local/bin \ && ln -s idle3 idle \ && ln -s pydoc3 pydoc \ && ln -s python3 python \ && ln -s python3-config python-config # if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value ''" ENV PYTHON_PIP_VERSION 19.0.3 RUN set -ex; \ \ wget -O get-pip.py 'https://bootstrap.pypa.io/get-pip.py'; \ \ python get-pip.py \ --disable-pip-version-check \ --no-cache-dir \ "pip==$PYTHON_PIP_VERSION" \ ; \ pip --version; \ \ find /usr/local -depth \ \( \ \( -type d -a \( -name test -o -name tests \) \) \ -o \ \( -type f -a \( -name '*.pyc' -o -name '*.pyo' \) \) \ \) -exec rm -rf '{}' +; \ rm -f get-pip.py RUN pip install --no-cache-dir multi-model-server RUN mkdir -p /home/model-server/tmp COPY dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh COPY config.properties /home/model-server RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh EXPOSE 8080 8081 WORKDIR /home/model-server ENV TEMP=/home/model-server/tmp ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] CMD ["serve"] LABEL maintainer="dantu@amazon.com, rakvas@amazon.com, lufen@amazon.com, dden@amazon.com" ================================================ FILE: docker/advanced-dockerfiles/Dockerfile.base.nvidia_cu92_ubuntu_16_04.py3_6.nightly ================================================ FROM nvidia/cuda:9.2-cudnn7-runtime-ubuntu16.04 ENV PYTHONUNBUFFERED TRUE # Install python3.6 and pip3 ENV PATH /usr/local/bin:$PATH ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y --no-install-recommends \ tk-dev RUN set -ex \ && echo 'Acquire::CompressionTypes::Order:: "gz";' > /etc/apt/apt.conf.d/99use-gzip-compression \ && apt-get update \ && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends wget ca-certificates vim\ autoconf automake less groff dpkg-dev \ bzip2=1.0.* file g++ gcc imagemagick \ libbz2-dev libc6-dev curl \ libdb-dev libevent-dev libffi-dev \ libgeoip-dev libglib2.0-dev libjpeg-dev \ libkrb5-dev liblzma-dev libmagickcore-dev \ libmagickwand-dev libmysqlclient-dev libncurses5-dev \ libpng12-dev libpq-dev libreadline-dev libsqlite3-dev \ libssl-dev libtool libwebp-dev libxml2-dev \ libxslt1-dev libyaml-dev make patch xz-utils \ zlib1g-dev tcl tk \ e2fsprogs iptables xfsprogs xz-utils openjdk-8-jdk-headless fakeroot \ mono-mcs libcurl4-openssl-dev liberror-perl unzip # Install python 3.6 ENV GPG_KEY 0D96DF4D4110E5C43FBFB17F2D347EA6AA65421D ENV PYTHON_VERSION 3.6.8 RUN set -ex \ \ && wget -O python.tar.xz "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz" \ && wget -O python.tar.xz.asc "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz.asc" \ && export GNUPGHOME="$(mktemp -d)" \ && gpg --batch --keyserver hkp://p80.pool.sks-keyservers.net:80 --recv-keys "$GPG_KEY" \ && gpg --batch --verify python.tar.xz.asc python.tar.xz \ && { command -v gpgconf > /dev/null && gpgconf --kill all || :; } \ && rm -rf "$GNUPGHOME" python.tar.xz.asc \ && mkdir -p /usr/src/python \ && tar -xJC /usr/src/python --strip-components=1 -f python.tar.xz \ && rm python.tar.xz \ && cd /usr/src/python \ && gnuArch="$(dpkg-architecture --query DEB_BUILD_GNU_TYPE)" \ && ./configure \ --build="$gnuArch" \ --enable-loadable-sqlite-extensions \ --enable-shared \ --with-system-expat \ --with-system-ffi \ --without-ensurepip \ && make -j "$(nproc)" \ && make install \ && ldconfig \ \ && find /usr/local -depth \ \( \ \( -type d -a \( -name test -o -name tests \) \) \ -o \ \( -type f -a \( -name '*.pyc' -o -name '*.pyo' \) \) \ \) -exec rm -rf '{}' + \ && rm -rf /usr/src/python \ \ && python3 --version # make some useful symlinks that are expected to exist RUN cd /usr/local/bin \ && ln -s idle3 idle \ && ln -s pydoc3 pydoc \ && ln -s python3 python \ && ln -s python3-config python-config # if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value ''" ENV PYTHON_PIP_VERSION 19.0.3 RUN set -ex; \ \ wget -O get-pip.py 'https://bootstrap.pypa.io/get-pip.py'; \ \ python get-pip.py \ --disable-pip-version-check \ --no-cache-dir \ "pip==$PYTHON_PIP_VERSION" \ ; \ pip --version; \ \ find /usr/local -depth \ \( \ \( -type d -a \( -name test -o -name tests \) \) \ -o \ \( -type f -a \( -name '*.pyc' -o -name '*.pyo' \) \) \ \) -exec rm -rf '{}' +; \ rm -f get-pip.py RUN pip install --no-cache-dir multi-model-server --pre RUN mkdir -p /home/model-server/tmp COPY dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh COPY config.properties /home/model-server RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh EXPOSE 8080 8081 WORKDIR /home/model-server ENV TEMP=/home/model-server/tmp ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] CMD ["serve"] LABEL maintainer="dantu@amazon.com, rakvas@amazon.com, lufen@amazon.com, dden@amazon.com" ================================================ FILE: docker/advanced-dockerfiles/Dockerfile.base.ubuntu_16_04.py2_7 ================================================ FROM ubuntu:16.04 ENV PYTHONUNBUFFERED TRUE RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ fakeroot \ ca-certificates \ dpkg-dev \ g++ \ python-dev \ openjdk-8-jdk-headless \ curl \ vim \ && rm -rf /var/lib/apt/lists/* \ && cd /tmp \ && curl -O https://bootstrap.pypa.io/get-pip.py \ && python get-pip.py RUN update-alternatives --install /usr/bin/python python /usr/bin/python2.7 1 RUN pip install --no-cache-dir multi-model-server RUN mkdir -p /home/model-server/tmp COPY dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh COPY config.properties /home/model-server RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh EXPOSE 8080 8081 WORKDIR /home/model-server ENV TEMP=/home/model-server/tmp ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] CMD ["serve"] LABEL maintainer="dantu@amazon.com, rakvas@amazon.com, lufen@amazon.com, dden@amazon.com" ================================================ FILE: docker/advanced-dockerfiles/Dockerfile.base.ubuntu_16_04.py2_7.nightly ================================================ FROM ubuntu:16.04 ENV PYTHONUNBUFFERED TRUE RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ fakeroot \ ca-certificates \ dpkg-dev \ g++ \ python-dev \ openjdk-8-jdk-headless \ curl \ vim \ && rm -rf /var/lib/apt/lists/* \ && cd /tmp \ && curl -O https://bootstrap.pypa.io/get-pip.py \ && python get-pip.py RUN update-alternatives --install /usr/bin/python python /usr/bin/python2.7 1 RUN pip install --no-cache-dir multi-model-server --pre RUN mkdir -p /home/model-server/tmp COPY dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh COPY config.properties /home/model-server RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh EXPOSE 8080 8081 WORKDIR /home/model-server ENV TEMP=/home/model-server/tmp ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] CMD ["serve"] LABEL maintainer="dantu@amazon.com, rakvas@amazon.com, lufen@amazon.com, dden@amazon.com" ================================================ FILE: docker/advanced-dockerfiles/Dockerfile.base.ubuntu_16_04.py3_6 ================================================ FROM ubuntu:16.04 ENV PYTHONUNBUFFERED TRUE # Install python3.6 and pip3 ENV PATH /usr/local/bin:$PATH ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y --no-install-recommends \ tk-dev RUN set -ex \ && echo 'Acquire::CompressionTypes::Order:: "gz";' > /etc/apt/apt.conf.d/99use-gzip-compression \ && apt-get update \ && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends wget ca-certificates vim\ autoconf automake less groff dpkg-dev \ bzip2=1.0.* file g++ gcc imagemagick \ libbz2-dev libc6-dev curl \ libdb-dev libevent-dev libffi-dev \ libgeoip-dev libglib2.0-dev libjpeg-dev \ libkrb5-dev liblzma-dev libmagickcore-dev \ libmagickwand-dev libmysqlclient-dev libncurses5-dev \ libpng12-dev libpq-dev libreadline-dev libsqlite3-dev \ libssl-dev libtool libwebp-dev libxml2-dev \ libxslt1-dev libyaml-dev make patch xz-utils \ zlib1g-dev tcl tk \ e2fsprogs iptables xfsprogs xz-utils openjdk-8-jdk-headless fakeroot \ mono-mcs libcurl4-openssl-dev liberror-perl unzip # Install python 3.6 ENV GPG_KEY 0D96DF4D4110E5C43FBFB17F2D347EA6AA65421D ENV PYTHON_VERSION 3.6.8 RUN set -ex \ \ && wget -O python.tar.xz "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz" \ && wget -O python.tar.xz.asc "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz.asc" \ && export GNUPGHOME="$(mktemp -d)" \ && gpg --batch --keyserver hkp://p80.pool.sks-keyservers.net:80 --recv-keys "$GPG_KEY" \ && gpg --batch --verify python.tar.xz.asc python.tar.xz \ && { command -v gpgconf > /dev/null && gpgconf --kill all || :; } \ && rm -rf "$GNUPGHOME" python.tar.xz.asc \ && mkdir -p /usr/src/python \ && tar -xJC /usr/src/python --strip-components=1 -f python.tar.xz \ && rm python.tar.xz \ && cd /usr/src/python \ && gnuArch="$(dpkg-architecture --query DEB_BUILD_GNU_TYPE)" \ && ./configure \ --build="$gnuArch" \ --enable-loadable-sqlite-extensions \ --enable-shared \ --with-system-expat \ --with-system-ffi \ --without-ensurepip \ && make -j "$(nproc)" \ && make install \ && ldconfig \ \ && find /usr/local -depth \ \( \ \( -type d -a \( -name test -o -name tests \) \) \ -o \ \( -type f -a \( -name '*.pyc' -o -name '*.pyo' \) \) \ \) -exec rm -rf '{}' + \ && rm -rf /usr/src/python \ \ && python3 --version # make some useful symlinks that are expected to exist RUN cd /usr/local/bin \ && ln -s idle3 idle \ && ln -s pydoc3 pydoc \ && ln -s python3 python \ && ln -s python3-config python-config # if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value ''" ENV PYTHON_PIP_VERSION 19.0.3 RUN set -ex; \ \ wget -O get-pip.py 'https://bootstrap.pypa.io/get-pip.py'; \ \ python get-pip.py \ --disable-pip-version-check \ --no-cache-dir \ "pip==$PYTHON_PIP_VERSION" \ ; \ pip --version; \ \ find /usr/local -depth \ \( \ \( -type d -a \( -name test -o -name tests \) \) \ -o \ \( -type f -a \( -name '*.pyc' -o -name '*.pyo' \) \) \ \) -exec rm -rf '{}' +; \ rm -f get-pip.py RUN pip install --no-cache-dir multi-model-server RUN mkdir -p /home/model-server/tmp COPY dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh COPY config.properties /home/model-server RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh EXPOSE 8080 8081 WORKDIR /home/model-server ENV TEMP=/home/model-server/tmp ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] CMD ["serve"] LABEL maintainer="dantu@amazon.com, rakvas@amazon.com, lufen@amazon.com, dden@amazon.com" ================================================ FILE: docker/advanced-dockerfiles/Dockerfile.base.ubuntu_16_04.py3_6.nightly ================================================ FROM ubuntu:16.04 ENV PYTHONUNBUFFERED TRUE # Install python3.6 and pip3 ENV PATH /usr/local/bin:$PATH ENV LANG C.UTF-8 RUN apt-get update && apt-get install -y --no-install-recommends \ tk-dev RUN set -ex \ && echo 'Acquire::CompressionTypes::Order:: "gz";' > /etc/apt/apt.conf.d/99use-gzip-compression \ && apt-get update \ && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends wget ca-certificates vim\ autoconf automake less groff dpkg-dev \ bzip2=1.0.* file g++ gcc imagemagick \ libbz2-dev libc6-dev curl \ libdb-dev libevent-dev libffi-dev \ libgeoip-dev libglib2.0-dev libjpeg-dev \ libkrb5-dev liblzma-dev libmagickcore-dev \ libmagickwand-dev libmysqlclient-dev libncurses5-dev \ libpng12-dev libpq-dev libreadline-dev libsqlite3-dev \ libssl-dev libtool libwebp-dev libxml2-dev \ libxslt1-dev libyaml-dev make patch xz-utils \ zlib1g-dev tcl tk \ e2fsprogs iptables xfsprogs xz-utils openjdk-8-jdk-headless fakeroot \ mono-mcs libcurl4-openssl-dev liberror-perl unzip # Install python 3.6 ENV GPG_KEY 0D96DF4D4110E5C43FBFB17F2D347EA6AA65421D ENV PYTHON_VERSION 3.6.8 RUN set -ex \ \ && wget -O python.tar.xz "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz" \ && wget -O python.tar.xz.asc "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz.asc" \ && export GNUPGHOME="$(mktemp -d)" \ && gpg --batch --keyserver hkp://p80.pool.sks-keyservers.net:80 --recv-keys "$GPG_KEY" \ && gpg --batch --verify python.tar.xz.asc python.tar.xz \ && { command -v gpgconf > /dev/null && gpgconf --kill all || :; } \ && rm -rf "$GNUPGHOME" python.tar.xz.asc \ && mkdir -p /usr/src/python \ && tar -xJC /usr/src/python --strip-components=1 -f python.tar.xz \ && rm python.tar.xz \ && cd /usr/src/python \ && gnuArch="$(dpkg-architecture --query DEB_BUILD_GNU_TYPE)" \ && ./configure \ --build="$gnuArch" \ --enable-loadable-sqlite-extensions \ --enable-shared \ --with-system-expat \ --with-system-ffi \ --without-ensurepip \ && make -j "$(nproc)" \ && make install \ && ldconfig \ \ && find /usr/local -depth \ \( \ \( -type d -a \( -name test -o -name tests \) \) \ -o \ \( -type f -a \( -name '*.pyc' -o -name '*.pyo' \) \) \ \) -exec rm -rf '{}' + \ && rm -rf /usr/src/python \ \ && python3 --version # make some useful symlinks that are expected to exist RUN cd /usr/local/bin \ && ln -s idle3 idle \ && ln -s pydoc3 pydoc \ && ln -s python3 python \ && ln -s python3-config python-config # if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value ''" ENV PYTHON_PIP_VERSION 19.0.3 RUN set -ex; \ \ wget -O get-pip.py 'https://bootstrap.pypa.io/get-pip.py'; \ \ python get-pip.py \ --disable-pip-version-check \ --no-cache-dir \ "pip==$PYTHON_PIP_VERSION" \ ; \ pip --version; \ \ find /usr/local -depth \ \( \ \( -type d -a \( -name test -o -name tests \) \) \ -o \ \( -type f -a \( -name '*.pyc' -o -name '*.pyo' \) \) \ \) -exec rm -rf '{}' +; \ rm -f get-pip.py RUN pip install --no-cache-dir multi-model-server --pre RUN mkdir -p /home/model-server/tmp COPY dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh COPY config.properties /home/model-server RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh EXPOSE 8080 8081 WORKDIR /home/model-server ENV TEMP=/home/model-server/tmp ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] CMD ["serve"] LABEL maintainer="dantu@amazon.com, rakvas@amazon.com, lufen@amazon.com, dden@amazon.com" ================================================ FILE: docker/advanced-dockerfiles/config.properties ================================================ # vmargs=-Xmx128m -XX:-UseLargePages -XX:+UseG1GC -XX:MaxMetaspaceSize=32M -XX:MaxDirectMemorySize=10m -XX:+ExitOnOutOfMemoryError model_store=/opt/ml/model load_models=ALL inference_address=http://0.0.0.0:8080 management_address=http://0.0.0.0:8081 # management_address=unix:/tmp/management.sock # number_of_netty_threads=0 # netty_client_threads=0 # default_response_timeout=120 # unregister_model_timeout=120 # default_workers_per_model=0 # job_queue_size=100 # async_logging=false # number_of_gpu=1 # cors_allowed_origin # cors_allowed_methods # cors_allowed_headers # keystore=src/test/resources/keystore.p12 # keystore_pass=changeit # keystore_type=PKCS12 # private_key_file=src/test/resources/key.pem # certificate_file=src/test/resources/certs.pem # max_response_size=6553500 # max_request_size=6553500 # blacklist_env_vars= decode_input_request=false # enable_envvars_config=false ================================================ FILE: docker/advanced-dockerfiles/dockerd-entrypoint.sh ================================================ #!/bin/bash set -e if [[ "$1" = "serve" ]]; then shift 1 multi-model-server --start --mms-config config.properties else eval "$@" fi # prevent docker exit tail -f /dev/null ================================================ FILE: docker/advanced_settings.md ================================================ # Advanced Settings ## Contents of this Document * [GPU Inference](advanced_settings.md#gpu-inference) * [Reference Commands](advanced_settings.md#reference-commands) * [Docker Details](advanced_settings.md#docker-details) * [Description of Config File Settings](advanced_settings.md#description-of-config-file-settings) * [Configuring SSL](advanced_settings.md#configuring-ssl) ## Other Relevant Documents * [Quickstart](README.md#quickstart) * [Configuring MMS with Docker](README.md#configuring-mms-with-docker) ## GPU Inference **Step 1: Install nvidia-docker.** `nvidia-docker` is NVIDIA's customized version of Docker that makes accessing your host's GPU resources from Docker a seamless experience. All of your regular Docker commands work the same way. Follow the [instructions for installing nvidia-docker](https://github.com/NVIDIA/nvidia-docker#quickstart). Return here and follow the next step when the installation completes. **Step 2: Download the GPU configuration template.** A GPU configuration template is provided for your use. Download the template a GPU config and place it in the `/tmp/models` folder you just created: * [config.properties](config.properties) **Step 3: Modify the configuration template.** Edit the file you downloaded, `config.properties` to configure the model-server. Save the file. **Step 4: Run MMS with Docker using a shared volume.** When you run the following command, the `-v` argument and path values of `/tmp/models/:/models` will map the `models` folder you created (assuming it was in ) with a folder inside the Docker container. MMS will then be able to use the local model file. ```bash nvidia-docker run -itd --name mms -p 80:8080 -p 8081:8081 -v /tmp/models/:/models awsdeeplearningteam/multi-model-server:latest-gpu multi-model-server --start --mms-config /models/config.properties --models squeezenet=https://s3.amazonaws.com/model-server/models/squeezenet_v1.1/squeezenet_v1.1.model ``` **Step 5: Test inference.** This configuration file is using the default Squeezenet model, so you will request the `predictions/squeezenet` API endpoint. ```bash curl -X POST http://127.0.0.1/predictions/squeezenet -F "data=@kitten.jpg" ``` Given that this is a different model, the same image yields a different inference result which will be something similar to the following: ``` { "prediction": [ [ { "class": "n02123159 tiger cat", "probability": 0.3630334138870239 }, ... ``` ## Reference Commands Manually pull the MMS Docker CPU image: ```bash docker pull awsdeeplearningteam/multi-model-server ``` Manually pull the MMS Docker GPU image: ```bash docker pull awsdeeplearningteam/multi-model-server:latest-gpu ``` List your Docker images: ```bash docker images ``` Verify the Docker container is running: ```bash docker ps -a ``` Stop the Docker container from running: ```bash docker rm -f mms ``` Delete the MMS Docker GPU image: ```bash docker rmi awsdeeplearningteam/multi-model-server:latest-gpu ``` Delete the MMS Docker GPU image: ```bash docker rmi awsdeeplearningteam/multi-model-server:latest ``` Output the recent logs to console. ```bash docker logs mms ``` Interact with the container. This will open a shell prompt inside the container. Use `$ Ctrl-p-Ctrl-q` to detach again. ```bash docker attach mms ``` Run the MMS Docker image without starting the Model Server: ```bash docker run -itd --name mms -p 80:8080 -p 8081:8081 awsdeeplearningteam/multi-model-server /bin/bash ``` Start MMS in the Docker container (CPU config): ```bash docker exec mms multi-model-server --start --mms-config /home/model-server/config.properties ``` Start MMS in the Docker container using nvidia-docker command as follows. : ```bash nvidia-docker exec multi-model-server --start --mms-config /home/model-server/config.properties ``` **Note**: To use GPU configuration, modify the config.properties to reflect that the model-server should use GPUs. ```properties inference_address=http://0.0.0.0:8080 management_address=http://0.0.0.0:8081 ... number_of_gpu=8 ... ``` Stop MMS. ```bash docker exec mms multi-model-server --stop ``` Get MMS help. ```bash docker exec mms multi-model-server --help ``` Refer [Docker CLI](https://docs.docker.com/engine/reference/commandline/run/) to understand each parameter. ## Docker Details ### Docker Hub Docker images are available on [Docker Hub](https://hub.docker.com/r/awsdeeplearningteam): * [CPU](https://hub.docker.com/r/awsdeeplearningteam/multi-model-server/tags) * [GPU](https://hub.docker.com/r/awsdeeplearningteam/multi-model-server/tags) ### Building a MMS Docker Image from Scratch The following are the steps to build a container image from scratch. #### Prerequisites In order to build the Docker image yourself you need the following: * Install Docker * Clone the MMS repo #### Docker Installation For macOS, you have the option of [Docker's Mac installer](https://docs.docker.com/docker-for-mac/install/) or you can simply use `brew`: ```bash brew install docker ``` For Windows, you should use [their Windows installer](https://docs.docker.com/docker-for-windows/install/). For Linux, check your favorite package manager if brew is available, otherwise use their installation instructions for [Ubuntu](https://docs.docker.com/engine/installation/linux/ubuntu/) or [CentOS](https://docs.docker.com/engine/installation/linux/centos/). #### Verify Docker When you've competed the installation, verify that Docker is running by running `docker images` in your terminal. If this works, you are ready to continue. #### Clone the MMS Repo If you haven't already, clone the MMS repo and go into the `docker` folder. ```bash git clone https://github.com/awslabs/multi-model-server.git && cd multi-model-server/docker ``` ### Building the Container Image #### Configuring the Docker Build for Use on EC2 Now you can examine how to build a Docker image with MMS and establish a public accessible endpoint on EC2 instance. You should be able to adapt this information for any cloud provider. This Docker image can be used in other production environments as well. Skip this section if you're building for local use. The first step is to create an [EC2 instance](https://aws.amazon.com/ec2/). ### Build Step for CPU container image There are separate `Dockerfile` configuration files for CPU and GPU. They are named `Dockerfile.cpu` and `Dockerfile.gpu` respectively. The container image consists of MXNet, Java, MMS and all related python libraries. We can build the Multi Model Server image based on the Dockerfile as follows: ```bash # Building the MMS CPU image docker build -f Dockerfile.cpu -t mms_image . ``` Once this completes, run `docker images` from your terminal. You should see the Docker image listed with the tag, `mms_image:latest`. ### Build Step for GPU If your host machine has at least one GPU installed, you can use a GPU Docker image to benefit from improved inference performance. You need to install [nvidia-docker plugin](https://github.com/NVIDIA/nvidia-docker) before you can use a NVIDIA GPU with Docker. Once you install `nvidia-docker`, run following commands (for info modifying the tag, see the CPU section above): ```bash # Building the MMS GPU image docker build -f Dockerfile.gpu -t mms_image_gpu . ``` #### Running the MMS GPU Docker ```bash nvidia-docker run -itd -p 80:8080 8081:8081 --name mms -v /home/user/models/:/models mms_image_gpu /bin/bash ``` This command starts the Docker instance in a detached mode and mounts `/home/user/models` of the host system into `/models` directory inside the Docker instance. Considering that you modified and copied `config.properties` file into the models directory, before you ran the above `nvidia-docker` command, you would have this configuration file ready to use in the Docker instance. ```bash nvidia-docker exec mms multi-model-server --start --mms-config /models/config.properties ``` ### Testing the MMS Docker Now you can send a request to your server's [api-description endpoint](http://localhost/api-description) to see the list of MMS endpoints or [ping endpoint](http://localhost/ping) to check the health status of the MMS API. Remember to add the port if you used a custom one or the IP or DNS of your server if you configured it for that instead of localhost. Here are some handy test links for common configurations: * [http://localhost/api-description](http://localhost/api-description) * [http://localhost/ping](http://localhost/ping) If `config.properties` file is used as is, the following commands can be run to verify that the Multi Model Server is running. ```bash curl -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg curl -X POST http://127.0.0.1/squeezenet/predict -F "data=@kitten.jpg" ``` The predict endpoint will return a prediction response in JSON. It will look something like the following result: ```json { "prediction": [ [ { "class": "n02124075 Egyptian cat", "probability": 0.9408261179924011 }, { "class": "n02127052 lynx, catamount", "probability": 0.055966004729270935 }, { "class": "n02123045 tabby, tabby cat", "probability": 0.0025502564385533333 }, { "class": "n02123159 tiger cat", "probability": 0.00034320182749070227 }, { "class": "n02123394 Persian cat", "probability": 0.00026897044153884053 } ] ] } ``` ## Description of Config File Settings **For config.properties:** The system settings are stored in [config.properties](config.properties). You can modify these settings to use different models, or to apply other customized settings. Notes on a couple of the parameters: * **model_store** - The directory on the local host where models must reside for serving. * **management_address** - The address:port value on which the model server would serve control plane APIs such as "GET", "PUT", "DELETE" of "models" * **inference_address** - The address:port value on which the model server would serve data plane APIs such as predictions, ping and api-description * **load_models** - List of all the models in the `model_store` which should be loaded on startup * **number_of_netty_threads** - Number of threads present to handle the incoming requests. * **max_workers** - * **job_queue_size** - Number of requests that can be queued. This queue is shared across models. * **number_of_gpu** - Number of GPUs available for model server when serving inferences on GPU hosts * **keystore** - SSL Key Store * **keystore_pass** - SSL password * **keystore_type** - Store of cryptographic keys and certificates * **private_key_file** - Location of the private key file * **certificate_file** - Location of the certificate file * **max_response_size** - The maximum buffer size the frontend allocates for a worker response, in bytes. * **max_request_size** - The maximum allowable request size that the MMS accepts. in the range of 0 .. (num-gpu-1) in a round-robin fashion. **By default MMS uses all the available GPUs but this parameter can be configured if user want to use only few of them**. ```properties # vmargs=-Xmx1g -XX:MaxDirectMemorySize=512m -Dlog4j.configurationFile=file:///opt/ml/conf/log4j2.xml # model_store=/opt/ml/model # load_models=ALL # inference_address=http://0.0.0.0:8080 # management_address=http://0.0.0.0:8081 # number_of_netty_threads=0 # max_workers=0 # job_queue_size=1000 # number_of_gpu=1 # keystore=src/test/resources/keystore.p12 # keystore_pass=changeit # keystore_type=PKCS12 # private_key_file=src/test/resources/key.pem # certificate_file=src/test/resources/certs.pem # max_response_size=6553500 # max_request_size=6553500 ``` ================================================ FILE: docker/config.properties ================================================ vmargs=-Xmx128m -XX:-UseLargePages -XX:+UseG1GC -XX:MaxMetaspaceSize=32M -XX:MaxDirectMemorySize=10m -XX:+ExitOnOutOfMemoryError model_store=/opt/ml/model load_models=ALL inference_address=http://0.0.0.0:8080 management_address=http://0.0.0.0:8081 preload_model=false # management_address=unix:/tmp/management.sock # number_of_netty_threads=0 # netty_client_threads=0 # default_response_timeout=120 # unregister_model_timeout=120 # default_workers_per_model=0 # job_queue_size=100 # async_logging=false # number_of_gpu=1 # cors_allowed_origin # cors_allowed_methods # cors_allowed_headers # keystore=src/test/resources/keystore.p12 # keystore_pass=changeit # keystore_type=PKCS12 # private_key_file=src/test/resources/key.pem # certificate_file=src/test/resources/certs.pem # max_response_size=6553500 # max_request_size=6553500 # blacklist_env_vars= # decode_input_request=true # enable_envvars_config=false ================================================ FILE: docker/dockerd-entrypoint.sh ================================================ #!/bin/bash set -e if [[ "$1" = "serve" ]]; then shift 1 multi-model-server --start --mms-config config.properties else eval "$@" fi # prevent docker exit tail -f /dev/null ================================================ FILE: docs/README.md ================================================ # Multi Model Server Documentation ## Basic Features * [Serving Quick Start](../README.md#serve-a-model) - Basic server usage tutorial * [Model Archive Quick Start](../model-archiver#creating-a-model-archive) - Tutorial that shows you how to package a model archive file. * [Installation](install.md) - Installation procedures and troubleshooting * [Serving Models](server.md) - Explains how to use `multi-model-server`. * [REST API](rest_api.md) - Specification on the API endpoint for MMS * [Model Zoo](model_zoo.md) - A collection of MMS model archive (.mar) files that you can use with MMS. * [Packaging Model Archive](../model-archiver/README.md) - Explains how to package model archive file, use `model-archiver`. * [Docker](../docker/README.md) - How to use MMS with Docker and cloud services * [Logging](logging.md) - How to configure logging * [Metrics](metrics.md) - How to configure metrics ## Advanced Features * [Advanced settings](configuration.md) - Describes advanced MMS configurations. * [Custom Model Service](custom_service.md) - Describes how to develop custom inference services. * [Unit Tests](../mms/tests/README.md) - Housekeeping unit tests for MMS. * [Benchmark](../benchmarks/README.md) - Use JMeter to run MMS through the paces and collect benchmark data. * [Model Serving with Amazon Elastic Inference](elastic_inference.md) - Run Model server on Elastic Inference enabled EC2 instances. ## Example Projects * [MMS on Fargate, Serverless Inference](mms_on_fargate.md) - The project which illustrates the step-by-step process to launch MMS as a managed inference production service, on ECS Fargate. * [MXNet Vision Service](../examples/mxnet_vision/README.md) - An example MMS project for a MXNet Image Classification model. The project takes JPEG image as input for inference. * [LSTM](../examples/lstm_ptb/README.md) - An example MMS project for a recurrent neural network (RNN) using long short-term memory (LSTM). The project takes JSON inputs for inference against a model trained with a specific vocabulary. * [Object Detection](../examples/ssd/README.md) - An example MMS project that uses a pretrained Single Shot Multi Object Detection (SSD) model that takes image inputs and infers the types and locations of several classes of objects. ================================================ FILE: docs/batch_inference_with_mms.md ================================================ # Batch Inference with Model Server ## Contents of this Document * [Introduction](#introduction) * [Batching example](#batch-inference-with-mms-using-resnet-152-model) * [Model Handler Code](#model-handler-a.k.a.-entry-point) * [Initialization logic](#initialization-logic) * [Preprocess logic](#preprocess-logic) * [Inference logic](#inference-logic) * [Postprocess logic](#postprocess-logic) * [MMS Model Configuration](#mms-model-configuration) * [Demo](#demo-to-configure-mms-with-batch-supported-model) * [Prerequisites](#pre-requisites) * [Running MMS with batch inference](#loading-resnet-152-which-handles-batch-inferences) * [Performance results](#performance-benchmarking) * [Conclusion](#conclusion) ## Introduction Batching in the Machine-Learning/Deep-Learning is a process of aggregating inference-requests and sending this aggregated requests through the ML/DL framework for inference at once. Model Server (MMS) was designed to natively support batching of incoming inference requests. This functionality provides customer using MMS to optimally utilize their host resources, because most ML/DL frameworks are optimized for batch requests. This optimal utilization of host resources in turn reduces the operational expense of hosting an inference service using MMS. In this document we will go through an example of how this is done and compare the performance of running a batched inference against running single inference. ## Prerequisites: Before jumping into this document, please go over the following docs 1. [What is MMS?](../README.md) 1. [What is custom service code?](custom_service.md) ## Batch Inference with MMS using ResNet-152 model To support batching of inference requests, MMS needs the following: 1. MMS Model Configuration: MMS provides means to configure "Max Batch Size" and "Max Batch Delay" through "POST /models" API. MMS needs to know the maximum batch size that the model can handle and the maximum delay that MMS should wait for, to form this request-batch. 2. Model Handler code: MMS requires the Model Handler to handle the batch of inference requests. In this section we will go over the configuration of MMS to handle batching and the actual code changes required at the model level to handle batching. Lets begin with the "Model Handler Code" component and see how we can convert an existing Resnet-152 model from the [MMS Model Zoo](https://github.com/awslabs/multi-model-server/blob/master/docs/model_zoo.md#resnet-152) into a model which can process a batch of requests. For a full working code, refer to [mxnet_vision_batching.py](https://github.com/awslabs/multi-model-server/blob/master/examples/model_service_template/mxnet_vision_batching.py) ### Model Handler a.k.a. entry-point The handler method is the entrypoint to the model. When MMS receives a request for a model, MMS forwards this request to the `handler` method associated with the model. A model's `handler` entry-point is expected to have the following logic: 1. Loading Model: The handler code should have logic to load the model-artifacts onto the DL/ML Framework, as a part of initialization. In our example, the DL Framework is MXNet 1. Once initialized, this handler method is given requests for processing. So, the handler is expected to have the logic for: 1. Logic for `preprocess`ing request: The handler converts the incoming request to a format understandable by the ML/DL framework. In our example, that is an NDArray. 1. Logic for `inference`: The handler should take the preprocessed data and pass it through through the DL Framework for inference. 1. Logic for `postprocess`ing: The handler should take the output of the inference-logic and `postprocess` this data. This is then sent back the clients as the final result. Let's look into how we define the `initialize`, `preprocess`, `inference` and `postprocess` logic in detail. #### Initialization logic As a part of the handler's parameters, MMS sends a `context` object to the handler. Besides everything else, this `context` object also contains the batch-size that this model was configured to handle. The configuration for the model comes from the `POST /models` call and this is explained in detail in the below section [MMS Model Configuration](#mms-model-configuration). Lets look at a code snippet, ```python class MXNetVisionServiceBatching(object): ... def initialize(self, context): ... self._batch_size = context.system_properties["batch_size"] data_shapes = [] # Read the data shape data_shapes[0] = self._batch_size ... # Bind the MXNet module with the data shape NCHW where N is the batch size self.mx_model.bind(for_training=False, data_shapes=data_shapes) ... ... ``` This initialization logic binds the MXNet module with a batch-size that the model is configured with. For example, if the incoming requests are images with RGB layers and the input to the network is or size [3,224, 224], the batch adds an additional dimension to this data. If the batch size configured with `POST /models` API is 8, the input to the DL Framework becomes [8, 3, 224, 224]. #### Preprocess logic Once a model is initialized, the DL framework (in our case MXNet) expects a `batch_size` number of requests. Hence, the preprocessing logic should make sure that no matter how many requests come into the handler, it always returns the `batch_size` number of NDArray elements to be used in the inference logic. Lets look at a sample `preprocess` logic below ```python import mxnet as mx class MXNetVisionServiceBatching(object): ... def initialize(self, context): ... ... def preprocess(self, request): img_list = [] param_name = self.signature['inputs'][0]['data_name'] input_shape = self.signature['inputs'][0]['data_shape'] # We are assuming input shape is NCHW [c, h, w] = input_shape[1:] for idx, data in enumerate(request): img = data.get(param_name) if img is None: img = data.get("body") if img is None: img = data.get("data") if img is None or len(img) == 0: logging.error("Error processing request") self.erroneous_reqs.add(idx) continue try: img_arr = mx.image.imdecode(img, 1, True, None) except Exception as e: logging.error(e, exc_info=True) self.erroneous_reqs.add(idx) continue img_arr = mx.image.imresize(img_arr, w, h, 2) img_arr = mx.nd.transpose(img_arr, (2, 0, 1)) self._num_requests = idx + 1 img_list.append(img_arr) logging.debug("Worker :{} received {} requests".format(os.getpid(), self._num_requests)) reqs = mx.nd.stack(*img_list) reqs = reqs.as_in_context(self.mxnet_ctx) if (self._batch_size - self._num_requests) != 0: padding = mx.nd.zeros((self._batch_size - self._num_requests, c, h, w), self.mxnet_ctx, 'uint8') reqs = mx.nd.concat(reqs, padding, dim=0) return reqs ``` **NOTE: The above code handles the case where the `handler` doesn't receive a `batch_size` number of requests. This is because MMS waits for a `max_batch_delay` amount of time to receive `batch_size` number of requests. If the `max_batch_delay` timer times out before receiving `batch_size` number of requests, MMS bundles what ever requests it received and sends it to the handler for processing.** #### Inference logic The inference logic is similar to the inference logic of processing single requests. Since this isn't as interesting, we will skip explaining this in detail. Sample logic is shown below for completeness of this document. ```python import mxnet as mx from collections import namedtuple class MXNetVisionServiceBatching(object): ... def initialize(self, context): ... ... def preprocess(self, request): ... ... def inference(self, model_input): batch = namedtuple('Batch', ['data']) if self.error is not None: return None self.mx_model.forward(batch([model_input]), is_train=False) outputs = self.mx_model.get_outputs() res = mx.ndarray.split(outputs[0], axis=0, num_outputs=outputs[0].shape[0]) res = [res] if not isinstance(res, list) else res return res ... ``` Lets move onto post-processing. #### Postprocess logic The output of inference logic is fed to the post-processing logic. As we saw before, during preprocessing, if there aren't `batch_size` number of requests, preprocessing logic pads the difference with 0's. These are artificially created requests sent to the module for inference. **In post-process we should ignore these artificially padded requests**. Lets look at the sample logic ```python import mxnet as mx from collections import namedtuple class MXNetVisionServiceBatching(object): ... def initialize(self, context): ... ... def preprocess(self, request): ... ... def inference(self, model_input): ... def postprocess(self, data): res = [] for idx, resp in data[:self._num_requests]: if idx not in self.erroneous_reqs: res.append(self.top_probability(resp, self.labels, top=5)) else: res.append("This request was not processed successfully. Refer to mms.log for additional information") return res ... ``` In the above code, we iterate only until **self._num_requests**. This variable is assigned a value during the preprocessing step. This logic ensures that the postprocess logic is run only for the actual requests coming from external clients. ### MMS Model Configuration To configure MMS to use the batching feature, you would have to provide the batch configuration information through [**POST /models** API](https://github.com/awslabs/multi-model-server/blob/master/docs/management_api.md#register-a-model). The configuration that we are interested in is the following: 1. `batch_size`: This is the maximum batch size that a model is expected to handle. 2. `max_batch_delay`: This is the maximum batch delay time MMS waits to receive `batch_size` number of requests. If MMS doesn't receive `batch_size` number of requests before this timer time's out, it sends what ever requests that were received to the model `handler`. Let's look at an example using this configuration ```bash # The following command will register a model "resnet-152.mar" and configure MMS to use a batch_size of 8 and a max batch delay of 50 milli seconds. curl -X POST "localhost:8081/models?url=resnet-152.mar&batch_size=8&max_batch_delay=50" ``` These configurations are used both in MMS and in the model's custom-service-code (a.k.a the handler code). 1. MMS: MMS associates the batch related configuration with each model. The frontend then tries to aggregate the batch-size number of requests and send it to the backend. 2. Model Custom Handler Code: The handler code is given the information about the batch-size. The handler then uses this information to tell the DL framework about the expected batch size. ## Demo to configure MMS with batch-supported model In this section lets bring up model server and launch Resnet-152 model, which has been built to handle a batch of request. ### Pre-requisites Follow the main [Readme](https://github.com/awslabs/multi-model-server/blob/master/README.md) and install all the required packages including "multi-model-server" ### Loading Resnet-152 which handles batch inferences * Start the model server. In this example, we are starting the model server to run on inference port 8080 and management port 8081. ```text $ cat config.properties ... inference_address=http://0.0.0.0:8080 management_address=http://0.0.0.0:8081 ... $ multi-model-server --start ``` * Verify that the MMS is up and running ```text $ curl localhost:8080/ping { "status": "Healthy" } ``` * Now lets launch resnet-152 model, which we have built to handle batch inference. Since this is an example, we are going to launch 1 worker which handles a batch size of 8 with a max-batch-delay of 10ms. ```text $ curl -X POST "localhost:8081/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/examples/resnet-152-batching/resnet-152.mar&batch_size=8&max_batch_delay=10&initial_workers=1" { "status": "Processing worker updates..." } ``` * Verify that the workers were started properly ```text $ curl localhost:8081/models/resnet-152 { "modelName": "resnet-152", "modelUrl": "https://s3.amazonaws.com/model-server/model_archive_1.0/examples/resnet-152-batching/resnet-152.mar", "runtime": "python", "minWorkers": 1, "maxWorkers": 1, "batchSize": 8, "maxBatchDelay": 10, "workers": [ { "id": "9008", "startTime": "2019-02-19T23:56:33.907Z", "status": "READY", "gpu": false, "memoryUsage": 607715328 } ] } ``` * Now let's test this service. * Get an image to test this service ```text $ curl -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg ``` * Run inference to test the model ```text $ curl -X POST localhost/predictions/resnet-152 -T kitten.jpg { "probability": 0.7148938179016113, "class": "n02123045 tabby, tabby cat" }, { "probability": 0.22877725958824158, "class": "n02123159 tiger cat" }, { "probability": 0.04032370448112488, "class": "n02124075 Egyptian cat" }, { "probability": 0.00837081391364336, "class": "n02127052 lynx, catamount" }, { "probability": 0.0006728120497427881, "class": "n02129604 tiger, Panthera tigris" } ``` * Now that we have the service up and running, we could run performance tests with the same kitten image as follows. There are multiple tools to measure performance of web-servers. We will use [apache-bench](https://httpd.apache.org/docs/2.4/programs/ab.html) to run our performance tests. We chose `apache-bench` for our tests because of the ease of installation and ease of running tests. Before running this test, we need to first install `apache-bench` on our System. Since we were running this on a ubuntu host, we installed apache-bench as follows ```bash $ sudo apt-get udpate && sudo apt-get install apache2-utils ``` Now that installation is done, we can run performance benchmark test as follows. ```text $ ab -k -l -n 10000 -c 1000 -T "image/jpeg" -p kitten.jpg localhost:8080/predictions/resnet-152 ``` The above test simulates MMS receiving 1000 concurrent requests at once and a total of 10,000 requests. All of these requests are directed to the endpoint "localhost:8080/predictions/resnet-152", which assumes that resnet-152 is already registered and scaled-up on MMS. We had done this registration and scaling up in the above steps. ## Performance benchmarking We benchmarked MMS with batch-inference enabled Resnet-152 on a *P3.8xlarge* instance, which is a AWS provided GPU EC2 instance. We ran MMS in our [GPU container](https://hub.docker.com/r/awsdeeplearningteam/multi-model-server/tags) which hosted the above resnet-152 model. We ran the tests for batch sizes 1, 8 and 16 and captured the results. We saw a significant gain in throughput and also saw that the GPU resources were utilized more optimally. Attached is the graph showing the throughput gains. The experiment was done with the following configuration. To understand the details of this configuration please refer the [Configuration document](configuration.md) ```bash # MMS configuration $ cat config.properties model_store=/opt/ml/model inference_address=http://0.0.0.0:8080 management_address=http://0.0.0.0:8081 number_of_netty_threads=32 job_queue_size=1000 async_logging=true ``` ```bash # To load the model run the following command $ curl -X POST "localhost:81/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/examples/resnet-152-batching/resnet-152.mar&batch_size=8&max_batch_delay=50&initial_workers=8" ``` As seen from the above command, the number of workers was set at 8 (2 per GPU), and `max_batch_delay` was set at 50 ms. ```bash # Apache bench performance test command $ ab -k -l -n 10000 -c 1000 -T "image/jpeg" -p kitten.jpg localhost:8080/predictions/resnet-152 ``` We set the `batch_size` in the above `curl` command at 1, 8 and 16 and captured the results of `ab` command by setting the `-c` option, or concurrency, at 10, 50, 100, 500 and 1000 for each of the `batch_size`s. ![Graph](images/throughput.png) As we can see from the above diagram, batching is not a one-size-fits-all solution. For example, when the rate of requests received at MMS is lower, look at 10 concurrent clients, the batch size of 1 out performs batch size of 8 or 16. This is because, some requests would wait until the 50ms timeout occured before getting scheduled to be processed by the model handlers. Whereas, by definition, batch size of 1 wouldn't wait for this timeout to occur and the model handler is given the request to handle as and when the requests come in. We tried to show in an intuitive way as to why we get lower TPS when batch size is 8 or 16 as compared to batch size of 1, when the number of concurrent requests sent to MMS was set at 10. For this, we kept the batch-size at 8, number of backend workers at 8 and kept the number of concurrent requests coming in at 10 and varied the `max_batch_delay` and ran `ab`. As we can see from the graph below, as the max_batch_delay was set closer to 0, the TPS increased and gets closer to being similar to `batch_size` 1. ![tps_vs_batch_delay](images/tps_vs_delay.png) ## Conclusion The take away from the experiments is that batching is a very useful feature. In cases where the services receive heavy load of requests or each request has high I/O, its advantageous to batch the requests. This allows for maximally utilizing the compute resources, especially GPU compute which are also more often than not more expensive. But customers should do their due diligence and perform enough tests to find optimal batch size depending on the number of GPUs available and number of models loaded per GPU. Customers should also analyze their traffic patterns before enabling the batch-inference. As shown in the above experiments, services receiving TPS lesser than the batch size would lead to consistent "batch delay" timeouts and cause the response latency per request to spike. As any cutting edge technology, batch-inference is definitely a double edged sword. ================================================ FILE: docs/configuration.md ================================================ # Advanced configuration One of design goal of MMS 1.0 is easy to use. The default settings form MMS should be sufficient for most of use cases. This document describe advanced configurations that allows user to deep customize MMS's behavior. ## Environment variables User can set environment variables to change MMS behavior, following is a list of variables that user can set for MMS: * JAVA_HOME * PYTHONPATH * MMS_CONFIG_FILE * LOG_LOCATION * METRICS_LOCATION **Note:** environment variable has higher priority that command line or config.properties. It will override other property values. ## Command line parameters User can following parameters to start MMS, those parameters will override default MMS behavior: * **--mms-config** MMS will load specified configuration file if MMS_CONFIG_FILE is not set. * **--model-store** This parameter will override `model_store` property in config.properties file. * **--models** This parameter will override `load_models' property in config.properties. * **--log-config** This parameter will override default log4j2.xml * **--foreground** This parameter will run the model server in foreground. If this option is disabled, the model server will run in the background. See [Running the Model Server](server.md) for detail. ## config.properties file MMS use a `config.properties` file to store configurations. MMS use following order to locate this `config.properties` file: 1. if `MMS_CONFIG_FILE` environment variable is set, MMS will load the configuration from the environment variable. 2. if `--mms-config` parameter is passed to `multi-model-server`, MMS will load the configuration from the parameter. 3. if there is a `config.properties` in current folder where user start the `multi-model-server`, MMS will load the `config.properties` file form current working directory. 4. If none of above is specified, MMS will load built-in configuration with default values. **Note:** Docker image that MMS provided has slightly different default value. ### Customize JVM options The restrict MMS frontend memory footprint, certain JVM options is set via **vmargs** property in `config.properties` file * default: N/A, use JVM default options * docker default: -Xmx128m -XX:-UseLargePages -XX:+UseG1GC -XX:MaxMetaspaceSize=32M -XX:MaxDirectMemorySize=10m -XX:OnOutOfMemoryError='kill -9 %p' User can adjust those JVM options for fit their memory requirement if needed. ### Load models at startup User can configure load models while MMS startup. MMS can load models from `model_store` or from HTTP(s) URL. * model_store * standalone: default: N/A, load models from local disk is disabled. * docker: default model_store location is set to: /opt/ml/model * load_models * standalone: default: N/A, no models will be load on startup. * docker: default: ALL, all model archives in /opt/ml/model will be loadded on startup. **Note:** `model_store` and `load_models` property can be override by command line parameters. ### Configure MMS listening port MMS doesn't support authentication natively. To avoid unauthorized access, MMS only allows localhost access by default. Inference API is listening on 8080 port and accepting HTTP request. Management API is listening on 8081 port and accepting HTTP request. See [Enable SSL](#enable-ssl) for configuring HTTPS. * inference_address: inference API binding address, default: http://127.0.0.1:8080 * management_address: management API binding address, default: http://127.0.0.1:8081 Here are a couple of examples: ```properties # bind inference API to all network interfaces with SSL enabled inference_address=https://0.0.0.0:8443 # bind inference API to private network interfaces inference_address=https://172.16.1.10:8080 ``` ### Enable SSL For users who want to enable HTTPs, you can change `inference_address` or `management_addrss` protocol from http to https, for example: `inference_addrss=https://127.0.0.1`. This will make MMS listening on localhost 443 port to accepting https request. User also must provide certificate and private keys to enable SSL. MMS support two ways to configure SSL: 1. Use keystore * keystore: Keystore file location, if multiple private key entry in the keystore, first one will be picked. * keystore_pass: keystore password, key password (if applicable) MUST be the same as keystore password. * keystore_type: type of keystore, default: PKCS12 2. Use private-key/certificate files * private_key_file: private key file location, support both PKCS8 and OpenSSL private key. * certificate_file: X509 certificate chain file location. #### Self-signed certificate example This is a quick example to enable SSL with self-signed certificate 1. User java keytool to create keystore ```bash keytool -genkey -keyalg RSA -alias mms -keystore keystore.p12 -storepass changeit -storetype PKCS12 -validity 3600 -keysize 2048 -dname "CN=www.MY_MMS.com, OU=Cloud Service, O=model server, L=Palo Alto, ST=California, C=US" ``` Config following property in config.properties: ```properties inference_address=https://127.0.0.1:8443 management_address=https://127.0.0.1:8444 keystore=keystore.p12 keystore_pass=changeit keystore_type=PKCS12 ``` 2. User OpenSSL to create private key and certificate ```bash ``` Config following property in config.properties: ```properties inference_address=https://127.0.0.1:8443 management_address=https://127.0.0.1:8444 keystore=keystore.p12 keystore_pass=changeit keystore_type=PKCS12 ``` ### Configure Cross-Origin Resource Sharing (CORS) CORS is a mechanism that uses additional HTTP headers to tell a browser to let a web application running at one origin (domain) have permission to access selected resources from a server at a different origin. CORS is disabled by default. Configure following properties in config.properties file to enable CORS: ```properties # cors_allowed_origin is required to enable CORS, use '*' or your domain name cors_allowed_origin=https://yourdomain.com # required if you want to use preflight request cors_allowed_methods=GET, POST, PUT, OPTIONS # required if the request has an Access-Control-Request-Headers header cors_allowed_headers=X-Custom-Header ``` ### Preloading a model The model server gives users an option to take advantage of fork() sematics, ie., copy-on-write, on linux based systems. In order to load a model before spinning up the model workers, use `preload_model` option. Model server upon seeing this option set, will load the model just before scaling the first model worker. All the other workers will share the same instance of the loaded model. This way only the memory locations in the loaded model which are touch will be copied over to the individual model-workers process memory space. ```properties preload_model=true ``` ### Prefer direct buffer Configuration parameter prefer_direct_buffer controls if the model server will be using direct memory specified by -XX:MaxDirectMemorySize. This parameter is for model server only and doesn't affect other packages' usage of direct memory buffer. Default: false ```properties prefer_direct_buffer=true ``` ### Restrict backend worker to access environment variable Environment variable may contains sensitive information like AWS credentials. Backend worker will execute arbitrary model's custom code, which may expose security risk. MMS provides a `blacklist_env_vars` property which allows user to restrict which environment variable can be accessed by backend worker. * blacklist_env_vars: a regular expression to filter out environment variable names, default: all environment variable will be visible to backend worker. ### Limit GPU usage By default, MMS will use all available GPUs for inference, you use `number_of_gpu` to limit the usage of GPUs. * number_of_gpu: max number of GPUs that MMS can use for inference, default: available GPUs in system. ### Other properties Most of those properties are designed for performance tuning. Adjusting those numbers will impact scalability and throughput. * enable_envvars_config: Enable configuring MMS through environment variables. When this option is set to "true", all the static configurations of MMS can come through environment variables as well. default: false * number_of_netty_threads: number frontend netty thread, default: number of logical processors available to the JVM. * netty_client_threads: number of backend netty thread, default: number of logical processors available to the JVM. * default_workers_per_model: number of workers to create for each model that loaded at startup time, default: available GPUs in system or number of logical processors available to the JVM. * job_queue_size: number inference jobs that frontend will queue before backend can serve, default 100. Useful in cases where certain requests take predictably longer to complete. * async_logging: enable asynchronous logging for higher throughput, log output may be delayed if this is enabled, default: false. * default_response_timeout: Timeout, in seconds, used for model's backend workers before they are deemed unresponsive and rebooted. default: 120 seconds. * unregister_model_timeout: Timeout, in seconds, used when handling an unregister model request when cleaning a process before it is deemed unresponsive and an error response is sent. default: 120 seconds. * decode_input_request: Configuration to let backend workers to decode requests, when the content type is known. If this is set to "true", backend workers do "Bytearray to JSON object" conversion when the content type is "application/json" and the backend workers convert "Bytearray to utf-8 string" when the Content-Type of the request is set to "text*". default: true ### config.properties Example See [config.properties for docker](https://github.com/awslabs/multi-model-server/blob/master/docker/config.properties) ================================================ FILE: docs/custom_service.md ================================================ # Custom Service ## Contents of this Document * [Introduction](#introduction) * [Requirements for custom service file](#requirements-for-custom-service-file) * [Example Custom Service file](#example-custom-service-file) * [Creating model archive with entry point](#creating-model-archive-with-entry-point) ## Introduction A custom service , is the code that is packaged into model archive, that is executed by Multi Model Server (MMS). The custom service is responsible for handling incoming data and passing on to engine for inference. The output of the custom service is returned back as response by MMS. ## Requirements for custom service file The custom service file should define a method that acts as an entry point for execution, this function will be invoked by MMS on a inference request. The function can have any name, not necessarily handle, however this function should accept, the following parameters * **data** - The input data from the incoming request * **context** - Is the MMS [context](https://github.com/awslabs/multi-model-server/blob/master/mms/context.py) information passed for use with the custom service if required. The signature of a entry point function is: ```python def function_name(data,context): """ Works on data and context passed """ # Use parameters passed ``` The next section, showcases an example custom service. ## Example Custom Service file ```python # custom service file # model_handler.py """ ModelHandler defines a base model handler. """ import logging class ModelHandler(object): """ A base Model handler implementation. """ def __init__(self): self.error = None self._context = None self._batch_size = 0 self.initialized = False def initialize(self, context): """ Initialize model. This will be called during model loading time :param context: Initial context contains model server system properties. :return: """ self._context = context self._batch_size = context.system_properties["batch_size"] self.initialized = True def preprocess(self, batch): """ Transform raw input into model input data. :param batch: list of raw requests, should match batch size :return: list of preprocessed model input data """ # Take the input data and pre-process it make it inference ready assert self._batch_size == len(batch), "Invalid input batch size: {}".format(len(batch)) return None def inference(self, model_input): """ Internal inference methods :param model_input: transformed model input data :return: list of inference output in NDArray """ # Do some inference call to engine here and return output return None def postprocess(self, inference_output): """ Return predict result in batch. :param inference_output: list of inference output :return: list of predict results """ # Take output from network and post-process to desired format return ["OK"] * self._batch_size def handle(self, data, context): """ Call preprocess, inference and post-process functions :param data: input data :param context: mms context """ model_input = self.preprocess(data) model_out = self.inference(model_input) return self.postprocess(model_out) _service = ModelHandler() def handle(data, context): if not _service.initialized: _service.initialize(context) if data is None: return None return _service.handle(data, context) ``` Here the ``` handle()``` method is our entry point that will be invoked by MMS, with the parameters data and context, it in turn can pass this information to an actual inference class object or handle all the processing in the ```handle()``` method itself. The ```initialize()``` method is used to initialize the model at load time, so after first time, the service need not be re-initialized in the the life cycle of the relevant worker. We recommend using a ```initialize()``` method, avoid initialization at prediction time. This entry point is engaged in two cases: (1) when MMS is asked to scale a model up, to increase the number of backend workers (it is done either via a ```PUT /models/{model_name}``` request or a ```POST /models``` request with `initial-workers` option or during MMS startup when you use `--models` option (```multi-model-server --start --models {model_name=model.mar}```), ie., you provide model(s) to load) or (2) when MMS gets a ```POST /predictions/{model_name}``` request. (1) is used to scale-up or scale-down workers for a model. (2) is used as a standard way to run inference against a model. (1) is also known as model load time, and that is where you would normally want to put code for model initialization. You can find out more about these and other MMS APIs in [MMS Management API](./management_api.md) and [MMS Inference API](./inference_api.md) ### Returning custom error codes To return a custom error code back to the user use the `PredictionException` in the `mms.service` module. ```python from mms.service import PredictionException def handler(data, context): # Some unexpected error - returning error code 513 raise PredictionException("Some Prediction Error", 513) ``` ## Creating model archive with entry point MMS, identifies the entry point to the custom service, from the manifest file. Thus file creating the model archive, one needs to mention the entry point using the ```--handler``` option. The [model-archiver](https://github.com/awslabs/multi-model-server/blob/master/model-archiver/README.md) tool enables the create to an archive understood by MMS. ```python model-archiver --model-name --handler model_handler:handle --export-path --model-path --runtime python3 ``` This will create file ```.mar``` in the directory `````` This will create a model archive with the custom handler, for python3 runtime. The ```--runtime``` parameter enables usage of specific python version at runtime, by default it uses the default python distribution of the system. ================================================ FILE: docs/elastic_inference.md ================================================ # Model Serving with Amazon Elastic Inference ## Contents of this Document * [Introduction](#introduction) * [Custom Service](#custom-service) * [Creating a EC2 instance with EIA support](#creating-a-ec2-instance-with-eia-support) * [Custom Service file with EIA](#custom-service-file-with-eia) * [Running elastic inference on a resnet-152](#running-elastic-inference-on-a-resnet-152) ## Introduction Amazon Elastic Inference (EI) is a service that allows you to attach low-cost GPU-powered acceleration to Amazon EC2 and Amazon SageMaker instances to reduce the cost of running deep learning inference by up to 75%. With MMS it is easy to deploy a MXNet based model, taking advantage of the attachable hardware accelerator called Elastic Inference Accelerator (EIA). In this document, we explore using EIA attached to a Compute Optimized EC2 instance. ## Custom Service The capability to run model inference with the EIA can be achieved by building a custom service to use the EIA context rather than a GPU or CPU context. An MXNet version with support for EIA is required. To understand the basics of writing a custom service file refer to the [Custom Service Documentation](https://github.com/awslabs/multi-model-server/blob/master/docs/custom_service.md). ## Creating a EC2 instance with EIA support To Create an EC2 instance with EIA support there are few pre-requisites. These include: 1. [Configuring a Security Group for Amazon EI](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/working-with-ei.html#ei-security). 2. [Configure AWS PrivateLink Endpoint Services](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/working-with-ei.html#eia-privatelink). 3. [Creating a IAM Role with EI instance policy](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/working-with-ei.html#ei-role-policy). The above steps are explored in detail in [AWS Elastic Inference documentation](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/working-with-ei.html) On completing the above steps, following steps need to be followed in launching an instance with EIA support 1. Open the Amazon EC2 console at https://console.aws.amazon.com/ec2/. 2. Choose Launch Instance. 3. Choose one of the Deep Learning AMIs, we recommend Deep Learning AMI v20 or later. This is required to use MXNet with EIA. 4. Choose an Instance Type, we recommend a compute optimized EC2 instance such as c5.2xlarge. 5. Choose Next: Configure Instance Details. 6. Under Configure Instance Details, check the configuration settings. Ensure that you are using the VPC with the security groups for the instance and the Amazon EI accelerator that was set up earlier. For more information, see Configuring Your Security Groups for Amazon EI. 7. For IAM role, select the role that you created in the Configuring an Instance Role with an Amazon EI Policy procedure explained in the above documentation. 8. Select Add an Amazon EI accelerator. 9. Select the size of the Amazon EI accelerator. Your options are eia1.medium, eia1.large, and eia1.xlarge. We recommend selecting the instance size, based on the model size. For larger models, larger instances offer better performance gains. 10. (Optional) You can choose to add storage and tags by choosing Next at the bottom of the page. Or, you can let the instance wizard complete the remaining configuration steps for you. 11. Review the configuration of your instance and choose Launch. 12. You are prompted to choose an existing key pair for your instance or to create a new key pair. **WARNING: Do NOT select the Proceed without a key pair option. If you launch your instance without a key pair, then you can’t connect to it.** After making your key pair selection, choose Launch Instances. It can take a few minutes for the instance to be ready so that you can connect to it. Check that your instance has passed its status checks. You can view this information in the Status Checks column. ## Custom Service file with EIA You use two different processing contexts with MXNet and EIA 1. mxnet.cpu() - used for loading up input data 2. mxnet.eia() - used for binding network symbols and params to an attached EIA instance We modify the [base model service template](https://github.com/awslabs/multi-model-server/blob/master/examples/model_service_template/mxnet_model_service.py) to support EIA. ```python def initialize(self, context): # NOT COMPLETE CODE, refer template above for it. #.... #.... #.... # Load MXNet module # Symbol Context set to eia self.mxnet_ctx = mx.eia() self.data_ctx = mx.cpu() sym, arg_params, aux_params = mx.model.load_checkpoint(checkpoint_prefix, self.epoch) # noinspection PyTypeChecker self.mx_model = mx.mod.Module(symbol=sym, context=self.mxnet_ctx, data_names=data_names, label_names=None) self.mx_model.bind(for_training=False, data_shapes=data_shapes) self.mx_model.set_params(arg_params, aux_params, allow_missing=True, allow_extra=True) def inference(self, model_input): # NOT COMPLETE CODE, refer template above for it. #.... #.... #.... model_input = [item.as_in_context(self.data_ctx) for item in model_input] ``` The above code shows initialization of two contexts, one for data (on CPU) and other for symbols (on EIA). Once we have the code ready. We can build a model archive consumable by MMS, using the model-archiver. The [model-archiver](https://github.com/awslabs/multi-model-server/blob/master/model-archiver/README.md) tool enables to build to an archive understood by MMS. ```bash model-archiver --model-name --handler model_service:handle --export-path --model-path --runtime python ``` This will create file ```.mar``` in the directory ``````. ## Running elastic inference on a resnet-152 A pre-built ResNet-152 model archive that uses Amazon Elastic Inference can be downloaded using the following command: ```bash $ wget https://s3.amazonaws.com/model-server/model_archive_1.0/resnet-152-eia.mar ``` **NOTE:** The above archive will only work on EIA-enabled instances. Start the EIA-enabled EC2 instance. If using a Deep Learning AMI, there are two Conda environments (one for Python 2 and one for Python 3), both of which come with MXNet built will EI support and Multi Model Server. Either of the two can be used. ```bash # python 3 $ source activate amazonei_mxnet_p36 # python 2 $ source activate amazonei_mxnet_p27 ``` After entering one of the Conda environments, we start MMS, with Resnet-152 EIA model: ```bash # Start MMS $ multi-model-server --start --models resnet-152=https://s3.amazonaws.com/model-server/model_archive_1.0/resnet-152-eia.mar ``` Now the model is ready for some inference requests. Let us download a kitten image for classification: ```bash $ curl -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg $ curl -X POST http://127.0.0.1:8080/predictions/resnet-152 -T kitten.jpg ``` The predict endpoint will return a prediction response in JSON. It will look something like the following result: ```json [ { "probability": 0.7148934602737427, "class": "n02123045 tabby, tabby cat" }, { "probability": 0.22877734899520874, "class": "n02123159 tiger cat" }, { "probability": 0.04032360762357712, "class": "n02124075 Egyptian cat" }, { "probability": 0.008370809257030487, "class": "n02127052 lynx, catamount" }, { "probability": 0.0006728142034262419, "class": "n02129604 tiger, Panthera tigris" } ] ``` ResNet-152 identified the tabby cat using Elastic Inference Accelerator while being hosted on MMS. Serving on Amazon EI instances reduces inference costs while benefiting from the performance of a GPU for inference tasks. ================================================ FILE: docs/images/helpers/plugins_sdk_class_uml_diagrams.puml ================================================ @startuml Context "1" *-- "many" Model : contains Model "1" *-- "many" Worker : contains Endpoint o-- EndpointTypes ModelServerEndpoint <.. Endpoint ModelServerEndpoint <.. Context ModelServerEndpoint <.. Request ModelServerEndpoint <.. Response interface Context { +Properties getConfig() +Map getModels() } interface Request { +List getHeaderNames() +String getRequestURI() +Map> getParameterMap() +List getParameter(String k) +String getContentType() +InputStream getInputStream() } interface Response { +void setStatus(int sc) +void setStatus(int sc, String phrase) +void setHeader(String k, String v) +void addHeader(String k, String v) +void setContentType(String ct) +OutputStream getOutputStream() } interface Model { +String getModelName() +String getModelUrl() +String getModelHandler() +List getModelWorkers() } interface Worker { +boolean isRunning() +long getWorkerMemory() } interface ModelServerEndpoint { +void doGet(Request req, Response rsp, Context ctx) +void doPost(Request req, Response rsp, Context ctx) +void doDelete(Request req, Response rsp, Context ctx) +void doPut(Request req, Response rsp, Context ctx) } annotation Endpoint Endpoint : +String urlPattern() Endpoint : +EndpointTypes endpointType() Endpoint : +String description() enum EndpointTypes { NONE, INFERENCE, MANAGEMENT } @enduml ================================================ FILE: docs/inference_api.md ================================================ # Inference API Inference API is listening on port 8080 and only accessible from localhost by default. To change the default setting, see [MMS Configuration](configuration.md). There are three type of APIs: 1. [API description](#api-description) - Describe MMS inference APIs with OpenAPI 3.0 specification 2. [Health check API](#health-check-api) - Check MMS health status 3. [Predictions API](#predictions-api) - Make predictions API call to MMS ## API Description To view a full list of inference API, you can use following command: ```bash curl -X OPTIONS http://localhost:8443 ``` The out is OpenAPI 3.0.1 json format. You can use it to generate client code, see [swagger codegen](https://swagger.io/swagger-codegen/) for detail. * [Inference API description output](../frontend/server/src/test/resources/inference_open_api.json) ## Health check API MMS support a `ping` API that user can check MMS health status: ```bash curl http://localhost:8080/ping ``` Your response, if the server is running should be: ```json { "health": "healthy!" } ``` ## Predictions API MMS 1.0 support 0.4 style API calls, those APIs are deprecated, they will be removed in future release. See [Deprecated APIs](#deprecated-api) for detail. For each loaded model, user can make REST call to URI: /predictions/{model_name} * POST /predictions/{model_name} **curl Example** ```bash curl -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg curl -X POST http://localhost:8080/predictions/resnet-18 -T kitten.jpg or: curl -X POST http://localhost:8080/predictions/resnet-18 -F "data=@kitten.jpg" ``` The result was some JSON that told us our image likely held a tabby cat. The highest prediction was: ```json { "class": "n02123045 tabby, tabby cat", "probability": 0.42514491081237793, ... } ``` ## Deprecated API MMS 0.4 style predict API is kept for backward compatible purpose, and will be removed in future release. * POST /{model_name}/predict **curl Example** ```bash curl -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg curl -X POST http://localhost:8080/resnet-18/predict -F "data=@kitten.jpg" ``` ================================================ FILE: docs/install.md ================================================ # Install MMS ## Prerequisites * **Python**: Required. Multi Model Server (MMS) works with Python 2 or 3. When installing MMS, we recommend that you use a Python and Conda environment to avoid conflicts with your other Apache MXNet or Open Neural Network Exchange (ONNX) installations. * **java 8**: Required. MMS use java to serve HTTP requests. You must install java 8 (or later) and make sure java is on available in $PATH environment variable *before* installing MMS. If you have multiple java installed, you can use $JAVA_HOME environment vairable to control which java to use. For ubuntu: ```bash sudo apt-get install openjdk-8-jre-headless ``` For centos ```bash sudo yum install java-1.8.0-openjdk ``` For Mac: ```bash brew tap caskroom/versions brew update brew cask install java8 ``` You can also download and install [Oracle JDK](https://www.oracle.com/technetwork/java/javase/overview/index.html) manually if you have trouble with above commands. * **MXNet**: Recommended. MMS won't install `mxnet` by default. MXNet is required for most of examples in this project. MMS won't install mxnet engine by default, you can install mxnet-mkl or mxnet-cu90mkl based on your need. And you can also choose specific version of mxnet if you want. ```bash pip install mxnet-mkl ``` or for GPU instance: ```bash pip install mxnet-cu90mkl ``` * **Curl**: Optional. Curl is used in all of the examples. Install it with your preferred package manager. * **Unzip**: Optional. Unzip allows you to easily extract model files and inspect their content. If you choose to use it, associate it with `.mar` extensions. ## Install MMS with pip To install MMS for the first time, install Python, then run the following command: ```bash pip install multi-model-server ``` To upgrade from a previous version of MMS, run: ```bash pip install -U multi-model-server ``` ## Install MMS from Source Code If you prefer, you can clone MMS from source code. First, run the following command: ```bash git clone https://github.com/awslabs/multi-model-server.git && cd multi-model-server ``` To install MMS, run: ```bash pip install . ``` To upgrade MMS, run: ```bash pip install -U . ``` ## Install MMS for Development If you plan to develop with MMS and change some of the source code, install it from source code and make your changes executable with this command: ```bash pip install -e . ``` To upgrade MMS from source code and make changes executable, run: ```bash pip install -U -e . ``` ## Troubleshooting Installation | Issue | Solution | |---|---| |java not found, please make sure JAVA_HOME is set properly. | Make sure java is installed. java is on the $PATH or $JAVA_HOME is set properly. | |Your PYTHONPATH points to a site-packages dir for Python 3.x but you are running Python 2.x! | You do one of following:
  • use virtualenv
  • unset PYTHONPATH
  • set PYTHONPATH properly
| ================================================ FILE: docs/logging.md ================================================ # Logging on Multi Model Server In this document we will go through logging mechanism in Multi Model Server. We will also go over how to modify the behavior of logging in model-server. Logging in Multi Model Server also covers metrics, as metrics are logged into a file. To further understand how to customize metrics or define custom logging layouts, refer to the [metrics document](metrics.md) # Pre-requisites Before getting into this tutorials, you must familiarize yourself with log4j2 configuration. Refer to this online [document](https://logging.apache.org/log4j/2.x/manual/configuration.html) on how to configure the log4j2 parameters. Similarly, familiarize yourself with the default [log4j2.xml](../frontend/server/src/main/resources/log4j2.xml) used by Multi Model Server. # Types of logs Multi Model Server currently provides three types of logs. 1. Access Logs. 1. Model Server Logs. ## Access Logs: These logs collect the access pattern to Multi Model Server. The configuration pertaining to access logs are as follows, ```xml ``` As defined in the properties file, the access logs are collected in {LOG_LOCATION}/access_log.log file. When we load the model server with a model and run inference against the server, the following logs are collected into the access_log.log ```text 2018-10-15 13:56:18,976 [INFO ] BackendWorker-9000 ACCESS_LOG - /127.0.0.1:64003 "POST /predictions/resnet-18 HTTP/1.1" 200 118 ``` The above log tells us that a successful `POST` call to `/predictions/resnet-18` was made by remote host `127.0.0.1:64003` it took `118`ms to complete this request. These logs are useful to determine the current performance of the model-server as well as understand the requests received by model-server. ## Model Server Logs These logs collect all the logs from Model Server and from the backend workers (the custom model code). The default configuration pertaining to mms logs are as follows: ```xml ``` This configuration by default dumps all the logs above `DEBUG` level. ### Generating and logging custom logs As a user of Multi Model Server(MMS), you might want to log custom logs into the log files. This could be for debug purposes or to log any errors. To accomplish this, simply print the required logs to `stdout/stderr`. MMS will capture the logs generated by the backend workers and log it into the log file. Some examples of logs are as follows 1. Messages printed to stderr ```text 2018-10-14 16:46:51,656 [WARN ] W-9000-stderr com.amazonaws.ml.mms.wlm.WorkerLifeCycle - [16:46:51] src/nnvm/legacy_json_util.cc:209: Loading symbol saved by previous version v0.8.0. Attempting to upgrad\ e... 2018-10-14 16:46:51,657 [WARN ] W-9000-stderr com.amazonaws.ml.mms.wlm.WorkerLifeCycle - [16:46:51] src/nnvm/legacy_json_util.cc:217: Symbol successfully upgraded! ``` 1. Messages printed to stdout ```text 2018-10-14 16:59:59,926 [INFO ] W-9000-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - preprocess time: 3.60 2018-10-14 16:59:59,926 [INFO ] W-9000-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - inference time: 117.31 2018-10-14 16:59:59,926 [INFO ] W-9000-stdout com.amazonaws.ml.mms.wlm.WorkerLifeCycle - postprocess time: 8.52 ``` # Modifying the behavior of the logs In order to modify the default behavior of the logging, you could define `log4j2.xml` file. There are two ways of starting model server with custom logs ### Provide with config.properties Once you define custom `log4j2.xml`, add this to the `config.properties` file as follows ```properties vmargs=-Dlog4j.configurationFile=file:///path/to/custom/log4j2.xml ``` Then start the model server as follows ```bash $ multi-model-server --start --mms-config /path/to/config.properties ``` ### Provide with CLI Alternatively, you could start the model server with the following command as well ```bash $ multi-model-server --start --log-config /path/to/custom/log4j2.xml ``` # Enable asynchronous logging If your model is super lightweight and seeking for high throughput, you can consider enable asynchronous logging. Note that log output maybe delayed and latest log might be lost if MMS is terminated unexpectedly. asynchronous logging is disabled by default. To enable asynchronous logging, add following property in `config.properties`: ```properties async_logging=true ``` ================================================ FILE: docs/management_api.md ================================================ # Management API MMS provides a set of API allow user to manage models at runtime: 1. [Register a model](#register-a-model) 2. [Increase/decrease number of workers for specific model](#scale-workers) 3. [Describe a model's status](#describe-model) 4. [Unregister a model](#unregister-a-model) 5. [List registered models](#list-models) Management API is listening on port 8081 and only accessible from localhost by default. To change the default setting, see [MMS Configuration](configuration.md). Similar as [Inference API](inference_api.md), Management API also provide a [API description](#api-description) to describe management APIs with OpenAPI 3.0 specification. ## Management APIs ### Register a model `POST /models` * url - Model archive download url. Supports the following locations: * a local model archive (.mar); the file must be directly in model_store folder. * a local model directory; the directory must be directly in model_store folder. This option can avoid MMS extracting .mar file to temporary folder, which will improve load time and reduce disk space usage. * a URI using the HTTP(s) protocol. MMS can download .mar files from the Internet. * model_name - the name of the model; this name will be used as {model_name} in other API as path. If this parameter is not present, modelName in MANIFEST.json will be used. * handler - the inference handler entry-point. This value will override `handler` in MANIFEST.json if present. **NOTE: Make sure that the given `handler` is in the `PYTHONPATH`. The format of handler is `module_name:method_name`.** * runtime - the runtime for the model custom service code. This value will override runtime in MANIFEST.json if present. The default value is `PYTHON`. * batch_size - the inference batch size. The default value is `1`. * max_batch_delay - the maximum delay for batch aggregation. The default value is 100 milliseconds. * initial_workers - the number of initial workers to create. The default value is `0`. MMS will not run inference until there is at least one work assigned. * synchronous - whether or not the creation of worker is synchronous. The default value is false. MMS will create new workers without waiting for acknowledgement that the previous worker is online. * response_timeout - If the model's backend worker doesn't respond with inference response within this timeout period, the worker will be deemed unresponsive and rebooted. The units is seconds. The default value is 120 seconds. ```bash curl -X POST "http://localhost:8081/models?url=https%3A%2F%2Fs3.amazonaws.com%2Fmodel-server%2Fmodel_archive_1.0%2Fsqueezenet_v1.1.mar" { "status": "Model \"squeezenet_v1.1\" registered" } ``` User may want to create workers while register, creating initial workers may take some time, user can choose between synchronous or synchronous call to make sure initial workers are created properly. The asynchronous call will return before trying to create workers with HTTP code 202: ```bash curl -v -X POST "http://localhost:8081/models?initial_workers=1&synchronous=false&url=https%3A%2F%2Fs3.amazonaws.com%2Fmodel-server%2Fmodel_archive_1.0%2Fsqueezenet_v1.1.mar" < HTTP/1.1 202 Accepted < content-type: application/json < x-request-id: 29cde8a4-898e-48df-afef-f1a827a3cbc2 < content-length: 33 < connection: keep-alive < { "status": "Worker updated" } ``` The synchronous call will return after all workers has be adjusted with HTTP code 200. ```bash curl -v -X POST "http://localhost:8081/models?initial_workers=1&synchronous=true&url=https%3A%2F%2Fs3.amazonaws.com%2Fmodel-server%2Fmodel_archive_1.0%2Fsqueezenet_v1.1.mar" < HTTP/1.1 200 OK < content-type: application/json < x-request-id: c4b2804e-42b1-4d6f-9e8f-1e8901fc2c6c < content-length: 32 < connection: keep-alive < { "status": "Worker scaled" } ``` ### Scale workers `PUT /models/{model_name}` * min_worker - (optional) the minimum number of worker processes. MMS will try to maintain this minimum for specified model. The default value is `1`. * max_worker - (optional) the maximum number of worker processes. MMS will make no more that this number of workers for the specified model. The default is the same as the setting for `min_worker`. * number_gpu - (optional) the number of GPU worker processes to create. The default value is `0`. If number_gpu exceeds the number of available GPUs, the rest of workers will run on CPU. * synchronous - whether or not the call is synchronous. The default value is `false`. * timeout - the specified wait time for a worker to complete all pending requests. If exceeded, the work process will be terminated. Use `0` to terminate the backend worker process immediately. Use `-1` to wait infinitely. The default value is `-1`. **Note:** not implemented yet. Use the Scale Worker API to dynamically adjust the number of workers to better serve different inference request loads. There are two different flavour of this API, synchronous vs asynchronous. The asynchronous call will return immediately with HTTP code 202: ```bash curl -v -X PUT "http://localhost:8081/models/noop?min_worker=3" < HTTP/1.1 202 Accepted < content-type: application/json < x-request-id: 74b65aab-dea8-470c-bb7a-5a186c7ddee6 < content-length: 33 < connection: keep-alive < { "status": "Worker updated" } ``` The synchronous call will return after all workers has be adjusted with HTTP code 200. ```bash curl -v -X PUT "http://localhost:8081/models/noop?min_worker=3&synchronous=true" < HTTP/1.1 200 OK < content-type: application/json < x-request-id: c4b2804e-42b1-4d6f-9e8f-1e8901fc2c6c < content-length: 32 < connection: keep-alive < { "status": "Worker scaled" } ``` ### Describe model `GET /models/{model_name}` Use the Describe Model API to get detail runtime status of a model: ```bash curl http://localhost:8081/models/noop { "modelName": "noop", "modelVersion": "snapshot", "modelUrl": "noop.mar", "engine": "MXNet", "runtime": "python", "minWorkers": 1, "maxWorkers": 1, "batchSize": 1, "maxBatchDelay": 100, "workers": [ { "id": "9000", "startTime": "2018-10-02T13:44:53.034Z", "status": "READY", "gpu": false, "memoryUsage": 89247744 } ] } ``` ### Unregister a model `DELETE /models/{model_name}` Use the Unregister Model API to free up system resources: ```bash curl -X DELETE http://localhost:8081/models/noop { "status": "Model \"noop\" unregistered" } ``` ### List models `GET /models` * limit - (optional) the maximum number of items to return. It is passed as a query parameter. The default value is `100`. * next_page_token - (optional) queries for next page. It is passed as a query parameter. This value is return by a previous API call. Use the Models API to query current registered models: ```bash curl "http://localhost:8081/models" ``` This API supports pagination: ```bash curl "http://localhost:8081/models?limit=2&next_page_token=2" { "nextPageToken": "4", "models": [ { "modelName": "noop", "modelUrl": "noop-v1.0" }, { "modelName": "noop_v0.1", "modelUrl": "noop-v0.1" } ] } ``` ## API Description `OPTIONS /` To view a full list of inference and management APIs, you can use following command: ```bash # To view all inference APIs: curl -X OPTIONS http://localhost:8080 # To view all management APIs: curl -X OPTIONS http://localhost:8081 ``` The out is OpenAPI 3.0.1 json format. You use it to generate client code, see [swagger codegen](https://swagger.io/swagger-codegen/) for detail. Example outputs of the Inference and Management APIs: * [Inference API description output](../frontend/server/src/test/resources/inference_open_api.json) * [Management API description output](../frontend/server/src/test/resources/management_open_api.json) ================================================ FILE: docs/metrics.md ================================================ # Metrics on Model Server ## Contents of this Document * [Introduction](#introduction) * [System metrics](#system-metrics) * [Formatting](#formatting) * [Custom Metrics API](#custom-metrics-api) ## Introduction MMS collects system level metrics in regular intervals, and also provides an API for custom metrics to be collected. Metrics collected by metrics are logged and can be aggregated by metric agents. The system level metrics are collected every minute. Metrics defined by the custom service code, can be collected per request or a batch of requests. MMS logs these two sets of metrics to different log files. Metrics are collected by default at: * System metrics - log_directory/mms_metrics.log * Custom metrics - log directory/model_metrics.log The location of log files and metric files can be configured at [log4j2.xml](https://github.com/awslabs/multi-model-server/blob/master/frontend/server/src/main/resources/log4j2.xml) file. ## System Metrics | Metric Name | Dimension | Unit | Semantics | |---|---|---|---| | CPUUtilization | host | percentage | cpu utillization on host | | DiskAvailable | host | GB | disk available on host | | DiskUsed | host | GB | disk used on host | | DiskUtilization | host | percentage | disk used on host | | MemoryAvailable | host | MB | memory available on host | | MemoryUsed | host | MB | memory used on host | | MemoryUtilization | host | percentage | memory used on host | | Requests2XX | host | count | total number of requests that responded in 200-300 range | | Requests4XX | host | count | total number of requests that responded in 400-500 range | | Requests5XX | host | count | total number of requests that responded above 500 | ## Formatting The metrics emitted into log files by default, is in a [StatsD](https://github.com/etsy/statsd) like format. ```bash CPUUtilization.Percent:0.0|#Level:Host|#hostname:my_machine_name MemoryUsed.Megabytes:13840.328125|#Level:Host|#hostname:my_machine_name ``` To enable metric logging in JSON format, we can modify the log formatter in [log4j2.xml](https://github.com/awslabs/multi-model-server/blob/master/frontend/server/src/main/resources/log4j2.xml), This is explained in the logging [document](https://github.com/awslabs/multi-model-server/blob/master/docs/logging.md). Once enabled the format emitted to logs, will look as follows ```json { "MetricName": "DiskAvailable", "Value": "108.15547180175781", "Unit": "Gigabytes", "Dimensions": [ { "Name": "Level", "Value": "Host" } ], "HostName": "my_machine_name" } { "MetricName": "DiskUsage", "Value": "124.13163757324219", "Unit": "Gigabytes", "Dimensions": [ { "Name": "Level", "Value": "Host" } ], "HostName": "my_machine_name" } ``` ## Custom Metrics API MMS enables the custom service code to emit metrics, that are then logged by the system The custom service code is provided with a [context](https://github.com/awslabs/multi-model-server/blob/master/mms/context.py) of the current request. Which has metrics object. ```python # Access context metrics as follows metrics = context.metrics ``` All metrics collected with in the context ### Creating dimension object(s) Dimensions for metrics can be defined as objects ```python from mms.metrics import dimension # Dimensions are name value pairs dim1 = Dimension(name, value) dim2 = Dimension(some_name, some_value) . . . dimN= Dimension(name_n, value_n) ``` **NOTE:** Metric functions below accept a list of dimensions ### Add generic metrics One can add metrics with generic units using the following function. Function API ```python def add_metric(name, value, idx=None, unit=None, dimensions=None): """ Add a metric which is generic with custom metrics Parameters ---------- name : str metric name value: int, float value of metric idx: int request_id index in batch unit: str unit of metric dimensions: list list of dimensions for the metric """ ``` ```python # Add Distance as a metric # dimensions = [dim1, dim2, dim3, ..., dimN] # Assuming batch size is 1 for example metrics.add_metric('DistanceInKM', distance, 'km', dimensions) ``` ### Add Time based metrics Time based metrics can be added by invoking the following method Function API ```python def add_time(name, value, idx=None, unit='ms', dimensions=None): """ Add a time based metric like latency, default unit is 'ms' Parameters ---------- name : str metric name value: int value of metric idx: int request_id index in batch unit: str unit of metric, default here is ms, s is also accepted dimensions: list list of dimensions for the metric """ ``` Note that the default unit in this case is 'ms' **Supported units**: ['ms', 's'] To add custom time based metrics ```python # Add inference time # dimensions = [dim1, dim2, dim3, ..., dimN] # Assuming batch size is 1 for example metrics.add_time('InferenceTime', end_time-start_time, None, 'ms', dimensions) ``` ### Add Size based metrics Size based metrics can be added by invoking the following method Function API ```python def add_size(name, value, idx=None, unit='MB', dimensions=None): """ Add a size based metric Parameters ---------- name : str metric name value: int, float value of metric idx: int request_id index in batch unit: str unit of metric, default here is 'MB', 'kB', 'GB' also supported dimensions: list list of dimensions for the metric """ ``` Note that the default unit in this case is 'ms' **Supported units**: ['MB', 'kB', 'GB'] To add custom size based metrics ```python # Add Image size as a metric # dimensions = [dim1, dim2, dim3, ..., dimN] # Assuming batch size is 1 for example metrics.add_size('SizeOfImage', img_size, None, 'MB', dimensions) ``` ### Add Percentage based metrics Percentage based metrics can be added by invoking the following method Function API ```python def add_percent(name, value, idx=None, dimensions=None): """ Add a percentage based metric Parameters ---------- name : str metric name value: int, float value of metric idx: int request_id index in batch dimensions: list list of dimensions for the metric """ ``` To add custom percentage based metrics ```python # Add MemoryUtilization as a metric # dimensions = [dim1, dim2, dim3, ..., dimN] # Assuming batch size is 1 for example metrics.add_percent('MemoryUtilization', utilization_percent, None, dimensions) ``` ### Add Counter based metrics Percentage based metrics can be added by invoking the following method Function API ```python def add_counter(name, value, idx=None, dimensions=None): """ Add a counter metric or increment an existing counter metric Parameters ---------- name : str metric name value: int value of metric idx: int request_id index in batch dimensions: list list of dimensions for the metric """ ``` To create , increment and decrement counter based metrics we can use the following calls ```python # Add Loop Count as a metric # dimensions = [dim1, dim2, dim3, ..., dimN] # Assuming batch size is 1 for example # Create a counter with name 'LoopCount' and dimensions, initial value metrics.add_counter('LoopCount', 1, None, dimensions) # Increment counter by 2 metrics.add_counter('LoopCount', 2 , None, dimensions) # Decrement counter by 1 metrics.add_counter('LoopCount', -1, None, dimensions) # Final counter value in this case is 2 ``` ================================================ FILE: docs/migration.md ================================================ # Migration from MMS 0.4 MMS 1.0 is a major release that contains significant architecture improvement based on MMS 0.4. MMS 1.0 adopted micro-services based architecture, the frontend request handler is separated from backend inference worker. The frontend is a java based web service which provide REST API, and backend is python based worker which execute custom service code (Other language worker support is also planed). ## Table of Content * [Installation](#installation) * [Command line interface](#command-line-interface) * [API](#api) * [Model archive](#model-archive) * [Docker container](#docker-container) * [Logging](#logging) * [Metrics](#metrics) * [Configuration](#configuration) * [SSL](#ssl) ## Installation MMS 1.0 made following changes for pip installation package: * **java 8**: java is required for MMS 1.0. You must install java 8 (or later) and make sure java is on available in $PATH environment variable *before* installing MMS. If you have multiple java installed, you can use $JAVA_HOME environment vairable to control which java to use. * **mxnet**: `mxnet` will not be installed by default with MMS 1.0 any more. You have to install it manually. See more detail: [Install MMS](install.md) ## Command line interface MMS 1.0 made some parameter changes in `mxnet-model-server` command line tool. The following command line parameters from previous versions will no longer function: * --service, See [Register model](management_api.md#register-a-model) for how to override service entry-point. * --gen-api, See [API description](inference_api#api-description) for how to generate your swagger client code. * --port, See [Configure MMS listening port](configuration.md#configure-mms-listening-port) for how to configure MMS listening ports. * --host, See [Configure MMS listening port](configuration.md#configure-mms-listening-port) for how to bind to specific network interface. * --gpu, See [Config properties](configuration.md#other-properties) for how to limit the number of GPUs. * --log-file, See [Logging](#logging) for how to specify a log file. * --log-rotation-time, See [Logging](#logging) for how to configure log rotation. * --log-level, See [Logging](#logging) for how to configure log level. * --metrics-write-to, See [Metrics](#metrics) for how to configure metrics. For further information on the parameters' updates, please see the [Command Line Interface](server.md#command-line-interface) section of the server documentation. ## API You can continue to use MMS 0.4 inference API in MMS 1.0. However they are deprecated. Please migrate to new [inference API](inference_api.md) ## Model archive You can continue to use your existing MMS 0.4 model archive (`.model` file). We stronger recommend you to migrate to new Model archive (`.mar`) format. Please refer to following documents: * [Custom service code](custom_service.md) * [model-archiver tool](../model-archiver/README.md) * [Create model archive example](../examples/mxnet_vision/README.md) ### model-archiver `mxnet-model-export` is no longer supported. Instead we release a `model-archiver` CLI tool. `model-archiver` now can be installed standalone: ```bash pip install model-archiver ``` See [model-archiver](../model-archiver/README.md) for more detail. ## Docker container MMS docker image makes it easier for you to serve a model. In 0.4 release, MMS require a configuration file (mms_app_cpu.conf or mms_app_gpu.conf) to start MMS in docker container. The old conf file format is no longer supported. To make it simple, MMS no longer requires the --mms-config parameter, the default configuration should work for most of use cases. MMS will start automatically while docker container starts: ```bash docker run -itd --name mms -p 80:8080 -p 8081:8081 awsdeeplearningteam/multi-model-server ``` After docker container started, you can use [Management API](management_api.md) to load models for inference. See [Docker Image](../docker/README.md) for detail. ## Logging MMS 1.0 provides highly customizable logging feature. MMS 0.4 logging parameter (--log-file, , --log-rotation-time and --log-level) in command line is not supported. For more detail see [Logging configuration](logging.md) ## Metrics MMS 1.0 redesigned metrics feature: * The old --metrics-write-to parameter is not supported, instead a rich configuration is provided. * The built-in ClouldWatch integration is removed, instead MMS 1.0 provide a template allows user to integrated with any metrics server. See [Metrics](metrics.md) for more detail. ## Configuration MMS 1.0 provide a rich set of configuration parameters allow advanced user to customize/tune MMS. A completely new set of parameters are introduced in new config.properties file. The MMS 0.4 format of configuration file is not supported any more. See [Advanced configuration](configuration.md) for more detail. ### SSL MMS 0.4 support SSL via nginx, now MMS 1.0 provide native SSL support. See [Enable SSL](configuration.md#enable-ssl) for detail. ================================================ FILE: docs/mms_endpoint_plugins.md ================================================ # Introduction In this document, we will go over how to build and load custom endpoints for MMS. We will go over the plugins based architecture for MMS and user experience. # Plugins SDK MMS currently provides an SDK for customers to develop their custom URL endpoints and drop those endpoints into MMS for custom URL handling. The SDK is currently published through Maven Central. Let's go over what is available in the SDK and how we could use this SDK to build our custom Endpoint. # Plugins SDK The Model Server plugins SDK is distributed through Maven Central. Find it on [Nexus Repository Manager](https://oss.sonatype.org/#nexus-search;quick~software.amazon.ai) The plugins SDK has multiple components. The following are the main classes: 1. **ModelServerEndpoint** - This is the main class used to create a custom endpoint. 2. **Context** - This contains the context for model-server. Context object contains methods to read and modify the behavior of model-server. 3. **Worker** - This object contains all the information pertaining to a worker # Plugins architecture ### MMS plugins loading at startup At startup, MMS reads the configuration file to get the plugins directory. This plugins directory contains endpoint jars. MMS loads all the jars implementing endpoints. ![](images/mms_plugins_startup.jpg) As seen in the above diagram, MMS loads all the endpoint "jars" implementing "ModelServerEndpoint" and registers them to "Management" or "Inference" channels. ### MMS plugins at runtime When the client invokes "endpoint_1", which was loaded at startup time above, MMS looks up the registered endpoints and validates if the request is valid. If the request is valid, the custom endpoint is scheduled in a separate thread-pool. Once the custom endpoint finishes running, the output is sent back to the client by the thread pool. ![](images/mms_plugins_runtime.jpg) ## Class diagram of the MMS custom-endpoint SDK: ![](images/plugins_sdk_class_diagram.png) # Writing your own custom endpoint In this section we will go cover how we could develop our own endpoint and test the endpoint with MMS. We will be developing an endpoint called "GET /execution-parameters" endpoint, which will return a set of configuration parameters such as "MAX_CONCURRENT_TRANSFORMS", "BATCH_STRATEGY", "MAX_PAYLOAD_IN_MB", "BATCH". ## Include Maven dependencies for your project In the Maven dependency section, include the following to get the plugins-sdk for MMS. ```xml software.amazon.ai mms-plugins-sdk 1.0.1 ``` This will bring in Plugins SDK for MMS into your project. Now we are ready to build our custom endpoint! ## Build endpoint and register the endpoint as service We create a project with a java file called "ExecutionParameters" and also create a "META-INF/services/software.amazon.ai.mms.servingsdk.ModelServerEndpoint" file which acts as registration for this service. Lets look into what goes into each of these folders. ### Project structure ![](images/project_structure.png) ### ExecutionParameters Here we define the the behavior of the endpoint as following ```java @Endpoint( urlPattern = "execution-parameters", endpointType = EndpointTypes.INFERENCE, description = "Execution parameters endpoint") public class ExecutionParameters extends ModelServerEndpoint { @Override public void doGet(Request req, Response rsp, Context ctx) throws IOException { Properties prop = ctx.getConfig(); HashMap r = new HashMap<>(); r.put("MAX_CONCURRENT_TRANSFORMS", prop.getProperty("NUM_WORKERS", "1")); r.put("BATCH_STRATEGY", "SINGLE_RECORD"); r.put("MAX_PAYLOAD_IN_MB", prop.getProperty("max_request_size")); r.put("BATCH", "true"); rsp.getOutputStream() .write( new GsonBuilder() .setPrettyPrinting() .create() .toJson(r) .getBytes(StandardCharsets.UTF_8)); } } ``` Here we have annotated the class with the name of the URL as "execution-parameters" and the type of the endpoint as "INFERENCE". This endpoint only implements a "doGet" method. In other words, this endpoint only supports "GET" method. The endpoint also returns the output by writing it to the output stream. ### Service file software.amazon.ai.mms.servingsdk.ModelServerEndpoint The service file is used for service loading on the MMS. The contents of this file are as follows ```text software.amazon.ai.mms.plugins.endpoint.ExecutionParameters ``` As the contents show, when loading this JAR, MMS loads the "ExecutionParameters" class, which in turn defines the functionality of "execution-parameters" endpoint. ### Getting the endpoint jar Once the project is built, the generated JAR is the custom endpoint jar. The project in this example builds "**execution-parameters.jar**" . # Loading the custom endpoint on MMS * Lets place this JAR in "/tmp/plugins" directory ```bash $ ls /tmp/plugins execution-parameters.jar ``` * Configure MMS to load all the endpoints in plugins directory. Contents of **config.properties** are: ```properties plugins_path=/tmp/plugins ``` * Start model server with this configuration file ```bash multi-model-server --start --mms-config config.properties ``` * MMS will load the endpoints configured in this directory and be ready to serve requests to this endpoint. * Test the endpoint ```bash curl -v 4 localhost:8080/execution-parameters * TCP_NODELAY set * Connected to localhost (::1) port 8080 (#1) > GET /execution-parameters HTTP/1.1 > Host: localhost:8080 > User-Agent: curl/7.54.0 > Accept: */* > < HTTP/1.1 200 OK < x-request-id: de5b2255-33ff-4d75-bed4-b24eb7820dec < Pragma: no-cache < Cache-Control: no-cache; no-store, must-revalidate, private < Expires: Thu, 01 Jan 1970 00:00:00 UTC < content-length: 94 < connection: keep-alive < { "BATCH_STRATEGY": "SINGLE_RECORD", "MAX_CONCURRENT_TRANSFORMS": "1", "BATCH": "true" } * Connection #1 to host localhost left intact ``` ## Conclusion For more implementations of the custom endpoints, please refer the [plugins](../plugins) repository. ================================================ FILE: docs/mms_on_fargate.md ================================================ # Serverless Inference with MMS on FARGATE This is self-contained step by step guide that shows how to create launch and server your deep learning models with MMS in a production setup. In this document you will learn how to launch MMS with AWS Fargate, in order to achieve a serverless inference. ## Prerequisites Even though it is fully self-contained we do expect the reader to have some knowledge about the following topics: * [MMS](https://github.com/awslabs/multi-model-server) * [What is Amazon Elastic Container Service (ECS)](https://aws.amazon.com/ecs) * [What is Fargate](https://aws.amazon.com/fargate) * [What is Docker](https://www.docker.com/) and how to use containers Since we are doing inference, we need to have a pre-trained model that we can use to run inference. For the sake of this article, we will be using [SqueezeNet model](https://github.com/awslabs/multi-model-server/blob/master/docs/model_zoo.md#squeezenet_v1.1). In short, SqueezeNet is a model that allows you to recognize objects in a picture. Now that we have the model chosen, let's discuss at a high level what our pure-container based solution will look like: ![architecture](https://s3.amazonaws.com/multi-model-server/mms-github-docs/MMS+with+Fargate+Article/AWS+Fargate+MMS.jpg) In this document we are going to walk you through all the steps of setting up MMS 1.0 on Amazon Fargate services. The steps in this process are as follows: 1. Familiarize yourself with MMS containers 2. Create a SqueezeNet task definition (with the docker container of MMS) 3. Create AWS Fargate cluster 4. Create Application Load Balancer 5. Create Squeezenet Fargate service on the cluster 6. Profit! Let the show begin... ## Familiarize Yourself With Our Containers With the current release of [MMS, 1.0](https://github.com/awslabs/multi-model-server/releases/tag/v1.0.0), Official pre-configured, optimized container images of MMS are provided on [Docker hub](https://hub.docker.com). * [awsdeeplearningteam/multi-model-server](https://hub.docker.com/r/awsdeeplearningteam/multi-model-server) ```bash docker pull awsdeeplearningteam/multi-model-server # for gpu image use following command: docker pull awsdeeplearningteam/multi-model-server:latest-gpu ``` In our article we are going to use the official CPU container image. One major constraint for using Fargate service is that there is currently no support for GPU on Fargate. The model-server container comes with a configuration file pre-baked inside the container. It is highly recommended that you understand all the parameters of the MMS configuration file. Familiarize yourself with the [MMS configuration](https://github.com/awslabs/multi-model-server/blob/master/docs/configuration.md) and [configuring MMS Container docs](https://github.com/awslabs/multi-model-server/blob/master/docker/README.md). When you want to launch and host your custom model, you will have to update this configuration. In this tutorial, we will be use the squeezenet model from the following S3 link. ``` https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar ``` Since MMS can consume model files from S3 buckets, we wouldn't need to bake the containers with the actual model files. The last question that we need to address: how we should be starting our MMS within our container. And the answer is very simple, you just need to set the following [ENTRYPOINT](https://docs.docker.com/engine/reference/builder/#entrypoint): ```bash multi-model-server --start --models https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar ``` You will now have a running container serving squeezenet model. At this point, you are ready to start creating actual task definition. **Note**: To start multiple models with the model-server, you could run the following command with multiple model names ```bash # Example, following command starts model server with Resnet-18 and Squeezenet V1 models $ multi-model-server --start --models https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar https://s3.amazonaws.com/model-server/model_archive_1.0/resnet-18.mar ``` ## Create an AWS Fargate task to serve SqueezeNet model This is the first step towards getting your own "inference service" up and running in a production setup. 1. Login to the AWS console and go to the Elastic Cloud Service -> Task Definitions and Click “Create new Task Definition”: ![task def](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/1_Create_task_definition.png) 2. Now you need to specify the type of the task, you will be using the Fargate task: ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/2_Select_Fargate.png) 3. The task requires some configuration, let's look at it step by step. First set the name: ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/3_Config_1.png) Now is important part, you need to create a [IAM role](https://aws.amazon.com/iam) that will be used to publish metrics to CloudWatch: ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/Task+Execution+IAM+Role+.png) The containers are optimized for 8 vCPUs, however in this example you are going to use slightly smaller task with 4 vCPUs and 8 GB of RAM: ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/cpu+and+ram.png) 4. Now it is time to configure the actual container that the task should be executing. ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/container+step+1.png)

*Note:* If you are using a [custom container](https://github.com/awslabs/multi-model-server/blob/master/docs/mms_on_fargate.md#customize-the-containers-to-serve-your-custom-deep-learning-models), make sure to first upload your container to Amazon ECR or Dockerhub and replace the link in this step with the link to your uploaded container. 5. The next task is to specify the port mapping. You need to expose container port 8080. This is the port that the MMS application inside the container is listening on. If needed it can be configured via the config [here](https://github.com/awslabs/multi-model-server/blob/master/docker/mms_app_cpu.conf#L40). ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/port+8080.png) Next, you will have to configure the health-checks. This is the command that ECS should run to find out whether MMS is running within the container or not. MMS has a pre-configured endpoint `/ping` that can be used for health checks. Configure ECS to reach that endpoint at `http://127.0.0.1:8080/ping` using the `curl` command as shown below: ```bash curl, http://127.0.0.1:8080/ping ``` The healthcheck portion of your container configuration should look like the image below: ![](https://s3.amazonaws.com/multi-model-server/mms-github-docs/MMS+with+Fargate+Article/add+container+healthcheck.png) After configuring the health-checks, you can go onto configuring the environment, with the entry point that we have discussed earlier: ![](https://s3.amazonaws.com/multi-model-server/mms-github-docs/MMS+with+Fargate+Article/environtment.png) Everything else can be left as default. So feel free to click `Create` to create your very first AWS Fargate-task. If everything is ok, you should now be able to see your task in the list of task definitions. In ECS, `Services` are created to run Tasks. A service is in charge of running multiple tasks and making sure the that required number of tasks are always running, restarting un-healthy tasks, adding more tasks when needed. To have your `inference service` accessible over the Internet, you would need to configure a load-balancer (LB). This LB will be in charge of serving the traffic from the Internet and redirecting it to these newly created tasks. Let's create an Application Load Balancer now: ## Create a Load Balancer AWS supports several different types of Load Balancers: * Application Load Balancer: works on the level 7 of the OSI model (effectively with the HTTP/HTTPS protocols) * TCP Load Balancer For your cluster you are going to use application load balancer. 1. Login to the EC2 Console. 2. Go to the “Load balancers” section. 3. Click on Create new Load Balancer. ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/1__Create_Load_Balancer.png) 5. Choose Application Load Balancer. ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/2__HTTP_HTTPS+.png) 6. Set all the required details. **Make a note of the VPC of the LB**. This is important since the LB's VPC and the ECS cluster's VPC need to be same for them to communicate with each other. ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/3_2_Listeners_and_AZ+.png) 7. Next is configuring the security group. This is also important. Your security group should: * Allow inbound connections for port 80 (since this is the port on which LB will be listening on) * LBs security group needs to be added to the AWS Fargate service's security group, so that all the traffic from LB is accepted by your "inference service". ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/4.+Configure+Security+groups.png) 8. Routing configuration is simple. Here you need to create a “target group”. But, in your case the AWS Fargate service, that you will create later, will automatically create a target group. Therefore you will create dummy “target group” that you will delete after the creation of the LB. ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/5.+Configure+Routing+(DUmmy).png) 9. Nothing needs to be done for the last two steps. `Finish` the creation and ... 10. Now you are ready to remove dummy listener and target group ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/8__Delete_the_dummy_listener.png) Now that you are `done-done-done` with the Load Balancer creation, lets move onto creating our Serverless inference service. ## Creating an ECS Service to launch our AWS Fargate task 1. Go to Elastic Container Service → Task Definitions and select the task definitions name. Click on actions and select create service. ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/1.+Go+to+task+definitions.png) 2. There are two important things on the first step (apart from naming): * Platform version: It should be set to 1.1.0 . * Number of tasks that the service should maintain as healthy all of the time, for this example you will set this to 3. ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/number+of+tasks.png) 3. Now it is time to configure the VPC and the security group. **You should use the same VPC that was used for the LB (and same subnets!).** ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/3.2.1+Use+the+existing+VPC+Edit+sg.png) 4. As for the security group, it should be either the same security group as you had for the LB, or the one that accepts traffic from the LBs security group. ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/3.2.2+SG+Use+existing.png) 5. Now you can connect your service to the LB that was created in the previous section. Select the "Application Load Balancer" and set the LB name: ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/3.2.3+Add+load+balancing.png) 6. Now you need to specify which port on the LB our service should be listening on: ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/3.2.4+Configure+load+blancer.png) 7. You are not going to use service discovery now, so uncheck it: ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/3.2.5+Next.png) 8. In this document, we are not using auto-scaling options. For an actual production system, it is advisable to have this configuration setup. ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/3.3+Auto+scaling.png) 9. Now you are `done-done-done` creating a running service. You can move to the final chapter of the journey, which is testing the service you created. ## Test your service First find the DNS name of your LB. It should be in `AWS Console -> Service -> EC2 -> Load Balancers` and click on the LB that you created. ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/lb_dns.png) Now you can run the health checks using this load-balancer public DNS name, to verify that your newly created service is working: ```bash curl InfraLb-1624382880.us-east-1.elb.amazonaws.com/ping ``` ```text http://infralb-1624382880.us-east-1.elb.amazonaws.com/ping { "status": "Healthy!" } ``` And now you are finally ready to run our inference! Let's download an example image: ```bash curl -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg ``` The image: ![](https://s3.amazonaws.com/mms-github-assets/MMS+with+Fargate+Article/kitten.jpg) The output of this query would be as follows, ```bash curl -X POST InfraLb-1624382880.us-east-1.elb.amazonaws.com/predictions/squeezenet_v1.1 -F "data=@kitten.jpg" ``` ```text { "prediction": [ [ { "class": "n02124075 Egyptian cat", "probability": 0.8515275120735168 }, { "class": "n02123045 tabby, tabby cat", "probability": 0.09674164652824402 }, { "class": "n02123159 tiger cat", "probability": 0.03909163549542427 }, { "class": "n02128385 leopard, Panthera pardus", "probability": 0.006105933338403702 }, { "class": "n02127052 lynx, catamount", "probability": 0.003104303264990449 } ] ] } ``` ## Instead of a Conclusion There are a few things that we have not covered here and which are very useful, such as: * How to configure auto-scaling on our ECS cluster. * Running A/B testing of different versions of the model with the Fargate Deployment concepts. Each of the above topics require their own articles, so stay tuned!! ## Authors * Aaron Markham * Vamshidhar Dantu * Viacheslav Kovalevskyi (@b0noi) ================================================ FILE: docs/model_zoo.md ================================================ # Model Zoo This page lists model archives that are pre-trained and pre-packaged, ready to be served for inference with MMS. To propose a model for inclusion, please submit a [pull request](https://github.com/awslabs/multi-model-server/pulls). *Special thanks to the [Apache MXNet](https://mxnet.incubator.apache.org) community whose Model Zoo and Model Examples were used in generating these model archives.* | Model File | Type | Dataset | Source | Size | Download | | --- | --- | --- | --- | --- | --- | | [AlexNet](#alexnet) | Image Classification | ImageNet | ONNX | 233 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/alexnet.mar) | | [ArcFace-ResNet100](#arcface-resnet100_onnx) | Face Recognition | Refined MS-Celeb1M | ONNX | 236.4 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-arcface-resnet100.mar) | | [Character-level Convolutional Networks for Text Classification](#crepe) | Text Classification | [Amazon Product Data](http://jmcauley.ucsd.edu/data/amazon/) | Gluon | 40 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/crepe.mar) | | [CaffeNet](#caffenet) | Image Classification | ImageNet | MXNet | 216 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/caffenet/caffenet.mar) | | [FERPlus](#ferplus_onnx) | Emotion Detection | FER2013 | ONNX | 35MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/FERPlus.mar) | | [Inception v1](#inception_v1) | Image Classification | ImageNet | ONNX | 27 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-inception_v1.mar) | | [Inception v3 w/BatchNorm](#inception_v3) | Image Classification | ImageNet | MXNet | 45 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/inception-bn.mar) | | [LSTM PTB](#lstm-ptb) | Language Modeling | PennTreeBank | MXNet | 16 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/lstm_ptb.mar) | | [MobileNetv2-1.0](#mobilenetv2-1.0_onnx) | Image Classification | ImageNet | ONNX | 13.7 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-mobilenet.mar) | | [Network in Network (NiN)](#nin) | Image Classification | ImageNet | MXNet | 30 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/nin.mar) | | [ResNet-152](#resnet-152) | Image Classification | ImageNet | MXNet | 241 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/resnet-152.mar) | | [ResNet-18](#resnet-18) | Image Classification | ImageNet | MXNet | 43 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/resnet-18.mar) | | [ResNet50-SSD](#resnet50-ssd) | SSD (Single Shot MultiBox Detector) | ImageNet | MXNet | 124 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/resnet50_ssd.mar) | | [ResNext101-64x4d](#resnext101) | Image Classification | ImageNet | MXNet | 334 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/resnext-101-64x4d.mar) | | [ResNet-18v1](#resnet-18v1) | Image Classification | ImageNet | ONNX | 45 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet18v1.mar) | | [ResNet-34v1](#resnet-34v1) | Image Classification | ImageNet | ONNX | 83 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet34v1.mar) | | [ResNet-50v1](#resnet-50v1) | Image Classification | ImageNet | ONNX | 98 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet50v1.mar) | | [ResNet-101v1](#resnet-101v1) | Image Classification | ImageNet | ONNX | 171 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet101v1.mar) | | [ResNet-152v1](#resnet-152v1) | Image Classification | ImageNet | ONNX | 231 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet152v1.mar) | | [ResNet-18v2](#resnet-18v2) | Image Classification | ImageNet | ONNX | 45 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet18v2.mar) | | [ResNet-34v2](#resnet-34v2) | Image Classification | ImageNet | ONNX | 83 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet34v2.mar) | | [ResNet-50v2](#resnet-50v2) | Image Classification | ImageNet | ONNX | 98 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet50v2.mar) | | [ResNet-101v2](#resnet-101v2) | Image Classification | ImageNet | ONNX | 171 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet101v2.mar) | | [ResNet-152v2](#resnet-152v2) | Image Classification | ImageNet | ONNX | 231 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet152v2.mar) | | [Shufflenet](#shufflenet) | Image Classification | ImageNet | ONNX | 8.1 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/shufflenet.mar) | | [SqueezeNet_v1.1](#squeezenet_v1.1_onnx) | Image Classification | ImageNet | ONNX | 5 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-squeezenet.mar) | | [SqueezeNet v1.1](#squeezenet_v1.1) | Image Classification | ImageNet | MXNet | 5 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar) | | [VGG16](#vgg16) | Image Classification | ImageNet | MXNet | 490 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/vgg16.mar) | | [VGG16](#vgg16_onnx) | Image Classification | ImageNet | ONNX | 527 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-vgg16.mar) | | [VGG16_bn](#vgg16_bn_onnx) | Image Classification | ImageNet | ONNX | 527 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-vgg16_bn.mar) | | [VGG19](#vgg19) | Image Classification | ImageNet | MXNet | 509 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/vgg19.mar) | | [VGG19](#vgg19_onnx) | Image Classification | ImageNet | ONNX | 548 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-vgg19.mar) | | [VGG19_bn](#vgg19_bn_onnx) | Image Classification | ImageNet | ONNX | 548 MB | [.mar](https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-vgg19_bn.mar) | ## Details on Each Model Each model below comes with a basic description, and where available, a link to a scholarly article about the model. Many of these models use a kitten image to test inference. Use the following to get one that will work: ```bash curl -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg ``` ## AlexNet * **Type**: Image classification trained on ImageNet * **Reference**: [Krizhevsky, et al.](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks) * **Model Service**: [mxnet_vision_service.py](https://github.com/awslabs/multi-model-server/blob/master/examples/mxnet_vision/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models alexnet=https://s3.amazonaws.com/model-server/model_archive_1.0/alexnet.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/alexnet -T kitten.jpg ``` ## ArcFace-ResNet100 (from ONNX model zoo) * **Type**: Face Recognition model trained on refined MS-Celeb1M dataset (model imported from ONNX) * **Reference**: [Deng et al.](https://arxiv.org/abs/1801.07698) * **Model Service**: * [arcface_service.py](https://s3.amazonaws.com/model-server/model_archive_1.0/examples/onnx-arcface-resnet100/arcface_service.py) * [mtcnn_detector.py](https://s3.amazonaws.com/model-server/model_archive_1.0/examples/onnx-arcface-resnet100/mtcnn_detector.py) * **Install dependencies**: ```bash pip install opencv-python pip install scikit-learn pip install easydict pip install scikit-image pip install numpy ``` * **Start Server**: ```bash multi-model-server --start --models arcface=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-arcface-resnet100.mar ``` * **Get two test images**: ```bash curl -O https://s3.amazonaws.com/model-server/inputs/arcface-input1.jpg curl -O https://s3.amazonaws.com/model-server/inputs/arcface-input2.jpg ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/arcface -F "img1=@arcface-input1.jpg" -F "img2=@arcface-input2.jpg" ``` ## CaffeNet * **Type**: Image classification trained on ImageNet * **Reference**: [Krizhevsky, et al.](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks) * **Model Service**: [mxnet_vision_service.py](https://github.com/awslabs/multi-model-server/blob/master/examples/mxnet_vision/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models caffenet=https://s3.amazonaws.com/model-server/model_archive_1.0/caffenet.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/caffenet -T kitten.jpg ``` ## Character-level Convolutional Networks for text Classification * **Type**: Character-level Convolutional network for text classification trained on [Amazon Product Data](http://jmcauley.ucsd.edu/data/amazon/). * **Reference**: [R. He, J. McAuley et al.](https://arxiv.org/abs/1602.01585), [J. McAuley, C. Targett, J. Shi, A. van den Hengel et al.](https://arxiv.org/abs/1506.04757) * **Model Service**: [gluon_crepe.py](https://github.com/awslabs/multi-model-server/blob/master/examples/gluon_character_cnn/gluon_crepe.py) * **Start Server**: ```bash multi-model-server --start --models crepe=https://s3.amazonaws.com/model-server/model_archive_1.0/crepe.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/crepe -F "data=[{\"review_title\":\"Inception is the best\",\"review\": \"great direction and story\"}]" ``` ## DUC-ResNet101 (from ONNX model zoo) * **Type**: Semantic Segmentation model trained on the Cityscapes dataset (model imported from ONNX) * **Reference**: [Wang et al.](https://arxiv.org/abs/1702.08502) * **Model Service**: * [duc_service.py](https://s3.amazonaws.com/model-server/model_archive_1.0/examples/onnx-duc/duc_service.py) * [cityscapes_labels.py](https://s3.amazonaws.com/model-server/model_archive_1.0/examples/onnx-duc/cityscapes_labels.py) * **Install dependencies**: ```bash pip install opencv-python pip install pillow ``` * **Start Server**: ```bash multi-model-server --models duc=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-duc.mar ``` * **Get the test image**: ```bash curl -O https://s3.amazonaws.com/multi-model-server/onnx-duc/city1.jpg ``` * **Download inference script**: The script makes an inference call to the server using the test image, displays the colorized segmentation map and prints the confidence score. ```bash curl -O https://s3.amazonaws.com/multi-model-server/onnx-duc/duc-inference.py ``` * **Run Prediction**: ```bash python duc-inference.py city1.jpg ``` ## FERPlus * **Type**: Emotion detection trained on FER2013 dataset (model imported from ONNX) * **Reference**: [Barsoum et al.](https://arxiv.org/abs/1608.01041) * **Model Service**: [emotion_detection_service.py](https://s3.amazonaws.com/model-server/model_archive_1.0/examples/FERPlus/emotion_detection_service.py) * **Start Server**: ```bash multi-model-server --start --models FERPlus=https://s3.amazonaws.com/model-server/model_archive_1.0/FERPlus.mar ``` * **Get a test image**: ```bash curl -O https://s3.amazonaws.com/model-server/inputs/ferplus-input.jpg ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/FERPlus -T ferplus-input.jpg ``` ## Inception v1 * **Type**: Image classification trained on ImageNet * **Reference**: [Szegedy, et al., Google](https://arxiv.org/pdf/1512.00567.pdf) * **Model Service**: [mxnet_vision_service.py](https://github.com/awslabs/multi-model-server/blob/master/examples/mxnet_vision/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models onnx-inception-v1=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-inception_v1.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/onnx-inception-v1 -T kitten.jpg ``` ## Inception v3 * **Type**: Image classification trained on ImageNet * **Reference**: [Szegedy, et al., Google](https://arxiv.org/pdf/1512.00567.pdf) * **Model Service**: [mxnet_vision_service.py](https://github.com/awslabs/multi-model-server/blob/master/examples/mxnet_vision/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models inception-bn=https://s3.amazonaws.com/model-server/model_archive_1.0/inception-bn.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/inception-bn -T kitten.jpg ``` ## LSTM PTB Long short-term memory network trained on the PennTreeBank dataset. * **Reference**: [Hochreiter, et al.](http://www.bioinf.jku.at/publications/older/2604.pdf) * **Model Service**: [lstm_ptb_service.py](https://github.com/awslabs/multi-model-server/blob/master/examples/lstm_ptb/lstm_ptb_service.py) * **Start Server**: ```bash multi-model-server --start --models lstm_ptb=https://s3.amazonaws.com/model-server/model_archive_1.0/lstm_ptb.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/lstm_ptb -H "Content-Type: application/json" -d '[{"input_sentence": "on the exchange floor as soon as ual stopped trading we for a panic said one top floor trader"}]' ``` ## MobileNetv2-1.0 (from ONNX model zoo) * **Type**: Image classification trained on ImageNet (imported from ONNX) * **Reference**: [Sandler et al.](https://arxiv.org/abs/1801.04381) * **Model Service**: [mxnet_vision_service.py](https://s3.amazonaws.com/model-server/model_archive_1.0/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models mobilenet=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-mobilenet.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/mobilenet -T kitten.jpg ``` ## Network in Network * **Type**: Image classification trained on ImageNet * **Reference**: [Lin, et al.](https://arxiv.org/pdf/1312.4400v3.pdf) * **Model Service**: [mxnet_vision_service.py](https://github.com/awslabs/multi-model-server/blob/master/examples/mxnet_vision/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models nin=https://s3.amazonaws.com/model-server/model_archive_1.0/nin.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/nin -T kitten.jpg ``` ## ResNet-152 * **Type**: Image classification trained on ImageNet * **Reference**: [Lin, et al.](https://arxiv.org/pdf/1312.4400v3.pdf) * **Model Service**: [mxnet_vision_service.py](https://github.com/awslabs/multi-model-server/blob/master/examples/mxnet_vision/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models resnet-152=https://s3.amazonaws.com/model-server/model_archive_1.0/resnet-152.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/resnet-152 -T kitten.jpg ``` ## ResNet-18 * **Type**: Image classification trained on ImageNet * **Reference**: [He, et al.](https://arxiv.org/pdf/1512.03385v1.pdf) * **Model Service**: [mxnet_vision_service.py](https://github.com/awslabs/multi-model-server/blob/master/examples/mxnet_vision/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models resnet-18=https://s3.amazonaws.com/model-server/model_archive_1.0/resnet-18.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/resnet-18 -T kitten.jpg ``` ## ResNet50-SSD * **Type**: Image classification trained on ImageNet * **Reference**: [Liu, et al.](https://arxiv.org/pdf/1512.02325v4.pdf) * **Model Service**: [ssd_service.py](https://github.com/awslabs/multi-model-server/blob/master/examples/ssd/ssd_service.py) * **Start Server**: ```bash multi-model-server --start --models SSD=https://s3.amazonaws.com/model-server/model_archive_1.0/resnet50_ssd.mar ``` * **Run Prediction**: ```bash curl -O https://www.dphotographer.co.uk/users/21963/thm1024/1337890426_Img_8133.jpg curl -X POST http://127.0.0.1:8080/predictions/SSD -T 1337890426_Img_8133.jpg ``` ## ResNext101-64x4d * **Type**: Image classification trained on ImageNet * **Reference**: [Xie, et al.](https://arxiv.org/pdf/1611.05431.pdf) * **Model Service**: [mxnet_vision_service.py](https://github.com/awslabs/multi-model-server/blob/master/examples/mxnet_vision/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models resnext101=https://s3.amazonaws.com/model-server/model_archive_1.0/resnext-101-64x4d.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/resnext101 -T kitten.jpg ``` ## ResNet (from ONNX model zoo) ### ResNet18-v1 * **Type**: Image classification trained on ImageNet (imported from ONNX) * **Reference**: [He, et al.](https://arxiv.org/abs/1512.03385) * **Model Service**: [mxnet_vision_service.py](https://s3.amazonaws.com/model-server/model_archive_1.0/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models resnet18-v1=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet18v1.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/resnet18-v1 -T kitten.jpg ``` ### ResNet34-v1 * **Type**: Image classification trained on ImageNet (imported from ONNX) * **Reference**: [He, et al.](https://arxiv.org/abs/1512.03385) * **Model Service**: [mxnet_vision_service.py](https://s3.amazonaws.com/model-server/model_archive_1.0/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models resnet34-v1=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet34v1.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/resnet34-v1 -T kitten.jpg ``` ### ResNet50-v1 * **Type**: Image classification trained on ImageNet (imported from ONNX) * **Reference**: [He, et al.](https://arxiv.org/abs/1512.03385) * **Model Service**: [mxnet_vision_service.py](https://s3.amazonaws.com/model-server/model_archive_1.0/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models resnet50-v1=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet50v1.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/resnet50-v1 -T kitten.jpg ``` ### ResNet101-v1 * **Type**: Image classification trained on ImageNet (imported from ONNX) * **Reference**: [He, et al.](https://arxiv.org/abs/1512.03385) * **Model Service**: [mxnet_vision_service.py](https://s3.amazonaws.com/model-server/model_archive_1.0/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models resnet101-v1=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet101v1.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/resnet101-v1 -T kitten.jpg ``` ### ResNet152-v1 * **Type**: Image classification trained on ImageNet (imported from ONNX) * **Reference**: [He, et al.](https://arxiv.org/abs/1512.03385) * **Model Service**: [mxnet_vision_service.py](https://s3.amazonaws.com/model-server/model_archive_1.0/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models resnet152-v1=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet152v1.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/resnet152-v1 -T kitten.jpg ``` ### ResNet18-v2 * **Type**: Image classification trained on ImageNet (imported from ONNX) * **Reference**: [He, et al.](https://arxiv.org/abs/1603.05027) * **Model Service**: [mxnet_vision_service.py](https://s3.amazonaws.com/model-server/model_archive_1.0/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models resnet18-v2=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet18v2.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/resnet18-v2 -T kitten.jpg ``` ### ResNet34-v2 * **Type**: Image classification trained on ImageNet (imported from ONNX) * **Reference**: [He, et al.](https://arxiv.org/abs/1603.05027) * **Model Service**: [mxnet_vision_service.py](https://s3.amazonaws.com/model-server/model_archive_1.0/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models resnet34-v2=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet34v2.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/resnet34-v2 -T kitten.jpg ``` ### ResNet50-v2 * **Type**: Image classification trained on ImageNet (imported from ONNX) * **Reference**: [He, et al.](https://arxiv.org/abs/1603.05027) * **Model Service**: [mxnet_vision_service.py](https://s3.amazonaws.com/model-server/model_archive_1.0/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models resnet50-v2=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet50v2.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/resnet50-v2 -T kitten.jpg ``` ### ResNet101-v2 * **Type**: Image classification trained on ImageNet (imported from ONNX) * **Reference**: [He, et al.](https://arxiv.org/abs/1603.05027) * **Model Service**: [mxnet_vision_service.py](https://s3.amazonaws.com/model-server/model_archive_1.0/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models resnet101-v2=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet101v2.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/resnet101-v2 -T kitten.jpg ``` ### ResNet152-v2 * **Type**: Image classification trained on ImageNet (imported from ONNX) * **Reference**: [He, et al.](https://arxiv.org/abs/1603.05027) * **Model Service**: [mxnet_vision_service.py](https://s3.amazonaws.com/model-server/model_archive_1.0/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models resnet152-v2=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet152v2.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/resnet152-v2 -T kitten.jpg ``` ## Shufflenet_v2 * **Type**: Image classification trained on ImageNet * **Reference**: [Zhang, et al.](https://arxiv.org/abs/1707.01083) * **Model Service**: [mxnet_vision_service.py](https://github.com/awslabs/multi-model-server/blob/master/examples/mxnet_vision/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models shufflenet=https://s3.amazonaws.com/model-server/model_archive_1.0/shufflenet.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/shufflenet -T kitten.jpg ``` ## SqueezeNet v1.1 * **Type**: Image classification trained on ImageNet * **Reference**: [Iandola, et al.](https://arxiv.org/pdf/1602.07360v4.pdf) * **Model Service**: [mxnet_vision_service.py](https://github.com/awslabs/multi-model-server/blob/master/examples/mxnet_vision/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models squeezenet_v1.1=https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/squeezenet_v1.1 -T kitten.jpg ``` ## SqueezeNet v1.1 (from ONNX model zoo) * **Type**: Image classification trained on ImageNet (imported from ONNX) * **Reference**: [Iandola, et al.](https://arxiv.org/pdf/1602.07360v4.pdf) * **Model Service**: [mxnet_vision_service.py](https://github.com/awslabs/multi-model-server/blob/master/examples/mxnet_vision/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models onnx-squeezenet=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-squeezenet.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/onnx-squeezenet -T kitten.jpg ``` ## VGG16 * **Type**: Image classification trained on ImageNet * **Reference**: [Simonyan, et al.](https://arxiv.org/pdf/1409.1556v6.pdf) * **Model Service**: [mxnet_vision_service.py](https://github.com/awslabs/multi-model-server/blob/master/examples/mxnet_vision/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models vgg16=https://s3.amazonaws.com/model-server/model_archive_1.0/vgg16.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/vgg16 -T kitten.jpg ``` ## VGG19 * **Type**: Image classification trained on ImageNet * **Reference**: [Simonyan, et al.](https://arxiv.org/pdf/1409.1556v6.pdf) * **Model Service**: [mxnet_vision_service.py](https://github.com/awslabs/multi-model-server/blob/master/examples/mxnet_vision/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models vgg19=https://s3.amazonaws.com/model-server/model_archive_1.0/vgg19.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/vgg19 -T kitten.jpg ``` ## VGG (from ONNX model zoo) ### VGG16 * **Type**: Image classification trained on ImageNet (imported from ONNX) * **Reference**: [Simonyan, et al.](https://arxiv.org/pdf/1409.1556v6.pdf) * **Model Service**: [mxnet_vision_service.py](https://s3.amazonaws.com/model-server/model_archive_1.0/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models onnx-vgg16=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-vgg16.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/onnx-vgg16 -T kitten.jpg ``` ### VGG16_bn * **Type**: Image classification trained on ImageNet (imported from ONNX) * **Reference**: [Simonyan, et al.](https://arxiv.org/pdf/1409.1556v6.pdf) (Batch normalization applied after each conv layer of VGG16) * **Model Service**: [mxnet_vision_service.py](https://s3.amazonaws.com/model-server/model_archive_1.0/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models onnx-vgg16_bn=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-vgg16_bn.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/onnx-vgg16_bn -T kitten.jpg ``` ### VGG19 * **Type**: Image classification trained on ImageNet (imported from ONNX) * **Reference**: [Simonyan, et al.](https://arxiv.org/pdf/1409.1556v6.pdf) * **Model Service**: [mxnet_vision_service.py](https://s3.amazonaws.com/model-server/model_archive_1.0/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models onnx-vgg19=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-vgg19.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/onnx-vgg19 -T kitten.jpg ``` ### VGG19_bn * **Type**: Image classification trained on ImageNet (imported from ONNX) * **Reference**: [Simonyan, et al.](https://arxiv.org/pdf/1409.1556v6.pdf) (Batch normalization applied after each conv layer of VGG19) * **Model Service**: [mxnet_vision_service.py](https://s3.amazonaws.com/model-server/model_archive_1.0/mxnet_vision_service.py) * **Start Server**: ```bash multi-model-server --start --models onnx-vgg19_bn=https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-vgg19_bn.mar ``` * **Run Prediction**: ```bash curl -X POST http://127.0.0.1:8080/predictions/onnx-vgg19_bn -T kitten.jpg ``` ================================================ FILE: docs/rest_api.md ================================================ # MMS REST API MMS use RESTful API for both inference and management calls. The API is compliance with [OpenAPI specification 3.0](https://swagger.io/specification/). User can easily generate client side code for Java, Scala, C#, Javascript use [swagger codegen](https://swagger.io/swagger-codegen/). When MMS startup, it start two web services: * [Inference API](inference_api.md) * [Management API](management_api.md) By default, MMS listening on 8080 port for Inference API and 8081 on Management API. Both API is only accessible from localhost. Please see [MMS Configuration](configuration.md) for how to enable access from remote host. ================================================ FILE: docs/server.md ================================================ # Running the Model Server ## Contents of this Document * [Overview](#overview) * [Technical Details](#technical-details) * [Model Files](#model-files) * [Command Line Interface](#command-line-interface) * [Advanced Features](#advanced-features) ## Overview Multi Model Server can be used for many types of inference in production settings. It provides an easy-to-use command line interface and utilizes [REST based APIs](rest_api.md) handle state prediction requests. Support for models from a wide range of deep learning frameworks is achieved through its [ONNX model](https://onnx.ai) export feature. For example, you want to make an app that lets your users snap a picture, and it will tell them what objects were detected in the scene and predictions on what the objects might be. You can use MMS to serve a prediction endpoint for a object detection and identification model that intakes images, then returns predictions. You can also modify MMS behavior with custom services and run multiple models. There are examples of custom services in the [examples](../examples) folder. ## Technical Details Now that you have a high level view of MMS, let's get a little into the weeds. MMS takes a deep learning model and it wraps it in a set of REST APIs. Currently it comes with a built-in web server that you run from command line. This command line call takes in the single or multiple models you want to serve, along with additional optional parameters controlling the port, host, and logging. MMS supports running custom services to handle the specific inference handling logic. These are covered in more detail in the [custom service](custom_service.md) documentation. To try out MMS serving now, you can load the SqueezeNet model, which is under 5 MB, with this example: ```bash multi-model-server --start --models squeezenet=https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar ``` With the command above executed, you have MMS running on your host, listening for inference requests. To test it out, you will need to open a new terminal window next to the one running MMS. Then we will use `curl` to download one of these [cute pictures of a kitten](https://www.google.com/search?q=cute+kitten&tbm=isch&hl=en&cr=&safe=images) and curl's `-o` flag will name it `kitten.jpg` for us. Then we will `curl` a `POST` to the MMS predictions endpoint with the kitten's image. In the example below, both of these steps are provided. ```bash curl -o kitten.jpg \ https://upload.wikimedia.org/wikipedia/commons/8/8f/Cute-kittens-12929201-1600-1200.jpg curl -X POST http://127.0.0.1:8080/predictions/squeezenet -T kitten.jpg ``` ![kitten](https://upload.wikimedia.org/wikipedia/commons/8/8f/Cute-kittens-12929201-1600-1200.jpg) The predict endpoint will return a prediction response in JSON. Each of the probabilities are percentages, and the classes are coming from a `synset` file included inside the model archive which holds the thousand ImageNet classes this model is matching against. It will look something like the following result, where the 0.94 result is a 94% probable match with an Egyptian cat: ```json [ { "probability": 0.8582232594490051, "class": "n02124075 Egyptian cat" }, { "probability": 0.09159987419843674, "class": "n02123045 tabby, tabby cat" }, { "probability": 0.0374876894056797, "class": "n02123159 tiger cat" }, { "probability": 0.006165083032101393, "class": "n02128385 leopard, Panthera pardus" }, { "probability": 0.0031716004014015198, "class": "n02127052 lynx, catamount" } ] ``` You will see this result in the response to your `curl` call to the predict endpoint, in the terminal window running MMS and log files. After this deep dive, you might also be interested in: * [Logging](logging.md): logging options that are available * [Metrics](metrics.md): details on metrics collection * [REST API Description](rest_api.md): more detail about the server's endpoints * [Model Zoo](model_zoo.md): try serving different models * [Custom Services](custom_service.md): learn about serving different kinds of model and inference types ## Model Files The rest of this topic focus on serving of model files without much discussion on the model files themselves, where they come from, and how they're made. Long story short: it's a zip archive with the parameters, weights, and metadata that define a model that has been trained already. If you want to know more about the model files, take a look at the [model-archiver documentation](../model-archiver/README.md). ## Command Line Interface ```bash $ multi-model-server --help usage: multi-model-server [-h] [--start] [--stop] [--mms-config MMS_CONFIG] [--model-store MODEL_STORE] [--models MODEL_PATH1 MODEL_NAME=MODEL_PATH2... [MODEL_PATH1 MODEL_NAME=MODEL_PATH2... ...]] [--log-config LOG_CONFIG] Multi Model Server optional arguments: -h, --help show this help message and exit --start Start the model-server --stop Stop the model-server --mms-config MMS_CONFIG Configuration file for model server --model-store MODEL_STORE Model store location where models can be loaded --models MODEL_PATH1 MODEL_NAME=MODEL_PATH2... [MODEL_PATH1 MODEL_NAME=MODEL_PATH2... ...] Models to be loaded using [model_name=]model_location format. Location can be a HTTP URL, a model archive file or directory contains model archive files in MODEL_STORE. --log-config LOG_CONFIG Log4j configuration file for model server ``` #### Arguments: Example where no models are loaded at start time: ```bash multi-model-server ``` There are no default required arguments to start the server 1. **models**: required, = pairs. a) Model path can be a local file path or URI (s3 link, or http link). local file path: path/to/local/model/file or file://root/path/to/model/file s3 link: s3://S3_endpoint[:port]/... http link: http://hostname/path/to/resource b) The model file has .mar extension, it is actually a zip file with a .mar extension packing trained models and model signature files. c) Multiple models loading are also supported by specifying multiple name path pairs. 1. **model-store**: optional, A location where models are stored by default, all models in this location are loaded, the model name is same as archive or folder name. 1. **mms-config**: optional, provide a [configuration](configuration.md) file in config.properties format. 1. **log-config**: optional, This parameter will override default log4j2.xml, present within the server. 1. **start**: optional, A more descriptive way to start the server. 1. **stop**: optional, Stop the server if it is already running. ## Advanced Features ### Custom Services This topic is covered in much more detail on the [custom service documentation page](custom_service.md), but let's talk about how you start up your MMS server using a custom service and why you might want one. Let's say you have a model named `super-fancy-net.mar` in `/models` folder, which can detect a lot of things, but you want an API endpoint that detects only hotdogs. You would use a name that makes sense for it, such as the "not-hot-dog" API. In this case we might invoke MMS like this: ```bash multi-model-server --start --model-store /models --models not-hot-dog=super-fancy-net.mar ``` This would serve a prediction endpoint at `predictions/not-hot-dog/` and run your custom service code in the archive, the manifest in archive would point to the entry point. ### Serving Multiple Models with MMS Example multiple model usage: ```bash multi-model-server --start --model-store /models --models name=model_location name2=model_location2 ``` Here's an example for running the resnet-18 and the vgg16 models using local model files. ```bash multi-model-server --start --model-store /models --models resnet-18=resnet-18.mar squeezenet=squeezenet_v1.1.mar ``` If you don't have the model files locally, then you can call MMS using URLs to the model files. ```bash multi-model-server --models resnet=https://s3.amazonaws.com/model-server/model_archive_1.0/resnet-18.mar squeezenet=https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar ``` This will setup a local host serving resnet-18 model and squeezenet model on the same port, using the default 8080. Check http://127.0.0.1:8081/models to see that each model has an endpoint for prediction. In this case you would see `predictions/resnet` and `predictions/squeezenet` ### Logging and Metrics For details on logging see the [logging documentation](logging.md). For details on metrics see the [metrics documentation](metrics.md). ================================================ FILE: examples/README.md ================================================ # MMS Examples The following are examples on how to create and serve model archives with MMS. * Gluon Models * [AlexNet](gluon_alexnet) - train a custom Gluon model, use model archiver, then serve the model archive with MMS * [Character CNN](gluon_character_cnn) - download a pre-trained model that classifies product reviews, use model archiver, then serve the model archive with MMS * Symbolic Models * [Image Classification with SqueezeNet](mxnet_vision) - download a pre-trained computer vision model, use model archiver, then serve the model archive with MMS * [LSTM with PTB](lstm_ptb) - download a pre-trained LSTM model that generates text, use model archiver, then serve the model archive with MMS * [How to setup metrics collection with AWS CloudWatch](metrics_cloudwatch) * [DenseNet PyTorch](densenet_pytorch) ================================================ FILE: examples/densenet_pytorch/Dockerfile ================================================ FROM awsdeeplearningteam/multi-model-server:base-cpu-py3.6 # Add PyTorch RUN pip install --no-cache-dir torch torchvision ================================================ FILE: examples/densenet_pytorch/README.md ================================================ # PyTorch serving This example shows how to serve PyTorch trained models for flower species recognition.. The custom handler is implemented in `densenet_service.py`. For simplicity, we'll use a pre-trained model. For simplicity we will use docker container to run Model Server. ## Getting Started With Docker Build the docker image with pytorch as backend engine: ```bash cd examples/densenet_pytorch/ docker build . -t mms_with_pytorch ``` Run the container that you have built in previous step. ```bash docker run -it --entrypoint bash mms_with_pytorch ``` Start the server from inside the container: ```bash multi-model-server --models densenet161_pytorch=https://s3.amazonaws.com/model-server/model_archive_1.0/examples/PyTorch+models/densenet/densenet161_pytorch.mar ``` Now we can download a sample flower's image ```bash curl -O https://s3.amazonaws.com/model-server/inputs/flower.jpg ``` Get the status of the model with the following: ```bash curl -X POST http://127.0.0.1:8080/predictions/densenet161_pytorch -T flower.jpg ``` ```json [ { "canna lily": 0.01565943844616413 }, { "water lily": 0.015515935607254505 }, { "purple coneflower": 0.014358781278133392 }, { "globe thistle": 0.014226051047444344 }, { "ruby-lipped cattleya": 0.014212552458047867 } ] ``` For more information on MAR files and the built-in REST APIs, see: * https://github.com/awslabs/multi-model-server/tree/master/docs ================================================ FILE: examples/densenet_pytorch/densenet_service.py ================================================ import os import io import json import numpy as np from PIL import Image import torch from torch.autograd import Variable from torchvision import transforms import torch.nn.functional as F class PyTorchImageClassifier(): """ PyTorchImageClassifier service class. This service takes a flower image and returns the name of that flower. """ def __init__(self): self.checkpoint_file_path = None self.model = None self.mapping = None self.device = "cpu" self.initialized = False def initialize(self, context): """ Load the model and mapping file to perform infernece. """ properties = context.system_properties model_dir = properties.get("model_dir") # Read checkpoint file checkpoint_file_path = os.path.join(model_dir, "model.pth") if not os.path.isfile(checkpoint_file_path): raise RuntimeError("Missing model.pth file.") # Prepare the model checkpoint = torch.load(checkpoint_file_path, map_location='cpu') model = checkpoint['model'] model.classifier = checkpoint['classifier'] model.load_state_dict(checkpoint['state_dict']) model.class_to_idx = checkpoint['class_to_idx'] for param in model.parameters(): param.requires_grad = False self.model = model # Read the mapping file, index to flower mapping_file_path = os.path.join(model_dir, "index_to_name.json") if not os.path.isfile(mapping_file_path): raise RuntimeError("Missing the mapping file") with open(mapping_file_path) as f: self.mapping = json.load(f) self.initialized = True def preprocess(self, data): """ Scales, crops, and normalizes a PIL image for a PyTorch model, returns an Numpy array """ image = data[0].get("data") if image is None: image = data[0].get("body") my_preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = Image.open(io.BytesIO(image)) image = my_preprocess(image) return image def inference(self, img, topk=5): ''' Predict the class (or classes) of an image using a trained deep learning model. ''' # Convert 2D image to 1D vector img = np.expand_dims(img, 0) img = torch.from_numpy(img) self.model.eval() inputs = Variable(img).to(self.device) logits = self.model.forward(inputs) ps = F.softmax(logits,dim=1) topk = ps.cpu().topk(topk) probs, classes = (e.data.numpy().squeeze().tolist() for e in topk) results = [] for i in range(len(probs)): tmp = dict() tmp[self.mapping[str(classes[i])]] = probs[i] results.append(tmp) return [results] def postprocess(self, inference_output): return inference_output # Following code is not necessary if your service class contains `handle(self, data, context)` function _service = PyTorchImageClassifier() def handle(data, context): if not _service.initialized: _service.initialize(context) if data is None: return None data = _service.preprocess(data) data = _service.inference(data) data = _service.postprocess(data) return data ================================================ FILE: examples/densenet_pytorch/index_to_name.json ================================================ {"21": "fire lily", "3": "canterbury bells", "45": "bolero deep blue", "1": "pink primrose", "34": "mexican aster", "27": "prince of wales feathers", "7": "moon orchid", "16": "globe-flower", "25": "grape hyacinth", "26": "corn poppy", "79": "toad lily", "39": "siam tulip", "24": "red ginger", "67": "spring crocus", "35": "alpine sea holly", "32": "garden phlox", "10": "globe thistle", "6": "tiger lily", "93": "ball moss", "33": "love in the mist", "9": "monkshood", "102": "blackberry lily", "14": "spear thistle", "19": "balloon flower", "100": "blanket flower", "13": "king protea", "49": "oxeye daisy", "15": "yellow iris", "61": "cautleya spicata", "31": "carnation", "64": "silverbush", "68": "bearded iris", "63": "black-eyed susan", "69": "windflower", "62": "japanese anemone", "20": "giant white arum lily", "38": "great masterwort", "4": "sweet pea", "86": "tree mallow", "101": "trumpet creeper", "42": "daffodil", "22": "pincushion flower", "2": "hard-leaved pocket orchid", "54": "sunflower", "66": "osteospermum", "70": "tree poppy", "85": "desert-rose", "99": "bromelia", "87": "magnolia", "5": "english marigold", "92": "bee balm", "28": "stemless gentian", "97": "mallow", "57": "gaura", "40": "lenten rose", "47": "marigold", "59": "orange dahlia", "48": "buttercup", "55": "pelargonium", "36": "ruby-lipped cattleya", "91": "hippeastrum", "29": "artichoke", "71": "gazania", "90": "canna lily", "18": "peruvian lily", "98": "mexican petunia", "8": "bird of paradise", "30": "sweet william", "17": "purple coneflower", "52": "wild pansy", "84": "columbine", "12": "colt's foot", "11": "snapdragon", "96": "camellia", "23": "fritillary", "50": "common dandelion", "44": "poinsettia", "53": "primula", "72": "azalea", "65": "californian poppy", "80": "anthurium", "76": "morning glory", "37": "cape flower", "56": "bishop of llandaff", "60": "pink-yellow dahlia", "82": "clematis", "58": "geranium", "75": "thorn apple", "41": "barbeton daisy", "95": "bougainvillea", "43": "sword lily", "83": "hibiscus", "78": "lotus lotus", "88": "cyclamen", "94": "foxglove", "81": "frangipani", "74": "rose", "89": "watercress", "73": "water lily", "46": "wallflower", "77": "passion flower", "51": "petunia"} ================================================ FILE: examples/gluon_alexnet/README.md ================================================ # Loading and serving Gluon models on Multi Model Server (MMS) Multi Model Server (MMS) supports loading and serving MXNet Imperative and Hybrid Gluon models. This is a short tutorial on how to write a custom Gluon model, and then serve it with MMS. This tutorial covers the following: 1. [Prerequisites](#prerequisites) 2. [Serve a Gluon model](#load-and-serve-a-gluon-model) * [Load and serve a pre-trained Gluon model](#load-and-serve-a-pre-trained-gluon-model) * [Load and serve a custom Gluon model](#load-and-serve-a-custom-gluon-imperative-model) * [Load and serve a custom hybrid Gluon model](#load-and-serve-a-hybrid-gluon-model) 3. [Conclusion](#conclusion) ## Prerequisites * **Basic Gluon knowledge**. If you are using Gluon for the first time, but are familiar with creating a neural network with MXNet or another framework, you may refer this 10 min Gluon crash-course: [Predict with a pre-trained model](http://gluon-crash-course.mxnet.io/predict.html). * **Gluon naming**. Fine-tuning pre-trained Gluon models requires some understanding of how the naming conventions work. Take a look at the [Naming of Gluon Parameter and Blocks](https://mxnet.incubator.apache.org/tutorials/gluon/naming.html) tutorial for more information. * **Basic MMS knowledge**. If you are using MMS for the first time, you should take advantage of the [MMS QuickStart tutorial](https://github.com/awslabs/multi-model-server#quick-start). * **MMS installed**. If you haven't already, [install MMS with pip](https://github.com/awslabs/multi-model-server/blob/master/docs/install.md#install-mms-with-pip) or [install MMS from source](https://github.com/awslabs/multi-model-server/blob/master/docs/install.md#install-mms-from-source-code). Either installation will also install MXNet. Refer to the [MXNet model zoo](https://mxnet.incubator.apache.org/api/python/gluon/model_zoo.html) documentation for examples of accessing other models. ## Load and serve a Gluon model There are three scenarios for serving a Gluon model with MMS: 1. Load and serve a pre-trained Gluon model. 2. Load and serve a custom imperative Gluon model. 3. Load and serve a custom hybrid Gluon model. To learn more about the differences between gluon and hybrid gluon models refer to [the following document](http://gluon.mxnet.io/chapter07_distributed-learning/hybridize.html) ### Load and serve a pre-trained Gluon model Loading and serving a pre-trained Gluon model is the simplest of the three scenarios. These models don't require you to provide `symbols` and `params` files. It is easy to access a model with a couple of lines of code. The following code snippet shows how to load and serve a pretrained Gluon model. ```python class PretrainedAlexnetService(GluonBaseService): """ Pretrained alexnet Service """ def initialize(self, params): self.net = mxnet.gluon.model_zoo.vision.alexnet(pretrained=True) self.param_filename = "alexnet.params" super(PretrainedAlexnetService, self).initialize(params) def postprocess(self, data): idx = data.topk(k=5)[0] return [[{'class': (self.labels[int(i.asscalar())]).split()[1], 'probability': float(data[0, int(i.asscalar())].asscalar())} for i in idx]] svc = PretrainedAlexnetService() def pretrained_gluon_alexnet(data, context): res = None if not svc.initialized: svc.initialize(context) if data is not None: res = svc.predict(data) return res ``` For an actual code implementation, refer to the custom-service code which uses the [pre-trained Alexnet](https://github.com/awslabs/multi-model-server/blob/master/examples/gluon_alexnet/gluon_pretrained_alexnet.py) ### Serve pre-trained model with MMS To serve pre-trained models with MMS we would need to create an model archive file. Follow the below steps to get the example custom service loaded and served with MMS. 1. Create a `models` directory ```bash mkdir /tmp/models ``` 2. Copy the [example code](https://github.com/awslabs/multi-model-server/blob/master/examples/gluon_alexnet/gluon_pretrained_alexnet.py) and other required artifacts to this folder ```bash cp ../model_service_template/gluon_base_service.py ../model_service_template/mxnet_utils/ndarray.py gluon_pretrained_alexnet.py synset.txt signature.json /tmp/models/. ``` 3. Run the model-export tool on this folder. ```bash model-archiver --model-name alexnet --model-path /tmp/models --handler gluon_pretrained_alexnet:pretrained_gluon_alexnet --runtime python --export-path /tmp ``` This creates a model-archive file `/tmp/alexnet.mar`. 4. You could run the server with this model file to serve the pre-trained alexnet. ```bash multi-model-server --start --models alexnet.mar --model-store /tmp ``` 5. Test your service ```bash curl -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg curl -X POST http://127.0.0.1:8080/alexnet/predict -F "data=@kitten.jpg" ``` ## Load and serve a custom Gluon imperative model To load an imperative model for use with MMS, you must activate the network in a MMS custom service. Once activated, MMS can load the pre-trained parameters and start serving the imperative model. You also need to handle pre-processing and post-processing of the image input. We created a custom imperative model using Gluon. Refer to [custom service code](https://github.com/awslabs/multi-model-server/examples/gluon_alexnet/examples/gluon_alexnet/gluon_alexnet.py) The network definition, which is defined in the example, is as follows ```python class GluonImperativeAlexNet(gluon.Block): """ Fully imperative gluon Alexnet model """ def __init__(self, classes=1000, **kwargs): super(GluonImperativeAlexNet, self).__init__(**kwargs) with self.name_scope(): self.features = nn.Sequential(prefix='') with self.features.name_scope(): self.features.add(nn.Conv2D(64, kernel_size=11, strides=4, padding=2, activation='relu')) self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) self.features.add(nn.Conv2D(192, kernel_size=5, padding=2, activation='relu')) self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) self.features.add(nn.Conv2D(384, kernel_size=3, padding=1, activation='relu')) self.features.add(nn.Conv2D(256, kernel_size=3, padding=1, activation='relu')) self.features.add(nn.Conv2D(256, kernel_size=3, padding=1, activation='relu')) self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) self.features.add(nn.Flatten()) self.features.add(nn.Dense(4096, activation='relu')) self.features.add(nn.Dropout(0.5)) self.features.add(nn.Dense(4096, activation='relu')) self.features.add(nn.Dropout(0.5)) self.output = nn.Dense(classes) def forward(self, x): x = self.features(x) x = self.output(x) return x ``` The pre-process, inference and post-process steps are similar to the service code that we saw in the [above section](#load-and-serve-a-pre-trained-gluon-model). ```python class ImperativeAlexnetService(GluonBaseService): """ Gluon alexnet Service """ def initialize(self, params): self.net = GluonImperativeAlexNet() self.param_filename = "alexnet.params" super(ImperativeAlexnetService, self).initialize(params) def postprocess(self, data): idx = data.topk(k=5)[0] return [[{'class': (self.labels[int(i.asscalar())]).split()[1], 'probability': float(data[0, int(i.asscalar())].asscalar())} for i in idx]] svc = ImperativeAlexnetService() def imperative_gluon_alexnet_inf(data, context): res = None if not svc.initialized: svc.initialize(context) if data is not None: res = svc.predict(data) return res ``` ### Test your imperative Gluon model service To serve imperative Gluon models with MMS we would need to create an model archive file. Follow the below steps to get the example custom service loaded and served with MMS. 1. Create a `models` directory ```bash mkdir /tmp/models ``` 2. Copy the [example code](https://github.com/awslabs/multi-model-server/examples/gluon_alexnet/gluon_imperative_alexnet.py) and other required artifacts to this folder ```bash cp ../model_service_template/gluon_base_service.py ../model_service_template/mxnet_utils/ndarray.py gluon_imperative_alexnet.py synset.txt signature.json /tmp/models/. ``` 3. Download/copy the parameters to this `/tmp/models` directory. For this example, we have the parameters file in a S3 bucket. ```bash wget https://s3.amazonaws.com/gluon-mms-model-files/alexnet.params mv alexnet.params /tmp/models ``` 4. Run the model-export tool on this folder. ```bash model-archiver --model-name alexnet --model-path /tmp/models --handler gluon_imperative_alexnet:imperative_gluon_alexnet_inf --runtime python --export-path /tmp ``` This creates a model-archive file `/tmp/alexnet.mar`. 5. You could run the server with this model file to serve the pre-trained alexnet. ```bash multi-model-server --start --models alexnet.mar --model-store /tmp ``` 6. Test your service ```bash curl -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg curl -X POST http://127.0.0.1:8080/alexnet/predict -F "data=@kitten.jpg" ``` The output should be close to the following: ```json {"prediction":[{"class":"lynx,","probability":0.9411474466323853},{"class":"leopard,","probability":0.016749195754528046},{"class":"tabby,","probability":0.012754007242619991},{"class":"Egyptian","probability":0.011728651821613312},{"class":"tiger","probability":0.008974711410701275}]} ``` ## Load and serve a hybrid Gluon model To serve hybrid Gluon models with MMS, let's consider `gluon_imperative_alexnet.py` in `multi-model-server/examples/gluon_alexnet` folder. We first convert the model to a `Gluon` hybrid block. For additional background on using `HybridBlocks` and the need to `hybridize` refer to this Gluon [hybridize](http://gluon.mxnet.io/chapter07_distributed-learning/hybridize.html#) tutorial. The above network, after this conversion, would look as follows: ```python class GluonHybridAlexNet(HybridBlock): """ Hybrid Block gluon model """ def __init__(self, classes=1000, **kwargs): super(GluonHybridAlexNet, self).__init__(**kwargs) with self.name_scope(): self.features = nn.HybridSequential(prefix='') with self.features.name_scope(): self.features.add(nn.Conv2D(64, kernel_size=11, strides=4, padding=2, activation='relu')) self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) self.features.add(nn.Conv2D(192, kernel_size=5, padding=2, activation='relu')) self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) self.features.add(nn.Conv2D(384, kernel_size=3, padding=1, activation='relu')) self.features.add(nn.Conv2D(256, kernel_size=3, padding=1, activation='relu')) self.features.add(nn.Conv2D(256, kernel_size=3, padding=1, activation='relu')) self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) self.features.add(nn.Flatten()) self.features.add(nn.Dense(4096, activation='relu')) self.features.add(nn.Dropout(0.5)) self.features.add(nn.Dense(4096, activation='relu')) self.features.add(nn.Dropout(0.5)) self.output = nn.Dense(classes) def hybrid_forward(self, F, x): x = self.features(x) x = self.output(x) return x ``` We could use the same custom service code as in the above section, ```python class HybridAlexnetService(GluonBaseService): """ Gluon alexnet Service """ def initialize(self, params): self.net = GluonHybridAlexNet() self.param_filename = "alexnet.params" super(HybridAlexnetService, self).initialize(params) self.net.hybridize() def postprocess(self, data): idx = data.topk(k=5)[0] return [[{'class': (self.labels[int(i.asscalar())]).split()[1], 'probability': float(data[0, int(i.asscalar())].asscalar())} for i in idx]] svc = HybridAlexnetService() def hybrid_gluon_alexnet_inf(data, context): res = None if not svc.initialized: svc.initialize(context) if data is not None: res = svc.predict(data) return res ``` Similar to imperative models, this model doesn't require `Symbols` as the call to `.hybridize()` compiles the neural net. This would store the `symbols` implicitly. ### Test your hybrid Gluon model service To serve Hybrid Gluon models with MMS we would need to create an model archive file. Follow the below steps to get the example custom service loaded and served with MMS. 1. Create a `models` directory ```bash mkdir /tmp/models ``` 2. Copy the [example code](https://github.com/awslabs/multi-model-server/examples/gluon_alexnet/gluon_imperative_alexnet.py) and other required artifacts to this folder ```bash cp ../model_service_template/gluon_base_service.py ../model_service_template/mxnet_utils/ndarray.py gluon_hybrid_alexnet.py synset.txt signature.json /tmp/models/. ``` 3. Download/copy the parameters to this `/tmp/models` directory. For this example, we have the parameters file in a S3 bucket. ```bash wget https://s3.amazonaws.com/gluon-mms-model-files/alexnet.params mv alexnet.params /tmp/models ``` 4. Run the model-export tool on this folder. ```bash model-archiver --model-name alexnet --model-path /tmp/models --handler gluon_hybrid_alexnet:hybrid_gluon_alexnet_inf --runtime python --export-path /tmp ``` This creates a model-archive file `/tmp/alexnet.mar`. 5. You could run the server with this model file to serve the pre-trained alexnet. ```bash multi-model-server --start --models alexnet.mar --model-store /tmp ``` 6. Test your service ```bash curl -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg curl -X POST http://127.0.0.1:8080/alexnet/predict -F "data=@kitten.jpg" ``` The output should be close to the following: ```json {"prediction":[{"class":"lynx,","probability":0.9411474466323853},{"class":"leopard,","probability":0.016749195754528046},{"class":"tabby,","probability":0.012754007242619991},{"class":"Egyptian","probability":0.011728651821613312},{"class":"tiger","probability":0.008974711410701275}]} ``` ## Conclusion In this tutorial you learned how to serve Gluon models in three unique scenarios: a pre-trained imperative model directly from the model zoo, a custom imperative model, and a hybrid model. For further examples of customizing gluon models, try the Gluon tutorial for [Transferring knowledge through fine-tuning](http://gluon.mxnet.io/chapter08_computer-vision/fine-tuning.html). For an advanced custom service example, try the MMS [SSD example](https://github.com/awslabs/multi-model-server/tree/master/examples/ssd). ================================================ FILE: examples/gluon_alexnet/gluon_hybrid_alexnet.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. from mxnet.gluon import nn from mxnet.gluon.block import HybridBlock from gluon_base_service import GluonBaseService """ MMS examples for loading Gluon Hybrid models """ class GluonHybridAlexNet(HybridBlock): """ Hybrid Block gluon model """ def __init__(self, classes=1000, **kwargs): """ This is the network definition of Gluon Hybrid Alexnet :param classes: :param kwargs: """ super(GluonHybridAlexNet, self).__init__(**kwargs) with self.name_scope(): self.features = nn.HybridSequential(prefix='') with self.features.name_scope(): self.features.add(nn.Conv2D(64, kernel_size=11, strides=4, padding=2, activation='relu')) self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) self.features.add(nn.Conv2D(192, kernel_size=5, padding=2, activation='relu')) self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) self.features.add(nn.Conv2D(384, kernel_size=3, padding=1, activation='relu')) self.features.add(nn.Conv2D(256, kernel_size=3, padding=1, activation='relu')) self.features.add(nn.Conv2D(256, kernel_size=3, padding=1, activation='relu')) self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) self.features.add(nn.Flatten()) self.features.add(nn.Dense(4096, activation='relu')) self.features.add(nn.Dropout(0.5)) self.features.add(nn.Dense(4096, activation='relu')) self.features.add(nn.Dropout(0.5)) self.output = nn.Dense(classes) def hybrid_forward(self, F, x): x = self.features(x) x = self.output(x) return x class HybridAlexnetService(GluonBaseService): """ Gluon alexnet Service """ def initialize(self, params): self.net = GluonHybridAlexNet() self.param_filename = "alexnet.params" super(HybridAlexnetService, self).initialize(params) self.net.hybridize() def postprocess(self, data): idx = data.topk(k=5)[0] return [[{'class': (self.labels[int(i.asscalar())]).split()[1], 'probability': float(data[0, int(i.asscalar())].asscalar())} for i in idx]] svc = HybridAlexnetService() def hybrid_gluon_alexnet_inf(data, context): res = None if not svc.initialized: svc.initialize(context) if data is not None: res = svc.predict(data) return res ================================================ FILE: examples/gluon_alexnet/gluon_imperative_alexnet.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. from mxnet import gluon from mxnet.gluon import nn from gluon_base_service import GluonBaseService """ MMS examples for loading Gluon Imperative models """ class GluonImperativeAlexNet(gluon.Block): """ Fully imperative gluon Alexnet model """ def __init__(self, classes=1000, **kwargs): """ This is the network definition of Imperative Alexnet :param classes: :param kwargs: """ super(GluonImperativeAlexNet, self).__init__(**kwargs) with self.name_scope(): self.features = nn.Sequential(prefix='') with self.features.name_scope(): self.features.add(nn.Conv2D(64, kernel_size=11, strides=4, padding=2, activation='relu')) self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) self.features.add(nn.Conv2D(192, kernel_size=5, padding=2, activation='relu')) self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) self.features.add(nn.Conv2D(384, kernel_size=3, padding=1, activation='relu')) self.features.add(nn.Conv2D(256, kernel_size=3, padding=1, activation='relu')) self.features.add(nn.Conv2D(256, kernel_size=3, padding=1, activation='relu')) self.features.add(nn.MaxPool2D(pool_size=3, strides=2)) self.features.add(nn.Flatten()) self.features.add(nn.Dense(4096, activation='relu')) self.features.add(nn.Dropout(0.5)) self.features.add(nn.Dense(4096, activation='relu')) self.features.add(nn.Dropout(0.5)) self.output = nn.Dense(classes) def forward(self, x): x = self.features(x) x = self.output(x) return x class ImperativeAlexnetService(GluonBaseService): """ Gluon alexnet Service """ def initialize(self, params): self.net = GluonImperativeAlexNet() self.param_filename = "alexnet.params" super(ImperativeAlexnetService, self).initialize(params) def postprocess(self, data): idx = data.topk(k=5)[0] return [[{'class': (self.labels[int(i.asscalar())]).split()[1], 'probability': float(data[0, int(i.asscalar())].asscalar())} for i in idx]] svc = ImperativeAlexnetService() def imperative_gluon_alexnet_inf(data, context): """ Handler registered for inference :param data: :param context: :return: """ res = None if not svc.initialized: svc.initialize(context) if data is not None: res = svc.predict(data) return res ================================================ FILE: examples/gluon_alexnet/gluon_pretrained_alexnet.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. import mxnet from gluon_base_service import GluonBaseService """ Gluon Pretrained Alexnet model """ class PretrainedAlexnetService(GluonBaseService): """ Pretrained alexnet Service """ def initialize(self, params): """ Initialize the model :param params: This is the same as the Context object :return: """ self.net = mxnet.gluon.model_zoo.vision.alexnet(pretrained=True) super(PretrainedAlexnetService, self).initialize(params) def postprocess(self, data): """ Post process for the Gluon Alexnet model :param data: :return: """ idx = data.topk(k=5)[0] return [[{'class': (self.labels[int(i.asscalar())]).split()[1], 'probability': float(data[0, int(i.asscalar())].asscalar())} for i in idx]] svc = PretrainedAlexnetService() def pretrained_gluon_alexnet(data, context): """ This is the handler that needs to be registerd in the model-archive. :param data: :param context: :return: """ res = None if not svc.initialized: svc.initialize(context) if data is not None: res = svc.predict(data) return res ================================================ FILE: examples/gluon_alexnet/signature.json ================================================ { "inputs": [ { "data_name": "data", "data_shape": [0, 3, 224, 224] } ], "input_type": "image/jpeg", "outputs": [ { "data_name": "softmax", "data_shape": [0, 1000] } ], "output_type": "application/json" } ================================================ FILE: examples/gluon_alexnet/synset.txt ================================================ n01440764 tench, Tinca tinca n01443537 goldfish, Carassius auratus n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias n01491361 tiger shark, Galeocerdo cuvieri n01494475 hammerhead, hammerhead shark n01496331 electric ray, crampfish, numbfish, torpedo n01498041 stingray n01514668 cock n01514859 hen n01518878 ostrich, Struthio camelus n01530575 brambling, Fringilla montifringilla n01531178 goldfinch, Carduelis carduelis n01532829 house finch, linnet, Carpodacus mexicanus n01534433 junco, snowbird n01537544 indigo bunting, indigo finch, indigo bird, Passerina cyanea n01558993 robin, American robin, Turdus migratorius n01560419 bulbul n01580077 jay n01582220 magpie n01592084 chickadee n01601694 water ouzel, dipper n01608432 kite n01614925 bald eagle, American eagle, Haliaeetus leucocephalus n01616318 vulture n01622779 great grey owl, great gray owl, Strix nebulosa n01629819 European fire salamander, Salamandra salamandra n01630670 common newt, Triturus vulgaris n01631663 eft n01632458 spotted salamander, Ambystoma maculatum n01632777 axolotl, mud puppy, Ambystoma mexicanum n01641577 bullfrog, Rana catesbeiana n01644373 tree frog, tree-frog n01644900 tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui n01664065 loggerhead, loggerhead turtle, Caretta caretta n01665541 leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea n01667114 mud turtle n01667778 terrapin n01669191 box turtle, box tortoise n01675722 banded gecko n01677366 common iguana, iguana, Iguana iguana n01682714 American chameleon, anole, Anolis carolinensis n01685808 whiptail, whiptail lizard n01687978 agama n01688243 frilled lizard, Chlamydosaurus kingi n01689811 alligator lizard n01692333 Gila monster, Heloderma suspectum n01693334 green lizard, Lacerta viridis n01694178 African chameleon, Chamaeleo chamaeleon n01695060 Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis n01697457 African crocodile, Nile crocodile, Crocodylus niloticus n01698640 American alligator, Alligator mississipiensis n01704323 triceratops n01728572 thunder snake, worm snake, Carphophis amoenus n01728920 ringneck snake, ring-necked snake, ring snake n01729322 hognose snake, puff adder, sand viper n01729977 green snake, grass snake n01734418 king snake, kingsnake n01735189 garter snake, grass snake n01737021 water snake n01739381 vine snake n01740131 night snake, Hypsiglena torquata n01742172 boa constrictor, Constrictor constrictor n01744401 rock python, rock snake, Python sebae n01748264 Indian cobra, Naja naja n01749939 green mamba n01751748 sea snake n01753488 horned viper, cerastes, sand viper, horned asp, Cerastes cornutus n01755581 diamondback, diamondback rattlesnake, Crotalus adamanteus n01756291 sidewinder, horned rattlesnake, Crotalus cerastes n01768244 trilobite n01770081 harvestman, daddy longlegs, Phalangium opilio n01770393 scorpion n01773157 black and gold garden spider, Argiope aurantia n01773549 barn spider, Araneus cavaticus n01773797 garden spider, Aranea diademata n01774384 black widow, Latrodectus mactans n01774750 tarantula n01775062 wolf spider, hunting spider n01776313 tick n01784675 centipede n01795545 black grouse n01796340 ptarmigan n01797886 ruffed grouse, partridge, Bonasa umbellus n01798484 prairie chicken, prairie grouse, prairie fowl n01806143 peacock n01806567 quail n01807496 partridge n01817953 African grey, African gray, Psittacus erithacus n01818515 macaw n01819313 sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita n01820546 lorikeet n01824575 coucal n01828970 bee eater n01829413 hornbill n01833805 hummingbird n01843065 jacamar n01843383 toucan n01847000 drake n01855032 red-breasted merganser, Mergus serrator n01855672 goose n01860187 black swan, Cygnus atratus n01871265 tusker n01872401 echidna, spiny anteater, anteater n01873310 platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus n01877812 wallaby, brush kangaroo n01882714 koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus n01883070 wombat n01910747 jellyfish n01914609 sea anemone, anemone n01917289 brain coral n01924916 flatworm, platyhelminth n01930112 nematode, nematode worm, roundworm n01943899 conch n01944390 snail n01945685 slug n01950731 sea slug, nudibranch n01955084 chiton, coat-of-mail shell, sea cradle, polyplacophore n01968897 chambered nautilus, pearly nautilus, nautilus n01978287 Dungeness crab, Cancer magister n01978455 rock crab, Cancer irroratus n01980166 fiddler crab n01981276 king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica n01983481 American lobster, Northern lobster, Maine lobster, Homarus americanus n01984695 spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish n01985128 crayfish, crawfish, crawdad, crawdaddy n01986214 hermit crab n01990800 isopod n02002556 white stork, Ciconia ciconia n02002724 black stork, Ciconia nigra n02006656 spoonbill n02007558 flamingo n02009229 little blue heron, Egretta caerulea n02009912 American egret, great white heron, Egretta albus n02011460 bittern n02012849 crane n02013706 limpkin, Aramus pictus n02017213 European gallinule, Porphyrio porphyrio n02018207 American coot, marsh hen, mud hen, water hen, Fulica americana n02018795 bustard n02025239 ruddy turnstone, Arenaria interpres n02027492 red-backed sandpiper, dunlin, Erolia alpina n02028035 redshank, Tringa totanus n02033041 dowitcher n02037110 oystercatcher, oyster catcher n02051845 pelican n02056570 king penguin, Aptenodytes patagonica n02058221 albatross, mollymawk n02066245 grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus n02071294 killer whale, killer, orca, grampus, sea wolf, Orcinus orca n02074367 dugong, Dugong dugon n02077923 sea lion n02085620 Chihuahua n02085782 Japanese spaniel n02085936 Maltese dog, Maltese terrier, Maltese n02086079 Pekinese, Pekingese, Peke n02086240 Shih-Tzu n02086646 Blenheim spaniel n02086910 papillon n02087046 toy terrier n02087394 Rhodesian ridgeback n02088094 Afghan hound, Afghan n02088238 basset, basset hound n02088364 beagle n02088466 bloodhound, sleuthhound n02088632 bluetick n02089078 black-and-tan coonhound n02089867 Walker hound, Walker foxhound n02089973 English foxhound n02090379 redbone n02090622 borzoi, Russian wolfhound n02090721 Irish wolfhound n02091032 Italian greyhound n02091134 whippet n02091244 Ibizan hound, Ibizan Podenco n02091467 Norwegian elkhound, elkhound n02091635 otterhound, otter hound n02091831 Saluki, gazelle hound n02092002 Scottish deerhound, deerhound n02092339 Weimaraner n02093256 Staffordshire bullterrier, Staffordshire bull terrier n02093428 American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier n02093647 Bedlington terrier n02093754 Border terrier n02093859 Kerry blue terrier n02093991 Irish terrier n02094114 Norfolk terrier n02094258 Norwich terrier n02094433 Yorkshire terrier n02095314 wire-haired fox terrier n02095570 Lakeland terrier n02095889 Sealyham terrier, Sealyham n02096051 Airedale, Airedale terrier n02096177 cairn, cairn terrier n02096294 Australian terrier n02096437 Dandie Dinmont, Dandie Dinmont terrier n02096585 Boston bull, Boston terrier n02097047 miniature schnauzer n02097130 giant schnauzer n02097209 standard schnauzer n02097298 Scotch terrier, Scottish terrier, Scottie n02097474 Tibetan terrier, chrysanthemum dog n02097658 silky terrier, Sydney silky n02098105 soft-coated wheaten terrier n02098286 West Highland white terrier n02098413 Lhasa, Lhasa apso n02099267 flat-coated retriever n02099429 curly-coated retriever n02099601 golden retriever n02099712 Labrador retriever n02099849 Chesapeake Bay retriever n02100236 German short-haired pointer n02100583 vizsla, Hungarian pointer n02100735 English setter n02100877 Irish setter, red setter n02101006 Gordon setter n02101388 Brittany spaniel n02101556 clumber, clumber spaniel n02102040 English springer, English springer spaniel n02102177 Welsh springer spaniel n02102318 cocker spaniel, English cocker spaniel, cocker n02102480 Sussex spaniel n02102973 Irish water spaniel n02104029 kuvasz n02104365 schipperke n02105056 groenendael n02105162 malinois n02105251 briard n02105412 kelpie n02105505 komondor n02105641 Old English sheepdog, bobtail n02105855 Shetland sheepdog, Shetland sheep dog, Shetland n02106030 collie n02106166 Border collie n02106382 Bouvier des Flandres, Bouviers des Flandres n02106550 Rottweiler n02106662 German shepherd, German shepherd dog, German police dog, alsatian n02107142 Doberman, Doberman pinscher n02107312 miniature pinscher n02107574 Greater Swiss Mountain dog n02107683 Bernese mountain dog n02107908 Appenzeller n02108000 EntleBucher n02108089 boxer n02108422 bull mastiff n02108551 Tibetan mastiff n02108915 French bulldog n02109047 Great Dane n02109525 Saint Bernard, St Bernard n02109961 Eskimo dog, husky n02110063 malamute, malemute, Alaskan malamute n02110185 Siberian husky n02110341 dalmatian, coach dog, carriage dog n02110627 affenpinscher, monkey pinscher, monkey dog n02110806 basenji n02110958 pug, pug-dog n02111129 Leonberg n02111277 Newfoundland, Newfoundland dog n02111500 Great Pyrenees n02111889 Samoyed, Samoyede n02112018 Pomeranian n02112137 chow, chow chow n02112350 keeshond n02112706 Brabancon griffon n02113023 Pembroke, Pembroke Welsh corgi n02113186 Cardigan, Cardigan Welsh corgi n02113624 toy poodle n02113712 miniature poodle n02113799 standard poodle n02113978 Mexican hairless n02114367 timber wolf, grey wolf, gray wolf, Canis lupus n02114548 white wolf, Arctic wolf, Canis lupus tundrarum n02114712 red wolf, maned wolf, Canis rufus, Canis niger n02114855 coyote, prairie wolf, brush wolf, Canis latrans n02115641 dingo, warrigal, warragal, Canis dingo n02115913 dhole, Cuon alpinus n02116738 African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus n02117135 hyena, hyaena n02119022 red fox, Vulpes vulpes n02119789 kit fox, Vulpes macrotis n02120079 Arctic fox, white fox, Alopex lagopus n02120505 grey fox, gray fox, Urocyon cinereoargenteus n02123045 tabby, tabby cat n02123159 tiger cat n02123394 Persian cat n02123597 Siamese cat, Siamese n02124075 Egyptian cat n02125311 cougar, puma, catamount, mountain lion, painter, panther, Felis concolor n02127052 lynx, catamount n02128385 leopard, Panthera pardus n02128757 snow leopard, ounce, Panthera uncia n02128925 jaguar, panther, Panthera onca, Felis onca n02129165 lion, king of beasts, Panthera leo n02129604 tiger, Panthera tigris n02130308 cheetah, chetah, Acinonyx jubatus n02132136 brown bear, bruin, Ursus arctos n02133161 American black bear, black bear, Ursus americanus, Euarctos americanus n02134084 ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus n02134418 sloth bear, Melursus ursinus, Ursus ursinus n02137549 mongoose n02138441 meerkat, mierkat n02165105 tiger beetle n02165456 ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle n02167151 ground beetle, carabid beetle n02168699 long-horned beetle, longicorn, longicorn beetle n02169497 leaf beetle, chrysomelid n02172182 dung beetle n02174001 rhinoceros beetle n02177972 weevil n02190166 fly n02206856 bee n02219486 ant, emmet, pismire n02226429 grasshopper, hopper n02229544 cricket n02231487 walking stick, walkingstick, stick insect n02233338 cockroach, roach n02236044 mantis, mantid n02256656 cicada, cicala n02259212 leafhopper n02264363 lacewing, lacewing fly n02268443 dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk n02268853 damselfly n02276258 admiral n02277742 ringlet, ringlet butterfly n02279972 monarch, monarch butterfly, milkweed butterfly, Danaus plexippus n02280649 cabbage butterfly n02281406 sulphur butterfly, sulfur butterfly n02281787 lycaenid, lycaenid butterfly n02317335 starfish, sea star n02319095 sea urchin n02321529 sea cucumber, holothurian n02325366 wood rabbit, cottontail, cottontail rabbit n02326432 hare n02328150 Angora, Angora rabbit n02342885 hamster n02346627 porcupine, hedgehog n02356798 fox squirrel, eastern fox squirrel, Sciurus niger n02361337 marmot n02363005 beaver n02364673 guinea pig, Cavia cobaya n02389026 sorrel n02391049 zebra n02395406 hog, pig, grunter, squealer, Sus scrofa n02396427 wild boar, boar, Sus scrofa n02397096 warthog n02398521 hippopotamus, hippo, river horse, Hippopotamus amphibius n02403003 ox n02408429 water buffalo, water ox, Asiatic buffalo, Bubalus bubalis n02410509 bison n02412080 ram, tup n02415577 bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis n02417914 ibex, Capra ibex n02422106 hartebeest n02422699 impala, Aepyceros melampus n02423022 gazelle n02437312 Arabian camel, dromedary, Camelus dromedarius n02437616 llama n02441942 weasel n02442845 mink n02443114 polecat, fitch, foulmart, foumart, Mustela putorius n02443484 black-footed ferret, ferret, Mustela nigripes n02444819 otter n02445715 skunk, polecat, wood pussy n02447366 badger n02454379 armadillo n02457408 three-toed sloth, ai, Bradypus tridactylus n02480495 orangutan, orang, orangutang, Pongo pygmaeus n02480855 gorilla, Gorilla gorilla n02481823 chimpanzee, chimp, Pan troglodytes n02483362 gibbon, Hylobates lar n02483708 siamang, Hylobates syndactylus, Symphalangus syndactylus n02484975 guenon, guenon monkey n02486261 patas, hussar monkey, Erythrocebus patas n02486410 baboon n02487347 macaque n02488291 langur n02488702 colobus, colobus monkey n02489166 proboscis monkey, Nasalis larvatus n02490219 marmoset n02492035 capuchin, ringtail, Cebus capucinus n02492660 howler monkey, howler n02493509 titi, titi monkey n02493793 spider monkey, Ateles geoffroyi n02494079 squirrel monkey, Saimiri sciureus n02497673 Madagascar cat, ring-tailed lemur, Lemur catta n02500267 indri, indris, Indri indri, Indri brevicaudatus n02504013 Indian elephant, Elephas maximus n02504458 African elephant, Loxodonta africana n02509815 lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens n02510455 giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca n02514041 barracouta, snoek n02526121 eel n02536864 coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch n02606052 rock beauty, Holocanthus tricolor n02607072 anemone fish n02640242 sturgeon n02641379 gar, garfish, garpike, billfish, Lepisosteus osseus n02643566 lionfish n02655020 puffer, pufferfish, blowfish, globefish n02666196 abacus n02667093 abaya n02669723 academic gown, academic robe, judge's robe n02672831 accordion, piano accordion, squeeze box n02676566 acoustic guitar n02687172 aircraft carrier, carrier, flattop, attack aircraft carrier n02690373 airliner n02692877 airship, dirigible n02699494 altar n02701002 ambulance n02704792 amphibian, amphibious vehicle n02708093 analog clock n02727426 apiary, bee house n02730930 apron n02747177 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin n02749479 assault rifle, assault gun n02769748 backpack, back pack, knapsack, packsack, rucksack, haversack n02776631 bakery, bakeshop, bakehouse n02777292 balance beam, beam n02782093 balloon n02783161 ballpoint, ballpoint pen, ballpen, Biro n02786058 Band Aid n02787622 banjo n02788148 bannister, banister, balustrade, balusters, handrail n02790996 barbell n02791124 barber chair n02791270 barbershop n02793495 barn n02794156 barometer n02795169 barrel, cask n02797295 barrow, garden cart, lawn cart, wheelbarrow n02799071 baseball n02802426 basketball n02804414 bassinet n02804610 bassoon n02807133 bathing cap, swimming cap n02808304 bath towel n02808440 bathtub, bathing tub, bath, tub n02814533 beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon n02814860 beacon, lighthouse, beacon light, pharos n02815834 beaker n02817516 bearskin, busby, shako n02823428 beer bottle n02823750 beer glass n02825657 bell cote, bell cot n02834397 bib n02835271 bicycle-built-for-two, tandem bicycle, tandem n02837789 bikini, two-piece n02840245 binder, ring-binder n02841315 binoculars, field glasses, opera glasses n02843684 birdhouse n02859443 boathouse n02860847 bobsled, bobsleigh, bob n02865351 bolo tie, bolo, bola tie, bola n02869837 bonnet, poke bonnet n02870880 bookcase n02871525 bookshop, bookstore, bookstall n02877765 bottlecap n02879718 bow n02883205 bow tie, bow-tie, bowtie n02892201 brass, memorial tablet, plaque n02892767 brassiere, bra, bandeau n02894605 breakwater, groin, groyne, mole, bulwark, seawall, jetty n02895154 breastplate, aegis, egis n02906734 broom n02909870 bucket, pail n02910353 buckle n02916936 bulletproof vest n02917067 bullet train, bullet n02927161 butcher shop, meat market n02930766 cab, hack, taxi, taxicab n02939185 caldron, cauldron n02948072 candle, taper, wax light n02950826 cannon n02951358 canoe n02951585 can opener, tin opener n02963159 cardigan n02965783 car mirror n02966193 carousel, carrousel, merry-go-round, roundabout, whirligig n02966687 carpenter's kit, tool kit n02971356 carton n02974003 car wheel n02977058 cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM n02978881 cassette n02979186 cassette player n02980441 castle n02981792 catamaran n02988304 CD player n02992211 cello, violoncello n02992529 cellular telephone, cellular phone, cellphone, cell, mobile phone n02999410 chain n03000134 chainlink fence n03000247 chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour n03000684 chain saw, chainsaw n03014705 chest n03016953 chiffonier, commode n03017168 chime, bell, gong n03018349 china cabinet, china closet n03026506 Christmas stocking n03028079 church, church building n03032252 cinema, movie theater, movie theatre, movie house, picture palace n03041632 cleaver, meat cleaver, chopper n03042490 cliff dwelling n03045698 cloak n03047690 clog, geta, patten, sabot n03062245 cocktail shaker n03063599 coffee mug n03063689 coffeepot n03065424 coil, spiral, volute, whorl, helix n03075370 combination lock n03085013 computer keyboard, keypad n03089624 confectionery, confectionary, candy store n03095699 container ship, containership, container vessel n03100240 convertible n03109150 corkscrew, bottle screw n03110669 cornet, horn, trumpet, trump n03124043 cowboy boot n03124170 cowboy hat, ten-gallon hat n03125729 cradle n03126707 crane n03127747 crash helmet n03127925 crate n03131574 crib, cot n03133878 Crock Pot n03134739 croquet ball n03141823 crutch n03146219 cuirass n03160309 dam, dike, dyke n03179701 desk n03180011 desktop computer n03187595 dial telephone, dial phone n03188531 diaper, nappy, napkin n03196217 digital clock n03197337 digital watch n03201208 dining table, board n03207743 dishrag, dishcloth n03207941 dishwasher, dish washer, dishwashing machine n03208938 disk brake, disc brake n03216828 dock, dockage, docking facility n03218198 dogsled, dog sled, dog sleigh n03220513 dome n03223299 doormat, welcome mat n03240683 drilling platform, offshore rig n03249569 drum, membranophone, tympan n03250847 drumstick n03255030 dumbbell n03259280 Dutch oven n03271574 electric fan, blower n03272010 electric guitar n03272562 electric locomotive n03290653 entertainment center n03291819 envelope n03297495 espresso maker n03314780 face powder n03325584 feather boa, boa n03337140 file, file cabinet, filing cabinet n03344393 fireboat n03345487 fire engine, fire truck n03347037 fire screen, fireguard n03355925 flagpole, flagstaff n03372029 flute, transverse flute n03376595 folding chair n03379051 football helmet n03384352 forklift n03388043 fountain n03388183 fountain pen n03388549 four-poster n03393912 freight car n03394916 French horn, horn n03400231 frying pan, frypan, skillet n03404251 fur coat n03417042 garbage truck, dustcart n03424325 gasmask, respirator, gas helmet n03425413 gas pump, gasoline pump, petrol pump, island dispenser n03443371 goblet n03444034 go-kart n03445777 golf ball n03445924 golfcart, golf cart n03447447 gondola n03447721 gong, tam-tam n03450230 gown n03452741 grand piano, grand n03457902 greenhouse, nursery, glasshouse n03459775 grille, radiator grille n03461385 grocery store, grocery, food market, market n03467068 guillotine n03476684 hair slide n03476991 hair spray n03478589 half track n03481172 hammer n03482405 hamper n03483316 hand blower, blow dryer, blow drier, hair dryer, hair drier n03485407 hand-held computer, hand-held microcomputer n03485794 handkerchief, hankie, hanky, hankey n03492542 hard disc, hard disk, fixed disk n03494278 harmonica, mouth organ, harp, mouth harp n03495258 harp n03496892 harvester, reaper n03498962 hatchet n03527444 holster n03529860 home theater, home theatre n03530642 honeycomb n03532672 hook, claw n03534580 hoopskirt, crinoline n03535780 horizontal bar, high bar n03538406 horse cart, horse-cart n03544143 hourglass n03584254 iPod n03584829 iron, smoothing iron n03590841 jack-o'-lantern n03594734 jean, blue jean, denim n03594945 jeep, landrover n03595614 jersey, T-shirt, tee shirt n03598930 jigsaw puzzle n03599486 jinrikisha, ricksha, rickshaw n03602883 joystick n03617480 kimono n03623198 knee pad n03627232 knot n03630383 lab coat, laboratory coat n03633091 ladle n03637318 lampshade, lamp shade n03642806 laptop, laptop computer n03649909 lawn mower, mower n03657121 lens cap, lens cover n03658185 letter opener, paper knife, paperknife n03661043 library n03662601 lifeboat n03666591 lighter, light, igniter, ignitor n03670208 limousine, limo n03673027 liner, ocean liner n03676483 lipstick, lip rouge n03680355 Loafer n03690938 lotion n03691459 loudspeaker, speaker, speaker unit, loudspeaker system, speaker system n03692522 loupe, jeweler's loupe n03697007 lumbermill, sawmill n03706229 magnetic compass n03709823 mailbag, postbag n03710193 mailbox, letter box n03710637 maillot n03710721 maillot, tank suit n03717622 manhole cover n03720891 maraca n03721384 marimba, xylophone n03724870 mask n03729826 matchstick n03733131 maypole n03733281 maze, labyrinth n03733805 measuring cup n03742115 medicine chest, medicine cabinet n03743016 megalith, megalithic structure n03759954 microphone, mike n03761084 microwave, microwave oven n03763968 military uniform n03764736 milk can n03769881 minibus n03770439 miniskirt, mini n03770679 minivan n03773504 missile n03775071 mitten n03775546 mixing bowl n03776460 mobile home, manufactured home n03777568 Model T n03777754 modem n03781244 monastery n03782006 monitor n03785016 moped n03786901 mortar n03787032 mortarboard n03788195 mosque n03788365 mosquito net n03791053 motor scooter, scooter n03792782 mountain bike, all-terrain bike, off-roader n03792972 mountain tent n03793489 mouse, computer mouse n03794056 mousetrap n03796401 moving van n03803284 muzzle n03804744 nail n03814639 neck brace n03814906 necklace n03825788 nipple n03832673 notebook, notebook computer n03837869 obelisk n03838899 oboe, hautboy, hautbois n03840681 ocarina, sweet potato n03841143 odometer, hodometer, mileometer, milometer n03843555 oil filter n03854065 organ, pipe organ n03857828 oscilloscope, scope, cathode-ray oscilloscope, CRO n03866082 overskirt n03868242 oxcart n03868863 oxygen mask n03871628 packet n03873416 paddle, boat paddle n03874293 paddlewheel, paddle wheel n03874599 padlock n03876231 paintbrush n03877472 pajama, pyjama, pj's, jammies n03877845 palace n03884397 panpipe, pandean pipe, syrinx n03887697 paper towel n03888257 parachute, chute n03888605 parallel bars, bars n03891251 park bench n03891332 parking meter n03895866 passenger car, coach, carriage n03899768 patio, terrace n03902125 pay-phone, pay-station n03903868 pedestal, plinth, footstall n03908618 pencil box, pencil case n03908714 pencil sharpener n03916031 perfume, essence n03920288 Petri dish n03924679 photocopier n03929660 pick, plectrum, plectron n03929855 pickelhaube n03930313 picket fence, paling n03930630 pickup, pickup truck n03933933 pier n03935335 piggy bank, penny bank n03937543 pill bottle n03938244 pillow n03942813 ping-pong ball n03944341 pinwheel n03947888 pirate, pirate ship n03950228 pitcher, ewer n03954731 plane, carpenter's plane, woodworking plane n03956157 planetarium n03958227 plastic bag n03961711 plate rack n03967562 plow, plough n03970156 plunger, plumber's helper n03976467 Polaroid camera, Polaroid Land camera n03976657 pole n03977966 police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria n03980874 poncho n03982430 pool table, billiard table, snooker table n03983396 pop bottle, soda bottle n03991062 pot, flowerpot n03992509 potter's wheel n03995372 power drill n03998194 prayer rug, prayer mat n04004767 printer n04005630 prison, prison house n04008634 projectile, missile n04009552 projector n04019541 puck, hockey puck n04023962 punching bag, punch bag, punching ball, punchball n04026417 purse n04033901 quill, quill pen n04033995 quilt, comforter, comfort, puff n04037443 racer, race car, racing car n04039381 racket, racquet n04040759 radiator n04041544 radio, wireless n04044716 radio telescope, radio reflector n04049303 rain barrel n04065272 recreational vehicle, RV, R.V. n04067472 reel n04069434 reflex camera n04070727 refrigerator, icebox n04074963 remote control, remote n04081281 restaurant, eating house, eating place, eatery n04086273 revolver, six-gun, six-shooter n04090263 rifle n04099969 rocking chair, rocker n04111531 rotisserie n04116512 rubber eraser, rubber, pencil eraser n04118538 rugby ball n04118776 rule, ruler n04120489 running shoe n04125021 safe n04127249 safety pin n04131690 saltshaker, salt shaker n04133789 sandal n04136333 sarong n04141076 sax, saxophone n04141327 scabbard n04141975 scale, weighing machine n04146614 school bus n04147183 schooner n04149813 scoreboard n04152593 screen, CRT screen n04153751 screw n04154565 screwdriver n04162706 seat belt, seatbelt n04179913 sewing machine n04192698 shield, buckler n04200800 shoe shop, shoe-shop, shoe store n04201297 shoji n04204238 shopping basket n04204347 shopping cart n04208210 shovel n04209133 shower cap n04209239 shower curtain n04228054 ski n04229816 ski mask n04235860 sleeping bag n04238763 slide rule, slipstick n04239074 sliding door n04243546 slot, one-armed bandit n04251144 snorkel n04252077 snowmobile n04252225 snowplow, snowplough n04254120 soap dispenser n04254680 soccer ball n04254777 sock n04258138 solar dish, solar collector, solar furnace n04259630 sombrero n04263257 soup bowl n04264628 space bar n04265275 space heater n04266014 space shuttle n04270147 spatula n04273569 speedboat n04275548 spider web, spider's web n04277352 spindle n04285008 sports car, sport car n04286575 spotlight, spot n04296562 stage n04310018 steam locomotive n04311004 steel arch bridge n04311174 steel drum n04317175 stethoscope n04325704 stole n04326547 stone wall n04328186 stopwatch, stop watch n04330267 stove n04332243 strainer n04335435 streetcar, tram, tramcar, trolley, trolley car n04336792 stretcher n04344873 studio couch, day bed n04346328 stupa, tope n04347754 submarine, pigboat, sub, U-boat n04350905 suit, suit of clothes n04355338 sundial n04355933 sunglass n04356056 sunglasses, dark glasses, shades n04357314 sunscreen, sunblock, sun blocker n04366367 suspension bridge n04367480 swab, swob, mop n04370456 sweatshirt n04371430 swimming trunks, bathing trunks n04371774 swing n04372370 switch, electric switch, electrical switch n04376876 syringe n04380533 table lamp n04389033 tank, army tank, armored combat vehicle, armoured combat vehicle n04392985 tape player n04398044 teapot n04399382 teddy, teddy bear n04404412 television, television system n04409515 tennis ball n04417672 thatch, thatched roof n04418357 theater curtain, theatre curtain n04423845 thimble n04428191 thresher, thrasher, threshing machine n04429376 throne n04435653 tile roof n04442312 toaster n04443257 tobacco shop, tobacconist shop, tobacconist n04447861 toilet seat n04456115 torch n04458633 totem pole n04461696 tow truck, tow car, wrecker n04462240 toyshop n04465501 tractor n04467665 trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi n04476259 tray n04479046 trench coat n04482393 tricycle, trike, velocipede n04483307 trimaran n04485082 tripod n04486054 triumphal arch n04487081 trolleybus, trolley coach, trackless trolley n04487394 trombone n04493381 tub, vat n04501370 turnstile n04505470 typewriter keyboard n04507155 umbrella n04509417 unicycle, monocycle n04515003 upright, upright piano n04517823 vacuum, vacuum cleaner n04522168 vase n04523525 vault n04525038 velvet n04525305 vending machine n04532106 vestment n04532670 viaduct n04536866 violin, fiddle n04540053 volleyball n04542943 waffle iron n04548280 wall clock n04548362 wallet, billfold, notecase, pocketbook n04550184 wardrobe, closet, press n04552348 warplane, military plane n04553703 washbasin, handbasin, washbowl, lavabo, wash-hand basin n04554684 washer, automatic washer, washing machine n04557648 water bottle n04560804 water jug n04562935 water tower n04579145 whiskey jug n04579432 whistle n04584207 wig n04589890 window screen n04590129 window shade n04591157 Windsor tie n04591713 wine bottle n04592741 wing n04596742 wok n04597913 wooden spoon n04599235 wool, woolen, woollen n04604644 worm fence, snake fence, snake-rail fence, Virginia fence n04606251 wreck n04612504 yawl n04613696 yurt n06359193 web site, website, internet site, site n06596364 comic book n06785654 crossword puzzle, crossword n06794110 street sign n06874185 traffic light, traffic signal, stoplight n07248320 book jacket, dust cover, dust jacket, dust wrapper n07565083 menu n07579787 plate n07583066 guacamole n07584110 consomme n07590611 hot pot, hotpot n07613480 trifle n07614500 ice cream, icecream n07615774 ice lolly, lolly, lollipop, popsicle n07684084 French loaf n07693725 bagel, beigel n07695742 pretzel n07697313 cheeseburger n07697537 hotdog, hot dog, red hot n07711569 mashed potato n07714571 head cabbage n07714990 broccoli n07715103 cauliflower n07716358 zucchini, courgette n07716906 spaghetti squash n07717410 acorn squash n07717556 butternut squash n07718472 cucumber, cuke n07718747 artichoke, globe artichoke n07720875 bell pepper n07730033 cardoon n07734744 mushroom n07742313 Granny Smith n07745940 strawberry n07747607 orange n07749582 lemon n07753113 fig n07753275 pineapple, ananas n07753592 banana n07754684 jackfruit, jak, jack n07760859 custard apple n07768694 pomegranate n07802026 hay n07831146 carbonara n07836838 chocolate sauce, chocolate syrup n07860988 dough n07871810 meat loaf, meatloaf n07873807 pizza, pizza pie n07875152 potpie n07880968 burrito n07892512 red wine n07920052 espresso n07930864 cup n07932039 eggnog n09193705 alp n09229709 bubble n09246464 cliff, drop, drop-off n09256479 coral reef n09288635 geyser n09332890 lakeside, lakeshore n09399592 promontory, headland, head, foreland n09421951 sandbar, sand bar n09428293 seashore, coast, seacoast, sea-coast n09468604 valley, vale n09472597 volcano n09835506 ballplayer, baseball player n10148035 groom, bridegroom n10565667 scuba diver n11879895 rapeseed n11939491 daisy n12057211 yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum n12144580 corn n12267677 acorn n12620546 hip, rose hip, rosehip n12768682 buckeye, horse chestnut, conker n12985857 coral fungus n12998815 agaric n13037406 gyromitra n13040303 stinkhorn, carrion fungus n13044778 earthstar n13052670 hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa n13054560 bolete n13133613 ear, spike, capitulum n15075141 toilet tissue, toilet paper, bathroom tissue ================================================ FILE: examples/gluon_character_cnn/README.md ================================================ # Character-level CNN Model in Gluon trained using Amazon Product Dataset In this example, we show how to create a service which classifies a review into product type using [Character-level Convolutional Network Model (CNN) model](https://papers.nips.cc/paper/5782-character-level-convolutional-networks-for-text-classification.pdf) model by Yann LeCunn. This model is trained on [Amazon product data](http://jmcauley.ucsd.edu/data/amazon/) and training detail can be found in a detailed tutorial from Thomas Delteil on [Character CNN training.](https://github.com/ThomasDelteil/CNN_NLP_MXNet). # Step by step to create service ## Step 1 - Download the Gluon Char CNN model file, model parameter and classification labels file to "/tmp/crepe" ```bash # Create a model directory mkdir /tmp/crepe # Download the model file curl -O https://raw.githubusercontent.com/awslabs/multi-model-server/master/examples/gluon_character_cnn/gluon_crepe.py # Download the parameters curl -O https://s3.amazonaws.com/model-server/model_archive_1.0/examples/mms-char-cnn-files/crepe_gluon_epoch6.params # Download classification labels file curl -O https://s3.amazonaws.com/model-server/model_archive_1.0/examples/mms-char-cnn-files/synset.txt # Move required files to the following folder mv crepe_gluon_epoch6.params gluon_crepe.py synset.txt /tmp/crepe ``` ## Step 2 - Look at the Gluon model/service file For Gluon models on MMS, the models are defined, within the MMS service file, the skeletal structure of the file looks like follows. ```python class GluonCrepe(HybridBlock): """ Hybrid Block gluon Crepe model """ def __init__(self, classes=7, **kwargs): ## Define model below pass class CharacterCNNService(object): """ Gluon Character-level Convolution Service """ def __init__(self): # The 69 characters as specified in the paper self.ALPHABET = list("abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+ =<>()[]{}") # Map Alphabets to index self.ALPHABET_INDEX = {letter: index for index, letter in enumerate(self.ALPHABET)} # max-length in characters for one document self.FEATURE_LEN = 1014 self.initialized = False def initialize(self, params): self.net = GluonCrepe() self.param_filename = "crepe_gluon_epoch6.params" self.model_name = params.manifest["model"]["modelName"] gpu_id = params.system_properties.get("gpu_id") model_dir = params.system_properties.get("model_dir") synset_file = os.path.join(model_dir, "synset.txt") param_file_path = os.path.join(model_dir, self.param_filename) if not os.path.isfile(param_file_path): raise OSError("Parameter file not found {}".format(param_file_path)) if not os.path.isfile(synset_file): raise OSError("synset file not available {}".format(synset_file)) self.ctx = mx.cpu() if gpu_id is None else mx.gpu(gpu_id) self.net.load_parameters(param_file_path, self.ctx) self.labels = [line.strip() for line in open(synset_file).readlines()] self.initialized = True self.net.hybridize(static_shape=True, static_alloc=True) # define preprocess, inference and postprocess methods ``` As shown, the Gluon model derives from the basic gluon hybrid block. Gluon hybrid blocks, provide performance of a symbolic model with a imperative model. More on Gluon, hybrid blocks [here](https://gluon.mxnet.io/chapter07_distributed-learning/hybridize.html). The fully defined service file can be found under [gluon_crepe.py](gluon_crepe.py), we define `preprocess`, `inference`, `postprocess` methods in this file. The input size is, limited to 1014, characters as mentioned in the paper. The output is of shape [0,7] as we classify the reviews into seven product categories. Both the input and output are passed on as 'application/json' based text content. # Step 3 - Prepare synset.txt with list of class names [synset.txt](synset.txt) is where we define list of all classes detected by the model. The pre-trained Character-level CNN model used in the example is trained to detect 7 classes including Books, CDs_and_Vinyl, Movies_and_TV and more. See synset.txt file for list of all classes. The list of classes in synset.txt will be loaded by MMS as list of labels in inference logic. ## Step 4 - Export model files with mxnet-model-export CLI tool With model file together with signature and files in the model folder, we are ready to export them to MMS model file. ```bash model-archiver --model-name crepe -f --model-path /tmp/crepe/ --handler gluon_crepe:crepe_inference --runtime python --export-path /tmp ``` A packaged model can be downloaded from [here.](https://s3.amazonaws.com/model-server/model_archive_1.0/examples/mms-char-cnn-files/crepe.mar) ## Step 5 - Establish inference service `crepe.mar` file is created by exporting model files. We also defined custom service under gluon_crepe.py. We are ready to establish the Character-level CNN inference service: ```bash multi-model-server --models crepe.mar --model-store /tmp ``` The endpoint is on localhost and port 8080. You can change them by passing --host and --port when establishing the service. ## Test inference service Now we can send post requests to the endpoint we just established. The key values of application/json input are 'review_title', 'review'. This can be a different value or combined to a single input , to achieve this preprocess method in gluon_crepe.py needs to be modified. Let's take up a movie, review ```bash curl -X POST http://127.0.0.1:8080/predictions/crepe -F "data=[{\"review_title\":\"Inception is the best\",\"review\": \"great direction and story\"}]" ``` Prediction result will be: ```json { "confidence": { "Clothing_Shoes_and_Jewelry": 0.004, "Home_and_Kitchen": 0.001, "Sports_and_Outdoors": 0.001, "CDs_and_Vinyl": 0.038, "Movies_and_TV": 0.59, "Cell_Phones_and_Accessories": 0.0, "Books": 0.362 }, "predicted": "Movies_and_TV" } ``` Let's try another review, this time for a music album. ```bash curl -X POST http://127.0.0.1:8080/predictions/crepe -F "data=[{\"review_title\":\"fantastic quality\",\"review\": \"quality sound playback\"}]" ``` Prediction result will be: ```json { "confidence": { "Clothing_Shoes_and_Jewelry": 0.028, "Home_and_Kitchen": 0.012, "Sports_and_Outdoors": 0.028, "CDs_and_Vinyl": 0.727, "Movies_and_TV": 0.118, "Cell_Phones_and_Accessories": 0.068, "Books": 0.015 }, "predicted": "CDs_and_Vinyl" } ``` References 1. [Character-level CNN](https://papers.nips.cc/paper/5782-character-level-convolutional-networks-for-text-classification.pdf) 2. [How to train Character-level CNN on gluon](https://github.com/ThomasDelteil/TextClassificationCNNs_MXNet) 3. [Web Demo of Character-level CNN on gluon](https://thomasdelteil.github.io/TextClassificationCNNs_MXNet/) ================================================ FILE: examples/gluon_character_cnn/gluon_crepe.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. import ast import os import mxnet as mx from mxnet import gluon, nd from mxnet.gluon import nn from mxnet.gluon.block import HybridBlock import numpy as np class GluonCrepe(HybridBlock): """ Hybrid Block gluon Crepe model """ def __init__(self, classes=7, **kwargs): super(GluonCrepe, self).__init__(**kwargs) self.NUM_FILTERS = 256 # number of convolutional filters per convolutional layer self.NUM_OUTPUTS = classes # number of classes self.FULLY_CONNECTED = 1024 # number of unit in the fully connected dense layer self.features = nn.HybridSequential() with self.name_scope(): self.features.add( nn.Conv1D(channels=self.NUM_FILTERS, kernel_size=7, activation='relu'), nn.MaxPool1D(pool_size=3, strides=3), nn.Conv1D(channels=self.NUM_FILTERS, kernel_size=7, activation='relu'), nn.MaxPool1D(pool_size=3, strides=3), nn.Conv1D(channels=self.NUM_FILTERS, kernel_size=3, activation='relu'), nn.Conv1D(channels=self.NUM_FILTERS, kernel_size=3, activation='relu'), nn.Conv1D(channels=self.NUM_FILTERS, kernel_size=3, activation='relu'), nn.Conv1D(channels=self.NUM_FILTERS, kernel_size=3, activation='relu'), nn.MaxPool1D(pool_size=3, strides=3), nn.Flatten(), nn.Dense(self.FULLY_CONNECTED, activation='relu'), nn.Dense(self.FULLY_CONNECTED, activation='relu'), ) self.output = nn.Dense(self.NUM_OUTPUTS) def hybrid_forward(self, F, x): x = self.features(x) x = self.output(x) return x class CharacterCNNService(object): """ Gluon Character-level Convolution Service """ def __init__(self): # The 69 characters as specified in the paper self.ALPHABET = list("abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+ =<>()[]{}") # Map Alphabets to index self.ALPHABET_INDEX = {letter: index for index, letter in enumerate(self.ALPHABET)} # max-length in characters for one document self.FEATURE_LEN = 1014 self.initialized = False def initialize(self, params): self.net = GluonCrepe() self.param_filename = "crepe_gluon_epoch6.params" self.model_name = params.manifest["model"]["modelName"] gpu_id = params.system_properties.get("gpu_id") model_dir = params.system_properties.get("model_dir") synset_file = os.path.join(model_dir, "synset.txt") param_file_path = os.path.join(model_dir, self.param_filename) if not os.path.isfile(param_file_path): raise OSError("Parameter file not found {}".format(param_file_path)) if not os.path.isfile(synset_file): raise OSError("synset file not available {}".format(synset_file)) self.ctx = mx.cpu() if gpu_id is None else mx.gpu(gpu_id) self.net.load_parameters(param_file_path, self.ctx) self.labels = [line.strip() for line in open(synset_file).readlines()] self.initialized = True self.net.hybridize(static_shape=True, static_alloc=True) def preprocess(self, data): """ Pre-process text to a encode it to a form, that gives spatial information to the CNN """ # build the text from the request if data[0].get('data') is not None: data = ast.literal_eval(data[0].get('data').decode('utf-8')) text = '{}|{}'.format(data[0].get('review_title'), data[0].get('review')) encoded = np.zeros([len(self.ALPHABET), self.FEATURE_LEN], dtype='float32') review = text.lower()[:self.FEATURE_LEN - 1:-1] i = 0 for letter in text: if i >= self.FEATURE_LEN: break; if letter in self.ALPHABET_INDEX: encoded[self.ALPHABET_INDEX[letter]][i] = 1 i += 1 return nd.array([encoded], ctx=self.ctx) def inference(self, data): # Call forward/hybrid_forward output = self.net(data) return output.softmax() def postprocess(self, data): # Post process and output the most likely category data = data[0] values = {val: float(int(data[i].asnumpy() * 1000) / 1000.0) for i, val in enumerate(self.labels)} index = int(nd.argmax(data, axis=0).asnumpy()[0]) predicted = self.labels[index] return [{'predicted': predicted, 'confidence': values}] def predict(self, data): data = self.preprocess(data) data = self.inference(data) return self.postprocess(data) svc = CharacterCNNService() def crepe_inference(data, context): res = "" if not svc.initialized: svc.initialize(context) if data is not None: res = svc.predict(data) return res ================================================ FILE: examples/gluon_character_cnn/signature.json ================================================ { "inputs": [ { "data_name": "data", "data_shape": [1,1014] } ], "input_type": "application/json", "outputs": [ { "data_name": "softmax", "data_shape": [0, 7] } ], "output_type": "application/json" } ================================================ FILE: examples/gluon_character_cnn/synset.txt ================================================ Home_and_Kitchen Books CDs_and_Vinyl Movies_and_TV Cell_Phones_and_Accessories Sports_and_Outdoors Clothing_Shoes_and_Jewelry ================================================ FILE: examples/lstm_ptb/README.md ================================================ # Sequence to Sequence inference with LSTM network trained on PenTreeBank data set In this example, we show how to create a service which generates sentences with a pre-trained LSTM model with deep model server. This model is trained on [PenTreeBank data](https://catalog.ldc.upenn.edu/ldc99t42) and training detail can be found in [MXNet example](https://github.com/apache/incubator-mxnet/tree/master/example/rnn). This model uses [MXNet Bucketing Module](https://mxnet.incubator.apache.org/how_to/bucketing.html) to deal with variable length input sentences and generates output sentences with the same length as inputs. # Step by step to create service ## Step 1 - Download the pre-trained LSTM model files, signature file and vocabulary dictionary file ```bash cd multi-model-server/examples/lstm_ptb curl -O https://s3.amazonaws.com/model-server/models/lstm_ptb/lstm_ptb-symbol.json curl -O https://s3.amazonaws.com/model-server/models/lstm_ptb/lstm_ptb-0100.params curl -O https://s3.amazonaws.com/model-server/models/lstm_ptb/vocab_dict.txt curl -O https://s3.amazonaws.com/model-server/models/lstm_ptb/signature.json ``` ## Step 2 - Verify signature file In this example, provided mxnet_vision_service.py template assume there is a `signature.json` file that describes input parameter and shape. After [Step 1](#step-1---download-the-pre-trained-lstm-model-files,-signature-file-and-vocabulary-dictionary-file) there should be a signature file in the lstm_ptb folder. Verify that this file exists before proceeding further. The signature file looks as follows. ```json { "inputs": [ { "data_name": "data", "data_shape": [ 1, 60 ], ... } ] } ``` Input data shape is (1, 60). For sequence to sequence models, the inputs can be variable length sequences. In the signature file the input shape should be set to the maximum length of the input sequence, which is the default bucket key. The bucket sizes are defined when training the model. In this example valid bucket sizes are 10, 20, 30, 40, 50 and 60. Default bucket key is the maximum value which is 60. Check [bucketing module tutorials](https://mxnet.incubator.apache.org/faq/bucketing.html) if you want to know more about the bucketing module in MXNet. ## Step 3 - Check vocabulary dictionary file [vocab_dict.txt](https://s3.amazonaws.com/model-server/models/lstm_ptb/vocab_dict.txt) is to store word to integer indexing information. In this example, each line in the text file represents a (word, index) pair. This file can be in different format and requires different customized parsing methods respectively. ## Step 4 - Create custom service class We provide custom service class template code in [model_service_template](../model_service_template) folder: 1. [model_handler.py](../model_service_template/model_handler.py) - A generic based service class. 2. [mxnet_utils](../model_service_template/mxnet_utils) - A python package that contains utility classes. ```bash cd multi-model-server/examples cp model_service_template/model_handler.py lstm_ptb/ cp -r model_service_template/mxnet_utils lstm_ptb/ ``` In this example, we need to implement `preprocess`, `inference` and `postprocess` methods in a custom service class. Implementation details are in [lstm_ptb_service.py](lstm_ptb_service.py). ## Step 5 - Package the model with `model-archiver` CLI utility In this step, we package the following: 1. pre-trained MXNet Model we downloaded in Step 1. 2. '[signature.json](signature.json)' file we prepared in step 2. 3. '[vocab_dict.txt](vocab_dict.txt)' file we prepared in step 3. 4. custom model service files we prepared in step 4. We use `model-archiver` command line utility (CLI) provided by MMS. Install `model-archiver` in case you have not: ```bash pip install model-archiver ``` This tool creates a .mar file that will be provided to MMS for serving inference requests. In following command line, we specify 'lstm_ptb_service:handle' as model archive entry point. ```bash cd multi-model-server/examples model-archiver --model-name lstm_ptb --model-path lstm_ptb --handler lstm_ptb_service:handle ``` ## Step 6 - Start the Inference Service Start the inference service by providing the 'lstm_ptb.mar' file we created in Step 5. By default, the server is started on the localhost at port 8080. ```bash cd multi-model-server multi-model-server --start --model-store examples --models lstm_ptb.mar ``` ## Test inference service Now we can send post requests to the endpoint we just established. Since the entire range of vocabularies in the training set is only 10,000, you may not get very good results with arbitrary test sentences. Instead, we recommend that you test with sentences from the [PTB test data set](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.test.txt). That being said, if you try some random text you should know that any word that isn't in that 10k vocabulary is encoded with an "invalid label" of 0. This will create a prediction result of '\n'. Note that in PTB data set, person name is represented by ``. The key value of application/json input is 'input_sentence'. This can be a different value and preprocess method in lstm_ptb_service.py needs to be modified respectively. ```bash curl -X POST http://127.0.0.1:8080/predictions/lstm_ptb -H "Content-Type: application/json" -d '[{"input_sentence": "on the exchange floor as soon as ual stopped trading we for a panic said one top floor trader"}]' ``` Prediction result will be: ```json { "prediction": "the 's the the as the 's the the 're to a analyst company trading at " } ``` Let's try another sentence: ```bash curl -X POST http://127.0.0.1:8080/predictions/lstm_ptb -H "Content-Type: application/json" -d '[{"input_sentence": "while friday '\''s debacle involved mainly professional traders rather than investors it left the market vulnerable to continued selling this morning traders said "}]' ``` Prediction result will be: ```json { "prediction": "the 's stock were in say than were will to to to the the week \n \n \n \n \n \n \n \n \n \n " } ``` References 1. [How to use MXNet bucketing module](https://mxnet.incubator.apache.org/how_to/bucketing.html) 2. [LSTM trained with PennTreeBank data set](https://github.com/apache/incubator-mxnet/tree/master/example/rnn) ================================================ FILE: examples/lstm_ptb/lstm_ptb_service.py ================================================ import json import os import mxnet as mx from mxnet_utils import nlp from model_handler import ModelHandler class MXNetLSTMService(ModelHandler): """ MXNetLSTMService service class. This service consumes a sentence from length 0 to 60 and generates a sentence with the same size. """ def __init__(self): super(MXNetLSTMService, self).__init__() self.mxnet_ctx = None self.mx_model = None self.labels = None self.signature = None self.data_names = None self.data_shapes = None self.epoch = 100 self.buckets = [10, 20, 30, 40, 50, 60] self.start_label = 1 self.invalid_key = "\n" self.invalid_label = 0 self.layout = "NT" self.vocab = {} self.idx2word = {} def initialize(self, context): super(MXNetLSTMService, self).initialize(context) properties = context.system_properties model_dir = properties.get("model_dir") gpu_id = properties.get("gpu_id") batch_size = properties.get("batch_size") if batch_size > 1: raise ValueError("Batch is not supported.") # reading signature.json file signature_file_path = os.path.join(model_dir, "signature.json") if not os.path.isfile(signature_file_path): raise RuntimeError("Missing signature.json file.") with open(signature_file_path) as f: self.signature = json.load(f) self.data_names = [] self.data_shapes = [] for input_data in self.signature["inputs"]: self.data_names.append(input_data["data_name"]) self.data_shapes.append((input_data['data_name'], tuple(input_data['data_shape']))) # reading vocab_dict.txt file vocab_dict_file = os.path.join(model_dir, "vocab_dict.txt") with open(vocab_dict_file, 'r') as vocab_file: self.vocab[self.invalid_key] = self.invalid_label for line in vocab_file: word_index = line.split(' ') if len(word_index) < 2 or word_index[0] == '': continue self.vocab[word_index[0]] = int(word_index[1].rstrip()) for key, val in self.vocab.items(): self.idx2word[val] = key # Load pre-trained lstm bucketing module num_layers = 2 num_hidden = 200 num_embed = 200 stack = mx.rnn.FusedRNNCell(num_hidden, num_layers=num_layers, mode="lstm").unfuse() # Define symbol generation function for bucket module def sym_gen(seq_len): data = mx.sym.Variable("data") embed = mx.sym.Embedding(data=data, input_dim=len(self.vocab), output_dim=num_embed, name="embed") stack.reset() outputs, _ = stack.unroll(seq_len, inputs=embed, merge_outputs=True) pred = mx.sym.Reshape(outputs, shape=(-1, num_hidden)) pred = mx.sym.FullyConnected(data=pred, num_hidden=len(self.vocab), name="pred") pred = mx.sym.softmax(pred, name='softmax') return pred, ('data',), None self.mxnet_ctx = mx.cpu() if gpu_id is None else mx.gpu(gpu_id) # Create bucketing module and load weights self.mx_model = mx.mod.BucketingModule( sym_gen=sym_gen, default_bucket_key=max(self.buckets), context=self.mxnet_ctx) checkpoint_prefix = "{}/{}".format(model_dir, "lstm_ptb") self.mx_model.bind(data_shapes=self.data_shapes, for_training=False) _, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(stack, checkpoint_prefix, self.epoch) self.mx_model.set_params(arg_params, aux_params) def preprocess(self, data): """ This service doesn't support batch, always get data from first item. :param data: :return: """ input_data = data[0].get("data") if input_data is None: input_data = data[0].get("body") # Convert a string of sentence to a list of string sent = input_data[0]["input_sentence"].lower().split(" ") assert len(sent) <= self.buckets[-1], "Sentence length must be no greater than %d." % (self.buckets[-1]) # Encode sentence to a list of int res, _ = nlp.encode_sentences( [sent], vocab=self.vocab, start_label=self.start_label, invalid_label=self.invalid_label) return res def inference(self, data): data_batch = nlp.pad_sentence( data[0], self.buckets, invalid_label=self.invalid_label, data_name=self.data_names[0], layout=self.layout) self.mx_model.forward(data_batch) return self.mx_model.get_outputs() def postprocess(self, data): # Generate predicted sentences word_idx = mx.nd.argmax(data[0], axis=1).asnumpy() res = "" for idx in word_idx: res += self.idx2word[idx] + " " ret = {"prediction": res} return [ret] # Following code is not necessary if your service class contains `handle(self, data, context)` function _service = MXNetLSTMService() def handle(data, context): if not _service.initialized: _service.initialize(context) if data is None: return None return _service.handle(data, context) ================================================ FILE: examples/metrics_cloudwatch/__init__.py ================================================ ================================================ FILE: examples/metrics_cloudwatch/metric_push_example.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Examples for pushing a log to boto client to cloudwatch """ import types import json import boto3 as boto from mms.metrics import system_metrics as sys_metric from mms.metrics.metric_encoder import MetricEncoder def generate_system_metrics(mod): """ Function acting as a stub for reading a log file, produces similar result :param mod: :return: """ members = dir(mod) for i in members: value = getattr(mod, i) if isinstance(value, types.FunctionType) and value.__name__ != 'collect_all': value() return json.dumps(sys_metric.system_metrics, indent=4, separators=(',', ':'), cls=MetricEncoder) def push_cloudwatch(metric_json, client): """ push metric to cloud watch, do some processing. :param metric_json: :param client: :return: """ metrics = json.loads(metric_json) cloud_metrics = [] for metric in metrics: cloud_metric = {} for key in metric.keys(): if key != 'RequestId' or key != 'HostName': cloud_metric[key] = metric[key] cloud_metrics.append(cloud_metric) client.put_metric_data( Namespace='MXNetModelServer', MetricData=cloud_metrics ) def connect_cloudwatch(): client = None try: client = boto.client('cloudwatch') except Exception as e: # pylint: disable=broad-except print(str(e)) return client if __name__ == '__main__': # Replace this with a log reader json_val = generate_system_metrics(sys_metric) cloud_client = connect_cloudwatch() if cloud_client is not None: push_cloudwatch(json_val, cloud_client) ================================================ FILE: examples/model_service_template/gluon_base_service.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Gluon Base service defines a Gluon base service for generic CNN """ import mxnet as mx import numpy as np import os import json import ndarray class GluonBaseService(object): """GluonBaseService defines a fundamental service for image classification task. In preprocess, input image buffer is read to NDArray and resized respect to input shape in signature. In post process, top-5 labels are returned. """ def __init__(self): self.param_filename = None self.model_name = None self.initialized = False self.ctx = None self.net = None self._signature = None self.labels = None self.signature = None def initialize(self, params): """ Initialization of the network :param params: This is the :func `Context` object :return: """ if self.net is None: raise NotImplementedError("Gluon network not defined") sys_prop = params.system_properties gpu_id = sys_prop.get("gpu_id") model_dir = sys_prop.get("model_dir") self.model_name = params.manifest["model"]["modelName"] self.ctx = mx.cpu() if gpu_id is None else mx.gpu(gpu_id) if self.param_filename is not None: param_file_path = os.path.join(model_dir, self.param_filename) if not os.path.isfile(param_file_path): raise OSError("Parameter file not found {}".format(param_file_path)) self.net.load_parameters(param_file_path, self.ctx) synset_file = os.path.join(model_dir, "synset.txt") signature_file_path = os.path.join(model_dir, "signature.json") if not os.path.isfile(signature_file_path): raise OSError("Signature file not found {}".format(signature_file_path)) if not os.path.isfile(synset_file): raise OSError("synset file not available {}".format(synset_file)) with open(signature_file_path) as sig_file: self.signature = json.load(sig_file) self.labels = [line.strip() for line in open(synset_file).readlines()] self.initialized = True def preprocess(self, data): """ This method considers only one input data :param data: Data is list of map format is [ { "parameterName": name "parameterValue": data }, {...} ] :return: """ param_name = self.signature['inputs'][0]['data_name'] input_shape = self.signature['inputs'][0]['data_shape'] img = data[0].get(param_name) if img is None: raise IOError("Invalid parameter given") # We are assuming input shape is NCHW [h, w] = input_shape[2:] img_arr = mx.img.imdecode(img) img_arr = mx.image.imresize(img_arr, w, h) img_arr = img_arr.astype(np.float32) img_arr /= 255 img_arr = mx.image.color_normalize(img_arr, mean=mx.nd.array([0.485, 0.456, 0.406]), std=mx.nd.array([0.229, 0.224, 0.225])) img_arr = mx.nd.transpose(img_arr, (2, 0, 1)) img_arr = img_arr.expand_dims(axis=0) return img_arr def inference(self, data): """ Internal inference methods for MMS service. Run forward computation and return output. Parameters ---------- data : list of NDArray Preprocessed inputs in NDArray format. Returns ------- list of NDArray Inference output. """ model_input = data.as_in_context(self.ctx) output = self.net(model_input) return output.softmax() def postprocess(self, data): assert hasattr(self, 'labels'), \ "Can't find labels attribute. Did you put synset.txt file into " \ "model archive or manually load class label file in __init__?" return [[ndarray.top_probability(d, self.labels, top=5) for d in data]] def predict(self, data): data = self.preprocess(data) data = self.inference(data) return self.postprocess(data) ================================================ FILE: examples/model_service_template/model_handler.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ ModelHandler defines a base model handler. """ import logging import time class ModelHandler(object): """ A base Model handler implementation. """ def __init__(self): self.error = None self._context = None self._batch_size = 0 self.initialized = False def initialize(self, context): """ Initialize model. This will be called during model loading time :param context: Initial context contains model server system properties. :return: """ self._context = context self._batch_size = context.system_properties["batch_size"] self.initialized = True def preprocess(self, batch): """ Transform raw input into model input data. :param batch: list of raw requests, should match batch size :return: list of preprocessed model input data """ assert self._batch_size == len(batch), "Invalid input batch size: {}".format(len(batch)) return None def inference(self, model_input): """ Internal inference methods :param model_input: transformed model input data :return: list of inference output in NDArray """ return None def postprocess(self, inference_output): """ Return predict result in batch. :param inference_output: list of inference output :return: list of predict results """ return ["OK"] * self._batch_size def handle(self, data, context): """ Custom service entry point function. :param data: list of objects, raw input from request :param context: model server context :return: list of outputs to be send back to client """ self.error = None # reset earlier errors try: preprocess_start = time.time() data = self.preprocess(data) inference_start = time.time() data = self.inference(data) postprocess_start = time.time() data = self.postprocess(data) end_time = time.time() metrics = context.metrics metrics.add_time("PreprocessTime", round((inference_start - preprocess_start) * 1000, 2)) metrics.add_time("InferenceTime", round((postprocess_start - inference_start) * 1000, 2)) metrics.add_time("PostprocessTime", round((end_time - postprocess_start) * 1000, 2)) return data except Exception as e: logging.error(e, exc_info=True) request_processor = context.request_processor request_processor.report_status(500, "Unknown inference error") return [str(e)] * self._batch_size ================================================ FILE: examples/model_service_template/mxnet_model_service.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ MXNetModelService defines an API for MXNet service. """ import json import os import mxnet as mx from mxnet.io import DataBatch from model_handler import ModelHandler class MXNetModelService(ModelHandler): """ MXNetBaseService defines the fundamental loading model and inference operations when serving MXNet model. This is a base class and needs to be inherited. """ def __init__(self): super(MXNetModelService, self).__init__() self.mxnet_ctx = None self.mx_model = None self.labels = None self.signature = None self.epoch = 0 # noinspection PyMethodMayBeStatic def get_model_files_prefix(self, context): return context.manifest["model"]["modelName"] def initialize(self, context): """ Initialize model. This will be called during model loading time :param context: Initial context contains model server system properties. :return: """ super(MXNetModelService, self).initialize(context) assert self._batch_size == 1, "Batch is not supported." properties = context.system_properties model_dir = properties.get("model_dir") gpu_id = properties.get("gpu_id") signature_file_path = os.path.join(model_dir, "signature.json") if not os.path.isfile(signature_file_path): raise RuntimeError("Missing signature.json file.") with open(signature_file_path) as f: self.signature = json.load(f) model_files_prefix = self.get_model_files_prefix(context) archive_synset = os.path.join(model_dir, "synset.txt") if os.path.isfile(archive_synset): synset = archive_synset self.labels = [line.strip() for line in open(synset).readlines()] data_names = [] data_shapes = [] for input_data in self.signature["inputs"]: data_name = input_data["data_name"] data_shape = input_data["data_shape"] # Set batch size data_shape[0] = self._batch_size # Replace 0 entry in data shape with 1 for binding executor. for idx in range(len(data_shape)): if data_shape[idx] == 0: data_shape[idx] = 1 data_names.append(data_name) data_shapes.append((data_name, tuple(data_shape))) checkpoint_prefix = "{}/{}".format(model_dir, model_files_prefix) # Load MXNet module self.mxnet_ctx = mx.cpu() if gpu_id is None else mx.gpu(gpu_id) sym, arg_params, aux_params = mx.model.load_checkpoint(checkpoint_prefix, self.epoch) # noinspection PyTypeChecker self.mx_model = mx.mod.Module(symbol=sym, context=self.mxnet_ctx, data_names=data_names, label_names=None) self.mx_model.bind(for_training=False, data_shapes=data_shapes) self.mx_model.set_params(arg_params, aux_params, allow_missing=True, allow_extra=True) def preprocess(self, batch): """ Transform raw input into model input data. :param batch: list of raw requests, should match batch size :return: list of preprocessed model input data """ assert self._batch_size == len(batch), "Invalid input batch size: {}".format(len(batch)) ret = [] param_name = self.signature['inputs'][0]['data_name'] for idx, request in enumerate(batch): data = request.get(param_name) if data is None: data = request.get("body") if data is None: data = request.get("data") ret.append(map(mx.nd.array, data)) return ret def inference(self, model_input): """ Internal inference methods for MXNet. Run forward computation and return output. :param model_input: list of NDArray Preprocessed inputs in NDArray format. :return: list of NDArray Inference output. """ if self.error is not None: return None # Check input shape check_input_shape(model_input, self.signature) model_input = [item.as_in_context(self.mxnet_ctx) for item in model_input] self.mx_model.forward(DataBatch(model_input)) model_input = self.mx_model.get_outputs() # by pass lazy evaluation get_outputs either returns a list of nd arrays # a list of list of NDArray for d in model_input: if isinstance(d, list): for n in model_input: if isinstance(n, mx.ndarray.ndarray.NDArray): n.wait_to_read() elif isinstance(d, mx.ndarray.ndarray.NDArray): d.wait_to_read() return model_input def postprocess(self, inference_output): if self.error is not None: return [self.error] * self._batch_size return [str(d.asnumpy().tolist()) for d in inference_output] def check_input_shape(inputs, signature): """ Check input data shape consistency with signature. Parameters ---------- inputs : List of NDArray Input data in NDArray format. signature : dict Dictionary containing model signature. """ assert isinstance(inputs, list), 'Input data must be a list.' assert len(inputs) == len(signature['inputs']), \ "Input number mismatches with " \ "signature. %d expected but got %d." \ % (len(signature['inputs']), len(inputs)) for input_data, sig_input in zip(inputs, signature["inputs"]): assert isinstance(input_data, mx.nd.NDArray), 'Each input must be NDArray.' assert len(input_data.shape) == len(sig_input["data_shape"]), \ 'Shape dimension of input %s mismatches with ' \ 'signature. %d expected but got %d.' \ % (sig_input['data_name'], len(sig_input['data_shape']), len(input_data.shape)) for idx in range(len(input_data.shape)): if idx != 0 and sig_input['data_shape'][idx] != 0: assert sig_input['data_shape'][idx] == input_data.shape[idx], \ 'Input %s has different shape with ' \ 'signature. %s expected but got %s.' \ % (sig_input['data_name'], sig_input['data_shape'], input_data.shape) ================================================ FILE: examples/model_service_template/mxnet_utils/__init__.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ MXNet Utils """ ================================================ FILE: examples/model_service_template/mxnet_utils/image.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Image utils """ import base64 import sys from io import BytesIO import mxnet as mx import numpy as np from PIL import Image from mxnet import image as img def transform_shape(img_arr, dim_order='NCHW'): """ Rearrange image NDArray shape to 'NCHW' or 'NHWC' which is valid for MXNet model input. Input image NDArray should has dim_order of 'HWC'. :param img_arr: NDArray Image in NDArray format with shape (channel, width, height) :param dim_order: str Output image dimension order. Valid values are 'NCHW' and 'NHWC' :return: NDArray Image in NDArray format with dim_order shape """ assert dim_order in 'NCHW' or dim_order in 'NHWC', "dim_order must be 'NCHW' or 'NHWC'." if dim_order == 'NCHW': img_arr = mx.nd.transpose(img_arr, (2, 0, 1)) output = mx.nd.expand_dims(img_arr, axis=0) return output def read(buf, flag=1, to_rgb=True, out=None): """ Read and decode an image to an NDArray. Input image NDArray should has dim_order of 'HWC'. Note: `imread` uses OpenCV (not the CV2 Python library). MXNet must have been built with USE_OPENCV=1 for `imdecode` to work. :param buf: str/bytes or numpy.ndarray Binary image data as string or numpy ndarray. :param flag: {0, 1}, default 1 1 for three channel color output. 0 for grayscale output. :param to_rgb: bool, default True True for RGB formatted output (MXNet default). False for BGR formatted output (OpenCV default). :param out: NDArray, optional Output buffer. Use `None` for automatic allocation. :return: NDArray An `NDArray` containing the image. Example ------- >>> buf = open("flower.jpg", 'rb').read() >>> image.read(buf) """ return img.imdecode(buf, flag, to_rgb, out) def write(img_arr, flag=1, output_format='jpeg', dim_order='CHW'): """ Write an NDArray to a base64 string. :param img_arr: NDArray Image in NDArray format with shape (channel, width, height). :param flag: {0, 1}, default 1 1 for three channel color output. 0 for grayscale output. :param output_format: str Output image format. :param dim_order: str Input image dimension order. Valid values are 'CHW' and 'HWC' :return: str Image in base64 string format """ assert dim_order in 'CHW' or dim_order in 'HWC', "dim_order must be 'CHW' or 'HWC'." if dim_order == 'CHW': img_arr = mx.nd.transpose(img_arr, (1, 2, 0)) if flag == 1: mode = 'RGB' else: mode = 'L' img_arr = mx.nd.reshape(img_arr, (img_arr.shape[0], img_arr.shape[1])) img_arr = img_arr.astype(np.uint8).asnumpy() image = Image.fromarray(img_arr, mode) output = BytesIO() image.save(output, format=output_format) output.seek(0) if sys.version_info[0] < 3: return base64.b64encode(output.getvalue()) else: return base64.b64encode(output.getvalue()).decode("utf-8") def resize(src, new_width, new_height, interp=2): """ Resizes image to new_width and new_height. Input image NDArray should has dim_order of 'HWC'. :param src: NDArray Source image in NDArray format :param new_width: int Width in pixel for resized image :param new_height: int Height in pixel for resized image :param interp: int interpolation method for all resizing operations Possible values: 0: Nearest Neighbors Interpolation. 1: Bilinear interpolation. 2: Area-based (resampling using pixel area relation). It may be a preferred method for image decimation, as it gives moire-free results. But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default). 3: Bicubic interpolation over 4x4 pixel neighborhood. 4: Lanczos interpolation over 8x8 pixel neighborhood. 9: Cubic for enlarge, area for shrink, bilinear for others 10: Random select from interpolation method metioned above. Note: When shrinking an image, it will generally look best with AREA-based interpolation, whereas, when enlarging an image, it will generally look best with Bicubic (slow) or Bilinear (faster but still looks OK). More details can be found in the documentation of OpenCV, please refer to http://docs.opencv.org/master/da/d54/group__imgproc__transform.html. :return: NDArray An `NDArray` containing the resized image. """ return img.imresize(src, new_width, new_height, interp) def fixed_crop(src, x0, y0, w, h, size=None, interp=2): """ Crop src at fixed location, and (optionally) resize it to size. Input image NDArray should has dim_order of 'HWC'. :param src: NDArray Input image :param x0: int Left boundary of the cropping area :param y0 : int Top boundary of the cropping area :param w : int Width of the cropping area :param h : int Height of the cropping area :param size : tuple of (w, h) Optional, resize to new size after cropping :param interp : int, optional, default=2 Interpolation method. See resize for details. :return: NDArray An `NDArray` containing the cropped image. """ return img.fixed_crop(src, x0, y0, w, h, size, interp) def color_normalize(src, mean, std=None): """ Normalize src with mean and std. :param src : NDArray Input image :param mean : NDArray RGB mean to be subtracted :param std : NDArray RGB standard deviation to be divided :return: NDArray An `NDArray` containing the normalized image. """ src = src.astype(np.float32) return img.color_normalize(src, mean, std) ================================================ FILE: examples/model_service_template/mxnet_utils/ndarray.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ NDArray utils """ import mxnet as mx import numpy as np def top_probability(data, labels, top=5): """ Get top probability prediction from NDArray. :param data: NDArray Data to be predicted :param labels: List List of class labels :param top: :return: List List of probability: class pairs in sorted order """ dim = len(data.shape) if dim > 2: data = mx.nd.array( np.squeeze(data.asnumpy(), axis=tuple(range(dim)[2:]))) sorted_prob = mx.nd.argsort(data[0], is_ascend=False) # pylint: disable=deprecated-lambda top_prob = map(lambda x: int(x.asscalar()), sorted_prob[0:top]) return [{'probability': float(data[0, i].asscalar()), 'class': labels[i]} for i in top_prob] ================================================ FILE: examples/model_service_template/mxnet_utils/nlp.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ NLP utils """ import bisect import mxnet as mx import numpy as np def encode_sentences(sentences, vocab=None, invalid_label=-1, invalid_key='\n', start_label=0): """ Encode sentences and (optionally) build a mapping from string tokens to integer indices. Unknown keys will be added to vocabulary. :param sentences: list of list of str A list of sentences to encode. Each sentence should be a list of string tokens. :param vocab: None or dict of str -> int Optional input Vocabulary :param invalid_label: int, default -1 Index for invalid token, like :param invalid_key: str, default '\\n' Key for invalid token. Use '\\n' for end of sentence by default. :param start_label: int lowest index. :returns: res : list of list of int encoded sentences vocab : dict of str -> int result vocabulary """ idx = start_label if vocab is None: vocab = {invalid_key: invalid_label} new_vocab = True else: new_vocab = False res = [] for sent in sentences: coded = [] for word in sent: if word not in vocab: if not new_vocab: coded.append(invalid_label) continue else: if idx == invalid_label: idx += 1 vocab[word] = idx idx += 1 coded.append(vocab[word]) res.append(coded) return res, vocab def pad_sentence(sentence, buckets, invalid_label=-1, data_name='data', layout='NT'): """ Pad a sentence to closest length in provided buckets. :param sentence: list of int A list of integer representing an encoded sentence. :param buckets: list of int Size of the data buckets. :param invalid_label: int, optional Index for invalid token, like . :param data_name: str, optional Input data name. :param layout: str, optional Format of data and label. 'NT' means (batch_size, length) and 'TN' means (length, batch_size). :return: mx.io.DataBatch DataBatch contains sentence. """ buck = bisect.bisect_left(buckets, len(sentence)) buff = np.full((buckets[buck],), invalid_label, dtype='float32') buff[:len(sentence)] = sentence sent_bucket = buckets[buck] pad_sent = mx.nd.array([buff], dtype='float32') shape = (1, sent_bucket) if layout == 'NT' else (sent_bucket, 1) return mx.io.DataBatch([pad_sent], pad=0, bucket_key=sent_bucket, provide_data=[mx.io.DataDesc( name=data_name, shape=shape, layout=layout)]) ================================================ FILE: examples/model_service_template/mxnet_vision_batching.py ================================================ # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. import mxnet as mx import json import os import numpy as np from collections import namedtuple import logging class MXNetVisionServiceBatching(object): def __init__(self): """ Initialization for MXNet Vision Service supporting batch inference """ self.mxnet_ctx = None self.mx_model = None self.labels = None self.signature = None self.epoch = 0 self._context = None self._batch_size = 0 self.initialized = False self.erroneous_reqs = set() def top_probability(self, data, labels, top=5): """ Get top probability prediction from NDArray. :param data: NDArray Data to be predicted :param labels: List List of class labels :param top: :return: List List of probability: class pairs in sorted order """ dim = len(data.shape) if dim > 2: data = mx.nd.array( np.squeeze(data.asnumpy(), axis=tuple(range(dim)[2:]))) sorted_prob = mx.nd.argsort(data[0], is_ascend=False) top_prob = map(lambda x: int(x.asscalar()), sorted_prob[0:top]) return [{'probability': float(data[0, i].asscalar()), 'class': labels[i]} for i in top_prob] def initialize(self, context): """ Initialize model. This will be called during model loading time :param context: Initial context contains model server system properties. :return: """ self._context = context self._batch_size = context.system_properties["batch_size"] self.initialized = True properties = context.system_properties model_dir = properties.get("model_dir") gpu_id = properties.get("gpu_id") signature_file_path = os.path.join(model_dir, "signature.json") if not os.path.isfile(signature_file_path): raise RuntimeError("Missing signature.json file.") with open(signature_file_path) as f: self.signature = json.load(f) model_files_prefix = context.manifest["model"]["modelName"] archive_synset = os.path.join(model_dir, "synset.txt") if os.path.isfile(archive_synset): synset = archive_synset self.labels = [line.strip() for line in open(synset).readlines()] data_names = [] data_shapes = [] for input_data in self.signature["inputs"]: data_name = input_data["data_name"] data_shape = input_data["data_shape"] # Set batch size data_shape[0] = self._batch_size # Replace 0 entry in data shape with 1 for binding executor. for idx in range(len(data_shape)): if data_shape[idx] == 0: data_shape[idx] = 1 data_names.append(data_name) data_shapes.append((data_name, tuple(data_shape))) checkpoint_prefix = "{}/{}".format(model_dir, model_files_prefix) # Load MXNet module self.mxnet_ctx = mx.cpu() if gpu_id is None else mx.gpu(gpu_id) sym, arg_params, aux_params = mx.model.load_checkpoint(checkpoint_prefix, self.epoch) self.mx_model = mx.mod.Module(symbol=sym, context=self.mxnet_ctx, data_names=data_names, label_names=None) self.mx_model.bind(for_training=False, data_shapes=data_shapes) self.mx_model.set_params(arg_params, aux_params, allow_missing=True, allow_extra=True) def inference(self, model_input): """ Internal inference methods for MXNet. Run forward computation and return output. :param model_input: list of NDArray Preprocessed inputs in NDArray format. :return: list of NDArray Inference output. """ batch = namedtuple('Batch', ['data']) self.mx_model.forward(batch([model_input]), is_train=False) outputs = self.mx_model.get_outputs() res = mx.ndarray.split(outputs[0], axis=0, num_outputs=outputs[0].shape[0]) res = [res] if not isinstance(res, list) else res return res def preprocess(self, request): """ Decode all input images into ndarray. Note: This implementation doesn't properly handle error cases in batch mode, If one of the input images is corrupted, all requests in the batch will fail. :param request: :return: """ img_list = [] param_name = self.signature['inputs'][0]['data_name'] input_shape = self.signature['inputs'][0]['data_shape'] # We are assuming input shape is NCHW [c, h, w] = input_shape[1:] # Clear error requests set. self.erroneous_reqs.clear() for idx, data in enumerate(request): img = data.get(param_name) if img is None: img = data.get("body") if img is None: img = data.get("data") if img is None or len(img) == 0: logging.error("Error processing request") self.erroneous_reqs.add(idx) continue try: img_arr = mx.image.imdecode(img, 1, True, None) except Exception as e: logging.error(e, exc_info=True) self.erroneous_reqs.add(idx) continue img_arr = mx.image.imresize(img_arr, w, h, 2) img_arr = mx.nd.transpose(img_arr, (2, 0, 1)) self._num_requests = idx + 1 img_list.append(img_arr) logging.debug("Worker :{} received {} requests".format(os.getpid(), self._num_requests)) reqs = mx.nd.stack(*img_list) reqs = reqs.as_in_context(self.mxnet_ctx) if (self._batch_size - self._num_requests) != 0: padding = mx.nd.zeros((self._batch_size - self._num_requests, c, h, w), self.mxnet_ctx, 'uint8') reqs = mx.nd.concat(reqs, padding, dim=0) return reqs def postprocess(self, data): res = [] for idx, resp in enumerate(data[:self._num_requests]): if idx not in self.erroneous_reqs: res.append(self.top_probability(resp, self.labels, top=5)) else: res.append("This request was not processed successfully. Refer to mms.log for additional information") return res _service = MXNetVisionServiceBatching() def handle(data, context): if not _service.initialized: _service.initialize(context) if data is None: return None try: data = _service.preprocess(data) data = _service.inference(data) data = _service.postprocess(data) return data except Exception as e: logging.error(e, exc_info=True) request_processor = context.request_processor request_processor.report_status(500, "Unknown inference error") return [str(e)] * _service._batch_size ================================================ FILE: examples/model_service_template/mxnet_vision_service.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ MXNetVisionService defines a MXNet base vision service """ import logging from mxnet_model_service import MXNetModelService from mxnet_utils import image, ndarray class MXNetVisionService(MXNetModelService): """ MXNetVisionService defines a fundamental service for image classification task. In preprocess, input image buffer is read to NDArray and resized respect to input shape in signature. In post process, top-5 labels are returned. """ def preprocess(self, request): """ Decode all input images into ndarray. Note: This implementation doesn't properly handle error cases in batch mode, If one of the input images is corrupted, all requests in the batch will fail. :param request: :return: """ img_list = [] param_name = self.signature['inputs'][0]['data_name'] input_shape = self.signature['inputs'][0]['data_shape'] for idx, data in enumerate(request): img = data.get(param_name) if img is None: img = data.get("body") if img is None: img = data.get("data") if img is None or len(img) == 0: self.error = "Empty image input" return None # We are assuming input shape is NCHW [h, w] = input_shape[2:] try: img_arr = image.read(img) except Exception as e: logging.warn(e, exc_info=True) self.error = "Corrupted image input" return None img_arr = image.resize(img_arr, w, h) img_arr = image.transform_shape(img_arr) img_list.append(img_arr) return img_list def postprocess(self, data): if self.error is not None: return [self.error] * self._batch_size assert hasattr(self, 'labels'), \ "Can't find labels attribute. Did you put synset.txt file into " \ "model archive or manually load class label file in __init__?" return [ndarray.top_probability(d, self.labels, top=5) for d in data] _service = MXNetVisionService() def handle(data, context): if not _service.initialized: _service.initialize(context) if data is None: return None return _service.handle(data, context) ================================================ FILE: examples/mxnet_vision/README.md ================================================ # MXNet Vision Service In this example, we show how to use a pre-trained MXNet model to performing real time Image Classification with MMS We choose squeezenet in this example: [Iandola, et al.](https://arxiv.org/pdf/1602.07360v4.pdf). But the same should work for other MXNet Image Classification models. The inference service would return the response in the json format. # Objective 1. Demonstrate how to package a pre-trained squeezenet into model archive (.mar) file 2. Demonstrate how to create model service code based on provided service template 3. Demonstrate how to load model archive (.mar) file into MMS and run inference. ## Step 1 - Download the pre-trained squeezenet Model You will need the model files in this example. Check this example's directory in case they're already downloaded. Otherwise, you can `curl` the files or download them via your browser: ```bash cd multi-model-server/examples/mxnet_vision curl -O https://s3.amazonaws.com/model-server/model_archive_1.0/examples/squeezenet_v1.1/squeezenet_v1.1-symbol.json curl -O https://s3.amazonaws.com/model-server/model_archive_1.0/examples/squeezenet_v1.1/squeezenet_v1.1-0000.params ``` Alternatively, use these links to download the Symbol and Params files via your browser: 1. squeezenet_v1.1-symbol.json 2. squeezenet_v1.1-0000.params ## Step 2 - Prepare the signature file Define Input and Output name, type and shape in `signature.json` file. The signature for this example looks like below: ```json { "inputs": [ { "data_name": "data", "data_shape": [ 0, 3, 224, 224 ] } ] } ``` In this pre-trained model, input name is 'data' and shape is '(1,3,224,224)'. Where, the expected input is a color image (3 channels - RGB) of shape 224*224. We also expect input type is a binary JPEG images. In provided mxnet_vision_service.py, you will see the code that take care of converting binary images to tensor NDArray used by MXNet. *Note:* Typically, if you train your own model, you define the Input and Output Layer name and shape when defining the Neural Network. If you are using a pre-trained MXNet model, to get these Input and Output name and dimensions, you can load the Model and extract the Input and Output layer details. Unfortunately, there are no APIs or easy way to extract the Input shape. Example code below: ```python >>> import mxnet as mx >>> load_symbol, args, auxs = mx.model.load_checkpoint("squeezenet_v1.1", 0) >>> mod = mx.mod.Module(load_symbol, label_names=None, data_names=['data'], context=mx.cpu()) >>> mod.data_names ['data'] >>> mod.bind(data_shapes=[('data', (1, 3, 224, 224))]) >>> mod.set_params(args, auxs) >>> print(mod.data_names) >>> print(mod.data_shapes) >>> print(mod.output_names) >>> print(mod.output_shapes) ['data'] [DataDesc[data,(1, 3, 224, 224),,NCHW]] ['detection_output'] [('detection_output', (1, 6132, 6))] ``` ## Step 3 - Prepare synset.txt with list of class names [synset.txt](https://s3.amazonaws.com/model-server/model_archive_1.0/examples/squeezenet_v1.1/synset.txt) is where we define list of all classes detected by the model. The list of classes in synset.txt will be loaded by MMS as list of labels in inference logic. You can use `curl` to download it. ```bash cd multi-model-server/examples/mxnet_vision curl -O https://s3.amazonaws.com/model-server/model_archive_1.0/examples/squeezenet_v1.1/synset.txt ``` Alternatively, use following link to download: synset.txt ## Step 4 - Create custom service class We provided custom service class template code in [model_service_template](../model_service_template) folder: 1. [model_handler.py](../model_service_template/model_handler.py) - A generic based service class. 2. [mxnet_model_service.py](../model_service_template/mxnet_model_service.py) - A MXNet base service class. 3. [mxnet_vision_service.py](../model_service_template/mxnet_vision_service.py) - A MXNet Vision service class. 4. [mxnet_utils](../model_service_template/mxnet_utils) - A python package that contains utility classes. In this example, you can simple copy them into mxnet_vision folder, as use provided mxnet_vision_service.py as user model archive entry point. ```bash cd multi-model-server/examples cp -r model_service_template/* mxnet_vision/ ``` ## Step 5 - Package the model with `model-archiver` CLI utility In this step, we package the following: 1. pre-trained MXNet Model we downloaded in Step 1. 2. '[signature.json](https://s3.amazonaws.com/model-server/model_archive_1.0/examples/squeezenet_v1.1/signature.json)' file we prepared in step 2. 3. '[synset.txt](https://s3.amazonaws.com/model-server/model_archive_1.0/examples/squeezenet_v1.1/synset.txt)' file we prepared in step 3. 4. custom model service files we prepared in step 4. We use `model-archiver` command line utility (CLI) provided by MMS. Install `model-archiver` in case you have not: ```bash pip install model-archiver ``` This tool create a .mar file that will be provided to MMS for serving inference requests. In following command line, we specify 'mxnet_model_service:handle' as model archive entry point. ```bash cd multi-model-server/examples model-archiver --model-name squeezenet_v1.1 --model-path mxnet_vision --handler mxnet_vision_service:handle ``` ## Step 6 - Start the Inference Service Start the inference service by providing the 'squeezenet_v1.1.mar' file we created in Step 5. By default, the server is started on the localhost at port 8080. ```bash cd multi-model-server multi-model-server --start --model-store examples --models squeezenet_v1.1.mar ``` Awesome! we have successfully packaged a pre-trained MXNet model and started a inference service. `Note:` In this example, MMS loads the .mar file from the local file system. However, you can also store the model archive (.mar file) over a network-accessible storage such as AWS S3, and use a URL such as http:// or https:// to indicate the model location. MMS is capable of loading the model archive over such URLs as well. ## Step 7 - Test sample inference Let us try the inference server we just started. Use curl to make a prediction call by passing a JPEG image as input to the prediction request. ```bash cd multi-model-server curl -X POST http://127.0.0.1:8080/predictions/squeezenet_v1.1 -T docs/images/kitten_small.jpg ``` You can expect the response similar to below. The output format is in json. ```json [ { "class": "n02127052 lynx, catamount", "probability": 0.5721369385719299 }, { "class": "n02124075 Egyptian cat", "probability": 0.4079437255859375 }, { "class": "n02123045 tabby, tabby cat", "probability": 0.013694713823497295 }, { "class": "n02123394 Persian cat", "probability": 0.004954110365360975 }, { "class": "n02123159 tiger cat", "probability": 0.0012674571480602026 } ] ``` A consumer application can use this response to identify the objects in the input image and their bounding boxes. ## Step 8 - Clean up and stop MMS MMS will keep running in background. And .mar file will be extracted to system temp directory. You can clean up temp directory by unregister model and use CLI to stop MMS ```bash curl -X DELETE http://127.0.0.1:8081/models/squeezenet_v1.1 multi-model-server --stop ``` ================================================ FILE: examples/sockeye_translate/Dockerfile ================================================ FROM nvidia/cuda:9.2-cudnn7-runtime-ubuntu18.04 ENV PYTHONUNBUFFERED TRUE RUN useradd -m model-server && \ mkdir -p /home/model-server/tmp WORKDIR /home/model-server ENV TEMP=/home/model-server/tmp RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install -y \ build-essential \ python3-dev \ python3-venv \ openjdk-8-jdk-headless \ curl \ vim && \ rm -rf /var/lib/apt/lists/* COPY requirements/ requirements/ RUN python3 -m venv venv && \ . venv/bin/activate && \ pip install --upgrade pip setuptools wheel && \ pip install sockeye --no-deps -r requirements/sockeye/requirements.gpu-cu92.txt && \ pip install --no-cache-dir multi-model-server && \ pip install -r requirements/sockeye-serving/requirements.txt COPY config/config.properties /home/model-server COPY scripts/mms/dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh && \ chown -R model-server /home/model-server EXPOSE 8080 8081 USER model-server ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"] CMD ["serve"] LABEL maintainer="james.e.woo@gmail.com" ================================================ FILE: examples/sockeye_translate/README.md ================================================ # sockeye-serving This example shows how to serve Sockeye models for machine translation. The custom handler is implemented in `sockeye_service.py`. Since Sockeye has many dependencies, it's convenient to use Docker. For simplicity, we'll use a pre-trained model and make some assumptions about how we preprocess the data. ## Getting Started With Docker Pull the latest Docker image: ```bash docker pull jwoo11/sockeye-serving ``` Download the example model archive file (MAR). This is a ZIP archive containing the parameter files and scripts needed to run translation for a particular language: * https://www.dropbox.com/s/pk7hmp7a5zjcfcj/zh.mar?dl=0 Extract the MAR file to `/tmp/models`. We'll use this directory as a bind mount for Docker: ```bash unzip -d /tmp/models/zh zh.mar ``` Start the server: ```bash docker run -itd --name mms -p 8080:8080 -p 8081:8081 -v /tmp/models:/opt/ml/model jwoo11/sockeye-serving serve ``` Now we can load the model using the management API provided by `multi-model-server`: ```bash curl -X POST "http://localhost:8081/models?synchronous=true&initial_workers=1&url=zh" ``` Get the status of the model with the following: ```bash curl -X GET "http://localhost:8081/models/zh" ``` ```json { "modelName": "zh", "modelUrl": "zh", "runtime": "python3", "minWorkers": 1, "maxWorkers": 1, "batchSize": 1, "maxBatchDelay": 100, "workers": [ { "id": "9000", "startTime": "2019-01-26T00:49:10.431Z", "status": "READY", "gpu": false, "memoryUsage": 601395200 } ] } ``` To translate text, use the inference API. Notice that the port is different from above. ```bash curl -X POST "http://localhost:8080/predictions/zh" -H "Content-Type: application/json" \ -d '{ "text": "我的世界是一款開放世界遊戲,玩家沒有具體要完成的目標,即玩家有超高的自由度選擇如何玩遊戲" }' ``` The translation quality depends on the model. Apparently, this one needs more training: ```json { "translation": "in my life was a life of a life of a public public, and a public, a time, a video, a play, which, it was a time of a time of a time." } ``` For more information on MAR files and the built-in REST APIs, see: * https://github.com/awslabs/multi-model-server/tree/master/docs ================================================ FILE: examples/sockeye_translate/config/config.properties ================================================ vmargs=-Xmx128m -XX:-UseLargePages -XX:+UseG1GC -XX:MaxMetaspaceSize=32M -XX:MaxDirectMemorySize=10m -XX:+ExitOnOutOfMemoryError model_store=/opt/ml/model # load_models=ALL inference_address=http://0.0.0.0:8080 management_address=http://0.0.0.0:8081 # management_address=unix:/tmp/management.sock # number_of_netty_threads=0 # netty_client_threads=0 # default_workers_per_model=0 # job_queue_size=100 # async_logging=false # number_of_gpu=1 # cors_allowed_origin # cors_allowed_methods # cors_allowed_headers # keystore=src/test/resources/keystore.p12 # keystore_pass=changeit # keystore_type=PKCS12 # private_key_file=src/test/resources/key.pem # certificate_file=src/test/resources/certs.pem # blacklist_env_vars= ================================================ FILE: examples/sockeye_translate/model_handler.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ ModelHandler defines a base model handler. """ import logging class ModelHandler(object): """ A base Model handler implementation. """ def __init__(self): self.error = None self._context = None self._batch_size = 0 self.initialized = False def initialize(self, context): """ Initialize model. This will be called during model loading time :param context: Initial context contains model server system properties. :return: """ self._context = context self._batch_size = context.system_properties["batch_size"] self.initialized = True def preprocess(self, batch): """ Transform raw input into model input data. :param batch: list of raw requests, should match batch size :return: list of preprocessed model input data """ assert self._batch_size == len(batch), "Invalid input batch size: {}".format(len(batch)) return None def inference(self, model_input): """ Internal inference methods :param model_input: transformed model input data :return: list of inference output in NDArray """ return None def postprocess(self, inference_output): """ Return predict result in batch. :param inference_output: list of inference output :return: list of predict results """ return ["OK"] * self._batch_size def handle(self, data, context): """ Custom service entry point function. :param data: list of objects, raw input from request :param context: model server context :return: list of outputs to be send back to client """ try: data = self.preprocess(data) data = self.inference(data) data = self.postprocess(data) return data except Exception as e: logging.error(e, exc_info=True) request_processor = context.request_processor request_processor.report_status(500, "Unknown inference error") return [str(e)] * self._batch_size ================================================ FILE: examples/sockeye_translate/preprocessor.py ================================================ import html import logging import os import subprocess from html.entities import html5, name2codepoint import regex as re from subword_nmt.apply_bpe import BPE class Preprocessor(object): def __init__(self, bpe_code_file): super(Preprocessor, self).__init__() symbols = '' symbol_set = set({}) for k in name2codepoint.keys(): symbol_set.add(k) for k in html5.keys(): symbol_set.add(k.strip(';')) for s in symbol_set: symbols += '|' + s symbols = symbols.strip('|') self.single = re.compile('&[ ]?(' + symbols + ')[ ]?;', re.IGNORECASE) self.double = re.compile('&[ ]?amp[ ]?;[ ]?(' + symbols + ')[ ]?;', re.IGNORECASE) self.singleNum = re.compile('&[ ]?#[ ]?([0-9]+)[ ]?;', re.IGNORECASE) self.doubleNum = re.compile('&[ ]?amp[ ]?;[ ]?#[ ]?([0-9]+)[ ]?;', re.IGNORECASE) self.singleXNum = re.compile('&[ ]?#[ ]?x[ ]?([a-f0-9]+)[ ]?;', re.IGNORECASE) self.doubleXNum = re.compile('&[ ]?amp[ ]?;[ ]?#[ ]?x[ ]?([a-f0-9]+)[ ]?;', re.IGNORECASE) self.nbsp = re.compile('(&[ ]?x?[ ]?n[]?b[ ]?([a-z][ ]?){0,6}[ ]?;)|(&[ ]?o[ ]?s[ ]?p[ ]?;)', re.IGNORECASE) self.shy = re.compile('[ ]?&[ ]?s[ ]?h[ ]?y[ ]?;[ ]?', re.IGNORECASE) self.bpe = None if bpe_code_file: with open(bpe_code_file, mode='r', encoding='utf-8') as f: self.bpe = BPE(f) else: logging.error('No BPE code file specified') def unescape(self, line): # put html-escaped (or double escaped) codes back into canonical format line = re.sub(self.double, r'&\1;', line) line = re.sub(self.doubleNum, r'&#\1;', line) line = re.sub(self.doubleXNum, r'&#x\1;', line) line = re.sub(self.single, r'&\1;', line) line = re.sub(self.singleNum, r'&#\1;', line) line = re.sub(self.singleXNum, r'&#x\1;', line) # get rid of this tag # alphabetic characters -- need only get rid of space around their canonical escaped forms line = re.sub(self.shy, '', line) # unescape line = html.unescape(line) # clean up weird errors in the escaping of the non-breaking space line = re.sub(self.nbsp, ' ', line) return line def bpe_encode(self, text): return self.bpe.process_line(text).strip() class JoshuaPreprocessor(Preprocessor): def __init__(self, bpe_code_file, joshua_path, moses_path, lang): super(JoshuaPreprocessor, self).__init__(bpe_code_file) self.lang = lang self.normalizer = os.path.join(joshua_path, 'normalize.pl') self.tokenizer = os.path.join(moses_path, 'tokenizer.perl') self.cleaner = os.path.join(moses_path, 'remove-non-printing-char.perl') for f in [self.normalizer, self.tokenizer, self.cleaner]: os.chmod(f, 0o755) def run(self, text): text = self.unescape(text) # normalize, remove non-printing characters, and tokenize popen = subprocess.run( [self.normalizer, self.lang, '|', self.cleaner, '|', self.tokenizer, '-l', self.lang, '-no-escape', '-q'], input=text, encoding='utf-8', stdout=subprocess.PIPE) result = popen.stdout.strip() return self.bpe_encode(result) class ChineseCharPreprocessor(JoshuaPreprocessor): def __init__(self, bpe_code_file, joshua_path, moses_path): super(ChineseCharPreprocessor, self).__init__(bpe_code_file, joshua_path, moses_path, 'zh') self.pattern = re.compile( '([\p{IsHan}\p{InCJK_Symbols_and_Punctuation}\p{InCJK_Radicals_Supplement}\p{InCJK_Compatibility}])', re.UNICODE) def run(self, text): text = self.unescape(text) # normalize and remove non-printing characters popen = subprocess.run([self.normalizer, self.lang, '|', self.cleaner], input=text, encoding='utf-8', stdout=subprocess.PIPE) text = popen.stdout.strip() # tokenize by separating all ZH characters with a space text = self.pattern.sub(r' \1 ', text).strip() # tokenize other characters using Moses popen = subprocess.run([self.tokenizer, '-l', self.lang, '-no-escape', '-q'], input=text, encoding='utf-8', stdout=subprocess.PIPE) result = popen.stdout.strip() return self.bpe_encode(result) class Detokenizer(): def __init__(self, path): self.de_bpe = re.compile('@@( |$)', re.IGNORECASE) self.de_tok = path os.chmod(self.de_tok, 0o755) def run(self, text): bpe_removed = re.sub(self.de_bpe, '', text.translation.strip()) popen = subprocess.run([self.de_tok, '-l', 'en'], input=bpe_removed, encoding='utf-8', stdout=subprocess.PIPE, env=os.environ) return popen.stdout.strip() ================================================ FILE: examples/sockeye_translate/sockeye_service.py ================================================ import logging import os import re from contextlib import ExitStack from sockeye import arguments from sockeye import constants as const from sockeye import inference from sockeye.lexicon import TopKLexicon from sockeye.output_handler import get_output_handler from sockeye.utils import check_condition, log_basic_info, determine_context from .model_handler import ModelHandler from .preprocessor import ChineseCharPreprocessor, Detokenizer def decode_bytes(data): """ Decodes a bytes array from a file upload :param data: a UTF-8 encoded byte array :return: a cleaned string """ pattern = re.compile('\r', re.UNICODE) res = data.decode('utf-8', 'ignore') res = pattern.sub('', res).strip() return res def get_text(req): """ Returns the text string, if any, in the request :param req: a JSON request :return: a text string """ for field in ['body']: if field in req: data = req[field] if isinstance(data, str): return data elif isinstance(data, dict) and 'text' in data: return data['text'] return None def get_file_data(req): """ Returns the file data, if any, in the request :param req: a JSON request :return: a byte array """ for field in ['body', 'file', 'data']: if field in req: data = req[field] if isinstance(data, bytearray): return data return None def read_sockeye_args(params_path): """ Reads command line arguments stored in a file :param params_path: path to the parameters file :return: a list of command line arguments """ with open(params_path) as f: content = f.readlines() res = [] for line in content: res += line.split() return res class SockeyeService(ModelHandler): """ Consumes text of arbitrary length and returns its translation. """ def __init__(self): super(SockeyeService, self).__init__() self.basedir = None self.device_ids = [] self.postprocessor = None self.preprocessor = None self.sentence_id = 0 self.translator = None def initialize(self, context): super(SockeyeService, self).initialize(context) self.basedir = context.system_properties.get('model_dir') self.preprocessor = ChineseCharPreprocessor(os.path.join(self.basedir, 'bpe.codes.zh-en'), os.path.join(self.basedir, 'scripts'), os.path.join(self.basedir, 'scripts')) self.postprocessor = Detokenizer(os.path.join(self.basedir, 'scripts', 'detokenize.pl')) params = arguments.ConfigArgumentParser(description='Translate CLI') arguments.add_translate_cli_args(params) sockeye_args_path = os.path.join(self.basedir, 'sockeye-args.txt') sockeye_args = params.parse_args(read_sockeye_args(sockeye_args_path)) # override models directory sockeye_args.models = [self.basedir] if 'gpu_id' in context.system_properties: self.device_ids.append(context.system_properties['gpu_id']) else: logging.warning('No gpu_id found in context') self.device_ids.append(0) if sockeye_args.checkpoints is not None: check_condition(len(sockeye_args.checkpoints) == len(sockeye_args.models), 'must provide checkpoints for each model') if sockeye_args.skip_topk: check_condition(sockeye_args.beam_size == 1, '--skip-topk has no effect if beam size is larger than 1') check_condition(len(sockeye_args.models) == 1, '--skip-topk has no effect for decoding with more than 1 model') if sockeye_args.nbest_size > 1: check_condition(sockeye_args.beam_size >= sockeye_args.nbest_size, 'Size of nbest list (--nbest-size) must be smaller or equal to beam size (--beam-size).') check_condition(sockeye_args.beam_search_drop == const.BEAM_SEARCH_STOP_ALL, '--nbest-size > 1 requires beam search to only stop after all hypotheses are finished ' '(--beam-search-stop all)') if sockeye_args.output_type != const.OUTPUT_HANDLER_NBEST: logging.warning('For nbest translation, output handler must be "%s", overriding option --output-type.', const.OUTPUT_HANDLER_NBEST) sockeye_args.output_type = const.OUTPUT_HANDLER_NBEST log_basic_info(sockeye_args) output_handler = get_output_handler(sockeye_args.output_type, sockeye_args.output, sockeye_args.sure_align_threshold) with ExitStack() as exit_stack: check_condition(len(self.device_ids) == 1, 'translate only supports single device for now') translator_ctx = determine_context(device_ids=self.device_ids, use_cpu=sockeye_args.use_cpu, disable_device_locking=sockeye_args.disable_device_locking, lock_dir=sockeye_args.lock_dir, exit_stack=exit_stack)[0] logging.info('Translate Device: %s', translator_ctx) if sockeye_args.override_dtype == const.DTYPE_FP16: logging.warning('Experimental feature \'--override-dtype float16\' has been used. ' 'This feature may be removed or change its behavior in the future. ' 'DO NOT USE IT IN PRODUCTION') models, source_vocabs, target_vocab = inference.load_models( context=translator_ctx, max_input_len=sockeye_args.max_input_len, beam_size=sockeye_args.beam_size, batch_size=sockeye_args.batch_size, model_folders=sockeye_args.models, checkpoints=sockeye_args.checkpoints, softmax_temperature=sockeye_args.softmax_temperature, max_output_length_num_stds=sockeye_args.max_output_length_num_stds, decoder_return_logit_inputs=sockeye_args.restrict_lexicon is not None, cache_output_layer_w_b=sockeye_args.restrict_lexicon is not None, override_dtype=sockeye_args.override_dtype, output_scores=output_handler.reports_score()) restrict_lexicon = None if sockeye_args.restrict_lexicon: restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocab) restrict_lexicon.load(sockeye_args.restrict_lexicon, k=sockeye_args.restrict_lexicon_topk) store_beam = sockeye_args.output_type == const.OUTPUT_HANDLER_BEAM_STORE self.translator = inference.Translator(context=translator_ctx, ensemble_mode=sockeye_args.ensemble_mode, bucket_source_width=sockeye_args.bucket_width, length_penalty=inference.LengthPenalty( sockeye_args.length_penalty_alpha, sockeye_args.length_penalty_beta), beam_prune=sockeye_args.beam_prune, beam_search_stop=sockeye_args.beam_search_stop, nbest_size=sockeye_args.nbest_size, models=models, source_vocabs=source_vocabs, target_vocab=target_vocab, restrict_lexicon=restrict_lexicon, avoid_list=sockeye_args.avoid_list, store_beam=store_beam, strip_unknown_words=sockeye_args.strip_unknown_words, skip_topk=sockeye_args.skip_topk) def preprocess(self, batch): """ Preprocesses a JSON request for translation. :param batch: a list of JSON requests of the form { 'text': input_string } or { 'file': file_data } :return: a list of input strings to translate """ logging.info('preprocess grabbed: %s' % batch) texts = [] for req in batch: data = get_file_data(req) if data: text = decode_bytes(data) else: text = get_text(req) if text: bpe = self.preprocessor.run(text) texts.append(bpe) return texts def inference(self, texts): """ Translates the input data. :param texts: a list of strings to translate :return: a list of translation objects from Sockeye """ logging.info('inference grabbed: %s' % texts) if texts: trans_inputs = [] for t in texts: _input = inference.make_input_from_plain_string(self.sentence_id, t) trans_inputs.append(_input) outputs = self.translator.translate(trans_inputs) if len(outputs) != len(trans_inputs): logging.warning("Number of translation outputs doesn't match the number of inputs") self.sentence_id += len(trans_inputs) return outputs else: self.error = 'Input to inference is empty' return [] def postprocess(self, outputs): """ Converts the translations into a list of JSON responses. :param outputs: a list of translation objects from Sockeye :return: a list of translations of the form: { 'translation': output_string } """ logging.info('postprocess grabbed: %s' % outputs) res = [] for t in outputs: output = self.postprocessor.run(t) res.append({'translation': output}) return res _service = SockeyeService() def handle(data, context): if not _service.initialized: _service.initialize(context) if data is None: return None return _service.handle(data, context) ================================================ FILE: examples/ssd/README.md ================================================ # Single Shot Multi Object Detection Inference Service In this example, we show how to use a pre-trained Single Shot Multi Object Detection (SSD) Multi model for performing real time inference using MMS The pre-trained model is trained on the [Pascal VOC 2012 dataset](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html) The network is a SSD model built on Resnet50 as base network to extract image features. The model is trained to detect the following entities (classes): ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']. For more details about the model, you can refer [here](https://github.com/apache/incubator-mxnet/tree/master/example/ssd). The inference service would return the response in the format - '[(object_class, xmin, ymin, xmax, ymax)]. Where, xmin, ymin, xmax and ymax are the bounding box coordinates of the detected object. # Objective 1. Demonstrate how to package a a pre-trained MXNet model in MMS 2. Demonstrate how to create custom service with pre-processing and post-processing ## Step 1 - Download the pre-trained SSD Model You will need the model files to use for the export. Check this example's directory in case they're already downloaded. Otherwise, you can `curl` the files or download them via your browser: ```bash cd multi-model-server/examples/ssd curl -O https://s3.amazonaws.com/model-server/model_archive_1.0/examples/ssd/resnet50_ssd_model-symbol.json curl -O https://s3.amazonaws.com/model-server/model_archive_1.0/examples/ssd/resnet50_ssd_model-0000.params ``` Alternatively, use these links to download the Symbol and Params files via your browser: 1. resnet50_ssd_model-symbol.json 2. resnet50_ssd_model-0000.params **Note** params file is around 125 MB. ## Step 2 - Prepare the signature file Define model input name and shape in `signature.json` file. The signature for this example looks like below: ```json { "inputs": [ { "data_name": "data", "data_shape": [ 1, 3, 512, 512 ] } ] } ``` In the pre-trained model, input name is 'data' and shape is '(1,3,512,512)'. Where, the expected input is a color image (3 channels - RGB) of shape 512*512. We also expect input type is a binary JPEG images. In provided mxnet_vision_service.py, you will see the code that take care of converting binary images to tensor NDArray used by MXNet. *Note:* Typically, if you train your own model, you define the Input and Output Layer name and shape when defining the Neural Network. If you are using a pre-trained MXNet model, to get these Input and Output name and dimensions, you can load the Model and extract the Input and Output layer details. Unfortunately, there are no APIs or easy way to extract the Input shape. Example code below: ```python >>> import mxnet as mx >>> load_symbol, args, auxs = mx.model.load_checkpoint("resnet50_ssd_model", 000) >>> mod = mx.mod.Module(load_symbol, label_names=None, context=mx.cpu()) >>> mod.data_names ['data'] >>> mod.bind(data_shapes=[('data', (1, 3, 512, 512))]) >>> mod.set_params(args, auxs) >>> print(mod.data_names) >>> print(mod.data_shapes) >>> print(mod.output_names) >>> print(mod.output_shapes) ['data'] [DataDesc[data,(1, 3, 512, 512),,NCHW]] ['detection_output'] [('detection_output', (1, 6132, 6))] ``` *Note:* The network generates 6132 detections because we use MXNet's [MultiboxPrior](https://mxnet.incubator.apache.org/api/python/symbol.html#mxnet.contrib.symbol.MultiBoxPrior) to generate the anchor boxes with the following 'Ratios and 'Sizes': ```python sizes = [[.1, .141], [.2,.272], [.37, .447], [.54, .619], [.71, .79], [.88, .961]] ratios = [[1,2,.5], [1,2,.5,3,1./3], [1,2,.5,3,1./3], [1,2,.5,3,1./3], \ [1,2,.5], [1,2,.5]] ``` To understand more about the MultiboxPrior, anchor boxes, sizes and ratios, please read [this tutorial](http://gluon.mxnet.io/chapter08_computer-vision/object-detection.html) ## Step 3 - Prepare synset.txt with list of class names `synset.txt` is where we define list of all classes detected by the model. The pre-trained SSD model used in the example is trained to detect 20 classes - person, car, aeroplane, bicycle and more. See synset.txt file for list of all classes. The list of classes in synset.txt will be loaded by MMS as list of labels in inference logic. You can use `curl` to download it. ```bash cd multi-model-server/examples/ssd curl -O https://s3.amazonaws.com/model-server/model_archive_1.0/examples/ssd/synset.txt ``` Alternatively, use following link to download: synset.txt ## Step 4 - Create custom service class We provided custom service class template code in [template](../template) folder: 1. [model_handler.py](../model_service_template/model_handler.py) - A generic based service class. 2. [mxnet_model_service.py](../model_service_template/mxnet_model_service.py) - A MXNet base service class. 3. [mxnet_vision_service.py](../model_service_template/mxnet_vision_service.py) - A MXNet Vision service class. 4. [mxnet_utils](../model_service_template/mxnet_utils) - A python package that contains utility classes. In this example, you can simple copy them into ssd folder, as use provided mxnet_vision_service.py as user model archive entry point. ```bash cd multi-model-server/examples cp -r model_service_template/* ssd/ ``` In this example, we extend `MXNetVisionService`, provided by MMS for vision inference use-cases, and reuse its input image preprocess functionality to resize and transform the image shape. We only add custom pre-processing and post-processing steps. See [ssd_service.py](ssd_service.py) for more details on how to extend the base service and add custom pre-processing and post-processing. ## Step 5 - Package the model with `model-archiver` CLI utility In this step, we package the following: 1. pre-trained MXNet Model we downloaded in Step 1. 2. '[signature.json](https://s3.amazonaws.com/model-server/model_archive_1.0/examples/ssd/signature.json)' file we prepared in step 2. 3. '[synset.txt](https://s3.amazonaws.com/model-server/model_archive_1.0/examples/ssd/synset.txt)' file we prepared in step 3. 4. custom model service files we prepared in step 4. We use `model-archiver` command line utility (CLI) provided by MMS. Install `model-archiver` in case you have not: ```bash pip install model-archiver ``` This tool create a .mar file that will be provided to MMS for serving inference requests. In following command line, we specify 'ssd_service:handle' as model archive entry point. ```bash cd multi-model-server/examples model-archiver --model-name resnet50_ssd_model --model-path ssd --handler ssd_service:handle ``` ## Step 6 - Start the Inference Service Start the inference service by providing the 'resnet50_ssd_model.mar' file we created in Step 5. MMS then extracts the resources (signature, synset, model symbol and params) we have packaged into .mar file and uses the extended custom service, to start the inference server. By default, the server is started on the localhost at port 8080. ```bash cd multi-model-server multi-model-server --start --model-store examples --models ssd=resnet50_ssd_model.mar ``` Awesome! we have successfully exported a pre-trained MXNet model, extended MMS with custom preprocess/postprocess and started a inference service. **Note**: In this example, MMS loads the .mar file from the local file system. However, you can also store the archive (.mar file) over a network-accessible storage such as AWS S3, and use a URL such as http:// or https:// to indicate the model archive location. MMS is capable of loading the model archive over such URLs as well. ## Step 7 - Test sample inference Let us try the inference server we just started. Open another terminal on the same host. Download a sample image, or try any jpeg image that contains the one or more of the object classes mentioned earlier: 'aeroplane', 'bicycle', 'bird', 'boat', etc... You can also use this image of three dogs on a beach. ![3 dogs on beach](../../docs/images/3dogs.jpg) Use curl to make a prediction call by passing the downloaded image as input to the prediction request. ```bash cd multi-model-server curl -X POST http://127.0.0.1:8080/predictions/ssd -T docs/images/3dogs.jpg ``` You can expect the response similar to below. The output format is `[(object_class, xmin, ymin, xmax, ymax)]`. Where, xmin, ymin, xmax and ymax are the bounding box coordinates of the detected object. ```json [ [ "dog", 399, 128, 570, 290 ], [ "dog", 278, 196, 417, 286 ], [ "cow", 205, 116, 297, 272 ] ] ``` A consumer application can use this response to identify the objects in the input image and their bounding boxes. For better visualization on the input and how we can use the inference output, see below: Input Image ![Street Input Image](../../docs/images/dogs-before.jpg) Output Image ![Street Output Image](../../docs/images/dogs-after.jpg) See [More example outputs](example_outputs.md) # References 1. Adapted code and pre-trained model from - https://github.com/apache/incubator-mxnet/tree/master/example/ssd 2. Learn more about SSD in this tutorial - http://gluon.mxnet.io/chapter08_computer-vision/object-detection.html ================================================ FILE: examples/ssd/example_outputs.md ================================================ # SSD Example Outputs ### Dog Beach ![dog beach](https://farm9.staticflickr.com/8184/8081332083_3a5c242b8b_z_d.jpg) ```bash curl -o dogbeach.jpg https://farm9.staticflickr.com/8184/8081332083_3a5c242b8b_z_d.jpg curl -X POST http://127.0.0.1:8080/ssd/predict -F "data=@dogbeach.jpg" { "prediction": [ [ "person", 203, 213, 248, 347 ], [ "dog", 334, 175, 403, 235 ], [ "person", 109, 211, 144, 291 ], [ "person", 529, 31, 562, 103 ], [ "person", 155, 12, 189, 98 ], [ "horse", 465, 3, 527, 40 ], [ "person", 51, 372, 96, 427 ], [ "dog", 80, 56, 131, 96 ], [ "person", 70, 89, 96, 155 ], [ "cow", 292, 188, 344, 231 ], [ "dog", 294, 186, 349, 231 ] ] } ``` ### 3 Dogs on Beach ![3 dogs on beach](https://farm9.staticflickr.com/8051/8081326814_64756479c6_z_d.jpg) ```bash curl -o 3dogs.jpg https://farm9.staticflickr.com/8051/8081326814_64756479c6_z_d.jpg curl -X POST http://127.0.0.1:8080/ssd/predict -F "data=@3dogs.jpg" { "prediction": [ [ "dog", 399, 128, 570, 290 ], [ "dog", 278, 196, 417, 286 ], [ "cow", 205, 116, 297, 272 ] ] } ``` ### Sailboat ![sailboat](https://farm9.staticflickr.com/8316/7990362092_84a688a089_z_d.jpg) ```bash curl -o sailboat.jpg https://farm9.staticflickr.com/8316/7990362092_84a688a089_z_d.jpg curl -X POST http://127.0.0.1:8080/ssd/predict -F "data=@sailboat.jpg" { "prediction": [ [ "boat", 160, 87, 249, 318 ] ] } ``` ================================================ FILE: examples/ssd/ssd_service.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. import numpy as np from mxnet_utils import image from mxnet_vision_service import MXNetVisionService class SSDService(MXNetVisionService): """ SSD Service to perform real time multi-object detection using pre-trained MXNet SSD model. This class extends MXNetVisionService to add custom preprocessing of input and preparing the output. Reuses input image transformation functionality of MXNetVisionService. """ def __init__(self): super(SSDService, self).__init__() # Threshold is used to pick the detection boxes with score > threshold. # The detections from this network will be of the format - [[class_id, score, x1, y1, x2, y2]]. # We pick all detections where 'score > threshold'. # You can experiment with different threshold to see the best threshold for the use-case. self.threshold = 0.2 # This is used to save the original input image shape. # This is required for preparing the bounding box of the detected object "relative to # original input" self.input_width = None self.input_height = None def preprocess(self, batch): """ Input image buffer from data is read into NDArray. Then, resized to expected shape. Swaps axes to convert image from BGR format to RGB. Returns the preprocessed NDArray as a list for next step, Inference. """ # Read input img = batch[0].get("data") if img is None: img = batch[0].get("body") input_image = image.read(img) # Save original input image shape. # This is required for preparing the bounding box of the detected object relative to # original input self.input_height = input_image.shape[0] self.input_width = input_image.shape[1] # Transform input image - resize, BGR to RGB. # Reuse MXNetVisionService preprocess to achieve above transformations. return super(SSDService, self).preprocess(batch) def postprocess(self, data): """ From the detections, prepares the output in the format of list of [(object_class, xmin, ymin, xmax, ymax)] object_class is name of the object detected. xmin, ymin, xmax, ymax provides the bounding box coordinates. Example: [(person, 555, 175, 581, 242), (dog, 306, 446, 468, 530)] """ # Read the detections output after forward pass (inference) detections = data[0].asnumpy() result = [] for i in range(detections.shape[0]): det = detections[i, :, :] res = det[np.where(det[:, 0] >= 0)[0]] result.append(res) # Prepare the output dets = result[0] classes = self.labels width = self.input_width # original input image width height = self.input_height # original input image height response = [] for i in range(dets.shape[0]): cls_id = int(dets[i, 0]) if cls_id >= 0: score = dets[i, 1] if score > self.threshold: xmin = int(dets[i, 2] * width) ymin = int(dets[i, 3] * height) xmax = int(dets[i, 4] * width) ymax = int(dets[i, 5] * height) class_name = str(cls_id) if classes and len(classes) > cls_id: class_name = classes[cls_id] response.append((class_name, xmin, ymin, xmax, ymax)) return [response] _service = SSDService() def handle(data, context): if not _service.initialized: _service.initialize(context) if data is None: return None return _service.handle(data, context) ================================================ FILE: frontend/.gitignore ================================================ .gradle .DS_Store .idea *.iml build libs ================================================ FILE: frontend/README.md ================================================ Model Server REST API endpoint ============================== ## Quick Start ### Building frontend You can build frontend using gradle: ```sh $ cd frontend $ ./gradlew build ``` You will find a jar file in frontend/server/build/libs file. ### Starting frontend Frontend web server using a configuration file to control the behavior of the frontend web server. An sample config.properties can be found in frontend/server/src/test/resources/config.properties. This configure will load a noop model by default. The noop model file is located in frontend/modelarchive/src/test/resources/model/noop-v0.1.model. #### Start Query service: ```sh cd frontend/server ../gradlew startServer ``` #### Stop Query service: ```sh cd frontend/server ../gradlew killServer ``` ================================================ FILE: frontend/build.gradle ================================================ allprojects { version = '1.0' repositories { jcenter() } apply plugin: 'idea' idea { module { outputDir = file('build/classes/java/main') testOutputDir = file('build/classes/java/test') } } } def javaProjects() { return subprojects.findAll(); } configure(javaProjects()) { apply plugin: 'java' sourceCompatibility = 1.8 targetCompatibility = 1.8 defaultTasks 'jar' apply from: file("${rootProject.projectDir}/tools/gradle/formatter.gradle") apply from: file("${rootProject.projectDir}/tools/gradle/check.gradle") test { useTestNG() { // suiteXmlFiles << new File(rootDir, "testng.xml") //This is how to add custom testng.xml } testLogging { showStandardStreams = true events "passed", "skipped", "failed", "standardOut", "standardError" } } test.finalizedBy(project.tasks.jacocoTestReport) compileJava { options.compilerArgs << "-proc:none" << "-Xlint:all,-options,-static" << "-Werror" } jacocoTestCoverageVerification { violationRules { rule { limit { minimum = 0.75 } } } } } ================================================ FILE: frontend/cts/build.gradle ================================================ dependencies { compile project(":server") } jar { manifest { attributes 'Main-Class': 'com.amazonaws.ml.mms.cts.Cts' } includeEmptyDirs = false from configurations.runtime.collect { it.isDirectory() ? it : zipTree(it) } exclude "META-INF/maven/**" exclude "META-INF/INDEX.LIST" exclude "META-INF/MANIFEST*" exclude "META-INF//LICENSE*" exclude "META-INF//NOTICE*" } ================================================ FILE: frontend/cts/src/main/java/com/amazonaws/ml/mms/cts/Cts.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.cts; import com.amazonaws.ml.mms.ModelServer; import com.amazonaws.ml.mms.util.ConfigManager; import io.netty.buffer.Unpooled; import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http.multipart.HttpPostRequestEncoder; import io.netty.handler.codec.http.multipart.MemoryFileUpload; import java.io.File; import java.io.IOException; import java.net.URL; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public final class Cts { private byte[] kitten; private byte[] player1; private byte[] player2; private List failedModels; private Cts() { failedModels = new ArrayList<>(); } public static void main(String[] args) { updateLog4jConfiguration(); Cts cts = new Cts(); cts.startTest(); } private void startTest() { ConfigManager.init(new ConfigManager.Arguments()); ConfigManager configManager = ConfigManager.getInstance(); ModelServer server = new ModelServer(configManager); Logger logger = LoggerFactory.getLogger(Cts.class); try { server.start(); kitten = loadImage( "https://s3.amazonaws.com/model-server/inputs/kitten.jpg", "kitten.jpg"); player1 = loadImage( "https://s3.amazonaws.com/multi-model-server/onnx-arcface/input1.jpg", "player1.jpg"); player2 = loadImage( "https://s3.amazonaws.com/multi-model-server/onnx-arcface/input2.jpg", "player1.jpg"); HttpClient client = new HttpClient(8081, 8080); for (ModelInfo info : ModelInfo.MODEL_ARCHIVE_1) { runTest(client, info, logger); } for (ModelInfo info : ModelInfo.MODEL_ARCHIVE_04) { runTest(client, info, logger); } } catch (Exception e) { logger.error("", e); } finally { try { server.stop(); } catch (Exception e) { logger.error("", e); } } if (failedModels.isEmpty()) { logger.info("All models passed CTS."); System.exit(0); } else { logger.info("Following models failed CTS:"); for (String model : failedModels) { logger.info(model); } System.exit(1); } } private void runTest(HttpClient client, ModelInfo info, Logger logger) throws HttpPostRequestEncoder.ErrorDataEncoderException, InterruptedException, IOException { String modelName = info.getModelName(); String url = info.getUrl(); int type = info.getType(); logger.info("Testing model: {}={}", modelName, url); if (!client.registerModel(modelName, url)) { failedModels.add(url); return; } try { if (!predict(client, type, modelName)) { failedModels.add(url); } } finally { if (!client.unregisterModel(modelName)) { failedModels.add(url); } } } private boolean predict(HttpClient client, int type, String modelName) throws HttpPostRequestEncoder.ErrorDataEncoderException, InterruptedException, IOException { switch (type) { case ModelInfo.FACE_RECOGNITION: // arcface DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/"); HttpPostRequestEncoder encoder = new HttpPostRequestEncoder(req, true); MemoryFileUpload body = new MemoryFileUpload( "img1", "img1.jpg", "images/jpeg", null, null, player1.length); body.setContent(Unpooled.copiedBuffer(player1)); encoder.addBodyHttpData(body); body = new MemoryFileUpload( "img2", "img2.jpg", "images/jpeg", null, null, player2.length); body.setContent(Unpooled.copiedBuffer(player2)); encoder.addBodyHttpData(body); return client.predict(modelName, req, encoder); case ModelInfo.SEMANTIC_SEGMENTATION: // duc return client.predict(modelName, kitten, "image/jpeg"); case ModelInfo.LANGUAGE_MODELING: // lstm byte[] json = ("[{'input_sentence': 'on the exchange floor as soon" + " as ual stopped trading we for a panic" + " said one top floor trader'}]") .getBytes(StandardCharsets.UTF_8); return client.predict(modelName, json, "application/json"); case ModelInfo.IMAGE_CLASSIFICATION: case ModelInfo.EMOTION_DETECTION: default: return client.predict(modelName, kitten, "image/jpeg"); } } private byte[] loadImage(String path, String fileName) throws IOException { File file = new File(System.getProperty("java.io.tmpdir"), fileName); if (file.exists()) { return FileUtils.readFileToByteArray(file); } byte[] buf = IOUtils.toByteArray(new URL(path)); FileUtils.writeByteArrayToFile(file, buf); return buf; } private static void updateLog4jConfiguration() { System.setProperty("LOG_LOCATION", "logs"); System.setProperty("METRICS_LOCATION", "logs"); } } ================================================ FILE: frontend/cts/src/main/java/com/amazonaws/ml/mms/cts/HttpClient.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.cts; import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPipeline; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpContentDecompressor; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http.multipart.HttpPostRequestEncoder; import io.netty.handler.stream.ChunkedWriteHandler; import io.netty.handler.timeout.ReadTimeoutException; import io.netty.handler.timeout.ReadTimeoutHandler; import java.io.IOException; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class HttpClient { static final Logger logger = LoggerFactory.getLogger(HttpClient.class); private int managementPort; private int inferencePort; private Bootstrap bootstrap; private ClientHandler handler; public HttpClient(int managementPort, int inferencePort) { this.managementPort = managementPort; this.inferencePort = inferencePort; handler = new ClientHandler(); bootstrap = bootstrap(handler); } public boolean registerModel(String modelName, String modelUrl) throws InterruptedException, IOException { Channel channel = connect(bootstrap, managementPort); String uri = "/models?url=" + URLEncoder.encode(modelUrl, StandardCharsets.UTF_8.name()) + "&model_name=" + modelName + "&initial_workers=1&synchronous=true"; HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, uri); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); int statusCode = handler.getStatusCode(); String ret = handler.getContent(); if (statusCode == 200) { logger.info("registerModel: {} success.", modelName); logger.trace(ret); return true; } logger.warn("registerModel: {} failed: {}", modelUrl, ret); return false; } public boolean unregisterModel(String modelName) throws InterruptedException, IOException { Channel channel = connect(bootstrap, managementPort); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.DELETE, "/models/" + URLEncoder.encode(modelName, StandardCharsets.UTF_8.name())); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); int statusCode = handler.getStatusCode(); String ret = handler.getContent(); if (statusCode == 200) { logger.info("unregisterModel: {} success.", modelName); logger.trace(ret); return true; } logger.warn("unregisterModel: {} failed: {}", modelName, ret); return false; } public boolean predict(String modelName, byte[] content, CharSequence contentType) throws InterruptedException, IOException { Channel channel = connect(bootstrap, inferencePort); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/" + URLEncoder.encode(modelName, StandardCharsets.UTF_8.name())); req.content().writeBytes(content); HttpUtil.setContentLength(req, content.length); req.headers().set(HttpHeaderNames.CONTENT_TYPE, contentType); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); int statusCode = handler.getStatusCode(); String ret = handler.getContent(); if (statusCode == 200) { logger.info("predict: {} success.", modelName); logger.trace(ret); return true; } logger.warn("predict: {} failed: {}", modelName, ret); return false; } public boolean predict( String modelName, DefaultFullHttpRequest req, HttpPostRequestEncoder requestEncoder) throws InterruptedException, HttpPostRequestEncoder.ErrorDataEncoderException, IOException { Channel channel = connect(bootstrap, inferencePort); req.setUri("/predictions/" + URLEncoder.encode(modelName, StandardCharsets.UTF_8.name())); channel.writeAndFlush(requestEncoder.finalizeRequest()); if (requestEncoder.isChunked()) { channel.writeAndFlush(requestEncoder).sync(); } channel.closeFuture().sync(); int statusCode = handler.getStatusCode(); String ret = handler.getContent(); if (statusCode == 200) { logger.info("predict: {} success.", modelName); logger.trace(ret); return true; } logger.warn("predict: {} failed: {}", modelName, ret); return false; } private Bootstrap bootstrap(ClientHandler handler) { Bootstrap b = new Bootstrap(); b.group(new NioEventLoopGroup(1)) .channel(NioSocketChannel.class) .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 10 * 1000) .handler( new ChannelInitializer() { @Override public void initChannel(Channel ch) { ChannelPipeline p = ch.pipeline(); p.addLast(new ReadTimeoutHandler(10 * 60 * 1000)); p.addLast(new HttpClientCodec()); p.addLast(new HttpContentDecompressor()); p.addLast(new ChunkedWriteHandler()); p.addLast(new HttpObjectAggregator(6553600)); p.addLast(handler); } }); return b; } private Channel connect(Bootstrap b, int port) throws InterruptedException { SocketAddress address = new InetSocketAddress("127.0.0.1", port); return b.connect(address).sync().channel(); } @ChannelHandler.Sharable private static final class ClientHandler extends SimpleChannelInboundHandler { private int statusCode; private String content; public ClientHandler() {} public int getStatusCode() { return statusCode; } public String getContent() { return content; } @Override public void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) { statusCode = msg.status().code(); content = msg.content().toString(StandardCharsets.UTF_8); ctx.close(); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { if (cause instanceof IOException) { content = "Failed to connect to MMS"; } else if (cause instanceof ReadTimeoutException) { content = "Request to MMS timeout."; } else { content = cause.getMessage(); if (content == null) { content = "NullPointException"; } logger.error("Unknown exception", cause); } statusCode = 500; ctx.close(); } } } ================================================ FILE: frontend/cts/src/main/java/com/amazonaws/ml/mms/cts/ModelInfo.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.cts; public class ModelInfo { public static final int IMAGE_CLASSIFICATION = 1; public static final int FACE_RECOGNITION = 2; public static final int SEMANTIC_SEGMENTATION = 3; public static final int EMOTION_DETECTION = 4; public static final int LANGUAGE_MODELING = 5; private static final String S3_PREFIX = "https://s3.amazonaws.com/model-server/model_archive_1.0/"; private static final String S3_PREFIX_LEGACY = "https://s3.amazonaws.com/model-server/models/"; static final ModelInfo[] MODEL_ARCHIVE_1 = { new ModelInfo("FERPlus", ModelInfo.EMOTION_DETECTION), new ModelInfo("caffenet"), new ModelInfo("inception-bn"), new ModelInfo("lstm_ptb", ModelInfo.LANGUAGE_MODELING), new ModelInfo("nin"), new ModelInfo("onnx-arcface-resnet100", ModelInfo.FACE_RECOGNITION), new ModelInfo("onnx-duc", ModelInfo.SEMANTIC_SEGMENTATION), new ModelInfo("onnx-inception_v1"), new ModelInfo("onnx-mobilenet"), new ModelInfo("onnx-resnet101v1"), new ModelInfo("onnx-resnet101v2"), new ModelInfo("onnx-resnet152v1"), new ModelInfo("onnx-resnet152v2"), new ModelInfo("onnx-resnet18v1"), new ModelInfo("onnx-resnet18v2"), new ModelInfo("onnx-resnet34v1"), new ModelInfo("onnx-resnet34v2"), new ModelInfo("onnx-resnet50v1"), new ModelInfo("onnx-resnet50v2"), new ModelInfo("onnx-squeezenet"), new ModelInfo("onnx-vgg16"), new ModelInfo("onnx-vgg16_bn"), new ModelInfo("onnx-vgg19"), new ModelInfo("onnx-vgg19_bn"), new ModelInfo("resnet-152"), new ModelInfo("resnet-18"), new ModelInfo("resnet50_ssd"), new ModelInfo("resnext-101-64x4d"), new ModelInfo("squeezenet_v1.1"), new ModelInfo("squeezenet_v1.2"), new ModelInfo("vgg16"), new ModelInfo("vgg19") }; static final ModelInfo[] MODEL_ARCHIVE_04 = { new ModelInfo( "FERPlus", "https://s3.amazonaws.com/model-server/models/FERPlus/ferplus.model", EMOTION_DETECTION), new ModelInfo(true, "caffenet"), new ModelInfo( "inception-bn", "https://s3.amazonaws.com/model-server/models/inception-bn/Inception-BN.model"), new ModelInfo(true, "lstm_ptb", LANGUAGE_MODELING), new ModelInfo(true, "nin"), new ModelInfo( "onnx-arcface-resnet100", "https://s3.amazonaws.com/multi-model-server/onnx-arcface/arcface-resnet100.model"), new ModelInfo( "onnx-duc", "https://s3.amazonaws.com/multi-model-server/onnx-duc/ResNet_DUC_HDC.model"), new ModelInfo( "onnx-inception_v1", "https://s3.amazonaws.com/model-server/models/onnx-inception_v1/inception_v1.model"), new ModelInfo( "onnx-mobilenet", "https://s3.amazonaws.com/multi-model-server/onnx-mobilenet/mobilenetv2-1.0.model"), new ModelInfo( "onnx-resnet101v1", "https://s3.amazonaws.com/multi-model-server/onnx-resnetv1/resnet101v1.model"), new ModelInfo( "onnx-resnet101v2", "https://s3.amazonaws.com/multi-model-server/onnx-resnetv2/resnet101v2.model"), new ModelInfo( "onnx-resnet152v1", "https://s3.amazonaws.com/multi-model-server/onnx-resnetv1/resnet152v1.model"), new ModelInfo( "onnx-resnet152v2", "https://s3.amazonaws.com/multi-model-server/onnx-resnetv2/resnet152v2.model"), new ModelInfo( "onnx-resnet18v1", "https://s3.amazonaws.com/multi-model-server/onnx-resnetv1/resnet18v1.model"), new ModelInfo( "onnx-resnet18v2", "https://s3.amazonaws.com/multi-model-server/onnx-resnetv2/resnet18v2.model"), new ModelInfo( "onnx-resnet34v1", "https://s3.amazonaws.com/multi-model-server/onnx-resnetv1/resnet34v1.model"), new ModelInfo( "onnx-resnet34v2", "https://s3.amazonaws.com/multi-model-server/onnx-resnetv2/resnet34v2.model"), new ModelInfo( "onnx-resnet50v1", "https://s3.amazonaws.com/multi-model-server/onnx-resnetv1/resnet50v1.model"), new ModelInfo( "onnx-resnet50v2", "https://s3.amazonaws.com/multi-model-server/onnx-resnetv2/resnet50v2.model"), new ModelInfo( "onnx-squeezenet", "https://s3.amazonaws.com/model-server/models/onnx-squeezenet/squeezenet.model"), new ModelInfo( "onnx-vgg16", "https://s3.amazonaws.com/multi-model-server/onnx-vgg16/vgg16.model"), new ModelInfo( "onnx-vgg16_bn", "https://s3.amazonaws.com/multi-model-server/onnx-vgg16_bn/vgg16_bn.model"), new ModelInfo( "onnx-vgg19", "https://s3.amazonaws.com/model-server/models/onnx-vgg19/vgg19.model"), new ModelInfo( "onnx-vgg19_bn", "https://s3.amazonaws.com/multi-model-server/onnx-vgg19_bn/vgg19_bn.model"), new ModelInfo(true, "resnet-152"), new ModelInfo(true, "resnet-18"), new ModelInfo( "resnet50_ssd", "https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model.model"), new ModelInfo(true, "resnext-101-64x4d"), new ModelInfo(true, "squeezenet_v1.1"), new ModelInfo(true, "vgg16"), new ModelInfo(true, "vgg19") }; private String modelName; private String url; private int type; public ModelInfo(String modelName) { this(false, modelName, IMAGE_CLASSIFICATION); } public ModelInfo(String modelName, int type) { this(false, modelName, type); } public ModelInfo(boolean legacy, String modelName) { this(legacy, modelName, IMAGE_CLASSIFICATION); } public ModelInfo(boolean legacy, String modelName, int type) { this.modelName = modelName; if (legacy) { url = S3_PREFIX_LEGACY + modelName + '/' + modelName + ".model"; } else { url = S3_PREFIX + modelName + ".mar"; } this.type = type; } public ModelInfo(String modelName, String url) { this(modelName, url, IMAGE_CLASSIFICATION); } public ModelInfo(String modelName, String url, int type) { this.modelName = modelName; this.url = url; this.type = type; } public String getModelName() { return modelName; } public String getUrl() { return url; } public int getType() { return type; } } ================================================ FILE: frontend/cts/src/main/resources/log4j2.xml ================================================ ================================================ FILE: frontend/gradle/wrapper/gradle-wrapper.properties ================================================ #Thu Apr 13 16:20:04 PDT 2017 distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists distributionUrl=https\://services.gradle.org/distributions/gradle-4.9-bin.zip ================================================ FILE: frontend/gradle.properties ================================================ org.gradle.daemon=true org.gradle.jvmargs=-Xmx1024M netty_version=4.1.109.Final slf4j_api_version=1.7.32 slf4j_log4j_version=2.17.1 gson_version=2.8.9 commons_cli_version=1.3.1 testng_version=6.8.1 mms_server_sdk_version=1.0.1 lmax_disruptor_version=3.4.4 ================================================ FILE: frontend/gradlew ================================================ #!/usr/bin/env bash ############################################################################## ## ## Gradle start up script for UN*X ## ############################################################################## # Attempt to set APP_HOME # Resolve links: $0 may be a link PRG="$0" # Need this for relative symlinks. while [ -h "$PRG" ] ; do ls=`ls -ld "$PRG"` link=`expr "$ls" : '.*-> \(.*\)$'` if expr "$link" : '/.*' > /dev/null; then PRG="$link" else PRG=`dirname "$PRG"`"/$link" fi done SAVED="`pwd`" cd "`dirname \"$PRG\"`/" >/dev/null APP_HOME="`pwd -P`" cd "$SAVED" >/dev/null APP_NAME="Gradle" APP_BASE_NAME=`basename "$0"` # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. DEFAULT_JVM_OPTS="" # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD="maximum" warn ( ) { echo "$*" } die ( ) { echo echo "$*" echo exit 1 } # OS specific support (must be 'true' or 'false'). cygwin=false msys=false darwin=false nonstop=false case "`uname`" in CYGWIN* ) cygwin=true ;; Darwin* ) darwin=true ;; MINGW* ) msys=true ;; NONSTOP* ) nonstop=true ;; esac CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar # Determine the Java command to use to start the JVM. if [ -n "$JAVA_HOME" ] ; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then # IBM's JDK on AIX uses strange locations for the executables JAVACMD="$JAVA_HOME/jre/sh/java" else JAVACMD="$JAVA_HOME/bin/java" fi if [ ! -x "$JAVACMD" ] ; then die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME Please set the JAVA_HOME variable in your environment to match the location of your Java installation." fi else JAVACMD="java" which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the location of your Java installation." fi # Increase the maximum file descriptors if we can. if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then MAX_FD_LIMIT=`ulimit -H -n` if [ $? -eq 0 ] ; then if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then MAX_FD="$MAX_FD_LIMIT" fi ulimit -n $MAX_FD if [ $? -ne 0 ] ; then warn "Could not set maximum file descriptor limit: $MAX_FD" fi else warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" fi fi # For Darwin, add options to specify how the application appears in the dock if $darwin; then GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" fi # For Cygwin, switch paths to Windows format before running java if $cygwin ; then APP_HOME=`cygpath --path --mixed "$APP_HOME"` CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` JAVACMD=`cygpath --unix "$JAVACMD"` # We build the pattern for arguments to be converted via cygpath ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` SEP="" for dir in $ROOTDIRSRAW ; do ROOTDIRS="$ROOTDIRS$SEP$dir" SEP="|" done OURCYGPATTERN="(^($ROOTDIRS))" # Add a user-defined pattern to the cygpath arguments if [ "$GRADLE_CYGPATTERN" != "" ] ; then OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" fi # Now convert the arguments - kludge to limit ourselves to /bin/sh i=0 for arg in "$@" ; do CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` else eval `echo args$i`="\"$arg\"" fi i=$((i+1)) done case $i in (0) set -- ;; (1) set -- "$args0" ;; (2) set -- "$args0" "$args1" ;; (3) set -- "$args0" "$args1" "$args2" ;; (4) set -- "$args0" "$args1" "$args2" "$args3" ;; (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; esac fi # Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules function splitJvmOpts() { JVM_OPTS=("$@") } eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME" # by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong if [[ "$(uname)" == "Darwin" ]] && [[ "$HOME" == "$PWD" ]]; then cd "$(dirname "$0")" fi exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@" ================================================ FILE: frontend/gradlew.bat ================================================ @if "%DEBUG%" == "" @echo off @rem ########################################################################## @rem @rem Gradle startup script for Windows @rem @rem ########################################################################## @rem Set local scope for the variables with windows NT shell if "%OS%"=="Windows_NT" setlocal set DIRNAME=%~dp0 if "%DIRNAME%" == "" set DIRNAME=. set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. set DEFAULT_JVM_OPTS= @rem Find java.exe if defined JAVA_HOME goto findJavaFromJavaHome set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 if "%ERRORLEVEL%" == "0" goto init echo. echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. echo. echo Please set the JAVA_HOME variable in your environment to match the echo location of your Java installation. goto fail :findJavaFromJavaHome set JAVA_HOME=%JAVA_HOME:"=% set JAVA_EXE=%JAVA_HOME%/bin/java.exe if exist "%JAVA_EXE%" goto init echo. echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% echo. echo Please set the JAVA_HOME variable in your environment to match the echo location of your Java installation. goto fail :init @rem Get command-line arguments, handling Windows variants if not "%OS%" == "Windows_NT" goto win9xME_args :win9xME_args @rem Slurp the command line arguments. set CMD_LINE_ARGS= set _SKIP=2 :win9xME_args_slurp if "x%~1" == "x" goto execute set CMD_LINE_ARGS=%* :execute @rem Setup the command line set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar @rem Execute Gradle "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% :end @rem End local scope for the variables with windows NT shell if "%ERRORLEVEL%"=="0" goto mainEnd :fail rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of rem the _cmd.exe /c_ return code! if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 exit /b 1 :mainEnd if "%OS%"=="Windows_NT" endlocal :omega ================================================ FILE: frontend/modelarchive/build.gradle ================================================ dependencies { compile "commons-io:commons-io:2.6" compile "org.slf4j:slf4j-api:${slf4j_api_version}" compile "org.apache.logging.log4j:log4j-slf4j-impl:${slf4j_log4j_version}" compile "com.google.code.gson:gson:${gson_version}" testCompile "commons-cli:commons-cli:${commons_cli_version}" testCompile "org.testng:testng:${testng_version}" } ================================================ FILE: frontend/modelarchive/src/main/java/com/amazonaws/ml/mms/archive/DownloadModelException.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.archive; public class DownloadModelException extends ModelException { static final long serialVersionUID = 1L; /** * Constructs an {@code DownloadModelException} with the specified detail message. * * @param message The detail message (which is saved for later retrieval by the {@link * #getMessage()} method) */ public DownloadModelException(String message) { super(message); } /** * Constructs an {@code DownloadModelException} with the specified detail message and cause. * *

Note that the detail message associated with {@code cause} is not automatically * incorporated into this exception's detail message. * * @param message The detail message (which is saved for later retrieval by the {@link * #getMessage()} method) * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} * method). (A null value is permitted, and indicates that the cause is nonexistent or * unknown.) */ public DownloadModelException(String message, Throwable cause) { super(message, cause); } } ================================================ FILE: frontend/modelarchive/src/main/java/com/amazonaws/ml/mms/archive/Hex.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.archive; public final class Hex { private static final char[] HEX_CHARS = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f' }; private Hex() {} public static String toHexString(byte[] block) { return toHexString(block, 0, block.length); } public static String toHexString(byte[] block, int offset, int len) { if (block == null) { return null; } if (offset < 0 || offset + len > block.length) { throw new IllegalArgumentException("Invalid offset or length."); } StringBuilder buf = new StringBuilder(); for (int i = offset, size = offset + len; i < size; i++) { int high = (block[i] & 0xf0) >> 4; int low = block[i] & 0x0f; buf.append(HEX_CHARS[high]); buf.append(HEX_CHARS[low]); } return buf.toString(); } } ================================================ FILE: frontend/modelarchive/src/main/java/com/amazonaws/ml/mms/archive/InvalidModelException.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.archive; public class InvalidModelException extends ModelException { static final long serialVersionUID = 1L; /** * Constructs an {@code InvalidModelException} with the specified detail message. * * @param message The detail message (which is saved for later retrieval by the {@link * #getMessage()} method) */ public InvalidModelException(String message) { super(message); } /** * Constructs an {@code InvalidModelException} with the specified detail message and cause. * *

Note that the detail message associated with {@code cause} is not automatically * incorporated into this exception's detail message. * * @param message The detail message (which is saved for later retrieval by the {@link * #getMessage()} method) * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} * method). (A null value is permitted, and indicates that the cause is nonexistent or * unknown.) */ public InvalidModelException(String message, Throwable cause) { super(message, cause); } } ================================================ FILE: frontend/modelarchive/src/main/java/com/amazonaws/ml/mms/archive/LegacyManifest.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.archive; import com.google.gson.annotations.SerializedName; import java.util.Map; public class LegacyManifest { @SerializedName("Engine") private Map engine; @SerializedName("Model-Archive-Description") private String description; @SerializedName("License") private String license; @SerializedName("Model-Archive-Version") private String version; @SerializedName("Model-Server") private String serverVersion; @SerializedName("Model") private ModelInfo modelInfo; @SerializedName("Created-By") private CreatedBy createdBy; public LegacyManifest() {} public Map getEngine() { return engine; } public void setEngine(Map engine) { this.engine = engine; } public String getDescription() { return description; } public void setDescription(String description) { this.description = description; } public String getLicense() { return license; } public void setLicense(String license) { this.license = license; } public String getVersion() { return version; } public void setVersion(String version) { this.version = version; } public String getServerVersion() { return serverVersion; } public void setServerVersion(String serverVersion) { this.serverVersion = serverVersion; } public ModelInfo getModelInfo() { return modelInfo; } public void setModelInfo(ModelInfo modelInfo) { this.modelInfo = modelInfo; } public CreatedBy getCreatedBy() { return createdBy; } public void setCreatedBy(CreatedBy createdBy) { this.createdBy = createdBy; } public Manifest migrate() throws InvalidModelException { Manifest manifest = new Manifest(); manifest.setDescription(description); manifest.setLicense(license); manifest.setSpecificationVersion("0.1"); if (createdBy != null) { Manifest.Publisher publisher = new Manifest.Publisher(); publisher.setAuthor(createdBy.getAuthor()); publisher.setEmail(createdBy.getEmail()); manifest.setPublisher(publisher); } if (engine != null) { Object engineVersion = engine.get("MXNet"); if (engineVersion instanceof Number) { Manifest.Engine eng = new Manifest.Engine(); eng.setEngineName("MXNet"); eng.setEngineVersion(engineVersion.toString()); manifest.setEngine(eng); } } Manifest.Model model = new Manifest.Model(); model.setModelName(modelInfo.getModelName()); model.setDescription(modelInfo.getDescription()); model.setHandler(modelInfo.getService()); model.setModelVersion("snapshot"); model.addExtension("parametersFile", modelInfo.getParameters()); model.addExtension("symbolFile", modelInfo.getSymbol()); manifest.setModel(model); if (model.getHandler() == null) { throw new InvalidModelException("Missing Service entry in MANIFEST.json"); } return manifest; } public static final class CreatedBy { @SerializedName("Author") private String author; @SerializedName("Author-Email") private String email; public CreatedBy() {} public String getAuthor() { return author; } public void setAuthor(String author) { this.author = author; } public String getEmail() { return email; } public void setEmail(String email) { this.email = email; } } public static final class ModelInfo { @SerializedName("Parameters") private String parameters; @SerializedName("Symbol") private String symbol; @SerializedName("Description") private String description; @SerializedName("Model-Name") private String modelName; @SerializedName("Service") private String service; public ModelInfo() {} public String getParameters() { return parameters; } public void setParameters(String parameters) { this.parameters = parameters; } public String getSymbol() { return symbol; } public void setSymbol(String symbol) { this.symbol = symbol; } public String getDescription() { return description; } public void setDescription(String description) { this.description = description; } public String getModelName() { return modelName; } public void setModelName(String modelName) { this.modelName = modelName; } public String getService() { return service; } public void setService(String service) { this.service = service; } } } ================================================ FILE: frontend/modelarchive/src/main/java/com/amazonaws/ml/mms/archive/Manifest.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.archive; import com.google.gson.annotations.SerializedName; import java.util.LinkedHashMap; import java.util.Map; public class Manifest { private String specificationVersion; private String implementationVersion; private String description; private String modelServerVersion; private String license; private RuntimeType runtime; private Engine engine; private Model model; private Publisher publisher; public Manifest() { specificationVersion = "1.0"; implementationVersion = "1.0"; modelServerVersion = "1.0"; license = "Apache 2.0"; runtime = RuntimeType.PYTHON; model = new Model(); } public String getSpecificationVersion() { return specificationVersion; } public void setSpecificationVersion(String specificationVersion) { this.specificationVersion = specificationVersion; } public String getImplementationVersion() { return implementationVersion; } public void setImplementationVersion(String implementationVersion) { this.implementationVersion = implementationVersion; } public String getDescription() { return description; } public void setDescription(String description) { this.description = description; } public String getModelServerVersion() { return modelServerVersion; } public void setModelServerVersion(String modelServerVersion) { this.modelServerVersion = modelServerVersion; } public String getLicense() { return license; } public void setLicense(String license) { this.license = license; } public RuntimeType getRuntime() { return runtime; } public void setRuntime(RuntimeType runtime) { this.runtime = runtime; } public Engine getEngine() { return engine; } public void setEngine(Engine engine) { this.engine = engine; } public Model getModel() { return model; } public void setModel(Model model) { this.model = model; } public Publisher getPublisher() { return publisher; } public void setPublisher(Publisher publisher) { this.publisher = publisher; } public static final class Publisher { private String author; private String email; public Publisher() {} public String getAuthor() { return author; } public void setAuthor(String author) { this.author = author; } public String getEmail() { return email; } public void setEmail(String email) { this.email = email; } } public static final class Engine { private String engineName; private String engineVersion; public Engine() {} public String getEngineName() { return engineName; } public void setEngineName(String engineName) { this.engineName = engineName; } public String getEngineVersion() { return engineVersion; } public void setEngineVersion(String engineVersion) { this.engineVersion = engineVersion; } } public static final class Model { private String modelName; private String description; private String modelVersion; private Map extensions; private String handler; public Model() {} public String getModelName() { return modelName; } public void setModelName(String modelName) { this.modelName = modelName; } public String getDescription() { return description; } public void setDescription(String description) { this.description = description; } public String getModelVersion() { return modelVersion; } public void setModelVersion(String modelVersion) { this.modelVersion = modelVersion; } public Map getExtensions() { return extensions; } public void setExtensions(Map extensions) { this.extensions = extensions; } public void addExtension(String key, Object value) { if (extensions == null) { extensions = new LinkedHashMap<>(); } extensions.put(key, value); } public String getHandler() { return handler; } public void setHandler(String handler) { this.handler = handler; } } public enum RuntimeType { @SerializedName("python") PYTHON("python"), @SerializedName("python2") PYTHON2("python2"), @SerializedName("python3") PYTHON3("python3"); String value; RuntimeType(String value) { this.value = value; } public String getValue() { return value; } public static RuntimeType fromValue(String value) { for (RuntimeType runtime : values()) { if (runtime.value.equals(value)) { return runtime; } } throw new IllegalArgumentException("Invalid RuntimeType value: " + value); } } } ================================================ FILE: frontend/modelarchive/src/main/java/com/amazonaws/ml/mms/archive/ModelArchive.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.archive; import com.google.gson.Gson; import com.google.gson.GsonBuilder; import com.google.gson.JsonObject; import com.google.gson.JsonParseException; import com.google.gson.JsonParser; import com.google.gson.JsonPrimitive; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.Reader; import java.net.HttpURLConnection; import java.net.MalformedURLException; import java.net.SocketTimeoutException; import java.net.URL; import java.nio.charset.StandardCharsets; import java.security.DigestInputStream; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.Enumeration; import java.util.regex.Pattern; import java.util.zip.ZipEntry; import java.util.zip.ZipFile; import java.util.zip.ZipOutputStream; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class ModelArchive { private static final Logger logger = LoggerFactory.getLogger(ModelArchive.class); public static final Gson GSON = new GsonBuilder().setPrettyPrinting().create(); private static final Pattern URL_PATTERN = Pattern.compile("http(s)?://.*", Pattern.CASE_INSENSITIVE); private static final String MANIFEST_FILE = "MANIFEST.json"; private Manifest manifest; private String url; private File modelDir; private boolean extracted; public ModelArchive(Manifest manifest, String url, File modelDir, boolean extracted) { this.manifest = manifest; this.url = url; this.modelDir = modelDir; this.extracted = extracted; } public static ModelArchive downloadModel(String modelStore, String url) throws ModelException, IOException { if (URL_PATTERN.matcher(url).matches()) { File modelDir = download(url); return load(url, modelDir, true); } if (url.contains("..")) { throw new ModelNotFoundException("Relative path is not allowed in url: " + url); } if (modelStore == null) { throw new ModelNotFoundException("Model store has not been configured."); } File modelLocation = new File(modelStore, url); if (!modelLocation.exists()) { throw new ModelNotFoundException("Model not found in model store: " + url); } if (modelLocation.isFile()) { try (InputStream is = new FileInputStream(modelLocation)) { File unzipDir = unzip(is, null); return load(url, unzipDir, true); } } return load(url, modelLocation, false); } public static void migrate(File legacyModelFile, File destination) throws InvalidModelException, IOException { boolean failed = true; try (ZipFile zip = new ZipFile(legacyModelFile); ZipOutputStream zos = new ZipOutputStream(new FileOutputStream(destination))) { ZipEntry manifestEntry = zip.getEntry(MANIFEST_FILE); if (manifestEntry == null) { throw new InvalidModelException("Missing manifest file in model archive."); } InputStream is = zip.getInputStream(manifestEntry); Reader reader = new InputStreamReader(is, StandardCharsets.UTF_8); JsonObject json = JsonParser.parseReader(reader).getAsJsonObject(); JsonPrimitive version = json.getAsJsonPrimitive("specificationVersion"); Manifest manifest; if (version != null && "1.0".equals(version.getAsString())) { throw new InvalidModelException("model archive is already in 1.0 version."); } LegacyManifest legacyManifest = GSON.fromJson(json, LegacyManifest.class); manifest = legacyManifest.migrate(); zos.putNextEntry(new ZipEntry("MAR-INF/")); zos.putNextEntry(new ZipEntry("MAR-INF/" + MANIFEST_FILE)); zos.write(GSON.toJson(manifest).getBytes(StandardCharsets.UTF_8)); Enumeration en = zip.entries(); while (en.hasMoreElements()) { ZipEntry entry = en.nextElement(); String name = entry.getName(); if (MANIFEST_FILE.equalsIgnoreCase(name) || name.startsWith(".")) { continue; } zos.putNextEntry(new ZipEntry(name)); if (!entry.isDirectory()) { IOUtils.copy(zip.getInputStream(entry), zos); } } failed = false; } finally { if (failed) { FileUtils.deleteQuietly(destination); } } } private static File download(String path) throws ModelException, IOException { HttpURLConnection conn; try { URL url = new URL(path); conn = (HttpURLConnection) url.openConnection(); if (conn.getResponseCode() != HttpURLConnection.HTTP_OK) { throw new DownloadModelException( "Failed to download model from: " + path + ", code: " + conn.getResponseCode()); } } catch (MalformedURLException | RuntimeException e) { // URLConnection may throw raw RuntimeException if port is out of range. throw new ModelNotFoundException("Invalid model url: " + path, e); } catch (IOException e) { throw new DownloadModelException("Failed to download model from: " + path, e); } try { String eTag = conn.getHeaderField("ETag"); File tmpDir = new File(System.getProperty("java.io.tmpdir")); File modelDir = new File(tmpDir, "models"); FileUtils.forceMkdir(modelDir); if (eTag != null) { if (eTag.startsWith("\"") && eTag.endsWith("\"") && eTag.length() > 2) { eTag = eTag.substring(1, eTag.length() - 1); } File dir = new File(modelDir, eTag); if (dir.exists()) { logger.info("model folder already exists: {}", eTag); return dir; } } return unzip(conn.getInputStream(), eTag); } catch (SocketTimeoutException e) { throw new DownloadModelException("Download model timeout: " + path, e); } } private static ModelArchive load(String url, File dir, boolean extracted) throws InvalidModelException, IOException { boolean failed = true; try { File manifestFile = new File(dir, "MAR-INF/" + MANIFEST_FILE); Manifest manifest; if (manifestFile.exists()) { // Must be MMS 1.0 or later manifest = readFile(manifestFile, Manifest.class); } else { manifestFile = new File(dir, MANIFEST_FILE); boolean nested = false; if (!manifestFile.exists()) { // Found MANIFEST.json in top level; manifestFile = findFile(dir, MANIFEST_FILE, true); // for 0.1 model archive nested = true; } if (manifestFile == null) { // Must be 1.0 manifest = new Manifest(); } else { // 0.1 model may have extra parent directory LegacyManifest legacyManifest = readFile(manifestFile, LegacyManifest.class); manifest = legacyManifest.migrate(); File modelDir = manifestFile.getParentFile(); if (extracted && nested) { // Move all file to top level, so we can clean up properly. moveToTopLevel(modelDir, dir); } } } failed = false; return new ModelArchive(manifest, url, dir, extracted); } finally { if (extracted && failed) { FileUtils.deleteQuietly(dir); } } } private static T readFile(File file, Class type) throws InvalidModelException, IOException { try (Reader r = new InputStreamReader(new FileInputStream(file), StandardCharsets.UTF_8)) { return GSON.fromJson(r, type); } catch (JsonParseException e) { throw new InvalidModelException("Failed to parse signature.json.", e); } } private static File findFile(File dir, String fileName, boolean recursive) { File[] list = dir.listFiles(); if (list == null) { return null; } for (File file : list) { if (recursive && file.isDirectory()) { File f = findFile(file, fileName, false); if (f != null) { return f; } } else if (file.getName().equalsIgnoreCase(fileName)) { return file; } } return null; } private static void moveToTopLevel(File from, File to) throws IOException { File[] list = from.listFiles(); if (list != null) { for (File file : list) { if (file.isDirectory()) { FileUtils.moveDirectoryToDirectory(file, to, false); } else { FileUtils.moveFileToDirectory(file, to, false); } } } } public static File unzip(InputStream is, String eTag) throws IOException { File tmpDir = FileUtils.getTempDirectory(); File modelDir = new File(tmpDir, "models"); FileUtils.forceMkdir(modelDir); File tmp = File.createTempFile("model", ".download"); FileUtils.forceDelete(tmp); FileUtils.forceMkdir(tmp); MessageDigest md; try { md = MessageDigest.getInstance("SHA1"); } catch (NoSuchAlgorithmException e) { throw new AssertionError(e); } ZipUtils.unzip(new DigestInputStream(is, md), tmp); if (eTag == null) { eTag = Hex.toHexString(md.digest()); } File dir = new File(modelDir, eTag); if (dir.exists()) { FileUtils.deleteDirectory(tmp); logger.info("model folder already exists: {}", eTag); return dir; } FileUtils.moveDirectory(tmp, dir); return dir; } public void validate() throws InvalidModelException { Manifest.Model model = manifest.getModel(); try { if (model == null) { throw new InvalidModelException("Missing Model entry in manifest file."); } if (model.getModelName() == null) { throw new InvalidModelException("Model name is not defined."); } if (model.getHandler() == null) { throw new InvalidModelException("Model handler is not defined."); } if (manifest.getRuntime() == null) { throw new InvalidModelException("Runtime is not defined or invalid."); } if (manifest.getEngine() != null && manifest.getEngine().getEngineName() == null) { throw new InvalidModelException("engineName is required in ."); } } catch (InvalidModelException e) { clean(); throw e; } } public String getHandler() { return manifest.getModel().getHandler(); } public Manifest getManifest() { return manifest; } public String getUrl() { return url; } public File getModelDir() { return modelDir; } public String getModelName() { return manifest.getModel().getModelName(); } public void clean() { if (url != null && extracted) { FileUtils.deleteQuietly(modelDir); } } } ================================================ FILE: frontend/modelarchive/src/main/java/com/amazonaws/ml/mms/archive/ModelException.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.archive; public class ModelException extends Exception { static final long serialVersionUID = 1L; /** * Constructs an {@code ModelException} with the specified detail message. * * @param message The detail message (which is saved for later retrieval by the {@link * #getMessage()} method) */ public ModelException(String message) { super(message); } /** * Constructs an {@code ModelException} with the specified detail message and cause. * *

Note that the detail message associated with {@code cause} is not automatically * incorporated into this exception's detail message. * * @param message The detail message (which is saved for later retrieval by the {@link * #getMessage()} method) * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} * method). (A null value is permitted, and indicates that the cause is nonexistent or * unknown.) */ public ModelException(String message, Throwable cause) { super(message, cause); } } ================================================ FILE: frontend/modelarchive/src/main/java/com/amazonaws/ml/mms/archive/ModelNotFoundException.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.archive; public class ModelNotFoundException extends ModelException { static final long serialVersionUID = 1L; /** * Constructs an {@code ModelNotFoundException} with the specified detail message. * * @param message The detail message (which is saved for later retrieval by the {@link * #getMessage()} method) */ public ModelNotFoundException(String message) { super(message); } /** * Constructs an {@code ModelNotFoundException} with the specified detail message and cause. * *

Note that the detail message associated with {@code cause} is not automatically * incorporated into this exception's detail message. * * @param message The detail message (which is saved for later retrieval by the {@link * #getMessage()} method) * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} * method). (A null value is permitted, and indicates that the cause is nonexistent or * unknown.) */ public ModelNotFoundException(String message, Throwable cause) { super(message, cause); } } ================================================ FILE: frontend/modelarchive/src/main/java/com/amazonaws/ml/mms/archive/ZipUtils.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.archive; import java.io.File; import java.io.FileFilter; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.zip.ZipEntry; import java.util.zip.ZipInputStream; import java.util.zip.ZipOutputStream; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; public final class ZipUtils { private ZipUtils() {} public static void zip(File src, File dest, boolean includeRootDir) throws IOException { int prefix = src.getCanonicalPath().length(); if (includeRootDir) { prefix -= src.getName().length(); } try (ZipOutputStream zos = new ZipOutputStream(new FileOutputStream(dest))) { addToZip(prefix, src, null, zos); } } public static void unzip(File src, File dest) throws IOException { unzip(new FileInputStream(src), dest); } public static void unzip(InputStream is, File dest) throws IOException { try (ZipInputStream zis = new ZipInputStream(is)) { ZipEntry entry; while ((entry = zis.getNextEntry()) != null) { String name = entry.getName(); File file = new File(dest, name); if (entry.isDirectory()) { FileUtils.forceMkdir(file); } else { File parentFile = file.getParentFile(); FileUtils.forceMkdir(parentFile); try (OutputStream os = new FileOutputStream(file)) { IOUtils.copy(zis, os); } } } } } public static void addToZip(int prefix, File file, FileFilter filter, ZipOutputStream zos) throws IOException { String name = file.getCanonicalPath().substring(prefix); if (name.startsWith("/")) { name = name.substring(1); } if (file.isDirectory()) { if (!name.isEmpty()) { ZipEntry entry = new ZipEntry(name + '/'); zos.putNextEntry(entry); } File[] files = file.listFiles(filter); if (files != null) { for (File f : files) { addToZip(prefix, f, filter, zos); } } } else if (file.isFile()) { ZipEntry entry = new ZipEntry(name); zos.putNextEntry(entry); try (FileInputStream fis = new FileInputStream(file)) { IOUtils.copy(fis, zos); } } } } ================================================ FILE: frontend/modelarchive/src/test/java/com/amazonaws/ml/mms/archive/CoverageTest.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.archive; import com.amazonaws.ml.mms.test.TestHelper; import java.io.IOException; import org.testng.annotations.Test; public class CoverageTest { @Test public void test() throws IOException, ClassNotFoundException { TestHelper.testGetterSetters(ModelArchive.class); } } ================================================ FILE: frontend/modelarchive/src/test/java/com/amazonaws/ml/mms/archive/Exporter.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.archive; import com.google.gson.Gson; import com.google.gson.GsonBuilder; import java.io.File; import java.io.FileFilter; import java.io.FileOutputStream; import java.io.IOException; import java.net.URL; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.zip.ZipEntry; import java.util.zip.ZipOutputStream; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.DefaultParser; import org.apache.commons.cli.HelpFormatter; import org.apache.commons.cli.Option; import org.apache.commons.cli.Options; import org.apache.commons.cli.ParseException; public final class Exporter { private static final Gson GSON = new GsonBuilder().setPrettyPrinting().create(); private Exporter() {} public static void main(String[] args) { String jarName = getJarName(); Options options = Config.getOptions(); DefaultParser parser = new DefaultParser(); try { if (args.length == 0 || args[0].equalsIgnoreCase("-h") || args[0].equalsIgnoreCase("--help")) { printHelp("java -jar " + jarName + " ", options); return; } CommandLine cmd = parser.parse(options, args, null, false); List cmdArgs = cmd.getArgList(); if (cmdArgs.isEmpty()) { printHelp("java -jar " + jarName + " ", options); return; } Config config = new Config(cmd); String action = cmdArgs.get(0); if (!"export".equalsIgnoreCase(action)) { printHelp("java -jar " + jarName + " ", options); return; } String modelName = config.getModelName(); if (!modelName.matches("[A-Za-z][A-Za-z0-9_\\-.]+")) { System.err.println( "model-name must starts with letter and only allows alphanumeric characters, hyphens, underscore or dot."); return; } File modelPath = new File(config.getModelPath()).getCanonicalFile(); if (!modelPath.exists()) { System.err.println("model-path not found: " + modelName); return; } String output = config.getOutputFile(); File outputFile; if (output == null) { outputFile = new File(modelPath.getParentFile(), modelName + ".mar"); } else { outputFile = new File(output); } final String fileName = modelPath.getName(); if (modelPath.isFile() && fileName.endsWith(".model") || fileName.endsWith(".mar")) { ModelArchive.migrate(modelPath, outputFile); return; } if (!modelPath.isDirectory()) { System.err.println("model-path should be a directory or model archive file."); return; } File[] files = modelPath.listFiles(); if (files == null) { throw new AssertionError( "Failed list files in folder: " + modelPath.getAbsolutePath()); } Manifest manifest = new Manifest(); Manifest.Model model = new Manifest.Model(); manifest.setModel(model); String runtime = config.getRuntime(); if (runtime != null) { manifest.setRuntime(Manifest.RuntimeType.fromValue(runtime)); } File symbolFile = findUniqueFile(files, "-symbol.json"); if (symbolFile != null) { model.addExtension("symbolFile", symbolFile.getName()); } File paramsFile = findUniqueFile(files, ".params"); if (paramsFile != null) { model.addExtension("parametersFile", paramsFile.getName()); } String handler = config.getHandler(); if (handler == null) { File serviceFile = findUniqueFile(files, "_service.py"); if (serviceFile != null) { model.setHandler(serviceFile.getName()); } } else { Manifest.RuntimeType runtimeType = manifest.getRuntime(); if (runtimeType == Manifest.RuntimeType.PYTHON || runtimeType == Manifest.RuntimeType.PYTHON2 || runtimeType == Manifest.RuntimeType.PYTHON3) { String[] tokens = handler.split(":"); File serviceFile = new File(modelPath, tokens[0]); if (serviceFile.exists()) { System.err.println("handler file is not found in: " + modelPath); return; } } model.setHandler(handler); } try (ZipOutputStream zos = new ZipOutputStream(new FileOutputStream(outputFile))) { zos.putNextEntry(new ZipEntry("MANIFEST.json")); zos.write(GSON.toJson(manifest).getBytes(StandardCharsets.UTF_8)); int prefix = modelPath.getCanonicalPath().length(); FileFilter filter = pathname -> { if (pathname.isHidden()) { return false; } String name = pathname.getName(); return !"MANIFEST.json".equalsIgnoreCase(name); }; for (File file : files) { if (filter.accept(file)) { ZipUtils.addToZip(prefix, file, filter, zos); } } } catch (IOException e) { e.printStackTrace(); if (!outputFile.delete()) { outputFile.deleteOnExit(); } } } catch (InvalidModelException | IOException e) { System.err.println(e.getMessage()); } catch (ParseException e) { System.err.println(e.getMessage()); printHelp("java -jar " + jarName + " ", options); } } private static void printHelp(String message, Options options) { HelpFormatter formatter = new HelpFormatter(); formatter.setLeftPadding(1); formatter.setWidth(120); formatter.printHelp(message, options); } private static String getJarName() { URL url = Exporter.class.getProtectionDomain().getCodeSource().getLocation(); String path = url.getPath(); if ("file".equalsIgnoreCase(url.getProtocol())) { File file = new File(path); if (path.toLowerCase().endsWith(".jar")) { // we only support jar file for now return file.getName(); } } return null; } private static File findUniqueFile(File[] list, String extension) throws InvalidModelException { File ret = null; for (File file : list) { if (file.getName().endsWith(extension)) { if (ret != null) { throw new InvalidModelException( "Multiple " + extension + " file found in the path."); } ret = file; } } return ret; } private static final class Config { private String modelName; private String modelPath; private String handler; private String runtime; private String outputFile; public Config(CommandLine cmd) { modelName = cmd.getOptionValue("model-name"); modelPath = cmd.getOptionValue("model-path"); handler = cmd.getOptionValue("handler"); runtime = cmd.getOptionValue("runtime"); handler = cmd.getOptionValue("handler"); outputFile = cmd.getOptionValue("output-file"); } public static Options getOptions() { Options options = new Options(); options.addOption( Option.builder("n") .longOpt("model-name") .hasArg() .required() .argName("MODEL_NAME") .desc( "Exported model name. Exported file will be named as model-name.model and saved in current working directory.") .build()); options.addOption( Option.builder("p") .longOpt("model-path") .hasArg() .required() .argName("MODEL_PATH") .desc( "Path to the folder containing model related files or legacy model archive. Signature file is required.") .build()); options.addOption( Option.builder("r") .longOpt("runtime") .hasArg() .argName("RUNTIME") .desc( "The runtime environment for the MMS to execute your model custom code, default python2.7") .build()); options.addOption( Option.builder("e") .longOpt("engine") .hasArg() .argName("engine") .desc("The ML framework for your model, default MXNet") .build()); options.addOption( Option.builder("s") .longOpt("handler") .hasArg() .argName("HANDLER") .desc( "The entry-point within your code that MMS can call to begin execution.") .build()); options.addOption( Option.builder("o") .longOpt("output-file") .hasArg() .argName("OUTPUT_FILE") .desc("Output model archive file path.") .build()); return options; } public String getModelName() { return modelName; } public void setModelName(String modelName) { this.modelName = modelName; } public String getModelPath() { return modelPath; } public void setModelPath(String modelPath) { this.modelPath = modelPath; } public String getHandler() { return handler; } public void setHandler(String handler) { this.handler = handler; } public String getOutputFile() { return outputFile; } public void setOutputFile(String outputFile) { this.outputFile = outputFile; } public String getRuntime() { return runtime; } public void setRuntime(String runtime) { this.runtime = runtime; } } } ================================================ FILE: frontend/modelarchive/src/test/java/com/amazonaws/ml/mms/archive/ModelArchiveTest.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.archive; import java.io.File; import java.io.IOException; import org.apache.commons.io.FileUtils; import org.testng.Assert; import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; public class ModelArchiveTest { private File output; @BeforeTest public void beforeTest() { output = new File("build/tmp/test/noop.mar"); FileUtils.deleteQuietly(output); FileUtils.deleteQuietly(new File("build/tmp/test/noop")); FileUtils.deleteQuietly(new File("build/tmp/test/noop-v0.1.mar")); File tmp = FileUtils.getTempDirectory(); FileUtils.deleteQuietly(new File(tmp, "models")); } @Test public void test() throws ModelException, IOException { String modelStore = "src/test/resources/models"; // load 0.1 model from model folder ModelArchive archive = ModelArchive.downloadModel(modelStore, "noop-v0.1"); Assert.assertEquals(archive.getModelName(), "noop_v0.1"); // load 0.1 model from model archive File src = new File(modelStore, "noop-v0.1"); File target = new File("build/tmp/test", "noop-v0.1.mar"); ZipUtils.zip(src, target, false); archive = ModelArchive.downloadModel("build/tmp/test", "noop-v0.1.mar"); Assert.assertEquals(archive.getModelName(), "noop_v0.1"); // load model for s3 archive = ModelArchive.downloadModel( modelStore, "https://s3.amazonaws.com/model-server/models/squeezenet_v1.1/squeezenet_v1.1.model"); Assert.assertEquals(archive.getModelName(), "squeezenet_v1.1"); // test export String[] args = new String[4]; args[0] = "export"; args[1] = "--model-name=noop"; args[2] = "--model-path=" + archive.getModelDir(); args[3] = "--output-file=" + output.getAbsolutePath(); Exporter.main(args); Assert.assertTrue(output.exists()); FileUtils.forceDelete(output); ModelArchive.migrate(target, output); Assert.assertTrue(output.exists()); // load 1.0 model archive = ModelArchive.downloadModel(modelStore, "noop-v1.0"); Assert.assertEquals(archive.getModelName(), "noop"); } } ================================================ FILE: frontend/modelarchive/src/test/java/com/amazonaws/ml/mms/test/TestHelper.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.test; import java.io.File; import java.io.IOException; import java.lang.reflect.Constructor; import java.lang.reflect.Method; import java.net.URL; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Enumeration; import java.util.List; import java.util.jar.JarEntry; import java.util.jar.JarFile; import org.apache.commons.io.FileUtils; public final class TestHelper { private TestHelper() {} public static void testGetterSetters(Class baseClass) throws IOException, ClassNotFoundException { List> list = getClasses(baseClass); for (Class clazz : list) { Constructor[] constructors = clazz.getConstructors(); Object obj = null; for (Constructor con : constructors) { try { Class[] types = con.getParameterTypes(); Object[] args = new Object[types.length]; for (int i = 0; i < args.length; ++i) { args[i] = getMockValue(types[i]); } obj = con.newInstance(args); } catch (ReflectiveOperationException ignore) { // ignore } } if (obj == null) { continue; } Method[] methods = clazz.getMethods(); for (Method method : methods) { String methodName = method.getName(); int parameterCount = method.getParameterCount(); try { if (parameterCount == 0 && methodName.startsWith("get") || methodName.startsWith("is")) { method.invoke(obj); } else if (methodName.startsWith("set") && parameterCount == 1) { Class type = method.getParameterTypes()[0]; method.invoke(obj, getMockValue(type)); } } catch (ReflectiveOperationException ignore) { // ignore } } } } private static List> getClasses(Class clazz) throws IOException, ClassNotFoundException { URL url = clazz.getProtectionDomain().getCodeSource().getLocation(); String path = url.getPath(); if (!"file".equalsIgnoreCase(url.getProtocol())) { return Collections.emptyList(); } List> classList = new ArrayList<>(); File classPath = new File(path); if (classPath.isDirectory()) { String rootPath = classPath.getCanonicalPath(); String[] filters = new String[] {"class"}; Collection files = FileUtils.listFiles(classPath, filters, true); for (File file : files) { String fileName = file.getCanonicalPath(); fileName = fileName.substring(rootPath.length() + 1); fileName = fileName.substring(0, fileName.lastIndexOf(".")); fileName = fileName.replace(File.separatorChar, '.'); classList.add(Class.forName(fileName)); } } else if (path.toLowerCase().endsWith(".jar")) { try (JarFile jarFile = new JarFile(path)) { Enumeration en = jarFile.entries(); while (en.hasMoreElements()) { JarEntry entry = en.nextElement(); String fileName = entry.getName(); if (fileName.endsWith(".class")) { fileName = fileName.substring(0, fileName.lastIndexOf(".")); fileName = fileName.replace('/', '.'); classList.add(Class.forName(fileName)); } } } } return classList; } private static Object getMockValue(Class type) { if (type.isPrimitive()) { if (type == Boolean.TYPE) { return Boolean.TRUE; } if (type == Character.TYPE) { return '0'; } if (type == Byte.TYPE) { return (byte) 0; } if (type == Short.TYPE) { return (short) 0; } if (type == Integer.TYPE) { return 0; } if (type == Long.TYPE) { return 0L; } if (type == Float.TYPE) { return 0f; } if (type == Double.TYPE) { return 0d; } } return null; } } ================================================ FILE: frontend/modelarchive/src/test/resources/models/custom-return-code/MAR-INF/MANIFEST.json ================================================ { "specificationVersion": "1.0", "implementationVersion": "1.0", "description": "noop v1.0", "modelServerVersion": "1.0", "license": "Apache 2.0", "runtime": "python", "model": { "modelName": "pred-custom-return-code", "description": "Tests for custom return code", "modelVersion": "1.0", "handler": "service:handle" }, "publisher": { "author": "MXNet SDK team", "email": "noreply@amazon.com" } } ================================================ FILE: frontend/modelarchive/src/test/resources/models/custom-return-code/service.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. from mms.service import PredictionException def handle(data, ctx): # Data is not none in prediction request # Python raises PredictionException with custom error code if data is not None: raise PredictionException("Some Prediction Error", 599) return ["OK"] ================================================ FILE: frontend/modelarchive/src/test/resources/models/error_batch/MAR-INF/MANIFEST.json ================================================ { "specificationVersion": "1.0", "implementationVersion": "1.0", "description": "invalid model", "modelServerVersion": "1.0", "license": "Apache 2.0", "runtime": "python", "model": { "modelName": "err_batch", "description": "batch error model", "modelVersion": "1.0", "handler": "service:handle" }, "publisher": { "author": "MXNet SDK team", "email": "noreply@amazon.com" } } ================================================ FILE: frontend/modelarchive/src/test/resources/models/error_batch/service.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ InvalidService defines a invalid model handler for testing purpose. """ def handle(data, context): # This model is created to test reporting of an error in a batch of requests if data: context.set_response_status(code=507, idx=0) return ["Invalid response"] ================================================ FILE: frontend/modelarchive/src/test/resources/models/init-error/MAR-INF/MANIFEST.json ================================================ { "specificationVersion": "1.0", "implementationVersion": "1.0", "description": "init error model", "modelServerVersion": "1.0", "license": "Apache 2.0", "runtime": "python", "model": { "modelName": "init-error", "description": "invalid model that cannot init", "modelVersion": "1.0", "handler": "invalid_service:handle" }, "publisher": { "author": "MXNet SDK team", "email": "noreply@amazon.com" } } ================================================ FILE: frontend/modelarchive/src/test/resources/models/init-error/invalid_service.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ InvalidService defines a invalid model handler for testing purpose. """ def handle(data, context): raise RuntimeError("Initialize failure.") ================================================ FILE: frontend/modelarchive/src/test/resources/models/invalid/MAR-INF/MANIFEST.json ================================================ { "specificationVersion": "1.0", "implementationVersion": "1.0", "description": "invalid model", "modelServerVersion": "1.0", "license": "Apache 2.0", "runtime": "python", "model": { "modelName": "invalid", "description": "invalid model", "modelVersion": "1.0", "handler": "invalid_service:handle" }, "publisher": { "author": "MXNet SDK team", "email": "noreply@amazon.com" } } ================================================ FILE: frontend/modelarchive/src/test/resources/models/invalid/invalid_service.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ InvalidService defines a invalid model handler for testing purpose. """ def handle(data, context): return "Invalid response" ================================================ FILE: frontend/modelarchive/src/test/resources/models/loading-memory-error/MAR-INF/MANIFEST.json ================================================ { "specificationVersion": "1.0", "implementationVersion": "1.0", "description": "noop v1.0", "modelServerVersion": "1.0", "license": "Apache 2.0", "runtime": "python", "model": { "modelName": "mem-err", "description": "Tests for memory error", "modelVersion": "1.0", "handler": "service:handle" }, "publisher": { "author": "MXNet SDK team", "email": "noreply@amazon.com" } } ================================================ FILE: frontend/modelarchive/src/test/resources/models/loading-memory-error/service.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. def handle(ctx, data): # Python raises MemoryError when the python program goes out of memory. MMS expects this error from the handler # if the handlers can not allocate any further memory. raise MemoryError("Throwing memory error") ================================================ FILE: frontend/modelarchive/src/test/resources/models/logging/MAR-INF/MANIFEST.json ================================================ { "specificationVersion": "1.0", "implementationVersion": "1.0", "description": "logging v1.0", "modelServerVersion": "1.0", "license": "Apache 2.0", "runtime": "python", "model": { "modelName": "logging", "description": "logging test model", "modelVersion": "1.0", "handler": "service:handle" }, "publisher": { "author": "MXNet SDK team", "email": "noreply@amazon.com" } } ================================================ FILE: frontend/modelarchive/src/test/resources/models/logging/service.py ================================================ # Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ LoggingService defines a no operational model handler. """ import logging import time import os class LoggingService(object): """ Logging Model handler implementation. Extend from BaseModelHandler is optional """ def __init__(self): logging.info("LoggingService init") self._context = None self.initialized = False def __del__(self): logging.info("LoggingService exit") def initialize(self, context): """ Initialize model. This will be called during model loading time :param context: model server context :return: """ self.initialized = True self._context = context @staticmethod def inference(model_input): """ Internal inference methods :param model_input: transformed model input data :return: inference results """ time.sleep(0.01) logging.info("LoggingService inference [PID]: %d", os.getpid()) return ["{} OK\n".format(os.getpid())] * len(model_input) def handle(self, data, context): """ Custom service entry point function. :param context: model server context :param data: list of objects, raw input from request :return: list of outputs to be send back to client """ # Add your initialization code here properties = context.system_properties try: start_time = time.time() data = self.inference(data) end_time = time.time() context.set_response_content_type(0, "text/plain") content_type = context.request_processor[0].get_request_property("Content-Type") return data except Exception as e: logging.error(e, exc_info=True) context.request_processor[0].report_status(500, "Unknown inference error.") return ["Error {}".format(str(e))] * len(data) _service = LoggingService() def handle(data, context): if not _service.initialized: _service.initialize(context) if data is None: return None return _service.handle(data, context) ================================================ FILE: frontend/modelarchive/src/test/resources/models/noop-no-manifest/service.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ NoopService defines a no operational model handler. """ import logging import time class NoopService(object): """ Noop Model handler implementation. Extend from BaseModelHandler is optional """ def __init__(self): self._context = None self.initialized = False def initialize(self, context): """ Initialize model. This will be called during model loading time :param context: model server context :return: """ self.initialized = True self._context = context @staticmethod def preprocess(data): """ Transform raw input into model input data. :param data: list of objects, raw input from request :return: list of model input data """ return data @staticmethod def inference(model_input): """ Internal inference methods :param model_input: transformed model input data :return: inference results """ return model_input @staticmethod def postprocess(model_output): return ["OK"] * len(model_output) def handle(self, data, context): """ Custom service entry point function. :param context: model server context :param data: list of objects, raw input from request :return: list of outputs to be send back to client """ # Add your initialization code here request_processor = context.request_processor try: data = self.preprocess(data) data = self.inference(data) data = self.postprocess(data) context.set_response_content_type(0, "text/plain") return data except Exception as e: logging.error(e, exc_info=True) request_processor[0].report_status(500, "Unknown inference error.") return ["Error {}".format(str(e))] * len(data) _service = NoopService() def handle(data, context): if not _service.initialized: _service.initialize(context) if data is None: return None return _service.handle(data, context) ================================================ FILE: frontend/modelarchive/src/test/resources/models/noop-v0.1/MANIFEST.json ================================================ { "Created-By": { "Author": "MXNet SDK team", "Author-Email": "noreply@amazon.com" }, "Engine": { "MXNet": 0.12 }, "Model-Archive-Description": "noop v0.1", "License": "Apache 2.0", "Model-Archive-Version": 0.1, "Model-Server": 0.4, "Model": { "Description": "no operation model", "Service": "noop_service.py", "Symbol": "noop-symbol.json", "Parameters": "noop-0000.params", "Signature": "signature.json", "Model-Name": "noop_v0.1", "Model-Format": "MXNet-Symbolic" } } ================================================ FILE: frontend/modelarchive/src/test/resources/models/noop-v0.1/noop_service.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ NoopService defines a noop service """ from mms.model_service.mxnet_model_service import SingleNodeService class NoopService(SingleNodeService): """ NoopService defines a noop service. """ def _inference(self, data): return "OK" def ping(self): return None ================================================ FILE: frontend/modelarchive/src/test/resources/models/noop-v0.1/signature.json ================================================ { "inputs": [ { "data_name": "data", "shape": [ 0, 3, 224, 224 ] } ], "input_type": "image/jpeg", "outputs": [ { "data_name": "data", "shape": [ 0, 3, 224, 224 ] } ], "output_type": "application/json" } ================================================ FILE: frontend/modelarchive/src/test/resources/models/noop-v1.0/MAR-INF/MANIFEST.json ================================================ { "specificationVersion": "1.0", "implementationVersion": "1.0", "description": "noop v1.0", "modelServerVersion": "1.0", "license": "Apache 2.0", "runtime": "python", "model": { "modelName": "noop", "description": "no operation model", "modelVersion": "1.0", "handler": "service:handle" }, "publisher": { "author": "MXNet SDK team", "email": "noreply@amazon.com" } } ================================================ FILE: frontend/modelarchive/src/test/resources/models/noop-v1.0/service.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ NoopService defines a no operational model handler. """ import logging import time class NoopService(object): """ Noop Model handler implementation. Extend from BaseModelHandler is optional """ def __init__(self): self._context = None self.initialized = False def initialize(self, context): """ Initialize model. This will be called during model loading time :param context: model server context :return: """ self.initialized = True self._context = context @staticmethod def preprocess(data): """ Transform raw input into model input data. :param data: list of objects, raw input from request :return: list of model input data """ return data @staticmethod def inference(model_input): """ Internal inference methods :param model_input: transformed model input data :return: inference results """ return model_input @staticmethod def postprocess(model_output): return ["OK"] * len(model_output) def handle(self, data, context): """ Custom service entry point function. :param context: model server context :param data: list of objects, raw input from request :return: list of outputs to be send back to client """ # Add your initialization code here properties = context.system_properties server_name = properties.get("server_name") server_version = properties.get("server_version") model_dir = properties.get("model_dir") gpu_id = properties.get("gpu_id") batch_size = properties.get("batch_size") logging.debug("server_name: {}".format(server_name)) logging.debug("server_version: {}".format(server_version)) logging.debug("model_dir: {}".format(model_dir)) logging.debug("gpu_id: {}".format(gpu_id)) logging.debug("batch_size: {}".format(batch_size)) try: preprocess_start = time.time() data = self.preprocess(data) inference_start = time.time() data = self.inference(data) postprocess_start = time.time() data = self.postprocess(data) end_time = time.time() context.set_response_content_type(0, "text/plain") content_type = context.request_processor[0].get_request_property("Content-Type") logging.debug("content_type: {}".format(content_type)) metrics = context.metrics metrics.add_time("PreprocessTime", round((inference_start - preprocess_start) * 1000, 2)) metrics.add_time("InferenceTime", round((postprocess_start - inference_start) * 1000, 2)) metrics.add_time("PostprocessTime", round((end_time - postprocess_start) * 1000, 2)) return data except Exception as e: logging.error(e, exc_info=True) context.request_processor[0].report_status(500, "Unknown inference error.") return ["Error {}".format(str(e))] * len(data) _service = NoopService() def handle(data, context): if not _service.initialized: _service.initialize(context) if data is None: return None return _service.handle(data, context) ================================================ FILE: frontend/modelarchive/src/test/resources/models/noop-v1.0-config-tests/MAR-INF/MANIFEST.json ================================================ { "specificationVersion": "1.0", "implementationVersion": "1.0", "description": "noop v1.0", "modelServerVersion": "1.0", "license": "Apache 2.0", "runtime": "python", "model": { "modelName": "noop-config", "description": "no operation model", "modelVersion": "1.0", "handler": "service:handle" }, "publisher": { "author": "MXNet SDK team", "email": "noreply@amazon.com" } } ================================================ FILE: frontend/modelarchive/src/test/resources/models/noop-v1.0-config-tests/service.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ NoopService defines a no operational model handler. """ import logging import time class NoopService(object): """ Noop Model handler implementation. Extend from BaseModelHandler is optional """ def __init__(self): self._context = None self.initialized = False def initialize(self, context): """ Initialize model. This will be called during model loading time :param context: model server context :return: """ self.initialized = True self._context = context @staticmethod def preprocess(data): """ Transform raw input into model input data. :param data: list of objects, raw input from request :return: list of model input data """ return data @staticmethod def inference(model_input): """ Internal inference methods :param model_input: transformed model input data :return: inference results """ return model_input @staticmethod def postprocess(model_output): return [str(model_output)] def handle(self, data, context): """ Custom service entry point function. :param context: model server context :param data: list of objects, raw input from request :return: list of outputs to be send back to client """ # Add your initialization code here properties = context.system_properties server_name = properties.get("server_name") server_version = properties.get("server_version") model_dir = properties.get("model_dir") gpu_id = properties.get("gpu_id") batch_size = properties.get("batch_size") logging.debug("server_name: {}".format(server_name)) logging.debug("server_version: {}".format(server_version)) logging.debug("model_dir: {}".format(model_dir)) logging.debug("gpu_id: {}".format(gpu_id)) logging.debug("batch_size: {}".format(batch_size)) try: preprocess_start = time.time() data = self.preprocess(data) inference_start = time.time() data = self.inference(data) postprocess_start = time.time() data = self.postprocess(data) end_time = time.time() context.set_response_content_type(0, "text/plain") content_type = context.get_request_header(0, "Content-Type") logging.debug("content_type: {}".format(content_type)) metrics = context.metrics metrics.add_time("PreprocessTime", round((inference_start - preprocess_start) * 1000, 2)) metrics.add_time("InferenceTime", round((postprocess_start - inference_start) * 1000, 2)) metrics.add_time("PostprocessTime", round((end_time - postprocess_start) * 1000, 2)) return data except Exception as e: logging.error(e, exc_info=True) context.request_processor.report_status(500, "Unknown inference error.") return ["Error {}".format(str(e))] * len(data) _service = NoopService() def handle(data, context): if not _service.initialized: _service.initialize(context) if data is None: return None return _service.handle(data, context) ================================================ FILE: frontend/modelarchive/src/test/resources/models/prediction-memory-error/MAR-INF/MANIFEST.json ================================================ { "specificationVersion": "1.0", "implementationVersion": "1.0", "description": "noop v1.0", "modelServerVersion": "1.0", "license": "Apache 2.0", "runtime": "python", "model": { "modelName": "pred-mem-err", "description": "Tests for memory error", "modelVersion": "1.0", "handler": "service:handle" }, "publisher": { "author": "MXNet SDK team", "email": "noreply@amazon.com" } } ================================================ FILE: frontend/modelarchive/src/test/resources/models/prediction-memory-error/service.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. def handle(data, ctx): # Data is not none in prediction request # Python raises MemoryError when the python program goes out of memory. MMS expects this error from the handler # if the handlers can not allocate any further memory. if data is not None: raise MemoryError("Some Memory Error") return ["OK"] ================================================ FILE: frontend/modelarchive/src/test/resources/models/respheader-test/MAR-INF/MANIFEST.json ================================================ { "specificationVersion": "1.0", "implementationVersion": "1.0", "description": "noop v1.0", "modelServerVersion": "1.0", "license": "Apache 2.0", "runtime": "python", "model": { "modelName": "respheader", "description": "Tests for response headers", "modelVersion": "1.0", "handler": "service:handle" }, "publisher": { "author": "MXNet SDK team", "email": "noreply@amazon.com" } } ================================================ FILE: frontend/modelarchive/src/test/resources/models/respheader-test/service.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ NoopService defines a no operational model handler. """ import logging import time class NoopService(object): """ Noop Model handler implementation. Extend from BaseModelHandler is optional """ def __init__(self): self._context = None self.initialized = False def initialize(self, context): """ Initialize model. This will be called during model loading time :param context: model server context :return: """ self.initialized = True self._context = context @staticmethod def preprocess(data): """ Transform raw input into model input data. :param data: list of objects, raw input from request :return: list of model input data """ return data @staticmethod def inference(model_input): """ Internal inference methods :param model_input: transformed model input data :return: inference results """ return model_input @staticmethod def postprocess(model_output): return [str(model_output)] def handle(self, data, context): """ Custom service entry point function. :param context: model server context :param data: list of objects, raw input from request :return: list of outputs to be send back to client """ # Add your initialization code here properties = context.system_properties server_name = properties.get("server_name") server_version = properties.get("server_version") model_dir = properties.get("model_dir") gpu_id = properties.get("gpu_id") batch_size = properties.get("batch_size") logging.debug("server_name: {}".format(server_name)) logging.debug("server_version: {}".format(server_version)) logging.debug("model_dir: {}".format(model_dir)) logging.debug("gpu_id: {}".format(gpu_id)) logging.debug("batch_size: {}".format(batch_size)) request_processor = context.request_processor try: data = self.preprocess(data) data = self.inference(data) data = self.postprocess(data) context.set_response_content_type(0, "text/plain") context.set_response_header(0, "dummy", "1") return data except Exception as e: logging.error(e, exc_info=True) request_processor.report_status(500, "Unknown inference error.") return ["Error {}".format(str(e))] * len(data) _service = NoopService() def handle(data, context): if not _service.initialized: _service.initialize(context) if data is None: return None return _service.handle(data, context) ================================================ FILE: frontend/server/build.gradle ================================================ dependencies { compile "io.netty:netty-all:${netty_version}" compile project(":modelarchive") compile "commons-cli:commons-cli:${commons_cli_version}" compile "software.amazon.ai:mms-plugins-sdk:${mms_server_sdk_version}" compile "com.lmax:disruptor:${lmax_disruptor_version}" testCompile "org.testng:testng:${testng_version}" } apply from: file("${project.rootProject.projectDir}/tools/gradle/launcher.gradle") jar { manifest { attributes 'Main-Class': 'com.amazonaws.ml.mms.ModelServer' } includeEmptyDirs = false from configurations.runtime.collect { it.isDirectory() ? it : zipTree(it) } exclude "META-INF/maven/**" exclude "META-INF/INDEX.LIST" exclude "META-INF/MANIFEST*" exclude "META-INF//LICENSE*" exclude "META-INF//NOTICE*" } test { doFirst { systemProperty "mmsConfigFile", 'src/test/resources/config.properties' } environment "METRICS_LOCATION","build/logs" environment "LOG_LOCATION","build/logs" } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/ModelServer.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms; import com.amazonaws.ml.mms.archive.ModelArchive; import com.amazonaws.ml.mms.archive.ModelException; import com.amazonaws.ml.mms.metrics.MetricManager; import com.amazonaws.ml.mms.servingsdk.impl.PluginsManager; import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.util.Connector; import com.amazonaws.ml.mms.util.ConnectorType; import com.amazonaws.ml.mms.util.ServerGroups; import com.amazonaws.ml.mms.wlm.ModelManager; import com.amazonaws.ml.mms.wlm.WorkLoadManager; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.FixedRecvByteBufAllocator; import io.netty.channel.ServerChannel; import io.netty.handler.ssl.SslContext; import io.netty.util.internal.logging.InternalLoggerFactory; import io.netty.util.internal.logging.Slf4JLoggerFactory; import java.io.File; import java.io.IOException; import java.lang.annotation.Annotation; import java.security.GeneralSecurityException; import java.util.ArrayList; import java.util.HashMap; import java.util.InvalidPropertiesFormatException; import java.util.List; import java.util.ServiceLoader; import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.DefaultParser; import org.apache.commons.cli.HelpFormatter; import org.apache.commons.cli.Options; import org.apache.commons.cli.ParseException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.ai.mms.servingsdk.ModelServerEndpoint; import software.amazon.ai.mms.servingsdk.annotations.Endpoint; import software.amazon.ai.mms.servingsdk.annotations.helpers.EndpointTypes; public class ModelServer { private Logger logger = LoggerFactory.getLogger(ModelServer.class); private ServerGroups serverGroups; private List futures = new ArrayList<>(2); private AtomicBoolean stopped = new AtomicBoolean(false); private ConfigManager configManager; public static final int MAX_RCVBUF_SIZE = 4096; /** Creates a new {@code ModelServer} instance. */ public ModelServer(ConfigManager configManager) { this.configManager = configManager; serverGroups = new ServerGroups(configManager); } public static void main(String[] args) { Options options = ConfigManager.Arguments.getOptions(); try { DefaultParser parser = new DefaultParser(); CommandLine cmd = parser.parse(options, args, null, false); ConfigManager.Arguments arguments = new ConfigManager.Arguments(cmd); ConfigManager.init(arguments); ConfigManager configManager = ConfigManager.getInstance(); PluginsManager.getInstance().initialize(); InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE); new ModelServer(configManager).startAndWait(); } catch (IllegalArgumentException e) { System.out.println("Invalid configuration: " + e.getMessage()); // NOPMD } catch (ParseException e) { HelpFormatter formatter = new HelpFormatter(); formatter.setLeftPadding(1); formatter.setWidth(120); formatter.printHelp(e.getMessage(), options); } catch (Throwable t) { t.printStackTrace(); // NOPMD } finally { System.exit(1); // NOPMD } } public void startAndWait() throws InterruptedException, IOException, GeneralSecurityException { try { List channelFutures = start(); // Create and schedule metrics manager MetricManager.scheduleMetrics(configManager); System.out.println("Model server started."); // NOPMD channelFutures.get(0).sync(); } catch (InvalidPropertiesFormatException e) { logger.error("Invalid configuration", e); } finally { serverGroups.shutdown(true); logger.info("Model server stopped."); } } private String getDefaultModelName(String name) { if (name.contains(".model") || name.contains(".mar")) { return name.substring(name.lastIndexOf('/') + 1, name.lastIndexOf('.')) .replaceAll("(\\W|^_)", "_"); } else { return name.substring(name.lastIndexOf('/') + 1).replaceAll("(\\W|^_)", "_"); } } private void initModelStore() { WorkLoadManager wlm = new WorkLoadManager(configManager, serverGroups.getBackendGroup()); ModelManager.init(configManager, wlm); Set startupModels = ModelManager.getInstance().getStartupModels(); String defaultModelName; String loadModels = configManager.getLoadModels(); if (loadModels == null || loadModels.isEmpty()) { return; } ModelManager modelManager = ModelManager.getInstance(); int workers = configManager.getDefaultWorkers(); if ("ALL".equalsIgnoreCase(loadModels)) { String modelStore = configManager.getModelStore(); if (modelStore == null) { logger.warn("Model store is not configured."); return; } File modelStoreDir = new File(modelStore); if (!modelStoreDir.exists()) { logger.warn("Model store path is not found: {}", modelStore); return; } // Check folders to see if they can be models as well File[] files = modelStoreDir.listFiles(); if (files != null) { for (File file : files) { if (file.isHidden()) { continue; } String fileName = file.getName(); if (file.isFile() && !fileName.endsWith(".mar") && !fileName.endsWith(".model")) { continue; } try { logger.debug( "Loading models from model store: {} preload_model: {}", file.getName(), configManager.getPreloadModel()); defaultModelName = getDefaultModelName(fileName); ModelArchive archive = modelManager.registerModel( file.getName(), defaultModelName, configManager.getPreloadModel()); modelManager.updateModel(archive.getModelName(), workers, workers); startupModels.add(archive.getModelName()); } catch (ModelException | IOException | InterruptedException | ExecutionException | TimeoutException e) { logger.warn("Failed to load model: " + file.getAbsolutePath(), e); } } } return; } String[] models = loadModels.split(","); for (String model : models) { String[] pair = model.split("=", 2); String modelName = null; String url; if (pair.length == 1) { url = pair[0]; } else { modelName = pair[0]; url = pair[1]; } if (url.isEmpty()) { continue; } try { logger.info( "Loading initial models: {} preload_model: {}", url, configManager.getPreloadModel()); defaultModelName = getDefaultModelName(url); ModelArchive archive = modelManager.registerModel( url, modelName, null, null, 1, 100, configManager.getDefaultResponseTimeoutSeconds(), defaultModelName, configManager.getPreloadModel()); modelManager.updateModel(archive.getModelName(), workers, workers); startupModels.add(archive.getModelName()); } catch (ModelException | IOException | InterruptedException | ExecutionException | TimeoutException e) { logger.warn("Failed to load model: " + url, e); } } } private void exitModelStore() { for (String modelName : ModelManager.getInstance().getModels().keySet()) { ModelManager.getInstance().unregisterModel(modelName); } } public ChannelFuture initializeServer( Connector connector, EventLoopGroup serverGroup, EventLoopGroup workerGroup, ConnectorType type) throws InterruptedException, IOException, GeneralSecurityException { final String purpose = connector.getPurpose(); Class channelClass = connector.getServerChannel(); logger.info("Initialize {} server with: {}.", purpose, channelClass.getSimpleName()); ServerBootstrap b = new ServerBootstrap(); b.option(ChannelOption.SO_BACKLOG, 1024) .channel(channelClass) .childOption(ChannelOption.SO_LINGER, 0) .childOption(ChannelOption.SO_REUSEADDR, true) .childOption(ChannelOption.SO_KEEPALIVE, true) .childOption( ChannelOption.RCVBUF_ALLOCATOR, new FixedRecvByteBufAllocator(MAX_RCVBUF_SIZE)); b.group(serverGroup, workerGroup); SslContext sslCtx = null; if (connector.isSsl()) { sslCtx = configManager.getSslContext(); } b.childHandler(new ServerInitializer(sslCtx, type)); ChannelFuture future; try { future = b.bind(connector.getSocketAddress()).sync(); } catch (Exception e) { // https://github.com/netty/netty/issues/2597 if (e instanceof IOException) { throw new IOException("Failed to bind to address: " + connector, e); } throw e; } future.addListener( (ChannelFutureListener) f -> { if (!f.isSuccess()) { try { f.get(); } catch (InterruptedException | ExecutionException e) { logger.error("", e); } System.exit(-1); // NO PMD } serverGroups.registerChannel(f.channel()); }); future.sync(); ChannelFuture f = future.channel().closeFuture(); f.addListener( (ChannelFutureListener) listener -> logger.info("{} model server stopped.", purpose)); logger.info("{} API bind to: {}", purpose, connector); return f; } /** * Main Method that prepares the future for the channel and sets up the ServerBootstrap. * * @return A ChannelFuture object * @throws InterruptedException if interrupted */ public List start() throws InterruptedException, IOException, GeneralSecurityException { stopped.set(false); configManager.validateConfigurations(); logger.info(configManager.dumpConfigurations()); initModelStore(); Connector inferenceConnector = configManager.getListener(false); Connector managementConnector = configManager.getListener(true); inferenceConnector.clean(); managementConnector.clean(); EventLoopGroup serverGroup = serverGroups.getServerGroup(); EventLoopGroup workerGroup = serverGroups.getChildGroup(); futures.clear(); if (!inferenceConnector.equals(managementConnector)) { futures.add( initializeServer( inferenceConnector, serverGroup, workerGroup, ConnectorType.INFERENCE_CONNECTOR)); futures.add( initializeServer( managementConnector, serverGroup, workerGroup, ConnectorType.MANAGEMENT_CONNECTOR)); } else { futures.add( initializeServer( inferenceConnector, serverGroup, workerGroup, ConnectorType.BOTH)); } return futures; } private boolean validEndpoint(Annotation a, EndpointTypes type) { return a instanceof Endpoint && !((Endpoint) a).urlPattern().isEmpty() && ((Endpoint) a).endpointType().equals(type); } private HashMap registerEndpoints(EndpointTypes type) { ServiceLoader loader = ServiceLoader.load(ModelServerEndpoint.class); HashMap ep = new HashMap<>(); for (ModelServerEndpoint mep : loader) { Class modelServerEndpointClassObj = mep.getClass(); Annotation[] annotations = modelServerEndpointClassObj.getAnnotations(); for (Annotation a : annotations) { if (validEndpoint(a, type)) { ep.put(((Endpoint) a).urlPattern(), mep); } } } return ep; } public boolean isRunning() { return !stopped.get(); } public void stop() { if (stopped.get()) { return; } stopped.set(true); for (ChannelFuture future : futures) { future.channel().close(); } serverGroups.shutdown(true); serverGroups.init(); exitModelStore(); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/ServerInitializer.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms; import com.amazonaws.ml.mms.http.ApiDescriptionRequestHandler; import com.amazonaws.ml.mms.http.HttpRequestHandler; import com.amazonaws.ml.mms.http.HttpRequestHandlerChain; import com.amazonaws.ml.mms.http.InferenceRequestHandler; import com.amazonaws.ml.mms.http.InvalidRequestHandler; import com.amazonaws.ml.mms.http.ManagementRequestHandler; import com.amazonaws.ml.mms.servingsdk.impl.PluginsManager; import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.util.ConnectorType; import io.netty.channel.Channel; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpServerCodec; import io.netty.handler.ssl.SslContext; /** * A special {@link io.netty.channel.ChannelInboundHandler} which offers an easy way to initialize a * {@link io.netty.channel.Channel} once it was registered to its {@link * io.netty.channel.EventLoop}. */ public class ServerInitializer extends ChannelInitializer { private ConnectorType connectorType; private SslContext sslCtx; /** * Creates a new {@code HttpRequestHandler} instance. * * @param sslCtx null if SSL is not enabled * @param type true to initialize a management server instead of an API Server */ public ServerInitializer(SslContext sslCtx, ConnectorType type) { this.sslCtx = sslCtx; this.connectorType = type; } /** {@inheritDoc} */ @Override public void initChannel(Channel ch) { ChannelPipeline pipeline = ch.pipeline(); HttpRequestHandlerChain apiDescriptionRequestHandler = new ApiDescriptionRequestHandler(connectorType); HttpRequestHandlerChain invalidRequestHandler = new InvalidRequestHandler(); int maxRequestSize = ConfigManager.getInstance().getMaxRequestSize(); if (sslCtx != null) { pipeline.addLast("ssl", sslCtx.newHandler(ch.alloc())); } pipeline.addLast("http", new HttpServerCodec()); pipeline.addLast("aggregator", new HttpObjectAggregator(maxRequestSize)); HttpRequestHandlerChain httpRequestHandlerChain = apiDescriptionRequestHandler; if (ConnectorType.BOTH.equals(connectorType) || ConnectorType.INFERENCE_CONNECTOR.equals(connectorType)) { httpRequestHandlerChain = httpRequestHandlerChain.setNextHandler( new InferenceRequestHandler( PluginsManager.getInstance().getInferenceEndpoints())); } if (ConnectorType.BOTH.equals(connectorType) || ConnectorType.MANAGEMENT_CONNECTOR.equals(connectorType)) { httpRequestHandlerChain = httpRequestHandlerChain.setNextHandler( new ManagementRequestHandler( PluginsManager.getInstance().getManagementEndpoints())); } httpRequestHandlerChain.setNextHandler(invalidRequestHandler); pipeline.addLast("handler", new HttpRequestHandler(apiDescriptionRequestHandler)); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/ApiDescriptionRequestHandler.java ================================================ package com.amazonaws.ml.mms.http; import com.amazonaws.ml.mms.archive.ModelException; import com.amazonaws.ml.mms.openapi.OpenApiUtils; import com.amazonaws.ml.mms.util.ConnectorType; import com.amazonaws.ml.mms.util.NettyUtils; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.QueryStringDecoder; public class ApiDescriptionRequestHandler extends HttpRequestHandlerChain { private ConnectorType connectorType; public ApiDescriptionRequestHandler(ConnectorType type) { connectorType = type; } @Override protected void handleRequest( ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String[] segments) throws ModelException { if (isApiDescription(segments)) { String path = decoder.path(); if (("/".equals(path) && HttpMethod.OPTIONS.equals(req.method())) || (segments.length == 2 && segments[1].equals("api-description"))) { handleApiDescription(ctx); return; } throw new MethodNotAllowedException(); } else { chain.handleRequest(ctx, req, decoder, segments); } } private boolean isApiDescription(String[] segments) { return segments.length == 0 || segments[1].equals("api-description"); } private void handleApiDescription(ChannelHandlerContext ctx) { NettyUtils.sendJsonResponse(ctx, OpenApiUtils.listApis(connectorType)); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/BadRequestException.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.http; public class BadRequestException extends IllegalArgumentException { private static final long serialVersionUID = 1L; /** * Constructs an {@code BadRequestException} with the specified detail message. * * @param message The detail message (which is saved for later retrieval by the {@link * #getMessage()} method) */ public BadRequestException(String message) { super(message); } /** * Constructs an {@code BadRequestException} with the specified detail message and cause. * *

Note that the detail message associated with {@code cause} is not automatically * incorporated into this exception's detail message. * * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} * method). (A null value is permitted, and indicates that the cause is nonexistent or * unknown.) */ public BadRequestException(Throwable cause) { super(cause.getMessage(), cause); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/ConflictStatusException.java ================================================ /* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.http; public class ConflictStatusException extends IllegalArgumentException { private static final long serialVersionUID = 1L; /** * Constructs an {@code ConflictStatusException} with the specified detail message. * * @param message The detail message (which is saved for later retrieval by the {@link * #getMessage()} method) */ public ConflictStatusException(String message) { super(message); } /** * Constructs an {@code ConflictStatusException} with the specified detail message and cause. * *

Note that the detail message associated with {@code cause} is not automatically * incorporated into this exception's detail message. * * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} * method). (A null value is permitted, and indicates that the cause is nonexistent or * unknown.) */ public ConflictStatusException(Throwable cause) { super(cause.getMessage(), cause); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/DescribeModelResponse.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.http; import java.util.ArrayList; import java.util.Date; import java.util.List; public class DescribeModelResponse { private String modelName; private String modelVersion; private String modelUrl; private String engine; private String runtime; private int minWorkers; private int maxWorkers; private int batchSize; private int maxBatchDelay; private String status; private boolean loadedAtStartup; private List workers; private Metrics metrics; public DescribeModelResponse() { workers = new ArrayList<>(); } public String getModelName() { return modelName; } public void setModelName(String modelName) { this.modelName = modelName; } public boolean getLoadedAtStartup() { return loadedAtStartup; } public void setLoadedAtStartup(boolean loadedAtStartup) { this.loadedAtStartup = loadedAtStartup; } public String getModelVersion() { return modelVersion; } public void setModelVersion(String modelVersion) { this.modelVersion = modelVersion; } public String getModelUrl() { return modelUrl; } public void setModelUrl(String modelUrl) { this.modelUrl = modelUrl; } public String getEngine() { return engine; } public void setEngine(String engine) { this.engine = engine; } public String getRuntime() { return runtime; } public void setRuntime(String runtime) { this.runtime = runtime; } public int getMinWorkers() { return minWorkers; } public void setMinWorkers(int minWorkers) { this.minWorkers = minWorkers; } public int getMaxWorkers() { return maxWorkers; } public void setMaxWorkers(int maxWorkers) { this.maxWorkers = maxWorkers; } public int getBatchSize() { return batchSize; } public void setBatchSize(int batchSize) { this.batchSize = batchSize; } public int getMaxBatchDelay() { return maxBatchDelay; } public void setMaxBatchDelay(int maxBatchDelay) { this.maxBatchDelay = maxBatchDelay; } public String getStatus() { return status; } public void setStatus(String status) { this.status = status; } public List getWorkers() { return workers; } public void setWorkers(List workers) { this.workers = workers; } public void addWorker( String id, long startTime, boolean isRunning, int gpuId, long memoryUsage) { Worker worker = new Worker(); worker.setId(id); worker.setStartTime(new Date(startTime)); worker.setStatus(isRunning ? "READY" : "UNLOADING"); worker.setGpu(gpuId >= 0); worker.setMemoryUsage(memoryUsage); workers.add(worker); } public Metrics getMetrics() { return metrics; } public void setMetrics(Metrics metrics) { this.metrics = metrics; } public static final class Worker { private String id; private Date startTime; private String status; private boolean gpu; private long memoryUsage; public Worker() {} public String getId() { return id; } public void setId(String id) { this.id = id; } public Date getStartTime() { return startTime; } public void setStartTime(Date startTime) { this.startTime = startTime; } public String getStatus() { return status; } public void setStatus(String status) { this.status = status; } public boolean isGpu() { return gpu; } public void setGpu(boolean gpu) { this.gpu = gpu; } public long getMemoryUsage() { return memoryUsage; } public void setMemoryUsage(long memoryUsage) { this.memoryUsage = memoryUsage; } } public static final class Metrics { private int rejectedRequests; private int waitingQueueSize; private int requests; public int getRejectedRequests() { return rejectedRequests; } public void setRejectedRequests(int rejectedRequests) { this.rejectedRequests = rejectedRequests; } public int getWaitingQueueSize() { return waitingQueueSize; } public void setWaitingQueueSize(int waitingQueueSize) { this.waitingQueueSize = waitingQueueSize; } public int getRequests() { return requests; } public void setRequests(int requests) { this.requests = requests; } } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/ErrorResponse.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.http; public class ErrorResponse { private int code; private String type; private String message; public ErrorResponse() {} public ErrorResponse(int code, String message) { this.code = code; this.message = message; } public ErrorResponse(int code, String type, String message) { this.code = code; this.type = type; this.message = message; } public int getCode() { return code; } public String getType() { return type; } public String getMessage() { return message; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandler.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.http; import com.amazonaws.ml.mms.archive.ModelException; import com.amazonaws.ml.mms.archive.ModelNotFoundException; import com.amazonaws.ml.mms.util.NettyUtils; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.QueryStringDecoder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * A class handling inbound HTTP requests. * *

This class */ public class HttpRequestHandler extends SimpleChannelInboundHandler { private static final Logger logger = LoggerFactory.getLogger(HttpRequestHandler.class); HttpRequestHandlerChain handlerChain; /** Creates a new {@code HttpRequestHandler} instance. */ public HttpRequestHandler() {} public HttpRequestHandler(HttpRequestHandlerChain chain) { handlerChain = chain; } /** {@inheritDoc} */ @Override protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) { try { NettyUtils.requestReceived(ctx.channel(), req); if (!req.decoderResult().isSuccess()) { throw new BadRequestException("Invalid HTTP message."); } QueryStringDecoder decoder = new QueryStringDecoder(req.uri()); String path = decoder.path(); String[] segments = path.split("/"); handlerChain.handleRequest(ctx, req, decoder, segments); } catch (ResourceNotFoundException | ModelNotFoundException e) { logger.trace("", e); NettyUtils.sendError(ctx, HttpResponseStatus.NOT_FOUND, e); } catch (BadRequestException | ModelException e) { logger.trace("", e); NettyUtils.sendError(ctx, HttpResponseStatus.BAD_REQUEST, e); } catch (ConflictStatusException e) { logger.trace("", e); NettyUtils.sendError(ctx, HttpResponseStatus.CONFLICT, e); } catch (RequestTimeoutException e) { logger.trace("", e); NettyUtils.sendError(ctx, HttpResponseStatus.REQUEST_TIMEOUT, e); } catch (MethodNotAllowedException e) { logger.trace("", e); NettyUtils.sendError(ctx, HttpResponseStatus.METHOD_NOT_ALLOWED, e); } catch (ServiceUnavailableException e) { logger.trace("", e); NettyUtils.sendError(ctx, HttpResponseStatus.SERVICE_UNAVAILABLE, e); } catch (OutOfMemoryError e) { logger.trace("", e); NettyUtils.sendError(ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, e); } catch (Throwable t) { logger.error("", t); NettyUtils.sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, t); } } /** {@inheritDoc} */ @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { logger.error("", cause); if (cause instanceof OutOfMemoryError) { NettyUtils.sendError(ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, cause); } ctx.close(); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/HttpRequestHandlerChain.java ================================================ package com.amazonaws.ml.mms.http; import com.amazonaws.ml.mms.archive.ModelException; import com.amazonaws.ml.mms.archive.ModelNotFoundException; import com.amazonaws.ml.mms.servingsdk.impl.ModelServerContext; import com.amazonaws.ml.mms.servingsdk.impl.ModelServerRequest; import com.amazonaws.ml.mms.servingsdk.impl.ModelServerResponse; import com.amazonaws.ml.mms.util.NettyUtils; import com.amazonaws.ml.mms.wlm.ModelManager; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.DefaultHttpHeadersFactory; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http.QueryStringDecoder; import java.io.IOException; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.ai.mms.servingsdk.ModelServerEndpoint; import software.amazon.ai.mms.servingsdk.ModelServerEndpointException; public abstract class HttpRequestHandlerChain { private static final Logger logger = LoggerFactory.getLogger(HttpRequestHandler.class); protected Map endpointMap; protected HttpRequestHandlerChain chain; public HttpRequestHandlerChain() {} public HttpRequestHandlerChain(Map map) { endpointMap = map; } public HttpRequestHandlerChain setNextHandler(HttpRequestHandlerChain nextHandler) { chain = nextHandler; return chain; } protected abstract void handleRequest( ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String[] segments) throws ModelNotFoundException, ModelException; private void run( ModelServerEndpoint endpoint, FullHttpRequest req, FullHttpResponse rsp, QueryStringDecoder decoder, String method) throws IOException { switch (method) { case "GET": endpoint.doGet( new ModelServerRequest(req, decoder), new ModelServerResponse(rsp), new ModelServerContext()); break; case "PUT": endpoint.doPut( new ModelServerRequest(req, decoder), new ModelServerResponse(rsp), new ModelServerContext()); break; case "DELETE": endpoint.doDelete( new ModelServerRequest(req, decoder), new ModelServerResponse(rsp), new ModelServerContext()); break; case "POST": endpoint.doPost( new ModelServerRequest(req, decoder), new ModelServerResponse(rsp), new ModelServerContext()); break; default: throw new ServiceUnavailableException("Invalid HTTP method received"); } } protected void handleCustomEndpoint( ChannelHandlerContext ctx, FullHttpRequest req, String[] segments, QueryStringDecoder decoder) { ModelServerEndpoint endpoint = endpointMap.get(segments[1]); Runnable r = () -> { Long start = System.currentTimeMillis(); FullHttpResponse rsp = new DefaultFullHttpResponse( HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.directBuffer(), DefaultHttpHeadersFactory.headersFactory(), DefaultHttpHeadersFactory.trailersFactory()); try { run(endpoint, req, rsp, decoder, req.method().toString()); NettyUtils.sendHttpResponse(ctx, rsp, true); logger.info( "Running \"{}\" endpoint took {} ms", segments[0], System.currentTimeMillis() - start); } catch (ModelServerEndpointException me) { NettyUtils.sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, me); logger.error("Error thrown by the model endpoint plugin.", me); } catch (OutOfMemoryError oom) { NettyUtils.sendError( ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, oom, "Out of memory"); } catch (IOException ioe) { NettyUtils.sendError( ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, ioe, "I/O error while running the custom endpoint"); logger.error("I/O error while running the custom endpoint.", ioe); } catch (Throwable e) { NettyUtils.sendError( ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, e, "Unknown exception"); logger.error("Unknown exception", e); } }; ModelManager.getInstance().submitTask(r); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/InferenceRequestHandler.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.http; import com.amazonaws.ml.mms.archive.ModelException; import com.amazonaws.ml.mms.archive.ModelNotFoundException; import com.amazonaws.ml.mms.openapi.OpenApiUtils; import com.amazonaws.ml.mms.util.NettyUtils; import com.amazonaws.ml.mms.util.messages.InputParameter; import com.amazonaws.ml.mms.util.messages.RequestInput; import com.amazonaws.ml.mms.util.messages.WorkerCommands; import com.amazonaws.ml.mms.wlm.Job; import com.amazonaws.ml.mms.wlm.Model; import com.amazonaws.ml.mms.wlm.ModelManager; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpHeaderValues; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.QueryStringDecoder; import io.netty.handler.codec.http.multipart.DefaultHttpDataFactory; import io.netty.handler.codec.http.multipart.HttpDataFactory; import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder; import java.util.List; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.ai.mms.servingsdk.ModelServerEndpoint; /** * A class handling inbound HTTP requests to the management API. * *

This class */ public class InferenceRequestHandler extends HttpRequestHandlerChain { private static final Logger logger = LoggerFactory.getLogger(InferenceRequestHandler.class); /** Creates a new {@code InferenceRequestHandler} instance. */ public InferenceRequestHandler(Map ep) { endpointMap = ep; } @Override protected void handleRequest( ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String[] segments) throws ModelException { if (isInferenceReq(segments)) { if (endpointMap.getOrDefault(segments[1], null) != null) { handleCustomEndpoint(ctx, req, segments, decoder); } else { switch (segments[1]) { case "ping": ModelManager.getInstance().workerStatus(ctx); break; case "models": case "invocations": validatePredictionsEndpoint(segments); handleInvocations(ctx, req, decoder, segments); break; case "predictions": handlePredictions(ctx, req, segments); break; default: handleLegacyPredict(ctx, req, decoder, segments); break; } } } else { chain.handleRequest(ctx, req, decoder, segments); } } private boolean isInferenceReq(String[] segments) { return segments.length == 0 || segments[1].equals("ping") || (segments.length == 4 && segments[1].equals("models")) || segments[1].equals("predictions") || segments[1].equals("api-description") || segments[1].equals("invocations") || (segments.length == 3 && segments[2].equals("predict")) || endpointMap.containsKey(segments[1]); } private void validatePredictionsEndpoint(String[] segments) { if (segments.length == 2 && "invocations".equals(segments[1])) { return; } else if (segments.length == 4 && "models".equals(segments[1]) && "invoke".equals(segments[3])) { return; } throw new ResourceNotFoundException(); } private void handlePredictions( ChannelHandlerContext ctx, FullHttpRequest req, String[] segments) throws ModelNotFoundException { if (segments.length < 3) { throw new ResourceNotFoundException(); } predict(ctx, req, null, segments[2]); } private void handleInvocations( ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String[] segments) throws ModelNotFoundException { String modelName = ("invocations".equals(segments[1])) ? NettyUtils.getParameter(decoder, "model_name", null) : segments[2]; if (modelName == null || modelName.isEmpty()) { if (ModelManager.getInstance().getStartupModels().size() == 1) { modelName = ModelManager.getInstance().getStartupModels().iterator().next(); } } predict(ctx, req, decoder, modelName); } private void handleLegacyPredict( ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String[] segments) throws ModelNotFoundException { if (segments.length < 3 || !"predict".equals(segments[2])) { throw new ResourceNotFoundException(); } predict(ctx, req, decoder, segments[1]); } private void predict( ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String modelName) throws ModelNotFoundException, BadRequestException { RequestInput input = parseRequest(ctx, req, decoder); if (modelName == null) { throw new BadRequestException("Parameter model_name is required."); } if (HttpMethod.OPTIONS.equals(req.method())) { ModelManager modelManager = ModelManager.getInstance(); Model model = modelManager.getModels().get(modelName); if (model == null) { throw new ModelNotFoundException("Model not found: " + modelName); } String resp = OpenApiUtils.getModelApi(model); NettyUtils.sendJsonResponse(ctx, resp); return; } Job job = new Job(ctx, modelName, WorkerCommands.PREDICT, input); if (!ModelManager.getInstance().addJob(job)) { throw new ServiceUnavailableException( "No worker is available to serve request for model: " + modelName + ". Consider increasing job queue size."); } } private static RequestInput parseRequest( ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder) { String requestId = NettyUtils.getRequestId(ctx.channel()); RequestInput inputData = new RequestInput(requestId); if (decoder != null) { for (Map.Entry> entry : decoder.parameters().entrySet()) { String key = entry.getKey(); for (String value : entry.getValue()) { inputData.addParameter(new InputParameter(key, value)); } } } CharSequence contentType = HttpUtil.getMimeType(req); for (Map.Entry entry : req.headers().entries()) { inputData.updateHeaders(entry.getKey(), entry.getValue()); } if (HttpPostRequestDecoder.isMultipart(req) || HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED.contentEqualsIgnoreCase( contentType)) { HttpDataFactory factory = new DefaultHttpDataFactory(6553500); HttpPostRequestDecoder form = new HttpPostRequestDecoder(factory, req); try { while (form.hasNext()) { inputData.addParameter(NettyUtils.getFormData(form.next())); } } catch (HttpPostRequestDecoder.EndOfDataDecoderException ignore) { logger.trace("End of multipart items."); } finally { form.cleanFiles(); form.destroy(); } } else { byte[] content = NettyUtils.getBytes(req.content()); inputData.addParameter(new InputParameter("body", content, contentType)); } return inputData; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/InternalServerException.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.http; public class InternalServerException extends RuntimeException { static final long serialVersionUID = 1L; /** * Constructs an {@code InternalServerException} with the specified detail message. * * @param message The detail message (which is saved for later retrieval by the {@link * #getMessage()} method) */ public InternalServerException(String message) { super(message); } /** * Constructs an {@code BadRequestException} with the specified detail message and cause. * *

Note that the detail message associated with {@code cause} is not automatically * incorporated into this exception's detail message. * * @param message The detail message (which is saved for later retrieval by the {@link * #getMessage()} method) * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} * method). (A null value is permitted, and indicates that the cause is nonexistent or * unknown.) */ public InternalServerException(String message, Throwable cause) { super(message, cause); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/InvalidPluginException.java ================================================ /* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.http; /** InvaliPluginException is thrown when there is an error while handling a Model Server plugin */ public class InvalidPluginException extends RuntimeException { static final long serialVersionUID = 1L; /** * Constructs an {@code InvalidPluginException} with {@code null} as its error detail message. */ public InvalidPluginException() { super("Registered plugin is invalid. Please re-check the configuration and the plugins."); } /** * Constructs an {@code InvalidPluginException} with {@code msg} as its error detail message * * @param msg : This is the error detail message */ public InvalidPluginException(String msg) { super(msg); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/InvalidRequestHandler.java ================================================ package com.amazonaws.ml.mms.http; import com.amazonaws.ml.mms.archive.ModelException; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.QueryStringDecoder; public class InvalidRequestHandler extends HttpRequestHandlerChain { public InvalidRequestHandler() {} @Override protected void handleRequest( ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String[] segments) throws ModelException { throw new ResourceNotFoundException(); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/ListModelsResponse.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.http; import java.util.ArrayList; import java.util.List; public class ListModelsResponse { private String nextPageToken; private List models; public ListModelsResponse() { models = new ArrayList<>(); } public String getNextPageToken() { return nextPageToken; } public void setNextPageToken(String nextPageToken) { this.nextPageToken = nextPageToken; } public List getModels() { return models; } public void setModels(List models) { this.models = models; } public void addModel(String modelName, String modelUrl) { models.add(new ModelItem(modelName, modelUrl)); } public static final class ModelItem { private String modelName; private String modelUrl; public ModelItem() {} public ModelItem(String modelName, String modelUrl) { this.modelName = modelName; this.modelUrl = modelUrl; } public String getModelName() { return modelName; } public void setModelName(String modelName) { this.modelName = modelName; } public String getModelUrl() { return modelUrl; } public void setModelUrl(String modelUrl) { this.modelUrl = modelUrl; } } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/ManagementRequestHandler.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.http; import com.amazonaws.ml.mms.archive.Manifest; import com.amazonaws.ml.mms.archive.ModelArchive; import com.amazonaws.ml.mms.archive.ModelException; import com.amazonaws.ml.mms.archive.ModelNotFoundException; import com.amazonaws.ml.mms.http.messages.RegisterModelRequest; import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.util.JsonUtils; import com.amazonaws.ml.mms.util.NettyUtils; import com.amazonaws.ml.mms.wlm.Model; import com.amazonaws.ml.mms.wlm.ModelManager; import com.amazonaws.ml.mms.wlm.WorkerThread; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpHeaderValues; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.QueryStringDecoder; import io.netty.util.CharsetUtil; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeoutException; import java.util.function.Function; import software.amazon.ai.mms.servingsdk.ModelServerEndpoint; /** * A class handling inbound HTTP requests to the management API. * *

This class */ public class ManagementRequestHandler extends HttpRequestHandlerChain { /** Creates a new {@code ManagementRequestHandler} instance. */ public ManagementRequestHandler(Map ep) { endpointMap = ep; } @Override protected void handleRequest( ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder, String[] segments) throws ModelException { if (isManagementReq(segments)) { if (endpointMap.getOrDefault(segments[1], null) != null) { handleCustomEndpoint(ctx, req, segments, decoder); } else { if (!"models".equals(segments[1])) { throw new ResourceNotFoundException(); } HttpMethod method = req.method(); if (segments.length < 3) { if (HttpMethod.GET.equals(method)) { handleListModels(ctx, decoder); return; } else if (HttpMethod.POST.equals(method)) { handleRegisterModel(ctx, decoder, req); return; } throw new MethodNotAllowedException(); } if (HttpMethod.GET.equals(method)) { handleDescribeModel(ctx, segments[2]); } else if (HttpMethod.PUT.equals(method)) { handleScaleModel(ctx, decoder, segments[2]); } else if (HttpMethod.DELETE.equals(method)) { handleUnregisterModel(ctx, segments[2]); } else { throw new MethodNotAllowedException(); } } } else { chain.handleRequest(ctx, req, decoder, segments); } } private boolean isManagementReq(String[] segments) { return segments.length == 0 || ((segments.length == 2 || segments.length == 3) && segments[1].equals("models")) || endpointMap.containsKey(segments[1]); } private void handleListModels(ChannelHandlerContext ctx, QueryStringDecoder decoder) { int limit = NettyUtils.getIntParameter(decoder, "limit", 100); int pageToken = NettyUtils.getIntParameter(decoder, "next_page_token", 0); if (limit > 100 || limit < 0) { limit = 100; } if (pageToken < 0) { pageToken = 0; } ModelManager modelManager = ModelManager.getInstance(); Map models = modelManager.getModels(); List keys = new ArrayList<>(models.keySet()); Collections.sort(keys); ListModelsResponse list = new ListModelsResponse(); int last = pageToken + limit; if (last > keys.size()) { last = keys.size(); } else { list.setNextPageToken(String.valueOf(last)); } for (int i = pageToken; i < last; ++i) { String modelName = keys.get(i); Model model = models.get(modelName); list.addModel(modelName, model.getModelUrl()); } NettyUtils.sendJsonResponse(ctx, list); } private void handleDescribeModel(ChannelHandlerContext ctx, String modelName) throws ModelNotFoundException { ModelManager modelManager = ModelManager.getInstance(); Model model = modelManager.getModels().get(modelName); if (model == null) { throw new ModelNotFoundException("Model not found: " + modelName); } DescribeModelResponse resp = new DescribeModelResponse(); resp.setModelName(modelName); resp.setModelUrl(model.getModelUrl()); resp.setBatchSize(model.getBatchSize()); resp.setMaxBatchDelay(model.getMaxBatchDelay()); resp.setMaxWorkers(model.getMaxWorkers()); resp.setMinWorkers(model.getMinWorkers()); resp.setLoadedAtStartup(modelManager.getStartupModels().contains(modelName)); Manifest manifest = model.getModelArchive().getManifest(); Manifest.Engine engine = manifest.getEngine(); if (engine != null) { resp.setEngine(engine.getEngineName()); } resp.setModelVersion(manifest.getModel().getModelVersion()); resp.setRuntime(manifest.getRuntime().getValue()); List workers = modelManager.getWorkers(modelName); for (WorkerThread worker : workers) { String workerId = worker.getWorkerId(); long startTime = worker.getStartTime(); boolean isRunning = worker.isRunning(); int gpuId = worker.getGpuId(); long memory = worker.getMemory(); resp.addWorker(workerId, startTime, isRunning, gpuId, memory); } NettyUtils.sendJsonResponse(ctx, resp); } private void handleRegisterModel( ChannelHandlerContext ctx, QueryStringDecoder decoder, FullHttpRequest req) throws ModelException { RegisterModelRequest registerModelRequest = parseRequest(req, decoder); String modelUrl = registerModelRequest.getModelUrl(); if (modelUrl == null) { throw new BadRequestException("Parameter url is required."); } String modelName = registerModelRequest.getModelName(); String runtime = registerModelRequest.getRuntime(); String handler = registerModelRequest.getHandler(); int batchSize = registerModelRequest.getBatchSize(); int maxBatchDelay = registerModelRequest.getMaxBatchDelay(); int initialWorkers = registerModelRequest.getInitialWorkers(); boolean synchronous = registerModelRequest.isSynchronous(); int responseTimeoutSeconds = registerModelRequest.getResponseTimeoutSeconds(); String preloadModel = registerModelRequest.getPreloadModel(); if (preloadModel == null) { preloadModel = ConfigManager.getInstance().getPreloadModel(); } if (responseTimeoutSeconds == -1) { responseTimeoutSeconds = ConfigManager.getInstance().getDefaultResponseTimeoutSeconds(); } Manifest.RuntimeType runtimeType = null; if (runtime != null) { try { runtimeType = Manifest.RuntimeType.fromValue(runtime); } catch (IllegalArgumentException e) { throw new BadRequestException(e); } } ModelManager modelManager = ModelManager.getInstance(); final ModelArchive archive; try { archive = modelManager.registerModel( modelUrl, modelName, runtimeType, handler, batchSize, maxBatchDelay, responseTimeoutSeconds, null, preloadModel); } catch (IOException | InterruptedException | ExecutionException | TimeoutException e) { throw new InternalServerException("Failed to save model: " + modelUrl, e); } modelName = archive.getModelName(); final String msg = "Model \"" + modelName + "\" registered"; if (initialWorkers <= 0) { NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg)); return; } updateModelWorkers( ctx, modelName, initialWorkers, initialWorkers, synchronous, f -> { modelManager.unregisterModel(archive.getModelName()); return null; }); } private void handleUnregisterModel(ChannelHandlerContext ctx, String modelName) throws ModelNotFoundException, InternalServerException, RequestTimeoutException { ModelManager modelManager = ModelManager.getInstance(); HttpResponseStatus httpResponseStatus = modelManager.unregisterModel(modelName); if (httpResponseStatus == HttpResponseStatus.NOT_FOUND) { throw new ModelNotFoundException("Model not found: " + modelName); } else if (httpResponseStatus == HttpResponseStatus.INTERNAL_SERVER_ERROR) { throw new InternalServerException("Interrupted while cleaning resources: " + modelName); } else if (httpResponseStatus == HttpResponseStatus.REQUEST_TIMEOUT) { throw new RequestTimeoutException("Timed out while cleaning resources: " + modelName); } String msg = "Model \"" + modelName + "\" unregistered"; NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg)); } private void handleScaleModel( ChannelHandlerContext ctx, QueryStringDecoder decoder, String modelName) throws ModelNotFoundException { int minWorkers = NettyUtils.getIntParameter(decoder, "min_worker", 1); int maxWorkers = NettyUtils.getIntParameter(decoder, "max_worker", minWorkers); if (maxWorkers < minWorkers) { throw new BadRequestException("max_worker cannot be less than min_worker."); } boolean synchronous = Boolean.parseBoolean(NettyUtils.getParameter(decoder, "synchronous", null)); ModelManager modelManager = ModelManager.getInstance(); if (!modelManager.getModels().containsKey(modelName)) { throw new ModelNotFoundException("Model not found: " + modelName); } updateModelWorkers(ctx, modelName, minWorkers, maxWorkers, synchronous, null); } private void updateModelWorkers( final ChannelHandlerContext ctx, final String modelName, int minWorkers, int maxWorkers, boolean synchronous, final Function onError) { ModelManager modelManager = ModelManager.getInstance(); CompletableFuture future = modelManager.updateModel(modelName, minWorkers, maxWorkers); if (!synchronous) { NettyUtils.sendJsonResponse( ctx, new StatusResponse("Processing worker updates..."), HttpResponseStatus.ACCEPTED); return; } future.thenApply( v -> { boolean status = modelManager.scaleRequestStatus(modelName); if (HttpResponseStatus.OK.equals(v)) { if (status) { NettyUtils.sendJsonResponse( ctx, new StatusResponse("Workers scaled"), v); } else { NettyUtils.sendJsonResponse( ctx, new StatusResponse("Workers scaling in progress..."), new HttpResponseStatus(210, "Partial Success")); } } else { NettyUtils.sendError( ctx, v, new InternalServerException("Failed to start workers")); if (onError != null) { onError.apply(null); } } return v; }) .exceptionally( (e) -> { if (onError != null) { onError.apply(null); } NettyUtils.sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR, e); return null; }); } private RegisterModelRequest parseRequest(FullHttpRequest req, QueryStringDecoder decoder) { RegisterModelRequest in; CharSequence mime = HttpUtil.getMimeType(req); if (HttpHeaderValues.APPLICATION_JSON.contentEqualsIgnoreCase(mime)) { in = JsonUtils.GSON.fromJson( req.content().toString(CharsetUtil.UTF_8), RegisterModelRequest.class); } else { in = new RegisterModelRequest(decoder); } return in; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/MethodNotAllowedException.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.http; public class MethodNotAllowedException extends RuntimeException { static final long serialVersionUID = 1L; /** * Constructs an {@code MethodNotAllowedException} with {@code null} as its error detail * message. */ public MethodNotAllowedException() { super("Requested method is not allowed, please refer to API document."); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/RequestTimeoutException.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.http; public class RequestTimeoutException extends RuntimeException { static final long serialVersionUID = 1L; /** * Constructs an {@code RequestTimeoutException} with the specified detail message. * * @param message The detail message (which is saved for later retrieval by the {@link * #getMessage()} method) */ public RequestTimeoutException(String message) { super(message); } /** * Constructs an {@code RequestTimeoutException} with the specified detail message and cause. * *

Note that the detail message associated with {@code cause} is not automatically * incorporated into this exception's detail message. * * @param message The detail message (which is saved for later retrieval by the {@link * #getMessage()} method) * @param cause The cause (which is saved for later retrieval by the {@link #getCause()} * method). (A null value is permitted, and indicates that the cause is nonexistent or * unknown.) */ public RequestTimeoutException(String message, Throwable cause) { super(message, cause); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/ResourceNotFoundException.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.http; public class ResourceNotFoundException extends RuntimeException { static final long serialVersionUID = 1L; /** * Constructs an {@code ResourceNotFoundException} with {@code null} as its error detail * message. */ public ResourceNotFoundException() { super("Requested resource is not found, please refer to API document."); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/ServiceUnavailableException.java ================================================ package com.amazonaws.ml.mms.http; public class ServiceUnavailableException extends RuntimeException { static final long serialVersionUID = 1L; /** * Constructs an {@code ServiceUnavailableException} with the specified detail message. * * @param message The detail message (which is saved for later retrieval by the {@link * #getMessage()} method) */ public ServiceUnavailableException(String message) { super(message); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/Session.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.http; import io.netty.handler.codec.http.HttpRequest; import java.util.UUID; public class Session { private String requestId; private String remoteIp; private String method; private String uri; private String protocol; private int code; private long startTime; public Session(String remoteIp, HttpRequest request) { this.remoteIp = remoteIp; this.uri = request.uri(); if (request.decoderResult().isSuccess()) { method = request.method().name(); protocol = request.protocolVersion().text(); } else { method = "GET"; protocol = "HTTP/1.1"; } requestId = UUID.randomUUID().toString(); startTime = System.currentTimeMillis(); } public String getRequestId() { return requestId; } public void setCode(int code) { this.code = code; } @Override public String toString() { long duration = System.currentTimeMillis() - startTime; return remoteIp + " \"" + method + " " + uri + ' ' + protocol + "\" " + code + ' ' + duration; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/StatusResponse.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.http; public class StatusResponse { private String status; public StatusResponse() {} public StatusResponse(String status) { this.status = status; } public String getStatus() { return status; } public void setStatus(String status) { this.status = status; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/http/messages/RegisterModelRequest.java ================================================ /* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.http.messages; import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.util.NettyUtils; import com.google.gson.annotations.SerializedName; import io.netty.handler.codec.http.QueryStringDecoder; /** Register Model Request for Model server */ public class RegisterModelRequest { @SerializedName("model_name") private String modelName; @SerializedName("runtime") private String runtime; @SerializedName("handler") private String handler; @SerializedName("batch_size") private int batchSize; @SerializedName("max_batch_delay") private int maxBatchDelay; @SerializedName("initial_workers") private int initialWorkers; @SerializedName("synchronous") private boolean synchronous; @SerializedName("response_timeout") private int responseTimeoutSeconds; @SerializedName("url") private String modelUrl; @SerializedName("preload_model") private String preloadModel; public RegisterModelRequest(QueryStringDecoder decoder) { modelName = NettyUtils.getParameter(decoder, "model_name", null); runtime = NettyUtils.getParameter(decoder, "runtime", null); handler = NettyUtils.getParameter(decoder, "handler", null); batchSize = NettyUtils.getIntParameter(decoder, "batch_size", 1); maxBatchDelay = NettyUtils.getIntParameter(decoder, "max_batch_delay", 100); initialWorkers = NettyUtils.getIntParameter( decoder, "initial_workers", ConfigManager.getInstance().getConfiguredDefaultWorkersPerModel()); synchronous = Boolean.parseBoolean(NettyUtils.getParameter(decoder, "synchronous", "true")); // TODO Fix this so it matches the documentation, where timeouts are specified in seconds. // For now, we're being extra careful about backwards compatibility. // So, assume parameter is in minutes, and convert to seconds internally. responseTimeoutSeconds = 60 * NettyUtils.getIntParameter(decoder, "response_timeout", -1); if (responseTimeoutSeconds < 0) { responseTimeoutSeconds = -1; } modelUrl = NettyUtils.getParameter(decoder, "url", null); preloadModel = NettyUtils.getParameter(decoder, "preload_model", null); } public RegisterModelRequest() { batchSize = 1; maxBatchDelay = 100; synchronous = true; initialWorkers = ConfigManager.getInstance().getConfiguredDefaultWorkersPerModel(); responseTimeoutSeconds = -1; preloadModel = null; } public String getModelName() { return modelName; } public String getRuntime() { return runtime; } public String getHandler() { return handler; } public Integer getBatchSize() { return batchSize; } public Integer getMaxBatchDelay() { return maxBatchDelay; } public Integer getInitialWorkers() { return initialWorkers; } public Boolean isSynchronous() { return synchronous; } public Integer getResponseTimeoutSeconds() { return responseTimeoutSeconds; } public String getModelUrl() { return modelUrl; } public String getPreloadModel() { return preloadModel; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/metrics/Dimension.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.metrics; import com.google.gson.annotations.SerializedName; public class Dimension { @SerializedName("Name") private String name; @SerializedName("Value") private String value; public Dimension() {} public Dimension(String name, String value) { this.name = name; this.value = value; } public String getName() { return name; } public void setName(String name) { this.name = name; } public String getValue() { return value; } public void setValue(String value) { this.value = value; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/metrics/Metric.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.metrics; import com.google.gson.annotations.SerializedName; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; public class Metric { private static final Pattern PATTERN = Pattern.compile( "\\s*(\\w+)\\.(\\w+):([0-9\\-,.e]+)\\|#([^|]*)\\|#hostname:([^,]+),([^,]+)(,(.*))?"); @SerializedName("MetricName") private String metricName; @SerializedName("Value") private String value; @SerializedName("Unit") private String unit; @SerializedName("Dimensions") private List dimensions; @SerializedName("Timestamp") private String timestamp; @SerializedName("RequestId") private String requestId; @SerializedName("HostName") private String hostName; public Metric() {} public Metric( String metricName, String value, String unit, String hostName, Dimension... dimensions) { this.metricName = metricName; this.value = value; this.unit = unit; this.hostName = hostName; this.dimensions = Arrays.asList(dimensions); } public String getHostName() { return hostName; } public void setHostName(String hostName) { this.hostName = hostName; } public String getRequestId() { return requestId; } public void setRequestId(String requestId) { this.requestId = requestId; } public String getMetricName() { return metricName; } public void setMetricName(String metricName) { this.metricName = metricName; } public String getValue() { return value; } public void setValue(String value) { this.value = value; } public String getUnit() { return unit; } public void setUnit(String unit) { this.unit = unit; } public List getDimensions() { return dimensions; } public void setDimensions(List dimensions) { this.dimensions = dimensions; } public String getTimestamp() { return timestamp; } public void setTimestamp(String timestamp) { this.timestamp = timestamp; } public static Metric parse(String line) { // DiskAvailable.Gigabytes:311|#Level:Host,hostname:localhost Matcher matcher = PATTERN.matcher(line); if (!matcher.matches()) { return null; } Metric metric = new Metric(); metric.setMetricName(matcher.group(1)); metric.setUnit(matcher.group(2)); metric.setValue(matcher.group(3)); String dimensions = matcher.group(4); metric.setHostName(matcher.group(5)); metric.setTimestamp(matcher.group(6)); metric.setRequestId(matcher.group(8)); if (dimensions != null) { String[] dimension = dimensions.split(","); List list = new ArrayList<>(dimension.length); for (String dime : dimension) { String[] pair = dime.split(":"); if (pair.length == 2) { list.add(new Dimension(pair[0], pair[1])); } } metric.setDimensions(list); } return metric; } @Override public String toString() { StringBuilder sb = new StringBuilder(128); sb.append(metricName).append('.').append(unit).append(':').append(getValue()).append("|#"); boolean first = true; for (Dimension dimension : getDimensions()) { if (first) { first = false; } else { sb.append(','); } sb.append(dimension.getName()).append(':').append(dimension.getValue()); } sb.append("|#hostname:").append(hostName); if (requestId != null) { sb.append(",requestID:").append(requestId); } sb.append(",timestamp:").append(timestamp); return sb.toString(); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/metrics/MetricCollector.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.metrics; import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.wlm.ModelManager; import com.amazonaws.ml.mms.wlm.WorkerThread; import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.io.InputStreamReader; import java.io.OutputStream; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; import java.util.Map; import org.apache.commons.io.IOUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class MetricCollector implements Runnable { static final Logger logger = LoggerFactory.getLogger(MetricCollector.class); private static final Logger loggerMetrics = LoggerFactory.getLogger(ConfigManager.MODEL_SERVER_METRICS_LOGGER); private ConfigManager configManager; public MetricCollector(ConfigManager configManager) { this.configManager = configManager; } @Override public void run() { try { // Collect System level Metrics String[] args = new String[2]; args[0] = configManager.getPythonExecutable(); args[1] = "mms/metrics/metric_collector.py"; File workingDir = new File(configManager.getModelServerHome()); String pythonPath = System.getenv("PYTHONPATH"); String pythonEnv; if ((pythonPath == null || pythonPath.isEmpty()) && (!workingDir.getAbsolutePath().contains("site-package"))) { pythonEnv = "PYTHONPATH=" + workingDir.getAbsolutePath(); } else { pythonEnv = "PYTHONPATH=" + pythonPath; if (!workingDir.getAbsolutePath().contains("site-package")) { pythonEnv += File.pathSeparatorChar + workingDir.getAbsolutePath(); // NOPMD } } // sbin added for macs for python sysctl pythonpath StringBuilder path = new StringBuilder(); path.append("PATH=").append(System.getenv("PATH")); String osName = System.getProperty("os.name"); if (osName.startsWith("Mac OS X")) { path.append(File.pathSeparatorChar).append("/sbin/"); } String[] env = {pythonEnv, path.toString()}; final Process p = Runtime.getRuntime().exec(args, env, workingDir); ModelManager modelManager = ModelManager.getInstance(); Map workerMap = modelManager.getWorkers(); try (OutputStream os = p.getOutputStream()) { writeWorkerPids(workerMap, os); } new Thread( () -> { try { String error = IOUtils.toString( p.getErrorStream(), StandardCharsets.UTF_8); if (!error.isEmpty()) { logger.error(error); } } catch (IOException e) { logger.error("", e); } }) .start(); MetricManager metricManager = MetricManager.getInstance(); try (BufferedReader reader = new BufferedReader( new InputStreamReader(p.getInputStream(), StandardCharsets.UTF_8))) { List metricsSystem = new ArrayList<>(); metricManager.setMetrics(metricsSystem); String line; while ((line = reader.readLine()) != null) { if (line.isEmpty()) { break; } Metric metric = Metric.parse(line); if (metric == null) { logger.warn("Parse metrics failed: " + line); } else { loggerMetrics.info("{}", metric); metricsSystem.add(metric); } } // Collect process level metrics while ((line = reader.readLine()) != null) { String[] tokens = line.split(":"); if (tokens.length != 2) { continue; } try { Integer pid = Integer.valueOf(tokens[0]); WorkerThread worker = workerMap.get(pid); worker.setMemory(Long.parseLong(tokens[1])); } catch (NumberFormatException e) { logger.warn("Failed to parse memory utilization metrics: " + line); continue; } } } } catch (IOException e) { logger.error("", e); } } private void writeWorkerPids(Map workerMap, OutputStream os) throws IOException { boolean first = true; for (Integer pid : workerMap.keySet()) { if (pid < 0) { logger.warn("worker pid is not available yet."); continue; } if (first) { first = false; } else { IOUtils.write(",", os, StandardCharsets.UTF_8); } IOUtils.write(pid.toString(), os, StandardCharsets.UTF_8); } os.write('\n'); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/metrics/MetricManager.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.metrics; import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.wlm.ModelManager; import java.util.Collections; import java.util.List; import java.util.concurrent.TimeUnit; public final class MetricManager { private static final MetricManager METRIC_MANAGER = new MetricManager(); private List metrics; private MetricManager() { metrics = Collections.emptyList(); } public static MetricManager getInstance() { return METRIC_MANAGER; } public static void scheduleMetrics(ConfigManager configManager) { MetricCollector metricCollector = new MetricCollector(configManager); ModelManager.getInstance() .getScheduler() .scheduleAtFixedRate( metricCollector, 0, configManager.getMetricTimeInterval(), TimeUnit.SECONDS); } public synchronized List getMetrics() { return metrics; } public synchronized void setMetrics(List metrics) { this.metrics = metrics; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/openapi/Encoding.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.openapi; public class Encoding { private String contentType; private String style; private boolean explode; private boolean allowReserved; public Encoding() {} public Encoding(String contentType) { this.contentType = contentType; } public String getContentType() { return contentType; } public void setContentType(String contentType) { this.contentType = contentType; } public boolean isAllowReserved() { return allowReserved; } public void setAllowReserved(boolean allowReserved) { this.allowReserved = allowReserved; } public String getStyle() { return style; } public void setStyle(String style) { this.style = style; } public boolean isExplode() { return explode; } public void setExplode(boolean explode) { this.explode = explode; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/openapi/Info.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.openapi; public class Info { private String title; private String description; private String termsOfService; private String version; public Info() {} public String getTitle() { return title; } public void setTitle(String title) { this.title = title; } public String getDescription() { return description; } public void setDescription(String description) { this.description = description; } public String getTermsOfService() { return termsOfService; } public void setTermsOfService(String termsOfService) { this.termsOfService = termsOfService; } public String getVersion() { return version; } public void setVersion(String version) { this.version = version; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/openapi/MediaType.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.openapi; import java.util.LinkedHashMap; import java.util.Map; public class MediaType { private transient String contentType; private Schema schema; private Map encoding; public MediaType() {} public MediaType(String contentType, Schema schema) { this.contentType = contentType; this.schema = schema; } public String getContentType() { return contentType; } public void setContentType(String contentType) { this.contentType = contentType; } public Schema getSchema() { return schema; } public void setSchema(Schema schema) { this.schema = schema; } public Map getEncoding() { return encoding; } public void setEncoding(Map encoding) { this.encoding = encoding; } public void addEncoding(String contentType, Encoding encoding) { if (this.encoding == null) { this.encoding = new LinkedHashMap<>(); } this.encoding.put(contentType, encoding); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/openapi/OpenApi.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.openapi; import java.util.LinkedHashMap; import java.util.Map; public class OpenApi { private String openapi = "3.0.1"; private Info info; private Map paths; public OpenApi() {} public String getOpenapi() { return openapi; } public void setOpenapi(String openapi) { this.openapi = openapi; } public Info getInfo() { return info; } public void setInfo(Info info) { this.info = info; } public Map getPaths() { return paths; } public void setPaths(Map paths) { this.paths = paths; } public void addPath(String url, Path path) { if (paths == null) { paths = new LinkedHashMap<>(); } paths.put(url, path); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/openapi/OpenApiUtils.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.openapi; import com.amazonaws.ml.mms.archive.Manifest; import com.amazonaws.ml.mms.util.ConnectorType; import com.amazonaws.ml.mms.util.JsonUtils; import com.amazonaws.ml.mms.wlm.Model; import io.netty.handler.codec.http.HttpHeaderValues; import java.util.ArrayList; import java.util.List; public final class OpenApiUtils { private OpenApiUtils() {} public static String listApis(ConnectorType type) { OpenApi openApi = new OpenApi(); Info info = new Info(); info.setTitle("Model Server APIs"); info.setDescription( "Model Server is a flexible and easy to use tool for serving deep learning models"); info.setVersion("1.0.0"); openApi.setInfo(info); if (ConnectorType.BOTH.equals(type) || ConnectorType.INFERENCE_CONNECTOR.equals(type)) { listInferenceApis(openApi); } if (ConnectorType.BOTH.equals(type) || ConnectorType.MANAGEMENT_CONNECTOR.equals(type)) { listManagementApis(openApi); } return JsonUtils.GSON_PRETTY.toJson(openApi); } static void listInferenceApis(OpenApi openApi) { openApi.addPath("/", getApiDescriptionPath(false)); openApi.addPath("/{model_name}/predict", getLegacyPredictPath()); openApi.addPath("/ping", getPingPath()); openApi.addPath("/predictions/{model_name}", getPredictionsPath()); openApi.addPath("/api-description", getApiDescriptionPath(true)); openApi.addPath("/invocations", getInvocationsPath()); openApi.addPath("/models/{model_name}/invoke", getInvocationsPath()); } static void listManagementApis(OpenApi openApi) { openApi.addPath("/", getApiDescriptionPath(false)); openApi.addPath("/models", getModelsPath()); openApi.addPath("/models/{model_name}", getModelManagerPath()); } public static String getModelApi(Model model) { String modelName = model.getModelName(); OpenApi openApi = new OpenApi(); Info info = new Info(); info.setTitle("RESTful API for: " + modelName); info.setVersion("1.0.0"); openApi.setInfo(info); openApi.addPath("/prediction/" + modelName, getModelPath(modelName)); return JsonUtils.GSON_PRETTY.toJson(openApi); } private static Path getApiDescriptionPath(boolean legacy) { Schema schema = new Schema("object"); schema.addProperty("openapi", new Schema("string"), true); schema.addProperty("info", new Schema("object"), true); schema.addProperty("paths", new Schema("object"), true); MediaType mediaType = new MediaType(HttpHeaderValues.APPLICATION_JSON.toString(), schema); Operation operation = new Operation("apiDescription"); operation.addResponse(new Response("200", "A openapi 3.0.1 descriptor", mediaType)); operation.addResponse(new Response("500", "Internal Server Error", getErrorResponse())); Path path = new Path(); if (legacy) { operation.setDeprecated(true); path.setGet(operation); } else { path.setOptions(operation); } return path; } private static Path getPingPath() { Schema schema = new Schema("object"); schema.addProperty( "status", new Schema("string", "Overall status of the Model Server."), true); MediaType mediaType = new MediaType(HttpHeaderValues.APPLICATION_JSON.toString(), schema); Operation operation = new Operation("ping"); operation.addResponse(new Response("200", "Model server status", mediaType)); operation.addResponse(new Response("500", "Internal Server Error", getErrorResponse())); Path path = new Path(); path.setGet(operation); return path; } private static Path getInvocationsPath() { Schema schema = new Schema(); schema.addProperty("model_name", new Schema("string", "Name of model"), false); Schema dataProp = new Schema("string", "Inference input data"); dataProp.setFormat("binary"); schema.addProperty("data", dataProp, true); MediaType multipart = new MediaType(HttpHeaderValues.MULTIPART_FORM_DATA.toString(), schema); RequestBody requestBody = new RequestBody(); requestBody.setRequired(true); requestBody.addContent(multipart); Operation operation = new Operation("invocations", "A generic invocation entry point for all models."); operation.setRequestBody(requestBody); operation.addParameter(new QueryParameter("model_name", "Name of model")); MediaType error = getErrorResponse(); MediaType mediaType = new MediaType("*/*", schema); operation.addResponse(new Response("200", "Model specific output data format", mediaType)); operation.addResponse(new Response("400", "Missing model_name parameter", error)); operation.addResponse(new Response("404", "Model not found", error)); operation.addResponse(new Response("500", "Internal Server Error", error)); operation.addResponse( new Response("503", "No worker is available to serve request", error)); Path path = new Path(); path.setPost(operation); return path; } private static Path getPredictionsPath() { Operation post = new Operation( "predictions", "Predictions entry point for each model." + " Use OPTIONS method to get detailed model API input and output description."); post.addParameter(new PathParameter("model_name", "Name of model.")); Schema schema = new Schema("string"); schema.setFormat("binary"); MediaType mediaType = new MediaType("*/*", schema); RequestBody requestBody = new RequestBody(); requestBody.setDescription( "Input data format is defined by each model. Use OPTIONS method to get details for model input format."); requestBody.setRequired(true); requestBody.addContent(mediaType); post.setRequestBody(requestBody); schema = new Schema("string"); schema.setFormat("binary"); mediaType = new MediaType("*/*", schema); Response resp = new Response( "200", "Output data format is defined by each model. Use OPTIONS method to get details for model output and output format.", mediaType); post.addResponse(resp); MediaType error = getErrorResponse(); post.addResponse(new Response("404", "Model not found", error)); post.addResponse(new Response("500", "Internal Server Error", error)); post.addResponse(new Response("503", "No worker is available to serve request", error)); Operation options = new Operation("predictionsApi", "Display details of per model input and output."); options.addParameter(new PathParameter("model_name", "Name of model.")); mediaType = new MediaType("application/json", new Schema("object")); options.addResponse(new Response("200", "OK", mediaType)); post.addResponse(new Response("500", "Internal Server Error", error)); Path path = new Path(); path.setPost(post); path.setOptions(options); return path; } private static Path getLegacyPredictPath() { Operation operation = new Operation("predict", "A legacy predict entry point for each model."); operation.addParameter(new PathParameter("model_name", "Name of model to unregister.")); Schema schema = new Schema("string"); schema.setFormat("binary"); MediaType mediaType = new MediaType("*/*", schema); RequestBody requestBody = new RequestBody(); requestBody.setRequired(true); requestBody.setDescription("Input data format is defined by each model."); requestBody.addContent(mediaType); operation.setRequestBody(requestBody); schema = new Schema("string"); schema.setFormat("binary"); mediaType = new MediaType("*/*", schema); MediaType error = getErrorResponse(); operation.addResponse(new Response("200", "Model specific output data format", mediaType)); operation.addResponse(new Response("404", "Model not found", error)); operation.addResponse(new Response("500", "Internal Server Error", error)); operation.addResponse( new Response("503", "No worker is available to serve request", error)); operation.setDeprecated(true); Path path = new Path(); path.setPost(operation); return path; } private static Path getModelsPath() { Path path = new Path(); path.setGet(getListModelsOperation()); path.setPost(getRegisterOperation()); return path; } private static Path getModelManagerPath() { Path path = new Path(); path.setGet(getDescribeModelOperation()); path.setPut(getScaleOperation()); path.setDelete(getUnRegisterOperation()); return path; } private static Operation getListModelsOperation() { Operation operation = new Operation("listModels", "List registered models in Model Server."); operation.addParameter( new QueryParameter( "limit", "integer", "100", "Use this parameter to specify the maximum number of items to return. When" + " this value is present, Model Server does not return more than the specified" + " number of items, but it might return fewer. This value is optional. If you" + " include a value, it must be between 1 and 1000, inclusive. If you do not" + " include a value, it defaults to 100.")); operation.addParameter( new QueryParameter( "next_page_token", "The token to retrieve the next set of results. Model Server provides the" + " token when the response from a previous call has more results than the" + " maximum page size.")); operation.addParameter( new QueryParameter( "model_name_pattern", "A model name filter to list only matching models.")); Schema schema = new Schema("object"); schema.addProperty( "nextPageToken", new Schema( "string", "Use this parameter in a subsequent request after you receive a response" + " with truncated results. Set it to the value of NextMarker from the" + " truncated response you just received."), false); Schema modelProp = new Schema("object"); modelProp.addProperty("modelName", new Schema("string", "Name of the model."), true); modelProp.addProperty("modelUrl", new Schema("string", "URL of the model."), true); Schema modelsProp = new Schema("array", "A list of registered models."); modelsProp.setItems(modelProp); schema.addProperty("models", modelsProp, true); MediaType json = new MediaType(HttpHeaderValues.APPLICATION_JSON.toString(), schema); operation.addResponse(new Response("200", "OK", json)); operation.addResponse(new Response("500", "Internal Server Error", getErrorResponse())); return operation; } private static Operation getRegisterOperation() { Operation operation = new Operation("registerModel", "Register a new model in Model Server."); operation.addParameter( new QueryParameter( "model_url", "string", null, true, "Model archive download url, support local file or HTTP(s) protocol." + " For S3, consider use pre-signed url.")); operation.addParameter( new QueryParameter( "model_name", "Name of model. This value will override modelName in MANIFEST.json if present.")); operation.addParameter( new QueryParameter( "handler", "Inference handler entry-point. This value will override handler in MANIFEST.json if present.")); Parameter runtime = new QueryParameter( "runtime", "Runtime for the model custom service code. This value will override runtime in MANIFEST.json if present."); operation.addParameter(runtime); operation.addParameter( new QueryParameter( "batch_size", "integer", "1", "Inference batch size, default: 1.")); operation.addParameter( new QueryParameter( "max_batch_delay", "integer", "100", "Maximum delay for batch aggregation, default: 100.")); operation.addParameter( new QueryParameter( "response_timeout", "integer", "2", "Maximum time, in seconds, the Model Server waits for a response from the model inference code, default: 120.")); operation.addParameter( new QueryParameter( "initial_workers", "integer", "0", "Number of initial workers, default: 0.")); operation.addParameter( new QueryParameter( "synchronous", "boolean", "false", "Decides whether creation of worker synchronous or not, default: false.")); operation.addParameter( new QueryParameter( "preload_model", "boolean", "false", "Decides if model should be preloaded, default: false.")); Manifest.RuntimeType[] types = Manifest.RuntimeType.values(); List runtimeTypes = new ArrayList<>(types.length); for (Manifest.RuntimeType type : types) { runtimeTypes.add(type.toString()); } runtime.getSchema().setEnumeration(runtimeTypes); MediaType status = getStatusResponse(); MediaType error = getErrorResponse(); operation.addResponse(new Response("200", "Model registered", status)); operation.addResponse(new Response("202", "Accepted", status)); operation.addResponse(new Response("210", "Partial Success", status)); operation.addResponse(new Response("400", "Bad request", error)); operation.addResponse(new Response("404", "Model not found", error)); operation.addResponse(new Response("409", "Model already registered", error)); operation.addResponse(new Response("500", "Internal Server Error", error)); return operation; } private static Operation getUnRegisterOperation() { Operation operation = new Operation( "unregisterModel", "Unregister a model from Model Server. This is an asynchronous call by default." + " Caller can call listModels to confirm if all the works has be terminated."); operation.addParameter(new PathParameter("model_name", "Name of model to unregister.")); operation.addParameter( new QueryParameter( "synchronous", "boolean", "false", "Decides whether the call is synchronous or not, default: false.")); operation.addParameter( new QueryParameter( "timeout", "integer", "-1", "Waiting up to the specified wait time if necessary for" + " a worker to complete all pending requests. Use 0 to terminate backend" + " worker process immediately. Use -1 for wait infinitely.")); MediaType status = getStatusResponse(); MediaType error = getErrorResponse(); operation.addResponse(new Response("200", "Model unregistered", status)); operation.addResponse(new Response("202", "Accepted", status)); operation.addResponse(new Response("404", "Model not found", error)); operation.addResponse(new Response("408", "Request Timeout Error", error)); operation.addResponse(new Response("500", "Internal Server Error", error)); return operation; } private static Operation getDescribeModelOperation() { Operation operation = new Operation( "describeModel", "Provides detailed information about the specified model."); operation.addParameter(new PathParameter("model_name", "Name of model to describe.")); Schema schema = new Schema("object"); schema.addProperty("modelName", new Schema("string", "Name of the model."), true); schema.addProperty("modelVersion", new Schema("string", "Version of the model."), true); schema.addProperty("modelUrl", new Schema("string", "URL of the model."), true); schema.addProperty( "minWorkers", new Schema("integer", "Configured minimum number of worker."), true); schema.addProperty( "maxWorkers", new Schema("integer", "Configured maximum number of worker."), true); schema.addProperty("batchSize", new Schema("integer", "Configured batch size."), false); schema.addProperty( "maxBatchDelay", new Schema("integer", "Configured maximum batch delay in ms."), false); schema.addProperty( "status", new Schema("string", "Overall health status of the model"), true); Schema workers = new Schema("array", "A list of active backend workers."); Schema worker = new Schema("object"); worker.addProperty("id", new Schema("string", "Worker id"), true); worker.addProperty("startTime", new Schema("string", "Worker start time"), true); worker.addProperty("gpu", new Schema("boolean", "If running on GPU"), false); Schema workerStatus = new Schema("string", "Worker status"); List status = new ArrayList<>(); status.add("READY"); status.add("LOADING"); status.add("UNLOADING"); workerStatus.setEnumeration(status); worker.addProperty("status", workerStatus, true); workers.setItems(worker); schema.addProperty("workers", workers, true); Schema metrics = new Schema("object"); metrics.addProperty( "rejectedRequests", new Schema("integer", "Number requests has been rejected in last 10 minutes."), true); metrics.addProperty( "waitingQueueSize", new Schema("integer", "Number requests waiting in the queue."), true); metrics.addProperty( "requests", new Schema("integer", "Number requests processed in last 10 minutes."), true); schema.addProperty("metrics", metrics, true); MediaType mediaType = new MediaType(HttpHeaderValues.APPLICATION_JSON.toString(), schema); operation.addResponse(new Response("200", "OK", mediaType)); operation.addResponse(new Response("500", "Internal Server Error", getErrorResponse())); return operation; } private static Operation getScaleOperation() { Operation operation = new Operation( "setAutoScale", "Configure number of workers for a model, This is a asynchronous call by default." + " Caller need to call describeModel check if the model workers has been changed."); operation.addParameter(new PathParameter("model_name", "Name of model to describe.")); operation.addParameter( new QueryParameter( "min_worker", "integer", "1", "Minimum number of worker processes.")); operation.addParameter( new QueryParameter( "max_worker", "integer", "1", "Maximum number of worker processes.")); operation.addParameter( new QueryParameter( "number_gpu", "integer", "0", "Number of GPU worker processes to create.")); operation.addParameter( new QueryParameter( "synchronous", "boolean", "false", "Decides whether the call is synchronous or not, default: false.")); operation.addParameter( new QueryParameter( "timeout", "integer", "-1", "Waiting up to the specified wait time if necessary for" + " a worker to complete all pending requests. Use 0 to terminate backend" + " worker process immediately. Use -1 for wait infinitely.")); MediaType status = getStatusResponse(); MediaType error = getErrorResponse(); operation.addResponse(new Response("200", "Model workers updated", status)); operation.addResponse(new Response("202", "Accepted", status)); operation.addResponse(new Response("210", "Partial Success", status)); operation.addResponse(new Response("400", "Bad request", error)); operation.addResponse(new Response("404", "Model not found", error)); operation.addResponse(new Response("500", "Internal Server Error", error)); return operation; } private static Path getModelPath(String modelName) { Operation operation = new Operation(modelName, "A predict entry point for model: " + modelName + '.'); operation.addResponse(new Response("200", "OK")); operation.addResponse(new Response("500", "Internal Server Error", getErrorResponse())); Path path = new Path(); path.setPost(operation); return path; } private static MediaType getErrorResponse() { Schema schema = new Schema("object"); schema.addProperty("code", new Schema("integer", "Error code."), true); schema.addProperty("type", new Schema("string", "Error type."), true); schema.addProperty("message", new Schema("string", "Error message."), true); return new MediaType(HttpHeaderValues.APPLICATION_JSON.toString(), schema); } private static MediaType getStatusResponse() { Schema schema = new Schema("object"); schema.addProperty("status", new Schema("string", "Error type."), true); return new MediaType(HttpHeaderValues.APPLICATION_JSON.toString(), schema); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/openapi/Operation.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.openapi; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; public class Operation { private String summary; private String description; private String operationId; private List parameters = new ArrayList<>(); private RequestBody requestBody; private Map responses; private Boolean deprecated; public Operation() {} public Operation(String operationId) { this(operationId, null); } public Operation(String operationId, String description) { this.operationId = operationId; this.description = description; } public String getSummary() { return summary; } public void setSummary(String summary) { this.summary = summary; } public String getDescription() { return description; } public void setDescription(String description) { this.description = description; } public String getOperationId() { return operationId; } public void setOperationId(String operationId) { this.operationId = operationId; } public List getParameters() { return parameters; } public void setParameters(List parameters) { this.parameters = parameters; } public void addParameter(Parameter parameter) { if (parameters == null) { parameters = new ArrayList<>(); } parameters.add(parameter); } public RequestBody getRequestBody() { return requestBody; } public void setRequestBody(RequestBody requestBody) { this.requestBody = requestBody; } public Map getResponses() { return responses; } public void setResponses(Map responses) { this.responses = responses; } public void addResponse(Response response) { if (responses == null) { responses = new LinkedHashMap<>(); } responses.put(response.getCode(), response); } public Boolean getDeprecated() { return deprecated; } public void setDeprecated(Boolean deprecated) { this.deprecated = deprecated; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/openapi/Parameter.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.openapi; @SuppressWarnings("PMD.AbstractClassWithoutAbstractMethod") public abstract class Parameter { protected String type; protected String in; protected String name; protected String description; protected boolean required; protected Boolean deprecated; protected Boolean allowEmptyValue; protected String style; protected Boolean explode; protected Schema schema; public void setType(String type) { this.type = type; } public String getType() { return type; } public String getName() { return name; } public void setName(String name) { this.name = name; } public String getIn() { return in; } public void setIn(String in) { this.in = in; } public String getDescription() { return description; } public void setDescription(String description) { this.description = description; } public boolean isRequired() { return required; } public void setRequired(boolean required) { this.required = required; } public Boolean getDeprecated() { return deprecated; } public void setDeprecated(Boolean deprecated) { this.deprecated = deprecated; } public Boolean getAllowEmptyValue() { return allowEmptyValue; } public void setAllowEmptyValue(Boolean allowEmptyValue) { this.allowEmptyValue = allowEmptyValue; } public String getStyle() { return style; } public void setStyle(String style) { this.style = style; } public Boolean getExplode() { return explode; } public void setExplode(Boolean explode) { this.explode = explode; } public Schema getSchema() { return schema; } public void setSchema(Schema schema) { this.schema = schema; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/openapi/Path.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.openapi; import java.util.List; public class Path { private Operation get; private Operation put; private Operation post; private Operation head; private Operation delete; private Operation patch; private Operation options; private List parameters; public Operation getGet() { return get; } public void setGet(Operation get) { this.get = get; } public Operation getPut() { return put; } public void setPut(Operation put) { this.put = put; } public Operation getPost() { return post; } public void setPost(Operation post) { this.post = post; } public Operation getHead() { return head; } public void setHead(Operation head) { this.head = head; } public Operation getDelete() { return delete; } public void setDelete(Operation delete) { this.delete = delete; } public Operation getPatch() { return patch; } public void setPatch(Operation patch) { this.patch = patch; } public Operation getOptions() { return options; } public void setOptions(Operation options) { this.options = options; } public List getParameters() { return parameters; } public void setParameters(List parameters) { this.parameters = parameters; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/openapi/PathParameter.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.openapi; public class PathParameter extends Parameter { public PathParameter() { this(null, "string", null, null); } public PathParameter(String name, String description) { this(name, "string", null, description); } public PathParameter(String name, String type, String defaultValue, String description) { this.name = name; this.description = description; in = "path"; required = true; schema = new Schema(type, null, defaultValue); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/openapi/QueryParameter.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.openapi; public class QueryParameter extends Parameter { public QueryParameter() { this(null, "string", null, false, null); } public QueryParameter(String name, String description) { this(name, "string", null, false, description); } public QueryParameter(String name, String type, String description) { this(name, type, null, false, description); } public QueryParameter(String name, String type, String defaultValue, String description) { this(name, type, defaultValue, false, description); } public QueryParameter( String name, String type, String defaultValue, boolean required, String description) { this.name = name; this.description = description; in = "query"; this.required = required; schema = new Schema(type, null, defaultValue); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/openapi/RequestBody.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.openapi; import java.util.LinkedHashMap; import java.util.Map; public class RequestBody { private String description; private Map content; private boolean required; public RequestBody() {} public String getDescription() { return description; } public void setDescription(String description) { this.description = description; } public Map getContent() { return content; } public void setContent(Map content) { this.content = content; } public void addContent(MediaType mediaType) { if (content == null) { content = new LinkedHashMap<>(); } content.put(mediaType.getContentType(), mediaType); } public boolean isRequired() { return required; } public void setRequired(boolean required) { this.required = required; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/openapi/Response.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.openapi; import java.util.LinkedHashMap; import java.util.Map; public class Response { private transient String code; private String description; private Map content; public Response() {} public Response(String code, String description) { this.code = code; this.description = description; } public Response(String code, String description, MediaType mediaType) { this.code = code; this.description = description; content = new LinkedHashMap<>(); content.put(mediaType.getContentType(), mediaType); } public String getCode() { return code; } public String getDescription() { return description; } public void setDescription(String description) { this.description = description; } public Map getContent() { return content; } public void setContent(Map content) { this.content = content; } public void addContent(MediaType mediaType) { if (content == null) { content = new LinkedHashMap<>(); } content.put(mediaType.getContentType(), mediaType); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/openapi/Schema.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.openapi; import com.google.gson.annotations.SerializedName; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; public class Schema { private String type; private String format; private String name; private List required; private Map properties; private Schema items; private String description; private Object example; private Schema additionalProperties; private String discriminator; @SerializedName("enum") private List enumeration; @SerializedName("default") private String defaultValue; public Schema() {} public Schema(String type) { this(type, null, null); } public Schema(String type, String description) { this(type, description, null); } public Schema(String type, String description, String defaultValue) { this.type = type; this.description = description; this.defaultValue = defaultValue; } public String getType() { return type; } public void setType(String type) { this.type = type; } public String getFormat() { return format; } public void setFormat(String format) { this.format = format; } public String getName() { return name; } public void setName(String name) { this.name = name; } public List getRequired() { return required; } public void setRequired(List required) { this.required = required; } public Map getProperties() { return properties; } public void setProperties(Map properties) { this.properties = properties; } public void addProperty(String key, Schema schema, boolean requiredProperty) { if (properties == null) { properties = new LinkedHashMap<>(); } properties.put(key, schema); if (requiredProperty) { if (required == null) { required = new ArrayList<>(); } required.add(key); } } public Schema getItems() { return items; } public void setItems(Schema items) { this.items = items; } public String getDescription() { return description; } public void setDescription(String description) { this.description = description; } public Object getExample() { return example; } public void setExample(Object example) { this.example = example; } public Schema getAdditionalProperties() { return additionalProperties; } public void setAdditionalProperties(Schema additionalProperties) { this.additionalProperties = additionalProperties; } public String getDiscriminator() { return discriminator; } public void setDiscriminator(String discriminator) { this.discriminator = discriminator; } public List getEnumeration() { return enumeration; } public void setEnumeration(List enumeration) { this.enumeration = enumeration; } public String getDefaultValue() { return defaultValue; } public void setDefaultValue(String defaultValue) { this.defaultValue = defaultValue; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/servingsdk/impl/ModelServerContext.java ================================================ /* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.servingsdk.impl; import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.wlm.ModelManager; import java.util.HashMap; import java.util.Map; import java.util.Properties; import software.amazon.ai.mms.servingsdk.Context; import software.amazon.ai.mms.servingsdk.Model; public class ModelServerContext implements Context { @Override public Properties getConfig() { return ConfigManager.getInstance().getConfiguration(); } @Override public Map getModels() { HashMap r = new HashMap<>(); ModelManager.getInstance().getModels().forEach((k, v) -> r.put(k, new ModelServerModel(v))); return r; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/servingsdk/impl/ModelServerModel.java ================================================ /* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.servingsdk.impl; import com.amazonaws.ml.mms.wlm.ModelManager; import java.util.ArrayList; import java.util.List; import software.amazon.ai.mms.servingsdk.Model; import software.amazon.ai.mms.servingsdk.Worker; public class ModelServerModel implements Model { private final com.amazonaws.ml.mms.wlm.Model model; public ModelServerModel(com.amazonaws.ml.mms.wlm.Model m) { model = m; } @Override public String getModelName() { return model.getModelName(); } @Override public String getModelUrl() { return model.getModelUrl(); } @Override public String getModelHandler() { return model.getModelArchive().getHandler(); } @Override public List getModelWorkers() { ArrayList list = new ArrayList<>(); ModelManager.getInstance() .getWorkers(model.getModelName()) .forEach(r -> list.add(new ModelWorker(r))); return list; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/servingsdk/impl/ModelServerRequest.java ================================================ /* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.servingsdk.impl; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.QueryStringDecoder; import java.io.ByteArrayInputStream; import java.util.ArrayList; import java.util.List; import java.util.Map; import software.amazon.ai.mms.servingsdk.http.Request; public class ModelServerRequest implements Request { private FullHttpRequest req; private QueryStringDecoder decoder; public ModelServerRequest(FullHttpRequest r, QueryStringDecoder d) { req = r; decoder = d; } @Override public List getHeaderNames() { return new ArrayList<>(req.headers().names()); } @Override public String getRequestURI() { return req.uri(); } @Override public Map> getParameterMap() { return decoder.parameters(); } @Override public List getParameter(String k) { return decoder.parameters().get(k); } @Override public String getContentType() { return HttpUtil.getMimeType(req).toString(); } @Override public ByteArrayInputStream getInputStream() { return new ByteArrayInputStream(req.content().array()); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/servingsdk/impl/ModelServerResponse.java ================================================ /* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.servingsdk.impl; import io.netty.buffer.ByteBufOutputStream; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpResponseStatus; import java.io.OutputStream; import software.amazon.ai.mms.servingsdk.http.Response; public class ModelServerResponse implements Response { private FullHttpResponse response; public ModelServerResponse(FullHttpResponse rsp) { response = rsp; } @Override public void setStatus(int i) { response.setStatus(HttpResponseStatus.valueOf(i)); } @Override public void setStatus(int i, String s) { response.setStatus(HttpResponseStatus.valueOf(i, s)); } @Override public void setHeader(String k, String v) { response.headers().set(k, v); } @Override public void addHeader(String k, String v) { response.headers().add(k, v); } @Override public void setContentType(String contentType) { response.headers().set(HttpHeaderNames.CONTENT_TYPE, contentType); } @Override public OutputStream getOutputStream() { return new ByteBufOutputStream(response.content()); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/servingsdk/impl/ModelWorker.java ================================================ /* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.servingsdk.impl; import com.amazonaws.ml.mms.wlm.WorkerState; import com.amazonaws.ml.mms.wlm.WorkerThread; import software.amazon.ai.mms.servingsdk.Worker; public class ModelWorker implements Worker { boolean running; long memory; public ModelWorker(WorkerThread t) { running = t.getState() == WorkerState.WORKER_MODEL_LOADED; memory = t.getMemory(); } @Override public boolean isRunning() { return running; } @Override public long getWorkerMemory() { return memory; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/servingsdk/impl/PluginsManager.java ================================================ /* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.servingsdk.impl; import com.amazonaws.ml.mms.http.InvalidPluginException; import java.lang.annotation.Annotation; import java.util.HashMap; import java.util.Map; import java.util.ServiceLoader; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.ai.mms.servingsdk.ModelServerEndpoint; import software.amazon.ai.mms.servingsdk.annotations.Endpoint; import software.amazon.ai.mms.servingsdk.annotations.helpers.EndpointTypes; public final class PluginsManager { private static final PluginsManager INSTANCE = new PluginsManager(); private Logger logger = LoggerFactory.getLogger(PluginsManager.class); private Map inferenceEndpoints; private Map managementEndpoints; private PluginsManager() {} public static PluginsManager getInstance() { return INSTANCE; } public void initialize() { inferenceEndpoints = initInferenceEndpoints(); managementEndpoints = initManagementEndpoints(); } private boolean validateEndpointPlugin(Annotation a, EndpointTypes type) { return a instanceof Endpoint && !((Endpoint) a).urlPattern().isEmpty() && ((Endpoint) a).endpointType().equals(type); } private HashMap getEndpoints(EndpointTypes type) throws InvalidPluginException { ServiceLoader loader = ServiceLoader.load(ModelServerEndpoint.class); HashMap ep = new HashMap<>(); for (ModelServerEndpoint mep : loader) { Class modelServerEndpointClassObj = mep.getClass(); Annotation[] annotations = modelServerEndpointClassObj.getAnnotations(); for (Annotation a : annotations) { if (validateEndpointPlugin(a, type)) { if (ep.get(((Endpoint) a).urlPattern()) != null) { throw new InvalidPluginException( "Multiple plugins found for endpoint " + "\"" + ((Endpoint) a).urlPattern() + "\""); } logger.info("Loading plugin for endpoint {}", ((Endpoint) a).urlPattern()); ep.put(((Endpoint) a).urlPattern(), mep); } } } return ep; } private HashMap initInferenceEndpoints() { return getEndpoints(EndpointTypes.INFERENCE); } private HashMap initManagementEndpoints() { return getEndpoints(EndpointTypes.MANAGEMENT); } public Map getInferenceEndpoints() { return inferenceEndpoints; } public Map getManagementEndpoints() { return managementEndpoints; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/ConfigManager.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.util; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.SelfSignedCertificate; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Field; import java.net.InetAddress; import java.net.UnknownHostException; import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; import java.security.KeyException; import java.security.KeyFactory; import java.security.KeyStore; import java.security.PrivateKey; import java.security.cert.Certificate; import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; import java.security.spec.InvalidKeySpecException; import java.security.spec.PKCS8EncodedKeySpec; import java.util.Arrays; import java.util.Base64; import java.util.Collection; import java.util.Enumeration; import java.util.HashMap; import java.util.InvalidPropertiesFormatException; import java.util.List; import java.util.Properties; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.regex.PatternSyntaxException; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Option; import org.apache.commons.cli.Options; import org.apache.commons.io.IOUtils; public final class ConfigManager { // Variables that can be configured through config.properties and Environment Variables // NOTE: Variables which can be configured through environment variables **SHOULD** have a // "MMS_" prefix private static final String MMS_DEBUG = "debug"; private static final String MMS_INFERENCE_ADDRESS = "inference_address"; private static final String MMS_MANAGEMENT_ADDRESS = "management_address"; private static final String MMS_LOAD_MODELS = "load_models"; private static final String MMS_BLACKLIST_ENV_VARS = "blacklist_env_vars"; private static final String MMS_DEFAULT_WORKERS_PER_MODEL = "default_workers_per_model"; private static final String MMS_DEFAULT_RESPONSE_TIMEOUT = "default_response_timeout"; private static final String MMS_DEFAULT_RESPONSE_TIMEOUT_SECONDS = "default_response_timeout_seconds"; private static final String MMS_UNREGISTER_MODEL_TIMEOUT = "unregister_model_timeout"; private static final String MMS_NUMBER_OF_NETTY_THREADS = "number_of_netty_threads"; private static final String MMS_NETTY_CLIENT_THREADS = "netty_client_threads"; private static final String MMS_JOB_QUEUE_SIZE = "job_queue_size"; private static final String MMS_NUMBER_OF_GPU = "number_of_gpu"; private static final String MMS_ASYNC_LOGGING = "async_logging"; private static final String MMS_CORS_ALLOWED_ORIGIN = "cors_allowed_origin"; private static final String MMS_CORS_ALLOWED_METHODS = "cors_allowed_methods"; private static final String MMS_CORS_ALLOWED_HEADERS = "cors_allowed_headers"; private static final String MMS_DECODE_INPUT_REQUEST = "decode_input_request"; private static final String MMS_KEYSTORE = "keystore"; private static final String MMS_KEYSTORE_PASS = "keystore_pass"; private static final String MMS_KEYSTORE_TYPE = "keystore_type"; private static final String MMS_CERTIFICATE_FILE = "certificate_file"; private static final String MMS_PRIVATE_KEY_FILE = "private_key_file"; private static final String MMS_MAX_REQUEST_SIZE = "max_request_size"; private static final String MMS_MAX_RESPONSE_SIZE = "max_response_size"; private static final String MMS_DEFAULT_SERVICE_HANDLER = "default_service_handler"; private static final String MMS_PRELOAD_MODEL = "preload_model"; private static final String MODEL_SERVER_HOME = "model_server_home"; private static final String MMS_MODEL_STORE = "model_store"; private static final String MMS_PREFER_DIRECT_BUFFER = "prefer_direct_buffer"; // Configuration which are not documented or enabled through environment variables private static final String USE_NATIVE_IO = "use_native_io"; private static final String IO_RATIO = "io_ratio"; private static final String METRIC_TIME_INTERVAL = "metric_time_interval"; private static final String ENABLE_ENVVARS_CONFIG = "enable_envvars_config"; // Variables which are local public static final String MODEL_METRICS_LOGGER = "MODEL_METRICS"; public static final String MODEL_LOGGER = "MODEL_LOG"; public static final String MODEL_SERVER_METRICS_LOGGER = "MMS_METRICS"; private Pattern blacklistPattern; private Properties prop; private static Pattern pattern = Pattern.compile("\\$\\$([^$]+[^$])\\$\\$"); private static ConfigManager instance; private String hostName; private ConfigManager(Arguments args) { prop = new Properties(); String filePath = System.getenv("MMS_CONFIG_FILE"); if (filePath == null) { filePath = args.getMmsConfigFile(); if (filePath == null) { filePath = System.getProperty("mmsConfigFile", "config.properties"); } } File file = new File(filePath); if (file.exists()) { try (FileInputStream stream = new FileInputStream(file)) { prop.load(stream); prop.put("mmsConfigFile", filePath); } catch (IOException e) { throw new IllegalStateException("Unable to read configuration file", e); } } resolveEnvVarVals(prop); String modelStore = args.getModelStore(); if (modelStore != null) { prop.setProperty(MMS_MODEL_STORE, modelStore); } String[] models = args.getModels(); if (models != null) { prop.setProperty(MMS_LOAD_MODELS, String.join(",", models)); } prop.setProperty( MMS_NUMBER_OF_GPU, String.valueOf( Integer.min( getAvailableGpu(), getIntProperty(MMS_NUMBER_OF_GPU, Integer.MAX_VALUE)))); String pythonExecutable = args.getPythonExecutable(); if (pythonExecutable != null) { prop.setProperty("PYTHON_EXECUTABLE", pythonExecutable); } try { InetAddress ip = InetAddress.getLocalHost(); hostName = ip.getHostName(); } catch (UnknownHostException e) { hostName = "Unknown"; } if (Boolean.parseBoolean(prop.getProperty(MMS_ASYNC_LOGGING))) { enableAsyncLogging(); } if (Boolean.parseBoolean(getEnableEnvVarsConfig())) { // Environment variables have higher precedence over the config file variables setSystemVars(); } } private void resolveEnvVarVals(Properties prop) { Set keys = prop.stringPropertyNames(); for (String key : keys) { String val = prop.getProperty(key); Matcher matcher = pattern.matcher(val); if (matcher.find()) { StringBuffer sb = new StringBuffer(); do { String envVar = matcher.group(1); if (System.getenv(envVar) == null) { throw new IllegalArgumentException( "Invalid Environment Variable " + envVar); } matcher.appendReplacement(sb, System.getenv(envVar)); } while (matcher.find()); matcher.appendTail(sb); prop.setProperty(key, sb.toString()); } } } private void setSystemVars() { Class configClass = ConfigManager.class; Field[] fields = configClass.getDeclaredFields(); for (Field f : fields) { if (f.getName().startsWith("MMS_")) { String val = System.getenv(f.getName()); if (val != null) { try { prop.setProperty((String) f.get(ConfigManager.class), val); } catch (IllegalAccessException e) { e.printStackTrace(); // NOPMD } } } } } String getEnableEnvVarsConfig() { return prop.getProperty(ENABLE_ENVVARS_CONFIG, "false"); } public String getHostName() { return hostName; } public static void init(Arguments args) { instance = new ConfigManager(args); } public static ConfigManager getInstance() { return instance; } public boolean isDebug() { return Boolean.getBoolean("MMS_DEBUG") || Boolean.parseBoolean(prop.getProperty(MMS_DEBUG, "false")); } public Connector getListener(boolean management) { String binding; if (management) { binding = prop.getProperty(MMS_MANAGEMENT_ADDRESS, "http://127.0.0.1:8081"); } else { binding = prop.getProperty(MMS_INFERENCE_ADDRESS, "http://127.0.0.1:8080"); } return Connector.parse(binding, management); } public String getPreloadModel() { return getProperty(MMS_PRELOAD_MODEL, "false"); } public boolean getPreferDirectBuffer() { return Boolean.parseBoolean(getProperty(MMS_PREFER_DIRECT_BUFFER, "false")); } public int getNettyThreads() { return getIntProperty(MMS_NUMBER_OF_NETTY_THREADS, 0); } public int getNettyClientThreads() { return getIntProperty(MMS_NETTY_CLIENT_THREADS, 0); } public int getJobQueueSize() { return getIntProperty(MMS_JOB_QUEUE_SIZE, 100); } public int getNumberOfGpu() { return getIntProperty(MMS_NUMBER_OF_GPU, 0); } public String getMmsDefaultServiceHandler() { return getProperty(MMS_DEFAULT_SERVICE_HANDLER, null); } public Properties getConfiguration() { return new Properties(prop); } public int getConfiguredDefaultWorkersPerModel() { return getIntProperty(MMS_DEFAULT_WORKERS_PER_MODEL, 0); } public int getDefaultWorkers() { if (isDebug()) { return 1; } int workers = getConfiguredDefaultWorkersPerModel(); if ((workers == 0) && (prop.getProperty("NUM_WORKERS", null) != null)) { workers = getIntProperty("NUM_WORKERS", 0); } if (workers == 0) { workers = getNumberOfGpu(); } if (workers == 0) { workers = Runtime.getRuntime().availableProcessors(); } setProperty("NUM_WORKERS", Integer.toString(workers)); return workers; } public int getMetricTimeInterval() { return getIntProperty(METRIC_TIME_INTERVAL, 60); } public String getModelServerHome() { String mmsHome = System.getenv("MODEL_SERVER_HOME"); if (mmsHome == null) { mmsHome = System.getProperty(MODEL_SERVER_HOME); if (mmsHome == null) { mmsHome = getProperty(MODEL_SERVER_HOME, null); if (mmsHome == null) { mmsHome = getCanonicalPath(findMmsHome()); return mmsHome; } } } File dir = new File(mmsHome); if (!dir.exists()) { throw new IllegalArgumentException("Model server home not exist: " + mmsHome); } mmsHome = getCanonicalPath(dir); return mmsHome; } public String getPythonExecutable() { return prop.getProperty("PYTHON_EXECUTABLE", "python"); } public String getModelStore() { return getCanonicalPath(prop.getProperty(MMS_MODEL_STORE)); } public String getLoadModels() { return prop.getProperty(MMS_LOAD_MODELS); } public Pattern getBlacklistPattern() { return blacklistPattern; } public String getCorsAllowedOrigin() { return prop.getProperty(MMS_CORS_ALLOWED_ORIGIN); } public String getCorsAllowedMethods() { return prop.getProperty(MMS_CORS_ALLOWED_METHODS); } public String getCorsAllowedHeaders() { return prop.getProperty(MMS_CORS_ALLOWED_HEADERS); } public SslContext getSslContext() throws IOException, GeneralSecurityException { List supportedCiphers = Arrays.asList( "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"); PrivateKey privateKey; X509Certificate[] chain; String keyStoreFile = prop.getProperty(MMS_KEYSTORE); String privateKeyFile = prop.getProperty(MMS_PRIVATE_KEY_FILE); String certificateFile = prop.getProperty(MMS_CERTIFICATE_FILE); if (keyStoreFile != null) { char[] keystorePass = getProperty(MMS_KEYSTORE_PASS, "changeit").toCharArray(); String keystoreType = getProperty(MMS_KEYSTORE_TYPE, "PKCS12"); KeyStore keyStore = KeyStore.getInstance(keystoreType); try (InputStream is = new FileInputStream(keyStoreFile)) { keyStore.load(is, keystorePass); } Enumeration en = keyStore.aliases(); String keyAlias = null; while (en.hasMoreElements()) { String alias = en.nextElement(); if (keyStore.isKeyEntry(alias)) { keyAlias = alias; break; } } if (keyAlias == null) { throw new KeyException("No key entry found in keystore."); } privateKey = (PrivateKey) keyStore.getKey(keyAlias, keystorePass); Certificate[] certs = keyStore.getCertificateChain(keyAlias); chain = new X509Certificate[certs.length]; for (int i = 0; i < certs.length; ++i) { chain[i] = (X509Certificate) certs[i]; } } else if (privateKeyFile != null && certificateFile != null) { privateKey = loadPrivateKey(privateKeyFile); chain = loadCertificateChain(certificateFile); } else { SelfSignedCertificate ssc = new SelfSignedCertificate(); privateKey = ssc.key(); chain = new X509Certificate[] {ssc.cert()}; } return SslContextBuilder.forServer(privateKey, chain) .protocols("TLSv1.2") .ciphers(supportedCiphers) .build(); } private PrivateKey loadPrivateKey(String keyFile) throws IOException, GeneralSecurityException { KeyFactory keyFactory = KeyFactory.getInstance("RSA"); try (InputStream is = new FileInputStream(keyFile)) { String content = IOUtils.toString(is, StandardCharsets.UTF_8); content = content.replaceAll("-----(BEGIN|END)( RSA)? PRIVATE KEY-----\\s*", ""); byte[] buf = Base64.getMimeDecoder().decode(content); try { PKCS8EncodedKeySpec privKeySpec = new PKCS8EncodedKeySpec(buf); return keyFactory.generatePrivate(privKeySpec); } catch (InvalidKeySpecException e) { // old private key is OpenSSL format private key buf = OpenSslKey.convertPrivateKey(buf); PKCS8EncodedKeySpec privKeySpec = new PKCS8EncodedKeySpec(buf); return keyFactory.generatePrivate(privKeySpec); } } } private X509Certificate[] loadCertificateChain(String keyFile) throws IOException, GeneralSecurityException { CertificateFactory cf = CertificateFactory.getInstance("X.509"); try (InputStream is = new FileInputStream(keyFile)) { Collection certs = cf.generateCertificates(is); int i = 0; X509Certificate[] chain = new X509Certificate[certs.size()]; for (Certificate cert : certs) { chain[i++] = (X509Certificate) cert; } return chain; } } public String getProperty(String key, String def) { return prop.getProperty(key, def); } public void validateConfigurations() throws InvalidPropertiesFormatException { String blacklistVars = prop.getProperty(MMS_BLACKLIST_ENV_VARS, ""); try { blacklistPattern = Pattern.compile(blacklistVars); } catch (PatternSyntaxException e) { throw new InvalidPropertiesFormatException(e); } } public String dumpConfigurations() { Runtime runtime = Runtime.getRuntime(); return "\nMMS Home: " + getModelServerHome() + "\nCurrent directory: " + getCanonicalPath(".") + "\nTemp directory: " + System.getProperty("java.io.tmpdir") + "\nNumber of GPUs: " + getNumberOfGpu() + "\nNumber of CPUs: " + runtime.availableProcessors() + "\nMax heap size: " + (runtime.maxMemory() / 1024 / 1024) + " M\nPython executable: " + (getPythonExecutable() == null ? "N/A" : getPythonExecutable()) + "\nConfig file: " + prop.getProperty("mmsConfigFile", "N/A") + "\nInference address: " + getListener(false) + "\nManagement address: " + getListener(true) + "\nModel Store: " + (getModelStore() == null ? "N/A" : getModelStore()) + "\nInitial Models: " + (getLoadModels() == null ? "N/A" : getLoadModels()) + "\nLog dir: " + getCanonicalPath(System.getProperty("LOG_LOCATION")) + "\nMetrics dir: " + getCanonicalPath(System.getProperty("METRICS_LOCATION")) + "\nNetty threads: " + getNettyThreads() + "\nNetty client threads: " + getNettyClientThreads() + "\nDefault workers per model: " + getDefaultWorkers() + "\nBlacklist Regex: " + prop.getProperty(MMS_BLACKLIST_ENV_VARS, "N/A") + "\nMaximum Response Size: " + prop.getProperty(MMS_MAX_RESPONSE_SIZE, "6553500") + "\nMaximum Request Size: " + prop.getProperty(MMS_MAX_REQUEST_SIZE, "6553500") + "\nPreload model: " + prop.getProperty(MMS_PRELOAD_MODEL, "false") + "\nPrefer direct buffer: " + prop.getProperty(MMS_PREFER_DIRECT_BUFFER, "false"); } public boolean useNativeIo() { return Boolean.parseBoolean(prop.getProperty(USE_NATIVE_IO, "true")); } public int getIoRatio() { return getIntProperty(IO_RATIO, 50); } public int getMaxResponseSize() { return getIntProperty(MMS_MAX_RESPONSE_SIZE, 6553500); } public int getMaxRequestSize() { return getIntProperty(MMS_MAX_REQUEST_SIZE, 6553500); } void setProperty(String key, String value) { prop.setProperty(key, value); } private int getIntProperty(String key, int def) { String value = prop.getProperty(key); if (value == null) { return def; } return Integer.parseInt(value); } public int getDefaultResponseTimeoutSeconds() { // TODO The MMS_DEFAULT_RESPONSE_TIMEOUT variable was never intended to represent minutes, // but due to a bug that's what it did. We'd like to remove this and match the documented // behavior, but for now we're being cautious about backward compatibility. // Check both properties, prefer seconds if provided, convert to seconds for return value int timeoutSeconds = Integer.parseInt(prop.getProperty(MMS_DEFAULT_RESPONSE_TIMEOUT_SECONDS, "-1")); if (timeoutSeconds < 0) { int timeoutMinutes = Integer.parseInt(prop.getProperty(MMS_DEFAULT_RESPONSE_TIMEOUT, "120")); timeoutSeconds = 60 * timeoutMinutes; } return timeoutSeconds; } public int getUnregisterModelTimeout() { return Integer.parseInt(prop.getProperty(MMS_UNREGISTER_MODEL_TIMEOUT, "120")); } private File findMmsHome() { File cwd = new File(getCanonicalPath(".")); File file = cwd; while (file != null) { File mms = new File(file, "mms"); if (mms.exists()) { return file; } file = file.getParentFile(); } return cwd; } private void enableAsyncLogging() { System.setProperty( "log4j2.contextSelector", "org.apache.logging.log4j.core.async.AsyncLoggerContextSelector"); } public HashMap getBackendConfiguration() { HashMap config = new HashMap<>(); // Append properties used by backend worker here config.put("MMS_DECODE_INPUT_REQUEST", prop.getProperty(MMS_DECODE_INPUT_REQUEST, "true")); return config; } private static String getCanonicalPath(File file) { try { return file.getCanonicalPath(); } catch (IOException e) { return file.getAbsolutePath(); } } private static String getCanonicalPath(String path) { if (path == null) { return null; } return getCanonicalPath(new File(path)); } private static int getAvailableGpu() { try { Process process = Runtime.getRuntime().exec("nvidia-smi --query-gpu=index --format=csv"); int ret = process.waitFor(); if (ret != 0) { return 0; } List list = IOUtils.readLines(process.getInputStream(), StandardCharsets.UTF_8); if (list.isEmpty() || !"index".equals(list.get(0))) { throw new AssertionError("Unexpected nvidia-smi response."); } return list.size() - 1; } catch (IOException | InterruptedException e) { return 0; } } public static final class Arguments { private String mmsConfigFile; private String pythonExecutable; private String modelStore; private String[] models; public Arguments() {} public Arguments(CommandLine cmd) { mmsConfigFile = cmd.getOptionValue("mms-config-file"); pythonExecutable = cmd.getOptionValue("python"); modelStore = cmd.getOptionValue("model-store"); models = cmd.getOptionValues("models"); } public static Options getOptions() { Options options = new Options(); options.addOption( Option.builder("f") .longOpt("mms-config-file") .hasArg() .argName("MMS-CONFIG-FILE") .desc("Path to the configuration properties file.") .build()); options.addOption( Option.builder("e") .longOpt("python") .hasArg() .argName("PYTHON") .desc("Python runtime executable path.") .build()); options.addOption( Option.builder("m") .longOpt("models") .hasArgs() .argName("MODELS") .desc("Models to be loaded at startup.") .build()); options.addOption( Option.builder("s") .longOpt("model-store") .hasArg() .argName("MODELS-STORE") .desc("Model store location where models can be loaded.") .build()); return options; } public String getMmsConfigFile() { return mmsConfigFile; } public String getPythonExecutable() { return pythonExecutable; } public void setMmsConfigFile(String mmsConfigFile) { this.mmsConfigFile = mmsConfigFile; } public String getModelStore() { return modelStore; } public void setModelStore(String modelStore) { this.modelStore = modelStore; } public String[] getModels() { return models; } public void setModels(String[] models) { this.models = models; } } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/Connector.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.util; import io.netty.channel.Channel; import io.netty.channel.EventLoopGroup; import io.netty.channel.ServerChannel; import io.netty.channel.epoll.Epoll; import io.netty.channel.epoll.EpollDomainSocketChannel; import io.netty.channel.epoll.EpollEventLoopGroup; import io.netty.channel.epoll.EpollServerDomainSocketChannel; import io.netty.channel.epoll.EpollServerSocketChannel; import io.netty.channel.epoll.EpollSocketChannel; import io.netty.channel.kqueue.KQueue; import io.netty.channel.kqueue.KQueueDomainSocketChannel; import io.netty.channel.kqueue.KQueueEventLoopGroup; import io.netty.channel.kqueue.KQueueServerDomainSocketChannel; import io.netty.channel.kqueue.KQueueServerSocketChannel; import io.netty.channel.kqueue.KQueueSocketChannel; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.channel.unix.DomainSocketAddress; import java.io.File; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.Objects; import java.util.regex.Matcher; import java.util.regex.Pattern; import org.apache.commons.io.FileUtils; public class Connector { private static final Pattern ADDRESS_PATTERN = Pattern.compile( "((https|http)://([^:^/]+)(:([0-9]+))?)|(unix:(/.*))", Pattern.CASE_INSENSITIVE); private static boolean useNativeIo = ConfigManager.getInstance().useNativeIo(); private boolean uds; private String socketPath; private String bindIp; private int port; private boolean ssl; private boolean management; public Connector(int port) { this(port, useNativeIo && (Epoll.isAvailable() || KQueue.isAvailable())); } private Connector(int port, boolean uds) { this.port = port; this.uds = uds; if (uds) { bindIp = ""; socketPath = System.getProperty("java.io.tmpdir") + "/.mms.sock." + port; } else { bindIp = "127.0.0.1"; socketPath = String.valueOf(port); } } private Connector( int port, boolean uds, String bindIp, String socketPath, boolean ssl, boolean management) { this.port = port; this.uds = uds; this.bindIp = bindIp; this.socketPath = socketPath; this.ssl = ssl; this.management = management; } public static Connector parse(String binding, boolean management) { Matcher matcher = ADDRESS_PATTERN.matcher(binding); if (!matcher.matches()) { throw new IllegalArgumentException("Invalid binding address: " + binding); } boolean uds = matcher.group(7) != null; if (uds) { if (!useNativeIo) { throw new IllegalArgumentException( "unix domain socket requires use_native_io set to true."); } String path = matcher.group(7); return new Connector(-1, true, "", path, false, management); } String protocol = matcher.group(2); String host = matcher.group(3); String listeningPort = matcher.group(5); boolean ssl = "https".equalsIgnoreCase(protocol); int port; if (listeningPort == null) { if (management) { port = ssl ? 8444 : 8081; } else { port = ssl ? 443 : 80; } } else { port = Integer.parseInt(listeningPort); } if (port >= Short.MAX_VALUE) { throw new IllegalArgumentException("Invalid port number: " + binding); } return new Connector(port, false, host, String.valueOf(port), ssl, management); } public String getSocketType() { return uds ? "unix" : "tcp"; } public String getSocketPath() { return socketPath; } public boolean isUds() { return uds; } public boolean isSsl() { return ssl; } public boolean isManagement() { return management; } public SocketAddress getSocketAddress() { return uds ? new DomainSocketAddress(socketPath) : new InetSocketAddress(bindIp, port); } public String getPurpose() { return management ? "Management" : "Inference"; } public static EventLoopGroup newEventLoopGroup(int threads) { if (useNativeIo && Epoll.isAvailable()) { return new EpollEventLoopGroup(threads); } else if (useNativeIo && KQueue.isAvailable()) { return new KQueueEventLoopGroup(threads); } NioEventLoopGroup eventLoopGroup = new NioEventLoopGroup(threads); eventLoopGroup.setIoRatio(ConfigManager.getInstance().getIoRatio()); return eventLoopGroup; } public Class getServerChannel() { if (useNativeIo && Epoll.isAvailable()) { return uds ? EpollServerDomainSocketChannel.class : EpollServerSocketChannel.class; } else if (useNativeIo && KQueue.isAvailable()) { return uds ? KQueueServerDomainSocketChannel.class : KQueueServerSocketChannel.class; } return NioServerSocketChannel.class; } public Class getClientChannel() { if (useNativeIo && Epoll.isAvailable()) { return uds ? EpollDomainSocketChannel.class : EpollSocketChannel.class; } else if (useNativeIo && KQueue.isAvailable()) { return uds ? KQueueDomainSocketChannel.class : KQueueSocketChannel.class; } return NioSocketChannel.class; } public void clean() { if (uds) { FileUtils.deleteQuietly(new File(socketPath)); } } @Override public boolean equals(Object o) { if (this == o) { return true; } if (o == null || getClass() != o.getClass()) { return false; } Connector connector = (Connector) o; return uds == connector.uds && port == connector.port && socketPath.equals(connector.socketPath) && bindIp.equals(connector.bindIp); } @Override public int hashCode() { return Objects.hash(uds, socketPath, bindIp, port); } @Override public String toString() { if (uds) { return "unix:" + socketPath; } else if (ssl) { return "https://" + bindIp + ':' + port; } return "http://" + bindIp + ':' + port; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/ConnectorType.java ================================================ package com.amazonaws.ml.mms.util; public enum ConnectorType { INFERENCE_CONNECTOR, MANAGEMENT_CONNECTOR, BOTH } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/JsonUtils.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.util; import com.google.gson.Gson; import com.google.gson.GsonBuilder; public final class JsonUtils { public static final Gson GSON_PRETTY = new GsonBuilder() .setDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") .setPrettyPrinting() .create(); public static final Gson GSON = new GsonBuilder().create(); private JsonUtils() {} } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/NettyUtils.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.util; import com.amazonaws.ml.mms.http.ErrorResponse; import com.amazonaws.ml.mms.http.Session; import com.amazonaws.ml.mms.metrics.Dimension; import com.amazonaws.ml.mms.metrics.Metric; import com.amazonaws.ml.mms.util.messages.InputParameter; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.DefaultHttpHeadersFactory; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderValues; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http.QueryStringDecoder; import io.netty.handler.codec.http.multipart.Attribute; import io.netty.handler.codec.http.multipart.FileUpload; import io.netty.handler.codec.http.multipart.InterfaceHttpData; import io.netty.util.AttributeKey; import io.netty.util.CharsetUtil; import java.io.IOException; import java.net.SocketAddress; import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** A utility class that handling Netty request and response. */ public final class NettyUtils { private static final Logger logger = LoggerFactory.getLogger("ACCESS_LOG"); private static final String REQUEST_ID = "x-request-id"; private static final AttributeKey SESSION_KEY = AttributeKey.valueOf("session"); private static final Dimension DIMENSION = new Dimension("Level", "Host"); private static final Metric REQUESTS_2_XX = new Metric( "Requests2XX", "1", "Count", ConfigManager.getInstance().getHostName(), DIMENSION); private static final Metric REQUESTS_4_XX = new Metric( "Requests4XX", "1", "Count", ConfigManager.getInstance().getHostName(), DIMENSION); private static final Metric REQUESTS_5_XX = new Metric( "Requests5XX", "1", "Count", ConfigManager.getInstance().getHostName(), DIMENSION); private static final Logger loggerMmsMetrics = LoggerFactory.getLogger(ConfigManager.MODEL_SERVER_METRICS_LOGGER); private NettyUtils() {} public static void requestReceived(Channel channel, HttpRequest request) { Session session = channel.attr(SESSION_KEY).get(); assert session == null; SocketAddress address = channel.remoteAddress(); String remoteIp; if (address == null) { // This is can be null on UDS, or on certain case in Windows remoteIp = "0.0.0.0"; } else { remoteIp = address.toString(); } channel.attr(SESSION_KEY).set(new Session(remoteIp, request)); } public static String getRequestId(Channel channel) { Session accessLog = channel.attr(SESSION_KEY).get(); if (accessLog != null) { return accessLog.getRequestId(); } return null; } public static void sendJsonResponse(ChannelHandlerContext ctx, Object json) { sendJsonResponse(ctx, JsonUtils.GSON_PRETTY.toJson(json), HttpResponseStatus.OK); } public static void sendJsonResponse( ChannelHandlerContext ctx, Object json, HttpResponseStatus status) { sendJsonResponse(ctx, JsonUtils.GSON_PRETTY.toJson(json), status); } public static void sendJsonResponse(ChannelHandlerContext ctx, String json) { sendJsonResponse(ctx, json, HttpResponseStatus.OK); } public static void sendJsonResponse( ChannelHandlerContext ctx, String json, HttpResponseStatus status) { FullHttpResponse resp = new DefaultFullHttpResponse( HttpVersion.HTTP_1_1, status, Unpooled.directBuffer(), DefaultHttpHeadersFactory.headersFactory(), DefaultHttpHeadersFactory.trailersFactory()); resp.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); ByteBuf content = resp.content(); content.writeCharSequence(json, CharsetUtil.UTF_8); content.writeByte('\n'); sendHttpResponse(ctx, resp, true); } public static void sendError( ChannelHandlerContext ctx, HttpResponseStatus status, Throwable t) { ErrorResponse error = new ErrorResponse(status.code(), t.getClass().getSimpleName(), t.getMessage()); sendJsonResponse(ctx, error, status); } public static void sendError( ChannelHandlerContext ctx, HttpResponseStatus status, Throwable t, String msg) { ErrorResponse error = new ErrorResponse(status.code(), t.getClass().getSimpleName(), msg); sendJsonResponse(ctx, error, status); } /** * Send HTTP response to client. * * @param ctx ChannelHandlerContext * @param resp HttpResponse to send * @param keepAlive if keep the connection */ public static void sendHttpResponse( ChannelHandlerContext ctx, FullHttpResponse resp, boolean keepAlive) { // Send the response and close the connection if necessary. Channel channel = ctx.channel(); Session session = channel.attr(SESSION_KEY).getAndSet(null); HttpHeaders headers = resp.headers(); ConfigManager configManager = ConfigManager.getInstance(); if (session != null) { // session might be recycled if channel is closed already. session.setCode(resp.status().code()); headers.set(REQUEST_ID, session.getRequestId()); logger.info(session.toString()); } int code = resp.status().code(); if (code >= 200 && code < 300) { loggerMmsMetrics.info("{}", REQUESTS_2_XX); } else if (code >= 400 && code < 500) { loggerMmsMetrics.info("{}", REQUESTS_4_XX); } else { loggerMmsMetrics.info("{}", REQUESTS_5_XX); } String allowedOrigin = configManager.getCorsAllowedOrigin(); String allowedMethods = configManager.getCorsAllowedMethods(); String allowedHeaders = configManager.getCorsAllowedHeaders(); if (allowedOrigin != null && !allowedOrigin.isEmpty() && !headers.contains(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN)) { headers.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, allowedOrigin); } if (allowedMethods != null && !allowedMethods.isEmpty() && !headers.contains(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS)) { headers.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS, allowedMethods); } if (allowedHeaders != null && !allowedHeaders.isEmpty() && !headers.contains(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS)) { headers.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, allowedHeaders); } // Add cache-control headers to avoid browser cache response headers.set("Pragma", "no-cache"); headers.set("Cache-Control", "no-cache; no-store, must-revalidate, private"); headers.set("Expires", "Thu, 01 Jan 1970 00:00:00 UTC"); HttpUtil.setContentLength(resp, resp.content().readableBytes()); if (!keepAlive || code >= 400) { headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE); ChannelFuture f = channel.writeAndFlush(resp); f.addListener(ChannelFutureListener.CLOSE); } else { headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE); channel.writeAndFlush(resp); } } /** Closes the specified channel after all queued write requests are flushed. */ public static void closeOnFlush(Channel ch) { if (ch.isActive()) { ch.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE); } } public static byte[] getBytes(ByteBuf buf) { if (buf.hasArray()) { return buf.array(); } byte[] ret = new byte[buf.readableBytes()]; int readerIndex = buf.readerIndex(); buf.getBytes(readerIndex, ret); return ret; } public static String getParameter(QueryStringDecoder decoder, String key, String def) { List param = decoder.parameters().get(key); if (param != null && !param.isEmpty()) { return param.get(0); } return def; } public static int getIntParameter(QueryStringDecoder decoder, String key, int def) { String value = getParameter(decoder, key, null); if (value == null) { return def; } try { return Integer.parseInt(value); } catch (NumberFormatException e) { return def; } } public static InputParameter getFormData(InterfaceHttpData data) { if (data == null) { return null; } String name = data.getName(); switch (data.getHttpDataType()) { case Attribute: Attribute attribute = (Attribute) data; try { return new InputParameter(name, attribute.getValue()); } catch (IOException e) { throw new AssertionError(e); } case FileUpload: FileUpload fileUpload = (FileUpload) data; String contentType = fileUpload.getContentType(); try { return new InputParameter(name, getBytes(fileUpload.getByteBuf()), contentType); } catch (IOException e) { throw new AssertionError(e); } default: throw new IllegalArgumentException( "Except form field, but got " + data.getHttpDataType()); } } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/OpenSslKey.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.util; /** A utility class converting OpenSSL private key to PKCS8 private key. */ public final class OpenSslKey { private static final int[] RSA_ENCRYPTION = {1, 2, 840, 113549, 1, 1, 1}; private static final byte[] NULL_BYTES = {0x05, 0x00}; private OpenSslKey() {} /** * Convert OpenSSL private key to PKCS8 private key. * * @param keySpec OpenSSL key spec * @return PKCS8 encoded private key */ public static byte[] convertPrivateKey(byte[] keySpec) { if (keySpec == null) { return null; } byte[] bytes = new byte[keySpec.length]; System.arraycopy(keySpec, 0, bytes, 0, keySpec.length); byte[] octetBytes = encodeOctetString(bytes); byte[] oidBytes = encodeOID(RSA_ENCRYPTION); byte[] verBytes = {0x02, 0x01, 0x00}; byte[][] seqBytes = new byte[4][]; seqBytes[0] = oidBytes; seqBytes[1] = NULL_BYTES; seqBytes[2] = null; byte[] oidSeqBytes = encodeSequence(seqBytes); seqBytes[0] = verBytes; seqBytes[1] = oidSeqBytes; seqBytes[2] = octetBytes; seqBytes[3] = null; return encodeSequence(seqBytes); } private static byte[] encodeOID(int[] oid) { if (oid == null) { return null; } int oLen = 1; for (int i = 2; i < oid.length; i++) { oLen += getOIDCompLength(oid[i]); } int len = oLen + getLengthOfLengthField(oLen) + 1; byte[] bytes = new byte[len]; bytes[0] = 0x06; // ASN Object ID int offset = writeLengthField(bytes, oLen); bytes[offset++] = (byte) (40 * oid[0] + oid[1]); for (int i = 2; i < oid.length; i++) { offset = writeOIDComp(oid[i], bytes, offset); } return bytes; } private static byte[] encodeOctetString(byte[] bytes) { if (bytes == null) { return null; } int oLen = bytes.length; // one byte for unused bits field int len = oLen + getLengthOfLengthField(oLen) + 1; byte[] newBytes = new byte[len]; newBytes[0] = 0x04; int offset = writeLengthField(newBytes, oLen); if (len - oLen != offset) { return null; } System.arraycopy(bytes, 0, newBytes, offset, oLen); return newBytes; } private static byte[] encodeSequence(byte[][] byteArrays) { if (byteArrays == null) { return null; } int oLen = 0; for (byte[] b : byteArrays) { if (b == null) { break; } oLen += b.length; } int len = oLen + getLengthOfLengthField(oLen) + 1; byte[] bytes = new byte[len]; bytes[0] = 0x10 | 0x20; // ASN sequence & constructed int offset = writeLengthField(bytes, oLen); if (len - oLen != offset) { return null; } for (byte[] b : byteArrays) { if (b == null) { break; } System.arraycopy(b, 0, bytes, offset, b.length); offset += b.length; } return bytes; } private static int writeLengthField(byte[] bytes, int len) { if (len < 127) { bytes[1] = (byte) len; return 2; } int lenOfLenField = getLengthOfLengthField(len); bytes[1] = (byte) ((lenOfLenField - 1) | 0x80); // record length of the length field for (int i = lenOfLenField; i >= 2; i--) { // write the length bytes[i] = (byte) (len >> ((lenOfLenField - i) * 8)); } return lenOfLenField + 1; } private static int getLengthOfLengthField(int len) { if (len <= 127) { // highest bit is zero, one byte is enough return 1; } else if (len <= 0xFF) { // highest bit is 1, two bytes in the form {0x81, 0xab} return 2; } else if (len <= 0xFFFF) { // three bytes in the form {0x82, 0xab, 0xcd} return 3; } else if (len <= 0xFFFFFF) { // four bytes in the form {0x83, 0xab, 0xcd, 0xef} return 4; } else { // five bytes in the form {0x84, 0xab, 0xcd, 0xef, 0xgh} return 5; } } private static int getOIDCompLength(int comp) { if (comp <= 0x7F) { return 1; } else if (comp <= 0x3FFF) { return 2; } else if (comp <= 0x1FFFFF) { return 3; } else if (comp <= 0xFFFFFFF) { return 4; } else { return 5; } } private static int writeOIDComp(int comp, byte[] bytes, int offset) { int len = getOIDCompLength(comp); int off = offset; for (int i = len - 1; i > 0; i--) { bytes[off++] = (byte) ((comp >>> i * 7) | 0x80); } bytes[off++] = (byte) (comp & 0x7F); return off; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/ServerGroups.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.util; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.EventLoopGroup; import io.netty.channel.group.ChannelGroup; import io.netty.channel.group.ChannelGroupFuture; import io.netty.channel.group.DefaultChannelGroup; import io.netty.util.concurrent.GlobalEventExecutor; import java.util.ArrayList; import java.util.List; import java.util.concurrent.TimeUnit; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class ServerGroups { static final Logger logger = LoggerFactory.getLogger(ServerGroups.class); private ChannelGroup allChannels; private EventLoopGroup serverGroup; private EventLoopGroup childGroup; private EventLoopGroup backendGroup; private ConfigManager configManager; public ServerGroups(ConfigManager configManager) { this.configManager = configManager; init(); } public final void init() { allChannels = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE); serverGroup = Connector.newEventLoopGroup(2); childGroup = Connector.newEventLoopGroup(configManager.getNettyThreads()); backendGroup = Connector.newEventLoopGroup(configManager.getNettyClientThreads()); } public void shutdown(boolean graceful) { closeAllChannels(graceful); List allEventLoopGroups = new ArrayList<>(); allEventLoopGroups.add(serverGroup); allEventLoopGroups.add(childGroup); for (EventLoopGroup group : allEventLoopGroups) { if (graceful) { group.shutdownGracefully(); } else { group.shutdownGracefully(0, 0, TimeUnit.SECONDS); } } if (graceful) { for (EventLoopGroup group : allEventLoopGroups) { try { group.awaitTermination(60, TimeUnit.SECONDS); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } } } } public EventLoopGroup getServerGroup() { return serverGroup; } public EventLoopGroup getChildGroup() { return childGroup; } public EventLoopGroup getBackendGroup() { return backendGroup; } public void registerChannel(Channel channel) { allChannels.add(channel); } private void closeAllChannels(boolean graceful) { ChannelGroupFuture future = allChannels.close(); // if this is a graceful shutdown, log any channel closing failures. if this isn't a // graceful shutdown, ignore them. if (graceful) { try { future.await(10, TimeUnit.SECONDS); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } if (!future.isSuccess()) { for (ChannelFuture cf : future) { if (!cf.isSuccess()) { logger.info("Unable to close channel: " + cf.channel(), cf.cause()); } } } } } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/codec/CodecUtils.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.util.codec; import io.netty.buffer.ByteBuf; import io.netty.handler.codec.CorruptedFrameException; import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; public final class CodecUtils { public static final int END = -1; public static final int BUFFER_UNDER_RUN = -3; private CodecUtils() {} static int readLength(ByteBuf byteBuf, int maxLength) { int size = byteBuf.readableBytes(); if (size < 4) { return BUFFER_UNDER_RUN; } int len = byteBuf.readInt(); if (len > maxLength) { throw new CorruptedFrameException("Message size exceed limit: " + len); } if (len > byteBuf.readableBytes()) { return BUFFER_UNDER_RUN; } return len; } static String readString(ByteBuf byteBuf, int len) { return new String(read(byteBuf, len), StandardCharsets.UTF_8); } static byte[] read(ByteBuf in, int len) { if (len < 0) { throw new CorruptedFrameException("Invalid message size: " + len); } byte[] buf = new byte[len]; in.readBytes(buf); return buf; } static Map readMap(ByteBuf in, int len) { HashMap ret = new HashMap<>(); for (; len > 0; len--) { int l = readLength(in, in.readableBytes()); String key = readString(in, l); l = readLength(in, in.readableBytes()); String val = readString(in, l); ret.put(key, val); } return ret; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/codec/ModelRequestEncoder.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.util.codec; import com.amazonaws.ml.mms.util.messages.BaseModelRequest; import com.amazonaws.ml.mms.util.messages.InputParameter; import com.amazonaws.ml.mms.util.messages.ModelInferenceRequest; import com.amazonaws.ml.mms.util.messages.ModelLoadModelRequest; import com.amazonaws.ml.mms.util.messages.RequestInput; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.MessageToByteEncoder; import java.nio.charset.StandardCharsets; import java.util.Map; @ChannelHandler.Sharable public class ModelRequestEncoder extends MessageToByteEncoder { public ModelRequestEncoder(boolean preferDirect) { super(preferDirect); } @Override protected void encode(ChannelHandlerContext ctx, BaseModelRequest msg, ByteBuf out) { if (msg instanceof ModelLoadModelRequest) { out.writeByte('L'); ModelLoadModelRequest request = (ModelLoadModelRequest) msg; byte[] buf = msg.getModelName().getBytes(StandardCharsets.UTF_8); out.writeInt(buf.length); out.writeBytes(buf); buf = request.getModelPath().getBytes(StandardCharsets.UTF_8); out.writeInt(buf.length); out.writeBytes(buf); int batchSize = request.getBatchSize(); if (batchSize <= 0) { batchSize = 1; } out.writeInt(batchSize); buf = request.getHandler().getBytes(StandardCharsets.UTF_8); out.writeInt(buf.length); out.writeBytes(buf); out.writeInt(request.getGpuId()); buf = request.getIoFileDescriptor().getBytes(StandardCharsets.UTF_8); out.writeInt(buf.length); out.writeBytes(buf); } else if (msg instanceof ModelInferenceRequest) { out.writeByte('I'); ModelInferenceRequest request = (ModelInferenceRequest) msg; for (RequestInput input : request.getRequestBatch()) { encodeRequest(input, out); } out.writeInt(-1); // End of List } } private void encodeRequest(RequestInput req, ByteBuf out) { byte[] buf = req.getRequestId().getBytes(StandardCharsets.UTF_8); out.writeInt(buf.length); out.writeBytes(buf); for (Map.Entry entry : req.getHeaders().entrySet()) { encodeField(entry.getKey(), out); encodeField(entry.getValue(), out); } out.writeInt(-1); // End of List for (InputParameter input : req.getParameters()) { encodeParameter(input, out); } out.writeInt(-1); // End of List } private void encodeParameter(InputParameter parameter, ByteBuf out) { byte[] modelInputName = parameter.getName().getBytes(StandardCharsets.UTF_8); out.writeInt(modelInputName.length); out.writeBytes(modelInputName); encodeField(parameter.getContentType(), out); byte[] buf = parameter.getValue(); out.writeInt(buf.length); out.writeBytes(buf); } private static void encodeField(CharSequence field, ByteBuf out) { if (field == null) { out.writeInt(0); return; } byte[] buf = field.toString().getBytes(StandardCharsets.UTF_8); out.writeInt(buf.length); out.writeBytes(buf); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/codec/ModelResponseDecoder.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.util.codec; import com.amazonaws.ml.mms.util.messages.ModelWorkerResponse; import com.amazonaws.ml.mms.util.messages.Predictions; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageDecoder; import java.util.ArrayList; import java.util.List; public class ModelResponseDecoder extends ByteToMessageDecoder { private final int maxBufferSize; public ModelResponseDecoder(int maxBufferSize) { this.maxBufferSize = maxBufferSize; } @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { int size = in.readableBytes(); if (size < 9) { return; } in.markReaderIndex(); boolean completed = false; try { ModelWorkerResponse resp = new ModelWorkerResponse(); // Get Response overall Code resp.setCode(in.readInt()); int len = CodecUtils.readLength(in, maxBufferSize); if (len == CodecUtils.BUFFER_UNDER_RUN) { return; } resp.setMessage(CodecUtils.readString(in, len)); List predictions = new ArrayList<>(); while ((len = CodecUtils.readLength(in, maxBufferSize)) != CodecUtils.END) { if (len == CodecUtils.BUFFER_UNDER_RUN) { return; } Predictions prediction = new Predictions(); // Set response RequestId prediction.setRequestId(CodecUtils.readString(in, len)); len = CodecUtils.readLength(in, maxBufferSize); if (len == CodecUtils.BUFFER_UNDER_RUN) { return; } // Set content type prediction.setContentType(CodecUtils.readString(in, len)); // Set per request response code int httpStatusCode = in.readInt(); prediction.setStatusCode(httpStatusCode); // Set the actual message len = CodecUtils.readLength(in, maxBufferSize); if (len == CodecUtils.BUFFER_UNDER_RUN) { return; } prediction.setReasonPhrase(CodecUtils.readString(in, len)); len = CodecUtils.readLength(in, maxBufferSize); if (len == CodecUtils.BUFFER_UNDER_RUN) { return; } prediction.setHeaders(CodecUtils.readMap(in, len)); len = CodecUtils.readLength(in, maxBufferSize); if (len == CodecUtils.BUFFER_UNDER_RUN) { return; } prediction.setResp(CodecUtils.read(in, len)); predictions.add(prediction); } resp.setPredictions(predictions); out.add(resp); completed = true; } finally { if (!completed) { in.resetReaderIndex(); } } } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/logging/QLogLayout.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.util.logging; import com.amazonaws.ml.mms.metrics.Dimension; import com.amazonaws.ml.mms.metrics.Metric; import org.apache.logging.log4j.core.Layout; import org.apache.logging.log4j.core.LogEvent; import org.apache.logging.log4j.core.config.Node; import org.apache.logging.log4j.core.config.plugins.Plugin; import org.apache.logging.log4j.core.config.plugins.PluginFactory; import org.apache.logging.log4j.core.layout.AbstractStringLayout; import org.apache.logging.log4j.message.Message; @Plugin( name = "QLogLayout", category = Node.CATEGORY, elementType = Layout.ELEMENT_TYPE, printObject = true) public class QLogLayout extends AbstractStringLayout { public QLogLayout() { super(null, null, null); } /** * Model server also supports query log formatting. * *

To enable Query Log format, change the layout as follows * *

     *     log4j.appender.model_metrics.layout = com.amazonaws.ml.mms.util.logging.QLogLayout
     * 
* * This enables logs which are shown as following * *
     *     HostName=hostName
     *     RequestId=004bd136-063c-4102-a070-d7aff5add939
     *     Marketplace=US
     *     StartTime=1542275707
     *     Program=MXNetModelServer
     *     Metrics=PredictionTime=45 Milliseconds ModelName|squeezenet  Level|Model
     *     EOE
     * 
* * Note: The following entities in this metrics can be customized. * *
    *
  • Marketplace : This can be customized by setting the "REALM" system environment * variable. *
  • Program : This entity can be customized by setting "MXNETMODELSERVER_PROGRAM" * environment variable. *
* * Example: If the above environment variables are set to the following, * *
     *     $ env
     *     REALM=someRealm
     *     MXNETMODELSERVER_PROGRAM=someProgram
     * 
* * This produces the metrics as follows * *
     *    HostName=hostName
     *    RequestId=004bd136-063c-4102-a070-d7aff5add939
     *    Marketplace=someRealm
     *    StartTime=1542275707
     *    Program=someProgram
     *    Metrics=PredictionTime=45 Milliseconds ModelName|squeezenet  Level|Model
     *    EOE
     * 
* * @param event * @return */ @Override public String toSerializable(LogEvent event) { Message eventMessage = event.getMessage(); if (eventMessage == null || eventMessage.getParameters() == null) { return null; } String programName = getStringOrDefault(System.getenv("MXNETMODELSERVER_PROGRAM"), "MXNetModelServer"); String domain = getStringOrDefault(System.getenv("DOMAIN"), "Unknown"); long currentTimeInSec = System.currentTimeMillis() / 1000; Object[] parameters = eventMessage.getParameters(); StringBuilder stringBuilder = new StringBuilder(); for (Object obj : parameters) { if (obj instanceof Metric) { Metric metric = (Metric) obj; String marketPlace = System.getenv("REALM"); stringBuilder.append("HostName=").append(metric.getHostName()); if (metric.getRequestId() != null && !metric.getRequestId().isEmpty()) { stringBuilder.append("\nRequestId=").append(metric.getRequestId()); } // Marketplace format should be : :: if (marketPlace != null && !marketPlace.isEmpty()) { stringBuilder .append("\nMarketplace=") .append(programName) .append(':') .append(domain) .append(':') .append(marketPlace); } stringBuilder .append("\nStartTime=") .append( getStringOrDefault( metric.getTimestamp(), Long.toString(currentTimeInSec))); stringBuilder .append("\nProgram=") .append(programName) .append("\nMetrics=") .append(metric.getMetricName()) .append('=') .append(metric.getValue()) .append(' ') .append(metric.getUnit()); for (Dimension dimension : metric.getDimensions()) { stringBuilder .append(' ') .append(dimension.getName()) .append('|') .append(dimension.getValue()) .append(' '); } stringBuilder.append("\nEOE\n"); } } return stringBuilder.toString(); } @PluginFactory public static QLogLayout createLayout() { return new QLogLayout(); } private static String getStringOrDefault(String val, String defVal) { if (val == null || val.isEmpty()) { return defVal; } return val; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/messages/BaseModelRequest.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.util.messages; public class BaseModelRequest { private WorkerCommands command; private String modelName; public BaseModelRequest() {} public BaseModelRequest(WorkerCommands command, String modelName) { this.command = command; this.modelName = modelName; } public WorkerCommands getCommand() { return command; } public String getModelName() { return modelName; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/messages/InputParameter.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.util.messages; import java.nio.charset.StandardCharsets; public class InputParameter { private String name; private byte[] value; private CharSequence contentType; public InputParameter() {} public InputParameter(String name, String value) { this.name = name; this.value = value.getBytes(StandardCharsets.UTF_8); } public InputParameter(String name, byte[] data) { this(name, data, null); } public InputParameter(String name, byte[] data, CharSequence contentType) { this.name = name; this.contentType = contentType; this.value = data; } public String getName() { return name; } public byte[] getValue() { return value; } public CharSequence getContentType() { return contentType; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/messages/ModelInferenceRequest.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.util.messages; import java.util.ArrayList; import java.util.List; public class ModelInferenceRequest extends BaseModelRequest { private List batch; public ModelInferenceRequest(String modelName) { super(WorkerCommands.PREDICT, modelName); batch = new ArrayList<>(); } public List getRequestBatch() { return batch; } public void setRequestBatch(List requestBatch) { this.batch = requestBatch; } public void addRequest(RequestInput req) { batch.add(req); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/messages/ModelLoadModelRequest.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.util.messages; import com.amazonaws.ml.mms.wlm.Model; public class ModelLoadModelRequest extends BaseModelRequest { /** * ModelLoadModelRequest is a interface between frontend and backend to notify the backend to * load a particular model. */ private String modelPath; private String handler; private int batchSize; private int gpuId; private String ioFileDescriptor; public ModelLoadModelRequest(Model model, int gpuId, String fd) { super(WorkerCommands.LOAD, model.getModelName()); this.gpuId = gpuId; modelPath = model.getModelDir().getAbsolutePath(); handler = model.getModelArchive().getManifest().getModel().getHandler(); batchSize = model.getBatchSize(); ioFileDescriptor = fd; } public String getIoFileDescriptor() { return ioFileDescriptor; } public void setIoFileDescriptor(String ioFileDescriptor) { this.ioFileDescriptor = ioFileDescriptor; } public String getModelPath() { return modelPath; } public String getHandler() { return handler; } public int getBatchSize() { return batchSize; } public int getGpuId() { return gpuId; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/messages/ModelWorkerResponse.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.util.messages; import java.util.List; public class ModelWorkerResponse { private int code; private String message; private List predictions; public ModelWorkerResponse() {} public int getCode() { return code; } public void setCode(int code) { this.code = code; } public String getMessage() { return message; } public void setMessage(String message) { this.message = message; } public List getPredictions() { return predictions; } public void setPredictions(List predictions) { this.predictions = predictions; } public void appendPredictions(Predictions prediction) { this.predictions.add(prediction); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/messages/Predictions.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.util.messages; import java.util.Map; public class Predictions { private String requestId; private int statusCode; private String reasonPhrase; private String contentType; private Map headers; private byte[] resp; public Map getHeaders() { return headers; } public void setHeaders(Map headers) { this.headers = headers; } public Predictions() {} public String getRequestId() { return requestId; } public void setRequestId(String requestId) { this.requestId = requestId; } public byte[] getResp() { return resp; } public void setResp(byte[] resp) { this.resp = resp; } public String getContentType() { return contentType; } public void setStatusCode(int statusCode) { this.statusCode = statusCode; } public void setContentType(String contentType) { this.contentType = contentType; } public int getStatusCode() { return statusCode; } public String getReasonPhrase() { return reasonPhrase; } public void setReasonPhrase(String reasonPhrase) { this.reasonPhrase = reasonPhrase; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/messages/RequestInput.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.util.messages; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; public class RequestInput { private String requestId; private Map headers; private List parameters; public RequestInput(String requestId) { this.requestId = requestId; headers = new HashMap<>(); parameters = new ArrayList<>(); } public String getRequestId() { return requestId; } public void setRequestId(String requestId) { this.requestId = requestId; } public Map getHeaders() { return headers; } public void setHeaders(Map headers) { this.headers = headers; } public void updateHeaders(String key, String val) { headers.put(key, val); } public List getParameters() { return parameters; } public void setParameters(List parameters) { this.parameters = parameters; } public void addParameter(InputParameter modelInput) { parameters.add(modelInput); } public String getStringParameter(String key) { for (InputParameter param : parameters) { if (key.equals(param.getName())) { return new String(param.getValue(), StandardCharsets.UTF_8); } } return null; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/util/messages/WorkerCommands.java ================================================ package com.amazonaws.ml.mms.util.messages; import com.google.gson.annotations.SerializedName; public enum WorkerCommands { @SerializedName("predict") PREDICT("predict"), @SerializedName("load") LOAD("load"), @SerializedName("unload") UNLOAD("unload"), @SerializedName("stats") STATS("stats"); private String command; WorkerCommands(String command) { this.command = command; } public String getCommand() { return command; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/BatchAggregator.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.wlm; import com.amazonaws.ml.mms.util.messages.BaseModelRequest; import com.amazonaws.ml.mms.util.messages.ModelInferenceRequest; import com.amazonaws.ml.mms.util.messages.ModelLoadModelRequest; import com.amazonaws.ml.mms.util.messages.ModelWorkerResponse; import com.amazonaws.ml.mms.util.messages.Predictions; import com.amazonaws.ml.mms.util.messages.RequestInput; import io.netty.handler.codec.http.HttpResponseStatus; import java.util.LinkedHashMap; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class BatchAggregator { private static final Logger logger = LoggerFactory.getLogger(BatchAggregator.class); private Model model; private Map jobs; public BatchAggregator(Model model) { this.model = model; jobs = new LinkedHashMap<>(); } public BaseModelRequest getRequest(String threadName, WorkerState state) throws InterruptedException { jobs.clear(); ModelInferenceRequest req = new ModelInferenceRequest(model.getModelName()); model.pollBatch( threadName, (state == WorkerState.WORKER_MODEL_LOADED) ? 0 : Long.MAX_VALUE, jobs); for (Job j : jobs.values()) { if (j.isControlCmd()) { if (jobs.size() > 1) { throw new IllegalStateException( "Received more than 1 control command. " + "Control messages should be processed/retrieved one at a time."); } RequestInput input = j.getPayload(); int gpuId = -1; String gpu = input.getStringParameter("gpu"); if (gpu != null) { gpuId = Integer.parseInt(gpu); } return new ModelLoadModelRequest(model, gpuId, threadName); } else { j.setScheduled(); req.addRequest(j.getPayload()); } } return req; } public void sendResponse(ModelWorkerResponse message) { // TODO: Handle prediction level code if (message.getCode() == 200) { if (jobs.isEmpty()) { // this is from initial load. return; } for (Predictions prediction : message.getPredictions()) { String jobId = prediction.getRequestId(); Job job = jobs.remove(jobId); if (job == null) { throw new IllegalStateException("Unexpected job: " + jobId); } job.response( prediction.getResp(), prediction.getContentType(), prediction.getStatusCode(), prediction.getReasonPhrase(), prediction.getHeaders()); } } else { for (String reqId : jobs.keySet()) { Job j = jobs.remove(reqId); if (j == null) { throw new IllegalStateException("Unexpected job: " + reqId); } j.sendError(HttpResponseStatus.valueOf(message.getCode()), message.getMessage()); } if (!jobs.isEmpty()) { throw new IllegalStateException("Not all jobs get response."); } } } public void sendError(BaseModelRequest message, String error, HttpResponseStatus status) { if (message instanceof ModelLoadModelRequest) { logger.warn("Load model failed: {}, error: {}", message.getModelName(), error); return; } if (message != null) { ModelInferenceRequest msg = (ModelInferenceRequest) message; for (RequestInput req : msg.getRequestBatch()) { String requestId = req.getRequestId(); Job job = jobs.remove(requestId); if (job == null) { logger.error("Unexpected job: " + requestId); } else { job.sendError(status, error); } } if (!jobs.isEmpty()) { jobs.clear(); logger.error("Not all jobs get response."); } } else { // Send the error message to all the jobs for (Map.Entry j : jobs.entrySet()) { String jobsId = j.getValue().getJobId(); Job job = jobs.remove(jobsId); if (job.isControlCmd()) { job.sendError(status, error); } else { // Data message can be handled by other workers. // If batch has gone past its batch max delay timer? model.addFirst(job); } } } } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Job.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.wlm; import com.amazonaws.ml.mms.http.InternalServerException; import com.amazonaws.ml.mms.util.NettyUtils; import com.amazonaws.ml.mms.util.messages.RequestInput; import com.amazonaws.ml.mms.util.messages.WorkerCommands; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.DefaultHttpHeadersFactory; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpVersion; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class Job { private static final Logger logger = LoggerFactory.getLogger(Job.class); private ChannelHandlerContext ctx; private String modelName; private WorkerCommands cmd; // Else its data msg or inf requests private RequestInput input; private long begin; private long scheduled; public Job( ChannelHandlerContext ctx, String modelName, WorkerCommands cmd, RequestInput input) { this.ctx = ctx; this.modelName = modelName; this.cmd = cmd; this.input = input; begin = System.currentTimeMillis(); scheduled = begin; } public String getJobId() { return input.getRequestId(); } public String getModelName() { return modelName; } public WorkerCommands getCmd() { return cmd; } public boolean isControlCmd() { return !WorkerCommands.PREDICT.equals(cmd); } public RequestInput getPayload() { return input; } public void setScheduled() { scheduled = System.currentTimeMillis(); } public void response( byte[] body, CharSequence contentType, int statusCode, String statusPhrase, Map responseHeaders) { HttpResponseStatus status = (statusPhrase == null) ? HttpResponseStatus.valueOf(statusCode) : HttpResponseStatus.valueOf(statusCode, statusPhrase); FullHttpResponse resp = new DefaultFullHttpResponse( HttpVersion.HTTP_1_1, status, Unpooled.directBuffer(), DefaultHttpHeadersFactory.headersFactory(), DefaultHttpHeadersFactory.trailersFactory()); if (contentType != null && contentType.length() > 0) { resp.headers().set(HttpHeaderNames.CONTENT_TYPE, contentType); } if (responseHeaders != null) { for (Map.Entry e : responseHeaders.entrySet()) { resp.headers().set(e.getKey(), e.getValue()); } } resp.content().writeBytes(body); /* * We can load the models based on the configuration file.Since this Job is * not driven by the external connections, we could have a empty context for * this job. We shouldn't try to send a response to ctx if this is not triggered * by external clients. */ if (ctx != null) { NettyUtils.sendHttpResponse(ctx, resp, true); } logger.debug( "Waiting time: {}, Backend time: {}", scheduled - begin, System.currentTimeMillis() - scheduled); } public void sendError(HttpResponseStatus status, String error) { /* * We can load the models based on the configuration file.Since this Job is * not driven by the external connections, we could have a empty context for * this job. We shouldn't try to send a response to ctx if this is not triggered * by external clients. */ if (ctx != null) { NettyUtils.sendError(ctx, status, new InternalServerException(error)); } logger.debug( "Waiting time: {}, Inference time: {}", scheduled - begin, System.currentTimeMillis() - begin); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/Model.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.wlm; import com.amazonaws.ml.mms.archive.ModelArchive; import com.amazonaws.ml.mms.util.ConfigManager; import java.io.File; import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.ReentrantLock; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class Model { public static final String DEFAULT_DATA_QUEUE = "DATA_QUEUE"; private static final Logger logger = LoggerFactory.getLogger(Model.class); private ModelArchive modelArchive; private int minWorkers; private int maxWorkers; private int batchSize; private int maxBatchDelay; private String preloadModel; private AtomicInteger port; // Port on which the model server is running private ReentrantLock lock; private int responseTimeoutSeconds; private WorkerThread serverThread; // Total number of subsequent inference request failures private AtomicInteger failedInfReqs; // Per worker thread job queue. This separates out the control queue from data queue private ConcurrentMap> jobsDb; public Model(ModelArchive modelArchive, int queueSize, String preloadModel) { this.modelArchive = modelArchive; this.preloadModel = preloadModel; batchSize = 1; maxBatchDelay = 100; jobsDb = new ConcurrentHashMap<>(); // Always have a queue for data jobsDb.putIfAbsent(DEFAULT_DATA_QUEUE, new LinkedBlockingDeque<>(queueSize)); failedInfReqs = new AtomicInteger(0); port = new AtomicInteger(-1); lock = new ReentrantLock(); } public String getModelName() { return modelArchive.getModelName(); } public File getModelDir() { return modelArchive.getModelDir(); } public String getModelUrl() { return modelArchive.getUrl(); } public ModelArchive getModelArchive() { return modelArchive; } public int getMinWorkers() { return minWorkers; } public void setMinWorkers(int minWorkers) { this.minWorkers = minWorkers; } public int getMaxWorkers() { return maxWorkers; } public void setMaxWorkers(int maxWorkers) { this.maxWorkers = maxWorkers; } public int getBatchSize() { return batchSize; } public void setBatchSize(int batchSize) { this.batchSize = batchSize; } public int getMaxBatchDelay() { return maxBatchDelay; } public void setMaxBatchDelay(int maxBatchDelay) { this.maxBatchDelay = maxBatchDelay; } public void addJob(String threadId, Job job) { LinkedBlockingDeque blockingDeque = jobsDb.get(threadId); if (blockingDeque == null) { blockingDeque = new LinkedBlockingDeque<>(); jobsDb.put(threadId, blockingDeque); } blockingDeque.offer(job); } public void removeJobQueue(String threadId) { if (!threadId.equals(DEFAULT_DATA_QUEUE)) { jobsDb.remove(threadId); } } public boolean addJob(Job job) { return jobsDb.get(DEFAULT_DATA_QUEUE).offer(job); } public void addFirst(Job job) { jobsDb.get(DEFAULT_DATA_QUEUE).addFirst(job); } public void pollBatch(String threadId, long waitTime, Map jobsRepo) throws InterruptedException { if (jobsRepo == null || threadId == null || threadId.isEmpty()) { throw new IllegalArgumentException("Invalid input given provided"); } if (!jobsRepo.isEmpty()) { throw new IllegalArgumentException( "The jobs repo provided contains stale jobs. Clear them!!"); } LinkedBlockingDeque jobsQueue = jobsDb.get(threadId); if (jobsQueue != null && !jobsQueue.isEmpty()) { Job j = jobsQueue.poll(waitTime, TimeUnit.MILLISECONDS); if (j != null) { jobsRepo.put(j.getJobId(), j); return; } } try { lock.lockInterruptibly(); long maxDelay = maxBatchDelay; jobsQueue = jobsDb.get(DEFAULT_DATA_QUEUE); Job j = jobsQueue.poll(Long.MAX_VALUE, TimeUnit.MILLISECONDS); logger.trace("get first job: {}", Objects.requireNonNull(j).getJobId()); jobsRepo.put(j.getJobId(), j); long begin = System.currentTimeMillis(); for (int i = 0; i < batchSize - 1; ++i) { j = jobsQueue.poll(maxDelay, TimeUnit.MILLISECONDS); if (j == null) { break; } long end = System.currentTimeMillis(); maxDelay -= end - begin; begin = end; jobsRepo.put(j.getJobId(), j); if (maxDelay <= 0) { break; } } logger.trace("sending jobs, size: {}", jobsRepo.size()); } finally { if (lock.isHeldByCurrentThread()) { lock.unlock(); } } } public int getPort() { return port.get(); } public void setPort(int port) { this.port.set(port); } public int incrFailedInfReqs() { return failedInfReqs.incrementAndGet(); } public void resetFailedInfReqs() { failedInfReqs.set(0); } public int getResponseTimeoutSeconds() { return ConfigManager.getInstance().isDebug() ? Integer.MAX_VALUE : responseTimeoutSeconds; } public void setResponseTimeoutSeconds(int responseTimeoutSeconds) { this.responseTimeoutSeconds = responseTimeoutSeconds; } public WorkerThread getServerThread() { return serverThread; } public void setServerThread(WorkerThread serverThread) { this.serverThread = serverThread; } public String preloadModel() { return preloadModel; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/ModelManager.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.wlm; import com.amazonaws.ml.mms.archive.Manifest; import com.amazonaws.ml.mms.archive.ModelArchive; import com.amazonaws.ml.mms.archive.ModelException; import com.amazonaws.ml.mms.archive.ModelNotFoundException; import com.amazonaws.ml.mms.http.ConflictStatusException; import com.amazonaws.ml.mms.http.StatusResponse; import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.util.NettyUtils; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.http.HttpResponseStatus; import java.io.IOException; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeoutException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public final class ModelManager { private static final Logger logger = LoggerFactory.getLogger(ModelManager.class); private static ModelManager modelManager; private ConfigManager configManager; private WorkLoadManager wlm; private ConcurrentHashMap models; private HashSet startupModels; private ScheduledExecutorService scheduler; private ModelManager(ConfigManager configManager, WorkLoadManager wlm) { this.configManager = configManager; this.wlm = wlm; models = new ConcurrentHashMap<>(); scheduler = Executors.newScheduledThreadPool(2); this.startupModels = new HashSet<>(); } public ScheduledExecutorService getScheduler() { return scheduler; } public static void init(ConfigManager configManager, WorkLoadManager wlm) { modelManager = new ModelManager(configManager, wlm); } public static ModelManager getInstance() { return modelManager; } public ModelArchive registerModel(String url, String defaultModelName, String preloadModel) throws ModelException, IOException, InterruptedException, ExecutionException, TimeoutException { return registerModel( url, null, null, null, 1, 100, configManager.getDefaultResponseTimeoutSeconds(), defaultModelName, preloadModel); } public ModelArchive registerModel( String url, String modelName, Manifest.RuntimeType runtime, String handler, int batchSize, int maxBatchDelay, int responseTimeoutSeconds, String defaultModelName, String preloadModel) throws ModelException, IOException, InterruptedException, ExecutionException, TimeoutException { ModelArchive archive = ModelArchive.downloadModel(configManager.getModelStore(), url); if (modelName == null || modelName.isEmpty()) { if (archive.getModelName() == null || archive.getModelName().isEmpty()) { archive.getManifest().getModel().setModelName(defaultModelName); } modelName = archive.getModelName(); } else { archive.getManifest().getModel().setModelName(modelName); } if (runtime != null) { archive.getManifest().setRuntime(runtime); } if (handler != null) { archive.getManifest().getModel().setHandler(handler); } else if (archive.getHandler() == null || archive.getHandler().isEmpty()) { archive.getManifest() .getModel() .setHandler(configManager.getMmsDefaultServiceHandler()); } archive.validate(); Model model = new Model(archive, configManager.getJobQueueSize(), preloadModel); model.setBatchSize(batchSize); model.setMaxBatchDelay(maxBatchDelay); model.setResponseTimeoutSeconds(responseTimeoutSeconds); Model existingModel = models.putIfAbsent(modelName, model); if (existingModel != null) { // model already exists throw new ConflictStatusException("Model " + modelName + " is already registered."); } if (configManager.isDebug()) { model.setPort(9000); } else { startBackendServer(model); } models.put(modelName, model); logger.info("Model {} loaded.", model.getModelName()); return archive; } public HttpResponseStatus unregisterModel(String modelName) { Model model = models.remove(modelName); if (model == null) { logger.warn("Model not found: " + modelName); return HttpResponseStatus.NOT_FOUND; } model.setMinWorkers(0); model.setMaxWorkers(0); CompletableFuture futureStatus = wlm.modelChanged(model); HttpResponseStatus httpResponseStatus = HttpResponseStatus.OK; try { httpResponseStatus = futureStatus.get(); } catch (InterruptedException | ExecutionException e) { logger.warn("Process was interrupted while cleaning resources."); httpResponseStatus = HttpResponseStatus.INTERNAL_SERVER_ERROR; } // Only continue cleaning if resource cleaning succeeded if (httpResponseStatus == HttpResponseStatus.OK) { model.getModelArchive().clean(); startupModels.remove(modelName); logger.info("Model {} unregistered.", modelName); } else { models.put(modelName, model); } return httpResponseStatus; } public void startBackendServer(Model model) throws InterruptedException, ExecutionException, TimeoutException { CompletableFuture future = new CompletableFuture<>(); if (model == null) { throw new AssertionError("Model not found"); } wlm.addServerThread(model, future); } public CompletableFuture updateModel( String modelName, int minWorkers, int maxWorkers) { Model model = models.get(modelName); if (model == null) { throw new AssertionError("Model not found: " + modelName); } model.setMinWorkers(minWorkers); model.setMaxWorkers(maxWorkers); logger.debug("updateModel: {}, count: {}", modelName, minWorkers); return wlm.modelChanged(model); } public Map getModels() { return models; } public List getWorkers(String modelName) { return wlm.getWorkers(modelName); } public Map getWorkers() { return wlm.getWorkers(); } public boolean addJob(Job job) throws ModelNotFoundException { String modelName = job.getModelName(); Model model = models.get(modelName); if (model == null) { throw new ModelNotFoundException("Model not found: " + modelName); } if (wlm.hasNoWorker(modelName)) { return false; } return model.addJob(job); } public void workerStatus(final ChannelHandlerContext ctx) { Runnable r = () -> { String response = "Healthy"; int numWorking = 0; int numScaled = 0; for (Map.Entry m : models.entrySet()) { numScaled += m.getValue().getMinWorkers(); numWorking += wlm.getNumRunningWorkers(m.getValue().getModelName()); } if ((numWorking > 0) && (numWorking < numScaled)) { response = "Partial Healthy"; } else if ((numWorking == 0) && (numScaled > 0)) { response = "Unhealthy"; } // TODO: Check if its OK to send other 2xx errors to ALB for "Partial Healthy" // and "Unhealthy" NettyUtils.sendJsonResponse( ctx, new StatusResponse(response), HttpResponseStatus.OK); }; wlm.scheduleAsync(r); } public boolean scaleRequestStatus(String modelName) { Model model = ModelManager.getInstance().getModels().get(modelName); int numWorkers = wlm.getNumRunningWorkers(modelName); return model == null || model.getMinWorkers() <= numWorkers; } public void submitTask(Runnable runnable) { wlm.scheduleAsync(runnable); } public Set getStartupModels() { return startupModels; } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/WorkLoadManager.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.wlm; import com.amazonaws.ml.mms.util.ConfigManager; import io.netty.channel.EventLoopGroup; import io.netty.handler.codec.http.HttpResponseStatus; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class WorkLoadManager { private ExecutorService threadPool; private ConcurrentHashMap> workers; private ConfigManager configManager; private EventLoopGroup backendGroup; private AtomicInteger port; private AtomicInteger gpuCounter; private AtomicInteger threadNumber; private static final Logger logger = LoggerFactory.getLogger(WorkLoadManager.class); public WorkLoadManager(ConfigManager configManager, EventLoopGroup backendGroup) { this.configManager = configManager; this.backendGroup = backendGroup; this.port = new AtomicInteger(9000); this.gpuCounter = new AtomicInteger(0); threadPool = Executors.newCachedThreadPool(); workers = new ConcurrentHashMap<>(); threadNumber = new AtomicInteger(0); } public List getWorkers(String modelName) { List list = workers.get(modelName); if (list == null) { return Collections.emptyList(); } return new ArrayList<>(list); } public Map getWorkers() { Map map = new HashMap<>(); for (Map.Entry> entry : workers.entrySet()) { // Add server thread String modelName = entry.getKey(); List workerThreads = entry.getValue(); WorkerThread serverThread = ModelManager.getInstance().getModels().get(modelName).getServerThread(); map.put(serverThread.getPid(), serverThread); for (WorkerThread worker : workerThreads) { map.put(worker.getPid(), worker); } } return map; } public boolean hasNoWorker(String modelName) { List worker = workers.get(modelName); if (worker == null) { return true; } return worker.isEmpty(); } public int getNumRunningWorkers(String modelName) { int numWorking = 0; List threads = workers.getOrDefault(modelName, null); if (threads != null) { for (WorkerThread thread : threads) { if ((thread.getState() != WorkerState.WORKER_STOPPED) && (thread.getState() != WorkerState.WORKER_ERROR) && (thread.getState() != WorkerState.WORKER_SCALED_DOWN)) { numWorking += 1; } } } return numWorking; } public CompletableFuture modelChanged(Model model) { synchronized (model.getModelName()) { CompletableFuture future = new CompletableFuture<>(); int minWorker = model.getMinWorkers(); int maxWorker = model.getMaxWorkers(); List threads; if (minWorker == 0) { threads = workers.remove(model.getModelName()); if (threads == null) { if (maxWorker == 0) { return shutdownServerThread(model, future); } future.complete(HttpResponseStatus.OK); return future; } } else { threads = workers.computeIfAbsent(model.getModelName(), k -> new ArrayList<>()); } int currentWorkers = threads.size(); if (currentWorkers < minWorker) { addThreads(threads, model, minWorker - currentWorkers, future); } else { for (int i = currentWorkers - 1; i >= maxWorker; --i) { WorkerThread thread = threads.remove(i); thread.shutdown(); } if (maxWorker == 0) { return shutdownServerThread(model, future); } future.complete(HttpResponseStatus.OK); } return future; } } private CompletableFuture shutdownServerThread( Model model, CompletableFuture future) { model.getServerThread().shutdown(); WorkerLifeCycle lifecycle = model.getServerThread().getLifeCycle(); Process workerProcess = lifecycle.getProcess(); if (workerProcess.isAlive()) { boolean workerDestroyed = false; workerProcess.destroyForcibly(); try { workerDestroyed = workerProcess.waitFor( configManager.getUnregisterModelTimeout(), TimeUnit.SECONDS); } catch (InterruptedException e) { logger.warn( "WorkerThread interrupted during waitFor, possible asynch resource cleanup."); future.complete(HttpResponseStatus.INTERNAL_SERVER_ERROR); return future; } if (!workerDestroyed) { logger.warn("WorkerThread timed out while cleaning, please resend request."); future.complete(HttpResponseStatus.REQUEST_TIMEOUT); return future; } } future.complete(HttpResponseStatus.OK); return future; } public void addServerThread(Model model, CompletableFuture future) throws InterruptedException, ExecutionException, TimeoutException { WorkerStateListener listener = new WorkerStateListener(future, 1); BatchAggregator aggregator = new BatchAggregator(model); synchronized (model.getModelName()) { model.setPort(port.getAndIncrement()); WorkerThread thread = new WorkerThread( configManager, backendGroup, model.getPort(), -1, model, aggregator, listener, threadNumber.getAndIncrement(), true); model.setServerThread(thread); threadPool.submit(thread); future.get(1, TimeUnit.MINUTES); } } private void addThreads( List threads, Model model, int count, CompletableFuture future) { WorkerStateListener listener = new WorkerStateListener(future, count); int maxGpu = configManager.getNumberOfGpu(); for (int i = 0; i < count; ++i) { int gpuId = -1; if (maxGpu > 0) { gpuId = gpuCounter.accumulateAndGet(maxGpu, (prev, maxGpuId) -> ++prev % maxGpuId); } BatchAggregator aggregator = new BatchAggregator(model); WorkerThread thread = new WorkerThread( configManager, backendGroup, model.getPort(), gpuId, model, aggregator, listener, threadNumber.getAndIncrement(), false); threads.add(thread); threadPool.submit(thread); } } public void scheduleAsync(Runnable r) { threadPool.execute(r); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/WorkerInitializationException.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.wlm; public class WorkerInitializationException extends Exception { static final long serialVersionUID = 1L; /** Creates a new {@code WorkerInitializationException} instance. */ public WorkerInitializationException(String message) { super(message); } /** * Constructs a new {@code WorkerInitializationException} with the specified detail message and * cause. * * @param message the detail message (which is saved for later retrieval by the {@link * #getMessage()} method). * @param cause the cause (which is saved for later retrieval by the {@link #getCause()} * method). (A null value is permitted, and indicates that the cause is nonexistent * or unknown.) */ public WorkerInitializationException(String message, Throwable cause) { super(message, cause); } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/WorkerLifeCycle.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.wlm; import com.amazonaws.ml.mms.archive.Manifest; import com.amazonaws.ml.mms.metrics.Metric; import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.util.Connector; import java.io.File; import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; import java.util.Scanner; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.regex.Pattern; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class WorkerLifeCycle { static final Logger logger = LoggerFactory.getLogger(WorkerLifeCycle.class); private ConfigManager configManager; private Model model; private int pid = -1; private Process process; private CountDownLatch latch; private boolean success; private Connector connector; private ReaderThread errReader; private ReaderThread outReader; public WorkerLifeCycle(ConfigManager configManager, Model model) { this.configManager = configManager; this.model = model; this.latch = new CountDownLatch(1); } private String[] getEnvString(String cwd, String modelPath, String handler) { ArrayList envList = new ArrayList<>(); Pattern blackList = configManager.getBlacklistPattern(); String handlerFile = handler; if (handler.contains(":")) { handlerFile = handler.split(":")[0]; if (handlerFile.contains("/")) { handlerFile = handlerFile.substring(0, handlerFile.lastIndexOf('/')); } } StringBuilder pythonPath = new StringBuilder(); HashMap environment = new HashMap<>(System.getenv()); environment.putAll(configManager.getBackendConfiguration()); pythonPath.append(handlerFile).append(File.pathSeparatorChar); if (System.getenv("PYTHONPATH") != null) { pythonPath.append(System.getenv("PYTHONPATH")).append(File.pathSeparatorChar); } pythonPath.append(modelPath); if (!cwd.contains("site-packages") && !cwd.contains("dist-packages")) { pythonPath.append(File.pathSeparatorChar).append(cwd); } environment.put("PYTHONPATH", pythonPath.toString()); for (Map.Entry entry : environment.entrySet()) { if (!blackList.matcher(entry.getKey()).matches()) { envList.add(entry.getKey() + '=' + entry.getValue()); } } return envList.toArray(new String[0]); // NOPMD } public synchronized void attachIOStreams( String threadName, InputStream outStream, InputStream errStream) { logger.warn("attachIOStreams() threadName={}", threadName); errReader = new ReaderThread(threadName, errStream, true, this); outReader = new ReaderThread(threadName, outStream, false, this); errReader.start(); outReader.start(); } public synchronized void terminateIOStreams() { if (errReader != null) { logger.warn("terminateIOStreams() threadName={}", errReader.getName()); errReader.terminate(); } if (outReader != null) { logger.warn("terminateIOStreams() threadName={}", outReader.getName()); outReader.terminate(); } } public void startBackendServer(int port) throws WorkerInitializationException, InterruptedException { File workingDir = new File(configManager.getModelServerHome()); File modelPath; setPort(port); try { modelPath = model.getModelDir().getCanonicalFile(); } catch (IOException e) { throw new WorkerInitializationException("Failed get MMS home directory", e); } String[] args = new String[16]; Manifest.RuntimeType runtime = model.getModelArchive().getManifest().getRuntime(); if (runtime == Manifest.RuntimeType.PYTHON) { args[0] = configManager.getPythonExecutable(); } else { args[0] = runtime.getValue(); } args[1] = new File(workingDir, "mms/model_service_worker.py").getAbsolutePath(); args[2] = "--sock-type"; args[3] = connector.getSocketType(); args[4] = connector.isUds() ? "--sock-name" : "--port"; args[5] = connector.getSocketPath(); args[6] = "--handler"; args[7] = model.getModelArchive().getManifest().getModel().getHandler(); args[8] = "--model-path"; args[9] = model.getModelDir().getAbsolutePath(); args[10] = "--model-name"; args[11] = model.getModelName(); args[12] = "--preload-model"; args[13] = model.preloadModel(); args[14] = "--tmp-dir"; args[15] = System.getProperty("java.io.tmpdir"); String[] envp = getEnvString( workingDir.getAbsolutePath(), modelPath.getAbsolutePath(), model.getModelArchive().getManifest().getModel().getHandler()); try { latch = new CountDownLatch(1); synchronized (this) { String threadName = "W-" + port + '-' + model.getModelName() .substring(0, Math.min(model.getModelName().length(), 25)); process = Runtime.getRuntime().exec(args, envp, modelPath); attachIOStreams(threadName, process.getInputStream(), process.getErrorStream()); } if (latch.await(2, TimeUnit.MINUTES)) { if (!success) { throw new WorkerInitializationException("Backend stream closed."); } return; } throw new WorkerInitializationException("Backend worker startup time out."); } catch (IOException e) { throw new WorkerInitializationException("Failed start worker process", e); } finally { if (!success) { exit(); } } } public synchronized void exit() { if (process != null) { process.destroyForcibly(); connector.clean(); terminateIOStreams(); } } public synchronized Integer getExitValue() { if (process != null && !process.isAlive()) { return process.exitValue(); } return null; } void setSuccess(boolean success) { this.success = success; latch.countDown(); } public synchronized int getPid() { return pid; } public synchronized void setPid(int pid) { this.pid = pid; } private synchronized void setPort(int port) { connector = new Connector(port); } public Process getProcess() { return process; } private static final class ReaderThread extends Thread { private InputStream is; private boolean error; private WorkerLifeCycle lifeCycle; private AtomicBoolean isRunning = new AtomicBoolean(true); static final Logger loggerModelMetrics = LoggerFactory.getLogger(ConfigManager.MODEL_METRICS_LOGGER); public ReaderThread(String name, InputStream is, boolean error, WorkerLifeCycle lifeCycle) { super(name + (error ? "-stderr" : "-stdout")); this.is = is; this.error = error; this.lifeCycle = lifeCycle; } public void terminate() { isRunning.set(false); } @Override public void run() { try (Scanner scanner = new Scanner(is, StandardCharsets.UTF_8.name())) { while (isRunning.get() && scanner.hasNext()) { String result = scanner.nextLine(); if (result == null) { break; } if (result.startsWith("[METRICS]")) { loggerModelMetrics.info("{}", Metric.parse(result.substring(9))); continue; } if ("MMS worker started.".equals(result)) { lifeCycle.setSuccess(true); } else if (result.startsWith("[PID]")) { lifeCycle.setPid(Integer.parseInt(result.substring("[PID] ".length()))); } if (error) { logger.warn(result); } else { logger.info(result); } } } catch (Exception e) { logger.error("Couldn't create scanner - {}", getName(), e); } finally { logger.info("Stopped Scanner - {}", getName()); lifeCycle.setSuccess(false); try { is.close(); } catch (IOException e) { logger.error("Failed to close stream for thread {}", this.getName(), e); } } } } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/WorkerState.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.wlm; public enum WorkerState { WORKER_STARTED, WORKER_MODEL_LOADED, WORKER_STOPPED, WORKER_ERROR, WORKER_SCALED_DOWN } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/WorkerStateListener.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.wlm; import io.netty.handler.codec.http.HttpResponseStatus; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicInteger; public class WorkerStateListener { private CompletableFuture future; private AtomicInteger count; public WorkerStateListener(CompletableFuture future, int count) { this.future = future; this.count = new AtomicInteger(count); } public void notifyChangeState(String modelName, WorkerState state, HttpResponseStatus status) { // Update success and fail counts if (state == WorkerState.WORKER_MODEL_LOADED) { if (count.decrementAndGet() == 0) { future.complete(status); } } if (state == WorkerState.WORKER_ERROR || state == WorkerState.WORKER_STOPPED) { future.complete(status); } } } ================================================ FILE: frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/WorkerThread.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.wlm; import com.amazonaws.ml.mms.metrics.Dimension; import com.amazonaws.ml.mms.metrics.Metric; import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.util.Connector; import com.amazonaws.ml.mms.util.NettyUtils; import com.amazonaws.ml.mms.util.codec.ModelRequestEncoder; import com.amazonaws.ml.mms.util.codec.ModelResponseDecoder; import com.amazonaws.ml.mms.util.messages.BaseModelRequest; import com.amazonaws.ml.mms.util.messages.InputParameter; import com.amazonaws.ml.mms.util.messages.ModelWorkerResponse; import com.amazonaws.ml.mms.util.messages.RequestInput; import com.amazonaws.ml.mms.util.messages.WorkerCommands; import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.HttpResponseStatus; import java.io.FileNotFoundException; import java.io.IOException; import java.io.RandomAccessFile; import java.net.SocketAddress; import java.nio.channels.Channels; import java.util.UUID; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class WorkerThread implements Runnable { static final Logger logger = LoggerFactory.getLogger(WorkerThread.class); private static final Logger loggerMmsMetrics = LoggerFactory.getLogger(ConfigManager.MODEL_SERVER_METRICS_LOGGER); private Metric workerLoadTime; private static final int[] BACK_OFF = { 0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597 }; static final long WORKER_TIMEOUT = ConfigManager.getInstance().isDebug() ? Long.MAX_VALUE : 2L; static final ModelRequestEncoder ENCODER = new ModelRequestEncoder(ConfigManager.getInstance().getPreferDirectBuffer()); private EventLoopGroup backendEventGroup; private int port; private Model model; private Channel backendChannel; private AtomicBoolean running = new AtomicBoolean(true); private int backoffIdx; private BatchAggregator aggregator; private WorkerStateListener listener; ArrayBlockingQueue replies; private int gpuId; private long memory; private long startTime; private AtomicReference currentThread = new AtomicReference<>(); private String workerId; private String threadName; private BaseModelRequest req; private WorkerState state; private WorkerLifeCycle lifeCycle; private boolean serverThread; private RandomAccessFile out; private RandomAccessFile err; private Connector connector; public WorkerState getState() { return state; } public WorkerLifeCycle getLifeCycle() { return lifeCycle; } public WorkerThread( ConfigManager configManager, EventLoopGroup backendEventGroup, int port, int gpuId, Model model, BatchAggregator aggregator, WorkerStateListener listener, int threadNumber, boolean serverThread) { this.workerId = String.valueOf(port); // Unique across all workers. this.backendEventGroup = backendEventGroup; this.port = port; this.model = model; this.aggregator = aggregator; this.gpuId = gpuId; this.listener = listener; startTime = System.currentTimeMillis(); lifeCycle = new WorkerLifeCycle(configManager, model); replies = new ArrayBlockingQueue<>(1); this.serverThread = serverThread; this.threadName = !serverThread ? "W-" + model.getModelName() .substring(0, Math.min(model.getModelName().length(), 25)) + '-' + threadNumber : "BackendServer-" + model.getModelName(); workerLoadTime = new Metric( getWorkerName(), String.valueOf(System.currentTimeMillis()), "ms", ConfigManager.getInstance().getHostName(), new Dimension("Level", "Host")); } private void runWorker() throws WorkerInitializationException, InterruptedException, FileNotFoundException { int responseTimeoutSeconds = model.getResponseTimeoutSeconds(); while (isRunning()) { req = aggregator.getRequest(backendChannel.id().asLongText(), state); backendChannel.writeAndFlush(req).sync(); long begin = System.currentTimeMillis(); // TODO: Change this to configurable param ModelWorkerResponse reply = replies.poll(responseTimeoutSeconds, TimeUnit.SECONDS); long duration = System.currentTimeMillis() - begin; logger.info("Backend response time: {}", duration); if (reply != null) { aggregator.sendResponse(reply); } else { int val = model.incrFailedInfReqs(); logger.error("Number or consecutive unsuccessful inference {}", val); throw new WorkerInitializationException( "Backend worker did not respond in given time"); } switch (req.getCommand()) { case PREDICT: model.resetFailedInfReqs(); break; case LOAD: String message = reply.getMessage(); String tmpdir = System.getProperty("java.io.tmpdir"); out = new RandomAccessFile( tmpdir + '/' + backendChannel.id().asLongText() + "-stdout", "rw"); err = new RandomAccessFile( tmpdir + '/' + backendChannel.id().asLongText() + "-stderr", "rw"); if (reply.getCode() == 200) { setState(WorkerState.WORKER_MODEL_LOADED, HttpResponseStatus.OK); lifeCycle.setPid( Integer.parseInt( message.substring( message.indexOf("[PID]:") + 6, message.length()))); lifeCycle.attachIOStreams( threadName, Channels.newInputStream(out.getChannel()), Channels.newInputStream(err.getChannel())); backoffIdx = 0; } else { setState( WorkerState.WORKER_ERROR, HttpResponseStatus.valueOf(reply.getCode())); } break; case UNLOAD: case STATS: default: break; } req = null; } } @Override public void run() { Process process = null; Thread thread = Thread.currentThread(); thread.setName(getWorkerName()); currentThread.set(thread); HttpResponseStatus status = HttpResponseStatus.INTERNAL_SERVER_ERROR; try { if (!serverThread) { connect(); runWorker(); } else { // TODO: Move this logic to a seperate ServerThread class // This is server thread and shouldn't come out as long as process exists in CPU. model.setPort(port); lifeCycle.startBackendServer(port); setState(WorkerState.WORKER_MODEL_LOADED, HttpResponseStatus.OK); process = lifeCycle.getProcess(); process.waitFor(); } } catch (InterruptedException e) { if (state == WorkerState.WORKER_SCALED_DOWN) { logger.debug("Shutting down the thread .. Scaling down."); } else { logger.debug( "Backend worker monitoring thread interrupted or backend worker process died.", e); } } catch (WorkerInitializationException e) { logger.error("Backend worker error", e); } catch (OutOfMemoryError oom) { logger.error("Out of memory error when creating workers", oom); status = HttpResponseStatus.INSUFFICIENT_STORAGE; } catch (Throwable t) { logger.warn("Backend worker thread exception.", t); } finally { // WorkerThread is running in thread pool, the thread will be assigned to next // Runnable once this worker is finished. If currentThread keep holding the reference // of the thread, currentThread.interrupt() might kill next worker. backendChannel.disconnect(); currentThread.set(null); Integer exitValue = lifeCycle.getExitValue(); if (exitValue != null && exitValue == 137) { status = HttpResponseStatus.INSUFFICIENT_STORAGE; } if (!serverThread && req != null) { aggregator.sendError(req, "Worker died.", status); } else if (serverThread) { model.setPort(-1); if (process != null && process.isAlive()) { process.destroyForcibly(); try { process.waitFor(1, TimeUnit.SECONDS); } catch (InterruptedException e) { logger.warn( "WorkerThread interrupted during waitFor, possible asynch resource cleanup."); } } } setState(WorkerState.WORKER_STOPPED, status); lifeCycle.exit(); retry(); } } public String getWorkerId() { return workerId; } public long getMemory() { return memory; } public void setMemory(long memory) { this.memory = memory; } private void connect() throws WorkerInitializationException, InterruptedException, FileNotFoundException { if (!this.serverThread && (model.getPort() == -1)) { throw new WorkerInitializationException("Backend server is not runniing"); } String modelName = model.getModelName(); setState(WorkerState.WORKER_STARTED, HttpResponseStatus.OK); final CountDownLatch latch = new CountDownLatch(1); final int responseBufferSize = ConfigManager.getInstance().getMaxResponseSize(); try { connector = new Connector(port); Bootstrap b = new Bootstrap(); b.group(backendEventGroup) .channel(connector.getClientChannel()) .handler( new ChannelInitializer() { @Override public void initChannel(Channel ch) { ChannelPipeline p = ch.pipeline(); p.addLast(ENCODER); p.addLast(new ModelResponseDecoder(responseBufferSize)); p.addLast(new WorkerHandler()); } }); SocketAddress address = connector.getSocketAddress(); logger.info("Connecting to: {}", address); backendChannel = b.connect(address).sync().channel(); backendChannel .closeFuture() .addListener( (ChannelFutureListener) future -> { latch.countDown(); logger.info( "{} Worker disconnected. {}", getWorkerId(), state); Thread thread = currentThread.getAndSet(null); if (thread != null) { thread.interrupt(); } }); backendChannel .newSucceededFuture() .addListener( (ChannelFutureListener) future -> { // TODO: // use gpu, batch size in load model command RequestInput input = new RequestInput(UUID.randomUUID().toString()); if (gpuId >= 0) { input.addParameter( new InputParameter( "gpu", String.valueOf(gpuId))); } Job job = new Job( null, modelName, WorkerCommands.LOAD, input); model.addJob(backendChannel.id().asLongText(), job); latch.countDown(); }); if (!latch.await(WORKER_TIMEOUT, TimeUnit.MINUTES)) { throw new WorkerInitializationException( "Worker failed to initialize within " + WORKER_TIMEOUT + " mins"); } workerId = workerId + "-" + backendChannel.id(); running.set(true); } catch (Throwable t) { // https://github.com/netty/netty/issues/2597 if (t instanceof IOException) { throw new WorkerInitializationException("Failed to connect to worker.", t); } throw t; } } public boolean isRunning() { return running.get(); } public int getGpuId() { return gpuId; } public long getStartTime() { return startTime; } public int getPid() { return lifeCycle.getPid(); } public void shutdown() { running.set(false); setState(WorkerState.WORKER_SCALED_DOWN, HttpResponseStatus.OK); if (backendChannel != null) { model.removeJobQueue(backendChannel.id().asLongText()); backendChannel.close(); } if (this.serverThread && this.connector != null) { logger.debug("Cleaning connector socket"); this.connector.clean(); } logger.debug("Terminating IOStreams for worker thread shutdown"); lifeCycle.terminateIOStreams(); try { if (out != null) { out.close(); } if (err != null) { err.close(); } } catch (IOException e) { logger.error("Failed to close IO file handles", e); } Thread thread = currentThread.getAndSet(null); if (thread != null) { thread.interrupt(); aggregator.sendError( null, "Worker scaled down.", HttpResponseStatus.INTERNAL_SERVER_ERROR); } } public boolean isServerThread() { return serverThread; } private final String getWorkerName() { String modelName = model.getModelName(); if (modelName.length() > 25) { modelName = modelName.substring(0, 25); } return "W-" + port + '-' + modelName; } void setState(WorkerState newState, HttpResponseStatus status) { listener.notifyChangeState(model.getModelName(), newState, status); logger.debug("{} State change {} -> {}", getWorkerName(), state, newState); long timeTaken = System.currentTimeMillis() - startTime; if (state != WorkerState.WORKER_SCALED_DOWN) { // Don't update the state if it was terminated on purpose.. Scaling in.. this.state = newState; } if (state == WorkerState.WORKER_MODEL_LOADED) { workerLoadTime.setValue(String.valueOf(timeTaken)); workerLoadTime.setTimestamp( String.valueOf(TimeUnit.MILLISECONDS.toSeconds(System.currentTimeMillis()))); loggerMmsMetrics.info("{}", workerLoadTime); } } void retry() { if (state == WorkerState.WORKER_SCALED_DOWN) { logger.debug("Worker terminated due to scale-down call."); return; } ModelManager manager = ModelManager.getInstance(); if (backoffIdx < BACK_OFF.length - 1) { ++backoffIdx; } manager.getScheduler() .schedule(() -> manager.submitTask(this), BACK_OFF[backoffIdx], TimeUnit.SECONDS); logger.info("Retry worker: {} in {} seconds.", workerId, BACK_OFF[backoffIdx]); } @ChannelHandler.Sharable private class WorkerHandler extends SimpleChannelInboundHandler { @Override public void channelRead0(ChannelHandlerContext ctx, ModelWorkerResponse msg) { if (!replies.offer(msg)) { throw new IllegalStateException("Reply queue is full."); } } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { logger.error("Unknown exception", cause); if (cause instanceof OutOfMemoryError) { NettyUtils.sendError(ctx, HttpResponseStatus.INSUFFICIENT_STORAGE, cause); } ctx.close(); } } } ================================================ FILE: frontend/server/src/main/resources/log4j2.xml ================================================ ================================================ FILE: frontend/server/src/test/java/com/amazonaws/ml/mms/CoverageTest.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms; import com.amazonaws.ml.mms.test.TestHelper; import java.io.IOException; import org.testng.annotations.Test; public class CoverageTest { @Test public void test() throws IOException, ClassNotFoundException { TestHelper.testGetterSetters(ModelServer.class); } } ================================================ FILE: frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms; import com.amazonaws.ml.mms.http.DescribeModelResponse; import com.amazonaws.ml.mms.http.ErrorResponse; import com.amazonaws.ml.mms.http.ListModelsResponse; import com.amazonaws.ml.mms.http.StatusResponse; import com.amazonaws.ml.mms.metrics.Dimension; import com.amazonaws.ml.mms.metrics.Metric; import com.amazonaws.ml.mms.metrics.MetricManager; import com.amazonaws.ml.mms.servingsdk.impl.PluginsManager; import com.amazonaws.ml.mms.util.ConfigManager; import com.amazonaws.ml.mms.util.Connector; import com.amazonaws.ml.mms.util.JsonUtils; import com.google.gson.JsonParseException; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPipeline; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpClientCodec; import io.netty.handler.codec.http.HttpContentDecompressor; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderValues; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http.multipart.HttpPostRequestEncoder; import io.netty.handler.codec.http.multipart.MemoryFileUpload; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.stream.ChunkedWriteHandler; import io.netty.handler.timeout.ReadTimeoutHandler; import io.netty.util.CharsetUtil; import io.netty.util.internal.logging.InternalLoggerFactory; import io.netty.util.internal.logging.Slf4JLoggerFactory; import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Field; import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; import java.util.List; import java.util.Properties; import java.util.Scanner; import java.util.concurrent.CountDownLatch; import org.apache.commons.io.IOUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testng.Assert; import org.testng.annotations.AfterSuite; import org.testng.annotations.BeforeSuite; import org.testng.annotations.Test; public class ModelServerTest { private static final String ERROR_NOT_FOUND = "Requested resource is not found, please refer to API document."; private static final String ERROR_METHOD_NOT_ALLOWED = "Requested method is not allowed, please refer to API document."; private ConfigManager configManager; private ModelServer server; CountDownLatch latch; HttpResponseStatus httpStatus; String result; HttpHeaders headers; private String listInferenceApisResult; private String listManagementApisResult; private String noopApiResult; static { TestUtils.init(); } @BeforeSuite public void beforeSuite() throws InterruptedException, IOException, GeneralSecurityException { ConfigManager.init(new ConfigManager.Arguments()); configManager = ConfigManager.getInstance(); PluginsManager.getInstance().initialize(); InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE); server = new ModelServer(configManager); server.start(); try (InputStream is = new FileInputStream("src/test/resources/inference_open_api.json")) { listInferenceApisResult = IOUtils.toString(is, StandardCharsets.UTF_8.name()); } try (InputStream is = new FileInputStream("src/test/resources/management_open_api.json")) { listManagementApisResult = IOUtils.toString(is, StandardCharsets.UTF_8.name()); } try (InputStream is = new FileInputStream("src/test/resources/describe_api.json")) { noopApiResult = IOUtils.toString(is, StandardCharsets.UTF_8.name()); } } @AfterSuite public void afterSuite() { server.stop(); } @Test public void test() throws InterruptedException, HttpPostRequestEncoder.ErrorDataEncoderException, IOException, NoSuchFieldException, IllegalAccessException { Channel channel = null; Channel managementChannel = null; for (int i = 0; i < 5; ++i) { channel = connect(false); if (channel != null) { break; } Thread.sleep(100); } for (int i = 0; i < 5; ++i) { managementChannel = connect(true); if (managementChannel != null) { break; } Thread.sleep(100); } Assert.assertNotNull(channel, "Failed to connect to inference port."); Assert.assertNotNull(managementChannel, "Failed to connect to management port."); testPing(channel); testRoot(channel, listInferenceApisResult); testRoot(managementChannel, listManagementApisResult); testApiDescription(channel, listInferenceApisResult); testDescribeApi(channel); testUnregisterModel(managementChannel); testLoadModel(managementChannel); testSyncScaleModel(managementChannel); testScaleModel(managementChannel); testListModels(managementChannel); testDescribeModel(managementChannel); testLoadModelWithInitialWorkers(managementChannel); testLoadModelWithInitialWorkersWithJSONReqBody(managementChannel); testPredictions(channel); testPredictionsBinary(channel); testPredictionsJson(channel); testInvocationsJson(channel); testInvocationsMultipart(channel); testModelsInvokeJson(channel); testModelsInvokeMultipart(channel); testLegacyPredict(channel); testPredictionsInvalidRequestSize(channel); testPredictionsValidRequestSize(channel); testPredictionsDecodeRequest(channel, managementChannel); testPredictionsDoNotDecodeRequest(channel, managementChannel); testPredictionsModifyResponseHeader(channel, managementChannel); testPredictionsNoManifest(channel, managementChannel); testModelRegisterWithDefaultWorkers(managementChannel); testLogging(channel, managementChannel); testLoggingUnload(channel, managementChannel); testLoadingMemoryError(); testPredictionMemoryError(); testPredictionCustomErrorCode(); testMetricManager(); testErrorBatch(); channel.close().sync(); managementChannel.close().sync(); // negative test case, channel will be closed by server testInvalidRootRequest(); testInvalidInferenceUri(); testInvalidPredictionsUri(); testInvalidDescribeModel(); testPredictionsModelNotFound(); testInvalidManagementUri(); testInvalidModelsMethod(); testInvalidModelMethod(); testDescribeModelNotFound(); testRegisterModelMissingUrl(); testRegisterModelInvalidRuntime(); testRegisterModelNotFound(); testRegisterModelConflict(); testRegisterModelMalformedUrl(); testRegisterModelConnectionFailed(); testRegisterModelHttpError(); testRegisterModelInvalidPath(); testScaleModelNotFound(); testScaleModelFailure(); testUnregisterModelNotFound(); testUnregisterModelTimeout(); testInvalidModel(); } private void testRoot(Channel channel, String expected) throws InterruptedException { result = null; latch = new CountDownLatch(1); HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.OPTIONS, "/"); channel.writeAndFlush(req).sync(); latch.await(); Assert.assertEquals(result, expected); } private void testPing(Channel channel) throws InterruptedException { result = null; latch = new CountDownLatch(1); HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/ping"); channel.writeAndFlush(req); latch.await(); StatusResponse resp = JsonUtils.GSON.fromJson(result, StatusResponse.class); Assert.assertEquals(resp.getStatus(), "Healthy"); Assert.assertTrue(headers.contains("x-request-id")); } private void testApiDescription(Channel channel, String expected) throws InterruptedException { result = null; latch = new CountDownLatch(1); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.GET, "/api-description"); channel.writeAndFlush(req); latch.await(); Assert.assertEquals(result, expected); } private void testDescribeApi(Channel channel) throws InterruptedException { result = null; latch = new CountDownLatch(1); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.OPTIONS, "/predictions/noop_v0.1"); channel.writeAndFlush(req); latch.await(); Assert.assertEquals(result, noopApiResult); } private void testLoadModel(Channel channel) throws InterruptedException { result = null; latch = new CountDownLatch(1); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/models?url=noop-v0.1&model_name=noop_v0.1&runtime=python&synchronous=false"); channel.writeAndFlush(req); latch.await(); StatusResponse resp = JsonUtils.GSON.fromJson(result, StatusResponse.class); Assert.assertEquals(resp.getStatus(), "Model \"noop_v0.1\" registered"); } private void testLoadModelWithInitialWorkers(Channel channel) throws InterruptedException { testUnregisterModel(channel); result = null; latch = new CountDownLatch(1); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/models?url=noop-v0.1&model_name=noop_v0.1&initial_workers=1&synchronous=true"); channel.writeAndFlush(req); latch.await(); StatusResponse resp = JsonUtils.GSON.fromJson(result, StatusResponse.class); Assert.assertEquals(resp.getStatus(), "Workers scaled"); } private void testLoadModelWithInitialWorkersWithJSONReqBody(Channel channel) throws InterruptedException { testUnregisterModel(channel); result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/models"); req.headers().add("Content-Type", "application/json"); req.content() .writeCharSequence( "{'url':'noop-v0.1', 'model_name':'noop_v0.1', 'initial_workers':'1', 'synchronous':'true'}", CharsetUtil.UTF_8); HttpUtil.setContentLength(req, req.content().readableBytes()); channel.writeAndFlush(req); latch.await(); StatusResponse resp = JsonUtils.GSON.fromJson(result, StatusResponse.class); Assert.assertEquals(resp.getStatus(), "Workers scaled"); } private void testScaleModel(Channel channel) throws InterruptedException { result = null; latch = new CountDownLatch(1); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.PUT, "/models/noop_v0.1?min_worker=2"); channel.writeAndFlush(req); latch.await(); StatusResponse resp = JsonUtils.GSON.fromJson(result, StatusResponse.class); Assert.assertEquals(resp.getStatus(), "Processing worker updates..."); } private void testSyncScaleModel(Channel channel) throws InterruptedException { result = null; latch = new CountDownLatch(1); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.PUT, "/models/noop_v0.1?synchronous=true&min_worker=1"); channel.writeAndFlush(req); latch.await(); StatusResponse resp = JsonUtils.GSON.fromJson(result, StatusResponse.class); Assert.assertEquals(resp.getStatus(), "Workers scaled"); } private void testUnregisterModel(Channel channel) throws InterruptedException { result = null; latch = new CountDownLatch(1); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.DELETE, "/models/noop_v0.1"); channel.writeAndFlush(req); latch.await(); StatusResponse resp = JsonUtils.GSON.fromJson(result, StatusResponse.class); Assert.assertEquals(resp.getStatus(), "Model \"noop_v0.1\" unregistered"); } private void testListModels(Channel channel) throws InterruptedException { result = null; latch = new CountDownLatch(1); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.GET, "/models?limit=200&nextPageToken=X"); channel.writeAndFlush(req); latch.await(); ListModelsResponse resp = JsonUtils.GSON.fromJson(result, ListModelsResponse.class); Assert.assertEquals(resp.getModels().size(), 2); } private void testDescribeModel(Channel channel) throws InterruptedException { result = null; latch = new CountDownLatch(1); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.GET, "/models/noop_v0.1"); channel.writeAndFlush(req); latch.await(); DescribeModelResponse resp = JsonUtils.GSON.fromJson(result, DescribeModelResponse.class); Assert.assertTrue(resp.getWorkers().size() > 1); } private void testPredictions(Channel channel) throws InterruptedException { result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop"); req.content().writeCharSequence("data=test", CharsetUtil.UTF_8); HttpUtil.setContentLength(req, req.content().readableBytes()); req.headers() .set( HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED); channel.writeAndFlush(req); latch.await(); Assert.assertEquals(result, "OK"); } private void testPredictionsJson(Channel channel) throws InterruptedException { result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop"); req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8); HttpUtil.setContentLength(req, req.content().readableBytes()); req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); channel.writeAndFlush(req); latch.await(); Assert.assertEquals(result, "OK"); } private void testPredictionsBinary(Channel channel) throws InterruptedException { result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop"); req.content().writeCharSequence("test", CharsetUtil.UTF_8); HttpUtil.setContentLength(req, req.content().readableBytes()); req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_OCTET_STREAM); channel.writeAndFlush(req); latch.await(); Assert.assertEquals(result, "OK"); } private void testInvocationsJson(Channel channel) throws InterruptedException { result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/invocations?model_name=noop"); req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8); HttpUtil.setContentLength(req, req.content().readableBytes()); req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); channel.writeAndFlush(req); latch.await(); Assert.assertEquals(result, "OK"); } private void testInvocationsMultipart(Channel channel) throws InterruptedException, HttpPostRequestEncoder.ErrorDataEncoderException, IOException { result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/invocations"); HttpPostRequestEncoder encoder = new HttpPostRequestEncoder(req, true); encoder.addBodyAttribute("model_name", "noop_v0.1"); MemoryFileUpload body = new MemoryFileUpload("data", "test.txt", "text/plain", null, null, 4); body.setContent(Unpooled.copiedBuffer("test", StandardCharsets.UTF_8)); encoder.addBodyHttpData(body); channel.writeAndFlush(encoder.finalizeRequest()); if (encoder.isChunked()) { channel.writeAndFlush(encoder).sync(); } latch.await(); Assert.assertEquals(result, "OK"); } private void testModelsInvokeJson(Channel channel) throws InterruptedException { result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/models/noop/invoke"); req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8); HttpUtil.setContentLength(req, req.content().readableBytes()); req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); channel.writeAndFlush(req); latch.await(); Assert.assertEquals(result, "OK"); } private void testModelsInvokeMultipart(Channel channel) throws InterruptedException, HttpPostRequestEncoder.ErrorDataEncoderException, IOException { result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/models/noop/invoke"); HttpPostRequestEncoder encoder = new HttpPostRequestEncoder(req, true); MemoryFileUpload body = new MemoryFileUpload("data", "test.txt", "text/plain", null, null, 4); body.setContent(Unpooled.copiedBuffer("test", StandardCharsets.UTF_8)); encoder.addBodyHttpData(body); channel.writeAndFlush(encoder.finalizeRequest()); if (encoder.isChunked()) { channel.writeAndFlush(encoder).sync(); } latch.await(); Assert.assertEquals(result, "OK"); } private void testPredictionsInvalidRequestSize(Channel channel) throws InterruptedException { result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop"); req.content().writeZero(11485760); HttpUtil.setContentLength(req, req.content().readableBytes()); req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_OCTET_STREAM); channel.writeAndFlush(req); latch.await(); Assert.assertEquals(httpStatus, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE); } private void testPredictionsValidRequestSize(Channel channel) throws InterruptedException { result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop"); req.content().writeZero(10385760); HttpUtil.setContentLength(req, req.content().readableBytes()); req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_OCTET_STREAM); channel.writeAndFlush(req); latch.await(); Assert.assertEquals(httpStatus, HttpResponseStatus.OK); } private void loadTests(Channel channel, String model, String modelName) throws InterruptedException { result = null; latch = new CountDownLatch(1); String url = "/models?url=" + model + "&model_name=" + modelName + "&initial_workers=1&synchronous=true"; HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, url); channel.writeAndFlush(req); latch.await(); } private void unloadTests(Channel channel, String modelName) throws InterruptedException { result = null; latch = new CountDownLatch(1); String expected = "Model \"" + modelName + "\" unregistered"; String url = "/models/" + modelName; HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.DELETE, url); channel.writeAndFlush(req); latch.await(); StatusResponse resp = JsonUtils.GSON.fromJson(result, StatusResponse.class); Assert.assertEquals(resp.getStatus(), expected); } private void setConfiguration(String key, String val) throws NoSuchFieldException, IllegalAccessException { Field f = configManager.getClass().getDeclaredField("prop"); f.setAccessible(true); Properties p = (Properties) f.get(configManager); p.setProperty(key, val); } private void testModelRegisterWithDefaultWorkers(Channel mgmtChannel) throws NoSuchFieldException, IllegalAccessException, InterruptedException { setConfiguration("default_workers_per_model", "1"); loadTests(mgmtChannel, "noop-v1.0", "noop_default_model_workers"); result = null; latch = new CountDownLatch(1); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.GET, "/models/noop_default_model_workers"); mgmtChannel.writeAndFlush(req); latch.await(); DescribeModelResponse resp = JsonUtils.GSON.fromJson(result, DescribeModelResponse.class); Assert.assertEquals(httpStatus, HttpResponseStatus.OK); Assert.assertEquals(resp.getMinWorkers(), 1); unloadTests(mgmtChannel, "noop_default_model_workers"); setConfiguration("default_workers_per_model", "0"); } private void testPredictionsDecodeRequest(Channel inferChannel, Channel mgmtChannel) throws InterruptedException, NoSuchFieldException, IllegalAccessException { setConfiguration("decode_input_request", "true"); loadTests(mgmtChannel, "noop-v1.0-config-tests", "noop-config"); result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop-config"); req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8); HttpUtil.setContentLength(req, req.content().readableBytes()); req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); inferChannel.writeAndFlush(req); latch.await(); Assert.assertEquals(httpStatus, HttpResponseStatus.OK); Assert.assertFalse(result.contains("bytearray")); unloadTests(mgmtChannel, "noop-config"); } private void testPredictionsDoNotDecodeRequest(Channel inferChannel, Channel mgmtChannel) throws InterruptedException, NoSuchFieldException, IllegalAccessException { setConfiguration("decode_input_request", "false"); loadTests(mgmtChannel, "noop-v1.0-config-tests", "noop-config"); result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/noop-config"); req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8); HttpUtil.setContentLength(req, req.content().readableBytes()); req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); inferChannel.writeAndFlush(req); latch.await(); Assert.assertEquals(httpStatus, HttpResponseStatus.OK); Assert.assertTrue(result.contains("bytearray")); unloadTests(mgmtChannel, "noop-config"); } private void testPredictionsModifyResponseHeader( Channel inferChannel, Channel managementChannel) throws NoSuchFieldException, IllegalAccessException, InterruptedException { setConfiguration("decode_input_request", "false"); loadTests(managementChannel, "respheader-test", "respheader"); result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/respheader"); req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8); HttpUtil.setContentLength(req, req.content().readableBytes()); req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); inferChannel.writeAndFlush(req); latch.await(); Assert.assertEquals(httpStatus, HttpResponseStatus.OK); Assert.assertEquals(headers.get("dummy"), "1"); Assert.assertEquals(headers.get("content-type"), "text/plain"); Assert.assertTrue(result.contains("bytearray")); unloadTests(managementChannel, "respheader"); } private void testPredictionsNoManifest(Channel inferChannel, Channel mgmtChannel) throws InterruptedException, NoSuchFieldException, IllegalAccessException { setConfiguration("default_service_handler", "service:handle"); loadTests(mgmtChannel, "noop-no-manifest", "nomanifest"); result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/nomanifest"); req.content().writeCharSequence("{\"data\": \"test\"}", CharsetUtil.UTF_8); HttpUtil.setContentLength(req, req.content().readableBytes()); req.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); inferChannel.writeAndFlush(req); latch.await(); Assert.assertEquals(httpStatus, HttpResponseStatus.OK); Assert.assertEquals(result, "OK"); unloadTests(mgmtChannel, "nomanifest"); } private void testLegacyPredict(Channel channel) throws InterruptedException { result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.GET, "/noop/predict?data=test"); channel.writeAndFlush(req); latch.await(); Assert.assertEquals(result, "OK"); } private void testInvalidRootRequest() throws InterruptedException { Channel channel = connect(false); Assert.assertNotNull(channel); HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.METHOD_NOT_ALLOWED.code()); Assert.assertEquals(resp.getMessage(), ERROR_METHOD_NOT_ALLOWED); } private void testInvalidInferenceUri() throws InterruptedException { Channel channel = connect(false); Assert.assertNotNull(channel); HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/InvalidUrl"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code()); Assert.assertEquals(resp.getMessage(), ERROR_NOT_FOUND); } private void testInvalidDescribeModel() throws InterruptedException { Channel channel = connect(false); Assert.assertNotNull(channel); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.OPTIONS, "/predictions/InvalidModel"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code()); Assert.assertEquals(resp.getMessage(), "Model not found: InvalidModel"); } private void testInvalidPredictionsUri() throws InterruptedException { Channel channel = connect(false); Assert.assertNotNull(channel); HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/predictions"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code()); Assert.assertEquals(resp.getMessage(), ERROR_NOT_FOUND); } private void testPredictionsModelNotFound() throws InterruptedException { Channel channel = connect(false); Assert.assertNotNull(channel); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.GET, "/predictions/InvalidModel"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code()); Assert.assertEquals(resp.getMessage(), "Model not found: InvalidModel"); } private void testInvalidManagementUri() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/InvalidUrl"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code()); Assert.assertEquals(resp.getMessage(), ERROR_NOT_FOUND); } private void testInvalidModelsMethod() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, "/models"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.METHOD_NOT_ALLOWED.code()); Assert.assertEquals(resp.getMessage(), ERROR_METHOD_NOT_ALLOWED); } private void testInvalidModelMethod() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/models/noop"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.METHOD_NOT_ALLOWED.code()); Assert.assertEquals(resp.getMessage(), ERROR_METHOD_NOT_ALLOWED); } private void testDescribeModelNotFound() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.GET, "/models/InvalidModel"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code()); Assert.assertEquals(resp.getMessage(), "Model not found: InvalidModel"); } private void testRegisterModelMissingUrl() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/models"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.BAD_REQUEST.code()); Assert.assertEquals(resp.getMessage(), "Parameter url is required."); } private void testRegisterModelInvalidRuntime() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/models?url=InvalidUrl&runtime=InvalidRuntime"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.BAD_REQUEST.code()); Assert.assertEquals(resp.getMessage(), "Invalid RuntimeType value: InvalidRuntime"); } private void testRegisterModelNotFound() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/models?url=InvalidUrl"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code()); Assert.assertEquals(resp.getMessage(), "Model not found in model store: InvalidUrl"); } private void testRegisterModelConflict() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/models?url=noop-v0.1&model_name=noop_v0.1&runtime=python&synchronous=false"); channel.writeAndFlush(req); latch.await(); req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/models?url=noop-v0.1&model_name=noop_v0.1&runtime=python&synchronous=false"); channel.writeAndFlush(req); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.CONFLICT.code()); Assert.assertEquals(resp.getMessage(), "Model noop_v0.1 is already registered."); } private void testRegisterModelMalformedUrl() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/models?url=http%3A%2F%2Flocalhost%3Aaaaa"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code()); Assert.assertEquals(resp.getMessage(), "Invalid model url: http://localhost:aaaa"); } private void testRegisterModelConnectionFailed() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/models?url=http%3A%2F%2Flocalhost%3A18888%2Ffake.mar&synchronous=false"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.BAD_REQUEST.code()); Assert.assertEquals( resp.getMessage(), "Failed to download model from: http://localhost:18888/fake.mar"); } private void testRegisterModelHttpError() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/models?url=https%3A%2F%2Flocalhost%3A8443%2Ffake.mar&synchronous=false"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.BAD_REQUEST.code()); Assert.assertEquals( resp.getMessage(), "Failed to download model from: https://localhost:8443/fake.mar, code: 404"); } private void testRegisterModelInvalidPath() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/models?url=..%2Ffake.mar&synchronous=false"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code()); Assert.assertEquals(resp.getMessage(), "Relative path is not allowed in url: ../fake.mar"); } private void testScaleModelNotFound() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, "/models/fake"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code()); Assert.assertEquals(resp.getMessage(), "Model not found: fake"); } private void testUnregisterModelNotFound() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.DELETE, "/models/fake"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getCode(), HttpResponseStatus.NOT_FOUND.code()); Assert.assertEquals(resp.getMessage(), "Model not found: fake"); } private void testUnregisterModelTimeout() throws InterruptedException, NoSuchFieldException, IllegalAccessException { Channel channel = connect(true); setConfiguration("unregister_model_timeout", "0"); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.DELETE, "/models/noop_v0.1"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); System.out.print("testUnregisterModelTimeout " + resp.getCode()); Assert.assertEquals(resp.getCode(), HttpResponseStatus.REQUEST_TIMEOUT.code()); Assert.assertEquals(resp.getMessage(), "Timed out while cleaning resources: noop_v0.1"); channel = connect(true); setConfiguration("unregister_model_timeout", "120"); req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.DELETE, "/models/noop_v0.1"); channel.writeAndFlush(req).sync(); channel.closeFuture().sync(); Assert.assertEquals(httpStatus, HttpResponseStatus.OK); } private void testScaleModelFailure() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); httpStatus = null; result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/models?url=init-error&model_name=init-error&synchronous=false"); channel.writeAndFlush(req); latch.await(); Assert.assertEquals(httpStatus, HttpResponseStatus.OK); httpStatus = null; result = null; latch = new CountDownLatch(1); req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.PUT, "/models/init-error?synchronous=true&min_worker=1"); channel.writeAndFlush(req); latch.await(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(httpStatus, HttpResponseStatus.INTERNAL_SERVER_ERROR); Assert.assertEquals(resp.getCode(), HttpResponseStatus.INTERNAL_SERVER_ERROR.code()); Assert.assertEquals(resp.getMessage(), "Failed to start workers"); } private void testInvalidModel() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); httpStatus = null; result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/models?url=invalid&model_name=invalid&initial_workers=1&synchronous=true"); channel.writeAndFlush(req); latch.await(); StatusResponse status = JsonUtils.GSON.fromJson(result, StatusResponse.class); Assert.assertEquals(status.getStatus(), "Workers scaled"); channel.close().sync(); channel = connect(false); Assert.assertNotNull(channel); result = null; latch = new CountDownLatch(1); req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/invalid"); req.content().writeCharSequence("data=invalid_output", CharsetUtil.UTF_8); HttpUtil.setContentLength(req, req.content().readableBytes()); req.headers() .set( HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED); channel.writeAndFlush(req); latch.await(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(httpStatus, HttpResponseStatus.SERVICE_UNAVAILABLE); Assert.assertEquals(resp.getCode(), HttpResponseStatus.SERVICE_UNAVAILABLE.code()); Assert.assertEquals(resp.getMessage(), "Invalid model predict output"); } private void testLoadingMemoryError() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); result = null; latch = new CountDownLatch(1); HttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/models?url=loading-memory-error&model_name=memory_error&runtime=python&initial_workers=1&synchronous=true"); channel.writeAndFlush(req); latch.await(); Assert.assertEquals(httpStatus, HttpResponseStatus.INSUFFICIENT_STORAGE); channel.close().sync(); } private void testPredictionMemoryError() throws InterruptedException { // Load the model Channel channel = connect(true); Assert.assertNotNull(channel); result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/models?url=prediction-memory-error&model_name=pred-err&runtime=python&initial_workers=1&synchronous=true"); channel.writeAndFlush(req); latch.await(); Assert.assertEquals(httpStatus, HttpResponseStatus.OK); channel.close().sync(); // Test for prediction channel = connect(false); Assert.assertNotNull(channel); result = null; latch = new CountDownLatch(1); req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/pred-err"); req.content().writeCharSequence("data=invalid_output", CharsetUtil.UTF_8); channel.writeAndFlush(req); latch.await(); Assert.assertEquals(httpStatus, HttpResponseStatus.INSUFFICIENT_STORAGE); channel.close().sync(); // Unload the model channel = connect(true); httpStatus = null; latch = new CountDownLatch(1); Assert.assertNotNull(channel); req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.DELETE, "/models/pred-err"); channel.writeAndFlush(req); latch.await(); Assert.assertEquals(httpStatus, HttpResponseStatus.OK); } private void testPredictionCustomErrorCode() throws InterruptedException { // Load the model Channel channel = connect(true); Assert.assertNotNull(channel); result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/models?url=custom-return-code&model_name=custom-return-code&runtime=python&initial_workers=1&synchronous=true"); channel.writeAndFlush(req); latch.await(); Assert.assertEquals(httpStatus, HttpResponseStatus.OK); channel.close().sync(); // Test for prediction channel = connect(false); Assert.assertNotNull(channel); result = null; latch = new CountDownLatch(1); req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/custom-return-code"); req.content().writeCharSequence("data=invalid_output", CharsetUtil.UTF_8); channel.writeAndFlush(req); latch.await(); ErrorResponse resp = JsonUtils.GSON.fromJson(result, ErrorResponse.class); Assert.assertEquals(resp.getMessage(), "Some Prediction Error"); Assert.assertEquals(resp.getCode(), 599); } private void testErrorBatch() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); httpStatus = null; result = null; latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/models?url=error_batch&model_name=err_batch&initial_workers=1&synchronous=true"); channel.writeAndFlush(req); latch.await(); StatusResponse status = JsonUtils.GSON.fromJson(result, StatusResponse.class); Assert.assertEquals(status.getStatus(), "Workers scaled"); channel.close().sync(); channel = connect(false); Assert.assertNotNull(channel); result = null; latch = new CountDownLatch(1); httpStatus = null; req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/err_batch"); req.content().writeCharSequence("data=invalid_output", CharsetUtil.UTF_8); HttpUtil.setContentLength(req, req.content().readableBytes()); req.headers() .set( HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED); channel.writeAndFlush(req); latch.await(); Assert.assertEquals(httpStatus, HttpResponseStatus.INSUFFICIENT_STORAGE); Assert.assertEquals(result, "Invalid response"); } private void testMetricManager() throws JsonParseException, InterruptedException { MetricManager.scheduleMetrics(configManager); MetricManager metricManager = MetricManager.getInstance(); List metrics = metricManager.getMetrics(); // Wait till first value is read in int count = 0; while (metrics.isEmpty()) { Thread.sleep(500); metrics = metricManager.getMetrics(); Assert.assertTrue(++count < 5); } for (Metric metric : metrics) { if (metric.getMetricName().equals("CPUUtilization")) { Assert.assertEquals(metric.getUnit(), "Percent"); } if (metric.getMetricName().equals("MemoryUsed")) { Assert.assertEquals(metric.getUnit(), "Megabytes"); } if (metric.getMetricName().equals("DiskUsed")) { List dimensions = metric.getDimensions(); for (Dimension dimension : dimensions) { if (dimension.getName().equals("Level")) { Assert.assertEquals(dimension.getValue(), "Host"); } } } } } private void testLogging(Channel inferChannel, Channel mgmtChannel) throws NoSuchFieldException, IllegalAccessException, InterruptedException, IOException { setConfiguration("default_workers_per_model", "2"); loadTests(mgmtChannel, "logging", "logging"); int niter = 5; int expected = 2; for (int i = 0; i < niter; i++) { latch = new CountDownLatch(1); DefaultFullHttpRequest req = new DefaultFullHttpRequest( HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/logging"); req.content().writeCharSequence("data=test", CharsetUtil.UTF_8); HttpUtil.setContentLength(req, req.content().readableBytes()); req.headers() .set( HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED); inferChannel.writeAndFlush(req); latch.await(); Assert.assertEquals(httpStatus, HttpResponseStatus.OK); } File logfile = new File("build/logs/mms_log.log"); Assert.assertTrue(logfile.exists()); Scanner logscanner = new Scanner(logfile, "UTF-8"); int count = 0; while (logscanner.hasNextLine()) { String line = logscanner.nextLine(); if (line.contains("LoggingService inference [PID]:")) { count = count + 1; } } Logger logger = LoggerFactory.getLogger(ModelServerTest.class); logger.info("testLogging, found {}, min expected {}.", count, expected); Assert.assertTrue(count >= expected); unloadTests(mgmtChannel, "logging"); } private void testLoggingUnload(Channel inferChannel, Channel mgmtChannel) throws NoSuchFieldException, IllegalAccessException, InterruptedException, IOException { setConfiguration("default_workers_per_model", "2"); loadTests(mgmtChannel, "logging", "logging"); unloadTests(mgmtChannel, "logging"); int expected = 1; int count = 0; File logfile = new File("build/logs/mms_log.log"); Assert.assertTrue(logfile.exists()); Scanner logscanner = new Scanner(logfile, "UTF-8"); while (logscanner.hasNextLine()) { String line = logscanner.nextLine(); if (line.contains("Model logging unregistered")) { count = count + 1; } } Assert.assertTrue(count >= expected); } private Channel connect(boolean management) { Logger logger = LoggerFactory.getLogger(ModelServerTest.class); final Connector connector = configManager.getListener(management); try { Bootstrap b = new Bootstrap(); final SslContext sslCtx = SslContextBuilder.forClient() .trustManager(InsecureTrustManagerFactory.INSTANCE) .build(); b.group(Connector.newEventLoopGroup(1)) .channel(connector.getClientChannel()) .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 10000) .handler( new ChannelInitializer() { @Override public void initChannel(Channel ch) { ChannelPipeline p = ch.pipeline(); if (connector.isSsl()) { p.addLast(sslCtx.newHandler(ch.alloc())); } p.addLast(new ReadTimeoutHandler(30)); p.addLast(new HttpClientCodec()); p.addLast(new HttpContentDecompressor()); p.addLast(new ChunkedWriteHandler()); p.addLast(new HttpObjectAggregator(6553600)); p.addLast(new TestHandler()); } }); return b.connect(connector.getSocketAddress()).sync().channel(); } catch (Throwable t) { logger.warn("Connect error.", t); } return null; } @ChannelHandler.Sharable private class TestHandler extends SimpleChannelInboundHandler { @Override public void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) { httpStatus = msg.status(); result = msg.content().toString(StandardCharsets.UTF_8); headers = msg.headers(); latch.countDown(); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { Logger logger = LoggerFactory.getLogger(TestHandler.class); logger.error("Unknown exception", cause); ctx.close(); latch.countDown(); } } } ================================================ FILE: frontend/server/src/test/java/com/amazonaws/ml/mms/TestUtils.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import java.security.GeneralSecurityException; import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.SSLContext; public final class TestUtils { private TestUtils() {} public static void init() { // set up system properties for local IDE debug if (System.getProperty("mmsConfigFile") == null) { System.setProperty("mmsConfigFile", "src/test/resources/config.properties"); } if (System.getProperty("METRICS_LOCATION") == null) { System.setProperty("METRICS_LOCATION", "build/logs"); } if (System.getProperty("LOG_LOCATION") == null) { System.setProperty("LOG_LOCATION", "build/logs"); } try { SSLContext context = SSLContext.getInstance("TLS"); context.init(null, InsecureTrustManagerFactory.INSTANCE.getTrustManagers(), null); HttpsURLConnection.setDefaultSSLSocketFactory(context.getSocketFactory()); HttpsURLConnection.setDefaultHostnameVerifier((s, sslSession) -> true); } catch (GeneralSecurityException e) { // ignore } } } ================================================ FILE: frontend/server/src/test/java/com/amazonaws/ml/mms/test/TestHelper.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.test; import java.io.File; import java.io.IOException; import java.lang.reflect.Constructor; import java.lang.reflect.Method; import java.net.URL; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Enumeration; import java.util.List; import java.util.jar.JarEntry; import java.util.jar.JarFile; import org.apache.commons.io.FileUtils; public final class TestHelper { private TestHelper() {} public static void testGetterSetters(Class baseClass) throws IOException, ClassNotFoundException { List> list = getClasses(baseClass); for (Class clazz : list) { Constructor[] constructors = clazz.getConstructors(); Object obj = null; for (Constructor con : constructors) { try { Class[] types = con.getParameterTypes(); Object[] args = new Object[types.length]; for (int i = 0; i < args.length; ++i) { args[i] = getMockValue(types[i]); } obj = con.newInstance(args); } catch (ReflectiveOperationException ignore) { // ignore } } if (obj == null) { continue; } Method[] methods = clazz.getMethods(); for (Method method : methods) { String methodName = method.getName(); int parameterCount = method.getParameterCount(); try { if (parameterCount == 0 && methodName.startsWith("get") || methodName.startsWith("is")) { method.invoke(obj); } else if (methodName.startsWith("set") && parameterCount == 1) { Class type = method.getParameterTypes()[0]; method.invoke(obj, getMockValue(type)); } } catch (ReflectiveOperationException ignore) { // ignore } } } } private static List> getClasses(Class clazz) throws IOException, ClassNotFoundException { URL url = clazz.getProtectionDomain().getCodeSource().getLocation(); String path = url.getPath(); if (!"file".equalsIgnoreCase(url.getProtocol())) { return Collections.emptyList(); } List> classList = new ArrayList<>(); File classPath = new File(path); if (classPath.isDirectory()) { String rootPath = classPath.getCanonicalPath(); String[] filters = new String[] {"class"}; Collection files = FileUtils.listFiles(classPath, filters, true); for (File file : files) { String fileName = file.getCanonicalPath(); fileName = fileName.substring(rootPath.length() + 1); fileName = fileName.substring(0, fileName.lastIndexOf(".")); fileName = fileName.replace(File.separatorChar, '.'); classList.add(Class.forName(fileName)); } } else if (path.toLowerCase().endsWith(".jar")) { try (JarFile jarFile = new JarFile(path)) { Enumeration en = jarFile.entries(); while (en.hasMoreElements()) { JarEntry entry = en.nextElement(); String fileName = entry.getName(); if (fileName.endsWith(".class")) { fileName = fileName.substring(0, fileName.lastIndexOf(".")); fileName = fileName.replace('/', '.'); classList.add(Class.forName(fileName)); } } } } return classList; } private static Object getMockValue(Class type) { if (type.isPrimitive()) { if (type == Boolean.TYPE) { return Boolean.TRUE; } if (type == Character.TYPE) { return '0'; } if (type == Byte.TYPE) { return (byte) 0; } if (type == Short.TYPE) { return (short) 0; } if (type == Integer.TYPE) { return 0; } if (type == Long.TYPE) { return 0L; } if (type == Float.TYPE) { return 0f; } if (type == Double.TYPE) { return 0d; } } return null; } } ================================================ FILE: frontend/server/src/test/java/com/amazonaws/ml/mms/util/ConfigManagerTest.java ================================================ /* * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package com.amazonaws.ml.mms.util; import com.amazonaws.ml.mms.TestUtils; import com.amazonaws.ml.mms.metrics.Dimension; import com.amazonaws.ml.mms.metrics.Metric; import io.netty.handler.ssl.SslContext; import java.io.File; import java.io.IOException; import java.lang.reflect.Field; import java.security.GeneralSecurityException; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import org.junit.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testng.annotations.Test; public class ConfigManagerTest { static { TestUtils.init(); } private Metric createMetric(String metricName, String requestId) { List dimensions = new ArrayList<>(); Metric metric = new Metric(); metric.setMetricName(metricName); metric.setRequestId(requestId); metric.setUnit("Milliseconds"); metric.setTimestamp("1542157988"); Dimension dimension = new Dimension(); dimension.setName("Level"); dimension.setValue("Model"); dimensions.add(dimension); metric.setDimensions(dimensions); return metric; } @SuppressWarnings("unchecked") private void modifyEnv(String key, String val) throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException { try { Class processEnvironmentClass = Class.forName("java.lang.ProcessEnvironment"); Field theEnvironmentField = processEnvironmentClass.getDeclaredField("theEnvironment"); theEnvironmentField.setAccessible(true); Map env = (Map) theEnvironmentField.get(null); env.put(key, val); Field theCIEField = processEnvironmentClass.getDeclaredField("theCaseInsensitiveEnvironment"); theCIEField.setAccessible(true); Map cienv = (Map) theCIEField.get(null); cienv.put(key, val); } catch (NoSuchFieldException e) { Class[] classes = Collections.class.getDeclaredClasses(); Map env = System.getenv(); for (Class cl : classes) { if ("java.util.Collections$UnmodifiableMap".equals(cl.getName())) { Field field = cl.getDeclaredField("m"); field.setAccessible(true); Object obj = field.get(env); Map map = (Map) obj; map.clear(); map.put(key, val); } } } } @Test public void test() throws IOException, GeneralSecurityException, IllegalAccessException, NoSuchFieldException, ClassNotFoundException { modifyEnv("MMS_DEFAULT_RESPONSE_TIMEOUT", "130"); ConfigManager.Arguments args = new ConfigManager.Arguments(); args.setModels(new String[] {"noop_v0.1"}); ConfigManager.init(args); ConfigManager configManager = ConfigManager.getInstance(); configManager.setProperty("keystore", "src/test/resources/keystore.p12"); Assert.assertEquals("true", configManager.getEnableEnvVarsConfig()); Assert.assertEquals(60 * 130, configManager.getDefaultResponseTimeoutSeconds()); Dimension dimension; List metrics = new ArrayList<>(); // Create two metrics and add them to a list metrics.add(createMetric("TestMetric1", "12345")); metrics.add(createMetric("TestMetric2", "23478")); Logger logger = LoggerFactory.getLogger(ConfigManager.MODEL_SERVER_METRICS_LOGGER); logger.debug("{}", metrics); Assert.assertTrue(new File("build/logs/mms_metrics.log").exists()); logger = LoggerFactory.getLogger(ConfigManager.MODEL_METRICS_LOGGER); logger.debug("{}", metrics); Assert.assertTrue(new File("build/logs/model_metrics.log").exists()); Logger modelLogger = LoggerFactory.getLogger(ConfigManager.MODEL_LOGGER); modelLogger.debug("test model_log"); Assert.assertTrue(new File("build/logs/model_log.log").exists()); SslContext ctx = configManager.getSslContext(); Assert.assertNotNull(ctx); } @Test public void testNoEnvVars() throws IllegalAccessException, NoSuchFieldException, ClassNotFoundException { System.setProperty("mmsConfigFile", "src/test/resources/config_test_env.properties"); modifyEnv("MMS_DEFAULT_RESPONSE_TIMEOUT", "130"); ConfigManager.Arguments args = new ConfigManager.Arguments(); args.setModels(new String[] {"noop_v0.1"}); ConfigManager.init(args); ConfigManager configManager = ConfigManager.getInstance(); Assert.assertEquals("false", configManager.getEnableEnvVarsConfig()); Assert.assertEquals(60 * 120, configManager.getDefaultResponseTimeoutSeconds()); } @Test public void testResponseTimeoutSeconds() throws IOException, GeneralSecurityException, IllegalAccessException, NoSuchFieldException, ClassNotFoundException { System.setProperty("mmsConfigFile", "src/test/resources/config.properties"); modifyEnv("MMS_DEFAULT_RESPONSE_TIMEOUT_SECONDS", "130"); ConfigManager.Arguments args = new ConfigManager.Arguments(); args.setModels(new String[] {"noop_v0.1"}); ConfigManager.init(args); ConfigManager configManager = ConfigManager.getInstance(); Assert.assertEquals("true", configManager.getEnableEnvVarsConfig()); Assert.assertEquals(130, configManager.getDefaultResponseTimeoutSeconds()); } } ================================================ FILE: frontend/server/src/test/resources/certs.pem ================================================ -----BEGIN CERTIFICATE----- MIICiDCCAfGgAwIBAgIEeC8zQzANBgkqhkiG9w0BAQsFADB2MQswCQYDVQQGEwJV UzETMBEGA1UECBMKQ2FsaWZvcm5pYTESMBAGA1UEBxMJUGFsbyBBbHRvMRIwEAYD VQQKEwlBbWF6b24gQUkxDjAMBgNVBAsTBU14TmV0MRowGAYDVQQDExFtbXMuYW1h em9uYXdzLmNvbTAgFw0xODA2MjAwMjExMjhaGA8yMTE3MDExMjAyMTEyOFowdjEL MAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExEjAQBgNVBAcTCVBhbG8g QWx0bzESMBAGA1UEChMJQW1hem9uIEFJMQ4wDAYDVQQLEwVNeE5ldDEaMBgGA1UE AxMRbW1zLmFtYXpvbmF3cy5jb20wgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGB AMcbCEP6kn9pUcap5+kYO/5xEl7SL965gSQ2TbFrVv+sLVkLSK8yTtcILr7RUINz FsD151Q7VyQCvpVzkOFew2s2mAFWWxPJYmxo1j/R3IkJakrrTrMy1R3jsqOQMrxY TLGR5LIe2pjdAnb9xWe2NB125619WDG7RrdHWZDfvSPxAgMBAAGjITAfMB0GA1Ud DgQWBBRWjdEyNchYAkdPoyudKJY9YP3JPzANBgkqhkiG9w0BAQsFAAOBgQBMAvqG cqvD3ColO2Ihgb/LCfCdV14e1YhusVFeKyZkSKFYyQR+MoBOxqMQqJ24gVzgqTU/ h+LkMqZcxxJAME08BzPgP5b06DBM4K0o0XUfYUViFpYXB0qCG5CA/0S7ONldBGaZ fv6JrnQ/a1NYBi92AaqXA4VmuaowWLVEFuPV1A== -----END CERTIFICATE----- ================================================ FILE: frontend/server/src/test/resources/config.properties ================================================ # debug=true # vmargs=-Xmx128m -XX:-UseLargePages -XX:+UseG1GC -XX:MaxMetaspaceSize=32M -XX:MaxDirectMemorySize=10m -XX:+ExitOnOutOfMemoryError inference_address=https://127.0.0.1:8443 management_address=unix:/tmp/management.sock # model_server_home=../.. model_store=../modelarchive/src/test/resources/models load_models=noop-v0.1,noop-v1.0 preload_model=false # number_of_netty_threads=0 # netty_client_threads=0 # default_workers_per_model=0 # job_queue_size=100 # plugins_path=/tmp/plugins async_logging=true default_response_timeout=120 unregister_model_timeout=120 # number_of_gpu=1 # cors_allowed_origin # cors_allowed_methods # cors_allowed_headers # keystore=src/test/resources/keystore.p12 # keystore_pass=changeit # keystore_type=PKCS12 private_key_file=src/test/resources/key.pem certificate_file=src/test/resources/certs.pem # max_response_size=6553500 max_request_size=10485760 # blacklist_env_vars=.*USERNAME.*|.*PASSWORD.* # enable_env_vars_config=false # decode_input_request=true enable_envvars_config=true # default_service_handler=/path/to/service.py:handle ================================================ FILE: frontend/server/src/test/resources/config_test_env.properties ================================================ # debug=true # vmargs=-Xmx128m -XX:-UseLargePages -XX:+UseG1GC -XX:MaxMetaspaceSize=32M -XX:MaxDirectMemorySize=10m -XX:+ExitOnOutOfMemoryError inference_address=https://127.0.0.1:8443 management_address=unix:/tmp/management.sock # model_server_home=../.. model_store=../modelarchive/src/test/resources/models load_models=noop-v0.1,noop-v1.0 # number_of_netty_threads=0 # netty_client_threads=0 # default_workers_per_model=0 # job_queue_size=100 async_logging=true default_response_timeout=120 unregister_model_timeout=120 # number_of_gpu=1 # cors_allowed_origin # cors_allowed_methods # cors_allowed_headers # keystore=src/test/resources/keystore.p12 # keystore_pass=changeit # keystore_type=PKCS12 private_key_file=src/test/resources/key.pem certificate_file=src/test/resources/certs.pem # max_response_size=6553500 max_request_size=10485760 # blacklist_env_vars=.*USERNAME.*|.*PASSWORD.* # enable_env_vars_config=false # decode_input_request=true # enable_envvars_config=false ================================================ FILE: frontend/server/src/test/resources/describe_api.json ================================================ { "openapi": "3.0.1", "info": { "title": "RESTful API for: noop_v0.1", "version": "1.0.0" }, "paths": { "/prediction/noop_v0.1": { "post": { "description": "A predict entry point for model: noop_v0.1.", "operationId": "noop_v0.1", "parameters": [], "responses": { "200": { "description": "OK" }, "500": { "description": "Internal Server Error", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } } } } } } } ================================================ FILE: frontend/server/src/test/resources/inference_open_api.json ================================================ { "openapi": "3.0.1", "info": { "title": "Model Server APIs", "description": "Model Server is a flexible and easy to use tool for serving deep learning models", "version": "1.0.0" }, "paths": { "/": { "options": { "operationId": "apiDescription", "parameters": [], "responses": { "200": { "description": "A openapi 3.0.1 descriptor", "content": { "application/json": { "schema": { "type": "object", "required": [ "openapi", "info", "paths" ], "properties": { "openapi": { "type": "string" }, "info": { "type": "object" }, "paths": { "type": "object" } } } } } }, "500": { "description": "Internal Server Error", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } } } } }, "/{model_name}/predict": { "post": { "description": "A legacy predict entry point for each model.", "operationId": "predict", "parameters": [ { "in": "path", "name": "model_name", "description": "Name of model to unregister.", "required": true, "schema": { "type": "string" } } ], "requestBody": { "description": "Input data format is defined by each model.", "content": { "*/*": { "schema": { "type": "string", "format": "binary" } } }, "required": true }, "responses": { "200": { "description": "Model specific output data format", "content": { "*/*": { "schema": { "type": "string", "format": "binary" } } } }, "404": { "description": "Model not found", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } }, "500": { "description": "Internal Server Error", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } }, "503": { "description": "No worker is available to serve request", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } } }, "deprecated": true } }, "/ping": { "get": { "operationId": "ping", "parameters": [], "responses": { "200": { "description": "Model server status", "content": { "application/json": { "schema": { "type": "object", "required": [ "status" ], "properties": { "status": { "type": "string", "description": "Overall status of the Model Server." } } } } } }, "500": { "description": "Internal Server Error", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } } } } }, "/predictions/{model_name}": { "post": { "description": "Predictions entry point for each model. Use OPTIONS method to get detailed model API input and output description.", "operationId": "predictions", "parameters": [ { "in": "path", "name": "model_name", "description": "Name of model.", "required": true, "schema": { "type": "string" } } ], "requestBody": { "description": "Input data format is defined by each model. Use OPTIONS method to get details for model input format.", "content": { "*/*": { "schema": { "type": "string", "format": "binary" } } }, "required": true }, "responses": { "200": { "description": "Output data format is defined by each model. Use OPTIONS method to get details for model output and output format.", "content": { "*/*": { "schema": { "type": "string", "format": "binary" } } } }, "404": { "description": "Model not found", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } }, "500": { "description": "Internal Server Error", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } }, "503": { "description": "No worker is available to serve request", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } } } }, "options": { "description": "Display details of per model input and output.", "operationId": "predictionsApi", "parameters": [ { "in": "path", "name": "model_name", "description": "Name of model.", "required": true, "schema": { "type": "string" } } ], "responses": { "200": { "description": "OK", "content": { "application/json": { "schema": { "type": "object" } } } } } } }, "/api-description": { "get": { "operationId": "apiDescription", "parameters": [], "responses": { "200": { "description": "A openapi 3.0.1 descriptor", "content": { "application/json": { "schema": { "type": "object", "required": [ "openapi", "info", "paths" ], "properties": { "openapi": { "type": "string" }, "info": { "type": "object" }, "paths": { "type": "object" } } } } } }, "500": { "description": "Internal Server Error", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } } }, "deprecated": true } }, "/invocations": { "post": { "description": "A generic invocation entry point for all models.", "operationId": "invocations", "parameters": [ { "in": "query", "name": "model_name", "description": "Name of model", "required": false, "schema": { "type": "string" } } ], "requestBody": { "content": { "multipart/form-data": { "schema": { "required": [ "data" ], "properties": { "model_name": { "type": "string", "description": "Name of model" }, "data": { "type": "string", "format": "binary", "description": "Inference input data" } } } } }, "required": true }, "responses": { "200": { "description": "Model specific output data format", "content": { "*/*": { "schema": { "required": [ "data" ], "properties": { "model_name": { "type": "string", "description": "Name of model" }, "data": { "type": "string", "format": "binary", "description": "Inference input data" } } } } } }, "400": { "description": "Missing model_name parameter", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } }, "404": { "description": "Model not found", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } }, "500": { "description": "Internal Server Error", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } }, "503": { "description": "No worker is available to serve request", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } } } } }, "/models/{model_name}/invoke": { "post": { "description": "A generic invocation entry point for all models.", "operationId": "invocations", "parameters": [ { "in": "query", "name": "model_name", "description": "Name of model", "required": false, "schema": { "type": "string" } } ], "requestBody": { "content": { "multipart/form-data": { "schema": { "required": [ "data" ], "properties": { "model_name": { "type": "string", "description": "Name of model" }, "data": { "type": "string", "format": "binary", "description": "Inference input data" } } } } }, "required": true }, "responses": { "200": { "description": "Model specific output data format", "content": { "*/*": { "schema": { "required": [ "data" ], "properties": { "model_name": { "type": "string", "description": "Name of model" }, "data": { "type": "string", "format": "binary", "description": "Inference input data" } } } } } }, "400": { "description": "Missing model_name parameter", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } }, "404": { "description": "Model not found", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } }, "500": { "description": "Internal Server Error", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } }, "503": { "description": "No worker is available to serve request", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } } } } } } } ================================================ FILE: frontend/server/src/test/resources/key.pem ================================================ -----BEGIN RSA PRIVATE KEY----- MIICXAIBAAKBgQDHGwhD+pJ/aVHGqefpGDv+cRJe0i/euYEkNk2xa1b/rC1ZC0iv Mk7XCC6+0VCDcxbA9edUO1ckAr6Vc5DhXsNrNpgBVlsTyWJsaNY/0dyJCWpK606z MtUd47KjkDK8WEyxkeSyHtqY3QJ2/cVntjQdduetfVgxu0a3R1mQ370j8QIDAQAB AoGANRoxlyfSQKcPR2PzVUjAX3k6xA1c9RMWrVjKWeJd/qymH5SR2yAYxOMKzJu4 1IYycF5lRyLYd+M/f06mOmVyysH3D7hkrNz57Z07UrZ0dO/mmUKRL7zc44mo22ck JtQRwWJMplgew7N8OyqEZbcLOpahjlkL4+KZIWOuO7X5m30CQQDob/rzNY8gfhEm oEHHQ4dCqa/b5as2OqpFoGBZ+iX3dumBf+UKuSHlvEozt4ZMm29DYSjhiGXgLUFw 6NBhWxpXAkEA20oNdGiYAyyGJ6TKkD3FNZYoqB5+E/Cq6c0AACssB4OrJtiGiBFq R1h5HTEwYMe+ciZ4CI5MvBukjAdlfn7W9wJAXOIqyTe060oVdncB8ivlCFmgweHk ajZFRq+Q8UPKGjq1kx9VmtRiXFjC2inTjBds/eL8oCuOcmgDR6hxZQYv3wJAcMLv kECIinlGsvQGRY297wQ7+9dSNaa3/Gmx6mRIy8RlKiCFbUqnP/C6tswoeFu+DqzB ZITn6IK+ZlMXWaiXmQJBAK7V4rR+4VdpYUu1OqPRxChkcM+Y4Wa985A46/8yoo3i 0PEenvApVzhVzS9jF6WEqIKcffBAmOxaXOn3kmn8w2A= -----END RSA PRIVATE KEY----- ================================================ FILE: frontend/server/src/test/resources/management_open_api.json ================================================ { "openapi": "3.0.1", "info": { "title": "Model Server APIs", "description": "Model Server is a flexible and easy to use tool for serving deep learning models", "version": "1.0.0" }, "paths": { "/": { "options": { "operationId": "apiDescription", "parameters": [], "responses": { "200": { "description": "A openapi 3.0.1 descriptor", "content": { "application/json": { "schema": { "type": "object", "required": [ "openapi", "info", "paths" ], "properties": { "openapi": { "type": "string" }, "info": { "type": "object" }, "paths": { "type": "object" } } } } } }, "500": { "description": "Internal Server Error", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } } } } }, "/models": { "get": { "description": "List registered models in Model Server.", "operationId": "listModels", "parameters": [ { "in": "query", "name": "limit", "description": "Use this parameter to specify the maximum number of items to return. When this value is present, Model Server does not return more than the specified number of items, but it might return fewer. This value is optional. If you include a value, it must be between 1 and 1000, inclusive. If you do not include a value, it defaults to 100.", "required": false, "schema": { "type": "integer", "default": "100" } }, { "in": "query", "name": "next_page_token", "description": "The token to retrieve the next set of results. Model Server provides the token when the response from a previous call has more results than the maximum page size.", "required": false, "schema": { "type": "string" } }, { "in": "query", "name": "model_name_pattern", "description": "A model name filter to list only matching models.", "required": false, "schema": { "type": "string" } } ], "responses": { "200": { "description": "OK", "content": { "application/json": { "schema": { "type": "object", "required": [ "models" ], "properties": { "nextPageToken": { "type": "string", "description": "Use this parameter in a subsequent request after you receive a response with truncated results. Set it to the value of NextMarker from the truncated response you just received." }, "models": { "type": "array", "items": { "type": "object", "required": [ "modelName", "modelUrl" ], "properties": { "modelName": { "type": "string", "description": "Name of the model." }, "modelUrl": { "type": "string", "description": "URL of the model." } } }, "description": "A list of registered models." } } } } } }, "500": { "description": "Internal Server Error", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } } } }, "post": { "description": "Register a new model in Model Server.", "operationId": "registerModel", "parameters": [ { "in": "query", "name": "model_url", "description": "Model archive download url, support local file or HTTP(s) protocol. For S3, consider use pre-signed url.", "required": true, "schema": { "type": "string" } }, { "in": "query", "name": "model_name", "description": "Name of model. This value will override modelName in MANIFEST.json if present.", "required": false, "schema": { "type": "string" } }, { "in": "query", "name": "handler", "description": "Inference handler entry-point. This value will override handler in MANIFEST.json if present.", "required": false, "schema": { "type": "string" } }, { "in": "query", "name": "runtime", "description": "Runtime for the model custom service code. This value will override runtime in MANIFEST.json if present.", "required": false, "schema": { "type": "string", "enum": [ "PYTHON", "PYTHON2", "PYTHON3" ] } }, { "in": "query", "name": "batch_size", "description": "Inference batch size, default: 1.", "required": false, "schema": { "type": "integer", "default": "1" } }, { "in": "query", "name": "max_batch_delay", "description": "Maximum delay for batch aggregation, default: 100.", "required": false, "schema": { "type": "integer", "default": "100" } }, { "in": "query", "name": "response_timeout", "description": "Maximum time, in seconds, the Model Server waits for a response from the model inference code, default: 120.", "required": false, "schema": { "type": "integer", "default": "2" } }, { "in": "query", "name": "initial_workers", "description": "Number of initial workers, default: 0.", "required": false, "schema": { "type": "integer", "default": "0" } }, { "in": "query", "name": "synchronous", "description": "Decides whether creation of worker synchronous or not, default: false.", "required": false, "schema": { "type": "boolean", "default": "false" } }, { "in": "query", "name": "preload_model", "description": "Decides if model should be preloaded, default: false.", "required": false, "schema": { "type": "boolean", "default": "false" } } ], "responses": { "200": { "description": "Model registered", "content": { "application/json": { "schema": { "type": "object", "required": [ "status" ], "properties": { "status": { "type": "string", "description": "Error type." } } } } } }, "202": { "description": "Accepted", "content": { "application/json": { "schema": { "type": "object", "required": [ "status" ], "properties": { "status": { "type": "string", "description": "Error type." } } } } } }, "210": { "description": "Partial Success", "content": { "application/json": { "schema": { "type": "object", "required": [ "status" ], "properties": { "status": { "type": "string", "description": "Error type." } } } } } }, "400": { "description": "Bad request", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } }, "404": { "description": "Model not found", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } }, "409": { "description": "Model already registered", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } }, "500": { "description": "Internal Server Error", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } } } } }, "/models/{model_name}": { "get": { "description": "Provides detailed information about the specified model.", "operationId": "describeModel", "parameters": [ { "in": "path", "name": "model_name", "description": "Name of model to describe.", "required": true, "schema": { "type": "string" } } ], "responses": { "200": { "description": "OK", "content": { "application/json": { "schema": { "type": "object", "required": [ "modelName", "modelVersion", "modelUrl", "minWorkers", "maxWorkers", "status", "workers", "metrics" ], "properties": { "modelName": { "type": "string", "description": "Name of the model." }, "modelVersion": { "type": "string", "description": "Version of the model." }, "modelUrl": { "type": "string", "description": "URL of the model." }, "minWorkers": { "type": "integer", "description": "Configured minimum number of worker." }, "maxWorkers": { "type": "integer", "description": "Configured maximum number of worker." }, "batchSize": { "type": "integer", "description": "Configured batch size." }, "maxBatchDelay": { "type": "integer", "description": "Configured maximum batch delay in ms." }, "status": { "type": "string", "description": "Overall health status of the model" }, "workers": { "type": "array", "items": { "type": "object", "required": [ "id", "startTime", "status" ], "properties": { "id": { "type": "string", "description": "Worker id" }, "startTime": { "type": "string", "description": "Worker start time" }, "gpu": { "type": "boolean", "description": "If running on GPU" }, "status": { "type": "string", "description": "Worker status", "enum": [ "READY", "LOADING", "UNLOADING" ] } } }, "description": "A list of active backend workers." }, "metrics": { "type": "object", "required": [ "rejectedRequests", "waitingQueueSize", "requests" ], "properties": { "rejectedRequests": { "type": "integer", "description": "Number requests has been rejected in last 10 minutes." }, "waitingQueueSize": { "type": "integer", "description": "Number requests waiting in the queue." }, "requests": { "type": "integer", "description": "Number requests processed in last 10 minutes." } } } } } } } }, "500": { "description": "Internal Server Error", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } } } }, "put": { "description": "Configure number of workers for a model, This is a asynchronous call by default. Caller need to call describeModel check if the model workers has been changed.", "operationId": "setAutoScale", "parameters": [ { "in": "path", "name": "model_name", "description": "Name of model to describe.", "required": true, "schema": { "type": "string" } }, { "in": "query", "name": "min_worker", "description": "Minimum number of worker processes.", "required": false, "schema": { "type": "integer", "default": "1" } }, { "in": "query", "name": "max_worker", "description": "Maximum number of worker processes.", "required": false, "schema": { "type": "integer", "default": "1" } }, { "in": "query", "name": "number_gpu", "description": "Number of GPU worker processes to create.", "required": false, "schema": { "type": "integer", "default": "0" } }, { "in": "query", "name": "synchronous", "description": "Decides whether the call is synchronous or not, default: false.", "required": false, "schema": { "type": "boolean", "default": "false" } }, { "in": "query", "name": "timeout", "description": "Waiting up to the specified wait time if necessary for a worker to complete all pending requests. Use 0 to terminate backend worker process immediately. Use -1 for wait infinitely.", "required": false, "schema": { "type": "integer", "default": "-1" } } ], "responses": { "200": { "description": "Model workers updated", "content": { "application/json": { "schema": { "type": "object", "required": [ "status" ], "properties": { "status": { "type": "string", "description": "Error type." } } } } } }, "202": { "description": "Accepted", "content": { "application/json": { "schema": { "type": "object", "required": [ "status" ], "properties": { "status": { "type": "string", "description": "Error type." } } } } } }, "210": { "description": "Partial Success", "content": { "application/json": { "schema": { "type": "object", "required": [ "status" ], "properties": { "status": { "type": "string", "description": "Error type." } } } } } }, "400": { "description": "Bad request", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } }, "404": { "description": "Model not found", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } }, "500": { "description": "Internal Server Error", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } } } }, "delete": { "description": "Unregister a model from Model Server. This is an asynchronous call by default. Caller can call listModels to confirm if all the works has be terminated.", "operationId": "unregisterModel", "parameters": [ { "in": "path", "name": "model_name", "description": "Name of model to unregister.", "required": true, "schema": { "type": "string" } }, { "in": "query", "name": "synchronous", "description": "Decides whether the call is synchronous or not, default: false.", "required": false, "schema": { "type": "boolean", "default": "false" } }, { "in": "query", "name": "timeout", "description": "Waiting up to the specified wait time if necessary for a worker to complete all pending requests. Use 0 to terminate backend worker process immediately. Use -1 for wait infinitely.", "required": false, "schema": { "type": "integer", "default": "-1" } } ], "responses": { "200": { "description": "Model unregistered", "content": { "application/json": { "schema": { "type": "object", "required": [ "status" ], "properties": { "status": { "type": "string", "description": "Error type." } } } } } }, "202": { "description": "Accepted", "content": { "application/json": { "schema": { "type": "object", "required": [ "status" ], "properties": { "status": { "type": "string", "description": "Error type." } } } } } }, "404": { "description": "Model not found", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } }, "408": { "description": "Request Timeout Error", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } }, "500": { "description": "Internal Server Error", "content": { "application/json": { "schema": { "type": "object", "required": [ "code", "type", "message" ], "properties": { "code": { "type": "integer", "description": "Error code." }, "type": { "type": "string", "description": "Error type." }, "message": { "type": "string", "description": "Error message." } } } } } } } } } } } ================================================ FILE: frontend/settings.gradle ================================================ include 'server', 'modelarchive', 'cts' ================================================ FILE: frontend/tools/conf/checkstyle.xml ================================================ ================================================ FILE: frontend/tools/conf/findbugs-exclude.xml ================================================ ================================================ FILE: frontend/tools/conf/pmd.xml ================================================ Java Rule in PMD ================================================ FILE: frontend/tools/conf/suppressions.xml ================================================ ================================================ FILE: frontend/tools/gradle/check.gradle ================================================ apply plugin: 'findbugs' findbugs { excludeFilter = file("${rootProject.projectDir}/tools/conf/findbugs-exclude.xml") ignoreFailures = false findbugsTest.enabled = true } tasks.withType(FindBugs) { reports { xml.enabled false html.enabled true } } apply plugin: 'pmd' pmd { ignoreFailures = false pmdTest.enabled = false ruleSets = [] // workaround pmd gradle plugin bug ruleSetFiles = files("${rootProject.projectDir}/tools/conf/pmd.xml") } tasks.withType(Pmd){ reports{ xml.enabled=true html.enabled=true } } apply plugin: 'checkstyle' checkstyle { toolVersion = '7.1.2' ignoreFailures = false checkstyleTest.enabled = true configProperties = [ "checkstyle.suppressions.file" : file("${rootProject.projectDir}/tools/conf/suppressions.xml")] configFile = file("${rootProject.projectDir}/tools/conf/checkstyle.xml") } checkstyleMain { classpath += configurations.compile } tasks.withType(Checkstyle) { reports { xml.enabled false html.enabled true } } apply plugin: "jacoco" jacoco { toolVersion = "0.8.1" } jacocoTestReport { group = "Reporting" reports { xml.enabled true csv.enabled false } } check.dependsOn jacocoTestReport check.dependsOn jacocoTestCoverageVerification ================================================ FILE: frontend/tools/gradle/formatter.gradle ================================================ buildscript { repositories { maven { url "https://plugins.gradle.org/m2/" } } dependencies { classpath 'com.google.googlejavaformat:google-java-format:1.6' } } apply plugin: FormatterPlugin check.dependsOn verifyJava import com.google.googlejavaformat.java.Formatter import com.google.googlejavaformat.java.ImportOrderer import com.google.googlejavaformat.java.JavaFormatterOptions import com.google.googlejavaformat.java.Main import com.google.googlejavaformat.java.RemoveUnusedImports class FormatterPlugin implements Plugin { void apply(Project project) { project.task('formatJava') { doLast { Main formatter = new Main(new PrintWriter(System.out, true), new PrintWriter(System.err, true), System.in) Project rootProject = project.getRootProject() for (item in project.sourceSets) { for (File file : item.getAllSource()) { if (!file.getName().endsWith(".java")) { continue } if (formatter.format("-a", "-i", file.getAbsolutePath()) != 0) { throw new GradleException("Format java failed: " + file.getAbsolutePath()) } } } } } project.task('verifyJava') { doLast { def options = JavaFormatterOptions.builder().style(JavaFormatterOptions.Style.AOSP).build() Formatter formatter = new Formatter(options) Project rootProject = project.getRootProject() for (item in project.sourceSets) { for (File file : item.getAllSource()) { if (!file.getName().endsWith(".java")) { continue } String src = file.text String formatted = formatter.formatSource(src) formatted = RemoveUnusedImports.removeUnusedImports(formatted, RemoveUnusedImports.JavadocOnlyImports.KEEP) formatted = ImportOrderer.reorderImports(formatted); if (!src.equals(formatted)) { throw new GradleException("File not formatted: " + file.getAbsolutePath()) } } } } } } } ================================================ FILE: frontend/tools/gradle/launcher.gradle ================================================ apply plugin: LauncherPlugin clean.dependsOn killServer import org.gradle.internal.jvm.Jvm class LauncherPlugin implements Plugin { void apply(Project project) { project.task('startServer') { dependsOn project.jar doLast { def pidFile = getPidFile() if (pidFile.exists()) { throw new GradleException("Server already running!") } def list = [] list.addAll(project.configurations.runtime.getFiles()) list.add(project.jar.outputs.files.singleFile) String cp = CollectionUtils.join(File.pathSeparator, list) String jvmPath = Jvm.current().getJavaExecutable() def cmd = [jvmPath, "-agentlib:jdwp=transport=dt_socket,address=0.0.0.0:4000,server=y,suspend=n", "-DmmsConfigFile=${project.projectDir}/src/test/resources/config.properties", "-DLOG_LOCATION=${project.buildDir}/logs", "-DMETRICS_LOCATION=${project.buildDir}/logs", "-cp", cp, "com.amazonaws.ml.mms.ModelServer"] as String[] def builder = new ProcessBuilder(cmd) builder.redirectErrorStream(true) builder.directory(project.projectDir) Process process = builder.start() ReaderThread rt = new ReaderThread(process.getInputStream()) rt.start() new ReaderThread(process.getErrorStream()).start() try { while (!rt.done) { try { process.exitValue(); throw new GradleException("MMS stop unexpectedly.") } catch(IllegalThreadStateException ex) { Thread.sleep(500); } } def pidField = process.class.getDeclaredField('pid') pidField.accessible = true pidFile << pidField.getInt(process) logger.quiet "MMS service started." } catch (IllegalThreadStateException ignored) { } } } project.task('killServer') { doLast { def pidFile = getPidFile() if(!pidFile.exists()) { logger.quiet "No server running!" return } def pid = pidFile.text def process = "kill $pid".execute() try { process.waitFor() } finally { pidFile.delete() } } } project.task('restartServer') { dependsOn project.killServer dependsOn project.startServer } } private File getPidFile() { return new File("build/server.pid") } } class ReaderThread extends Thread { private InputStream is private boolean done; public ReaderThread(InputStream is) { this.is = is } public void run() { long begin = System.currentTimeMillis() def line def reader = new BufferedReader(new InputStreamReader(is)) while ((line = reader.readLine()) != null) { if (!done) { done = line.matches("Model server started.*") println line } } } } ================================================ FILE: mms/.gitignore ================================================ frontend ================================================ FILE: mms/__init__.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ This module does the following: a. Starts model-server. b. Creates end-points based on the configured models. c. Exposes standard "ping" and "api-description" endpoints. d. Waits for servicing inference requests. """ from . import version __version__ = version.__version__ ================================================ FILE: mms/arg_parser.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ This module parses the arguments given through the multi-model-server command-line. This is used by model-server at runtime. """ import argparse # noinspection PyTypeChecker class ArgParser(object): """ Argument parser for multi-model-server and multi-model-export commands More detailed example is available at https://github.com/awslabs/multi-model-server/blob/master/README.md """ @staticmethod def mms_parser(): """ Argument parser for multi-model-server start service """ parser = argparse.ArgumentParser(prog='multi-model-server', description='Multi Model Server') sub_parse = parser.add_mutually_exclusive_group(required=False) sub_parse.add_argument('--start', action='store_true', help='Start the model-server') sub_parse.add_argument('--stop', action='store_true', help='Stop the model-server') parser.add_argument('--mms-config', dest='mms_config', help='Configuration file for model server') parser.add_argument('--model-store', dest='model_store', help='Model store location where models can be loaded') parser.add_argument('--models', metavar='MODEL_PATH1 MODEL_NAME=MODEL_PATH2...', nargs='+', help='Models to be loaded using [model_name=]model_location format. ' 'Location can be a HTTP URL, a model archive file or directory ' 'contains model archive files in MODEL_STORE.') parser.add_argument('--log-config', dest='log_config', help='Log4j configuration file for model server') parser.add_argument('--foreground', help='Run the model server in foreground. If this option is disabled, the model server' ' will run in the background.', action='store_true') return parser @staticmethod def str2bool(v): if v.lower() in ('yes', 'true', 'y', '1'): return True if v.lower() in ('no', 'false', 'n', '0'): return False raise argparse.ArgumentTypeError('Boolean value expected {}'.format(v)) @staticmethod def model_service_worker_args(): """ ArgParser for backend worker. Takes the socket name and socket type. :return: """ parser = argparse.ArgumentParser(prog='model-server-worker', description='Model Server Worker') parser.add_argument('--sock-type', required=True, dest="sock_type", type=str, choices=["unix", "tcp"], help='Socket type the model service worker would use. The options are\n' 'unix: The model worker expects to unix domain-socket\n' 'tcp: The model worker expects a host-name and port-number') parser.add_argument('--sock-name', required=False, dest="sock_name", type=str, help='If \'sock-type\' is \'unix\', sock-name is expected to be a string. ' 'Eg: --sock-name \"test_sock\"') parser.add_argument('--host', type=str, help='If \'sock-type\' is \'tcp\' this is expected to have a host IP address') parser.add_argument('--port', type=str, help='If \'sock-type\' is \'tcp\' this is expected to have the host port to bind on') parser.add_argument('--handler', type=str, help='Entry point to the Model Server') parser.add_argument('--model-path', type=str, help='Path to the actual model location') parser.add_argument('--model-name', type=str, help='Name of the model') parser.add_argument('--preload-model', dest="preload_model", required=True, type=ArgParser.str2bool, help='Determines if initialization should occur before spawning/forking child process') parser.add_argument('--tmp-dir', dest="tmp_dir", required=True, type=str, help='Location of temporaty file descriptors') return parser @staticmethod def extract_args(args=None): parser = ArgParser.mms_parser() return parser.parse_args(args) if args else parser.parse_args() ================================================ FILE: mms/configs/sagemaker_config.properties ================================================ vmargs=-XX:-UseContainerSupport model_store=$$SAGEMAKER_MODEL_STORE$$ inference_address=http://0.0.0.0:$$SAGEMAKER_BIND_TO_PORT$$ management_address=http://0.0.0.0:$$SAGEMAKER_BIND_TO_PORT$$ default_response_timeout=$$SAGEMAKER_RESPONSE_TIMEOUT$$ unregister_model_timeout=$$SAGEMAKER_RESPONSE_TIMEOUT$$ default_workers_per_model=$$SAGEMAKER_NUM_MODEL_WORKERS$$ default_service_handler=$$SAGEMAKER_HANDLER$$ async_logging=true max_response_size=$$SAGEMAKER_MAX_RESPONSE_SIZE$$ max_request_size=$$SAGEMAKER_MAX_REQUEST_SIZE$$ decode_input_request=false ================================================ FILE: mms/context.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Context object of incoming request """ class Context(object): """ Context stores model relevant worker information Some fixed during load times and some """ def __init__(self, model_name, model_dir, manifest, batch_size, gpu, mms_version): self.model_name = model_name self.manifest = manifest self._system_properties = { "model_dir": model_dir, "gpu_id": gpu, "batch_size": batch_size, "server_name": "MMS", "server_version": mms_version } self.request_ids = None self.request_processor = None self._metrics = None @property def system_properties(self): return self._system_properties @property def request_processor(self): return self._request_processor @request_processor.setter def request_processor(self, request_processor): self._request_processor = request_processor @property def metrics(self): return self._metrics @metrics.setter def metrics(self, metrics): self._metrics = metrics def get_request_id(self, idx=0): return self.request_ids.get(idx) def get_request_header(self, idx, key): return self._request_processor[idx].get_request_property(key) def get_all_request_header(self, idx): return self._request_processor[idx].get_request_properties() def set_response_content_type(self, idx, value): self.set_response_header(idx, 'content-type', value) def get_response_content_type(self, idx): return self.get_response_headers(idx).get('content-type') def get_response_status(self, idx): return self._request_processor[idx].get_response_status_code(), \ self._request_processor[idx].get_response_status_phrase() def set_response_status(self, code=200, phrase="", idx=0): """ Set the status code of individual requests :param phrase: :param idx: The index data in the list(data) that is sent to the handle() method :param code: :return: """ if self._request_processor is not None and self._request_processor[idx] is not None: self._request_processor[idx].report_status(code, reason_phrase=phrase) def set_all_response_status(self, code=200, phrase=""): """ Set the status code of individual requests :param phrase: :param code: :return: """ for idx, _ in enumerate(self._request_processor): self._request_processor[idx].report_status(code, reason_phrase=phrase) def get_response_headers(self, idx): return self._request_processor[idx].get_response_headers() def set_response_header(self, idx, key, value): self._request_processor[idx].add_response_property(key, value) # TODO: Should we add "add_header()" interface, to have multiple values for a single header. EG: Accept headers. def __eq__(self, other): return isinstance(other, Context) and self.__dict__ == other.__dict__ class RequestProcessor(object): """ Request processor """ def __init__(self, request_header): self._status_code = 200 self._reason_phrase = None self._response_header = {} self._request_header = request_header def get_request_property(self, key): return self._request_header.get(key) def report_status(self, code, reason_phrase=None): self._status_code = code self._reason_phrase = reason_phrase def get_response_status_code(self): return self._status_code def get_response_status_phrase(self): return self._reason_phrase def add_response_property(self, key, value): self._response_header[key] = value def get_response_headers(self): return self._response_header def get_response_header(self, key): return self._response_header.get(key) def get_request_properties(self): return self._request_header ================================================ FILE: mms/export_model.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ This command line interface is no longer used. Please refer to model-archiver tool for the new CLI for exporting models. """ def main(): print('\033[93m' # Red Color start + "multi-model-export is no longer supported.\n" "Please use model-archiver to create 1.0 model archive.\n" "For more detail, see: https://pypi.org/project/model-archiver" + '\033[0m') # Red Color end if __name__ == '__main__': main() ================================================ FILE: mms/metrics/__init__.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ This is a folder for all python worker metrics. """ from . import dimension from . import metric from . import metric_encoder from . import metrics_store from . import system_metrics from . import unit ================================================ FILE: mms/metrics/dimension.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Dimension class for model server metrics """ class Dimension(object): """ Dimension class defining key value pair """ def __init__(self, name, value): """ Constructor for Dimension class Parameters ---------- name: str NAme of dimension value : str Unique Value of dimension """ self.name = name self.value = value def __str__(self): """ Return a string value :return: """ return "{}:{}".format(self.name, self.value) def to_dict(self): """ return an dictionary """ return {'Name': self.name, 'Value': self.value} ================================================ FILE: mms/metrics/metric.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Metric class for model server """ import time import socket from collections import OrderedDict from builtins import str from mms.metrics.unit import Units MetricUnit = Units() class Metric(object): """ Class for generating metrics and printing it to stdout of the worker """ def __init__(self, name, value, unit, dimensions, request_id=None, metric_method=None): """ Constructor for Metric class Metric class will spawn a thread and report collected metrics to stdout of worker Parameters ---------- name: str Name of metric value : int, float Can be integer or float unit: str unit can be one of ms, percent, count, MB, GB or a generic string dimensions: list list of dimension objects request_id: str req_id of metric metric_method: str useful for defining different operations, optional """ self.name = name self.unit = unit if unit in list(MetricUnit.units.keys()): self.unit = MetricUnit.units[unit] self.metric_method = metric_method self.value = value self.dimensions = dimensions self.request_id = request_id def update(self, value): """ Update function for Metric class Parameters ---------- value : int, float metric to be updated """ if self.metric_method == 'counter': self.value += value else: self.value = value def __str__(self): dims = ",".join([str(d) for d in self.dimensions]) if self.request_id: return "{}.{}:{}|#{}|#hostname:{},{},{}".format( self.name, self.unit, self.value, dims, socket.gethostname(), int(time.time()), self.request_id) return "{}.{}:{}|#{}|#hostname:{},{}".format( self.name, self.unit, self.value, dims, socket.gethostname(), int(time.time())) def to_dict(self): """ return an Ordered Dictionary """ return OrderedDict({'MetricName': self.name, 'Value': self.value, 'Unit': self.unit, 'Dimensions': self.dimensions, 'Timestamp': int(time.time()), 'HostName': socket.gethostname(), 'RequestId': self.request_id}) ================================================ FILE: mms/metrics/metric_collector.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Single start point for system metrics and process metrics script """ import logging import sys from mms.metrics import system_metrics from mms.metrics.process_memory_metric import check_process_mem_usage if __name__ == '__main__': logging.basicConfig(stream=sys.stdout, format="%(message)s", level=logging.INFO) system_metrics.collect_all(sys.modules['mms.metrics.system_metrics']) check_process_mem_usage(sys.stdin) ================================================ FILE: mms/metrics/metric_encoder.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Metric Encoder class for json dumps """ import json from json import JSONEncoder from mms.metrics.dimension import Dimension from mms.metrics.metric import Metric class MetricEncoder(JSONEncoder): """ Encoder class for json encoding Metric Object """ def default(self, obj): # pylint: disable=arguments-differ, method-hidden """ Override only when object is of type Metric """ if isinstance(obj, (Metric, Dimension)): return obj.to_dict() return json.JSONEncoder.default(self, obj) ================================================ FILE: mms/metrics/metrics_store.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Metrics collection module """ from builtins import str from mms.metrics.dimension import Dimension from mms.metrics.metric import Metric class MetricsStore(object): """ Class for creating, modifying different metrics. And keep them in a dictionary """ def __init__(self, request_ids, model_name): """ Initialize metrics map,model name and request map """ self.store = list() self.request_ids = request_ids self.model_name = model_name self.cache = {} def _add_or_update(self, name, value, req_id, unit, metrics_method=None, dimensions=None): """ Add a metric key value pair Parameters ---------- name : str metric name value: int, float value of metric req_id: str request id unit: str unit of metric value: int, float , str value of metric metrics_method: str, optional indicates type of metric operation if it is defined """ # IF req_id is none error Metric if dimensions is None: dimensions = list() elif not isinstance(dimensions, list): raise ValueError("Please provide a list of dimensions") if req_id is None: dimensions.append(Dimension("Level", "Error")) else: dimensions.append(Dimension("ModelName", self.model_name)) dimensions.append(Dimension("Level", "Model")) # Cache the metric with an unique key for update dim_str = [name, unit, str(req_id)] + [str(d) for d in dimensions] dim_str = '-'.join(dim_str) if dim_str not in self.cache: metric = Metric(name, value, unit, dimensions, req_id, metrics_method) self.store.append(metric) self.cache[dim_str] = metric else: self.cache[dim_str].update(value) def _get_req(self, idx): """ Provide the request id dimension Parameters ---------- idx : int request_id index in batch """ # check if request id for the metric is given, if so use it else have a list of all. req_id = self.request_ids if isinstance(req_id, dict): req_id = ','.join(self.request_ids.values()) if idx is not None and self.request_ids is not None and idx in self.request_ids: req_id = self.request_ids[idx] return req_id def add_counter(self, name, value, idx=None, dimensions=None): """ Add a counter metric or increment an existing counter metric Parameters ---------- name : str metric name value: int value of metric idx: int request_id index in batch dimensions: list list of dimensions for the metric """ unit = 'count' req_id = self._get_req(idx) self._add_or_update(name, value, req_id, unit, 'counter', dimensions) def add_time(self, name, value, idx=None, unit='ms', dimensions=None): """ Add a time based metric like latency, default unit is 'ms' Parameters ---------- name : str metric name value: int value of metric idx: int request_id index in batch unit: str unit of metric, default here is ms, s is also accepted dimensions: list list of dimensions for the metric """ if unit not in ['ms', 's']: raise ValueError("the unit for a timed metric should be one of ['ms', 's']") req_id = self._get_req(idx) self._add_or_update(name, value, req_id, unit, dimensions) def add_size(self, name, value, idx=None, unit='MB', dimensions=None): """ Add a size based metric Parameters ---------- name : str metric name value: int, float value of metric idx: int request_id index in batch unit: str unit of metric, default here is 'MB', 'kB', 'GB' also supported dimensions: list list of dimensions for the metric """ if unit not in ['MB', 'kB', 'GB', 'B']: raise ValueError("The unit for size based metric is one of ['MB','kB', 'GB', 'B']") req_id = self._get_req(idx) self._add_or_update(name, value, req_id, unit, dimensions) def add_percent(self, name, value, idx=None, dimensions=None): """ Add a percentage based metric Parameters ---------- name : str metric name value: int, float value of metric idx: int request_id index in batch dimensions: list list of dimensions for the metric """ unit = 'percent' req_id = self._get_req(idx) self._add_or_update(name, value, req_id, unit, dimensions) def add_error(self, name, value, dimensions=None): """ Add a Error Metric Parameters ---------- name : str metric name value: str value of metric, in this case a str dimensions: list list of dimensions for the metric """ unit = '' # noinspection PyTypeChecker self._add_or_update(name, value, None, unit, dimensions) def add_metric(self, name, value, idx=None, unit=None, dimensions=None): """ Add a metric which is generic with custom metrics Parameters ---------- name : str metric name value: int, float value of metric idx: int request_id index in batch unit: str unit of metric dimensions: list list of dimensions for the metric """ req_id = self._get_req(idx) self._add_or_update(name, value, req_id, unit, dimensions) ================================================ FILE: mms/metrics/process_memory_metric.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Collect process memory usage metrics here Pass a json, collection of pids and gpuID """ import logging import psutil def get_cpu_usage(pid): """ use psutil for cpu memory :param pid: str :return: int """ try: process = psutil.Process(int(pid)) except psutil.Error: logging.error("Failed to get process for pid: %s", pid, exc_info=True) return 0 mem_utilization = process.memory_info()[0] if mem_utilization == 0: logging.error("Failed to get memory utilization for pid: %s", pid, exc_info=True) return 0 return mem_utilization def check_process_mem_usage(stdin): """ Return ------ mem_utilization: float """ process_list = stdin.readline().strip().split(",") for process in process_list: if not process: continue mem_utilization = get_cpu_usage(process) if mem_utilization != 0: logging.info("%s:%d", process, mem_utilization) ================================================ FILE: mms/metrics/system_metrics.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Module to collect system metrics for front-end """ import logging import types from builtins import str import psutil from mms.metrics.dimension import Dimension from mms.metrics.metric import Metric system_metrics = [] dimension = [Dimension('Level', 'Host')] def cpu_utilization(): data = psutil.cpu_percent() system_metrics.append(Metric('CPUUtilization', data, 'percent', dimension)) def memory_used(): data = psutil.virtual_memory().used / (1024 * 1024) # in MB system_metrics.append(Metric('MemoryUsed', data, 'MB', dimension)) def memory_available(): data = psutil.virtual_memory().available / (1024 * 1024) # in MB system_metrics.append(Metric('MemoryAvailable', data, 'MB', dimension)) def memory_utilization(): data = psutil.virtual_memory().percent system_metrics.append(Metric('MemoryUtilization', data, 'percent', dimension)) def disk_used(): data = psutil.disk_usage('/').used / (1024 * 1024 * 1024) # in GB system_metrics.append(Metric('DiskUsage', data, 'GB', dimension)) def disk_utilization(): data = psutil.disk_usage('/').percent system_metrics.append(Metric('DiskUtilization', data, 'percent', dimension)) def disk_available(): data = psutil.disk_usage('/').free / (1024 * 1024 * 1024) # in GB system_metrics.append(Metric('DiskAvailable', data, 'GB', dimension)) def collect_all(mod): """ Collect all system metrics. :param mod: :return: """ members = dir(mod) for i in members: value = getattr(mod, i) if isinstance(value, types.FunctionType) and value.__name__ not in ('collect_all', 'log_msg'): value() for met in system_metrics: logging.info(str(met)) logging.info("") ================================================ FILE: mms/metrics/unit.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Module to define Unit mappings """ class Units(object): """ Define a unit of elements """ def __init__(self): self.units = { 'ms': "Milliseconds", 's': 'Seconds', 'percent': 'Percent', 'count': 'Count', 'MB': 'Megabytes', 'GB': 'Gigabytes', 'kB': 'Kilobytes', 'B': 'Bytes', None: 'unit', } ================================================ FILE: mms/model_loader.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Model loader. """ import importlib import inspect import json import logging import os import sys import uuid from abc import ABCMeta, abstractmethod from builtins import str from mms.metrics.metrics_store import MetricsStore from mms.service import Service class ModelLoaderFactory(object): """ ModelLoaderFactory """ @staticmethod def get_model_loader(model_dir): manifest_file = os.path.join(model_dir, "MAR-INF/MANIFEST.json") if os.path.exists(manifest_file): return MmsModelLoader() elif os.path.exists(os.path.join(model_dir, "MANIFEST.json")): return LegacyModelLoader() else: return MmsModelLoader() class ModelLoader(object): """ Base Model Loader class. """ __metaclass__ = ABCMeta @abstractmethod def load(self, model_name, model_dir, handler, gpu_id, batch_size): """ Load model from file. :param model_name: :param model_dir: :param handler: :param gpu_id: :param batch_size: :return: Model """ pass # pylint: disable=unnecessary-pass @staticmethod def list_model_services(module, parent_class=None): """ Parse user defined module to get all model service classes in it. :param module: :param parent_class: :return: List of model service class definitions """ # Parsing the module to get all defined classes classes = [cls[1] for cls in inspect.getmembers(module, lambda member: inspect.isclass(member) and member.__module__ == module.__name__)] # filter classes that is subclass of parent_class if parent_class is not None: return [c for c in classes if issubclass(c, parent_class)] return classes class MmsModelLoader(ModelLoader): """ MMS 1.0 Model Loader """ def load(self, model_name, model_dir, handler, gpu_id, batch_size): """ Load MMS 1.0 model from file. :param model_name: :param model_dir: :param handler: :param gpu_id: :param batch_size: :return: """ logging.debug("Loading model - working dir: %s", os.getcwd()) # TODO: Request ID is not given. UUID is a temp UUID. metrics = MetricsStore(uuid.uuid4(), model_name) manifest_file = os.path.join(model_dir, "MAR-INF/MANIFEST.json") manifest = None if os.path.exists(manifest_file): with open(manifest_file) as f: manifest = json.load(f) temp = handler.split(":", 1) module_name = temp[0] function_name = None if len(temp) == 1 else temp[1] if module_name.endswith(".py"): module_name = module_name[:-3] module_name = module_name.split("/")[-1] self.module = importlib.import_module(module_name) if self.module is None: raise ValueError("Unable to load module {}, make sure it is added to python path".format(module_name)) if function_name is None: function_name = "handle" if hasattr(self.module, function_name): entry_point = getattr(self.module, function_name) service = Service(model_name, model_dir, manifest, entry_point, gpu_id, batch_size) service.context.metrics = metrics # initialize model at load time entry_point(None, service.context) else: model_class_definitions = ModelLoader.list_model_services(self.module) if len(model_class_definitions) != 1: raise ValueError("Expected only one class in custom service code or a function entry point {}".format( model_class_definitions)) model_class = model_class_definitions[0] model_service = model_class() handle = getattr(model_service, "handle") if handle is None: raise ValueError("Expect handle method in class {}".format(str(model_class))) service = Service(model_name, model_dir, manifest, model_service.handle, gpu_id, batch_size) initialize = getattr(model_service, "initialize") if initialize is not None: # noinspection PyBroadException try: model_service.initialize(service.context) # pylint: disable=broad-except except Exception as e: # noinspection PyBroadException logging.exception(e) try: sys.exc_clear() # pylint: disable=broad-except except Exception: pass return service def unload(self): # to make sure logs emitted from model on exit get into mms logs, # do not delete logging python module module_vars = [var for var in vars(self.module) if not var.startswith('__') and not var == "logging"] for var in module_vars: delattr(self.module, var) del self.module class LegacyModelLoader(ModelLoader): """ MMS 0.4 Model Loader """ def load(self, model_name, model_dir, handler, gpu_id, batch_size): """ Load MMS 0.3 model from file. :param model_name: :param model_dir: :param handler: :param gpu_id: :param batch_size: :return: """ manifest_file = os.path.join(model_dir, "MANIFEST.json") manifest = None if os.path.isfile(manifest_file): with open(manifest_file) as f: manifest = json.load(f) if not handler.endswith(".py"): handler = handler + ".py" service_file = os.path.join(model_dir, handler) name = os.path.splitext(os.path.basename(service_file))[0] if sys.version_info[0] > 2: from importlib import util spec = util.spec_from_file_location(name, service_file) module = util.module_from_spec(spec) spec.loader.exec_module(module) else: import imp module = imp.load_source(name, service_file) if module is None: raise ValueError("Unable to load module {}".format(service_file)) from mms.model_service.mxnet_model_service import SingleNodeService model_class_definitions = ModelLoader.list_model_services(module, SingleNodeService) module_class = model_class_definitions[0] module = module_class(model_name, model_dir, manifest, gpu_id) service = Service(model_name, model_dir, manifest, module.handle, gpu_id, batch_size) module.initialize(service.context) return service ================================================ FILE: mms/model_server.py ================================================ """ File to define the entry point to Model Server """ import os import re import subprocess import sys import tempfile from builtins import str import psutil from mms.arg_parser import ArgParser def old_start(): """ This is the entry point for model server when using the old CLI name (mxnet-model-server). Please migrate to multi-model-server in the future :return: """ print("Warning: Calling MMS with mxnet-model-server. Please move to multi-model-server.") start() def start(): """ This is the entry point for model server :return: """ args = ArgParser.mms_parser().parse_args() pid_file = os.path.join(tempfile.gettempdir(), ".model_server.pid") pid = None if os.path.isfile(pid_file): with open(pid_file, "r") as f: pid = int(f.readline()) # pylint: disable=too-many-nested-blocks if args.stop: if pid is None: print("Model server is not currently running.") else: try: parent = psutil.Process(pid) for child in parent.children(recursive=True): child.terminate() for child in parent.children(recursive=True): if psutil.pid_exists(child.pid): child.kill() parent.terminate() if psutil.pid_exists(parent.pid): parent.kill() print("Model server stopped.") except (OSError, psutil.Error): print("Model server already stopped.") os.remove(pid_file) else: if pid is not None: try: psutil.Process(pid) print("Model server is already running, please use multi-model-server --stop to stop MMS.") exit(1) except psutil.Error: print("Removing orphan pid file.") os.remove(pid_file) java_home = os.environ.get("JAVA_HOME") java = "java" if not java_home else "{}/bin/java".format(java_home) mms_home = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) cmd = [java, "-Dmodel_server_home={}".format(mms_home)] if args.log_config: log_config = os.path.realpath(args.log_config) if not os.path.isfile(log_config): print("--log-config file not found: {}".format(log_config)) exit(1) cmd.append("-Dlog4j.configurationFile=file://{}".format(log_config)) tmp_dir = os.environ.get("TEMP") if tmp_dir: if not os.path.isdir(tmp_dir): print("Invalid temp directory: {}, please check TEMP environment variable.".format(tmp_dir)) exit(1) cmd.append("-Djava.io.tmpdir={}".format(tmp_dir)) mms_config = args.mms_config mms_conf_file = None if mms_config: if mms_config == "sagemaker": mms_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs", "sagemaker_config.properties") if not os.path.isfile(mms_config): print("--mms-config file not found: {}".format(mms_config)) exit(1) mms_conf_file = mms_config else: mms_conf_file = "config.properties" class_path = \ ".:{}".format(os.path.join(mms_home, "mms/frontend/*")) if os.path.isfile(mms_conf_file): props = load_properties(mms_conf_file) vm_args = props.get("vmargs") if vm_args: print("Warning: MMS is using non-default JVM parameters: {}".format(vm_args)) arg_list = vm_args.split() if args.log_config: for word in arg_list[:]: if word.startswith("-Dlog4j.configurationFile="): arg_list.remove(word) cmd.extend(arg_list) plugins = props.get("plugins_path", None) if plugins: class_path += ":" + plugins + "/*" if "*" not in plugins else ":" + plugins cmd.append("-cp") cmd.append(class_path) cmd.append("com.amazonaws.ml.mms.ModelServer") # model-server.jar command line parameters cmd.append("--python") cmd.append(sys.executable) if mms_conf_file is not None: cmd.append("-f") cmd.append(mms_conf_file) if args.model_store: if not os.path.isdir(args.model_store): print("--model-store directory not found: {}".format(args.model_store)) exit(1) cmd.append("-s") cmd.append(args.model_store) if args.models: cmd.append("-m") cmd.extend(args.models) if not args.model_store: pattern = re.compile(r"(.+=)?http(s)?://.+", re.IGNORECASE) for model_url in args.models: if not pattern.match(model_url) and model_url != "ALL": print("--model-store is required to load model locally.") exit(1) try: process = subprocess.Popen(cmd) pid = process.pid with open(pid_file, "w") as pf: pf.write(str(pid)) if args.foreground: process.wait() except OSError as e: if e.errno == 2: print("java not found, please make sure JAVA_HOME is set properly.") else: print("start java frontend failed:", sys.exc_info()) def load_properties(file_path): """ Read properties file into map. """ props = {} with open(file_path, "rt") as f: for line in f: line = line.strip() if not line.startswith("#"): pair = line.split("=", 1) if len(pair) > 1: key = pair[0].strip() props[key] = pair[1].strip() return props if __name__ == "__main__": start() ================================================ FILE: mms/model_service/__init__.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Model services code """ import warnings from . import model_service from . import mxnet_model_service from . import mxnet_vision_service warnings.warn("Module mms.model_service is deprecated, please migrate to model archive 1.0 format.", DeprecationWarning, stacklevel=2) ================================================ FILE: mms/model_service/gluon_vision_service.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """`Gluon vision service` defines a Gluon base vision service """ import numpy as np import mxnet from mms.model_service.mxnet_model_service import GluonImperativeBaseService from mms.utils.mxnet import ndarray class GluonVisionService(GluonImperativeBaseService): """MXNetVisionService defines a fundamental service for image classification task. In preprocess, input image buffer is read to NDArray and resized respect to input shape in signature. In post process, top-5 labels are returned. """ def _preprocess(self, data): img_list = [] for idx, img in enumerate(data): input_shape = self.signature['inputs'][idx]['data_shape'] # We are assuming input shape is NCHW [h, w] = input_shape[2:] img_arr = mxnet.img.imdecode(img) img_arr = mxnet.image.imresize(img_arr, w, h) img_arr = img_arr.astype(np.float32) img_arr /= 255 img_arr = mxnet.image.color_normalize(img_arr, mean=mxnet.nd.array([0.485, 0.456, 0.406]), std=mxnet.nd.array([0.229, 0.224, 0.225])) img_arr = mxnet.nd.transpose(img_arr, (2, 0, 1)) img_arr = img_arr.expand_dims(axis=0) img_list.append(img_arr) return img_list def _inference(self, data): """ Internal inference methods for MMS service. Run forward computation and return output. Parameters ---------- data : list of NDArray Preprocessed inputs in NDArray format. Returns ------- list of NDArray Inference output. """ # Check input shape super(GluonVisionService, self)._inference(data) output = self.net(data[0]) return output.softmax() def _postprocess(self, data): assert hasattr(self, 'labels'), \ "Can't find labels attribute. Did you put synset.txt file into " \ "model archive or manually load class label file in __init__?" return [ndarray.top_probability(d, self.labels, top=5) for d in data] ================================================ FILE: mms/model_service/model_service.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """`ModelService` defines an API for base model service. """ # pylint: disable=W0223 import ast import json import logging import os import time from abc import ABCMeta, abstractmethod class ModelService(object): """ ModelService wraps up all preprocessing, inference and postprocessing functions used by model service. It is defined in a flexible manner to be easily extended to support different frameworks. """ __metaclass__ = ABCMeta # noinspection PyUnusedLocal def __init__(self, model_name, model_dir, manifest, gpu=None): # pylint: disable=unused-argument self.ctx = None self._context = None self._signature = None def initialize(self, context): """ Internal initialize ModelService. :param context: MMS context object :return: """ self._context = context properties = context.system_properties model_dir = properties.get("model_dir") signature_file_path = os.path.join(model_dir, context.manifest['Model']['Signature']) if not os.path.isfile(signature_file_path): raise ValueError("Signature file is not found.") with open(signature_file_path) as f: self._signature = json.load(f) @abstractmethod def inference(self, data): """ Wrapper function to run pre-process, inference and post-process functions. Parameters ---------- data : list of object Raw input from request. Returns ------- list of outputs to be sent back to client. data to be sent back """ # pylint: disable=unnecessary-pass pass @abstractmethod def ping(self): """ Ping to get system's health. Returns ------- String A message, "health": "healthy!", to show system is healthy. """ pass # pylint: disable=unnecessary-pass def signature(self): """ Signature for model service. Returns ------- Dict Model service signature. """ return self._signature # noinspection PyUnusedLocal def handle(self, data, context): # pylint: disable=unused-argument """ Backward compatible handle function. :param data: :param context: :return: """ input_type = self._signature['input_type'] input_data = [] data_name = self._signature["inputs"][0]["data_name"] form_data = data[0].get(data_name) if form_data is None: form_data = data[0].get("body") if form_data is None: form_data = data[0].get("data") if input_type == "application/json": # user might not send content in HTTP request if isinstance(form_data, (bytes, bytearray)): form_data = ast.literal_eval(form_data.decode("utf-8")) input_data.append(form_data) ret = self.inference(input_data) if isinstance(ret, list): return ret return [ret] class SingleNodeService(ModelService): """ SingleNodeModel defines abstraction for model service which loads a single model. """ def inference(self, data): """ Wrapper function to run preprocess, inference and postprocess functions. Parameters ---------- data : list of object Raw input from request. Returns ------- list of outputs to be sent back to client. data to be sent back """ preprocess_start = time.time() data = self._preprocess(data) inference_start = time.time() data = self._inference(data) postprocess_start = time.time() data = self._postprocess(data) end_time = time.time() logging.info("preprocess time: %.2f", (inference_start - preprocess_start) * 1000) logging.info("inference time: %.2f", (postprocess_start - inference_start) * 1000) logging.info("postprocess time: %.2f", (end_time - postprocess_start) * 1000) return data @abstractmethod def _inference(self, data): """ Internal inference methods. Run forward computation and return output. Parameters ---------- data : list of NDArray Preprocessed inputs in NDArray format. Returns ------- list of NDArray Inference output. """ return data def _preprocess(self, data): """ Internal preprocess methods. Do transformation on raw inputs and convert them to NDArray. Parameters ---------- data : list of object Raw inputs from request. Returns ------- list of NDArray Processed inputs in NDArray format. """ return data def _postprocess(self, data): """ Internal postprocess methods. Do transformation on inference output and convert them to MIME type objects. Parameters ---------- data : list of NDArray Inference output. Returns ------- list of object list of outputs to be sent back. """ return data ================================================ FILE: mms/model_service/mxnet_model_service.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ `MXNetBaseService` defines an API for MXNet service. """ import json import os import logging import mxnet as mx from mxnet.io import DataBatch from .model_service import SingleNodeService def check_input_shape(inputs, signature): """ Check input data shape consistency with signature. Parameters ---------- inputs : List of NDArray Input data in NDArray format. signature : dict Dictionary containing model signature. """ assert isinstance(inputs, list), 'Input data must be a list.' assert len(inputs) == len(signature['inputs']), \ "Input number mismatches with " \ "signature. %d expected but got %d." \ % (len(signature['inputs']), len(inputs)) for input_data, sig_input in zip(inputs, signature["inputs"]): assert isinstance(input_data, mx.nd.NDArray), 'Each input must be NDArray.' assert len(input_data.shape) == len(sig_input["data_shape"]), \ 'Shape dimension of input %s mismatches with ' \ 'signature. %d expected but got %d.' \ % (sig_input['data_name'], len(sig_input['data_shape']), len(input_data.shape)) for idx in range(len(input_data.shape)): if idx != 0 and sig_input['data_shape'][idx] != 0: assert sig_input['data_shape'][idx] == input_data.shape[idx], \ 'Input %s has different shape with ' \ 'signature. %s expected but got %s.' \ % (sig_input['data_name'], sig_input['data_shape'], input_data.shape) class MXNetBaseService(SingleNodeService): """ MXNetBaseService defines the fundamental loading model and inference operations when serving MXNet model. This is a base class and needs to be inherited. """ def __init__(self, model_name, model_dir, manifest, gpu=None): super(MXNetBaseService, self).__init__(model_name, model_dir, manifest, gpu) self.param_filename = None self.model_name = model_name self.ctx = mx.gpu(int(gpu)) if gpu is not None else mx.cpu() signature_file_path = os.path.join(model_dir, manifest['Model']['Signature']) if not os.path.isfile(signature_file_path): raise RuntimeError('Signature file is not found. Please put signature.json ' 'into the model file directory...' + signature_file_path) try: signature_file = open(signature_file_path) self._signature = json.load(signature_file) except Exception: raise Exception('Failed to open model signature file: %s' % signature_file_path) data_names = [] data_shapes = [] epoch = 0 for input_data in self._signature['inputs']: data_names.append(input_data['data_name']) # Replace 0 entry in data shape with 1 for binding executor. # Set batch size as 1 data_shape = input_data['data_shape'] data_shape[0] = 1 # pylint: disable=consider-using-enumerate for idx in range(len(data_shape)): if data_shape[idx] == 0: data_shape[idx] = 1 data_shapes.append((input_data['data_name'], tuple(data_shape))) # Load MXNet module # noinspection PyBroadException try: self.param_filename = manifest['Model']['Parameters'] epoch = int(self.param_filename[len(model_name) + 1: -len('.params')]) except Exception: # pylint: disable=broad-except logging.info("Failed to parse epoch from param file, setting epoch to 0") sym, arg_params, aux_params = mx.model.load_checkpoint('%s/%s' % (model_dir, manifest['Model']['Symbol'][:-12]), epoch) self.mx_model = mx.mod.Module(symbol=sym, context=self.ctx, data_names=data_names, label_names=None) self.mx_model.bind(for_training=False, data_shapes=data_shapes) self.mx_model.set_params(arg_params, aux_params, allow_missing=True, allow_extra=True) # Read synset file # If synset is not specified, check whether model archive contains synset file. archive_synset = os.path.join(model_dir, 'synset.txt') if os.path.isfile(archive_synset): synset = archive_synset self.labels = [line.strip() for line in open(synset).readlines()] def _preprocess(self, data): return map(mx.nd.array, data) def _postprocess(self, data): return [str(d.asnumpy().tolist()) for d in data] def _inference(self, data): """Internal inference methods for MXNet. Run forward computation and return output. Parameters ---------- data : list of NDArray Preprocessed inputs in NDArray format. Returns ------- list of NDArray Inference output. """ # Check input shape check_input_shape(data, self.signature) data = [item.as_in_context(self.ctx) for item in data] self.mx_model.forward(DataBatch(data)) data = self.mx_model.get_outputs() # by pass lazy evaluation get_outputs either returns a list of nd arrays # a list of list of NDArray for d in data: if isinstance(d, list): for n in data: if isinstance(n, mx.ndarray.ndarray.NDArray): n.wait_to_read() elif isinstance(d, mx.ndarray.ndarray.NDArray): d.wait_to_read() return data def ping(self): """ Ping to get system's health. Returns ------- String MXNet version to show system is healthy. """ return mx.__version__ @property def signature(self): """ Signature for model service. Returns ------- Dict Model service signiture. """ return self._signature class GluonImperativeBaseService(SingleNodeService): """GluonImperativeBaseService defines the fundamental loading model and inference operations when serving Gluon model. This is a base class and needs to be inherited. """ def __init__(self, model_name, model_dir, manifest, net=None, gpu=None): super(GluonImperativeBaseService, self).__init__(model_name, model_dir, manifest, gpu) self.param_filename = None self.model_name = model_name self.ctx = mx.gpu(int(gpu)) if gpu is not None else mx.cpu() self.net = net signature_file_path = os.path.join(model_dir, manifest['Model']['Signature']) if not os.path.isfile(signature_file_path): raise RuntimeError('Signature file is not found. Please put signature.json ' 'into the model file directory...' + signature_file_path) try: signature_file = open(signature_file_path) self._signature = json.load(signature_file) except Exception: raise Exception('Failed to open model signature file: %s' % signature_file_path) # Load MXNet module # noinspection PyBroadException try: self.param_filename = manifest['Model']['Parameters'] if self.param_filename or self.net is not None: self.net.load_params(os.path.join(model_dir, self.param_filename), ctx=self.ctx) else: logging.info("No parameters file given for this imperative service") except Exception: # pylint: disable=broad-except logging.info("Failed to parse epoch from param file, setting epoch to 0") # Read synset file # If synset is not specified, check whether model archive contains synset file. archive_synset = os.path.join(model_dir, 'synset.txt') if os.path.isfile(archive_synset): synset = archive_synset self.labels = [line.strip() for line in open(synset).readlines()] def _preprocess(self, data): pass def _postprocess(self, data): pass def _inference(self, data): check_input_shape(data, self.signature) def ping(self): """ Ping to get system's health. Returns ------- String MXNet version to show system is healthy. """ return mx.__version__ @property def signature(self): """ Signature for model service. Returns ------- Dict Model service signature. """ return self._signature ================================================ FILE: mms/model_service/mxnet_vision_service.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """`MXNetVisionService` defines a MXNet base vision service """ from mms.model_service.mxnet_model_service import MXNetBaseService from mms.utils.mxnet import image, ndarray class MXNetVisionService(MXNetBaseService): """MXNetVisionService defines a fundamental service for image classification task. In preprocess, input image buffer is read to NDArray and resized respect to input shape in signature. In post process, top-5 labels are returned. """ def _preprocess(self, data): img_list = [] for idx, img in enumerate(data): input_shape = self.signature['inputs'][idx]['data_shape'] # We are assuming input shape is NCHW [h, w] = input_shape[2:] img_arr = image.read(img) img_arr = image.resize(img_arr, w, h) img_arr = image.transform_shape(img_arr) img_list.append(img_arr) return img_list def _postprocess(self, data): assert hasattr(self, 'labels'), \ "Can't find labels attribute. Did you put synset.txt file into " \ "model archive or manually load class label file in __init__?" return [ndarray.top_probability(d, self.labels, top=5) for d in data] ================================================ FILE: mms/model_service_worker.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ ModelServiceWorker is the worker that is started by the MMS front-end. Communication message format: binary encoding """ # pylint: disable=redefined-builtin import logging import os import multiprocessing import platform import socket import sys import signal from mms.arg_parser import ArgParser from mms.model_loader import ModelLoaderFactory from mms.protocol.otf_message_handler import retrieve_msg, create_load_model_response from mms.service import emit_metrics MAX_FAILURE_THRESHOLD = 5 SOCKET_ACCEPT_TIMEOUT = 30.0 DEBUG = False class MXNetModelServiceWorker(object): """ Backend worker to handle Model Server's python service code """ def __init__(self, s_type=None, s_name=None, host_addr=None, port_num=None, model_request=None, preload_model=False, tmp_dir="/tmp"): if os.environ.get("OMP_NUM_THREADS") is None: os.environ["OMP_NUM_THREADS"] = "1" if os.environ.get("MXNET_USE_OPERATOR_TUNING") is None: # work around issue: https://github.com/apache/incubator-mxnet/issues/12255 os.environ["MXNET_USE_OPERATOR_TUNING"] = "0" self.sock_type = s_type if s_type == "unix": if s_name is None: raise ValueError("Wrong arguments passed. No socket name given.") self.sock_name, self.port = s_name, -1 try: os.remove(s_name) except OSError: if os.path.exists(s_name): raise RuntimeError("socket already in use: {}.".format(s_name)) elif s_type == "tcp": self.sock_name = host_addr if host_addr is not None else "127.0.0.1" if port_num is None: raise ValueError("Wrong arguments passed. No socket port given.") self.port = port_num else: raise ValueError("Invalid socket type provided") logging.info("Listening on port: %s", s_name) socket_family = socket.AF_INET if s_type == "tcp" else socket.AF_UNIX self.sock = socket.socket(socket_family, socket.SOCK_STREAM) self.preload = preload_model self.service = None self.model_meta_data = model_request self.out = self.err = None self.tmp_dir = tmp_dir self.socket_name = s_name def load_model(self, load_model_request=None): """ Expected command { "command" : "load", string "modelPath" : "/path/to/model/file", string "modelName" : "name", string "gpu" : None if CPU else gpu_id, int "handler" : service handler entry point if provided, string "batchSize" : batch size, int } :param load_model_request: :return: """ try: model_dir = load_model_request["modelPath"].decode("utf-8") model_name = load_model_request["modelName"].decode("utf-8") handler = load_model_request["handler"].decode("utf-8") batch_size = 1 if "batchSize" in load_model_request: batch_size = int(load_model_request["batchSize"]) gpu = None if "gpu" in load_model_request: gpu = int(load_model_request["gpu"]) io_fd = None if "ioFileDescriptor" in load_model_request: io_fd = load_model_request.get("ioFileDescriptor").decode("utf-8") self._create_io_files(self.tmp_dir, io_fd) if self.service is None or self.preload is False: self.model_loader = ModelLoaderFactory.get_model_loader(model_dir) self.service = self.model_loader.load(model_name, model_dir, handler, gpu, batch_size) logging.info("Model %s loaded io_fd=%s", model_name, str(io_fd)) return "loaded model {}. [PID]:{}".format(model_name, os.getpid()), 200 except MemoryError: return "System out of memory", 507 def _create_io_files(self, tmp_dir, io_fd): self.out = tmp_dir + '/' + io_fd + "-stdout" self.err = tmp_dir + '/' + io_fd + "-stderr" # TODO: Windows support os.mkfifo(self.out) os.mkfifo(self.err) def _remap_io(self): out_fd = open(self.out, "w") err_fd = open(self.err, "w") os.dup2(out_fd.fileno(), sys.stdout.fileno()) os.dup2(err_fd.fileno(), sys.stderr.fileno()) def handle_connection(self, cl_socket): """ Handle socket connection. :param cl_socket: :return: """ logging.basicConfig(stream=sys.stdout, format="%(message)s", level=logging.INFO) cl_socket.setblocking(True) while True: cmd, msg = retrieve_msg(cl_socket) if cmd == b'I': resp = self.service.predict(msg) cl_socket.send(resp) elif cmd == b'L': result, code = self.load_model(msg) resp = bytearray() resp += create_load_model_response(code, result) cl_socket.send(resp) self._remap_io() if code != 200: raise RuntimeError("{} - {}".format(code, result)) else: raise ValueError("Received unknown command: {}".format(cmd)) if self.service is not None and self.service.context is not None \ and self.service.context.metrics is not None: emit_metrics(self.service.context.metrics.store) def sigterm_handler(self): for node in [self.socket_name, self.out, self.err]: try: os.remove(node) except OSError: pass def start_worker(self, cl_socket): """ Method to start the worker threads. These worker threads use multiprocessing to spawn a new worker. :param cl_socket: :return: """ self.sock.close() # close listening socket in the fork try: signal.signal(signal.SIGTERM, lambda signum, frame: self.sigterm_handler()) self.handle_connection(cl_socket) except Exception: # pylint: disable=broad-except logging.error("Backend worker process died.", exc_info=True) finally: try: self.model_loader.unload() sys.stdout.flush() os.remove(self.out) os.remove(self.err) finally: cl_socket.shutdown(socket.SHUT_RDWR) cl_socket.close() sys.exit(0) def run_server(self): """ Run the backend worker process and listen on a socket :return: """ if self.sock_type == "unix": self.sock.bind(self.sock_name) else: self.sock.bind((self.sock_name, int(self.port))) self.sock.listen(128) logging.info("[PID] %d", os.getpid()) logging.info("MMS worker started.") logging.info("Python runtime: %s", platform.python_version()) while True: if self.service is None and self.preload is True: # Lazy loading the models self.load_model(self.model_meta_data) (cl_socket, _) = self.sock.accept() # workaround error(35, 'Resource temporarily unavailable') on OSX cl_socket.setblocking(True) logging.info("Connection accepted: %s.", cl_socket.getsockname()) p = multiprocessing.Process(target=self.start_worker, args=(cl_socket,)) p.start() cl_socket.close() # close accepted socket in the parent if __name__ == "__main__": # Remove mms dir from python path to avoid module name conflict. mms_path = os.path.dirname(os.path.realpath(__file__)) while mms_path in sys.path: sys.path.remove(mms_path) sock_type = None socket_name = None # noinspection PyBroadException try: logging.basicConfig(stream=sys.stdout, format="%(message)s", level=logging.INFO) logging.info("model_service_worker started with args: %s", " ".join(sys.argv[1:])) model_req = dict() args = ArgParser.model_service_worker_args().parse_args() socket_name = args.sock_name sock_type = args.sock_type host = args.host port = args.port model_req["handler"] = args.handler.encode('utf-8') model_req["modelPath"] = args.model_path.encode('utf-8') model_req["modelName"] = args.model_name.encode('utf-8') worker = MXNetModelServiceWorker(sock_type, socket_name, host, port, model_req, args.preload_model, args.tmp_dir) worker.run_server() except socket.timeout: logging.error("Backend worker did not receive connection in: %d", SOCKET_ACCEPT_TIMEOUT) except Exception: # pylint: disable=broad-except logging.error("Backend worker process died", exc_info=True) finally: if sock_type == 'unix' and os.path.exists(socket_name): os.remove(socket_name) exit(1) ================================================ FILE: mms/protocol/__init__.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. ================================================ FILE: mms/protocol/otf_message_handler.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ OTF Codec """ import json import logging import struct import os from builtins import bytearray from builtins import bytes int_size = 4 END_OF_LIST = -1 LOAD_MSG = b'L' PREDICT_MSG = b'I' RESPONSE = 3 def retrieve_msg(conn): """ Retrieve a message from the socket channel. :param conn: :return: """ cmd = _retrieve_buffer(conn, 1) if cmd == LOAD_MSG: msg = _retrieve_load_msg(conn) elif cmd == PREDICT_MSG: msg = _retrieve_inference_msg(conn) else: raise ValueError("Invalid command: {}".format(cmd)) return cmd, msg def encode_response_headers(resp_hdr_map): msg = bytearray() msg += struct.pack('!i', len(resp_hdr_map)) for k, v in resp_hdr_map.items(): msg += struct.pack('!i', len(k.encode('utf-8'))) msg += k.encode('utf-8') msg += struct.pack('!i', len(v.encode('utf-8'))) msg += v.encode('utf-8') return msg def create_predict_response(ret, req_id_map, message, code, context=None): """ Create inference response. :param context: :param ret: :param req_id_map: :param message: :param code: :return: """ msg = bytearray() msg += struct.pack('!i', code) buf = message.encode("utf-8") msg += struct.pack('!i', len(buf)) msg += buf for idx in req_id_map: req_id = req_id_map.get(idx).encode('utf-8') msg += struct.pack("!i", len(req_id)) msg += req_id # Encoding Content-Type if context is None: msg += struct.pack('!i', 0) # content_type else: content_type = context.get_response_content_type(idx) if content_type is None or len(content_type) == 0: msg += struct.pack('!i', 0) # content_type else: msg += struct.pack('!i', len(content_type)) msg += content_type.encode('utf-8') # Encoding the per prediction HTTP response code if context is None: # status code and reason phrase set to none msg += struct.pack('!i', code) msg += struct.pack('!i', 0) # No code phrase is returned # Response headers none msg += struct.pack('!i', 0) else: sc, phrase = context.get_response_status(idx) http_code = sc if sc is not None else 200 http_phrase = phrase if phrase is not None else "" msg += struct.pack('!i', http_code) msg += struct.pack("!i", len(http_phrase)) msg += http_phrase.encode("utf-8") # Response headers msg += encode_response_headers(context.get_response_headers(idx)) if ret is None: buf = b"error" msg += struct.pack('!i', len(buf)) msg += buf else: val = ret[idx] # NOTE: Process bytes/bytearray case before processing the string case. if isinstance(val, (bytes, bytearray)): msg += struct.pack('!i', len(val)) msg += val elif isinstance(val, str): buf = val.encode("utf-8") msg += struct.pack('!i', len(buf)) msg += buf else: try: json_value = json.dumps(val, indent=2).encode("utf-8") msg += struct.pack('!i', len(json_value)) msg += json_value except TypeError: logging.warning("Unable to serialize model output.", exc_info=True) return create_predict_response(None, req_id_map, "Unsupported model output data type.", 503) msg += struct.pack('!i', -1) # End of list return msg def create_load_model_response(code, message): """ Create load model response. :param code: :param message: :return: """ msg = bytearray() msg += struct.pack('!i', code) buf = message.encode("utf-8") msg += struct.pack('!i', len(buf)) msg += buf msg += struct.pack('!i', -1) # no predictions return msg def _retrieve_buffer(conn, length): data = bytearray() while length > 0: pkt = conn.recv(length) if len(pkt) == 0: logging.info("Frontend disconnected.") raise ValueError("Frontend disconnected") data += pkt length -= len(pkt) return data def _retrieve_int(conn): data = _retrieve_buffer(conn, int_size) return struct.unpack("!i", data)[0] def _retrieve_load_msg(conn): """ MSG Frame Format: | cmd value | | int model-name length | model-name value | | int model-path length | model-path value | | int batch-size length | | int handler length | handler value | | int gpu id | :param conn: :return: """ msg = dict() length = _retrieve_int(conn) msg["modelName"] = _retrieve_buffer(conn, length) length = _retrieve_int(conn) msg["modelPath"] = _retrieve_buffer(conn, length) msg["batchSize"] = _retrieve_int(conn) length = _retrieve_int(conn) msg["handler"] = _retrieve_buffer(conn, length) gpu_id = _retrieve_int(conn) if gpu_id >= 0: msg["gpu"] = gpu_id length = _retrieve_int(conn) msg["ioFileDescriptor"] = _retrieve_buffer(conn, length) return msg def _retrieve_inference_msg(conn): """ MSG Frame Format: | cmd value | | batch: list of requests | """ msg = [] while True: request = _retrieve_request(conn) if request is None: break msg.append(request) return msg def _retrieve_request(conn): """ MSG Frame Format: | request_id | | request_headers: list of request headers| | parameters: list of request parameters | """ length = _retrieve_int(conn) if length == -1: return None request = dict() request["requestId"] = _retrieve_buffer(conn, length) headers = [] while True: header = _retrieve_reqest_header(conn) if header is None: break headers.append(header) request["headers"] = headers model_inputs = [] while True: input_data = _retrieve_input_data(conn) if input_data is None: break model_inputs.append(input_data) request["parameters"] = model_inputs return request def _retrieve_reqest_header(conn): """ MSG Frame Format: | parameter_name | | content_type | | input data in bytes | """ length = _retrieve_int(conn) if length == -1: return None header = dict() header["name"] = _retrieve_buffer(conn, length) length = _retrieve_int(conn) header["value"] = _retrieve_buffer(conn, length) return header def _retrieve_input_data(conn): """ MSG Frame Format: | parameter_name | | content_type | | input data in bytes | """ decode_req = os.environ.get("MMS_DECODE_INPUT_REQUEST") length = _retrieve_int(conn) if length == -1: return None model_input = dict() model_input["name"] = _retrieve_buffer(conn, length).decode("utf-8") length = _retrieve_int(conn) content_type = _retrieve_buffer(conn, length).decode("utf-8") model_input["contentType"] = content_type length = _retrieve_int(conn) value = _retrieve_buffer(conn, length) if content_type == "application/json" and (decode_req is None or decode_req == "true"): model_input["value"] = json.loads(value.decode("utf-8")) elif content_type.startswith("text") and (decode_req is None or decode_req == "true"): model_input["value"] = value.decode("utf-8") else: model_input["value"] = value return model_input ================================================ FILE: mms/service.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ CustomService class definitions """ import logging import time from builtins import str import mms from mms.context import Context, RequestProcessor from mms.metrics.metrics_store import MetricsStore from mms.protocol.otf_message_handler import create_predict_response PREDICTION_METRIC = 'PredictionTime' logger = logging.getLogger(__name__) class Service(object): """ Wrapper for custom entry_point """ def __init__(self, model_name, model_dir, manifest, entry_point, gpu, batch_size): self._context = Context(model_name, model_dir, manifest, batch_size, gpu, mms.__version__) self._entry_point = entry_point @property def context(self): return self._context @staticmethod def retrieve_data_for_inference(batch): """ REQUEST_INPUT = { "requestId" : "111-222-3333", "parameters" : [ PARAMETER ] } PARAMETER = { "name" : parameter name "contentType": "http-content-types", "value": "val1" } :param batch: :return: """ if batch is None: raise ValueError("Received invalid inputs") req_to_id_map = {} headers = [] input_batch = [] for batch_idx, request_batch in enumerate(batch): req_id = request_batch.get('requestId').decode("utf-8") parameters = request_batch['parameters'] model_in_headers = dict() model_in = dict() # Parameter level headers are updated here. multipart/form-data can have multiple headers. for parameter in parameters: model_in.update({parameter["name"]: parameter["value"]}) model_in_headers.update({parameter["name"]: {"content-type": parameter["contentType"]}}) # Request level headers are populated here if request_batch.get("headers") is not None: for h in request_batch.get("headers"): model_in_headers.update({h['name'].decode('utf-8'): h['value'].decode('utf-8')}) headers.append(RequestProcessor(model_in_headers)) input_batch.append(model_in) req_to_id_map[batch_idx] = req_id return headers, input_batch, req_to_id_map def predict(self, batch): """ PREDICT COMMAND = { "command": "predict", "batch": [ REQUEST_INPUT ] } :param batch: list of request :return: """ headers, input_batch, req_id_map = Service.retrieve_data_for_inference(batch) self.context.request_ids = req_id_map self.context.request_processor = headers metrics = MetricsStore(req_id_map, self.context.model_name) self.context.metrics = metrics start_time = time.time() # noinspection PyBroadException try: ret = self._entry_point(input_batch, self.context) except PredictionException as e: logger.error("Prediction error", exc_info=True) return create_predict_response(None, req_id_map, e.message, e.error_code) except MemoryError: logger.error("System out of memory", exc_info=True) return create_predict_response(None, req_id_map, "Out of resources", 507) except Exception: # pylint: disable=broad-except logger.warning("Invoking custom service failed.", exc_info=True) return create_predict_response(None, req_id_map, "Prediction failed", 503) if not isinstance(ret, list): logger.warning("model: %s, Invalid return type: %s.", self.context.model_name, type(ret)) return create_predict_response(None, req_id_map, "Invalid model predict output", 503) if len(ret) != len(input_batch): logger.warning("model: %s, number of batch response mismatched, expect: %d, got: %d.", self.context.model_name, len(input_batch), len(ret)) return create_predict_response(None, req_id_map, "number of batch response mismatched", 503) duration = round((time.time() - start_time) * 1000, 2) metrics.add_time(PREDICTION_METRIC, duration) return create_predict_response(ret, req_id_map, "Prediction success", 200, context=self.context) class PredictionException(Exception): def __init__(self, message, error_code=500): self.message = message self.error_code = error_code super(PredictionException, self).__init__(message) def __str__(self): return "{message} : {error_code}".format(message=self.message, error_code=self.error_code) def emit_metrics(metrics): """ Emit the metrics in the provided Dictionary Parameters ---------- metrics: Dictionary A dictionary of all metrics, when key is metric_name value is a metric object """ if metrics: for met in metrics: logger.info("[METRICS]%s", str(met)) ================================================ FILE: mms/tests/README.md ================================================ # Testing MMS ## Pre-requisites You will need some additional Python modules to run the unit tests and linting. ```bash pip install mock pytest pylint ``` You will also need the source for the project, so clone the project first. ```bash git clone https://github.com/awslabs/multi-model-server.git cd multi-model-server ``` ## Unit Tests You can run the unit tests with the following: ```bash python -m pytest mms/tests/unit_tests/ ``` To get the coverage report of unit tests, you can run : ```bash python -m pytest --cov-report term-missing --cov=mms/ mms/tests/unit_tests/ ``` or: ```bash python -m pytest --cov-report html:htmlcov --cov=mms/ mms/tests/unit_tests/ ``` ## Lint test You can run the lint tests with the following: ```bash pylint -rn --rcfile=./mms/tests/pylintrc mms/. ``` ================================================ FILE: mms/tests/pylintrc ================================================ [MASTER] # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # Specify a configuration file. #rcfile= # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). #init-hook= # Add files or directories to the blacklist. They should be base names, not # paths. ignore=CVS # Add files or directories matching the regex patterns to the blacklist. The # regex matches against base names, not paths. ignore-patterns= # Pickle collected data for later comparisons. persistent=yes # List of plugins (as comma separated values of python modules names) to load, # usually to register additional checkers. load-plugins= # Use multiple processes to speed up Pylint. jobs=8 # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code extension-pkg-whitelist=numpy,opencv # Allow optimization of some AST trees. This will activate a peephole AST # optimizer, which will apply various small optimizations. For instance, it can # be used to obtain the result of joining multiple strings with the addition # operator. Joining a lot of strings can lead to a maximum recursion error in # Pylint and this flag can prevent that. It has one side effect, the resulting # AST will be different than the one from reality. This option is deprecated # and it will be removed in Pylint 2.0. optimize-ast=no [MESSAGES CONTROL] # Only show warnings with the listed confidence levels. Leave empty to show # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED confidence= # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). See also the "--disable" option for examples. enable=indexing-exception,old-raise-syntax # Disable the message, report, category or checker with the given id(s). You # can either give multiple identifiers separated by comma (,) or put this # option multiple times (only on the command line, not in the configuration # file where it should appear only once).You can also use "--disable=all" to # disable everything first and then reenable specific checks. For example, if # you want to run only the similarities checker, you can use "--disable=all # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,protected-access,superfluous-parens,invalid-name,no-else-return,useless-super-delegation,len-as-condition,invalid-unary-operand-type,useless-object-inheritance # disable=unicode-builtin,delslice-method,using-cmp-argument,setslice-method,dict-view-method,parameter-unpacking,range-builtin-not-iterating,print-statement,file-builtin,old-raise-syntax,basestring-builtin,execfile-builtin,indexing-exception,import-star-module-level,coerce-method,long-builtin,old-ne-operator,old-division,no-absolute-import,raw_input-builtin,old-octal-literal,oct-method,xrange-builtin,hex-method,unpacking-in-except,nonzero-method,raising-string,intern-builtin,reload-builtin,metaclass-assignment,cmp-method,filter-builtin-not-iterating,apply-builtin,map-builtin-not-iterating,next-method-called,unichr-builtin,buffer-builtin,dict-iter-method,input-builtin,coerce-builtin,getslice-method,useless-suppression,standarderror-builtin,zip-builtin-not-iterating,suppressed-message,cmp-builtin,backtick,long-suffix,reduce-builtin,round-builtin [REPORTS] # Set the output format. Available formats are text, parseable, colorized, msvs # (visual studio) and html. You can also give a reporter class, eg # mypackage.mymodule.MyReporterClass. output-format=text # Put messages in a separate file for each module / package specified on the # command line instead of printing them on stdout. Reports (if any) will be # written in a file name "pylint_global.[txt|html]". This option is deprecated # and it will be removed in Pylint 2.0. files-output=no # Tells whether to display a full report or only the messages reports=no # Python expression which should return a note less than 10 (10 is the highest # note). You have access to the variables errors warning, statement which # respectively contain the number of errors / warnings messages and the total # number of statements analyzed. This is used by the global evaluation report # (RP0004). evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) # Template used to display messages. This is a python new-style format string # used to format the message information. See doc for all details #msg-template= [FORMAT] # Maximum number of characters on a single line. max-line-length=120 # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines=^\s*(# )??$ # Allow the body of an if to be on the same line as the test if there is no # else. single-line-if-stmt=no # List of optional constructs for which whitespace checking is disabled. `dict- # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. # `trailing-comma` allows a space between comma and closing bracket: (a, ). # `empty-line` allows space-only lines. no-space-check=trailing-comma,dict-separator # Maximum number of lines in a module max-module-lines=1000 # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 # tab). indent-string=' ' # Number of spaces of indent required inside a hanging or continued line. indent-after-paren=4 # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format= [SPELLING] # Spelling dictionary name. Available dictionaries: none. To make it working # install python-enchant package. spelling-dict= # List of comma separated words that should not be checked. spelling-ignore-words= # A path to a file that contains private dictionary; one word per line. spelling-private-dict-file= # Tells whether to store unknown words to indicated private dictionary in # --spelling-private-dict-file option instead of raising a message. spelling-store-unknown-words=no [MISCELLANEOUS] # List of note tags to take in consideration, separated by a comma. notes=FIXME,XXX,TODO [TYPECHECK] # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). ignore-mixin-members=yes # List of module names for which member attributes should not be checked # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis. It # supports qualified module names, as well as Unix pattern matching. ignored-modules= # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. ignored-classes=optparse.Values,thread._local,_thread._local # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. generated-members= # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. contextmanager-decorators=contextlib.contextmanager [LOGGING] # Logging modules to check that the string format arguments are in logging # function parameter format logging-modules=logging [SIMILARITIES] # Minimum lines number of a similarity. min-similarity-lines=4 # Ignore comments when computing similarities. ignore-comments=yes # Ignore docstrings when computing similarities. ignore-docstrings=yes # Ignore imports when computing similarities. ignore-imports=no [VARIABLES] # Tells whether we should check for unused import in __init__ files. init-import=no # A regular expression matching the name of dummy variables (i.e. expectedly # not used). dummy-variables-rgx=(_+[a-zA-Z0-9]*?$)|dummy # List of additional names supposed to be defined in builtins. Remember that # you should avoid to define new builtins when possible. additional-builtins= # List of strings which can identify a callback function by name. A callback # name must start or end with one of those strings. callbacks=cb_,_cb # List of qualified module names which can have objects that can redefine # builtins. redefining-builtins-modules=six.moves,future.builtins,builtins [BASIC] # Good variable names which should always be accepted, separated by a comma good-names=i,j,_,a,b,op,x,y,wd,lr,kv,k,v,s,p,h,c,m,n,X,t,g,f # Bad variable names which should always be refused, separated by a comma bad-names= # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. name-group= # Include a hint for the correct naming format with invalid-name include-naming-hint=no # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. property-classes=abc.abstractproperty # Regular expression matching correct module names module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ # Naming hint for module names module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ # Regular expression matching correct constant names const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ # Naming hint for constant names const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ # Regular expression matching correct inline iteration names inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ # Naming hint for inline iteration names inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ # Regular expression matching correct method names method-rgx=[a-z_][a-z0-9_]{2,30}$ # Naming hint for method names method-name-hint=[a-z_][a-z0-9_]{2,30}$ # Regular expression matching correct class attribute names class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ # Naming hint for class attribute names class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ # Regular expression matching correct argument names argument-rgx=[a-z_][a-z0-9_]{2,30}$ # Naming hint for argument names argument-name-hint=[a-z_][a-z0-9_]{2,30}$ # Regular expression matching correct attribute names attr-rgx=[a-z_][a-z0-9_]{2,30}$ # Naming hint for attribute names attr-name-hint=[a-z_][a-z0-9_]{2,30}$ # Regular expression matching correct variable names variable-rgx=[a-z_][a-z0-9_]{2,30}$ # Naming hint for variable names variable-name-hint=[a-z_][a-z0-9_]{2,30}$ # Regular expression matching correct function names function-rgx=[a-z_][a-z0-9_]{2,30}$ # Naming hint for function names function-name-hint=[a-z_][a-z0-9_]{2,30}$ # Regular expression matching correct class names class-rgx=[A-Z_][a-zA-Z0-9]+$ # Naming hint for class names class-name-hint=[A-Z_][a-zA-Z0-9]+$ # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=^_ # Minimum line length for functions/classes that require docstrings, shorter # ones are exempt. docstring-min-length=10 [ELIF] # Maximum number of nested blocks for function / method body max-nested-blocks=5 [CLASSES] # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__,__new__,setUp # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg=mcs # List of member names, which should be excluded from the protected access # warning. exclude-protected=_asdict,_fields,_replace,_source,_make [IMPORTS] # Deprecated modules which should not be used, separated by a comma deprecated-modules=optparse # Create a graph of every (i.e. internal and external) dependencies in the # given file (report RP0402 must not be disabled) import-graph= # Create a graph of external dependencies in the given file (report RP0402 must # not be disabled) ext-import-graph= # Create a graph of internal dependencies in the given file (report RP0402 must # not be disabled) int-import-graph= # Force import order to recognize a module as part of the standard # compatibility libraries. known-standard-library= # Force import order to recognize a module as part of a third party library. known-third-party=enchant # Analyse import fallback blocks. This can be used to support both Python 2 and # 3 compatible code, which means that the block might have code that exists # only in one or another interpreter, leading to false positives when analysed. analyse-fallback-blocks=no [DESIGN] # Maximum number of arguments for function / method max-args=5 # Argument names that match this expression will be ignored. Default to name # with leading underscore ignored-argument-names=_.* # Maximum number of locals for function / method body max-locals=15 # Maximum number of return / yield for function / method body max-returns=6 # Maximum number of branch for function / method body max-branches=12 # Maximum number of statements in function / method body max-statements=50 # Maximum number of parents for a class (see R0901). max-parents=7 # Maximum number of attributes for a class (see R0902). max-attributes=7 # Minimum number of public methods for a class (see R0903). min-public-methods=2 # Maximum number of public methods for a class (see R0904). max-public-methods=20 # Maximum number of boolean expressions in a if statement max-bool-expr=5 [EXCEPTIONS] # Exceptions that will emit a warning when being caught. Defaults to # "Exception" overgeneral-exceptions=Exception ================================================ FILE: mms/tests/unit_tests/helper/__init__.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. ================================================ FILE: mms/tests/unit_tests/helper/pixel2pixel_service.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. import mxnet as mx import numpy as np import sys sys.path.append('../../..') from mms.model_service.mxnet_model_service import MXNetBaseService, check_input_shape from mms.utils.mxnet import image from mxnet import ndarray as nd from mxnet.gluon.nn import Dense, Activation, Conv2D, Conv2DTranspose, \ BatchNorm, LeakyReLU, Flatten, HybridSequential, HybridBlock, Dropout # Define Unet generator skip block class UnetSkipUnit(HybridBlock): def __init__(self, inner_channels, outer_channels, inner_block=None, innermost=False, outermost=False, use_dropout=False, use_bias=False): super(UnetSkipUnit, self).__init__() with self.name_scope(): self.outermost = outermost en_conv = Conv2D(channels=inner_channels, kernel_size=4, strides=2, padding=1, in_channels=outer_channels, use_bias=use_bias) en_relu = LeakyReLU(alpha=0.2) en_norm = BatchNorm(momentum=0.1, in_channels=inner_channels) de_relu = Activation(activation='relu') de_norm = BatchNorm(momentum=0.1, in_channels=outer_channels) if innermost: de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1, in_channels=inner_channels, use_bias=use_bias) encoder = [en_relu, en_conv] decoder = [de_relu, de_conv, de_norm] model = encoder + decoder elif outermost: de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1, in_channels=inner_channels * 2) encoder = [en_conv] decoder = [de_relu, de_conv, Activation(activation='tanh')] model = encoder + [inner_block] + decoder else: de_conv = Conv2DTranspose(channels=outer_channels, kernel_size=4, strides=2, padding=1, in_channels=inner_channels * 2, use_bias=use_bias) encoder = [en_relu, en_conv, en_norm] decoder = [de_relu, de_conv, de_norm] model = encoder + [inner_block] + decoder if use_dropout: model += [Dropout(rate=0.5)] self.model = HybridSequential() with self.model.name_scope(): for block in model: self.model.add(block) def hybrid_forward(self, F, x): if self.outermost: return self.model(x) else: return F.concat(self.model(x), x, dim=1) # Define Unet generator class UnetGenerator(HybridBlock): def __init__(self, in_channels, num_downs, ngf=64, use_dropout=True): super(UnetGenerator, self).__init__() # Build unet generator structure with self.name_scope(): unet = UnetSkipUnit(ngf * 8, ngf * 8, innermost=True) for _ in range(num_downs - 5): unet = UnetSkipUnit(ngf * 8, ngf * 8, unet, use_dropout=use_dropout) unet = UnetSkipUnit(ngf * 8, ngf * 4, unet) unet = UnetSkipUnit(ngf * 4, ngf * 2, unet) unet = UnetSkipUnit(ngf * 2, ngf * 1, unet) unet = UnetSkipUnit(ngf, in_channels, unet, outermost=True) self.model = unet def hybrid_forward(self, F, x): return self.model(x) class Pixel2pixelService(MXNetBaseService): def __init__(self, model_name, path): self.mx_model = UnetGenerator(in_channels=3, num_downs=8) self.mx_model.load_params('%s/%s.params' % (path, model_name), ctx=mx.cpu()) def _preprocess(self, data): input_shape = self.signature['inputs'][0]['data_shape'] height, width = input_shape[2:] img_arr = image.read(data[0]) img_arr = image.resize(img_arr, width, height) img_arr = image.color_normalize(img_arr, nd.array([127.5]), nd.array([127.5])) img_arr = image.transform_shape(img_arr) return [img_arr] def _inference(self, data): check_input_shape(data, self.signature) return self.mx_model(*data) def _postprocess(self, data): img_arr = ((data[0] + 1.0) * 127.5).astype(np.uint8) return [image.write(img_arr)] ================================================ FILE: mms/tests/unit_tests/model_service/dummy_model/MANIFEST.json ================================================ { "Engine": { "MXNet": 0.12 }, "Model-Archive-Description": "dummy", "License": "Apache 2.0", "Model-Archive-Version": 0.1, "Model-Server": 0.1, "Model": { "Description": "dummy model", "Service": "dummy_model_service.py", "Model-Name": "dummy", } } ================================================ FILE: mms/tests/unit_tests/model_service/dummy_model/dummy_model_service.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. from mms.model_service.model_service import SingleNodeService """ This file is a dummy file for the purpose of unit-testing test_service_manager.py """ class DummyNodeService(SingleNodeService): def _inference(self, data): pass def signature(self): pass def ping(self): pass def inference(self): pass class SomeOtherClass: def __init__(self): pass ================================================ FILE: mms/tests/unit_tests/model_service/test_mxnet_image.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. import os import sys curr_path = os.path.dirname(os.path.abspath(__file__)) sys.path.append(curr_path + '/../../..') import PIL import unittest import numpy as np import mxnet as mx import mms.utils.mxnet.image as image from io import BytesIO class TestMXNetImageUtils(unittest.TestCase): def _write_image(self, img_arr, flag=1): img_arr = mx.nd.transpose(img_arr, (1, 2, 0)) mode = 'RGB' if flag == 1 else 'L' if flag == 0: img_arr = mx.nd.reshape(img_arr, shape=(img_arr.shape[0], img_arr.shape[1])) img_arr = img_arr.astype(np.uint8).asnumpy() image = PIL.Image.fromarray(img_arr, mode) output = BytesIO() image.save(output, format='jpeg') return output.getvalue() def test_transform_shape(self): input1 = mx.nd.random.uniform(0, 255, shape=(32, 32, 3)) output1 = image.transform_shape(input1) assert output1.shape == (1, 3, 32, 32), "transform_shape method fail. Got %s shape." % (str(output1.shape)) input2 = mx.nd.random.uniform(0, 255, shape=(28, 28, 3)) output2 = image.transform_shape(input2, dim_order='NHWC') assert output2.shape == (1, 28, 28, 3), "transform_shape method fail. Got %s shape." % (str(output2.shape)) def test_read(self): input1 = mx.nd.random.uniform(0, 255, shape=(3, 256, 256)) input_buf1 = self._write_image(input1) output1 = image.read(input_buf1) assert output1.shape == (256, 256, 3), "Read method failed. Got %s shape." % (str(output1.shape)) input2 = mx.nd.random.uniform(0, 255, shape=(1, 128, 128)) input_buf2 = self._write_image(input2, flag=0) output2 = image.read(input_buf2, flag=0) assert output2.shape == (128, 128, 1), "Read method failed. Got %s shape." % (str(output2.shape)) def test_write(self): input1 = mx.nd.random.uniform(0, 255, shape=(3, 256, 256)) output1 = image.write(input1) assert isinstance(output1, str), "Write method failed. Output is not a string." input2 = mx.nd.random.uniform(0, 255, shape=(256, 256, 1)) output2 = image.write(input2, flag=0, dim_order='HWC') assert isinstance(output2, str), "Write method failed. Output is not a string." def test_resize(self): input1 = mx.nd.random.uniform(0, 255, shape=(245, 156, 3)) output1 = image.resize(input1, 128, 256) assert output1.shape == (256, 128, 3), "Resize method failed. Got %s shape." % (str(output1.shape)) def test_fix_crop(self): input1 = mx.nd.random.uniform(0, 255, shape=(100, 100, 3)) output1 = image.fixed_crop(input1, 10, 20, 50, 70) assert output1.shape == (70, 50, 3), "Resize method failed. Got %s shape." % (str(output1.shape)) def test_color_normalize(self): input1 = mx.nd.random.uniform(0, 255, shape=(1, 10, 10)) output1 = image.color_normalize(input1, 127.5, 127.5).asnumpy() assert (output1 >= -1.0).all() and (output1 <= 1.0).all(), "color_normalize method failed." def runTest(self): self.test_transform_shape() self.test_read() self.test_write() self.test_resize() self.test_fix_crop() self.test_color_normalize() ================================================ FILE: mms/tests/unit_tests/model_service/test_mxnet_ndarray.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. import os import sys curr_path = os.path.dirname(os.path.abspath(__file__)) sys.path.append(curr_path + '/../../..') import unittest import mxnet as mx import utils.mxnet.ndarray as ndarray class TestMXNetNDArrayUtils(unittest.TestCase): def test_top_prob(self): labels = ['dummay' for _ in range(100)] data = mx.nd.random.uniform(0, 1, shape=(1, 100)) top = 13 output = ndarray.top_probability(data, labels, top=top) assert len(output) == top, "top_probability method failed." def runTest(self): self.test_top_prob() ================================================ FILE: mms/tests/unit_tests/model_service/test_mxnet_nlp.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. import os import sys curr_path = os.path.dirname(os.path.abspath(__file__)) sys.path.append(curr_path + '/../../..') import unittest import utils.mxnet.nlp as nlp from random import randint class TestMXNetNLPUtils(unittest.TestCase): def test_encode_sentence(self): vocab = {} sentence = [] for i in range(100): vocab['word%d' % (i)] = i sen_vec = [0, 56, 8, 10] for i in sen_vec: sentence.append('word%d' % (i)) res1, out1 = nlp.encode_sentences([sentence], vocab) assert res1[0] == sen_vec, "encode_sentence method failed. " \ "Result vector invalid." assert len(out1) == len(vocab), "encode_sentence method failed. " \ "Generated vocab incorrect." res2, out2 = nlp.encode_sentences([sentence]) assert res2[0] == [i for i in range(len(sentence))], \ "encode_sentence method failed. Result vector invalid." assert len(out2) == len(sentence) + 1, "encode_sentence method failed. " \ "Generated vocab incorrect." def test_pad_sentence(self): buckets = [10, 20, 30, 40, 50, 60] for _ in range(5): sent_length = randint(1, 60) sentence = [i for i in range(sent_length)] databatch = nlp.pad_sentence(sentence, buckets) assert databatch.data[0].shape[1] in buckets, "pad_sentence failed. Padded sentence has length %d." \ % (databatch.data[0].shape[1]) ================================================ FILE: mms/tests/unit_tests/model_service/test_service.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. import json import os import shutil import sys import tempfile import unittest from io import BytesIO import PIL import mxnet as mx import numpy as np import pytest from helper.pixel2pixel_service import UnetGenerator from mms.model_service.mxnet_model_service import MXNetBaseService, GluonImperativeBaseService curr_path = os.path.dirname(os.path.abspath(__file__)) sys.path.append(curr_path + '/../..') def empty_file(path): open(path, 'a').close() def module_dir(tmpdir): path = '{}/test'.format(tmpdir) os.mkdir(path) empty_file('{}/test-symbol.json'.format(path)) empty_file('{}/test-0000.params'.format(path)) empty_file('{}/synset.txt'.format(path)) with open('{}/signature.json'.format(path), 'w') as sig: signature = { "input_type": "image/jpeg", "inputs": [ { 'data_name': 'data1', 'data_shape': [1, 3, 64, 64] }, { 'data_name': 'data2', 'data_shape': [1, 3, 32, 32] } ], "output_type": "application/json", "outputs": [ { 'data_name': 'softmax', 'data_shape': [1, 10] } ] } json.dump(signature, sig) return path def create_symbolic_manifest(path): with open('{}/MANIFEST.json'.format(path), 'w') as man: manifest = { "Engine": { "MXNet": 0.12 }, "Model-Archive-Description": "test", "License": "Apache 2.0", "Model-Archive-Version": 0.1, "Model-Server": 0.1, "Model": { "Description": "test", "Service": "test", "Symbol": "", "Parameters": "test-0000.params", "Signature": "signature.json", "Model-Name": "test", "Model-Format": "MXNet-Symbolic" } } json.dump(manifest, man) def create_imperative_manifest(path): with open('{}/MANIFEST.json'.format(path), 'w') as man: manifest = { "Engine": { "MXNet": 0.12 }, "Model-Archive-Description": "test", "License": "Apache 2.0", "Model-Archive-Version": 0.1, "Model-Server": 0.1, "Model": { "Description": "test", "Service": "test", "Symbol": "", "Parameters": "", "Signature": "signature.json", "Model-Name": "test", "Model-Format": "Gluon-Imperative" } } json.dump(manifest, man) class TestService(unittest.TestCase): def setUp(self): self.test_dir = tempfile.mkdtemp() def tearDown(self): shutil.rmtree(self.test_dir) def _train_and_export(self, path): model_path = curr_path + '/' + path if not os.path.isdir(model_path): os.mkdir(model_path) num_class = 10 data1 = mx.sym.Variable('data1') data2 = mx.sym.Variable('data2') conv1 = mx.sym.Convolution(data=data1, kernel=(2, 2), num_filter=2, stride=(2, 2)) conv2 = mx.sym.Convolution(data=data2, kernel=(3, 3), num_filter=3, stride=(1, 1)) pooling1 = mx.sym.Pooling(data=conv1, kernel=(2, 2), stride=(1, 1), pool_type="avg") pooling2 = mx.sym.Pooling(data=conv2, kernel=(2, 2), stride=(1, 1), pool_type="max") flatten1 = mx.sym.flatten(data=pooling1) flatten2 = mx.sym.flatten(data=pooling2) summary = mx.sym.sum(data=flatten1, axis=1) + mx.sym.sum(data=flatten2, axis=1) fc = mx.sym.FullyConnected(data=summary, num_hidden=num_class) sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') dshape1 = (10, 3, 64, 64) dshape2 = (10, 3, 32, 32) lshape = (10,) mod = mx.mod.Module(symbol=sym, data_names=('data1', 'data2'), label_names=('softmax_label',)) mod.bind(data_shapes=[('data1', dshape1), ('data2', dshape2)], label_shapes=[('softmax_label', lshape)]) mod.init_params() mod.init_optimizer(optimizer_params={'learning_rate': 0.01}) data_batch = mx.io.DataBatch(data=[mx.nd.random.uniform(0, 9, dshape1), mx.nd.random.uniform(5, 15, dshape2)], label=[mx.nd.ones(lshape)]) mod.forward(data_batch) mod.backward() mod.update() with open('%s/synset.txt' % model_path, 'w') as synset: for i in range(10): synset.write('test label %d\n' % i) def _write_image(self, img_arr): img_arr = mx.nd.transpose(img_arr, (1, 2, 0)).astype(np.uint8).asnumpy() mode = 'RGB' image = PIL.Image.fromarray(img_arr, mode) output = BytesIO() image.save(output, format='jpeg') return output.getvalue() def test_vision_init(self): path = 'test' self._train_and_export(path) model_path = curr_path + '/' + path os.system('rm -rf %s' % model_path) def test_vision_inference(self): path = 'test' self._train_and_export(path) os.system('rm -rf %s/test' % curr_path) def test_gluon_inference(self): path = 'gluon' model_name = 'gluon1' model_path = curr_path + '/' + path os.mkdir(model_path) ctx = mx.cpu() net_g = UnetGenerator(in_channels=3, num_downs=8) data = mx.nd.random_uniform(0, 255, shape=(1, 3, 256, 256)) net_g.initialize(mx.init.Normal(0.02), ctx=ctx) net_g(data) net_g.save_params('%s/%s.params' % (model_path, model_name)) with open('%s/signature.json' % model_path, 'w') as sig: signature = { "input_type": "image/jpeg", "inputs": [ { 'data_name': 'data', 'data_shape': [1, 3, 256, 256] }, ], "output_type": "image/jpeg", "outputs": [ { 'data_name': 'output', 'data_shape': [1, 3, 256, 256] } ] } json.dump(signature, sig) cmd = 'python %s/../../export_model.py --model-name %s --model-path %s' \ % (curr_path, model_name, model_path) os.system(cmd) os.system('rm -rf %s %s/%s.model %s/%s' % (model_path, os.getcwd(), model_name, os.getcwd(), model_name)) def test_mxnet_model_service(self): mod_dir = module_dir(self.test_dir) if mod_dir.startswith('~'): model_path = os.path.expanduser(mod_dir) else: model_path = mod_dir create_symbolic_manifest(model_path) manifest = json.load(open(os.path.join(model_path, 'MANIFEST.json'))) with pytest.raises(Exception): MXNetBaseService('test', model_path, manifest) os.system('rm -rf %s' % model_path) def test_gluon_model_service(self): mod_dir = module_dir(self.test_dir) if mod_dir.startswith('~'): model_path = os.path.expanduser(mod_dir) else: model_path = mod_dir create_imperative_manifest(model_path) manifest = json.load(open(os.path.join(model_path, 'MANIFEST.json'))) GluonImperativeBaseService('test', model_path, manifest, mx.gluon.model_zoo.vision.alexnet(pretrained=True)) os.system('rm -rf %s' % model_path) def runTest(self): self.test_vision_init() self.test_vision_inference() self.test_gluon_inference() self.test_mxnet_model_service() self.test_gluon_model_service() self.test_incorrect_service() ================================================ FILE: mms/tests/unit_tests/test_beckend_metric.py ================================================ import logging import sys import pytest from mms.metrics.dimension import Dimension from mms.metrics.metrics_store import MetricsStore from mms.service import emit_metrics logging.basicConfig(stream=sys.stdout, format="%(message)s", level=logging.INFO) def get_model_key(name, unit, req_id, model_name): dimensions = list() dimensions.append(Dimension("ModelName", model_name)) dimensions.append(Dimension("Level", "Model")) dim_str = [name, unit, str(req_id)] + [str(d) for d in dimensions] return '-'.join(dim_str) def get_error_key(name, unit): dimensions = list() dimensions.append(Dimension("Level", "Error")) dim_str = [name, unit, 'None'] + [str(d) for d in dimensions] return '-'.join(dim_str) def test_metrics(caplog): """ Test if metric classes methods behave as expected Also checks global metric service methods """ caplog.set_level(logging.INFO) # Create a batch of request ids request_ids = {0: 'abcd', 1: "xyz", 2: "qwerty", 3: "hjshfj"} all_req_ids = ','.join(request_ids.values()) model_name = "dummy model" # Create a metrics objects metrics = MetricsStore(request_ids, model_name) # Counter tests metrics.add_counter('CorrectCounter', 1, 1) test_metric = metrics.cache[get_model_key('CorrectCounter', 'count', 'xyz', model_name)] assert 'CorrectCounter' == test_metric.name metrics.add_counter('CorrectCounter', 1, 1) metrics.add_counter('CorrectCounter', 1, 3) metrics.add_counter('CorrectCounter', 1) test_metric = metrics.cache[get_model_key('CorrectCounter', 'count', all_req_ids, model_name)] assert 'CorrectCounter' == test_metric.name metrics.add_counter('CorrectCounter', 3) test_metric = metrics.cache[get_model_key('CorrectCounter', 'count', 'xyz', model_name)] assert test_metric.value == 2 test_metric = metrics.cache[get_model_key('CorrectCounter', 'count', 'hjshfj', model_name)] assert test_metric.value == 1 test_metric = metrics.cache[get_model_key('CorrectCounter', 'count', all_req_ids, model_name)] assert test_metric.value == 4 # Check what is emitted is correct emit_metrics(metrics.store) assert "hjshfj" in caplog.text assert "ModelName:dummy model" in caplog.text # Adding other types of metrics # Check for time metric with pytest.raises(Exception) as e_info: metrics.add_time('WrongTime', 20, 1, 'ns') assert "the unit for a timed metric should be one of ['ms', 's']" == e_info.value.args[0] metrics.add_time('CorrectTime', 20, 2, 's') metrics.add_time('CorrectTime', 20, 0) test_metric = metrics.cache[get_model_key('CorrectTime', 'ms', 'abcd', model_name)] assert test_metric.value == 20 assert test_metric.unit == 'Milliseconds' test_metric = metrics.cache[get_model_key('CorrectTime', 's', 'qwerty', model_name)] assert test_metric.value == 20 assert test_metric.unit == 'Seconds' # Size based metrics with pytest.raises(Exception) as e_info: metrics.add_size('WrongSize', 20, 1, 'TB') assert "The unit for size based metric is one of ['MB','kB', 'GB', 'B']" == e_info.value.args[0] metrics.add_size('CorrectSize', 200, 0, 'GB') metrics.add_size('CorrectSize', 10, 2) test_metric = metrics.cache[get_model_key('CorrectSize', 'GB', 'abcd', model_name)] assert test_metric.value == 200 assert test_metric.unit == 'Gigabytes' test_metric = metrics.cache[get_model_key('CorrectSize', 'MB', 'qwerty', model_name)] assert test_metric.value == 10 assert test_metric.unit == 'Megabytes' # Check a percentage metric metrics.add_percent('CorrectPercent', 20.0, 3) test_metric = metrics.cache[get_model_key('CorrectPercent', 'percent', 'hjshfj', model_name)] assert test_metric.value == 20.0 assert test_metric.unit == 'Percent' # Check a error metric metrics.add_error('CorrectError', 'Wrong values') test_metric = metrics.cache[get_error_key('CorrectError', '')] assert test_metric.value == 'Wrong values' ================================================ FILE: mms/tests/unit_tests/test_model_loader.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. import importlib import inspect import os import sys import types from collections import namedtuple import mock import pytest from mms.model_loader import LegacyModelLoader from mms.model_loader import MmsModelLoader from mms.model_loader import ModelLoaderFactory from mms.model_service.model_service import SingleNodeService # noinspection PyClassHasNoInit # @pytest.mark.skip(reason="Disabling it currently until the PR #467 gets merged") class TestModelFactory: def test_model_loader_factory_legacy(self): model_loader = ModelLoaderFactory.get_model_loader( os.path.abspath('mms/tests/unit_tests/model_service/dummy_model')) assert isinstance(model_loader, LegacyModelLoader) def test_model_loader_factory(self): model_loader = ModelLoaderFactory.get_model_loader( os.path.abspath('mms/tests/unit_tests/test_utils/')) assert isinstance(model_loader, MmsModelLoader) # noinspection PyClassHasNoInit class TestListModels: def test_list_models_legacy(self): model_loader = ModelLoaderFactory.get_model_loader("legacy_mms") sys.path.append(os.path.abspath('mms/tests/unit_tests/model_service/dummy_model')) module = importlib.import_module('dummy_model_service') classes = model_loader.list_model_services(module, SingleNodeService) assert len(classes) == 1 assert issubclass(classes[0], SingleNodeService) def test_list_models(self): model_loader = ModelLoaderFactory.get_model_loader("mms") sys.path.append(os.path.abspath('mms/tests/unit_tests/test_utils/')) module = importlib.import_module('dummy_class_model_service') classes = model_loader.list_model_services(module) assert len(classes) == 1 assert classes[0].__name__ == 'CustomService' # noinspection PyProtectedMember # noinspection PyClassHasNoInit class TestLoadModels: model_name = 'testmodel' model_dir = os.path.abspath('mms/tests/unit_tests/model_service/dummy_model') mock_manifest = '{"Model":{"Service":"dummy_class_model_service.py",' \ '"Signature":"signature.json","Model-Name":"testmodel"}}' @pytest.fixture() def patches(self, mocker): Patches = namedtuple('Patches', ['mock_open', 'os_path', "is_file", "open_signature"]) patches = Patches( mocker.patch('mms.model_loader.open'), mocker.patch('os.path.exists'), mocker.patch('os.path.isfile'), mocker.patch('mms.model_service.model_service.open') ) return patches def test_load_model_legacy(self, patches): patches.mock_open.side_effect = [mock.mock_open(read_data=self.mock_manifest).return_value] patches.open_signature.side_effect = [mock.mock_open(read_data='{}').return_value] patches.is_file.return_value = True patches.os_path.side_effect = [False, True] sys.path.append(self.model_dir) handler = 'dummy_model_service' model_loader = ModelLoaderFactory.get_model_loader(self.model_dir) assert isinstance(model_loader, LegacyModelLoader) service = model_loader.load(self.model_name, self.model_dir, handler, 0, 1) assert inspect.ismethod(service._entry_point) def test_load_class_model(self, patches): patches.mock_open.side_effect = [mock.mock_open(read_data=self.mock_manifest).return_value] sys.path.append(os.path.abspath('mms/tests/unit_tests/test_utils/')) patches.os_path.return_value = True handler = 'dummy_class_model_service' model_loader = ModelLoaderFactory.get_model_loader(os.path.abspath('mms/unit_tests/test_utils/')) service = model_loader.load(self.model_name, self.model_dir, handler, 0, 1) assert inspect.ismethod(service._entry_point) def test_load_func_model(self, patches): patches.mock_open.side_effect = [mock.mock_open(read_data=self.mock_manifest).return_value] sys.path.append(os.path.abspath('mms/tests/unit_tests/test_utils/')) patches.os_path.return_value = True handler = 'dummy_func_model_service:infer' model_loader = ModelLoaderFactory.get_model_loader(os.path.abspath('mms/unit_tests/test_utils/')) service = model_loader.load(self.model_name, self.model_dir, handler, 0, 1) assert isinstance(service._entry_point, types.FunctionType) assert service._entry_point.__name__ == 'infer' def test_load_func_model_with_error(self, patches): patches.mock_open.side_effect = [mock.mock_open(read_data=self.mock_manifest).return_value] sys.path.append(os.path.abspath('mms/tests/unit_tests/test_utils/')) patches.os_path.return_value = True handler = 'dummy_func_model_service:wrong' model_loader = ModelLoaderFactory.get_model_loader(os.path.abspath('mms/unit_tests/test_utils/')) with pytest.raises(ValueError, match=r"Expected only one class .*"): model_loader.load(self.model_name, self.model_dir, handler, 0, 1) def test_load_model_with_error(self, patches): patches.mock_open.side_effect = [ mock.mock_open(read_data='{"test" : "h"}').return_value] sys.path.append(os.path.abspath('mms/tests/unit_tests/test_utils/')) patches.os_path.return_value = True handler = 'dummy_func_model_service' model_loader = ModelLoaderFactory.get_model_loader(os.path.abspath('mms/unit_tests/test_utils/')) with pytest.raises(ValueError, match=r"Expected only one class .*"): model_loader.load(self.model_name, self.model_dir, handler, 0, 1) ================================================ FILE: mms/tests/unit_tests/test_model_service_worker.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ ModelServiceWorker is the worker that is started by the MMS front-end. """ import socket from collections import namedtuple import mock import pytest from mock import Mock from mms.model_service_worker import MXNetModelServiceWorker from mms.service import Service @pytest.fixture() def socket_patches(mocker): Patches = namedtuple('Patches', ['socket']) mock_patch = Patches(mocker.patch('socket.socket')) mock_patch.socket.recv.side_effect = [ b"L", b"\x00\x00\x00\x0a", b"model_name", b"\x00\x00\x00\x0a", b"model_path", b"\x00\x00\x00\x01", b"\x00\x00\x00\x07", b"handler", b"\x00\x00\x00\x01" ] return mock_patch @pytest.fixture() def model_service_worker(socket_patches): model_service_worker = MXNetModelServiceWorker('unix', 'my-socket', None, None) model_service_worker.sock = socket_patches.socket model_service_worker.service = Service('name', 'mpath', 'testmanifest', None, 0, 1) return model_service_worker # noinspection PyClassHasNoInit class TestInit: socket_name = "sampleSocketName" def test_missing_socket_name(self): with pytest.raises(ValueError, match="Invalid socket type provided.*"): MXNetModelServiceWorker() def test_socket_in_use(self, mocker): remove = mocker.patch('os.remove') path_exists = mocker.patch('os.path.exists') remove.side_effect = OSError() path_exists.return_value = True with pytest.raises(Exception, match=r".*socket already in use: sampleSocketName.*"): MXNetModelServiceWorker('unix', self.socket_name) @pytest.fixture() def patches(self, mocker): Patches = namedtuple('Patches', ['remove', 'socket']) patches = Patches( mocker.patch('os.remove'), mocker.patch('socket.socket') ) return patches def test_success(self, patches): MXNetModelServiceWorker('unix', self.socket_name) patches.remove.assert_called_once_with(self.socket_name) patches.socket.assert_called_once_with(socket.AF_UNIX, socket.SOCK_STREAM) # noinspection PyClassHasNoInit class TestRunServer: accept_result = (mock.MagicMock(), None) def test_with_socket_bind_error(self, socket_patches, model_service_worker): bind_exception = socket.error("binding error") socket_patches.socket.bind.side_effect = bind_exception with pytest.raises(Exception): model_service_worker.run_server() socket_patches.socket.bind.assert_called() socket_patches.socket.listen.assert_not_called() def test_with_timeout(self, socket_patches, model_service_worker): exception = socket.timeout("Some Exception") socket_patches.socket.accept.side_effect = exception with pytest.raises(socket.timeout): model_service_worker.run_server() socket_patches.socket.bind.assert_called() socket_patches.socket.listen.assert_called() socket_patches.socket.accept.assert_called() def test_with_run_server_debug(self, socket_patches, model_service_worker, mocker): exception = Exception("Some Exception") socket_patches.socket.accept.side_effect = exception mocker.patch('mms.model_service_worker.DEBUG', True) model_service_worker.handle_connection = Mock() with pytest.raises(Exception): model_service_worker.run_server() socket_patches.socket.bind.assert_called() socket_patches.socket.listen.assert_called() socket_patches.socket.accept.assert_called() def test_success(self, model_service_worker): model_service_worker.sock.accept.return_value = self.accept_result model_service_worker.sock.recv.return_value = b"" exception = SystemExit model_service_worker.sock.accept.side_effect = exception with pytest.raises(SystemExit): model_service_worker.run_server() model_service_worker.sock.accept.assert_called_once() # noinspection PyClassHasNoInit class TestLoadModel: data = {'modelPath': b'mpath', 'modelName': b'name', 'handler': b'handled'} @pytest.fixture() def patches(self, mocker): Patches = namedtuple('Patches', ['loader']) patches = Patches(mocker.patch('mms.model_service_worker.ModelLoaderFactory')) return patches def test_load_model(self, patches, model_service_worker): patches.loader.get_model_loader.return_value = Mock() model_service_worker.load_model(self.data) patches.loader.get_model_loader.assert_called() # noinspection PyUnusedLocal @pytest.mark.parametrize('batch_size', [(None, None), ('1', 1)]) @pytest.mark.parametrize('gpu', [(None, None), ('2', 2)]) def test_optional_args(self, patches, model_service_worker, batch_size, gpu): data = self.data.copy() if batch_size[0]: data['batchSize'] = batch_size[0] if gpu[0]: data['gpu'] = gpu[0] model_service_worker.load_model(data) # noinspection PyClassHasNoInit class TestHandleConnection: data = {'modelPath': b'mpath', 'modelName': b'name', 'handler': b'handled'} @pytest.fixture() def patches(self, mocker): Patches = namedtuple("Patches", ["retrieve_msg"]) patches = Patches( mocker.patch("mms.model_service_worker.retrieve_msg") ) return patches def test_handle_connection(self, patches, model_service_worker): patches.retrieve_msg.side_effect = [(b"L", ""), (b"I", ""), (b"U", "")] model_service_worker.load_model = Mock() model_service_worker.service.predict = Mock() model_service_worker._remap_io = Mock() service = Mock() service.context = None model_service_worker.load_model.return_value = ("", 200) model_service_worker.service.predict.return_value = ("OK") model_service_worker._remap_io.return_value = ("") cl_socket = Mock() with pytest.raises(ValueError, match=r"Received unknown command.*"): model_service_worker.handle_connection(cl_socket) cl_socket.send.assert_called() ================================================ FILE: mms/tests/unit_tests/test_otf_codec_protocol.py ================================================ # coding=utf-8 # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ On The Fly Codec tester """ from collections import namedtuple import pytest import mms.protocol.otf_message_handler as codec from builtins import bytes @pytest.fixture() def socket_patches(mocker): Patches = namedtuple('Patches', ['socket']) mock_patch = Patches(mocker.patch('socket.socket')) mock_patch.socket.recv.return_value = b'1' return mock_patch # noinspection PyClassHasNoInit class TestOtfCodecHandler: def test_retrieve_msg_unknown(self, socket_patches): socket_patches.socket.recv.side_effect = [b"U", b"\x00\x00\x00\x03"] with pytest.raises(ValueError, match=r"Invalid command: .*"): codec.retrieve_msg(socket_patches.socket) def test_retrieve_msg_load_gpu(self, socket_patches): expected = {"modelName": b"model_name", "modelPath": b"model_path", "batchSize": 1, "handler": b"handler", "gpu": 1, "ioFileDescriptor": b"0123456789"} socket_patches.socket.recv.side_effect = [ b"L", b"\x00\x00\x00\x0a", b"model_name", b"\x00\x00\x00\x0a", b"model_path", b"\x00\x00\x00\x01", b"\x00\x00\x00\x07", b"handler", b"\x00\x00\x00\x01", b"\x00\x00\x00\x0a", b"0123456789" ] cmd, ret = codec.retrieve_msg(socket_patches.socket) assert cmd == b"L" assert ret == expected def test_retrieve_msg_load_no_gpu(self, socket_patches): expected = {"modelName": b"model_name", "modelPath": b"model_path", "batchSize": 1, "handler": b"handler", "ioFileDescriptor": b"0123456789"} socket_patches.socket.recv.side_effect = [ b"L", b"\x00\x00\x00\x0a", b"model_name", b"\x00\x00\x00\x0a", b"model_path", b"\x00\x00\x00\x01", b"\x00\x00\x00\x07", b"handler", b"\xFF\xFF\xFF\xFF", b"\x00\x00\x00\x0a", b"0123456789" ] cmd, ret = codec.retrieve_msg(socket_patches.socket) assert cmd == b"L" assert ret == expected def test_retrieve_msg_predict(self, socket_patches): expected = [{ "requestId": b"request_id", "headers": [], "parameters": [ {"name": "input_name", "contentType": "application/json", "value": {"data": "value"} } ] }] socket_patches.socket.recv.side_effect = [ b"I", b"\x00\x00\x00\x0a", b"request_id", b"\xFF\xFF\xFF\xFF", b"\x00\x00\x00\x0a", b"input_name", b"\x00\x00\x00\x0F", b"application/json", b"\x00\x00\x00\x0F", b'{"data":"value"}', b"\xFF\xFF\xFF\xFF", # end of parameters b"\xFF\xFF\xFF\xFF" # end of batch ] cmd, ret = codec.retrieve_msg(socket_patches.socket) assert cmd == b'I' assert ret == expected def test_retrieve_msg_predict_text(self, socket_patches): expected = [{ "requestId": b"request_id", "headers": [], "parameters": [ {"name": "input_name", "contentType": "text/plain", "value": u"text_value测试" } ] }] socket_patches.socket.recv.side_effect = [ b"I", b"\x00\x00\x00\x0a", b"request_id", b"\xFF\xFF\xFF\xFF", b"\x00\x00\x00\x0a", b"input_name", b"\x00\x00\x00\x0a", b"text/plain", b"\x00\x00\x00\x0a", bytes(u"text_value测试", "utf-8"), b"\xFF\xFF\xFF\xFF", # end of parameters b"\xFF\xFF\xFF\xFF" # end of batch ] cmd, ret = codec.retrieve_msg(socket_patches.socket) assert cmd == b'I' assert ret == expected def test_retrieve_msg_predict_binary(self, socket_patches): expected = [{ "requestId": b"request_id", "headers": [], "parameters": [ {"name": "input_name", "contentType": "", "value": b"binary" } ] }] socket_patches.socket.recv.side_effect = [ b"I", b"\x00\x00\x00\x0a", b"request_id", b"\xFF\xFF\xFF\xFF", b"\x00\x00\x00\x0a", b"input_name", b"\x00\x00\x00\x00", b"\x00\x00\x00\x06", b"binary", b"\xFF\xFF\xFF\xFF", # end of parameters b"\xFF\xFF\xFF\xFF" # end of batch ] cmd, ret = codec.retrieve_msg(socket_patches.socket) assert cmd == b'I' assert ret == expected def test_create_load_model_response(self): msg = codec.create_load_model_response(200, "model_loaded") assert msg == b'\x00\x00\x00\xc8\x00\x00\x00\x0cmodel_loaded\xff\xff\xff\xff' def test_create_predict_response(self): msg = codec.create_predict_response(["OK"], {0: "request_id"}, "success", 200) assert msg == b'\x00\x00\x00\xc8\x00\x00\x00\x07success\x00\x00\x00\nrequest_id\x00\x00\x00\x00\x00\x00' \ b'\x00\xc8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02OK\xff\xff\xff\xff' def test_create_predict_response_with_error(self): msg = codec.create_predict_response(None, {0: "request_id"}, "failed", 200) assert msg == b'\x00\x00\x00\xc8\x00\x00\x00\x06failed\x00\x00\x00\nrequest_id\x00\x00\x00\x00\x00\x00\x00' \ b'\xc8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05error\xff\xff\xff\xff' ================================================ FILE: mms/tests/unit_tests/test_utils/MAR-INF/MANIFEST.json ================================================ ================================================ FILE: mms/tests/unit_tests/test_utils/dummy_class_model_service.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Dummy custom service which is class based """ # noinspection PyUnusedLocal class CustomService(object): def initialize(self, context): pass # noinspection PyMethodMayBeStatic def handle(self, data, context): from mms.context import Context return ["OK"] ================================================ FILE: mms/tests/unit_tests/test_utils/dummy_func_model_service.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Dummy custom service which is function based """ from mms.context import Context # noinspection PyUnusedLocal def infer(data, context): return isinstance(context, Context) ================================================ FILE: mms/tests/unit_tests/test_version.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. import os import re import mms def test_mms_version(): with open(os.path.join("mms", "version.py")) as f: exec(f.read(), globals()) assert __version__ == str(mms.__version__), "Versions don't match" ================================================ FILE: mms/tests/unit_tests/test_worker_service.py ================================================ import logging import os import sys import pytest from mms.context import Context from mms.service import Service from mms.service import emit_metrics logging.basicConfig(stream=sys.stdout, format="%(message)s", level=logging.INFO) # noinspection PyClassHasNoInit class TestService: model_name = 'testmodel' model_dir = os.path.abspath('mms/tests/unit_tests/test_utils/') manifest = "testmanifest" data = [ {"requestId": b"123", "parameters": [ {"name": "xyz", "value": "abc", "contentType": "text/csv"} ], "data": b""} ] @pytest.fixture() def service(self, mocker): service = object.__new__(Service) service._entry_point = mocker.MagicMock(return_value=['prediction']) service._context = Context(self.model_name, self.model_dir, self.manifest, 1, 0, '1.0') return service def test_predict(self, service, mocker): create_predict_response = mocker.patch("mms.service.create_predict_response") service.predict(self.data) create_predict_response.assert_called() def test_with_nil_request(self, service): with pytest.raises(ValueError, match=r"Received invalid inputs"): service.retrieve_data_for_inference(None) def test_valid_req(self, service): headers, input_batch, req_to_id_map = service.retrieve_data_for_inference(self.data) assert headers[0].get_request_property("xyz").get("content-type") == "text/csv" assert input_batch[0] == {"xyz": "abc"} assert req_to_id_map == {0: "123"} # noinspection PyClassHasNoInit class TestEmitMetrics: def test_emit_metrics(self, caplog): caplog.set_level(logging.INFO) metrics = {'test_emit_metrics': True} emit_metrics(metrics) assert "[METRICS]" in caplog.text ================================================ FILE: mms/utils/__init__.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Util files for MMS """ from . import timeit_decorator ================================================ FILE: mms/utils/mxnet/__init__.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ MXNet Utils """ import warnings warnings.warn("Module mms.utils.mxnet is deprecated, please avoid using mms internal modules.", DeprecationWarning, stacklevel=2) ================================================ FILE: mms/utils/mxnet/image.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Image utils """ import sys import base64 from io import BytesIO import numpy as np from PIL import Image import mxnet as mx from mxnet import image as img def transform_shape(img_arr, dim_order='NCHW'): """Rearrange image NDArray shape to 'NCHW' or 'NHWC' which is valid for MXNet model input. Input image NDArray should has dim_order of 'HWC'. Parameters ---------- img_arr : NDArray Image in NDArray format with shape (channel, width, height) dim_order : str Output image dimension order. Valid values are 'NCHW' and 'NHWC' Returns ------- output : NDArray Image in NDArray format with dim_order shape """ assert dim_order in 'NCHW' or dim_order in 'NHWC', "dim_order must be 'NCHW' or 'NHWC'." if dim_order == 'NCHW': img_arr = mx.nd.transpose(img_arr, (2, 0, 1)) output = mx.nd.expand_dims(img_arr, axis=0) return output def read(buf, flag=1, to_rgb=True, out=None): """Read and decode an image to an NDArray. Input image NDArray should has dim_order of 'HWC'. Note: `imread` uses OpenCV (not the CV2 Python library). MXNet must have been built with USE_OPENCV=1 for `imdecode` to work. Parameters ---------- buf : str/bytes or numpy.ndarray Binary image data as string or numpy ndarray. flag : {0, 1}, default 1 1 for three channel color output. 0 for grayscale output. to_rgb : bool, default True True for RGB formatted output (MXNet default). False for BGR formatted output (OpenCV default). out : NDArray, optional Output buffer. Use `None` for automatic allocation. Returns ------- NDArray An `NDArray` containing the image. Example ------- >>> buf = open("flower.jpg", 'rb').read() >>> image.read(buf) """ return img.imdecode(buf, flag, to_rgb, out) # TODO: Check where this is used and rename format def write(img_arr, flag=1, format='jpeg', dim_order='CHW'): # pylint: disable=redefined-builtin """Write an NDArray to a base64 string Parameters ---------- img_arr : NDArray Image in NDArray format with shape (channel, width, height). flag : {0, 1}, default 1 1 for three channel color output. 0 for grayscale output. format : str Output image format. dim_order : str Input image dimension order. Valid values are 'CHW' and 'HWC' Returns ------- str Image in base64 string format """ assert dim_order in 'CHW' or dim_order in 'HWC', "dim_order must be 'CHW' or 'HWC'." if dim_order == 'CHW': img_arr = mx.nd.transpose(img_arr, (1, 2, 0)) if flag == 1: mode = 'RGB' else: mode = 'L' img_arr = mx.nd.reshape(img_arr, (img_arr.shape[0], img_arr.shape[1])) img_arr = img_arr.astype(np.uint8).asnumpy() image = Image.fromarray(img_arr, mode) output = BytesIO() image.save(output, format=format) output.seek(0) if sys.version_info[0] < 3: return base64.b64encode(output.getvalue()) else: return base64.b64encode(output.getvalue()).decode("utf-8") def resize(src, new_width, new_height, interp=2): """Resizes image to new_width and new_height. Input image NDArray should has dim_order of 'HWC'. Parameters ---------- src : NDArray Source image in NDArray format new_width : int Width in pixel for resized image new_height : int Height in pixel for resized image interp : int interpolation method for all resizing operations Possible values: 0: Nearest Neighbors Interpolation. 1: Bilinear interpolation. 2: Area-based (resampling using pixel area relation). It may be a preferred method for image decimation, as it gives moire-free results. But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default). 3: Bicubic interpolation over 4x4 pixel neighborhood. 4: Lanczos interpolation over 8x8 pixel neighborhood. 9: Cubic for enlarge, area for shrink, bilinear for others 10: Random select from interpolation method metioned above. Note: When shrinking an image, it will generally look best with AREA-based interpolation, whereas, when enlarging an image, it will generally look best with Bicubic (slow) or Bilinear (faster but still looks OK). More details can be found in the documentation of OpenCV, please refer to http://docs.opencv.org/master/da/d54/group__imgproc__transform.html. Returns ------- NDArray An `NDArray` containing the resized image. """ return img.imresize(src, new_width, new_height, interp) def fixed_crop(src, x0, y0, w, h, size=None, interp=2): """Crop src at fixed location, and (optionally) resize it to size. Input image NDArray should has dim_order of 'HWC'. Parameters ---------- src : NDArray Input image x0 : int Left boundary of the cropping area y0 : int Top boundary of the cropping area w : int Width of the cropping area h : int Height of the cropping area size : tuple of (w, h) Optional, resize to new size after cropping interp : int, optional, default=2 Interpolation method. See resize for details. Returns ------- NDArray An `NDArray` containing the cropped image. """ return img.fixed_crop(src, x0, y0, w, h, size, interp) def color_normalize(src, mean, std=None): """Normalize src with mean and std. Parameters ---------- src : NDArray Input image mean : NDArray RGB mean to be subtracted std : NDArray RGB standard deviation to be divided Returns ------- NDArray An `NDArray` containing the normalized image. """ src = src.astype(np.float32) return img.color_normalize(src, mean, std) ================================================ FILE: mms/utils/mxnet/ndarray.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ NDArray utils """ import numpy as np import mxnet as mx def top_probability(data, labels, top=5): """Get top probability prediction from NDArray. Parameters ---------- data : NDArray Data to be predicted labels : List List of class labels Returns ------- List List of probability: class pairs in sorted order """ dim = len(data.shape) if dim > 2: data = mx.nd.array( np.squeeze(data.asnumpy(), axis=tuple(range(dim)[2:]))) sorted_prob = mx.nd.argsort(data[0], is_ascend=False) # pylint: disable=deprecated-lambda top_prob = map(lambda x: int(x.asscalar()), sorted_prob[0:top]) return [{'probability': float(data[0, i].asscalar()), 'class': labels[i]} for i in top_prob] ================================================ FILE: mms/utils/mxnet/nlp.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ NLP utils """ import bisect import numpy as np import mxnet as mx def encode_sentences(sentences, vocab=None, invalid_label=-1, invalid_key='\n', start_label=0): """Encode sentences and (optionally) build a mapping from string tokens to integer indices. Unknown keys will be added to vocabulary. Parameters ---------- sentences : list of list of str A list of sentences to encode. Each sentence should be a list of string tokens. vocab : None or dict of str -> int Optional input Vocabulary invalid_label : int, default -1 Index for invalid token, like invalid_key : str, default '\\n' Key for invalid token. Use '\\n' for end of sentence by default. start_label : int lowest index. Returns ------- result : list of list of int encoded sentences vocab : dict of str -> int result vocabulary """ idx = start_label if vocab is None: vocab = {invalid_key: invalid_label} new_vocab = True else: new_vocab = False res = [] for sent in sentences: coded = [] for word in sent: if word not in vocab: if not new_vocab: coded.append(invalid_label) continue else: if idx == invalid_label: idx += 1 vocab[word] = idx idx += 1 coded.append(vocab[word]) res.append(coded) return res, vocab def pad_sentence(sentence, buckets, invalid_label=-1, data_name='data', layout='NT'): """Pad a sentence to closest length in provided buckets. Parameters ---------- sentence : list of int A list of integer representing an encoded sentence. buckets : list of int Size of the data buckets. invalid_label : int, optional Index for invalid token, like . data_name : str, optional Input data name. layoutlayout : str, optional Format of data and label. 'NT' means (batch_size, length) and 'TN' means (length, batch_size). Returns ------- result : mx.io.DataBatch DataBatch contains sentence. """ buck = bisect.bisect_left(buckets, len(sentence)) buff = np.full((buckets[buck],), invalid_label, dtype='float32') buff[:len(sentence)] = sentence sent_bucket = buckets[buck] pad_sent = mx.nd.array([buff], dtype='float32') shape = (1, sent_bucket) if layout == 'NT' else (sent_bucket, 1) return mx.io.DataBatch([pad_sent], pad=0, bucket_key=sent_bucket, provide_data=[mx.io.DataDesc( name=data_name, shape=shape, layout=layout)]) ================================================ FILE: mms/utils/timeit_decorator.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ timeit decorator """ import time from functools import wraps def timeit(func): """ Use this decorator on a method to find it's execution time. :param func: :return: """ @wraps(func) def time_and_log(*args, **kwargs): start = time.time() start_cpu = time.clock() result = func(*args, **kwargs) end = time.time() end_cpu = time.clock() print("func: %r took a total of %2.4f sec to run and %2.4f sec of CPU time\n", (func.__name__, (end-start), (end_cpu - start_cpu))) return result return time_and_log ================================================ FILE: mms/version.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ This is the current version of MMS """ __version__ = '1.1.11' ================================================ FILE: model-archiver/.coveragerc ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. [report] exclude_lines = pragma: no cover if __name__ == .__main__.: if __name__ == "__main__" : [run] branch = True omit = */__init__.py model_archiver/tests/* model_archiver/manifest_components/* model_archiver/arg_parser.py model_archiver/setup.py ================================================ FILE: model-archiver/LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright {yyyy} {name of copyright owner} Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: model-archiver/MANIFEST.in ================================================ include PyPiDescription.rst ================================================ FILE: model-archiver/PyPiDescription.rst ================================================ Project Description =================== Model Archiver is a tool used for creating archives of trained neural net models that can be consumed by Multi-Model-Server inference. Use the Model Archiver CLI to start create a ``.mar`` file. Model Archiver is part of `MMS `__. However,you ca install Model Archiver stand alone. Detailed documentation and examples are provided in the `README `__. Prerequisites ------------- ONNX support is optional in `model-archiver` tool. It's not installed by default with `model-archiver`. If you wish to package a ONNX model, you will need to first install a ``protobuf`` compiler, ``onnx`` and ``mxnet`` manually. `Instructions for installing Model Archiver with ONNX `__. Installation ------------ :: pip install model-archiver Development ----------- We welcome new contributors of all experience levels. For information on how to install MMS for development, refer to the `MMS docs `__. Important links --------------- - `Official source code repo `__ - `Download releases `__ - `Issue tracker `__ Source code ----------- You can check the latest source code as follows: :: git clone https://github.com/awslabs/multi-model-server.git Testing ------- After installation, try out the MMS Quickstart for `Create a model archive `__ and `Serving a Model `__. Help and Support ---------------- - `Documentation `__ - `Forum `__ Citation -------- If you use MMS in a publication or project, please cite MMS: https://github.com/awslabs/multi-model-server ================================================ FILE: model-archiver/README.md ================================================ # Model archiver for MMS ## Contents of this Document * [Overview](#overview) * [Model Archiver CLI](#model-archiver-command-line-interface) * [Artifact Details](#artifact-details) * [MAR-INFO](#mar-inf) * [Model name](#model-name) * [Runtime](#runtime) * [Handler](#handler) * [Quick Start: Creating a Model Archive](#creating-a-model-archive) ## Other Relevant Documents * [Model Archive Examples](../examples/README.md) * [Packaging an ONNX Model](docs/convert_from_onnx.md) ## Overview A key feature of MMS is the ability to package all model artifacts into a single model archive file. It is a separate command line interface (CLI), `model-archiver`, that can take model checkpoints and package them into a `.mar` file. This file can then be redistributed and served by anyone using MMS. It takes in the following model artifacts: a model composed of one or more files, the description of the model's inputs in the form of a signature file, a service file describing how to handle inputs and outputs, and other optional assets that may be required to serve the model. The CLI creates a `.mar` file that MMS's server CLI uses to serve the models. **Important**: Make sure you try the [Quick Start: Creating a Model Archive](#creating-a-model-archive) tutorial for a short example of using `model-archiver`. MMS can support any arbitrary model file. It is the custom service code's responsibility to locate and load the model files. The following information is required to create a standalone model archive: 1. [Model name](#model-name) 2. [Model path](#model-path) 3. [Handler](#handler) ## Model Archiver Command Line Interface Now let's cover the details on using the CLI tool: `model-archiver`. Here is an example usage with the squeezenet_v1.1 model archive which you can download or create by following the example in the [main README](../README.md): ```bash model-archiver --model-name squeezenet_v1.1 --model-path squeezenet --handler mxnet_vision_service:handle ``` ### Arguments ``` $ model-archiver -h usage: model-archiver [-h] --model-name MODEL_NAME --model-path MODEL_PATH --handler HANDLER [--runtime {python,python2,python3}] [--export-path EXPORT_PATH] [-f] Model Archiver Tool optional arguments: -h, --help show this help message and exit --model-name MODEL_NAME Exported model name. Exported file will be named as model-name.mar and saved in current working directory if no --export-path is specified, else it will be saved under the export path --model-path MODEL_PATH Path to the folder containing model related files. --handler HANDLER Handler path to handle custom MMS inference logic. --runtime {python,python2,python3} The runtime specifies which language to run your inference code on. The default runtime is RuntimeType.PYTHON. At the present moment we support the following runtimes python, python2, python3 --export-path EXPORT_PATH Path where the exported .mar file will be saved. This is an optional parameter. If --export-path is not specified, the file will be saved in the current working directory. --archive-format {tgz,default} The format in which the model artifacts are archived. "tgz": This creates the model-archive in .tar.gz format. If platform hosting MMS requires model-artifacts to be in ".tar.gz" use this option. "no-archive": This option creates an non-archived version of model artifacts at "export-path/{model-name}" location. As a result of this choice, MANIFEST file will be created at "export-path/{model-name}" location without archiving these model files "default": This creates the model-archive in .mar format. This is the default archiving format. Models archived in this format will be readily hostable on native MMS. -f, --force When the -f or --force flag is specified, an existing .mar file with same name as that provided in --model- name in the path specified by --export-path will overwritten ``` ## Artifact Details ### MAR-INF **MAR-INF** is a reserved folder name that will be used inside `.mar` file. This folder contains the model archive metadata files. Users should avoid using **MAR-INF** in their model path. ### Runtime ### Model name A valid model name must begin with a letter of the alphabet and can only contains letters, digits, underscores (_), dashes (-) and periods (.). **Note**: The model name can be overridden when you register the model with [Register Model API](../docs/management_api.md#register-a-model). ### Model path A folder that contains all necessary files needed to run inference code for the model. All the files and sub-folders (except [excluded files](#excluded-files)) will be packaged into the `.mar` file. #### excluded files The following types of file will be excluded during model archive packaging: 1. hidden files 2. Mac system files: __MACOSX and .DS_Store 3. MANIFEST.json 4. python compiled byte code (.pyc) files and cache folder __pycache__ ### handler A handler is a python entry point that MMS can invoke to execute inference code. The format of a Python handler is: * python_module_name[:function_name] (for example: lstm-service:handle). The function name is optional if the provided python module follows one of predefined conventions: 1. There is a `handle()` function available in the module 2. The module contains only one Class and that class contains a `handle()` function. Further details and specifications are found on the [custom service](../docs/custom_service.md) page. ## Creating a Model Archive **1. Download these sample SqueezeNet model artifacts (if you don't have them handy)** ```bash mkdir squeezenet curl -o squeezenet/squeezenet_v1.1-symbol.json https://s3.amazonaws.com/model-server/model_archive_1.0/examples/squeezenet_v1.1/squeezenet_v1.1-symbol.json curl -o squeezenet/squeezenet_v1.1-0000.params https://s3.amazonaws.com/model-server/model_archive_1.0/examples/squeezenet_v1.1/squeezenet_v1.1-0000.params curl -o squeezenet/signature.json https://s3.amazonaws.com/model-server/model_archive_1.0/examples/squeezenet_v1.1/signature.json curl -o squeezenet/synset.txt https://s3.amazonaws.com/model-server/model_archive_1.0/examples/squeezenet_v1.1/synset.txt ``` The downloaded model artifact files are: * **Model Definition** (json file) - contains the layers and overall structure of the neural network. * **Model Params and Weights** (params file) - contains the parameters and the weights. * **Model Signature** (json file) - defines the inputs and outputs that MMS is expecting to hand-off to the API. * **assets** (text files) - auxiliary files that support model inference such as vocabularies, labels, etc. These vary depending on the model. **2. Download the model archiver source** ```bash git clone https://github.com/awslabs/multi-model-server.git ``` **3. Prepare your model custom service code** You can implement your own model customer service code with a model archive entry point. Here we are going to use the MXNet vision service `model_service_template`. This template is one of several provided with MMS. Download the template and place it in your `squeezenet` folder. ```bash cp -r multi-model-server/examples/model_service_template/* squeezenet/ ``` **4. Package your model** With the model artifacts available locally, you can use the `model-archiver` CLI to generate a `.mar` file that can be used to serve an inference API with MMS. In this next step we'll run `model-archiver` and tell it our model's prefix is `squeezenet_v1.1` with the `model-name` argument. Then we're giving it the `model-path` to the model's assets. **Note**: For mxnet models, `model-name` must match prefix of the symbol and param file name. ```bash model-archiver --model-name squeezenet_v1.1 --model-path squeezenet --handler mxnet_vision_service:handle ``` This will package all the model artifacts files located in the `squeezenet` directory and output `squeezenet_v1.1.mar` in the current working directory. This `.mar` file is all you need to run MMS, serving inference requests for a simple image recognition API. Go back to the [Serve a Model tutorial](../README.md#serve-a-model) and try to run this model archive that you just created! ================================================ FILE: model-archiver/docs/convert_from_onnx.md ================================================ # Converting an ONNX Model ## Install model-archiver with ONNX support ONNX support is optional in `model-archiver` tool. It's not installed by default with `model-archiver`. To install MMS with ONNX support, you will need to have the [protobuf compiler](https://github.com/onnx/onnx#installation) installed: for Ubuntu run: ```bash sudo apt-get install protobuf-compiler libprotoc-dev pip install model-archiver[onnx] ``` Or for Mac run: ```bash conda install -c conda-forge protobuf numpy pip install model-archiver[onnx] ``` MXNet is also required for conversion. You can choose different flavor is mxnet: ```bash pip install mxnet or pip install mxnet-mkl or pip install mxnet-cu90mkl ``` ## ONNX model archive example You can download a model from the [ONNX Model Zoo](https://github.com/onnx/models) then use `model-archiver` to covert it to a `.mar` file. **Note**: Some ONNX model authors upload their models to the zoo in the `.pb` or `.pb2` format. Just change the extension to `.onnx` before attempting to convert. Let's use the SqueezeNet ONNX model as an example. ### Prepare ONNX model and labels To create a model archive for MMS, you can get `.onnx` file and optionally a labels file (synset.txt) from our S3: * [SqueezeNet ONNX model](https://s3.amazonaws.com/model-server/model_archive_1.0/examples/onnx-squeezenet/squeezenet.onnx): a `.onnx` model file from the [ONNX Model Zoo](https://github.com/onnx/models) * [label file](https://s3.amazonaws.com/model-server/model_archive_1.0/examples/onnx-squeezenet/synset.txt): has the labels for 1,000 ImageNet classes ```bash cd multi-model-server/examples mkdir onnx-squeezenet cd onnx-squeezenet curl -O https://s3.amazonaws.com/model-server/model_archive_1.0/examples/onnx-squeezenet/squeezenet.onnx curl -O https://s3.amazonaws.com/model-server/model_archive_1.0/examples/onnx-squeezenet/synset.txt ``` ### Prepare your model custom service code You can implement your own model customer service code as model archive entry point. In this example we just copy provided mxnet vision service template: ```bash cd multi-model-server/examples cp -r model_service_template/* onnx-squeezenet/ ``` The mxnet_vision_service.py assume there is a signature.json file that describes input parameter name and shape. You can download example from: [signature file](https://s3.amazonaws.com/model-server/model_archive_1.0/examples/onnx-squeezenet/signature.json). ```bash cd multi-model-server/examples/onnx-squeezenet curl -o signature.json https://s3.amazonaws.com/model-server/model_archive_1.0/examples/onnx-squeezenet/signature.json ``` ### Create a `.mar` file from onnx model The model file in this example contains`.onnx` extension. In order to convert the model with `.onnx` extension to an MXNet model, we would need to use the `-c` option of the model-archiver tool. Now you can use the `model-archiver` command to output `onnx-squeezenet.mar` file. ```bash cd multi-model-server/examples model-archiver --model-name onnx-squeezenet --model-path onnx-squeezenet --handler mxnet_vision_service:handle -c -f ``` Now start the server: ```bash cd multi-model-server multi-model-server --start --model-store examples --models squeezenet=onnx-squeezenet.mar ``` After your server starts, you can use the following command to see the prediction results. ```bash curl -X POST http://127.0.0.1:8080/predictions/squeezenet -T docs/images/kitten_small.jpg ``` ================================================ FILE: model-archiver/model_archiver/__init__.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ This module does the following: Exports the model folder to generate a Model Archive file out of it in .mar format """ from . import version __version__ = version.__version__ ================================================ FILE: model-archiver/model_archiver/arg_parser.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ This module parses the arguments given through the multi-model-server command-line. This is used by model-server at runtime. """ import argparse import os from .manifest_components.manifest import RuntimeType # noinspection PyTypeChecker class ArgParser(object): """ Argument parser for model-export-tool commands More detailed example is available at https://github.com/awslabs/multi-model-server/blob/master/README.md """ @staticmethod def export_model_args_parser(): """ Argument parser for multi-model-export """ parser_export = argparse.ArgumentParser(prog='model-archiver', description='Model Archiver Tool', formatter_class=argparse.RawTextHelpFormatter) parser_export.add_argument('--model-name', required=True, type=str, default=None, help='Exported model name. Exported file will be named as\n' 'model-name.mar and saved in current working directory if no --export-path is\n' 'specified, else it will be saved under the export path') parser_export.add_argument('--model-path', required=True, type=str, default=None, help='Path to the folder containing model related files.') parser_export.add_argument('--handler', required=True, dest="handler", type=str, default=None, help='Handler path to handle custom MMS inference logic.') parser_export.add_argument('--runtime', required=False, type=str, default=RuntimeType.PYTHON.value, choices=[s.value for s in RuntimeType], help='The runtime specifies which language to run your inference code on.\n' 'The default runtime is "python".') parser_export.add_argument('--export-path', required=False, type=str, default=os.getcwd(), help='Path where the exported .mar file will be saved. This is an optional\n' 'parameter. If --export-path is not specified, the file will be saved in the\n' 'current working directory. ') parser_export.add_argument('--archive-format', required=False, type=str, default="default", choices=["tgz", "no-archive", "default"], help='The format in which the model artifacts are archived.\n' '"tgz": This creates the model-archive in .tar.gz format.\n' 'If platform hosting MMS requires model-artifacts to be in ".tar.gz"\n' 'use this option.\n' '"no-archive": This option creates an non-archived version of model artifacts\n' 'at "export-path/{model-name}" location. As a result of this choice, \n' 'MANIFEST file will be created at "export-path/{model-name}" location\n' 'without archiving these model files\n' '"default": This creates the model-archive in .mar format.\n' 'This is the default archiving format. Models archived in this format\n' 'will be readily hostable on native MMS.\n') parser_export.add_argument('-f', '--force', required=False, action='store_true', help='When the -f or --force flag is specified, an existing .mar file with same\n' 'name as that provided in --model-name in the path specified by --export-path\n' 'will overwritten') parser_export.add_argument('-c', '--convert', required=False, action='store_true', help='When this option is used, model-archiver looks for special files and tries\n' 'preprocesses them. For example, if this option is chosen when running\n' 'model-archiver tool on a model with ".onnx" extension, the tool will try and\n' 'convert ".onnx" model into an Multi model.') return parser_export ================================================ FILE: model-archiver/model_archiver/manifest_components/__init__.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. ================================================ FILE: model-archiver/model_archiver/manifest_components/engine.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. # pylint: disable=missing-docstring import json from enum import Enum class EngineType(Enum): MXNET = "MXNet" # TODO Add more engines here as and when MMS supports more DL Frameworks class Engine(object): """ Engine is a part of the final manifest.json. It defines which framework to run the inference on """ def __init__(self, engine_name, engine_version=None): self.engine_name = EngineType(engine_name) self.engine_version = engine_version self.engine_dict = self.__to_dict__() def __to_dict__(self): engine_dict = dict() engine_dict['engineName'] = self.engine_name.value if self.engine_version is not None: engine_dict['engineVersion'] = self.engine_version return engine_dict def __str__(self): return json.dumps(self.engine_dict) def __repr__(self): return json.dumps(self.engine_dict) ================================================ FILE: model-archiver/model_archiver/manifest_components/manifest.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. # pylint: disable=redefined-builtin # pylint: disable=missing-docstring import json from enum import Enum class RuntimeType(Enum): PYTHON = "python" PYTHON2 = "python2" PYTHON3 = "python3" # TODO : Add more runtimes here when we support more runtimes such as Java/Go/Scala etc.. class Manifest(object): """ The main manifest object which gets written into the model archive as MANIFEST.json """ def __init__(self, runtime, model, engine=None, specification_version='1.0', implementation_version='1.0', description=None, publisher=None, model_server_version='1.0', license=None, user_data=None): self.runtime = RuntimeType(runtime) self.engine = engine self.model = model self.publisher = publisher self.specification_version = specification_version self.implementation_version = implementation_version self.model_server_version = model_server_version self.license = license self.description = description self.user_data = user_data self.manifest_dict = self.__to_dict__() def __to_dict__(self): manifest_dict = dict() manifest_dict['runtime'] = self.runtime.value manifest_dict['model'] = self.model.__to_dict__() if self.engine is not None: manifest_dict['engine'] = self.engine.__to_dict__() if self.license is not None: manifest_dict['license'] = self.license if self.model_server_version is not None: manifest_dict['modelServerVersion'] = self.model_server_version if self.description is not None: manifest_dict['description'] = self.description if self.implementation_version is not None: manifest_dict['implementationVersion'] = self.implementation_version if self.specification_version is not None: manifest_dict['specificationVersion'] = self.specification_version if self.user_data is not None: manifest_dict['userData'] = self.user_data if self.publisher is not None: manifest_dict['publisher'] = self.publisher.__to_dict__() return manifest_dict def __str__(self): return json.dumps(self.manifest_dict, indent=2) def __repr__(self): return json.dumps(self.manifest_dict, indent=2) ================================================ FILE: model-archiver/model_archiver/manifest_components/model.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. # pylint: disable=missing-docstring import json class Model(object): """ Model is a part of the manifest.json. It defines the properties of the model such as name, version as weill as the entry point into the service code through the handler property """ def __init__(self, model_name, handler, description=None, model_version=None, extensions=None): self.model_name = model_name self.description = description self.model_version = model_version self.extensions = extensions self.handler = handler self.model_dict = self.__to_dict__() def __to_dict__(self): model_dict = dict() model_dict['modelName'] = self.model_name model_dict['handler'] = self.handler if self.description is not None: model_dict['description'] = self.description if self.model_version is not None: model_dict['modelVersion'] = self.model_version if self.extensions is not None: model_dict['extensions'] = self.extensions return model_dict def __str__(self): return json.dumps(self.model_dict) def __repr__(self): return json.dumps(self.model_dict) ================================================ FILE: model-archiver/model_archiver/manifest_components/publisher.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. # pylint: disable=missing-docstring import json class Publisher(object): """ Publisher object is a part of Manifest.json """ def __init__(self, author, email): self.author = author self.email = email self.pub_dict = self.__to_dict__() def __to_dict__(self): pub_dict = dict() pub_dict['author'] = self.author pub_dict['email'] = self.email return pub_dict def __str__(self): return json.dumps(self.pub_dict) def __repr__(self): return json.dumps(self.pub_dict) ================================================ FILE: model-archiver/model_archiver/model_archiver_error.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Model Archiver Error """ class ModelArchiverError(Exception): """ Error for Model Archiver module """ def __init__(self, message): super(ModelArchiverError, self).__init__(message) ================================================ FILE: model-archiver/model_archiver/model_packaging.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Command line interface to export model files to be used for inference by Multi Model Server """ import logging import sys from .arg_parser import ArgParser from .model_packaging_utils import ModelExportUtils from .model_archiver_error import ModelArchiverError def package_model(args, manifest): """ Internal helper for the exporting model command line interface. """ model_path = args.model_path model_name = args.model_name export_file_path = args.export_path temp_files = [] try: ModelExportUtils.validate_inputs(model_path, model_name, export_file_path) # Step 1 : Check if .mar already exists with the given model name export_file_path = ModelExportUtils.check_mar_already_exists(model_name, export_file_path, args.force, args.archive_format) # Step 2 : Check if any special handling is required for custom models like onnx models files_to_exclude = [] if args.convert: t, files_to_exclude = ModelExportUtils.check_custom_model_types(model_path, model_name) temp_files.extend(t) # Step 3 : Zip 'em all up ModelExportUtils.archive(export_file_path, model_name, model_path, files_to_exclude, manifest, args.archive_format) logging.info("Successfully exported model %s to file %s", model_name, export_file_path) except ModelArchiverError as e: logging.error(e) sys.exit(1) finally: ModelExportUtils.clean_temp_files(temp_files) def generate_model_archive(): """ Generate a model archive file :return: """ logging.basicConfig(format='%(levelname)s - %(message)s') args = ArgParser.export_model_args_parser().parse_args() manifest = ModelExportUtils.generate_manifest_json(args) package_model(args, manifest=manifest) if __name__ == '__main__': generate_model_archive() ================================================ FILE: model-archiver/model_archiver/model_packaging_utils.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Helper utils for Model Export tool """ import json import logging import os import re import zipfile import shutil from .model_archiver_error import ModelArchiverError from .manifest_components.engine import Engine from .manifest_components.manifest import Manifest from .manifest_components.model import Model from .manifest_components.publisher import Publisher archiving_options = { "tgz": ".tar.gz", "no-archive": "", "default": ".mar" } MODEL_SERVER_VERSION = '1.0' MODEL_ARCHIVE_VERSION = '1.0' MANIFEST_FILE_NAME = 'MANIFEST.json' MAR_INF = 'MAR-INF' ONNX_TYPE = '.onnx' class ModelExportUtils(object): """ Helper utils for Model Archiver tool. This class lists out all the methods such as validations for model archiving, ONNX model checking etc. """ @staticmethod def get_archive_export_path(export_file_path, model_name, archive_format): return os.path.join(export_file_path, '{}{}'.format(model_name, archiving_options.get(archive_format))) @staticmethod def check_mar_already_exists(model_name, export_file_path, overwrite, archive_format="default"): """ Function to check if .mar already exists :param archive_format: :param model_name: :param export_file_path: :param overwrite: :return: """ if export_file_path is None: export_file_path = os.getcwd() export_file = ModelExportUtils.get_archive_export_path(export_file_path, model_name, archive_format) if os.path.exists(export_file): if overwrite: logging.warning("Overwriting %s ...", export_file) else: raise ModelArchiverError("%s already exists.\n" "Please specify --force/-f option to overwrite the model archive " "output file.\n" "See -h/--help for more details." + export_file) return export_file_path @staticmethod def check_custom_model_types(model_path, model_name=None): """ This functions checks whether any special handling is required for custom model extensions such as .onnx, or in the future, for Tensorflow and PyTorch extensions. :param model_path: :param model_name: :return: """ temp_files = [] # List of temp files added to handle custom models files_to_exclude = [] # List of files to be excluded from .mar packaging. files_set = set(os.listdir(model_path)) onnx_file = ModelExportUtils.find_unique(files_set, ONNX_TYPE) if onnx_file is not None: logging.debug("Found ONNX files. Converting ONNX file to model archive...") symbol_file, params_file = ModelExportUtils.convert_onnx_model(model_path, onnx_file, model_name) files_to_exclude.append(onnx_file) temp_files.append(os.path.join(model_path, symbol_file)) temp_files.append(os.path.join(model_path, params_file)) # More cases will go here as an if-else block return temp_files, files_to_exclude @staticmethod def find_unique(files, suffix): """ Function to find unique model params file :param files: :param suffix: :return: """ match = [f for f in files if f.endswith(suffix)] count = len(match) if count == 0: return None elif count == 1: return match[0] else: raise ModelArchiverError("model-archiver expects only one {} file in the folder." " Found {} files {} in model-path.".format(suffix, count, match)) @staticmethod def convert_onnx_model(model_path, onnx_file, model_name): """ Util to convert onnx model to MXNet model :param model_name: :param model_path: :param onnx_file: :return: """ try: import mxnet as mx from mxnet.contrib import onnx as onnx_mxnet except ImportError: raise ModelArchiverError("MXNet package is not installed. Run command: pip install mxnet to install it.") try: import onnx except ImportError: raise ModelArchiverError("Onnx package is not installed. Run command: pip install onnx to install it.") symbol_file = '%s-symbol.json' % model_name params_file = '%s-0000.params' % model_name signature_file = 'signature.json' # Find input symbol name and shape try: model_proto = onnx.load(os.path.join(model_path, onnx_file)) except: logging.error("Failed to load the %s model. Verify if the model file is valid", onnx_file) raise graph = model_proto.graph _params = set() for tensor_vals in graph.initializer: _params.add(tensor_vals.name) input_data = [] for graph_input in graph.input: shape = [] if graph_input.name not in _params: for val in graph_input.type.tensor_type.shape.dim: shape.append(val.dim_value) input_data.append((graph_input.name, tuple(shape))) try: sym, arg_params, aux_params = onnx_mxnet.import_model(os.path.join(model_path, onnx_file)) # UNION of argument and auxillary parameters params = dict(arg_params, **aux_params) except: logging.error("Failed to import %s file to onnx. Verify if the model file is valid", onnx_file) raise try: # rewrite input data_name correctly with open(os.path.join(model_path, signature_file), 'r') as f: data = json.loads(f.read()) data['inputs'][0]['data_name'] = input_data[0][0] data['inputs'][0]['data_shape'] = [int(i) for i in input_data[0][1]] with open(os.path.join(model_path, signature_file), 'w') as f: f.write(json.dumps(data, indent=2)) with open(os.path.join(model_path, symbol_file), 'w') as f: f.write(sym.tojson()) except: logging.error("Failed to write the signature or symbol files for %s model", onnx_file) raise save_dict = {('arg:%s' % k): v.as_in_context(mx.cpu()) for k, v in params.items()} mx.nd.save(os.path.join(model_path, params_file), save_dict) return symbol_file, params_file @staticmethod def generate_publisher(publisherargs): publisher = Publisher(author=publisherargs.author, email=publisherargs.email) return publisher @staticmethod def generate_engine(engineargs): engine = Engine(engine_name=engineargs.engine) return engine @staticmethod def generate_model(modelargs): model = Model(model_name=modelargs.model_name, handler=modelargs.handler) return model @staticmethod def generate_manifest_json(args): """ Function to generate manifest as a json string from the inputs provided by the user in the command line :param args: :return: """ arg_dict = vars(args) publisher = ModelExportUtils.generate_publisher(args) if 'author' in arg_dict and 'email' in arg_dict else None engine = ModelExportUtils.generate_engine(args) if 'engine' in arg_dict else None model = ModelExportUtils.generate_model(args) manifest = Manifest(runtime=args.runtime, model=model, engine=engine, publisher=publisher) return str(manifest) @staticmethod def clean_temp_files(temp_files): for f in temp_files: os.remove(f) @staticmethod def make_dir(d): if not os.path.isdir(d): os.makedirs(d) @staticmethod def archive(export_file, model_name, model_path, files_to_exclude, manifest, archive_format="default"): """ Create a model-archive :param archive_format: :param export_file: :param model_name: :param model_path: :param files_to_exclude: :param manifest: :return: """ mar_path = ModelExportUtils.get_archive_export_path(export_file, model_name, archive_format) try: if archive_format == "tgz": import tarfile from io import BytesIO with tarfile.open(mar_path, 'w:gz') as z: ModelExportUtils.archive_dir(model_path, z, set(files_to_exclude), archive_format, model_name) # Write the manifest here now as a json tar_manifest = tarfile.TarInfo(name=os.path.join(model_name, MAR_INF, MANIFEST_FILE_NAME)) tar_manifest.size = len(manifest.encode('utf-8')) z.addfile(tarinfo=tar_manifest, fileobj=BytesIO(manifest.encode())) z.close() elif archive_format == "no-archive": if model_path != mar_path: # Copy files to export path if ModelExportUtils.archive_dir(model_path, mar_path, set(files_to_exclude), archive_format, model_name) # Write the MANIFEST in place manifest_path = os.path.join(mar_path, MAR_INF) ModelExportUtils.make_dir(manifest_path) with open(os.path.join(manifest_path, MANIFEST_FILE_NAME), "w") as f: f.write(manifest) else: with zipfile.ZipFile(mar_path, 'w', zipfile.ZIP_DEFLATED) as z: ModelExportUtils.archive_dir(model_path, z, set(files_to_exclude), archive_format, model_name) # Write the manifest here now as a json z.writestr(os.path.join(MAR_INF, MANIFEST_FILE_NAME), manifest) except IOError: logging.error("Failed to save the model-archive to model-path \"%s\". " "Check the file permissions and retry.", export_file) raise except: logging.error("Failed to convert %s to the model-archive.", model_name) raise @staticmethod def archive_dir(path, dst, files_to_exclude, archive_format, model_name): """ This method zips the dir and filters out some files based on a expression :param archive_format: :param path: :param dst: :param model_name: :param files_to_exclude: :return: """ unwanted_dirs = {'__MACOSX', '__pycache__'} for root, directories, files in os.walk(path): # Filter directories directories[:] = [d for d in directories if ModelExportUtils.directory_filter(d, unwanted_dirs)] # Filter files files[:] = [f for f in files if ModelExportUtils.file_filter(f, files_to_exclude)] for f in files: file_path = os.path.join(root, f) if archive_format == "tgz": dst.add(file_path, arcname=os.path.join(model_name, os.path.relpath(file_path, path))) elif archive_format == "no-archive": dst_dir = os.path.dirname(os.path.join(dst, os.path.relpath(file_path, path))) ModelExportUtils.make_dir(dst_dir) shutil.copy(file_path, dst_dir) else: dst.write(file_path, os.path.relpath(file_path, path)) @staticmethod def directory_filter(directory, unwanted_dirs): """ This method weeds out unwanted hidden directories from the model archive .mar file :param directory: :param unwanted_dirs: :return: """ if directory in unwanted_dirs: return False if directory.startswith('.'): return False return True @staticmethod def file_filter(current_file, files_to_exclude): """ This method weeds out unwanted files :param current_file: :param files_to_exclude: :return: """ files_to_exclude.add('MANIFEST.json') if current_file in files_to_exclude: return False elif current_file.endswith(('.pyc', '.DS_Store', '.mar')): return False return True @staticmethod def check_model_name_regex_or_exit(model_name): """ Method checks whether model name passes regex filter. If the regex Filter fails, the method exits. :param model_name: :return: """ if not re.match(r'^[A-Za-z0-9][A-Za-z0-9_\-.]*$', model_name): raise ModelArchiverError("Model name contains special characters.\n" "The allowed regular expression filter for model " "name is: ^[A-Za-z0-9][A-Za-z0-9_\\-.]*$") @staticmethod def validate_inputs(model_path, model_name, export_path): ModelExportUtils.check_model_name_regex_or_exit(model_name) if not os.path.isdir(os.path.abspath(export_path)): raise ModelArchiverError("Given export-path {} is not a directory. " "Point to a valid export-path directory.".format(export_path)) if not os.path.isdir(os.path.abspath(model_path)): raise ModelArchiverError("Given model-path {} is not a valid directory. " "Point to a valid model-path directory.".format(model_path)) ================================================ FILE: model-archiver/model_archiver/tests/integ_tests/configuration.json ================================================ [{ "name": "packaging_mar", "modelName": "model", "modelPath": "model_archiver/tests/integ_tests/resources/regular_model", "handler": "service:handle", "exportPath": "/tmp/model", "archiveFormat": "default", "iterations": 1 }, { "name": "packaging_noarchive", "modelName": "model", "modelPath": "model_archiver/tests/integ_tests/resources/regular_model", "handler": "service:handle", "exportPath": "/tmp/model", "archiveFormat": "no-archive", "iterations": 1 }, { "name": "packaging_tgz", "modelName": "model", "modelPath": "model_archiver/tests/integ_tests/resources/regular_model", "handler": "service:handle", "exportPath": "/tmp/model", "archiveFormat": "tgz", "iterations": 1 }, { "name": "packaging_mar", "modelName": "model", "modelPath": "model_archiver/tests/integ_tests/resources/regular_model", "handler": "service:handle", "exportPath": "/tmp/model", "archiveFormat": "default", "iterations": 2, "expectError": true }, { "name": "packaging_onnx_mar", "modelName": "model", "modelPath": "model_archiver/tests/integ_tests/resources/onnx_model", "handler": "service:handle", "exportPath": "/tmp/model", "archiveFormat": "default", "iterations": 1 }, { "name": "packaging_onnx_tgz", "modelName": "model", "modelPath": "model_archiver/tests/integ_tests/resources/onnx_model", "handler": "service:handle", "exportPath": "/tmp/model", "archiveFormat": "tgz", "iterations": 1 }, { "name": "packaging_onnx_no_archive", "modelName": "model", "modelPath": "model_archiver/tests/integ_tests/resources/onnx_model", "handler": "service:handle", "exportPath": "/tmp/model", "archiveFormat": "no-archive", "iterations": 1 }, { "name": "packaging_mar", "modelName": "model", "modelPath": "model_archiver/tests/integ_tests/resources/regular_model", "handler": "service:handle", "exportPath": "/tmp/model", "archiveFormat": "default", "iterations": 2, "force": true } ] ================================================ FILE: model-archiver/model_archiver/tests/integ_tests/resources/onnx_model/service.py ================================================ """ This is a dummy source fiile """ def handle(): return "Dummy model" ================================================ FILE: model-archiver/model_archiver/tests/integ_tests/resources/regular_model/dir/1.py ================================================ __version__ = 1 ================================================ FILE: model-archiver/model_archiver/tests/integ_tests/resources/regular_model/dummy-artifacts.txt ================================================ Some random text. ================================================ FILE: model-archiver/model_archiver/tests/integ_tests/resources/regular_model/service.py ================================================ """ This is a dummy source fiile """ def handle(): return "Dummy model" ================================================ FILE: model-archiver/model_archiver/tests/integ_tests/test_integration_model_archiver.py ================================================ import errno import json import os import shutil import subprocess import requests DEFAULT_MODEL_PATH = "model_archiver/tests/integ_tests/resources/regular_model" DEFAULT_HANDLER = "service:handle" DEFAULT_RUNTIME = "python" DEFAULT_MODEL_NAME = "model" DEFAULT_EXPORT_PATH = "/tmp/model" MANIFEST_FILE = "MAR-INF/MANIFEST.json" def update_tests(test): test["modelName"] = test.get("modelName", DEFAULT_MODEL_NAME) test["modelPath"] = test.get("modelPath", DEFAULT_MODEL_PATH) test["handler"] = test.get("handler", DEFAULT_HANDLER) test["runtime"] = test.get("runtime", DEFAULT_RUNTIME) test["exportPath"] = test.get("exportPath", DEFAULT_EXPORT_PATH) test["archiveFormat"] = test.get("archiveFormat", "default") return test def create_file_path(path): try: os.makedirs(path) except OSError as exc: if exc.errno == errno.EEXIST and os.path.isdir(path): pass else: raise def delete_file_path(path): try: if os.path.isfile(path): os.remove(path) if os.path.isdir(path): shutil.rmtree(path) except OSError: pass def run_test(test, cmd): it = test.get("iterations") if test.get("iterations") is not None else 1 for i in range(it): try: subprocess.check_call(cmd, shell=True) except subprocess.CalledProcessError as exc: if test.get("expectError") is not True: assert 0, "{}".format(exc.output) else: return 0 return 1 def validate_archive_exists(test): fmt = test.get("archiveFormat") if fmt == "tgz": assert os.path.isfile(os.path.join(test.get("exportPath"), test.get("modelName")+".tar.gz")) elif fmt == "no-archive": assert os.path.isdir(os.path.join(test.get("exportPath"), test.get("modelName"))) else: assert os.path.isfile(os.path.join(test.get("exportPath"), test.get("modelName")+".mar")) def validate_manifest_file(manifest, test): """ Validate the MANIFEST file :param manifest: :param test: :return: """ assert manifest.get("runtime") == test.get("runtime") assert manifest.get("model").get("modelName") == test.get("modelName") assert manifest.get("model").get("handler") == test.get("handler") def validate_files(file_list, prefix, regular): assert os.path.join(prefix, MANIFEST_FILE) in file_list assert os.path.join(prefix, "service.py") in file_list if regular: assert os.path.join(prefix, "dummy-artifacts.txt") in file_list assert os.path.join(prefix, "dir/1.py") in file_list else: assert os.path.join(prefix, "model.onnx") in file_list def validate_tar_archive(test_cfg): import tarfile file_name = os.path.join(test_cfg.get("exportPath"), test_cfg.get("modelName") + ".tar.gz") f = tarfile.open(file_name, "r:gz") manifest = json.loads(f.extractfile(os.path.join(test_cfg.get("modelName"), MANIFEST_FILE)).read()) validate_manifest_file(manifest, test_cfg) validate_files(f.getnames(), test_cfg.get("modelName"), "regular_model" in test_cfg.get("modelPath")) def validate_noarchive_archive(test): file_name = os.path.join(test.get("exportPath"), test.get("modelName"), MANIFEST_FILE) manifest = json.loads(open(file_name).read()) validate_manifest_file(manifest, test) def validate_mar_archive(test): import zipfile file_name = os.path.join(test.get("exportPath"), test.get("modelName") + ".mar") zf = zipfile.ZipFile(file_name, "r") manifest = json.loads(zf.open(MANIFEST_FILE).read()) validate_manifest_file(manifest, test) def validate_archive_content(test): fmt = test.get("archiveFormat") if fmt == "tgz": validate_tar_archive(test) if fmt == "no-archive": validate_noarchive_archive(test) if fmt == "default": validate_mar_archive(test) def validate(test): validate_archive_exists(test) validate_archive_content(test) def test_model_archiver(): f = open("model_archiver/tests/integ_tests/configuration.json", "r") tests = json.loads(f.read()) for t in tests: try: delete_file_path(t.get("exportPath")) create_file_path(t.get("exportPath")) t = update_tests(t) cmd = "model-archiver " \ "--model-name {} " \ "--model-path {} " \ "--handler {} " \ "--runtime {} " \ "--export-path {} " \ "--archive-format {}".format(t.get("modelName"), t.get("modelPath"), t.get("handler"), t.get("runtime"), t.get("exportPath"), t.get("archiveFormat")) if t.get("force"): cmd += " -f" # TODO: Add tests to check for "convert" functionality if run_test(t, cmd): validate(t) finally: delete_file_path(t.get("exportPath")) if __name__ == "__main__": test_model_archiver() ================================================ FILE: model-archiver/model_archiver/tests/pylintrc ================================================ [MASTER] # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # Specify a configuration file. #rcfile= # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). #init-hook= # Add files or directories to the blacklist. They should be base names, not # paths. ignore=CVS # Add files or directories matching the regex patterns to the blacklist. The # regex matches against base names, not paths. ignore-patterns= # Pickle collected data for later comparisons. persistent=yes # List of plugins (as comma separated values of python modules names) to load, # usually to register additional checkers. load-plugins= # Use multiple processes to speed up Pylint. jobs=8 # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code extension-pkg-whitelist=numpy,opencv # Allow optimization of some AST trees. This will activate a peephole AST # optimizer, which will apply various small optimizations. For instance, it can # be used to obtain the result of joining multiple strings with the addition # operator. Joining a lot of strings can lead to a maximum recursion error in # Pylint and this flag can prevent that. It has one side effect, the resulting # AST will be different than the one from reality. This option is deprecated # and it will be removed in Pylint 2.0. optimize-ast=no [MESSAGES CONTROL] # Only show warnings with the listed confidence levels. Leave empty to show # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED confidence= # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). See also the "--disable" option for examples. enable=indexing-exception,old-raise-syntax # Disable the message, report, category or checker with the given id(s). You # can either give multiple identifiers separated by comma (,) or put this # option multiple times (only on the command line, not in the configuration # file where it should appear only once).You can also use "--disable=all" to # disable everything first and then reenable specific checks. For example, if # you want to run only the similarities checker, you can use "--disable=all # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,protected-access,superfluous-parens,invalid-name,no-else-return,useless-super-delegation,len-as-condition,invalid-unary-operand-type,useless-object-inheritance # disable=unicode-builtin,delslice-method,using-cmp-argument,setslice-method,dict-view-method,parameter-unpacking,range-builtin-not-iterating,print-statement,file-builtin,old-raise-syntax,basestring-builtin,execfile-builtin,indexing-exception,import-star-module-level,coerce-method,long-builtin,old-ne-operator,old-division,no-absolute-import,raw_input-builtin,old-octal-literal,oct-method,xrange-builtin,hex-method,unpacking-in-except,nonzero-method,raising-string,intern-builtin,reload-builtin,metaclass-assignment,cmp-method,filter-builtin-not-iterating,apply-builtin,map-builtin-not-iterating,next-method-called,unichr-builtin,buffer-builtin,dict-iter-method,input-builtin,coerce-builtin,getslice-method,useless-suppression,standarderror-builtin,zip-builtin-not-iterating,suppressed-message,cmp-builtin,backtick,long-suffix,reduce-builtin,round-builtin [REPORTS] # Set the output format. Available formats are text, parseable, colorized, msvs # (visual studio) and html. You can also give a reporter class, eg # mypackage.mymodule.MyReporterClass. output-format=text # Put messages in a separate file for each module / package specified on the # command line instead of printing them on stdout. Reports (if any) will be # written in a file name "pylint_global.[txt|html]". This option is deprecated # and it will be removed in Pylint 2.0. files-output=no # Tells whether to display a full report or only the messages reports=no # Python expression which should return a note less than 10 (10 is the highest # note). You have access to the variables errors warning, statement which # respectively contain the number of errors / warnings messages and the total # number of statements analyzed. This is used by the global evaluation report # (RP0004). evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) # Template used to display messages. This is a python new-style format string # used to format the message information. See doc for all details #msg-template= [FORMAT] # Maximum number of characters on a single line. max-line-length=120 # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines=^\s*(# )??$ # Allow the body of an if to be on the same line as the test if there is no # else. single-line-if-stmt=no # List of optional constructs for which whitespace checking is disabled. `dict- # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. # `trailing-comma` allows a space between comma and closing bracket: (a, ). # `empty-line` allows space-only lines. no-space-check=trailing-comma,dict-separator # Maximum number of lines in a module max-module-lines=1000 # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 # tab). indent-string=' ' # Number of spaces of indent required inside a hanging or continued line. indent-after-paren=4 # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format= [SPELLING] # Spelling dictionary name. Available dictionaries: none. To make it working # install python-enchant package. spelling-dict= # List of comma separated words that should not be checked. spelling-ignore-words= # A path to a file that contains private dictionary; one word per line. spelling-private-dict-file= # Tells whether to store unknown words to indicated private dictionary in # --spelling-private-dict-file option instead of raising a message. spelling-store-unknown-words=no [MISCELLANEOUS] # List of note tags to take in consideration, separated by a comma. notes=FIXME,XXX,TODO [TYPECHECK] # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). ignore-mixin-members=yes # List of module names for which member attributes should not be checked # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis. It # supports qualified module names, as well as Unix pattern matching. ignored-modules= # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. ignored-classes=optparse.Values,thread._local,_thread._local # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. generated-members= # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. contextmanager-decorators=contextlib.contextmanager [LOGGING] # Logging modules to check that the string format arguments are in logging # function parameter format logging-modules=logging [SIMILARITIES] # Minimum lines number of a similarity. min-similarity-lines=4 # Ignore comments when computing similarities. ignore-comments=yes # Ignore docstrings when computing similarities. ignore-docstrings=yes # Ignore imports when computing similarities. ignore-imports=no [VARIABLES] # Tells whether we should check for unused import in __init__ files. init-import=no # A regular expression matching the name of dummy variables (i.e. expectedly # not used). dummy-variables-rgx=(_+[a-zA-Z0-9]*?$)|dummy # List of additional names supposed to be defined in builtins. Remember that # you should avoid to define new builtins when possible. additional-builtins= # List of strings which can identify a callback function by name. A callback # name must start or end with one of those strings. callbacks=cb_,_cb # List of qualified module names which can have objects that can redefine # builtins. redefining-builtins-modules=six.moves,future.builtins [BASIC] # Good variable names which should always be accepted, separated by a comma good-names=i,j,_,a,b,op,x,y,wd,lr,kv,k,v,s,p,h,c,m,n,X,t,g,f # Bad variable names which should always be refused, separated by a comma bad-names= # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. name-group= # Include a hint for the correct naming format with invalid-name include-naming-hint=no # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. property-classes=abc.abstractproperty # Regular expression matching correct module names module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ # Naming hint for module names module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ # Regular expression matching correct constant names const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ # Naming hint for constant names const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ # Regular expression matching correct inline iteration names inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ # Naming hint for inline iteration names inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ # Regular expression matching correct method names method-rgx=[a-z_][a-z0-9_]{2,30}$ # Naming hint for method names method-name-hint=[a-z_][a-z0-9_]{2,30}$ # Regular expression matching correct class attribute names class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ # Naming hint for class attribute names class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ # Regular expression matching correct argument names argument-rgx=[a-z_][a-z0-9_]{2,30}$ # Naming hint for argument names argument-name-hint=[a-z_][a-z0-9_]{2,30}$ # Regular expression matching correct attribute names attr-rgx=[a-z_][a-z0-9_]{2,30}$ # Naming hint for attribute names attr-name-hint=[a-z_][a-z0-9_]{2,30}$ # Regular expression matching correct variable names variable-rgx=[a-z_][a-z0-9_]{2,30}$ # Naming hint for variable names variable-name-hint=[a-z_][a-z0-9_]{2,30}$ # Regular expression matching correct function names function-rgx=[a-z_][a-z0-9_]{2,30}$ # Naming hint for function names function-name-hint=[a-z_][a-z0-9_]{2,30}$ # Regular expression matching correct class names class-rgx=[A-Z_][a-zA-Z0-9]+$ # Naming hint for class names class-name-hint=[A-Z_][a-zA-Z0-9]+$ # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=^_ # Minimum line length for functions/classes that require docstrings, shorter # ones are exempt. docstring-min-length=10 [ELIF] # Maximum number of nested blocks for function / method body max-nested-blocks=5 [CLASSES] # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__,__new__,setUp # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg=mcs # List of member names, which should be excluded from the protected access # warning. exclude-protected=_asdict,_fields,_replace,_source,_make [IMPORTS] # Deprecated modules which should not be used, separated by a comma deprecated-modules=optparse # Create a graph of every (i.e. internal and external) dependencies in the # given file (report RP0402 must not be disabled) import-graph= # Create a graph of external dependencies in the given file (report RP0402 must # not be disabled) ext-import-graph= # Create a graph of internal dependencies in the given file (report RP0402 must # not be disabled) int-import-graph= # Force import order to recognize a module as part of the standard # compatibility libraries. known-standard-library= # Force import order to recognize a module as part of a third party library. known-third-party=enchant # Analyse import fallback blocks. This can be used to support both Python 2 and # 3 compatible code, which means that the block might have code that exists # only in one or another interpreter, leading to false positives when analysed. analyse-fallback-blocks=no [DESIGN] # Maximum number of arguments for function / method max-args=5 # Argument names that match this expression will be ignored. Default to name # with leading underscore ignored-argument-names=_.* # Maximum number of locals for function / method body max-locals=15 # Maximum number of return / yield for function / method body max-returns=6 # Maximum number of branch for function / method body max-branches=12 # Maximum number of statements in function / method body max-statements=50 # Maximum number of parents for a class (see R0901). max-parents=7 # Maximum number of attributes for a class (see R0902). max-attributes=7 # Minimum number of public methods for a class (see R0903). min-public-methods=2 # Maximum number of public methods for a class (see R0904). max-public-methods=20 # Maximum number of boolean expressions in a if statement max-bool-expr=5 [EXCEPTIONS] # Exceptions that will emit a warning when being caught. Defaults to # "Exception" overgeneral-exceptions=Exception ================================================ FILE: model-archiver/model_archiver/tests/unit_tests/test_model_packaging.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. from collections import namedtuple import pytest from model_archiver.manifest_components.engine import EngineType from model_archiver.manifest_components.manifest import RuntimeType from model_archiver.model_packaging import generate_model_archive, package_model from model_archiver.model_packaging_utils import ModelExportUtils # noinspection PyClassHasNoInit class TestModelPackaging: class Namespace: def __init__(self, **kwargs): self.__dict__.update(kwargs) def update(self, **kwargs): self.__dict__.update(kwargs) author = 'ABC' email = 'ABC@XYZ.com' engine = EngineType.MXNET.value model_name = 'my-model' model_path = 'my-model/' handler = 'a.py::my-awesome-func' export_path = '/Users/dummyUser/' args = Namespace(author=author, email=email, engine=engine, model_name=model_name, handler=handler, runtime=RuntimeType.PYTHON.value, model_path=model_path, export_path=export_path, force=False, archive_format="default", convert=False) @pytest.fixture() def patches(self, mocker): Patches = namedtuple('Patches', ['arg_parse', 'export_utils', 'export_method']) patches = Patches(mocker.patch('model_archiver.model_packaging.ArgParser'), mocker.patch('model_archiver.model_packaging.ModelExportUtils'), mocker.patch('model_archiver.model_packaging.package_model')) return patches def test_gen_model_archive(self, patches): patches.arg_parse.export_model_args_parser.parse_args.return_value = self.args generate_model_archive() patches.export_method.assert_called() def test_export_model_method(self, patches): patches.export_utils.check_mar_already_exists.return_value = '/Users/dummyUser/' patches.export_utils.check_custom_model_types.return_value = '/Users/dummyUser', ['a.txt', 'b.txt'] patches.export_utils.zip.return_value = None package_model(self.args, ModelExportUtils.generate_manifest_json(self.args)) patches.export_utils.validate_inputs.assert_called() patches.export_utils.archive.assert_called() patches.export_utils.clean_temp_files.assert_called() def test_export_model_method_tar(self, patches): self.args.update(archive_format="tar") patches.export_utils.check_mar_already_exists.return_value = '/Users/dummyUser/' patches.export_utils.check_custom_model_types.return_value = '/Users/dummyUser', ['a.txt', 'b.txt'] patches.export_utils.zip.return_value = None package_model(self.args, ModelExportUtils.generate_manifest_json(self.args)) patches.export_utils.validate_inputs.assert_called() patches.export_utils.archive.assert_called() patches.export_utils.clean_temp_files.assert_called() def test_export_model_method_noarchive(self, patches): self.args.update(archive_format="no-archive") patches.export_utils.check_mar_already_exists.return_value = '/Users/dummyUser/' patches.export_utils.check_custom_model_types.return_value = '/Users/dummyUser', ['a.txt', 'b.txt'] patches.export_utils.zip.return_value = None package_model(self.args, ModelExportUtils.generate_manifest_json(self.args)) patches.export_utils.validate_inputs.assert_called() patches.export_utils.archive.assert_called() patches.export_utils.clean_temp_files.assert_called() ================================================ FILE: model-archiver/model_archiver/tests/unit_tests/test_model_packaging_utils.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. import json import os import pytest from collections import namedtuple from model_archiver.model_packaging_utils import ModelExportUtils from model_archiver.manifest_components.engine import EngineType from model_archiver.manifest_components.manifest import RuntimeType from model_archiver.model_archiver_error import ModelArchiverError # noinspection PyClassHasNoInit class TestExportModelUtils: # noinspection PyClassHasNoInit class TestMarExistence: @pytest.fixture() def patches(self, mocker): Patches = namedtuple('Patches', ['getcwd', 'path_exists']) patches = Patches(mocker.patch('os.getcwd'), mocker.patch('os.path.exists')) patches.getcwd.return_value = '/Users/dummyUser' return patches def test_export_file_is_none(self, patches): patches.path_exists.return_value = False ret_val = ModelExportUtils.check_mar_already_exists('some-model', None, False) patches.path_exists.assert_called_once_with("/Users/dummyUser/some-model.mar") assert ret_val == "/Users/dummyUser" def test_export_file_is_not_none(self, patches): patches.path_exists.return_value = False ModelExportUtils.check_mar_already_exists('some-model', '/Users/dummyUser/', False) patches.path_exists.assert_called_once_with('/Users/dummyUser/some-model.mar') def test_export_file_already_exists_with_override(self, patches): patches.path_exists.return_value = True ModelExportUtils.check_mar_already_exists('some-model', None, True) patches.path_exists.assert_called_once_with('/Users/dummyUser/some-model.mar') def test_export_file_already_exists_with_override_false(self, patches): patches.path_exists.return_value = True with pytest.raises(ModelArchiverError): ModelExportUtils.check_mar_already_exists('some-model', None, False) patches.path_exists.assert_called_once_with('/Users/dummyUser/some-model.mar') def test_export_file_is_none_tar(self, patches): patches.path_exists.return_value = False ret_val = ModelExportUtils.check_mar_already_exists('some-model', None, False, archive_format='tgz') patches.path_exists.assert_called_once_with("/Users/dummyUser/some-model.tar.gz") assert ret_val == "/Users/dummyUser" def test_export_file_is_none_tar(self, patches): patches.path_exists.return_value = False ret_val = ModelExportUtils.check_mar_already_exists('some-model', None, False, archive_format='no-archive') patches.path_exists.assert_called_once_with("/Users/dummyUser/some-model") assert ret_val == "/Users/dummyUser" # noinspection PyClassHasNoInit class TestArchiveTypes: def test_archive_types(self): from model_archiver.model_packaging_utils import archiving_options as ar_opts assert ar_opts.get("tgz") == ".tar.gz" assert ar_opts.get("no-archive") == "" assert ar_opts.get("default") == ".mar" assert len(ar_opts) == 3 # noinspection PyClassHasNoInit class TestCustomModelTypes: model_path = '/Users/dummyUser' @pytest.fixture() def patches(self, mocker): Patches = namedtuple('Patches', ['utils', 'listdir']) patch = Patches(mocker.patch('model_archiver.model_packaging_utils.ModelExportUtils'), mocker.patch('os.listdir')) patch.listdir.return_value = {'a', 'b', 'c'} return patch def test_onnx_file_is_none(self, patches): patches.utils.find_unique.return_value = None ModelExportUtils.check_custom_model_types(model_path=self.model_path, model_name=None) patches.utils.find_unique.assert_called() patches.utils.convert_onnx_model.assert_not_called() def test_onnx_file_is_not_none(self, patches): onnx_file = 'some-file.onnx' patches.utils.find_unique.return_value = onnx_file patches.utils.convert_onnx_model.return_value = ('sym', 'param') temp, exclude = ModelExportUtils.check_custom_model_types(self.model_path) patches.utils.convert_onnx_model.assert_called_once_with(self.model_path, onnx_file, None) assert len(temp) == 2 assert len(exclude) == 1 assert temp[0] == os.path.join(self.model_path, 'sym') assert temp[1] == os.path.join(self.model_path, 'param') assert exclude[0] == onnx_file # noinspection PyClassHasNoInit class TestFindUnique: def test_with_count_zero(self): files = ['a.txt', 'b.txt', 'c.txt'] suffix = '.mxnet' val = ModelExportUtils.find_unique(files, suffix) assert val is None def test_with_count_one(self): files = ['a.mxnet', 'b.txt', 'c.txt'] suffix = '.mxnet' val = ModelExportUtils.find_unique(files, suffix) assert val == 'a.mxnet' def test_with_exit(self): files = ['a.onnx', 'b.onnx', 'c.txt'] suffix = '.onnx' with pytest.raises(ModelArchiverError): ModelExportUtils.find_unique(files, suffix) # noinspection PyClassHasNoInit class TestCleanTempFiles: @pytest.fixture() def patches(self, mocker): Patches = namedtuple('Patches', ['remove']) patches = Patches(mocker.patch('os.remove')) patches.remove.return_value = True return patches def test_clean_call(self, patches): temp_files = ['a', 'b', 'c'] ModelExportUtils.clean_temp_files(temp_files) patches.remove.assert_called() assert patches.remove.call_count == len(temp_files) # noinspection PyClassHasNoInit class TestGenerateManifestProps: class Namespace: def __init__(self, **kwargs): self.__dict__.update(kwargs) author = 'ABC' email = 'ABC@XYZ.com' engine = EngineType.MXNET.value model_name = 'my-model' handler = 'a.py::my-awesome-func' args = Namespace(author=author, email=email, engine=engine, model_name=model_name, handler=handler, runtime=RuntimeType.PYTHON.value) def test_publisher(self): pub = ModelExportUtils.generate_publisher(self.args) assert pub.email == self.email assert pub.author == self.author def test_engine(self): eng = ModelExportUtils.generate_engine(self.args) assert eng.engine_name == EngineType(self.engine) def test_model(self): mod = ModelExportUtils.generate_model(self.args) assert mod.model_name == self.model_name assert mod.handler == self.handler def test_manifest_json(self): manifest = ModelExportUtils.generate_manifest_json(self.args) manifest_json = json.loads(manifest) assert manifest_json['runtime'] == RuntimeType.PYTHON.value assert 'engine' in manifest_json assert 'model' in manifest_json assert 'publisher' in manifest_json assert 'license' not in manifest_json # noinspection PyClassHasNoInit class TestModelNameRegEx: def test_regex_pass(self): model_names = ['my-awesome-model', 'Aa.model', 'a', 'aA.model', 'a1234.model', 'a-A-A.model', '123-abc'] for m in model_names: ModelExportUtils.check_model_name_regex_or_exit(m) def test_regex_fail(self): model_names = ['abc%', '123$abc', 'abc!123', '@123', '(model', 'mdoel)', '12*model-a.model', '##.model', '-.model'] for m in model_names: with pytest.raises(ModelArchiverError): ModelExportUtils.check_model_name_regex_or_exit(m) # noinspection PyClassHasNoInit class TestFileFilter: files_to_exclude = {'abc.onnx'} def test_with_return_false(self): assert ModelExportUtils.file_filter('abc.onnx', self.files_to_exclude) is False def test_with_pyc(self): assert ModelExportUtils.file_filter('abc.pyc', self.files_to_exclude) is False def test_with_ds_store(self): assert ModelExportUtils.file_filter('.DS_Store', self.files_to_exclude) is False def test_with_return_true(self): assert ModelExportUtils.file_filter('abc.mxnet', self.files_to_exclude) is True # noinspection PyClassHasNoInit class TestDirectoryFilter: unwanted_dirs = {'__MACOSX', '__pycache__'} def test_with_unwanted_dirs(self): assert ModelExportUtils.directory_filter('__MACOSX', self.unwanted_dirs) is False def test_with_starts_with_dot(self): assert ModelExportUtils.directory_filter('.gitignore', self.unwanted_dirs) is False def test_with_return_true(self): assert ModelExportUtils.directory_filter('my-model', self.unwanted_dirs) is True ================================================ FILE: model-archiver/model_archiver/tests/unit_tests/test_version.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. import os import model_archiver def test_model_export_tool_version(): """ Test the model archive version :return: """ with (open(os.path.join('model_archiver', 'version.py'))) as f: exec(f.read(), globals()) assert __version__ == str(model_archiver.__version__), "Versions do not match" ================================================ FILE: model-archiver/model_archiver/version.py ================================================ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ This is the current version of Model Archiver Tool """ __version__ = '1.0.4' ================================================ FILE: model-archiver/setup.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. # To build and upload a new version, follow the steps below. # Notes: # - this is a "Universal Wheels" package that is pure Python and supports both Python2 and Python3 # - Twine is a secure PyPi upload package # - Make sure you have bumped the version! at mms/version.py # $ pip install twine # $ pip install wheel # $ python setup.py bdist_wheel --universal # *** TEST YOUR PACKAGE WITH TEST PI ****** # twine upload --repository-url https://test.pypi.org/legacy/ dist/* # If this is successful then push it to actual pypi # $ twine upload dist/* """ Setup.py for the model-archiver tool """ from datetime import date import sys from setuptools import setup, find_packages # pylint: disable = relative-import import model_archiver pkgs = find_packages() def pypi_description(): """Imports the long description for the project page""" with open('PyPiDescription.rst') as df: return df.read() def detect_model_archiver_version(): if "--release" in sys.argv: sys.argv.remove("--release") # pylint: disable = relative-import return model_archiver.__version__.strip() # pylint: disable = relative-import return model_archiver.__version__.strip() + 'b' + str(date.today()).replace('-', '') if __name__ == '__main__': version = detect_model_archiver_version() requirements = ['future', 'enum-compat'] setup( name='model-archiver', version=version, description='Model Archiver is used for creating archives of trained neural net models that can be consumed ' 'by Multi-Model-Server inference', long_description=pypi_description(), author='Trinity', author_email='noreply@amazon.com', url='https://github.com/awslabs/multi-model-server/model-archiver/', keywords='Multi Model Archive Archiver MMS Server Serving Deep Learning Inference AI', packages=pkgs, install_requires=requirements, extras_require={ 'mxnet-mkl': ['mxnet-mkl==1.3.1'], 'mxnet-cu90mkl': ['mxnet-cu90mkl==1.3.1'], 'mxnet-cu92mkl': ['mxnet-cu92mkl==1.3.1'], 'mxnet': ['mxnet==1.3.1'], 'onnx': ['onnx==1.1.1'] }, entry_points={ 'console_scripts': ['model-archiver=model_archiver.model_packaging:generate_model_archive'] }, include_package_data=True, license='Apache License Version 2.0' ) ================================================ FILE: performance_regression/imageInputModelPlan.jmx.yaml ================================================ --- execution: - concurrency: 1 ramp-up: 5s hold-for: 1m scenario: Inference scenarios: Inference: default-address: ${__P(protocol,https)}://${__P(hostname,127.0.0.1)}:${__P(port,8443)}/ requests: - follow-redirects: true label: Inference Request method: POST upload-files: - mime-type: image/jpeg param: data path: ${__P(input_filepath)} url: ${__P(protocol,http)}://${__P(hostname,127.0.0.1)}:${__P(port,8080)}/predictions/${model} store-cache: false store-cookie: false use-dns-cache-mgr: false variables: cnn_url: ${__P(url, https://s3.amazonaws.com/model-server/models/squeezenet_v1.1/squeezenet_v1.1.model)} model: ${__P(model_name,squeezenet_v1.1)} scale_down_workers: '0' scale_up_workers: ${__P(min_workers,1)} modules: jmeter: properties: input_filepath : kitten.jpg model_name : squeezenet services: - module: monitoring local: - interval: 2s logging: True metrics: - cpu - disk-space - mem reporting: - module: passfail criteria: - class: bzt.modules.monitoring.MonitoringCriteria subject: local/cpu condition: '>' threshold: 100 timeframe: 6s - module: junit-xml filename: ${TAURUS_ARTIFACTS_DIR}/output/results.xml data-source: pass-fail ================================================ FILE: plugins/build.gradle ================================================ /* * This file was generated by the Gradle 'init' task. * * This generated file contains a sample Java Library project to get you started. * For more details take a look at the Java Libraries chapter in the Gradle * User Manual available at https://docs.gradle.org/5.4.1/userguide/java_library_plugin.html */ allprojects { apply plugin: 'idea' apply plugin: 'java' version = '1.0' repositories { jcenter() } idea { module { outputDir = file('build/classes/java/main') testOutputDir = file('build/classes/java/test') } } task buildSagemaker("type": Jar) { doFirst{ task -> println "building $task.project.name" } with project.jar doLast { copy { def fromDir = project.jar def intoDir = "${rootProject.projectDir}/build/plugins" from fromDir into intoDir println "Copying files from" + fromDir + " into " + intoDir } } } buildSagemaker.onlyIf {project.hasProperty("sagemaker")} } def javaProjects() { return subprojects.findAll() } configure(javaProjects()) { sourceCompatibility = 1.8 targetCompatibility = 1.8 defaultTasks 'jar' apply from: file("${rootProject.projectDir}/tools/gradle/formatter.gradle") apply from: file("${rootProject.projectDir}/tools/gradle/check.gradle") test { useTestNG() { // suiteXmlFiles << new File(rootDir, "testng.xml") //This is how to add custom testng.xml } testLogging { showStandardStreams = true events "passed", "skipped", "failed", "standardOut", "standardError" } } test.finalizedBy(project.tasks.jacocoTestReport) compileJava { options.compilerArgs << "-Xlint:all,-options,-static" << "-Werror" } jacocoTestCoverageVerification { violationRules { rule { limit { minimum = 0.75 } } } } } ================================================ FILE: plugins/endpoints/build.gradle ================================================ dependencies { compile "com.google.code.gson:gson:${gson_version}" compile "software.amazon.ai:mms-plugins-sdk:${mms_server_sdk_version}" } project.ext{ sagemaker = true } jar { includeEmptyDirs = false exclude "META-INF/maven/**" exclude "META-INF/INDEX.LIST" exclude "META-INF/MANIFEST*" exclude "META-INF//LICENSE*" exclude "META-INF//NOTICE*" } ================================================ FILE: plugins/endpoints/src/main/java/software/amazon/ai/mms/plugins/endpoint/ExecutionParameters.java ================================================ package software.amazon.ai.mms.plugins.endpoint; import com.google.gson.GsonBuilder; import com.google.gson.annotations.SerializedName; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Properties; import software.amazon.ai.mms.servingsdk.Context; import software.amazon.ai.mms.servingsdk.ModelServerEndpoint; import software.amazon.ai.mms.servingsdk.annotations.Endpoint; import software.amazon.ai.mms.servingsdk.annotations.helpers.EndpointTypes; import software.amazon.ai.mms.servingsdk.http.Request; import software.amazon.ai.mms.servingsdk.http.Response; @Endpoint( urlPattern = "execution-parameters", endpointType = EndpointTypes.INFERENCE, description = "Execution parameters endpoint") public class ExecutionParameters extends ModelServerEndpoint { @Override public void doGet(Request req, Response rsp, Context ctx) throws IOException { Properties prop = ctx.getConfig(); // 6 * 1024 * 1024 int maxRequestSize = Integer.parseInt(prop.getProperty("max_request_size", "6291456")); ExecutionParametersResponse r = new ExecutionParametersResponse(); r.setMaxConcurrentTransforms(Integer.parseInt(prop.getProperty("NUM_WORKERS", "1"))); r.setBatchStrategy("MULTI_RECORD"); r.setMaxPayloadInMB(maxRequestSize / (1024 * 1024)); rsp.getOutputStream() .write( new GsonBuilder() .setPrettyPrinting() .create() .toJson(r) .getBytes(StandardCharsets.UTF_8)); } /** Response for Model server endpoint */ public static class ExecutionParametersResponse { @SerializedName("MaxConcurrentTransforms") private int maxConcurrentTransforms; @SerializedName("BatchStrategy") private String batchStrategy; @SerializedName("MaxPayloadInMB") private int maxPayloadInMB; public ExecutionParametersResponse() { maxConcurrentTransforms = 4; batchStrategy = "MULTI_RECORD"; maxPayloadInMB = 6; } public int getMaxConcurrentTransforms() { return maxConcurrentTransforms; } public String getBatchStrategy() { return batchStrategy; } public int getMaxPayloadInMB() { return maxPayloadInMB; } public void setMaxConcurrentTransforms(int newMaxConcurrentTransforms) { maxConcurrentTransforms = newMaxConcurrentTransforms; } public void setBatchStrategy(String newBatchStrategy) { batchStrategy = newBatchStrategy; } public void setMaxPayloadInMB(int newMaxPayloadInMB) { maxPayloadInMB = newMaxPayloadInMB; } } } ================================================ FILE: plugins/endpoints/src/main/java/software/amazon/ai/mms/plugins/endpoint/Ping.java ================================================ package software.amazon.ai.mms.plugins.endpoint; import java.io.File; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Map; import software.amazon.ai.mms.servingsdk.Context; import software.amazon.ai.mms.servingsdk.Model; import software.amazon.ai.mms.servingsdk.ModelServerEndpoint; import software.amazon.ai.mms.servingsdk.Worker; import software.amazon.ai.mms.servingsdk.annotations.Endpoint; import software.amazon.ai.mms.servingsdk.annotations.helpers.EndpointTypes; import software.amazon.ai.mms.servingsdk.http.Request; import software.amazon.ai.mms.servingsdk.http.Response; @Endpoint( urlPattern = "ping", endpointType = EndpointTypes.INFERENCE, description = "Ping endpoint for sagemaker containers.") public class Ping extends ModelServerEndpoint { private boolean init; private byte[] success = "{\n\t\"Status\": \"Healthy\"\n}\n".getBytes(StandardCharsets.UTF_8); private boolean modelsLoaded(Context ctx) { Map modelMap = ctx.getModels(); for (Map.Entry entry : modelMap.entrySet()) { for (Worker w : entry.getValue().getModelWorkers()) { if (w.isRunning()) { return true; } } } return false; } private boolean validConfig(String svc) { String fileName = svc; if (svc.contains(":")) { fileName = svc.substring(0, svc.lastIndexOf(':')); } if (!fileName.contains(".py")) { fileName = fileName.concat(".py"); } return new File(fileName).exists(); } @Override public void doGet(Request req, Response rsp, Context ctx) throws IOException { rsp.setStatus(200); String isMultiModelMode = System.getenv("SAGEMAKER_MULTI_MODE"); if (isMultiModelMode == null || "false".equalsIgnoreCase(isMultiModelMode)) { if (!init && !modelsLoaded(ctx)) { rsp.setStatus(503, "Model loading..."); rsp.getOutputStream() .write("Models are not loaded".getBytes(StandardCharsets.UTF_8)); } else { init = true; rsp.getOutputStream().write(success); } } else { String svcFile = ctx.getConfig().getProperty("default_service_handler"); if ((svcFile == null) || !validConfig(svcFile)) { rsp.setStatus(503, "Service file unavailable"); rsp.getOutputStream() .write("Service file unavailable".getBytes(StandardCharsets.UTF_8)); } else { rsp.getOutputStream().write(success); } } } } ================================================ FILE: plugins/endpoints/src/main/resources/META-INF/services/software.amazon.ai.mms.servingsdk.ModelServerEndpoint ================================================ software.amazon.ai.mms.plugins.endpoint.Ping ================================================ FILE: plugins/gradle/wrapper/gradle-wrapper.properties ================================================ #Tue Jun 04 12:29:27 PDT 2019 distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists distributionUrl=https\://services.gradle.org/distributions/gradle-4.10-all.zip ================================================ FILE: plugins/gradle.properties ================================================ gson_version=2.8.9 mms_server_sdk_version=1.0.1 ================================================ FILE: plugins/gradlew ================================================ #!/usr/bin/env sh ############################################################################## ## ## Gradle start up script for UN*X ## ############################################################################## # Attempt to set APP_HOME # Resolve links: $0 may be a link PRG="$0" # Need this for relative symlinks. while [ -h "$PRG" ] ; do ls=`ls -ld "$PRG"` link=`expr "$ls" : '.*-> \(.*\)$'` if expr "$link" : '/.*' > /dev/null; then PRG="$link" else PRG=`dirname "$PRG"`"/$link" fi done SAVED="`pwd`" cd "`dirname \"$PRG\"`/" >/dev/null APP_HOME="`pwd -P`" cd "$SAVED" >/dev/null APP_NAME="Gradle" APP_BASE_NAME=`basename "$0"` # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. DEFAULT_JVM_OPTS="" # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD="maximum" warn () { echo "$*" } die () { echo echo "$*" echo exit 1 } # OS specific support (must be 'true' or 'false'). cygwin=false msys=false darwin=false nonstop=false case "`uname`" in CYGWIN* ) cygwin=true ;; Darwin* ) darwin=true ;; MINGW* ) msys=true ;; NONSTOP* ) nonstop=true ;; esac CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar # Determine the Java command to use to start the JVM. if [ -n "$JAVA_HOME" ] ; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then # IBM's JDK on AIX uses strange locations for the executables JAVACMD="$JAVA_HOME/jre/sh/java" else JAVACMD="$JAVA_HOME/bin/java" fi if [ ! -x "$JAVACMD" ] ; then die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME Please set the JAVA_HOME variable in your environment to match the location of your Java installation." fi else JAVACMD="java" which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the location of your Java installation." fi # Increase the maximum file descriptors if we can. if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then MAX_FD_LIMIT=`ulimit -H -n` if [ $? -eq 0 ] ; then if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then MAX_FD="$MAX_FD_LIMIT" fi ulimit -n $MAX_FD if [ $? -ne 0 ] ; then warn "Could not set maximum file descriptor limit: $MAX_FD" fi else warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" fi fi # For Darwin, add options to specify how the application appears in the dock if $darwin; then GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" fi # For Cygwin, switch paths to Windows format before running java if $cygwin ; then APP_HOME=`cygpath --path --mixed "$APP_HOME"` CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` JAVACMD=`cygpath --unix "$JAVACMD"` # We build the pattern for arguments to be converted via cygpath ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` SEP="" for dir in $ROOTDIRSRAW ; do ROOTDIRS="$ROOTDIRS$SEP$dir" SEP="|" done OURCYGPATTERN="(^($ROOTDIRS))" # Add a user-defined pattern to the cygpath arguments if [ "$GRADLE_CYGPATTERN" != "" ] ; then OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" fi # Now convert the arguments - kludge to limit ourselves to /bin/sh i=0 for arg in "$@" ; do CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` else eval `echo args$i`="\"$arg\"" fi i=$((i+1)) done case $i in (0) set -- ;; (1) set -- "$args0" ;; (2) set -- "$args0" "$args1" ;; (3) set -- "$args0" "$args1" "$args2" ;; (4) set -- "$args0" "$args1" "$args2" "$args3" ;; (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; esac fi # Escape application args save () { for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done echo " " } APP_ARGS=$(save "$@") # Collect all arguments for the java command, following the shell quoting and substitution rules eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" # by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then cd "$(dirname "$0")" fi exec "$JAVACMD" "$@" ================================================ FILE: plugins/gradlew.bat ================================================ @if "%DEBUG%" == "" @echo off @rem ########################################################################## @rem @rem Gradle startup script for Windows @rem @rem ########################################################################## @rem Set local scope for the variables with windows NT shell if "%OS%"=="Windows_NT" setlocal set DIRNAME=%~dp0 if "%DIRNAME%" == "" set DIRNAME=. set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. set DEFAULT_JVM_OPTS= @rem Find java.exe if defined JAVA_HOME goto findJavaFromJavaHome set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 if "%ERRORLEVEL%" == "0" goto init echo. echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. echo. echo Please set the JAVA_HOME variable in your environment to match the echo location of your Java installation. goto fail :findJavaFromJavaHome set JAVA_HOME=%JAVA_HOME:"=% set JAVA_EXE=%JAVA_HOME%/bin/java.exe if exist "%JAVA_EXE%" goto init echo. echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% echo. echo Please set the JAVA_HOME variable in your environment to match the echo location of your Java installation. goto fail :init @rem Get command-line arguments, handling Windows variants if not "%OS%" == "Windows_NT" goto win9xME_args :win9xME_args @rem Slurp the command line arguments. set CMD_LINE_ARGS= set _SKIP=2 :win9xME_args_slurp if "x%~1" == "x" goto execute set CMD_LINE_ARGS=%* :execute @rem Setup the command line set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar @rem Execute Gradle "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% :end @rem End local scope for the variables with windows NT shell if "%ERRORLEVEL%"=="0" goto mainEnd :fail rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of rem the _cmd.exe /c_ return code! if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 exit /b 1 :mainEnd if "%OS%"=="Windows_NT" endlocal :omega ================================================ FILE: plugins/settings.gradle ================================================ /* * This file was generated by the Gradle 'init' task. * * The settings file is used to specify which projects to include in your build. * * Detailed information about configuring a multi-project build in Gradle can be found * in the user manual at https://docs.gradle.org/5.4.1/userguide/multi_project_builds.html */ rootProject.name = 'plugins' include 'endpoints' ================================================ FILE: plugins/tools/conf/checkstyle.xml ================================================ ================================================ FILE: plugins/tools/conf/findbugs-exclude.xml ================================================ ================================================ FILE: plugins/tools/conf/pmd.xml ================================================ Java Rule in PMD ================================================ FILE: plugins/tools/conf/suppressions.xml ================================================ ================================================ FILE: plugins/tools/gradle/check.gradle ================================================ apply plugin: 'findbugs' findbugs { excludeFilter = file("${rootProject.projectDir}/tools/conf/findbugs-exclude.xml") ignoreFailures = false findbugsTest.enabled = true } tasks.withType(FindBugs) { reports { xml.enabled false html.enabled true } } apply plugin: 'pmd' pmd { ignoreFailures = false pmdTest.enabled = false ruleSets = [] // workaround pmd gradle plugin bug ruleSetFiles = files("${rootProject.projectDir}/tools/conf/pmd.xml") } tasks.withType(Pmd){ reports{ xml.enabled=true html.enabled=true } } apply plugin: 'checkstyle' checkstyle { toolVersion = '7.1.2' ignoreFailures = false checkstyleTest.enabled = true configProperties = [ "checkstyle.suppressions.file" : file("${rootProject.projectDir}/tools/conf/suppressions.xml")] configFile = file("${rootProject.projectDir}/tools/conf/checkstyle.xml") } checkstyleMain { classpath += configurations.compile } tasks.withType(Checkstyle) { reports { xml.enabled false html.enabled true } } apply plugin: "jacoco" jacoco { toolVersion = "0.8.1" } jacocoTestReport { group = "Reporting" reports { xml.enabled true csv.enabled false } } check.dependsOn jacocoTestReport check.dependsOn jacocoTestCoverageVerification ================================================ FILE: plugins/tools/gradle/formatter.gradle ================================================ buildscript { repositories { maven { url "https://plugins.gradle.org/m2/" } } dependencies { classpath 'com.google.googlejavaformat:google-java-format:1.6' } } apply plugin: FormatterPlugin check.dependsOn verifyJava import com.google.googlejavaformat.java.Formatter import com.google.googlejavaformat.java.ImportOrderer import com.google.googlejavaformat.java.JavaFormatterOptions import com.google.googlejavaformat.java.Main import com.google.googlejavaformat.java.RemoveUnusedImports class FormatterPlugin implements Plugin { void apply(Project project) { project.task('formatJava') { doLast { Main formatter = new Main(new PrintWriter(System.out, true), new PrintWriter(System.err, true), System.in) Project rootProject = project.getRootProject() for (item in project.sourceSets) { for (File file : item.getAllSource()) { if (!file.getName().endsWith(".java")) { continue } if (formatter.format("-a", "-i", file.getAbsolutePath()) != 0) { throw new GradleException("Format java failed: " + file.getAbsolutePath()) } } } } } project.task('verifyJava') { doLast { def options = JavaFormatterOptions.builder().style(JavaFormatterOptions.Style.AOSP).build() Formatter formatter = new Formatter(options) Project rootProject = project.getRootProject() for (item in project.sourceSets) { for (File file : item.getAllSource()) { if (!file.getName().endsWith(".java")) { continue } String src = file.text String formatted = formatter.formatSource(src) formatted = RemoveUnusedImports.removeUnusedImports(formatted, RemoveUnusedImports.JavadocOnlyImports.KEEP) formatted = ImportOrderer.reorderImports(formatted) if (!src.equals(formatted)) { throw new GradleException("File not formatted: " + file.getAbsolutePath()) } } } } } } } ================================================ FILE: plugins/tools/gradle/launcher.gradle ================================================ apply plugin: LauncherPlugin clean.dependsOn killServer import org.gradle.internal.jvm.Jvm class LauncherPlugin implements Plugin { void apply(Project project) { project.task('startServer') { dependsOn project.jar doLast { def pidFile = getPidFile() if (pidFile.exists()) { throw new GradleException("Server already running!") } def list = [] list.addAll(project.configurations.runtime.getFiles()) list.add(project.jar.outputs.files.singleFile) String cp = CollectionUtils.join(File.pathSeparator, list) String jvmPath = Jvm.current().getJavaExecutable() def cmd = [jvmPath, "-agentlib:jdwp=transport=dt_socket,address=0.0.0.0:4000,server=y,suspend=n", "-DmmsConfigFile=${project.projectDir}/src/test/resources/config.properties", "-DLOG_LOCATION=${project.buildDir}/logs", "-DMETRICS_LOCATION=${project.buildDir}/logs", "-cp", cp, "com.amazonaws.ml.mms.ModelServer"] as String[] def builder = new ProcessBuilder(cmd) builder.redirectErrorStream(true) builder.directory(project.projectDir) Process process = builder.start() ReaderThread rt = new ReaderThread(process.getInputStream()) rt.start() new ReaderThread(process.getErrorStream()).start() try { while (!rt.done) { try { process.exitValue() throw new GradleException("MMS stop unexpectedly.") } catch(IllegalThreadStateException ex) { Thread.sleep(500) } } def pidField = process.class.getDeclaredField('pid') pidField.accessible = true pidFile << pidField.getInt(process) logger.quiet "MMS service started." } catch (IllegalThreadStateException ignored) { } } } project.task('killServer') { doLast { def pidFile = getPidFile() if(!pidFile.exists()) { logger.quiet "No server running!" return } def pid = pidFile.text def process = "kill $pid".execute() try { process.waitFor() } finally { pidFile.delete() } } } project.task('restartServer') { dependsOn project.killServer dependsOn project.startServer } } private File getPidFile() { return new File("build/server.pid") } } class ReaderThread extends Thread { private InputStream is private boolean done ReaderThread(InputStream is) { this.is = is } void run() { long begin = System.currentTimeMillis() def line def reader = new BufferedReader(new InputStreamReader(is)) while ((line = reader.readLine()) != null) { if (!done) { done = line.matches("Model server started.*") println line } } } } ================================================ FILE: run_ci_tests.sh ================================================ #!/usr/bin/env bash # # A shell script to build MMS locally. # Developer should make sure local build passed before submit PR. set -e which docker if [ $? -ne 0 ] then echo "Please install docker." exit 1 fi MMS_HOME="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd )" BUILDSPEC="ci/buildspec.yml" docker pull amazon/aws-codebuild-local:latest docker pull awsdeeplearningteam/mms-build:python2.7@sha256:2b743d6724dead806873cce1330f7b8a0197399a35af47dfd7035251fdade122 docker pull awsdeeplearningteam/mms-build:python3.6@sha256:2c1afa8834907ceec641d254dffbf4bcc659ca2d00fd6f2872d7521f32c9fa2e find . -name __pycache__ | xargs rm -rf docker run -it --rm -v /var/run/docker.sock:/var/run/docker.sock -e "IMAGE_NAME=awsdeeplearningteam/mms-build:python2.7" -e "ARTIFACTS=${MMS_HOME}/build/artifacts2.7" -e "SOURCE=${MMS_HOME}" -e "BUILDSPEC=${BUILDSPEC}" amazon/aws-codebuild-local find . -name __pycache__ | xargs rm -rf docker run -it --rm -v /var/run/docker.sock:/var/run/docker.sock -e "IMAGE_NAME=awsdeeplearningteam/mms-build:python3.6" -e "ARTIFACTS=${MMS_HOME}/build/artifacts3.6" -e "SOURCE=${MMS_HOME}" -e "BUILDSPEC=${BUILDSPEC}" amazon/aws-codebuild-local ================================================ FILE: run_circleci_tests.py ================================================ #!/usr/bin/env python """ - This script helps to execute circleci jobs in a container on developer's local machine. - The script accepts workflow(mandatory), job(optional) and executor(optional) arguments. - The script used circleci cli's process command to generate a processed yaml. - The processed yaml, is parsed and twekaed to generate a new transformed yaml. - The transformed yaml contains a single job, which is merged and ordered list of job steps from the specfied and requird parent jobs. """ # Make sure you have following dependencies installed on your local machine # 1. PyYAML (pip install PyYaml) # 2. CircleCI cli from - https://circleci.com/docs/2.0/local-cli/#installation # 3. docker from collections import OrderedDict from functools import reduce import subprocess import sys import copy import argparse import yaml parser = argparse.ArgumentParser(description='Execute circleci jobs in a container \ on your local machine') parser.add_argument('workflow', type=str, help='Workflow name from config.yml') parser.add_argument('-j', '--job', type=str, help='Job name from config.yml') parser.add_argument('-e', '--executor', type=str, help='Executor name from config.yml') args = parser.parse_args() workflow = args.workflow job = args.job executor = args.executor CCI_CONFIG_FILE = '.circleci/config.yml' PROCESSED_FILE = '.circleci/processed.yml' XFORMED_FILE = '.circleci/xformed.yml' CCI_CONFIG = {} PROCESSED_CONFIG = {} XFORMED_CONFIG = {} XFORMED_JOB_NAME = 'mms_xformed_job' BLACKLISTED_STEPS = ['persist_to_workspace', 'attach_workspace', 'store_artifacts'] # Read CircleCI's config with open(CCI_CONFIG_FILE) as fstream: try: CCI_CONFIG = yaml.safe_load(fstream) except yaml.YAMLError as err: print(err) # Create processed YAML using circleci cli's 'config process' commands PROCESS_CONFIG_CMD = 'circleci config process {} > {}'.format(CCI_CONFIG_FILE, PROCESSED_FILE) print("Executing command : ", PROCESS_CONFIG_CMD) subprocess.check_call(PROCESS_CONFIG_CMD, shell=True) # Read the processed config with open(PROCESSED_FILE) as fstream: try: PROCESSED_CONFIG = yaml.safe_load(fstream) except yaml.YAMLError as err: print(err) # All executors available in the config file available_executors = list(CCI_CONFIG['executors']) # All jobs available under the specified workflow jobs_in_workflow = PROCESSED_CONFIG['workflows'][workflow]['jobs'] def get_processed_job_sequence(processed_job_name): """ Recursively iterate over jobs in the workflow to generate an ordered list of parent jobs """ jobs_in_sequence = [] job_dict = next((jd for jd in jobs_in_workflow \ if isinstance(jd, dict) and processed_job_name == list(jd)[0]), None) if job_dict: # Find all parent jobs, recurse to find their respective ancestors parent_jobs = job_dict[processed_job_name].get('requires', []) for pjob in parent_jobs: jobs_in_sequence += get_processed_job_sequence(pjob) return jobs_in_sequence + [processed_job_name] def get_jobs_to_exec(job_name): """ Returns a dictionary of executors and a list of jobs to be executed in them """ jobs_dict = {} executors = [executor] if executor else available_executors for exectr_name in executors: if job_name is None: # List of all job names(as string) jobs_dict[exectr_name] = map(lambda j: j if isinstance(j, str) \ else list(j)[0], jobs_in_workflow) # Filter processed job names as per the executor # "job_name-executor_name" is a convention set in config.yml # pylint: disable=cell-var-from-loop jobs_dict[exectr_name] = filter(lambda j: exectr_name in j, jobs_dict[exectr_name]) else: # The list might contain duplicate parent jobs due to multiple fan-ins like config # - Remove the duplicates # "job_name-executor_name" is a convention set in config.yml jobs_dict[exectr_name] = \ OrderedDict.fromkeys(get_processed_job_sequence(job_name + '-' + exectr_name)) jobs_dict[exectr_name] = list(jobs_dict[exectr_name]) return jobs_dict # jobs_to_exec is a dict, with executor(s) as the key and list of jobs to be executed as its value jobs_to_exec = get_jobs_to_exec(job) def get_jobs_steps(steps, job_name): """ Merge all the steps from list of jobs to execute """ job_steps = PROCESSED_CONFIG['jobs'][job_name]['steps'] filtered_job_steps = list(filter(lambda step: list(step)[0] not in BLACKLISTED_STEPS, \ job_steps)) return steps + filtered_job_steps result_dict = {} for exectr, jobs in jobs_to_exec.items(): merged_steps = reduce(get_jobs_steps, jobs, []) # Create a new job, using the first job as a reference # This ensures configs like executor, environment, etc are maintained from the first job first_job = jobs[0] xformed_job = copy.deepcopy(PROCESSED_CONFIG['jobs'][first_job]) # Add the merged steps to this newly introduced job xformed_job['steps'] = merged_steps # Create a duplicate config(transformed) with the newly introduced job as the only job in config XFORMED_CONFIG = copy.deepcopy(PROCESSED_CONFIG) XFORMED_CONFIG['jobs'] = {} XFORMED_CONFIG['jobs'][XFORMED_JOB_NAME] = xformed_job # Create a transformed yaml with open(XFORMED_FILE, 'w+') as fstream: yaml.dump(XFORMED_CONFIG, fstream) try: # Locally execute the newly created job # This newly created job has all the steps (ordered and merged from steps in parent job(s)) LOCAL_EXECUTE_CMD = 'circleci local execute -c {} --job {}'.format(XFORMED_FILE, \ XFORMED_JOB_NAME) print('Executing command : ', LOCAL_EXECUTE_CMD) result_dict[exectr] = subprocess.check_call(LOCAL_EXECUTE_CMD, shell=True) except subprocess.CalledProcessError as err: result_dict[exectr] = err.returncode # Clean up, remove the processed and transformed yml files CLEANUP_CMD = 'rm {} {}'.format(PROCESSED_FILE, XFORMED_FILE) print('Executing command : ', CLEANUP_CMD) subprocess.check_call(CLEANUP_CMD, shell=True) # Print job execution details for exectr, retcode in result_dict.items(): colorcode, status = ('\033[0;37;42m', 'successful') if retcode == 0 \ else ('\033[0;37;41m', 'failed') print("{} Job execution {} using {} executor \x1b[0m".format(colorcode, status, exectr)) # Exit as per overall status SYS_EXIT_CODE = 0 if all(retcode == 0 for exectr, retcode in result_dict.items()) else 1 sys.exit(SYS_EXIT_CODE) ================================================ FILE: serving-sdk/checkstyle.xml ================================================ ================================================ FILE: serving-sdk/pom.xml ================================================ 4.0.0 software.amazon.ai mms-plugins-sdk jar 1.0.1 mms-plugins-sdk SDK for Model Server plugins The Apache Software License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0.txt repo scm:git:git://github.com/awslabs/multi-model-server.git scm:git:ssh://git@github.com:deep-learning-mms-bot/multi-model-server.git https://github.com/awslabs/multi-model-server/tree/master/serving-sdk HEAD Multi Model Server https://github.com/awslabs/multi-model-server modelserver.io http://maven.apache.org file://${project.build.directory}/repo true UTF-8 UTF-8 junit junit 4.13.1 test org.mockito mockito-all 1.10.19 test staging https://oss.sonatype.org/service/local/staging/deploy/maven2/ false org.apache.maven.plugins maven-checkstyle-plugin 3.1.0 checkstyle validate check ${project.basedir}/checkstyle.xml true org.apache.maven.plugins maven-pmd-plugin 3.12.0 compile check cpd-check org.apache.maven.plugins maven-source-plugin 3.1.0 attach-sources jar org.apache.maven.plugins maven-javadoc-plugin 3.1.0 attach-javadocs jar org.apache.maven.plugins maven-gpg-plugin 1.6 ${skip.gpg} sign-artifacts verify sign org.apache.maven.plugins maven-deploy-plugin 2.8.2 false deployrepo::default::${repo_url} org.apache.maven.plugins maven-compiler-plugin 8 8 ================================================ FILE: serving-sdk/src/main/java/software/amazon/ai/mms/servingsdk/Context.java ================================================ /* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package software.amazon.ai.mms.servingsdk; import java.util.Map; import java.util.Properties; /** * This interface provides access to the current running Model Server. */ public interface Context { /** * Get the configuration of the current running Model Server * @return Properties */ Properties getConfig(); /** * Get a list of Models registered with the Model Server * @return List of models */ Map getModels(); } ================================================ FILE: serving-sdk/src/main/java/software/amazon/ai/mms/servingsdk/Model.java ================================================ /* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package software.amazon.ai.mms.servingsdk; import java.util.List; /** * This provides information about the model which is currently registered with Model Server */ public interface Model { /** * Get the name of this model * @return The name of this model */ String getModelName(); /** * Get the URL of the Model location * @return models URL */ String getModelUrl(); /** * Get the model's entry-point * @return "handler" invoked to handle requests */ String getModelHandler(); /** * Returns the current list of workers for this model * @return list of Worker objects */ List getModelWorkers(); } ================================================ FILE: serving-sdk/src/main/java/software/amazon/ai/mms/servingsdk/ModelServerEndpoint.java ================================================ /* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package software.amazon.ai.mms.servingsdk; import software.amazon.ai.mms.servingsdk.http.Request; import software.amazon.ai.mms.servingsdk.http.Response; import java.io.IOException; /** * This class defines the abstract class for ModelServerEndpoint */ public abstract class ModelServerEndpoint { /** * This method is called when a HTTP GET method is invoked for the defined custom model server endpoint * @param req - Incoming request * @param res - Outgoing response * @param ctx - ModelServer's context which defines the current model-server system information * @throws IOException if I/O error occurs */ public void doGet(Request req, Response res, Context ctx) throws ModelServerEndpointException, IOException { throw new ModelServerEndpointException("No implementation found .. Default implementation invoked"); } /** * This method is called when a HTTP PUT method is invoked for the defined custom model server endpoint * @param req - Incoming request * @param res - Outgoing response * @param ctx - ModelServer's context which defines the current model-server system information * @throws IOException if I/O error occurs */ public void doPut(Request req, Response res, Context ctx) throws ModelServerEndpointException, IOException { throw new ModelServerEndpointException("No implementation found .. Default implementation invoked"); } /** * This method is called when a HTTP POST method is invoked for the defined custom model server endpoint * @param req - Incoming request * @param res - Outgoing response * @param ctx - ModelServer's context which defines the current model-server system information * @throws IOException if I/O error occurs */ public void doPost(Request req, Response res, Context ctx) throws ModelServerEndpointException, IOException { throw new ModelServerEndpointException("No implementation found .. Default implementation invoked"); } /** * This method is called when a HTTP DELETE method is invoked for the defined custom model server endpoint * @param req - Incoming request * @param res - Outgoing response * @param ctx - ModelServer's context which defines the current model-server system information * @throws IOException if I/O error occurs */ public void doDelete(Request req, Response res, Context ctx) throws ModelServerEndpointException, IOException { throw new ModelServerEndpointException("No implementation found .. Default implementation invoked"); } } ================================================ FILE: serving-sdk/src/main/java/software/amazon/ai/mms/servingsdk/ModelServerEndpointException.java ================================================ /* * Copyright (c) 2019 Amazon.com, Inc. or its affiliates. * All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. A copy of the License is located at * http://aws.amazon.com/apache2.0/ or in the "license" file accompanying this file. This file is distributed * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for * the specific language governing permissions and limitations under the License. */ package software.amazon.ai.mms.servingsdk; /** * Runtime exception for custom model server endpoint plugins */ public class ModelServerEndpointException extends RuntimeException { public ModelServerEndpointException(String err) {super(err);} public ModelServerEndpointException(String err, Throwable t) {super(err, t);} } ================================================ FILE: serving-sdk/src/main/java/software/amazon/ai/mms/servingsdk/Worker.java ================================================ /* * Copyright (c) 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache * License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of * the License is located at http://aws.amazon.com/apache2.0/ or in the "license" file accompanying this file. This file * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. See the License for the specific language governing permissionsand limitations under the License. */ package software.amazon.ai.mms.servingsdk; /** * Describe the model worker */ public interface Worker { /** * Get the current running status of this model's worker * @return True - if the worker is currently running. False - the worker is currently not running. */ boolean isRunning(); /** * Get the current memory foot print of this worker * @return Current memory usage of this worker */ long getWorkerMemory(); } ================================================ FILE: serving-sdk/src/main/java/software/amazon/ai/mms/servingsdk/annotations/Endpoint.java ================================================ /* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions * and limitations under the License. */ package software.amazon.ai.mms.servingsdk.annotations; import software.amazon.ai.mms.servingsdk.annotations.helpers.EndpointTypes; import java.lang.annotation.Documented; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; @Target(ElementType.TYPE) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface Endpoint { /** * @return URL pattern to which this class applies */ String urlPattern() default ""; /** * @return Type of this endpoint. Default NONE */ EndpointTypes endpointType() default EndpointTypes.NONE; /** * @return Description of this endpoint. Default "" */ String description() default ""; } ================================================ FILE: serving-sdk/src/main/java/software/amazon/ai/mms/servingsdk/annotations/helpers/EndpointTypes.java ================================================ /* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions * and limitations under the License. */ package software.amazon.ai.mms.servingsdk.annotations.helpers; /** * Types of ModelServer endpoints */ public enum EndpointTypes { NONE, INFERENCE, MANAGEMENT; } ================================================ FILE: serving-sdk/src/main/java/software/amazon/ai/mms/servingsdk/http/Request.java ================================================ /* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package software.amazon.ai.mms.servingsdk.http; import java.io.IOException; import java.io.InputStream; import java.util.List; import java.util.Map; /** * This defines the request object given to the custom endpoint */ public interface Request { /** * Get all header names in the request object * @return List of request header names */ List getHeaderNames(); /** * Get the URI of the request * @return URI of the endpoint */ String getRequestURI(); /** * Get all query parameters coming in for this endpoint * @return a dictionary of all the parameters in the query */ Map> getParameterMap(); /** * Get a query parameter * @param k - Parameter name * @return - value of the parameter */ List getParameter(String k); /** * Get the content-type of the incoming request object * @return content-type string in the request */ String getContentType(); /** * Get the body content stream of the incoming request * @return the request content input stream * @throws IOException if there is an I/O error */ InputStream getInputStream() throws IOException; } ================================================ FILE: serving-sdk/src/main/java/software/amazon/ai/mms/servingsdk/http/Response.java ================================================ /* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package software.amazon.ai.mms.servingsdk.http; import java.io.IOException; import java.io.OutputStream; /** * Interface defining the response object sent to the custom defined endpoints */ public interface Response { /** * Set HTTP response status * @param sc - status code */ void setStatus(int sc); /** * Set HTTP response status code and status phrase * @param sc - Integer value representing the status code of this response * @param phrase - String phrase representing the status phrase of this response */ void setStatus(int sc, String phrase); /** * Set HTTP headers * @param k - Header name * @param v - Header value */ void setHeader(String k, String v); /** * Add HTTP headers for an existing header name * @param k - Header name * @param v - Header value */ void addHeader(String k, String v); /** * Set content type header in the response object * @param ct - Content-Type */ void setContentType(String ct); /** * Get the output stream object for response * @return response body content as OutputStream * @throws IOException if I/O error occurs */ OutputStream getOutputStream() throws IOException; } ================================================ FILE: serving-sdk/src/test/java/software/amazon/ai/mms/servingsdk/ModelServerEndpointTest.java ================================================ /* * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at * * http://aws.amazon.com/apache2.0/ * * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ package software.amazon.ai.mms.servingsdk; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; import org.mockito.stubbing.Answer; import software.amazon.ai.mms.servingsdk.annotations.Endpoint; import software.amazon.ai.mms.servingsdk.annotations.helpers.EndpointTypes; import software.amazon.ai.mms.servingsdk.http.Request; import software.amazon.ai.mms.servingsdk.http.Response; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Properties; /** * Unit test for simple App. */ public class ModelServerEndpointTest { Context c; Model m; Worker w; ModelServerEndpoint mse; Request req; Response rsp; ByteArrayOutputStream outputStream; Endpoint ea; @Before public void beforeSuite() throws IOException { c = Mockito.mock(Context.class); m = Mockito.mock(Model.class); w = Mockito.mock(Worker.class); ea = Mockito.mock(Endpoint.class); mse = Mockito.mock(ModelServerEndpoint.class); req = Mockito.mock(Request.class); rsp = Mockito.mock(Response.class); outputStream = new ByteArrayOutputStream(); Properties p = new Properties(); HashMap map = new HashMap<>(); List l = new ArrayList<>(); map.put("squeezenet", m); p.setProperty("Hello", "World"); c.getConfig(); l.add(w); Mockito.when(c.getConfig()).thenReturn(p); Mockito.when(c.getModels()).thenReturn(map); Mockito.when(m.getModelWorkers()).thenReturn(l); Mockito.when(m.getModelHandler()).thenReturn("mxnet_service:handle"); Mockito.when(m.getModelUrl()).thenReturn("/tmp/model/squeezenet.mar"); Mockito.when(m.getModelName()).thenReturn("squeezenet"); Mockito.when(w.getWorkerMemory()).thenReturn((long)100); Mockito.when(w.isRunning()).thenReturn(false); Mockito.when(ea.urlPattern()).thenReturn("myEndpoint"); Mockito.when(ea.description()).thenReturn("This is a test endpoint"); Mockito.when(ea.endpointType()).thenReturn(EndpointTypes.INFERENCE); Mockito.when(rsp.getOutputStream()).thenReturn(outputStream); } @Test public void test() throws IOException { testContextInterface(); testEndpointAnnotation(); testEndpointInterface(); } private void testEndpointInterface() throws IOException { Class ep = ModelServerEndpoint.class; Assert.assertEquals(4, ep.getDeclaredMethods().length); for(Method m : ep.getDeclaredMethods()) { switch (m.getName()) { case "doGet": case "doPost": case "doDelete": case "doPut": break; default: Assert.fail("Invalid method found"); } } // Check signatures Mockito.doAnswer((Answer) i -> { Object rq = i.getArguments()[0]; Object rs = i.getArguments()[1]; Object ctx = i.getArguments()[2]; ((Response)rs).getOutputStream().write("This is a test".getBytes()); return null; }).when(mse).doGet(req, rsp, c); mse.doGet(req, rsp, c); Assert.assertEquals("This is a test", outputStream.toString()); outputStream.reset(); // Check signatures Mockito.doAnswer((Answer) i -> { Object rq = i.getArguments()[0]; Object rs = i.getArguments()[1]; Object ctx = i.getArguments()[2]; ((Response)rs).getOutputStream().write("This is a test".getBytes()); return null; }).when(mse).doPost(req, rsp, c); mse.doPost(req, rsp, c); Assert.assertEquals("This is a test", outputStream.toString()); outputStream.reset(); // Check signatures Mockito.doAnswer((Answer) i -> { Object rq = i.getArguments()[0]; Object rs = i.getArguments()[1]; Object ctx = i.getArguments()[2]; ((Response)rs).getOutputStream().write("This is a test".getBytes()); return null; }).when(mse).doPut(req, rsp, c); mse.doPut(req, rsp, c); Assert.assertEquals("This is a test", outputStream.toString()); outputStream.reset(); // Check signatures Mockito.doAnswer((Answer) i -> { Object rq = i.getArguments()[0]; Object rs = i.getArguments()[1]; Object ctx = i.getArguments()[2]; ((Response)rs).getOutputStream().write("This is a test".getBytes()); return null; }).when(mse).doDelete(req, rsp, c); mse.doDelete(req, rsp, c); Assert.assertEquals("This is a test", outputStream.toString()); } private void testEndpointAnnotation() { Assert.assertEquals(3, Endpoint.class.getDeclaredMethods().length); Assert.assertEquals("myEndpoint", ea.urlPattern()); Assert.assertEquals(EndpointTypes.INFERENCE, ea.endpointType()); Assert.assertEquals("This is a test endpoint", ea.description()); Assert.assertEquals(3, EndpointTypes.class.getFields().length); } private void testWorkerInterface(Worker w) { Assert.assertNotNull(w); Assert.assertFalse( w.isRunning()); Assert.assertEquals(100, w.getWorkerMemory()); Assert.assertEquals(2, Worker.class.getDeclaredMethods().length); } private void testModelInterface(Model m) { Assert.assertEquals("squeezenet", m.getModelName()); Assert.assertEquals("/tmp/model/squeezenet.mar", m.getModelUrl()); Assert.assertEquals("mxnet_service:handle", m.getModelHandler()); Assert.assertEquals(1, m.getModelWorkers().size()); Assert.assertEquals(4, Model.class.getDeclaredMethods().length); testWorkerInterface(m.getModelWorkers().get(0)); } private void testContextInterface() { Assert.assertNotNull(c.getModels()); Assert.assertTrue(c.getModels().containsKey("squeezenet")); Assert.assertTrue(c.getConfig().containsKey("Hello")); Assert.assertEquals("World", c.getConfig().getProperty("Hello")); Assert.assertEquals(2, Context.class.getDeclaredMethods().length); testModelInterface(c.getModels().get("squeezenet")); } } ================================================ FILE: setup.py ================================================ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. # To build and upload a new version, follow the steps below. # Notes: # - this is a "Universal Wheels" package that is pure Python and supports both Python2 and Python3 # - Twine is a secure PyPi upload package # - Make sure you have bumped the version! at mms/version.py # $ pip install twine # $ pip install wheel # $ python setup.py bdist_wheel --universal # *** TEST YOUR PACKAGE WITH TEST PI ****** # twine upload --repository-url https://test.pypi.org/legacy/ dist/* # If this is successful then push it to actual pypi # $ twine upload dist/* """ Setup.py for the model server package """ import errno import os import subprocess import sys from datetime import date from shutil import copy2, rmtree import setuptools.command.build_py from setuptools import setup, find_packages, Command import mms pkgs = find_packages() def pypi_description(): """ Imports the long description for the project page """ with open('PyPiDescription.rst') as df: return df.read() def detect_model_server_version(): sys.path.append(os.path.abspath("mms")) if "--release" in sys.argv: sys.argv.remove("--release") return mms.__version__.strip() return mms.__version__.strip() + 'b' + str(date.today()).replace('-', '') class BuildFrontEnd(setuptools.command.build_py.build_py): """ Class defined to run custom commands. """ description = 'Build Model Server Frontend' source_server_file = os.path.abspath('frontend/server/build/libs/server-1.0.jar') dest_file_name = os.path.abspath('mms/frontend/model-server.jar') # noinspection PyMethodMayBeStatic def run(self): """ Actual method called to run the build command :return: """ front_end_bin_dir = os.path.abspath('.') + '/mms/frontend' try: os.mkdir(front_end_bin_dir) except OSError as exc: if exc.errno == errno.EEXIST and os.path.isdir(front_end_bin_dir): pass else: raise if os.path.exists(self.source_server_file): os.remove(self.source_server_file) # Remove build/lib directory. if os.path.exists('build/lib/'): rmtree('build/lib/') try: subprocess.check_call('frontend/gradlew -p frontend clean build', shell=True) except OSError: assert 0, "build failed" copy2(self.source_server_file, self.dest_file_name) class BuildPy(setuptools.command.build_py.build_py): """ Class to invoke the custom command defined above. """ def run(self): sys.stderr.flush() self.run_command('build_frontend') setuptools.command.build_py.build_py.run(self) class BuildPlugins(Command): description = 'Build Model Server Plugins' user_options = [('plugins=', 'p', 'Plugins installed')] source_plugin_dir = \ os.path.abspath('plugins/build/plugins') def initialize_options(self): self.plugins = None def finalize_options(self): if self.plugins is None: print("No plugin option provided. Defaulting to 'default'") self.plugins = "default" # noinspection PyMethodMayBeStatic def run(self): if os.path.isdir(self.source_plugin_dir): rmtree(self.source_plugin_dir) try: if self.plugins == "sagemaker": subprocess.check_call('plugins/gradlew -p plugins clean bS', shell=True) else: raise OSError("No such rule exists") except OSError: assert 0, "build failed" self.run_command('build_py') if __name__ == '__main__': version = detect_model_server_version() requirements = ['Pillow', 'psutil', 'future', 'model-archiver'] setup( name='multi-model-server', version=version, description='Multi Model Server is a tool for serving neural net models for inference', author='Trinity team', author_email='noreply@amazon.com', long_description=pypi_description(), url='https://github.com/awslabs/multi-model-server', keywords='Multi Model Server Serving Deep Learning Inference AI', packages=pkgs, cmdclass={ 'build_frontend': BuildFrontEnd, 'build_plugins': BuildPlugins, 'build_py': BuildPy, }, install_requires=requirements, extras_require={ 'mxnet-mkl': ['mxnet-mkl'], 'mxnet-cu90mkl': ['mxnet-cu90mkl'], 'mxnet': ['mxnet'], }, entry_points={ 'console_scripts': [ 'multi-model-server=mms.model_server:start', 'mxnet-model-server=mms.model_server:old_start', 'multi-model-export=mms.export_model:main' ] }, include_package_data=True, license='Apache License Version 2.0' ) ================================================ FILE: test/README.md ================================================ # MMS Regression Tests This folder contains regression tests executed against MMS master.These tests use [POSTMAN](https://www.postman.com/downloads/) for exercising all the Management & Inference APIs. ### Running the test manually. Pull multi-model-server pre build docker image ``` docker pull awsdeeplearningteam/multi-model-server ``` This would build a docker Image with a awsdeeplearningteam/multi-model-server:latest in which we would run our Regression Tests. ``` docker run -it --user root awsdeeplearningteam/multi-model-server:latest /bin/bash ``` In the Docker CLI execute the following cmds. ``` apt-get update apt-get install -y git wget sudo git clone https://github.com/awslabs/multi-model-server.git cd multi-model-server ``` To execute tests on master run: `./test/regression_tests.sh ` To execute tests on different run: `./test/regression_tests.sh ` You can view the logs for Test execution & the Multi-model-server in the /tmp dir. ``` cat /tmp/test_exec.log cat /tmp/mms.log ``` ### Adding tests To add to the tests, import a collection (in /postman) to Postman and add new requests. Specifically to test for inference against a new model * Open /postman/inference_data.json * Add new json object with the new model url and payload. ![POSTMAN UI](screenshot/postman.png) Afterwards, export the collection as a v2.1 collection and replace the existing exported collection. To add a new suite of tests, add a new collection to /postman and update regression_tests.sh to run the new collection and buldsepc.yml to keep track of the report. ================================================ FILE: test/postman/environment.json ================================================ { "id": "b9eb6b86-585b-4fca-abd0-972ec686e3e8", "name": "Multi-Model-server", "values": [ { "key": "hostname", "value": "localhost", "enabled": true }, { "key": "protocol", "value": "http", "enabled": true }, { "key": "mgmt-port", "value": "8081", "enabled": true }, { "key": "model-name", "value": "densenet161", "enabled": true }, { "key": "pred-port", "value": "8080", "enabled": true }, { "key": "sec-mgmt-port", "value": "8444", "enabled": true }, { "key": "sec-pred-port", "value": "8443", "enabled": true } ], "_postman_variable_scope": "environment", "_postman_exported_at": "2020-05-15T07:34:40.974Z", "_postman_exported_using": "Postman/7.24.0" } ================================================ FILE: test/postman/https_test_collection.json ================================================ { "info": { "_postman_id": "cd002524-ce12-4ff7-ab75-932a05403f83", "name": "MMS - https_test_collection", "schema": "https://schema.getpostman.com/json/collection/v2.1.0/collection.json" }, "item": [ { "name": "HTTPS Inference API Description", "event": [ { "listen": "test", "script": { "id": "638ec081-ebf4-4634-a9ea-f675613c2127", "exec": [ "pm.test(\"Status code is 200\", function () {", " pm.response.to.have.status(200);", "});" ], "type": "text/javascript" } } ], "request": { "method": "OPTIONS", "header": [], "url": { "raw": "https://{{hostname}}:{{sec-pred-port}}", "protocol": "https", "host": [ "{{hostname}}" ], "port": "{{sec-pred-port}}" } }, "response": [] }, { "name": "HTTPS Management API Description", "event": [ { "listen": "test", "script": { "id": "81b8730a-0b89-4569-b042-1076266563ba", "exec": [ "pm.test(\"Status code is 200\", function () {", " pm.response.to.have.status(200);", "});" ], "type": "text/javascript" } } ], "request": { "method": "OPTIONS", "header": [], "url": { "raw": "https://{{hostname}}:{{sec-pred-port}}", "protocol": "https", "host": [ "{{hostname}}" ], "port": "{{sec-pred-port}}" } }, "response": [] }, { "name": "HTTPS Register Model - SqueezeNet", "event": [ { "listen": "test", "script": { "id": "7e9a9528-e9f1-446c-b1ba-7dce112ffa30", "exec": [ "pm.test(\"Successful POST request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "POST", "header": [], "url": { "raw": "https://{{hostname}}:{{sec-mgmt-port}}/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar&model_name=squeezenetv1_1&initial_workers=1&synchronous=true", "protocol": "https", "host": [ "{{hostname}}" ], "port": "{{sec-mgmt-port}}", "path": [ "models" ], "query": [ { "key": "url", "value": "https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar" }, { "key": "model_name", "value": "squeezenetv1_1" }, { "key": "initial_workers", "value": "1" }, { "key": "synchronous", "value": "true" } ] } }, "response": [] }, { "name": "HTTPS Get SqueezeNet Model Description", "event": [ { "listen": "test", "script": { "id": "cd11b9cc-335a-415a-8315-54ddda6f6d8a", "exec": [ "pm.test(\"Successful GET request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "GET", "header": [], "url": { "raw": "https://{{hostname}}:{{sec-mgmt-port}}/models/squeezenetv1_1", "protocol": "https", "host": [ "{{hostname}}" ], "port": "{{sec-mgmt-port}}", "path": [ "models", "squeezenetv1_1" ] } }, "response": [] }, { "name": "HTTPS Scale up Workers - Synchronous", "event": [ { "listen": "test", "script": { "id": "ba06ca94-e630-4f72-adf4-49d1a80ded31", "exec": [ "pm.test(\"Successful PUT request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200, 201, 202]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "PUT", "header": [], "url": { "raw": "https://{{hostname}}:{{sec-mgmt-port}}/models/squeezenetv1_1?min_worker=5&max_worker=5&synchronous=true", "protocol": "https", "host": [ "{{hostname}}" ], "port": "{{sec-mgmt-port}}", "path": [ "models", "squeezenetv1_1" ], "query": [ { "key": "min_worker", "value": "5" }, { "key": "max_worker", "value": "5" }, { "key": "synchronous", "value": "true" } ] } }, "response": [] }, { "name": "HTTPS Scale up Workers - Asynchronous", "event": [ { "listen": "test", "script": { "id": "5a13e310-65a9-4982-84c7-d89f29d6c27a", "exec": [ "pm.test(\"Successful PUT request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200,201,202]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "PUT", "header": [], "url": { "raw": "https://{{hostname}}:{{sec-mgmt-port}}/models/squeezenetv1_1?min_worker=6&max_worker=6&synchronous=false", "protocol": "https", "host": [ "{{hostname}}" ], "port": "{{sec-mgmt-port}}", "path": [ "models", "squeezenetv1_1" ], "query": [ { "key": "min_worker", "value": "6" }, { "key": "max_worker", "value": "6" }, { "key": "synchronous", "value": "false" } ] } }, "response": [] }, { "name": "HTTPS - Inference - SqueezeNet", "event": [ { "listen": "test", "script": { "id": "7c1c4eaa-48f8-4734-8737-78b4b2766b29", "exec": [ "pm.test(\"Status code is 200\", function () {", " pm.response.to.have.status(200);", "});" ], "type": "text/javascript" } } ], "protocolProfileBehavior": { "disabledSystemHeaders": { "content-type": true } }, "request": { "method": "POST", "header": [], "body": { "mode": "file", "file": { "src": "../examples/image_classifier/kitten.jpg" }, "options": { "raw": { "language": "text" } } }, "url": { "raw": "https://{{hostname}}:{{sec-pred-port}}/predictions/squeezenetv1_1", "protocol": "https", "host": [ "{{hostname}}" ], "port": "{{sec-pred-port}}", "path": [ "predictions", "squeezenetv1_1" ] } }, "response": [] }, { "name": "HTTPS UnRegister Model SqueezeNet", "event": [ { "listen": "test", "script": { "id": "de94e7b6-d4fa-4e10-8b54-4052753de19e", "exec": [ "pm.test(\"Successful DELETE request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "DELETE", "header": [], "url": { "raw": "https://{{hostname}}:{{sec-mgmt-port}}/models/squeezenetv1_1", "protocol": "https", "host": [ "{{hostname}}" ], "port": "{{sec-mgmt-port}}", "path": [ "models", "squeezenetv1_1" ] } }, "response": [] } ], "protocolProfileBehavior": {} } ================================================ FILE: test/postman/inference_api_test_collection.json ================================================ { "info": { "_postman_id": "e69000d9-d3c8-49bd-879a-ad42b95b042a", "name": "MMS - inference_api_collection", "schema": "https://schema.getpostman.com/json/collection/v2.1.0/collection.json" }, "item": [ { "name": "Model Zoo - Register Model", "event": [ { "listen": "test", "script": { "id": "80fa33ea-ff6a-4535-9328-ddffb980062a", "exec": [ "pm.test(\"Successful POST request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "POST", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models?url={{url}}&model_name={{model_name}}&initial_workers={{worker}}&synchronous={{synchronous}}", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models" ], "query": [ { "key": "url", "value": "{{url}}" }, { "key": "model_name", "value": "{{model_name}}" }, { "key": "initial_workers", "value": "{{worker}}" }, { "key": "synchronous", "value": "{{synchronous}}" } ] } }, "response": [] }, { "name": "Model Zoo - Inference Model", "event": [ { "listen": "test", "script": { "id": "d6b1f2cf-6ffb-4850-b276-108f7f65fbd9", "exec": [ "var type_response = pm.iterationData.get(\"content-type\");", "validators = {", " image_classification: validate_image_classification,", " default_json: validate_default", "};", "", "pm.test(\"Successful POST request\", function() {", " pm.expect(pm.response.code).to.be.oneOf([200]);", "});", "", "if (type_response === \"text/plain\") {", " pm.test(\"Test expected TEXT response\", function() {", " pm.response.to.have.body(pm.iterationData.get(\"expected\"));", " });", "", "} else if (type_response === \"application/json\") {", " if (pm.iterationData.has(\"validator\")) {", " var validator = pm.iterationData.get(\"validator\"); ", " } else {", " var validator = \"default_json\";", " }", " pm.test(\"Test expected JSON response\", function() {", " var actual_obj = pm.response.json();", " var expected_obj = pm.iterationData.get(\"expected\");", " pm.expect(validators[validator](actual_obj, expected_obj)).to.be.true;", " });", "", "}", "", "function get_tolerance_value(expected_val) {", " var tolerance_percent = pm.iterationData.get(\"tolerance\")", " return (expected_val * tolerance_percent) / 100;", " ", "}", "", "function validate_image_classification(actual_obj,expected_obj) {", " if (_.size(expected_obj) != _.size(actual_obj)) {", " return false;", " }", "", " for (i = 0; i < expected_obj.length; i += 1) {", " if(actual_obj[i][\"class\"] !== expected_obj[i][\"class\"]){", " return false", " }", "", " expected_val = expected_obj[i][\"probability\"]", " actual_val = actual_obj[i][\"probability\"]", "", " tolerance_value = get_tolerance_value(expected_val);", " if (!(Math.abs(expected_val - actual_val) < tolerance_value)) {", " return false;", " }", " }", " return true;", "}", "", "", "", "/* Simple and nested json object can be compared using validate_default when key and value are constant.", "-Notes-", "The order of keys within an object may change.", "If the output is array of objects then the objects compared are positional and cannot change order.", "*/", "function validate_default(actual_obj, expected_obj) {", " return _.isEqual(actual_obj, expected_obj);", "}", "" ], "type": "text/javascript" } } ], "protocolProfileBehavior": { "disabledSystemHeaders": { "content-type": true } }, "request": { "method": "POST", "header": [], "body": { "mode": "file", "file": { "src": "{{file}}" } }, "url": { "raw": "{{protocol}}://{{hostname}}:{{pred-port}}/predictions/{{model_name}}", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{pred-port}}", "path": [ "predictions", "{{model_name}}" ] } }, "response": [] }, { "name": "Model Zoo - Unregister model", "event": [ { "listen": "test", "script": { "id": "a14dd390-4176-45e7-af00-999676685f4a", "exec": [ "pm.test(\"Successful DELETE request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200,201,202]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "DELETE", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models/{{model_name}}", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models", "{{model_name}}" ] } }, "response": [] } ], "protocolProfileBehavior": {} } ================================================ FILE: test/postman/inference_data.json ================================================ [ { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/alexnet.mar", "model_name":"alexnet", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":0.4766186773777008, "class":"n02127052 lynx, catamount" }, { "probability":0.20687074959278107, "class":"n02128757 snow leopard, ounce, Panthera uncia" }, { "probability":0.135288268327713, "class":"n02124075 Egyptian cat" }, { "probability":0.09019536525011063, "class":"n02123045 tabby, tabby cat" }, { "probability":0.04814659804105759, "class":"n02123159 tiger cat" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/caffenet.mar", "model_name":"caffenet", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":0.7166884541511536, "class":"n02127052 lynx, catamount" }, { "probability":0.09750154614448547, "class":"n02123045 tabby, tabby cat" }, { "probability":0.0745730996131897, "class":"n02123159 tiger cat" }, { "probability":0.06743090599775314, "class":"n02124075 Egyptian cat" }, { "probability":0.03200334310531616, "class":"n02128757 snow leopard, ounce, Panthera uncia" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-inception_v1.mar", "model_name":"inception_v1", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":0.3833286464214325, "class":"n02123159 tiger cat" }, { "probability":0.33825597167015076, "class":"n02124075 Egyptian cat" }, { "probability":0.17002010345458984, "class":"n02123045 tabby, tabby cat" }, { "probability":0.09881989657878876, "class":"n02127052 lynx, catamount" }, { "probability":0.001223195344209671, "class":"n02123394 Persian cat" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/inception-bn.mar", "model_name":"inception-bn", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":0.5972931981086731, "class":"n02123045 tabby, tabby cat" }, { "probability":0.30086880922317505, "class":"n02123159 tiger cat" }, { "probability":0.08586657792329788, "class":"n02124075 Egyptian cat" }, { "probability":0.009726772084832191, "class":"n02127052 lynx, catamount" }, { "probability":0.0008877559448592365, "class":"n03598930 jigsaw puzzle" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-mobilenet.mar", "model_name":"mobilenet", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":216.0907745361328, "class":"n02948072 candle, taper, wax light" }, { "probability":186.8531494140625, "class":"n04456115 torch" }, { "probability":178.81483459472656, "class":"n03347037 fire screen, fireguard" }, { "probability":169.5437469482422, "class":"n03666591 lighter, light, igniter, ignitor" }, { "probability":145.08824157714844, "class":"n09472597 volcano" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/nin.mar", "model_name":"nin", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":0.978486180305481, "class":"n02123045 tabby, tabby cat" }, { "probability":0.021282507106661797, "class":"n02123159 tiger cat" }, { "probability":9.809057519305497e-05, "class":"n02124075 Egyptian cat" }, { "probability":8.281067130155861e-05, "class":"n02128385 leopard, Panthera pardus" }, { "probability":2.5670484319562092e-05, "class":"n02128757 snow leopard, ounce, Panthera uncia" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/resnet-152.mar", "model_name":"resnet-152", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":0.7149006128311157, "class":"n02123045 tabby, tabby cat" }, { "probability":0.22877003252506256, "class":"n02123159 tiger cat" }, { "probability":0.040323738008737564, "class":"n02124075 Egyptian cat" }, { "probability":0.008370742201805115, "class":"n02127052 lynx, catamount" }, { "probability":0.00067278987262398, "class":"n02129604 tiger, Panthera tigris" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/resnet-18.mar", "model_name":"resnet-18", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":0.3633013665676117, "class":"n02123159 tiger cat" }, { "probability":0.2988913953304291, "class":"n02124075 Egyptian cat" }, { "probability":0.18073132634162903, "class":"n02123045 tabby, tabby cat" }, { "probability":0.07343611121177673, "class":"n02127052 lynx, catamount" }, { "probability":0.02035967819392681, "class":"n02128757 snow leopard, ounce, Panthera uncia" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/resnext-101-64x4d.mar", "model_name":"resnext101", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":0.5961099863052368, "class":"n02123159 tiger cat" }, { "probability":0.4005117118358612, "class":"n02123045 tabby, tabby cat" }, { "probability":0.0012956340797245502, "class":"n02124075 Egyptian cat" }, { "probability":0.0011538631515577435, "class":"n04074963 remote control, remote" }, { "probability":0.000405249185860157, "class":"n04286575 spotlight, spot" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet18v1.mar", "model_name":"resnet18-v1", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":16.35452651977539, "class":"n02999410 chain" }, { "probability":16.15776824951172, "class":"n10148035 groom, bridegroom" }, { "probability":15.857562065124512, "class":"n04141076 sax, saxophone" }, { "probability":15.083999633789062, "class":"n04507155 umbrella" }, { "probability":14.938633918762207, "class":"n09229709 bubble" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet34v1.mar", "model_name":"resnet34-v1", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":12.617600440979004, "class":"n04286575 spotlight, spot" }, { "probability":10.216148376464844, "class":"n03637318 lampshade, lamp shade" }, { "probability":9.676478385925293, "class":"n03942813 ping-pong ball" }, { "probability":9.529446601867676, "class":"n02708093 analog clock" }, { "probability":9.494606018066406, "class":"n03691459 loudspeaker, speaker, speaker unit, loudspeaker system, speaker system" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet50v1.mar", "model_name":"resnet50-v1", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":177.36087036132812, "class":"n03982430 pool table, billiard table, snooker table" }, { "probability":174.36256408691406, "class":"n03942813 ping-pong ball" }, { "probability":172.44488525390625, "class":"n03661043 library" }, { "probability":163.6439971923828, "class":"n02788148 bannister, banister, balustrade, balusters, handrail" }, { "probability":159.4976043701172, "class":"n03065424 coil, spiral, volute, whorl, helix" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet101v1.mar", "model_name":"resnet101-v1", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":107.36915588378906, "class":"n02823428 beer bottle" }, { "probability":99.51375579833984, "class":"n04485082 tripod" }, { "probability":95.99050903320312, "class":"n04069434 reflex camera" }, { "probability":84.3740463256836, "class":"n04557648 water bottle" }, { "probability":83.7496566772461, "class":"n02841315 binoculars, field glasses, opera glasses" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet152v1.mar", "model_name":"resnet152-v1", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":61.71363067626953, "class":"n07930864 cup" }, { "probability":60.88921356201172, "class":"n03832673 notebook, notebook computer" }, { "probability":60.11431121826172, "class":"n03691459 loudspeaker, speaker, speaker unit, loudspeaker system, speaker system" }, { "probability":53.656005859375, "class":"n07920052 espresso" }, { "probability":53.115779876708984, "class":"n03492542 hard disc, hard disk, fixed disk" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet18v2.mar", "model_name":"resnet18-v2", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":24.102333068847656, "class":"n04553703 washbasin, handbasin, washbowl, lavabo, wash-hand basin" }, { "probability":23.69866943359375, "class":"n04254120 soap dispenser" }, { "probability":19.214643478393555, "class":"n02747177 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin" }, { "probability":18.93875503540039, "class":"n03691459 loudspeaker, speaker, speaker unit, loudspeaker system, speaker system" }, { "probability":18.90488052368164, "class":"n04201297 shoji" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet34v2.mar", "model_name":"resnet34-v2", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":9.2193603515625, "class":"n02708093 analog clock" }, { "probability":7.803028583526611, "class":"n03782006 monitor" }, { "probability":7.681037425994873, "class":"n04286575 spotlight, spot" }, { "probability":7.129834175109863, "class":"n03028079 church, church building" }, { "probability":7.0597100257873535, "class":"n04152593 screen, CRT screen" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet50v2.mar", "model_name":"resnet50-v2", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":15.76949691772461, "class":"n06359193 web site, website, internet site, site" }, { "probability":14.25102710723877, "class":"n07565083 menu" }, { "probability":13.321331024169922, "class":"n03944341 pinwheel" }, { "probability":11.99173641204834, "class":"n06596364 comic book" }, { "probability":11.768353462219238, "class":"n03291819 envelope" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet101v2.mar", "model_name":"resnet101-v2", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":7.887596130371094, "class":"n04380533 table lamp" }, { "probability":7.870771884918213, "class":"n03627232 knot" }, { "probability":7.605418682098389, "class":"n02093859 Kerry blue terrier" }, { "probability":7.55618143081665, "class":"n04033995 quilt, comforter, comfort, puff" }, { "probability":7.256267547607422, "class":"n03884397 panpipe, pandean pipe, syrinx" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-resnet152v2.mar", "model_name":"resnet152-v2", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":36.13174819946289, "class":"n03970156 plunger, plumber's helper" }, { "probability":29.110092163085938, "class":"n02708093 analog clock" }, { "probability":28.711875915527344, "class":"n04152593 screen, CRT screen" }, { "probability":28.073928833007812, "class":"n01930112 nematode, nematode worm, roundworm" }, { "probability":28.058176040649414, "class":"n04404412 television, television system" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/shufflenet.mar", "model_name":"shufflenet", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":0.0010000000474974513, "class":"n03792972 mountain tent" }, { "probability":0.0010000000474974513, "class":"n03773504 missile" }, { "probability":0.0010000000474974513, "class":"n03775071 mitten" }, { "probability":0.0010000000474974513, "class":"n03775546 mixing bowl" }, { "probability":0.0010000000474974513, "class":"n03776460 mobile home, manufactured home" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-squeezenet.mar", "model_name":"onnx-squeezenet", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":0.6482629179954529, "class":"n02124075 Egyptian cat" }, { "probability":0.2318260818719864, "class":"n02123045 tabby, tabby cat" }, { "probability":0.10045160353183746, "class":"n02123159 tiger cat" }, { "probability":0.013487360440194607, "class":"n02127052 lynx, catamount" }, { "probability":0.003135664388537407, "class":"n02128757 snow leopard, ounce, Panthera uncia" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar", "model_name":"squeezenet_v1.1", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":0.8582231402397156, "class":"n02124075 Egyptian cat" }, { "probability":0.09159984439611435, "class":"n02123045 tabby, tabby cat" }, { "probability":0.03748767822980881, "class":"n02123159 tiger cat" }, { "probability":0.006165081635117531, "class":"n02128385 leopard, Panthera pardus" }, { "probability":0.0031715999357402325, "class":"n02127052 lynx, catamount" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/vgg16.mar", "model_name":"vgg16", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":0.8178463578224182, "class":"n02123159 tiger cat" }, { "probability":0.12500879168510437, "class":"n02123045 tabby, tabby cat" }, { "probability":0.05412120372056961, "class":"n02124075 Egyptian cat" }, { "probability":0.0020657714921981096, "class":"n02127052 lynx, catamount" }, { "probability":0.0005614628316834569, "class":"n02128757 snow leopard, ounce, Panthera uncia" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-vgg16.mar", "model_name":"onnx-vgg16", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":101.13871765136719, "class":"n02124075 Egyptian cat" }, { "probability":89.77296447753906, "class":"n02123045 tabby, tabby cat" }, { "probability":88.40411376953125, "class":"n02123159 tiger cat" }, { "probability":76.17413330078125, "class":"n02125311 cougar, puma, catamount, mountain lion, painter, panther, Felis concolor" }, { "probability":74.09810638427734, "class":"n07930864 cup" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-vgg16_bn.mar", "model_name":"onnx-vgg16_bn", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":16.06477165222168, "class":"n03196217 digital clock" }, { "probability":13.64653491973877, "class":"n04286575 spotlight, spot" }, { "probability":13.565534591674805, "class":"n03692522 loupe, jeweler's loupe" }, { "probability":13.479013442993164, "class":"n03388043 fountain" }, { "probability":12.639715194702148, "class":"n03666591 lighter, light, igniter, ignitor" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/vgg19.mar", "model_name":"vgg19", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":0.5058671832084656, "class":"n02123159 tiger cat" }, { "probability":0.28164851665496826, "class":"n02124075 Egyptian cat" }, { "probability":0.20637290179729462, "class":"n02123045 tabby, tabby cat" }, { "probability":0.0046674045734107494, "class":"n02127052 lynx, catamount" }, { "probability":0.00011857607023557648, "class":"n03598930 jigsaw puzzle" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-vgg19.mar", "model_name":"onnx-vgg19", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":0.3949933350086212, "class":"n02124075 Egyptian cat" }, { "probability":0.34417903423309326, "class":"n02123159 tiger cat" }, { "probability":0.25566166639328003, "class":"n02123045 tabby, tabby cat" }, { "probability":0.0027325130067765713, "class":"n02127052 lynx, catamount" }, { "probability":0.0002417804644210264, "class":"n07930864 cup" } ], "tolerance":1 }, { "url":"https://s3.amazonaws.com/model-server/model_archive_1.0/onnx-vgg19_bn.mar", "model_name":"onnx-vgg19_bn", "worker":1, "synchronous":"true", "file":"test/resources/kitten.jpg", "content-type":"application/json", "validator":"image_classification", "expected":[ { "probability":17.26873016357422, "class":"n04589890 window screen" }, { "probability":16.25399398803711, "class":"n03347037 fire screen, fireguard" }, { "probability":16.093460083007812, "class":"n04286575 spotlight, spot" }, { "probability":16.02733039855957, "class":"n04590129 window shade" }, { "probability":15.910074234008789, "class":"n03637318 lampshade, lamp shade" } ], "tolerance":1 } ] ================================================ FILE: test/postman/management_api_test_collection.json ================================================ { "info": { "_postman_id": "5c45fe29-4bb1-4df4-9e13-a4fd410e4775", "name": "MMS - management_api_collection", "schema": "https://schema.getpostman.com/json/collection/v2.1.0/collection.json" }, "item": [ { "name": "Register Model", "event": [ { "listen": "test", "script": { "id": "4f706340-af84-431b-affd-6d6f53c89e80", "exec": [ "pm.test(\"Successful POST request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "POST", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar&model_name=squeezenet1_1", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models" ], "query": [ { "key": "url", "value": "https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar" }, { "key": "model_name", "value": "squeezenet1_1" } ] } }, "response": [] }, { "name": "Get Valid Model", "event": [ { "listen": "test", "script": { "id": "3d6c7158-f729-42cf-a2ec-4e62b7d764d0", "exec": [ "pm.test(\"Successful POST request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "GET", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models/squeezenet1_1", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models", "squeezenet1_1" ] } }, "response": [] }, { "name": "List Models", "event": [ { "listen": "test", "script": { "id": "54f1209e-4819-4f06-9d26-f39593a889d9", "exec": [ "pm.test(\"Successful POST request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "GET", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models" ] } }, "response": [] }, { "name": "Scale Min Workers - Asynchronous", "event": [ { "listen": "test", "script": { "id": "47ac7b07-04c2-4128-9662-387e26446f7f", "exec": [ "pm.test(\"Successful PUT request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200,201,202]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "PUT", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models/squeezenet1_1?min_worker=3", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models", "squeezenet1_1" ], "query": [ { "key": "min_worker", "value": "3" } ] } }, "response": [] }, { "name": "Scale Min Workers - Synchronous", "event": [ { "listen": "test", "script": { "id": "3a8cf27e-b0d9-4df9-bdcf-9eadb239807d", "exec": [ "pm.test(\"Successful PUT request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200,201,202]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "PUT", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models/squeezenet1_1?min_worker=4&synchronous=true", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models", "squeezenet1_1" ], "query": [ { "key": "min_worker", "value": "4" }, { "key": "synchronous", "value": "true" } ] } }, "response": [] }, { "name": "Scale Min Workers with GPU", "event": [ { "listen": "test", "script": { "id": "b73619d3-3abc-4628-8e52-ecbcda561191", "exec": [ "pm.test(\"Successful PUT request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200,201,202]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "PUT", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models/squeezenet1_1?min_worker=6&number_gpu=1", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models", "squeezenet1_1" ], "query": [ { "key": "min_worker", "value": "6" }, { "key": "number_gpu", "value": "1" } ] } }, "response": [] }, { "name": "UnRegister Model", "event": [ { "listen": "test", "script": { "id": "0e10d089-7183-4adf-ba81-d14678f5ecc7", "exec": [ "pm.test(\"Successful DELETE request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200,201,202]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "DELETE", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models/squeezenet1_1", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models", "squeezenet1_1" ] } }, "response": [] }, { "name": "Register Model with Additional Params", "event": [ { "listen": "test", "script": { "id": "5f79bffa-f260-440d-9d49-a3553972b8eb", "exec": [ "pm.test(\"Successful POST request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "POST", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar&model_name=squeezenet1_1&initial_workers=1&synchronous=true", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models" ], "query": [ { "key": "url", "value": "https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar" }, { "key": "model_name", "value": "squeezenet1_1" }, { "key": "initial_workers", "value": "1" }, { "key": "synchronous", "value": "true" } ] } }, "response": [] }, { "name": "Register Model Synchronous", "event": [ { "listen": "test", "script": { "id": "74b42a61-dda4-4d70-9543-e61a5be2dd9b", "exec": [ "pm.test(\"Successful POST request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "POST", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/resnet-18.mar&model_name=resnet18&synchronous=true", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models" ], "query": [ { "key": "url", "value": "https://s3.amazonaws.com/model-server/model_archive_1.0/resnet-18.mar" }, { "key": "model_name", "value": "resnet18" }, { "key": "synchronous", "value": "true" } ] } }, "response": [] }, { "name": "UnRegister Model", "event": [ { "listen": "test", "script": { "id": "d74f53dc-5aac-4afc-8acf-849a309054e5", "exec": [ "pm.test(\"Successful DELETE request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "DELETE", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models/resnet18", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models", "resnet18" ] } }, "response": [] }, { "name": "Register Resnet Model", "event": [ { "listen": "test", "script": { "id": "11206673-aec5-4d26-a06c-96fab1fc81f3", "exec": [ "pm.test(\"Successful POST request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "POST", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/resnet-18.mar&model_name=resnet18&synchronous=false", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models" ], "query": [ { "key": "url", "value": "https://s3.amazonaws.com/model-server/model_archive_1.0/resnet-18.mar" }, { "key": "model_name", "value": "resnet18" }, { "key": "synchronous", "value": "false" } ] } }, "response": [] }, { "name": "List with Limit", "event": [ { "listen": "test", "script": { "id": "55acdbb9-cae8-4ad1-b2a4-b50a3549e8be", "exec": [ "pm.test(\"Successful POST request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "GET", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models?limit=1", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models" ], "query": [ { "key": "limit", "value": "1" } ] } }, "response": [] }, { "name": "List with Pagination", "event": [ { "listen": "test", "script": { "id": "d1b17cfc-90af-4b61-af07-081224ba7f81", "exec": [ "pm.test(\"Successful POST request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "GET", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models?limit=1&next_page_token=1", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models" ], "query": [ { "key": "limit", "value": "1" }, { "key": "next_page_token", "value": "1" } ] } }, "response": [] }, { "name": "Update GPU Count", "event": [ { "listen": "test", "script": { "id": "e00d8108-3ce9-4a02-9fe1-3e38907d7f7b", "exec": [ "pm.test(\"Successful PUT request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200,201,202]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "PUT", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models/resnet18?number_gpu=10", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models", "resnet18" ], "query": [ { "key": "number_gpu", "value": "10" } ] } }, "response": [] }, { "name": "Scale up Workers - Synchronous", "event": [ { "listen": "test", "script": { "id": "f575937f-1665-4cf4-91a7-31c351eb9424", "exec": [ "pm.test(\"Successful PUT request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200,201,202]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "PUT", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models/resnet18?min_worker=5&max_worker=5&synchronous=true", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models", "resnet18" ], "query": [ { "key": "min_worker", "value": "5" }, { "key": "max_worker", "value": "5" }, { "key": "synchronous", "value": "true" } ] } }, "response": [] }, { "name": "Scale up Workers - Asynchronous", "event": [ { "listen": "test", "script": { "id": "a65c2a28-de18-4ac5-a893-9aa2cf3cef77", "exec": [ "pm.test(\"Successful PUT request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200,201,202]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "PUT", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models/resnet18?min_worker=6&max_worker=6&synchronous=false", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models", "resnet18" ], "query": [ { "key": "min_worker", "value": "6" }, { "key": "max_worker", "value": "6" }, { "key": "synchronous", "value": "false" } ] } }, "response": [] }, { "name": "Update Timeout to -1", "event": [ { "listen": "test", "script": { "id": "96e4be74-20d4-4343-b950-2738d6034341", "exec": [ "pm.test(\"Successful PUT request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200,201,202]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "PUT", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models/resnet18?timeout=-1", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models", "resnet18" ], "query": [ { "key": "timeout", "value": "-1" } ] } }, "response": [] }, { "name": "Update Timeout to 0", "event": [ { "listen": "test", "script": { "id": "d2c89755-a457-48ef-8dea-fdd61e88d1e5", "exec": [ "pm.test(\"Successful PUT request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200,201,202]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "PUT", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models/resnet18?timeout=0", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models", "resnet18" ], "query": [ { "key": "timeout", "value": "0" } ] } }, "response": [] }, { "name": "Register Model - Invalid URL", "event": [ { "listen": "test", "script": { "id": "427b68f9-de96-49b8-9477-450fa580e257", "exec": [ "pm.test(\"Invalid URL POST request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([400]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "POST", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/invalid-resnet-18.mar&model_name=invalid-resnet18", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models" ], "query": [ { "key": "url", "value": "https://s3.amazonaws.com/model-server/model_archive_1.0/invalid-resnet-18.mar" }, { "key": "model_name", "value": "invalid-resnet18" } ] } }, "response": [] }, { "name": "Get Model - Invalid Model", "event": [ { "listen": "test", "script": { "id": "36836026-4477-493b-898a-b642628a51d3", "exec": [ "pm.test(\"Valid ERROR message\", function () {", " pm.expect(pm.response.code).to.be.oneOf([404]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "GET", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models/invalid_squeezenet1_1", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models", "invalid_squeezenet1_1" ] } }, "response": [] }, { "name": "List Models - Invalid Next Page Token", "event": [ { "listen": "test", "script": { "id": "ccc07f12-6f45-4213-b4f5-e37f7d0ceadb", "exec": [ "pm.test(\"Successful POST request\", function () {", " pm.expect(pm.response.code).to.be.oneOf([200]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "GET", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models?next_page_token=12", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models" ], "query": [ { "key": "next_page_token", "value": "12" } ] } }, "response": [] }, { "name": "Update Worker with Invalid Worker Count", "event": [ { "listen": "test", "script": { "id": "ec7d233e-54bf-456a-a148-811385a34cc4", "exec": [ "pm.test(\"Valid ERROR message\", function () {", " pm.expect(pm.response.code).to.be.oneOf([400]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "PUT", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models/resnet18?min_worker=10&max_worker=9", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models", "resnet18" ], "query": [ { "key": "min_worker", "value": "10" }, { "key": "max_worker", "value": "9" } ] } }, "response": [] }, { "name": "UnRegister Invalid Model Name", "event": [ { "listen": "test", "script": { "id": "08052d51-fd23-4249-a17f-933e557bdacc", "exec": [ "pm.test(\"Valid ERROR message\", function () {", " pm.expect(pm.response.code).to.be.oneOf([404]);", "});" ], "type": "text/javascript" } } ], "request": { "method": "DELETE", "header": [], "url": { "raw": "{{protocol}}://{{hostname}}:{{mgmt-port}}/models/invalid_squeezenet1_1", "protocol": "{{protocol}}", "host": [ "{{hostname}}" ], "port": "{{mgmt-port}}", "path": [ "models", "invalid_squeezenet1_1" ] } }, "response": [] } ] } ================================================ FILE: test/regression_tests.sh ================================================ #!/bin/bash set -x set -e MMS_REPO="https://github.com/awslabs/multi-model-server.git" BRANCH=${1:-master} ROOT_DIR="/workspace/" CODEBUILD_WD=$(pwd) MODEL_STORE=$ROOT_DIR"/model_store" MMS_LOG_FILE="/tmp/mms.log" TEST_EXECUTION_LOG_FILE="/tmp/test_exec.log" install_mms_from_source() { echo "Cloning & Building Multi Model Server Repo from " $1 sudo apt-get -y install nodejs-dev node-gyp libssl1.0-dev sudo apt-get -y install npm sudo npm install -g n sudo n latest export PATH="$PATH" sudo npm install -g newman newman-reporter-html pip install mxnet-mkl # Clone & Build MMS echo "Installing MMS from source" git clone -b $2 $1 cd multi-model-server pip install . cd - echo "MMS Succesfully installed" } start_mms() { # Start MMS with Model Store multi-model-server --start --model-store $1 &>> $2 sleep 10 curl http://127.0.0.1:8081/models } stop_mms_serve() { multi-model-server --stop } start_secure_mms() { # Start MMS with Model Store multi-model-server --start --mms-config test/resources/config.properties --model-store $1 &>> $2 sleep 10 curl --insecure -X GET https://127.0.0.1:8444/models } run_postman_test() { # Run Postman Scripts mkdir $ROOT_DIR/report/ cd $CODEBUILD_WD/ set +e # Run Management API Tests stop_mms_serve start_mms $MODEL_STORE $MMS_LOG_FILE newman run -e test/postman/environment.json --bail --verbose test/postman/management_api_test_collection.json \ -r cli,html --reporter-html-export $ROOT_DIR/report/management_report.html >>$1 2>&1 # Run Inference API Tests after Restart stop_mms_serve start_mms $MODEL_STORE $MMS_LOG_FILE newman run -e test/postman/environment.json --bail --verbose test/postman/inference_api_test_collection.json \ -d test/postman/inference_data.json -r cli,html --reporter-html-export $ROOT_DIR/report/inference_report.html >>$1 2>&1 # Run Https test cases stop_mms_serve start_secure_mms $MODEL_STORE $MMS_LOG_FILE newman run --insecure -e test/postman/environment.json --bail --verbose test/postman/https_test_collection.json \ -r cli,html --reporter-html-export $ROOT_DIR/report/MMS_https_test_report.html >>$1 2>&1 stop_mms_serve set -e cd - } sudo rm -rf $ROOT_DIR && sudo mkdir $ROOT_DIR sudo chown -R $USER:$USER $ROOT_DIR cd $ROOT_DIR mkdir $MODEL_STORE sudo rm -f $TEST_EXECUTION_LOG_FILE $MMS_LOG_FILE echo "** Execuing MMS Regression Test Suite executon for " $MMS_REPO " **" install_mms_from_source $MMS_REPO $BRANCH run_postman_test $TEST_EXECUTION_LOG_FILE echo "** Tests Complete ** " exit 0 ================================================ FILE: test/resources/certs.pem ================================================ -----BEGIN CERTIFICATE----- MIICiDCCAfGgAwIBAgIEeC8zQzANBgkqhkiG9w0BAQsFADB2MQswCQYDVQQGEwJV UzETMBEGA1UECBMKQ2FsaWZvcm5pYTESMBAGA1UEBxMJUGFsbyBBbHRvMRIwEAYD VQQKEwlBbWF6b24gQUkxDjAMBgNVBAsTBU14TmV0MRowGAYDVQQDExFtbXMuYW1h em9uYXdzLmNvbTAgFw0xODA2MjAwMjExMjhaGA8yMTE3MDExMjAyMTEyOFowdjEL MAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExEjAQBgNVBAcTCVBhbG8g QWx0bzESMBAGA1UEChMJQW1hem9uIEFJMQ4wDAYDVQQLEwVNeE5ldDEaMBgGA1UE AxMRbW1zLmFtYXpvbmF3cy5jb20wgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGB AMcbCEP6kn9pUcap5+kYO/5xEl7SL965gSQ2TbFrVv+sLVkLSK8yTtcILr7RUINz FsD151Q7VyQCvpVzkOFew2s2mAFWWxPJYmxo1j/R3IkJakrrTrMy1R3jsqOQMrxY TLGR5LIe2pjdAnb9xWe2NB125619WDG7RrdHWZDfvSPxAgMBAAGjITAfMB0GA1Ud DgQWBBRWjdEyNchYAkdPoyudKJY9YP3JPzANBgkqhkiG9w0BAQsFAAOBgQBMAvqG cqvD3ColO2Ihgb/LCfCdV14e1YhusVFeKyZkSKFYyQR+MoBOxqMQqJ24gVzgqTU/ h+LkMqZcxxJAME08BzPgP5b06DBM4K0o0XUfYUViFpYXB0qCG5CA/0S7ONldBGaZ fv6JrnQ/a1NYBi92AaqXA4VmuaowWLVEFuPV1A== -----END CERTIFICATE----- ================================================ FILE: test/resources/config.properties ================================================ inference_address=https://127.0.0.1:8443 management_address=https://127.0.0.1:8444 private_key_file=test/resources/key.pem certificate_file=test/resources/certs.pem ================================================ FILE: test/resources/key.pem ================================================ -----BEGIN RSA PRIVATE KEY----- MIICXAIBAAKBgQDHGwhD+pJ/aVHGqefpGDv+cRJe0i/euYEkNk2xa1b/rC1ZC0iv Mk7XCC6+0VCDcxbA9edUO1ckAr6Vc5DhXsNrNpgBVlsTyWJsaNY/0dyJCWpK606z MtUd47KjkDK8WEyxkeSyHtqY3QJ2/cVntjQdduetfVgxu0a3R1mQ370j8QIDAQAB AoGANRoxlyfSQKcPR2PzVUjAX3k6xA1c9RMWrVjKWeJd/qymH5SR2yAYxOMKzJu4 1IYycF5lRyLYd+M/f06mOmVyysH3D7hkrNz57Z07UrZ0dO/mmUKRL7zc44mo22ck JtQRwWJMplgew7N8OyqEZbcLOpahjlkL4+KZIWOuO7X5m30CQQDob/rzNY8gfhEm oEHHQ4dCqa/b5as2OqpFoGBZ+iX3dumBf+UKuSHlvEozt4ZMm29DYSjhiGXgLUFw 6NBhWxpXAkEA20oNdGiYAyyGJ6TKkD3FNZYoqB5+E/Cq6c0AACssB4OrJtiGiBFq R1h5HTEwYMe+ciZ4CI5MvBukjAdlfn7W9wJAXOIqyTe060oVdncB8ivlCFmgweHk ajZFRq+Q8UPKGjq1kx9VmtRiXFjC2inTjBds/eL8oCuOcmgDR6hxZQYv3wJAcMLv kECIinlGsvQGRY297wQ7+9dSNaa3/Gmx6mRIy8RlKiCFbUqnP/C6tswoeFu+DqzB ZITn6IK+ZlMXWaiXmQJBAK7V4rR+4VdpYUu1OqPRxChkcM+Y4Wa985A46/8yoo3i 0PEenvApVzhVzS9jF6WEqIKcffBAmOxaXOn3kmn8w2A= -----END RSA PRIVATE KEY----- ================================================ FILE: tests/performance/README.md ================================================ # Performance Regression Suite ## Motivation The goal of this test suite is to ensure that performance regressions are detected early on. Ideally, with every commit made into the source control system. The salient features of the performance regression suite are * Non-intrusive - Does not need any code-changes or instrumentation on the server being monitored. * It can be used to monitor a wide variety of server metrics - memory, cpu, io - in addition to traditional API level metrics such as latency, throughput etc. * It is easy to add custom metrics. For example, in MMS server, `the number of workers spawned` would be an interesting metric to track. The platform allows for easy addition of these metrics. * Test cases are specified in human readable yaml files. Every test case has a pass or fail status. This is determined by evaluating expressions specified in the test case. Every expression checks metrics against threshold values. For example, `memory consumed by all workers < 500M`, `number of worker processes < 3`. * Test cases execute against compute environments. The threshold values are specific to the compute environment. It is possible to specify multiple compute environments against which the test cases will run. It follows that each compute environment, will have its own threshold values. * This suite leverages the open source [Taurus framework](https://gettaurus.org/). * This suite extends the Taurus framework in the following ways * Adds resource monitoring service. This allows MMS specific metrics to be added. * Environments as described earlier. * Specification of pass/fail criterion between two commits. For example, memory consumed by workers should not increase by more than 10% between two commits for the given test case. * Custom reporting of results. The building blocks of the performance regression suite and flow is captured in the following drawing ![](assets/blocks.png) ## Quickstart ### A. Installation 1. Install Taurus. Refer the [link](https://gettaurus.org/docs/Installation/) for more details on installation. ```bash pip install bzt # Needs python3.6+ ``` 2. Install performance regression suite dependencies. ```bash export MMS_HOME= pip install -r $MMS_HOME/tests/performance/requirements.txt ``` 3. Make sure that `git` is installed and the test suites are run from the MMS working directory. ### B. Running the test suite 1. Make sure parameters set in [tests/common/global_config.yaml](tests/performance/tests/global_config.yaml) are correct. 2. To run the test suite execute [run_performance_suite.py](run_performance_suite.py) with the following parameters * `--artifacts-dir` or `-a` is a directory where the test case results will be stored. The default value is `$MMS_HOME/tests/performance/run_artifacts`. * `--test-dir` or `-t` is a directory containing the test cases. The default value is `$MMS_HOME/tests/performance/tests`. * `--pattern` or `-p` glob pattern picks up certain test cases for execution within the `test-dir`. The default value picks up all test cases. * `--exclude-pattern` or `-x` glob pattern excludes certain test cases for execution within the `test-dir`. The default value excludes nothing. * `--env-name` or `-e` specifies the environment name to use while running the test cases. The environment name is the name of the file (minus the extension) found inside the environments folder in each test case. They encapsulate parameter values which are specific to the execution environment. This is a mandatory parameter. The script does the following: 1. Starts the metrics monitoring server. 2. Collects all the tests from test-dir satisfying the pattern 3. Executes the tests 4. Generates artifacts in the artifacts-dir against each test case. 3. Check the console logs, $artifacts-dir$//performance_results.html report, comparison.csv, comparison.html and other artifacts. **Steps are provided below** ```bash export MMS_HOME= cd $MMS_HOME/tests/performance # Note that MMS server started and stopped by the individual test suite. # check variables such as MMS server PORT etc # vi tests/common/global_config.yaml #all tests python -m run_performance_suite -e xlarge #run a specific test python -m run_performance_suite -e xlarge -p inference_single_worker ``` ### C. Understanding the test suite artifacts and reports 1. The $artifacts-dir$//performance_results.html is a summary report of the test run. 2. Each test yaml is treated as a test suite. Each criteria in the test suite is treated as a test case. If the test suite does not specify any criteria, then the test suite is reported as skipped with 0 test cases. 3. For each test suite, a sub-directory is created containing relevant run artifacts. Important files in this directory are * metrics.csv -- contains the values of the various system-monitored metrics over time * finals_stats.csv -- contains the values of the various api metrics over time 4. The $artifacts-dir$//comparison_results.html is a summary report which shows performance difference between the last two commits. 5. The run completes with a console summary of the performance and comparision suites which have failed ![](assets/console.png) ## Add a new test Follow these three steps to add a new test case to the test suite. 1. Add scenario (a.k.a test suite) 2. Add metrics to monitor 3. Add pass/fail criteria (a.k.a test case) #### 1. Add scenario (a.k.a test suite) Create a folder for the test under `test_dir` location. A test generally comprises of a jmeter file - containing the load scenario and a yaml file which contains test scenarios specifying the conditions for failure or success. The file-names should be identical to the folder name with their respective extensions. An example [jmeter script](tests/examples_starter/examples_starter.jmx) and a [scenario](tests/examples_starter/examples_starter.yaml) is provided as a template to get started. Please note that various global configuration settings used by examples_starter.jmx script are specified in [tests/global_config.yaml](tests/performance/tests/global_config.yaml) file. ```tests/examples_starter/examples_starter.yaml execution: - concurrency: 1 ramp-up: 1s hold-for: 40s scenario: Inference scenarios: Inference: script: examples_starter.jmx ``` To execute this test suite, run the following command ```bash export MMS_HOME= cd $MMS_HOME/tests/performance python -m run_performance_suite -p examples_starter -e xlarge ``` **Note**: Taurus provides support for different executors such as JMeter. Supported executor types can be found [here](https://gettaurus.org/docs/ExecutionSettings/). Details about how to use an existing JMeter script are provided [here](https://gettaurus.org/docs/JMeter/). #### 2. Add metrics to monitor Specify the metrics of interest in the services/monitoring section of the yaml. 1. Standalone monitoring server Use this technique if MMS and the tests execute on different machines. Before running the test cases, please start the [metrics_monitoring_server.py](metrics_monitoring_server.py) script. It will communicate server metric data with the test client over sockets. The monitoring server runs on port 9009 by default. To start the monitoring server, run the following commands on the MMS host: ```bash export MMS_HOME= pip install -r $MMS_HOME/tests/performance/requirements.txt python $MMS_HOME/tests/performance/metrics_monitoring_server.py --start ``` The monitoring section configuration is shown below. ```yaml services: - module: monitoring server-agent: - address: :9009 # metric monitoring service address label: mms-inference-server # Specified label will be used in reports instead of ip:port interval: 1s # polling interval logging: True # those logs will be saved to "SAlogs_192.168.0.1_9009.csv" in the artifacts dir metrics: # metrics should be supported by monitoring service - sum_cpu_percent # cpu percent used by all the mms server processes and workers - sum_memory_percent - sum_num_handles - server_workers # no of mms workers ``` The complete yaml can be found [here](tests/examples_remote_monitoring/examples_remote_monitoring.yaml) Use the command below to run the test suite. ```bash export MMS_HOME= cd $MMS_HOME/tests/performance python -m run_performance_suite -p examples_remote_monitoring -e xlarge ``` 2. Local monitoring plugin Use this technique if both MMS and the tests run on the same host. The monitoring section configuration is shown below. ```yaml modules: server_local_monitoring: # metrics_monitoring_taurus and dependencies should be in python path class : metrics_monitoring_taurus.Monitor # monitoring class. services: - module: server_local_monitoring # should be added in modules section ServerLocalClient: # keyword from metrics_monitoring_taurus.Monitor - interval: 1s metrics: - cpu - disk-space - mem - sum_memory_percent ``` The complete yaml can be found [here](tests/examples_local_monitoring/examples_local_monitoring.yaml). Use the command below to run the test suite. ```bash export MMS_HOME= cd $MMS_HOME/tests/performance python -m run_performance_suite -p examples_local_monitoring -e xlarge ``` #### 3. Add pass/fail criteria (a.k.a test case) 1. **Specify the pass/fail criteria**. Each pass-fail criterion maps to a test case in the generated report. We leverage the pass-fail module from Taurus to achieve this functionality. More details can be found [here](https://gettaurus.org/docs/PassFail/). A sample criterion is shown below ```yaml reporting: - module: passfail criteria: - class: bzt.modules.monitoring.MonitoringCriteria subject: mms-inference-server/sum_num_handles condition: '>' threshold: 180 timeframe: 1s fail: true stop: true ``` 2. Specify the pass/fail criterion vis-a-vis a prior run. On completion, the test suite runner script compares the monitoring metrics with values from a previous run which was executed on same environment. The run results are stored in either a local folder or a S3 bucket based on the `compare-local` option. Metrics which have 'diff_percent' value specified in the pass/fail criterion are used for comparison with the previous run. A sample criterion is shown below ```yaml reporting: - module: passfail criteria: - class: bzt.modules.monitoring.MonitoringCriteria subject: mms-inference-server/sum_num_handles condition: '>' threshold: 180 timeframe: 1s fail: true stop: true diff_percent : 30 ``` Note that 1. At least one test suite run on the same environment should have happened in order to do the comparison. 2. The $artifacts-dir$//comparison_results.html is a summary report which shows performance difference between the last two commits. 3. The test case fails if the diff_percent is greater than the specified value across runs. 3. Metrics available for pass-fail criteria **System Metrics** > disk_used, memory_percent, read_count, write_count, read_bytes, write_byte | Syntax | Examples | | ------ | -------- | | system_{metricname} | system_disk_used, system_memory_percent, system_write_count | **Process Metrics** > cpu_percent, memory_percent, cpu_user_time, cpu_system_time, cpu_iowait_time, memory_rss, memory_vms, io_read_count, io_write_count, io_read_bytes, io_write_bytes, file_descriptors, threads - Frontend. Represents the Java process hosting the REST APIs | Syntax | Examples | | ------ | -------- | | frontend_{metricname} | frontend_cpu_percent, frontend_memory_percent, frontend_cpu_iowait_time, frontend_memory_rss, frontend_io_write_bytes, frontend_threads | - Workers. Represents the python worker processes. Metrics for worker(s) are always available with an aggregate > Aggregates > sum, avg, min, max | Syntax | Examples | | ------ | -------- | | {aggregate}\_workers\_{metricname} | total_workers, sum_workers_memory_percent, avg_workers_iowait_time, min_workers_io_write_bytes, max_workers_threads | - All (Frontend + Workers). Represents aggregate metrics for both frontend and worker processes. | Syntax | Examples | | ------ | -------- | | {aggregate}\_all\_{metricname} | sum_all_memory_percent, avg_all_iowait_time, min_all_io_write_bytes, max_all_threads | - Miscellaneous * total_processes - Total number of processes spawned for frontend & workers * total_workers - Total number of workers spawned * orphans - Total number of orphan processes ## Test Strategy & Cases More details about our testing strategy and test cases can be found [here](TESTS.md) ## FAQ Q1. Is it possible to use the performance regression framework to test MMS on Python2.7? Yes. Even though, the performance regression framework needs Python 3.7+ (as Taurus requires Python 3.7+), there are two possible ways to achieve this * Please create a Python 2.7 virtual env which runs MMS and a Python 3.7 virtual env which runs the test framework and test cases. * Alternatively, deploy the standalone monitoring agent on the MMS instance and run the test cases against the remote server. Note that the standalone monitoring agent works on both Python 2/3. ================================================ FILE: tests/performance/TESTS.md ================================================ |CODE|Test Types |Comments | |----|------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------| |A |Operation time is deterministic for given inputs and env configuration |Model Inference Time = (Wait time for worker to be free) + (Time taken by worker to infer) | |B |Minimum drift in user sla and system metrics over time |Current architecture doesnt support recycling of workers. Need to check whether params specified in sheet 2 remain within acceptable bounds| |C |Demonstrate performance isolation across registered models |User uploaded handlers cannot cause denial of service (dos) for other model type inference | |D |Operations should scale linearly on node |Time to scale N workers == ( Time to scale 1 worker ) * N | |E |Demonstrate expected service concurrency commensurate with environment configuration|Setup environment in a way to minimize false positives/negatives | |F |Demonstrate preservation of performance characteristics |Do not spawn additional workers, accept model registrations if they hamper system SLAs. Is this implemented currently? | |G |Demonstrate cleanup releases system resources |Unregistering model should free up commensurate resources held by workers | |H |Demonstrate that cleanup/termination of operations should be graceful |Scale down should wait for current inference operation to succeed. Is this current behavior? | |I |Demonstrate that operations rollback in case request cannot be satisifed |Ongoing inference should complete before scaledown operation is allowed to start | |J |Demonstrate that operations are idempotent |Multiple simultaneous scale operations with the same parameter value should result in the same system state | |API|CODE|YAML| |---|----|---| |Register Model|A,B,F,J|[register_unregister](tests/register_unregister)| |Inference|A,B,C|[inference_single_worker](tests/inference_single_worker), [inference_multiple_worker](tests/inference_multiple_worker)| |Batch Inference|A,B,C|[batch_inference](tests/batch_inference)| |Custom Model Handlers|C|[batch_and_single_inference](tests/batch_and_single_inference)| |Scale Workers - UP/DOWN|D,G,I,F,J|[scale_up_workers](tests/scale_up_workers), [scale_down_workers](tests/scale_down_workers)| |Unregister Models|D,G,I,J|[register_unregister](tests/register_unregister)| |Health Check|A,B,E|[health_check](tests/health_check)| |API Description|A,B,E|[api_description](tests/api_description)| |Model Describe|A,B,E|[model_description](tests/model_description)| |List Models|A,B,E|[list_models](tests/list_models)| ================================================ FILE: tests/performance/agents/__init__.py ================================================ ================================================ FILE: tests/performance/agents/config.ini ================================================ [server] pid_file = model_server.pid [monitoring] HOST = PORT = 9009 [suite] s3_bucket = mms-performance-regression-reports ================================================ FILE: tests/performance/agents/configuration.py ================================================ #!/usr/bin/env python3 # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Read configuration file """ # pylint: disable=redefined-builtin, bare-except import os import configparser import pathlib config = configparser.ConfigParser() path = pathlib.Path(__file__).parent.absolute() config.read(os.path.join(path, 'config.ini')) def get(section, key, default=''): try: return config[section][key] except: return default ================================================ FILE: tests/performance/agents/metrics/__init__.py ================================================ #!/usr/bin/env python3 """ Customised system and mms process metrics for monitoring and pass-fail criteria in taurus""" # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. from enum import Enum from statistics import mean import psutil from psutil import NoSuchProcess, ZombieProcess class ProcessType(Enum): """ Type of MMS processes to compute metrics on """ FRONTEND = 1 WORKER = 2 ALL = 3 operators = { 'sum': sum, 'avg': mean, 'min': min, 'max': max } process_metrics = { # cpu 'cpu_percent': lambda p: p.get('cpu_percent', 0), 'cpu_user_time': lambda p: getattr(p.get('cpu_times', {}), 'user', 0), 'cpu_system_time': lambda p: getattr(p.get('cpu_times', {}), 'system', 0), 'cpu_iowait_time': lambda p: getattr(p.get('cpu_times', {}), 'iowait', 0), # memory 'memory_percent': lambda p: p.get('memory_percent', 0), 'memory_rss': lambda p: getattr(p.get('memory_info', {}), 'rss', 0), 'memory_vms': lambda p: getattr(p.get('memory_info', {}), 'vms', 0), # io 'io_read_count': lambda p: getattr(p.get('io_counters', {}), 'read_count', 0), 'io_write_count': lambda p: getattr(p.get('io_counters', {}), 'write_count', 0), 'io_read_bytes': lambda p: getattr(p.get('io_counters', {}), 'read_bytes', 0), 'io_write_bytes': lambda p: getattr(p.get('io_counters', {}), 'write_bytes', 0), 'file_descriptors': lambda p: p.get('num_fds', 0), # processes 'threads': lambda p: p.get('num_threads', 0) } system_metrics = { 'system_disk_used': None, 'system_memory_percent': None, 'system_read_count': None, 'system_write_count': None, 'system_read_bytes': None, 'system_write_bytes': None, } misc_metrics = { 'total_processes': None, 'total_workers': None, 'orphans': None } AVAILABLE_METRICS = list(system_metrics) + list(misc_metrics) WORKER_NAME = 'model_service_worker.py' for metric in list(process_metrics): for ptype in list(ProcessType): if ptype == ProcessType.WORKER: PNAME = 'workers' for op in list(operators): AVAILABLE_METRICS.append('{}_{}_{}'.format(op, PNAME, metric)) elif ptype == ProcessType.FRONTEND: PNAME = 'frontend' AVAILABLE_METRICS.append('{}_{}'.format(PNAME, metric)) else: PNAME = 'all' for op in list(operators): AVAILABLE_METRICS.append('{}_{}_{}'.format(op, PNAME, metric)) children = set() def get_metrics(server_process, child_processes, logger): """ Get Server processes specific metrics """ result = {} children.update(child_processes) logger.debug("children : {0}".format(",".join([str(c.pid) for c in children]))) def update_metric(metric_name, proc_type, stats): stats = list(filter(lambda x: isinstance(x, (float, int)), stats)) stats = stats if len(stats) else [0] if proc_type == ProcessType.WORKER: proc_name = 'workers' elif proc_type == ProcessType.FRONTEND: proc_name = 'frontend' result[proc_name + '_' + metric_name] = stats[0] return else: proc_name = 'all' for op_name in operators: result['{}_{}_{}'.format(op_name, proc_name, metric_name)] = operators[op_name](stats) processes_stats = [] reclaimed_pids = [] try: # as_dict() gets all stats in one shot processes_stats.append({'type': ProcessType.FRONTEND, 'stats': server_process.as_dict()}) except: pass for child in children: try: child_cmdline = child.cmdline() if psutil.pid_exists(child.pid) and len(child_cmdline) >= 2 and WORKER_NAME in child_cmdline[1]: processes_stats.append({'type': ProcessType.WORKER, 'stats': child.as_dict()}) else: reclaimed_pids.append(child) logger.debug('child {0} no longer available'.format(child.pid)) except (NoSuchProcess, ZombieProcess): reclaimed_pids.append(child) logger.debug('child {0} no longer available'.format(child.pid)) for p in reclaimed_pids: children.remove(p) ### PROCESS METRICS ### worker_stats = list(map(lambda x: x['stats'], \ filter(lambda x: x['type'] == ProcessType.WORKER, processes_stats))) server_stats = list(map(lambda x: x['stats'], \ filter(lambda x: x['type'] == ProcessType.FRONTEND, processes_stats))) all_stats = list(map(lambda x: x['stats'], processes_stats)) for k in process_metrics: update_metric(k, ProcessType.WORKER, list(map(process_metrics[k], worker_stats))) update_metric(k, ProcessType.ALL, list(map(process_metrics[k], all_stats))) update_metric(k, ProcessType.FRONTEND, list(map(process_metrics[k], server_stats))) # Total processes result['total_processes'] = len(worker_stats) + 1 result['total_workers'] = max(len(worker_stats) - 1, 0) result['orphans'] = len(list(filter(lambda p: p['ppid'] == 1, worker_stats))) ### SYSTEM METRICS ### result['system_disk_used'] = psutil.disk_usage('/').used result['system_memory_percent'] = psutil.virtual_memory().percent system_disk_io_counters = psutil.disk_io_counters() result['system_read_count'] = system_disk_io_counters.read_count result['system_write_count'] = system_disk_io_counters.write_count result['system_read_bytes'] = system_disk_io_counters.read_bytes result['system_write_bytes'] = system_disk_io_counters.write_bytes return result ================================================ FILE: tests/performance/agents/metrics_collector.py ================================================ #!/usr/bin/env python3 # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Server metrics collector """ # pylint: disable=redefined-builtin, broad-except, unused-variable import argparse import logging import os import sys import tempfile import time import gevent import psutil from utils.process import get_process_pid_from_file, get_child_processes, \ get_server_processes, get_server_pidfile from metrics import AVAILABLE_METRICS, get_metrics import configuration logger = logging.getLogger(__name__) logging.basicConfig(stream=sys.stdout, format="%(message)s", level=logging.INFO) TMP_DIR = tempfile.gettempdir() METRICS_LOG_FILE = os.path.join(TMP_DIR, "server_metrics_{}.log".format(int(time.time()))) METRICS_COLLECTOR_PID_FILE = os.path.join(TMP_DIR, "metrics_collector.pid") PID_FILE = configuration.get('server', 'pid_file', 'model_server.pid') MONITOR_INTERVAL = 1 def store_pid(pid_file): """ Store the current process id to pid_file""" process = psutil.Process() pid_file = os.path.join(pid_file) with open(pid_file, "w") as pf: pf.write(str(process.pid)) def stop_process(pid_file): """This will stop already running process . Note at a time only one pid file will be available. """ pid = get_process_pid_from_file(pid_file) if pid: try: process = psutil.Process(pid) if process.is_running(): logger.info("Process with pid %s is running. Killing it.", process.pid) process.kill() except Exception as e: pass else: logger.info("Dead process with pid %s found in '%s'.", process.pid, pid_file) logger.info("Removing pid file '%s'.", pid_file) os.remove(pid_file) def check_is_running(pid_file): """check if pid is running""" pid = get_process_pid_from_file(pid_file) if pid: try: perf_mon_process = psutil.Process(pid) except Exception as e: stop_process(pid_file) else: if perf_mon_process.is_running(): logger.error("Performance monitoring script already running. " "Stop it using stop option.") sys.exit() def store_metrics_collector_pid(): """ Store the process id of metrics collector process""" store_pid(METRICS_COLLECTOR_PID_FILE) def stop_metrics_collector_process(): """This will stop already running metrics collector process. Note at a time only one pid file will be available. """ stop_process(METRICS_COLLECTOR_PID_FILE) def monitor_processes(server_process, metrics, interval, socket): """ Monitor the metrics of server_process and its child processes """ while True: message = [] collected_metrics = get_metrics(server_process, get_child_processes(server_process), logger) metrics_msg = [] for metric in metrics: message.append(str(collected_metrics.get(metric, 0))) if collected_metrics.get(metric) is not None: metrics_msg.append("{0} : {1}".format(metric, collected_metrics.get(metric, 0))) message = "\t".join(message) + "\t\n" logger.info("%s", " -- ".join(metrics_msg)) if socket: try: socket.send(message.encode("latin-1")) except BrokenPipeError: logger.info("Stopping monitoring as socket connection is closed.") break # TODO - log metrics to a file METRICS_LOG_FILE if METRICS_LOG_FILE is provided gevent.sleep(interval) def start_metric_collection(server_process, metrics, interval, socket): bad_metrics = set(metrics) - set(AVAILABLE_METRICS) if bad_metrics: raise Exception("Metrics not available for monitoring {}.".format(bad_metrics)) logger.info("Started metric collection for target server processes.....") thread = gevent.spawn(monitor_processes, server_process, metrics, interval, socket) gevent.joinall([thread]) def start_metric_collector_process(): """Spawn a metric collection process and keep on monitoring """ check_is_running(METRICS_COLLECTOR_PID_FILE) store_metrics_collector_pid() server_pid = get_process_pid_from_file(get_server_pidfile(PID_FILE)) server_process = get_server_processes(server_pid) start_metric_collection(server_process, AVAILABLE_METRICS, MONITOR_INTERVAL, None) if __name__ == "__main__": logging.basicConfig(stream=sys.stdout, format="%(message)s", level=logging.INFO) parser = argparse.ArgumentParser(prog='metric-collector', description='System Performance Metrics collector') sub_parse = parser.add_mutually_exclusive_group(required=True) sub_parse.add_argument('--start', action='store_true', help='Start the metric-collector') sub_parse.add_argument('--stop', action='store_true', help='Stop the metric-collector') args = parser.parse_args() if args.start: start_metric_collector_process() elif args.stop: stop_metrics_collector_process() ================================================ FILE: tests/performance/agents/metrics_monitoring_inproc.py ================================================ #!/usr/bin/env python3 # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Taurus Local plugin for server monitoring. Should be used when server and Taurus are running on same machine. This file should be placed in Python Path along with monitoring package. """ # pylint: disable=redefined-builtin, unnecessary-comprehension import csv import sys from bzt import TaurusConfigError from bzt.modules import monitoring from bzt.utils import dehumanize_time import configuration from metrics import get_metrics, AVAILABLE_METRICS as AVAILABLE_SERVER_METRICS from utils.process import get_process_pid_from_file, get_server_processes, \ get_child_processes, get_server_pidfile PY2 = sys.version_info[0] == 2 PY3 = sys.version_info[0] == 3 PID_FILE = configuration.get('server', 'pid_file', 'model_server.pid') class Monitor(monitoring.Monitoring): """Add ServerLocalClient to Monitoring by patching to monitoring.Monitoring """ def __init__(self): super(Monitor, self).__init__() self.client_classes.update({'ServerLocalClient': ServerLocalClient}) class ServerLocalClient(monitoring.LocalClient): """Custom server local client """ AVAILABLE_METRICS = monitoring.LocalClient.AVAILABLE_METRICS + \ AVAILABLE_SERVER_METRICS def __init__(self, parent_log, label, config, engine=None): super(ServerLocalClient, self).__init__(parent_log, label, config, engine=engine) if label: self.label = label else: self.label = 'ServerLocalClient' def connect(self): exc = TaurusConfigError('Metric is required in Local monitoring client') metric_names = self.config.get('metrics', exc) bad_list = set(metric_names) - set(self.AVAILABLE_METRICS) if bad_list: self.log.warning('Wrong metrics found: %s', bad_list) good_list = set(metric_names) & set(self.AVAILABLE_METRICS) if not good_list: raise exc self.metrics = list(set(good_list)) self.monitor = ServerLocalMonitor(self.log, self.metrics, self.engine) self.interval = dehumanize_time(self.config.get("interval", self.engine.check_interval)) if self.config.get("logging", False): if not PY3: self.log.warning("Logging option doesn't work on python2.") else: self.logs_file = self.engine.create_artifact("local_monitoring_logs", ".csv") with open(self.logs_file, "a", newline='') as mon_logs: logs_writer = csv.writer(mon_logs, delimiter=',') metrics = ['ts'] + sorted([metric for metric in good_list]) logs_writer.writerow(metrics) class ServerLocalMonitor(monitoring.LocalMonitor): """Custom server local monitor""" def _calc_resource_stats(self, interval): result = super()._calc_resource_stats(interval) server_pid = get_process_pid_from_file(get_server_pidfile(PID_FILE)) server_process = get_server_processes(server_pid) result.update(get_metrics(server_process, get_child_processes(server_process), self.log)) metrics_msg = [] updated_result = {} for key in self.metrics: if result.get(key) is not None: metrics_msg.append("{0} : {1}".format(key, result[key])) updated_result[key] = result.get(key) self.log.info("{0}".format(" -- ".join(metrics_msg))) return updated_result ================================================ FILE: tests/performance/agents/metrics_monitoring_server.py ================================================ #!/usr/bin/env python3 # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Remote server monitoring script """ # pylint: disable=redefined-builtin, wrong-import-position, too-many-nested-blocks, broad-except import argparse import logging import sys import tempfile import os from gevent import monkey from gevent import select from gevent import socket monkey.patch_select() monkey.patch_socket() from metrics_collector import start_metric_collection, stop_process, store_pid, check_is_running from utils.process import get_process_pid_from_file, \ get_server_processes, get_server_pidfile import configuration logger = logging.getLogger(__name__) logging.basicConfig(stream=sys.stdout, format="%(message)s", level=logging.INFO) TMP_DIR = tempfile.gettempdir() METRICS_MON_SERVER_PID_FILE = os.path.join(TMP_DIR, ".metrics_monitoring_server.pid") PID_FILE = configuration.get('server', 'pid_file', 'model_server.pid') HOST = str(configuration.get('monitoring', 'HOST')) PORT = int(configuration.get('monitoring', 'PORT', 9009)) SOCKET_LIST = [] RECV_BUFFER = 4096 interval = 1 def process_data(sock): """ process data recieved on socket""" # receiving data from the socket. data = sock.recv(RECV_BUFFER).decode() if data: if data == 'test\n': send_message(sock, "Yep\n") elif data == 'exit\n': close_socket(sock) elif data.startswith('interval'): try: global interval interval = int(data.split(":")[1][:-1]) except Exception: send_message(sock, "In-correct interval data") elif data.startswith('metrics'): metrics = data[:-1].split("metrics:")[1].split("\t") server_pid = get_process_pid_from_file(get_server_pidfile(PID_FILE)) server_process = get_server_processes(server_pid) start_metric_collection(server_process, metrics, interval, sock) else: # TODO - decide what to do here pass else: # remove the socket that's broken if sock in SOCKET_LIST: SOCKET_LIST.remove(sock) def perf_server(): """ start performance moniting server on a socket """ server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) server_socket.bind((HOST, PORT)) server_socket.listen(10) SOCKET_LIST.append(server_socket) logger.info("Started metrics monitoring server on port %s", PORT) while True: ready_to_read, _, _ = select.select(SOCKET_LIST, [], [], 0) for sock in ready_to_read: # a new connection request recieved if sock == server_socket: sockfd, addr = server_socket.accept() SOCKET_LIST.append(sockfd) logger.info("client (%s, %s) connected", addr[0], addr[1]) # a message from a client, not a new connection else: try: process_data(sock) except Exception as e: logger.warning("Error %s", str(e)) continue server_socket.close() def send_message(socket_, message): try: socket_.send(message.encode("latin-1")) except Exception as e: logger.warning("Error while sending the message %s. Closing the socket.", str(e)) close_socket(socket_) def close_socket(socket_): socket_.close() if socket_ in SOCKET_LIST: SOCKET_LIST.remove(socket_) if __name__ == "__main__": logging.basicConfig(stream=sys.stdout, format="%(message)s", level=logging.INFO) parser = argparse.ArgumentParser(prog='perf-mon-script', description='System Performance Monitoring') sub_parse = parser.add_mutually_exclusive_group(required=True) sub_parse.add_argument('--start', action='store_true', help='Start the perf-mon-script') sub_parse.add_argument('--stop', action='store_true', help='Stop the perf-mon-script') args = parser.parse_args() if args.start: check_is_running(METRICS_MON_SERVER_PID_FILE) store_pid(METRICS_MON_SERVER_PID_FILE) perf_server() elif args.stop: stop_process(METRICS_MON_SERVER_PID_FILE) ================================================ FILE: tests/performance/agents/utils/__init__.py ================================================ ================================================ FILE: tests/performance/agents/utils/process.py ================================================ #!/usr/bin/env python3 # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Utility methods for process related information """ # pylint: disable=redefined-builtin import os import tempfile import psutil def find_procs_by_name(name): """Return a list of processes matching 'name'.""" ls = [] for p in psutil.process_iter(["name", "exe", "cmdline"]): if name == p.info['name'] or \ p.info['exe'] and os.path.basename(p.info['exe']) == name or \ p.info['cmdline'] and p.info['cmdline'][0] == name: ls.append(p) if len(ls) > 1: raise Exception("Multiple processes found with name {}.".format(name)) return ls[0] def get_process_pid_from_file(file_path): """Get the process pid from pid file. """ pid = None if os.path.isfile(file_path): with open(file_path, "r") as f: pid = int(f.readline()) return pid def get_child_processes(process): """Get all running child processes recursively""" child_processes = set() for p in process.children(recursive=True): child_processes.add(p) return child_processes def get_server_processes(server_process_pid): """ It caches the main server and child processes at module level. Ensure that you call this process so that MMS process """ try: server_process = psutil.Process(server_process_pid) except Exception as e: print("Server process not found. Error: {}".format(str(e))) raise return server_process def get_server_pidfile(file): return os.path.join(tempfile.gettempdir(), ".{}".format(file)) ================================================ FILE: tests/performance/pylintrc ================================================ [MASTER] # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # Specify a configuration file. #rcfile= # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). #init-hook= # Add files or directories to the blacklist. They should be base names, not # paths. ignore=CVS # Add files or directories matching the regex patterns to the blacklist. The # regex matches against base names, not paths. ignore-patterns= # Pickle collected data for later comparisons. persistent=yes # List of plugins (as comma separated values of python modules names) to load, # usually to register additional checkers. load-plugins= # Use multiple processes to speed up Pylint. jobs=8 # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code extension-pkg-whitelist=numpy,opencv # Allow optimization of some AST trees. This will activate a peephole AST # optimizer, which will apply various small optimizations. For instance, it can # be used to obtain the result of joining multiple strings with the addition # operator. Joining a lot of strings can lead to a maximum recursion error in # Pylint and this flag can prevent that. It has one side effect, the resulting # AST will be different than the one from reality. This option is deprecated # and it will be removed in Pylint 2.0. optimize-ast=no [MESSAGES CONTROL] # Only show warnings with the listed confidence levels. Leave empty to show # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED confidence= # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). See also the "--disable" option for examples. enable=indexing-exception,old-raise-syntax # Disable the message, report, category or checker with the given id(s). You # can either give multiple identifiers separated by comma (,) or put this # option multiple times (only on the command line, not in the configuration # file where it should appear only once).You can also use "--disable=all" to # disable everything first and then reenable specific checks. For example, if # you want to run only the similarities checker, you can use "--disable=all # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use"--disable=all --enable=classes # --disable=W" disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,protected-access,superfluous-parens,invalid-name,no-else-return,useless-super-delegation,len-as-condition,invalid-unary-operand-type,useless-object-inheritance # disable=unicode-builtin,delslice-method,using-cmp-argument,setslice-method,dict-view-method,parameter-unpacking,range-builtin-not-iterating,print-statement,file-builtin,old-raise-syntax,basestring-builtin,execfile-builtin,indexing-exception,import-star-module-level,coerce-method,long-builtin,old-ne-operator,old-division,no-absolute-import,raw_input-builtin,old-octal-literal,oct-method,xrange-builtin,hex-method,unpacking-in-except,nonzero-method,raising-string,intern-builtin,reload-builtin,metaclass-assignment,cmp-method,filter-builtin-not-iterating,apply-builtin,map-builtin-not-iterating,next-method-called,unichr-builtin,buffer-builtin,dict-iter-method,input-builtin,coerce-builtin,getslice-method,useless-suppression,standarderror-builtin,zip-builtin-not-iterating,suppressed-message,cmp-builtin,backtick,long-suffix,reduce-builtin,round-builtin [REPORTS] # Set the output format. Available formats are text, parseable, colorized, msvs # (visual studio) and html. You can also give a reporter class, eg # mypackage.mymodule.MyReporterClass. output-format=text # Put messages in a separate file for each module / package specified on the # command line instead of printing them on stdout. Reports (if any) will be # written in a file name "pylint_global.[txt|html]". This option is deprecated # and it will be removed in Pylint 2.0. files-output=no # Tells whether to display a full report or only the messages reports=no # Python expression which should return a note less than 10 (10 is the highest # note). You have access to the variables errors warning, statement which # respectively contain the number of errors / warnings messages and the total # number of statements analyzed. This is used by the global evaluation report # (RP0004). evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) # Template used to display messages. This is a python new-style format string # used to format the message information. See doc for all details #msg-template= [FORMAT] # Maximum number of characters on a single line. max-line-length=120 # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines=^\s*(# )??$ # Allow the body of an if to be on the same line as the test if there is no # else. single-line-if-stmt=no # List of optional constructs for which whitespace checking is disabled. `dict- # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. # `trailing-comma` allows a space between comma and closing bracket: (a, ). # `empty-line` allows space-only lines. no-space-check=trailing-comma,dict-separator # Maximum number of lines in a module max-module-lines=1000 # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 # tab). indent-string=' ' # Number of spaces of indent required inside a hanging or continued line. indent-after-paren=4 # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format= [SPELLING] # Spelling dictionary name. Available dictionaries: none. To make it working # install python-enchant package. spelling-dict= # List of comma separated words that should not be checked. spelling-ignore-words= # A path to a file that contains private dictionary; one word per line. spelling-private-dict-file= # Tells whether to store unknown words to indicated private dictionary in # --spelling-private-dict-file option instead of raising a message. spelling-store-unknown-words=no [MISCELLANEOUS] # List of note tags to take in consideration, separated by a comma. notes=FIXME,XXX,TODO [TYPECHECK] # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). ignore-mixin-members=yes # List of module names for which member attributes should not be checked # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis. It # supports qualified module names, as well as Unix pattern matching. ignored-modules= # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. ignored-classes=optparse.Values,thread._local,_thread._local # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. generated-members= # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. contextmanager-decorators=contextlib.contextmanager [LOGGING] # Logging modules to check that the string format arguments are in logging # function parameter format logging-modules=logging [SIMILARITIES] # Minimum lines number of a similarity. min-similarity-lines=4 # Ignore comments when computing similarities. ignore-comments=yes # Ignore docstrings when computing similarities. ignore-docstrings=yes # Ignore imports when computing similarities. ignore-imports=no [VARIABLES] # Tells whether we should check for unused import in __init__ files. init-import=no # A regular expression matching the name of dummy variables (i.e. expectedly # not used). dummy-variables-rgx=(_+[a-zA-Z0-9]*?$)|dummy # List of additional names supposed to be defined in builtins. Remember that # you should avoid to define new builtins when possible. additional-builtins= # List of strings which can identify a callback function by name. A callback # name must start or end with one of those strings. callbacks=cb_,_cb # List of qualified module names which can have objects that can redefine # builtins. redefining-builtins-modules=six.moves,future.builtins,builtins [BASIC] # Good variable names which should always be accepted, separated by a comma good-names=i,j,_,a,b,op,x,y,wd,lr,kv,k,v,s,p,h,c,m,n,X,t,g,f # Bad variable names which should always be refused, separated by a comma bad-names= # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. name-group= # Include a hint for the correct naming format with invalid-name include-naming-hint=no # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. property-classes=abc.abstractproperty # Regular expression matching correct module names module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ # Naming hint for module names module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ # Regular expression matching correct constant names const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ # Naming hint for constant names const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ # Regular expression matching correct inline iteration names inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ # Naming hint for inline iteration names inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ # Regular expression matching correct method names method-rgx=[a-z_][a-z0-9_]{2,30}$ # Naming hint for method names method-name-hint=[a-z_][a-z0-9_]{2,30}$ # Regular expression matching correct class attribute names class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ # Naming hint for class attribute names class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ # Regular expression matching correct argument names argument-rgx=[a-z_][a-z0-9_]{2,30}$ # Naming hint for argument names argument-name-hint=[a-z_][a-z0-9_]{2,30}$ # Regular expression matching correct attribute names attr-rgx=[a-z_][a-z0-9_]{2,30}$ # Naming hint for attribute names attr-name-hint=[a-z_][a-z0-9_]{2,30}$ # Regular expression matching correct variable names variable-rgx=[a-z_][a-z0-9_]{2,30}$ # Naming hint for variable names variable-name-hint=[a-z_][a-z0-9_]{2,30}$ # Regular expression matching correct function names function-rgx=[a-z_][a-z0-9_]{2,30}$ # Naming hint for function names function-name-hint=[a-z_][a-z0-9_]{2,30}$ # Regular expression matching correct class names class-rgx=[A-Z_][a-zA-Z0-9]+$ # Naming hint for class names class-name-hint=[A-Z_][a-zA-Z0-9]+$ # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=^_ # Minimum line length for functions/classes that require docstrings, shorter # ones are exempt. docstring-min-length=10 [ELIF] # Maximum number of nested blocks for function / method body max-nested-blocks=5 [CLASSES] # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__,__new__,setUp # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg=mcs # List of member names, which should be excluded from the protected access # warning. exclude-protected=_asdict,_fields,_replace,_source,_make [IMPORTS] # Deprecated modules which should not be used, separated by a comma deprecated-modules=optparse # Create a graph of every (i.e. internal and external) dependencies in the # given file (report RP0402 must not be disabled) import-graph= # Create a graph of external dependencies in the given file (report RP0402 must # not be disabled) ext-import-graph= # Create a graph of internal dependencies in the given file (report RP0402 must # not be disabled) int-import-graph= # Force import order to recognize a module as part of the standard # compatibility libraries. known-standard-library= # Force import order to recognize a module as part of a third party library. known-third-party=enchant # Analyse import fallback blocks. This can be used to support both Python 2 and # 3 compatible code, which means that the block might have code that exists # only in one or another interpreter, leading to false positives when analysed. analyse-fallback-blocks=no [DESIGN] # Maximum number of arguments for function / method max-args=5 # Argument names that match this expression will be ignored. Default to name # with leading underscore ignored-argument-names=_.* # Maximum number of locals for function / method body max-locals=15 # Maximum number of return / yield for function / method body max-returns=6 # Maximum number of branch for function / method body max-branches=12 # Maximum number of statements in function / method body max-statements=50 # Maximum number of parents for a class (see R0901). max-parents=7 # Maximum number of attributes for a class (see R0902). max-attributes=7 # Minimum number of public methods for a class (see R0903). min-public-methods=2 # Maximum number of public methods for a class (see R0904). max-public-methods=20 # Maximum number of boolean expressions in a if statement max-bool-expr=5 [EXCEPTIONS] # Exceptions that will emit a warning when being caught. Defaults to # "Exception" overgeneral-exceptions=Exception ================================================ FILE: tests/performance/requirements.txt ================================================ gevent==20.5.2 junitparser==1.4.1 git+https://github.com/maheshambule/vjunit.git#egg=vjunit tqdm==4.40.0 pathlib==1.0.1 boto3==1.14.3 awscli==1.18.80 click==7.1.2 tabulate==0.8.7 pandas==1.0.3 termcolor==1.1.0 ================================================ FILE: tests/performance/run_performance_suite.py ================================================ # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Run Performance Regression Test Cases and Generate Reports """ # pylint: disable=redefined-builtin, no-value-for-parameter import logging import os import subprocess import sys import time import click import pathlib from runs.context import ExecutionEnv from runs.taurus import get_taurus_options, x2junit, update_taurus_metric_files from tqdm import tqdm from utils import run_process, Timer, get_sub_dirs logger = logging.getLogger(__name__) logging.basicConfig(stream=sys.stdout, format="%(message)s", level=logging.INFO) ROOT_PATH = pathlib.Path(__file__).parent.absolute() RUN_ARTIFACTS_PATH = os.path.join(ROOT_PATH, "run_artifacts") GLOBAL_CONFIG_PATH = os.path.join(ROOT_PATH, "tests", "global_config.yaml") MONITORING_AGENT = os.path.join(ROOT_PATH, "agents", "metrics_monitoring_server.py") def get_artifacts_dir(ctx, param, value): commit_id = subprocess.check_output('git rev-parse --short HEAD'.split()).decode("utf-8")[:-1] run_name = "{}__{}__{}".format(ctx.params['env_name'], commit_id, int(time.time())) if value is None: artifacts_dir = os.path.join(RUN_ARTIFACTS_PATH, run_name) else: artifacts_dir = os.path.abspath(value) artifacts_dir = os.path.join(artifacts_dir, run_name) return artifacts_dir def validate_env(ctx, param, value): try: if '__' in value: raise ValueError return value except ValueError: raise click.BadParameter('Environment name should not have double underscores in it.') @click.command() @click.option('-a', '--artifacts-dir', help='Directory to store artifacts.', type=click.Path(writable=True), callback=get_artifacts_dir) @click.option('-t', '--test-dir', help='Directory containing tests.', type=click.Path(exists=True), default=os.path.join(ROOT_PATH, "tests")) @click.option('-p', '--pattern', help='Test case folder name glob pattern', default="*") @click.option('-x', '--exclude-pattern', help='Test case folder name glob pattern to exclude', default=None) @click.option('-j', '--jmeter-path', help='JMeter executable path.') @click.option('-e', '--env-name', help='Environment filename without the extension. Contains threshold values.', required=True, callback=validate_env) @click.option('--monit/--no-monit', help='Start Monitoring server', default=True) @click.option('--compare-local/--no-compare-local', help='Compare with previous run with files stored' ' in artifacts directory', default=True) def run_test_suite(artifacts_dir, test_dir, pattern, exclude_pattern, jmeter_path, env_name, monit, compare_local): """Collect test suites, run them and generate reports""" logger.info("Artifacts will be stored in directory %s", artifacts_dir) test_dirs = get_sub_dirs(test_dir, exclude_list=[], include_pattern=pattern, exclude_pattern=exclude_pattern) if not test_dirs: logger.info("No test cases are collected...Exiting.") sys.exit(3) else: logger.info("Collected tests %s", test_dirs) with ExecutionEnv(MONITORING_AGENT, artifacts_dir, env_name, compare_local, monit) as prt: pre_command = 'export PYTHONPATH={}:$PYTHONPATH;'.format(os.path.join(str(ROOT_PATH), "agents")) for suite_name in tqdm(test_dirs, desc="Test Suites"): with Timer("Test suite {} execution time".format(suite_name)) as t: suite_artifacts_dir = os.path.join(artifacts_dir, suite_name) options_str = get_taurus_options(suite_artifacts_dir, jmeter_path) env_yaml_path = os.path.join(test_dir, suite_name, "environments", "{}.yaml".format(env_name)) env_yaml_path = "" if not os.path.exists(env_yaml_path) else env_yaml_path test_file = os.path.join(test_dir, suite_name, "{}.yaml".format(suite_name)) with x2junit.X2Junit(suite_name, suite_artifacts_dir, prt.reporter, t, env_name) as s: s.code, s.err = run_process("{} bzt {} {} {} {}".format(pre_command, options_str, test_file, env_yaml_path, GLOBAL_CONFIG_PATH)) update_taurus_metric_files(suite_artifacts_dir, test_file) if __name__ == "__main__": run_test_suite() ================================================ FILE: tests/performance/runs/__init__.py ================================================ ================================================ FILE: tests/performance/runs/compare.py ================================================ #!/usr/bin/env python # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Compare artifacts between runs """ # pylint: disable=redefined-builtin, self-assigning-variable, broad-except import csv import glob import logging import sys import os import pandas as pd from junitparser import TestCase, TestSuite, JUnitXml, Skipped, Error, Failure from runs.taurus import reader as taurus_reader from runs.storage import LocalStorage, S3Storage from utils import Timer, get_sub_dirs logger = logging.getLogger(__name__) logging.basicConfig(stream=sys.stdout, format="%(message)s", level=logging.INFO) class CompareReportGenerator(): def __init__(self, path, env_name, local_run): self.artifacts_dir = path self.current_run_name = os.path.basename(path) self.env_name = env_name storage_class = LocalStorage if local_run else S3Storage self.storage = storage_class(self.artifacts_dir, self.env_name) self.junit_reporter = None self.pandas_result = None self.pass_fail = True def gen(self): """Driver method to get comparison directory, do the comparison of it with current run directory and then store results """ compare_dir, compare_run_name = self.storage.get_dir_to_compare() if compare_run_name: self.junit_reporter, self.pandas_result = compare_artifacts(self.storage.artifacts_dir, compare_dir, self.storage.current_run_name, compare_run_name) self.pandas_result.to_csv(os.path.join(self.artifacts_dir, "comparison_result.csv")) else: logger.warning("The latest run not found for env.") self.storage.store_results() return self.junit_reporter class CompareTestSuite(): """ Wrapper helper class over JUnit parser Test Suite """ result_types = {"pass": [lambda x: None, "tests"], "fail": [Failure, "failures"], "error": [Error, "errors"], "skip": [Skipped, "skipped"]} def __init__(self, name, hostname, t): self.ts = TestSuite(name) self.ts.errors, self.ts.failures, self.ts.skipped, self.ts.tests = 0, 0, 0, 0 self.ts.hostname = hostname self.ts.timestamp = t.start def add_test_case(self, name, msg, type): tc = TestCase(name) result_type = CompareTestSuite.result_types[type] tc.result = result_type[0](msg) self.ts.add_testcase(tc) setattr(self.ts, result_type[1], getattr(self.ts, result_type[1]) + 1) def get_log_file(dir, sub_dir): """Get metric monitoring log files""" metrics_file = os.path.join(dir, sub_dir, "metrics.csv") return metrics_file if os.path.exists(metrics_file) else None def get_aggregate_val(df, agg_func, col): """Get aggregate values of a pandas datframe coulmn for given aggregate function""" val = None if str(col) in df: try: val = float(getattr(df[str(col)], agg_func)()) except TypeError: val = None return val def compare_values(val1, val2, diff_percent, run_name1, run_name2): """ Compare percentage diff values of val1 and val2 """ if pd.isna(val1) or pd.isna(val2): msg = "Either of the value can not be determined. The run1 value is '{}' and " \ "run2 value is {}.".format(val1, val2) pass_fail, diff, msg = "error", "NA", msg else: try: if val2 != val1: diff = (abs(val2 - val1) / ((val2 + val1) / 2)) * 100 if diff < float(diff_percent): pass_fail, diff, msg = "pass", diff, "passed" else: msg = "The diff_percent criteria has failed. The expected diff_percent is '{}' and actual " \ "diff percent is '{}' and the '{}' run value is '{}' and '{}' run value is '{}'. ". \ format(diff_percent, diff, run_name1, val1, run_name2, val2) pass_fail, diff, msg = "fail", diff, msg else: # special case of 0 pass_fail, diff, msg = "pass", 0, "" except Exception as e: msg = "error while calculating the diff for val1={} and val2={}." \ "Error is: {}".format(val1, val2, str(e)) logger.info(msg) pass_fail, diff, msg = "pass", "NA", msg return diff, pass_fail, msg def compare_artifacts(dir1, dir2, run_name1, run_name2): """Compare artifacts from dir1 with di2 and store results in out_dir""" logger.info("Comparing artifacts from %s with %s", dir1, dir2) sub_dirs_1 = get_sub_dirs(dir1) over_all_pass = True aggregates = ["mean", "max", "min"] header = ["run_name1", "run_name2", "test_suite", "metric", "run1", "run2", "percentage_diff", "expected_diff", "result", "message"] rows = [header] reporter = JUnitXml() for sub_dir1 in sub_dirs_1: with Timer("Comparison test suite {} execution time".format(sub_dir1)) as t: comp_ts = CompareTestSuite(sub_dir1, run_name1 + " and " + run_name1, t) metrics_file1, metrics_file2 = get_log_file(dir1, sub_dir1), get_log_file(dir2, sub_dir1) if not (metrics_file1 and metrics_file2): msg = "Metrics monitoring logs are not captured for {} in either " \ "of the runs.".format(sub_dir1) logger.info(msg) rows.append([run_name1, run_name2, sub_dir1, "metrics_log_file_availability", "NA", "NA", "NA", "NA", "pass", msg]) comp_ts.add_test_case("metrics_log_file_availability", msg, "skip") continue metrics_from_file1 = pd.read_csv(metrics_file1) metrics_from_file2 = pd.read_csv(metrics_file2) metrics, diff_percents = taurus_reader.get_compare_metric_list(dir1, sub_dir1) for col, diff_percent in zip(metrics, diff_percents): for agg_func in aggregates: name = "{}_{}".format(agg_func, str(col)) val1 = get_aggregate_val(metrics_from_file1, agg_func, col) val2 = get_aggregate_val(metrics_from_file2, agg_func, col) diff, pass_fail, msg = compare_values(val1, val2, diff_percent, run_name1, run_name2) if over_all_pass: over_all_pass = pass_fail == "pass" result_row = [run_name1, run_name2, sub_dir1, name, val1, val2, diff, diff_percent, pass_fail, msg] rows.append(result_row) test_name = "{}: diff_percent < {}".format(name, diff_percent) comp_ts.add_test_case(test_name, msg, pass_fail) comp_ts.ts.time = t.diff() comp_ts.ts.update_statistics() reporter.add_testsuite(comp_ts.ts) dataframe = pd.DataFrame(rows[1:], columns=rows[0]) return reporter, dataframe ================================================ FILE: tests/performance/runs/context.py ================================================ #!/usr/bin/env python # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Start and stop monitoring server """ # pylint: disable=redefined-builtin import logging import os import sys import time import webbrowser from termcolor import colored from junitparser import JUnitXml from runs.compare import CompareReportGenerator from runs.junit import JunitConverter, junit2tabulate from utils import run_process logger = logging.getLogger(__name__) logging.basicConfig(stream=sys.stdout, format="%(message)s", level=logging.INFO) class ExecutionEnv(object): """ Context Manager class to run the performance regression suites """ def __init__(self, agent, artifacts_dir, env, local_run, use=True, check_mms_server_status=False): self.monitoring_agent = agent self.artifacts_dir = artifacts_dir self.use = use self.env = env self.local_run = local_run self.check_mms_server_status = check_mms_server_status self.reporter = JUnitXml() self.compare_reporter_generator = CompareReportGenerator(self.artifacts_dir, self.env, self.local_run) def __enter__(self): if self.use: start_monitoring_server = "{} {} --start".format(sys.executable, self.monitoring_agent) run_process(start_monitoring_server, wait=False) time.sleep(2) return self @staticmethod def open_report(file_path): if os.path.exists(file_path): return webbrowser.open_new_tab('file://' + os.path.realpath(file_path)) return False @staticmethod def report_summary(reporter, suite_name): if reporter and os.path.exists(reporter.junit_html_path): status = reporter.junit_xml.errors or reporter.junit_xml.failures or reporter.junit_xml.skipped status, code, color = ("failed", 3, "red") if status else ("passed", 0, "green") msg = "{} run has {}.".format(suite_name, status) logger.info(colored(msg, color, attrs=['reverse', 'blink'])) logger.info("%s report - %s", suite_name, reporter.junit_html_path) logger.info("%s summary:", suite_name) print(junit2tabulate(reporter.junit_xml)) ExecutionEnv.open_report(reporter.junit_html_path) return code else: msg = "{} run report is not generated.".format(suite_name) logger.info(colored(msg, "yellow", attrs=['reverse', 'blink'])) return 0 def __exit__(self, type, value, traceback): if self.use: stop_monitoring_server = "{} {} --stop".format(sys.executable, self.monitoring_agent) run_process(stop_monitoring_server) junit_reporter = JunitConverter(self.reporter, self.artifacts_dir, 'performance_results') junit_reporter.generate_junit_report() junit_compare = self.compare_reporter_generator.gen() junit_compare_reporter = None if junit_compare: junit_compare_reporter = JunitConverter(junit_compare, self.artifacts_dir, 'comparison_results') junit_compare_reporter.generate_junit_report() compare_exit_code = ExecutionEnv.report_summary(junit_compare_reporter, "Comparison Test suite") exit_code = ExecutionEnv.report_summary(junit_reporter, "Performance Regression Test suite") sys.exit(0 if 0 == exit_code == compare_exit_code else 3) ================================================ FILE: tests/performance/runs/junit.py ================================================ #!/usr/bin/env python # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Start and stop monitoring server """ # pylint: disable=redefined-builtin import os import html import textwrap import tabulate from utils import run_process from junitparser import JUnitXml header = ["suite_name", "test_case", "result", "message"] class JunitConverter(): def __init__(self, junit_xml, out_dir, report_name): self.junit_xml = junit_xml self.junit_xml_path = os.path.join(out_dir, '{}.xml'.format(report_name)) self.junit_html_path = os.path.join(out_dir, '{}.html'.format(report_name)) def generate_junit_report(self): self.junit_xml.update_statistics() self.junit_xml.write(self.junit_xml_path) # vjunit pip package is used here run_process("vjunit -f {} -o {}".format(self.junit_xml_path, self.junit_html_path)) def pretty_text(data): """Unsescape the html characters from the data & wrap it""" if data is not None: return textwrap.fill(html.unescape(html.unescape(data)), width=60) else: return "" def junit2array(junit_xml): """convert junit xml junitparser.JUnitXml object to 2d array """ rows = [header] for i, suite in enumerate(junit_xml): if len(suite) == 0: rows.append([suite.name, "", "skipped", "No criteria specified or there is an error."]) else: for case in suite: result = case.result tag, msg = (result._tag, result.message) if result else ("passed", "") rows.append([suite.name, pretty_text(case.name), tag, pretty_text(msg)]) return rows def junit2tabulate(junit_xml): """convert junit xml junitparser.JUnitXml object or a Junit xml to tabulate string """ if not isinstance(junit_xml, JUnitXml): if os.path.exists(junit_xml): junit_xml = JUnitXml.fromfile(junit_xml) else: return tabulate.tabulate([[header]], headers='firstrow') data = junit2array(junit_xml) return tabulate.tabulate(data, headers='firstrow', tablefmt="grid") ================================================ FILE: tests/performance/runs/storage.py ================================================ #!/usr/bin/env python # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Result store classes """ # pylint: disable=redefined-builtin import logging import os import sys import shutil import boto3 import pathlib from agents import configuration from utils import run_process logger = logging.getLogger(__name__) logging.basicConfig(stream=sys.stdout, format="%(message)s", level=logging.INFO) S3_BUCKET = configuration.get('suite', 's3_bucket') class Storage(): """Class to store and retrieve artifacts""" def __init__(self, path, env_name): self.artifacts_dir = path self.current_run_name = os.path.basename(path) self.env_name = env_name def get_dir_to_compare(self): """get the artifacts dir to compare to""" def store_results(self): """Store the results""" @staticmethod def get_latest(names, env_name, exclude_name): """ Get latest directory for same env_name name given a list of them. :param names: list of folder names in the format env_name___commitid__timestamp :param env_name: filter for env_name :param exclude_name: any name to exclude :return: latest directory name """ max_ts = 0 latest_run = '' for run_name in names: run_name_list = run_name.split('__') if env_name == run_name_list[0] and run_name != exclude_name: if int(run_name_list[2]) > max_ts: max_ts = int(run_name_list[2]) latest_run = run_name return latest_run class LocalStorage(Storage): """ Compare the monitoring metrics for current and previous run for the same env_name """ def get_dir_to_compare(self): """Get latest run directory name to be compared with""" parent_dir = pathlib.Path(self.artifacts_dir).parent names = [di for di in os.listdir(parent_dir) if os.path.isdir(os.path.join(parent_dir, di))] latest_run = self.get_latest(names, self.env_name, self.current_run_name) return os.path.join(parent_dir, latest_run), latest_run class S3Storage(Storage): """Compare current run results with the results stored on S3""" def get_dir_to_compare(self): """Get latest run result artifacts directory for same env_name from S3 bucket and store it locally for further comparison """ comp_data_path = os.path.join(self.artifacts_dir, "comp_data") s3 = boto3.resource('s3') bucket = s3.Bucket(S3_BUCKET) result = bucket.meta.client.list_objects(Bucket=bucket.name, Delimiter='/') run_names = [] for o in result.get('CommonPrefixes'): run_names.append(o.get('Prefix')[:-1]) latest_run = self.get_latest(run_names, self.env_name, self.current_run_name) if not latest_run: logger.info("No run found for env_id %s", self.env_name) return '', '' if not os.path.exists(comp_data_path): os.makedirs(comp_data_path) tgt_path = os.path.join(comp_data_path, latest_run) run_process("aws s3 cp s3://{}/{} {} --recursive".format(bucket.name, latest_run, tgt_path)) return tgt_path, latest_run def store_results(self): """Store the run results back to S3""" comp_data_path = os.path.join(self.artifacts_dir, "comp_data") if os.path.exists(comp_data_path): shutil.rmtree(comp_data_path) run_process("aws s3 cp {} s3://{}/{} --recursive".format(self.artifacts_dir, S3_BUCKET, self.current_run_name)) ================================================ FILE: tests/performance/runs/taurus/__init__.py ================================================ #!/usr/bin/env python # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Convert the Taurus Test suite XML to Junit XML """ # pylint: disable=redefined-builtin import glob import shutil import os from .reader import get_mon_metrics_list def get_taurus_options(artifacts_dir, jmeter_path=None): """The options for Taurus BZT command""" options = [] if jmeter_path: options.append('-o modules.jmeter.path={}'.format(jmeter_path)) options.append('-o settings.artifacts-dir={}'.format(artifacts_dir)) options.append('-o modules.console.disable=true') options.append('-o settings.env.BASEDIR={}'.format(artifacts_dir)) options_str = ' '.join(options) return options_str def update_taurus_metric_files(suite_artifacts_dir, test_file): """ It renames the server and local metric monitoring log files to metrics.csv. The order of the columns in header of server metric monitoring SALogs file generated by taurus is not inline with data. So as a work around this function rewrites the header based on order defined in the test yaml. """ metrics_new_file = os.path.join(suite_artifacts_dir, "metrics.csv") server_metric_file_pattern = os.path.join(suite_artifacts_dir, "SAlogs_*") metrics_log_file = glob.glob(server_metric_file_pattern) if metrics_log_file: metrics = get_mon_metrics_list(test_file) if metrics: with open(metrics_log_file[0]) as from_file: line = from_file.readline() with open(metrics_log_file[0], mode="w") as to_file: to_file.write(','.join(line.split(',')[0:1] + metrics) + "\n") shutil.copyfileobj(from_file, to_file) os.rename(metrics_log_file[0], metrics_new_file) else: metrics_log_file = os.path.join(suite_artifacts_dir, "local_monitoring_logs.csv") if os.path.exists(metrics_log_file): os.rename(metrics_log_file, metrics_new_file) ================================================ FILE: tests/performance/runs/taurus/reader.py ================================================ #!/usr/bin/env python # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Run shell command utilities """ # pylint: disable=redefined-builtin import os import yaml def get_mon_metrics_list(test_yaml_path): """Utility method to get list of server-agent metrics which are being monitored from a test yaml file""" metrics = [] with open(test_yaml_path) as test_yaml: test_yaml = yaml.safe_load(test_yaml) for rep_section in test_yaml.get('services', []): if rep_section.get('module', None) == 'monitoring' and "server-agent" in rep_section: for mon_section in rep_section.get('server-agent', []): if isinstance(mon_section, dict): metrics.extend(mon_section.get('metrics', [])) return metrics def get_compare_metric_list(dir, sub_dir): """Utility method to get list of compare monitoring metrics identified by diff_percent property""" diff_percents = [] metrics = [] test_yaml = os.path.join(dir, sub_dir, "effective.yml") with open(test_yaml) as test_yaml: test_yaml = yaml.safe_load(test_yaml) for rep_section in test_yaml.get('reporting', []): if rep_section.get('module', None) == 'passfail': for criterion in rep_section.get('criteria', []): if isinstance(criterion, dict) and 'monitoring' in criterion.get('class', ''): subject = criterion["subject"] metric = subject.rsplit('/', 1) metric = metric[1] if len(metric) == 2 else metric[0] diff_percent = criterion.get("diff_percent", None) if diff_percent: metrics.append(metric) diff_percents.append(diff_percent) return metrics, diff_percents ================================================ FILE: tests/performance/runs/taurus/x2junit.py ================================================ #!/usr/bin/env python # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Convert the Taurus Test suite XML to Junit XML """ # pylint: disable=redefined-builtin import os from junitparser import TestCase, TestSuite, JUnitXml, Skipped, Error, Failure class X2Junit(object): """ Context Manager class to do convert Taurus Test suite XML report which is in Xunit specifications to JUnit XML report. """ def __init__(self, name, artifacts_dir, junit_xml, timer, env_name): self.ts = TestSuite(name) self.name = name self.junit_xml = junit_xml self.timer = timer self.artifacts_dir = artifacts_dir self.env_name = env_name def __enter__(self): return self def __exit__(self, type, value, traceback): xunit_file = os.path.join(self.artifacts_dir, "xunit.xml") tests, failures, skipped, errors = 0, 0, 0, 0 if os.path.exists(xunit_file): xml = JUnitXml.fromfile(xunit_file) for i, suite in enumerate(xml): for case in suite: name = "scenario_{}: {}".format(i, case.name) result = case.result if isinstance(result, Error): failures += 1 result = Failure(result.message, result.type) elif isinstance(result, Failure): errors += 1 result = Error(result.message, result.type) elif isinstance(result, Skipped): skipped += 1 else: tests += 1 tc = TestCase(name) tc.result = result self.ts.add_testcase(tc) else: tc = TestCase(self.name) tc.result = Skipped() self.ts.add_testcase(tc) self.ts.hostname = self.env_name self.ts.timestamp = self.timer.start self.ts.time = self.timer.diff() self.ts.tests = tests self.ts.failures = failures self.ts.skipped = skipped self.ts.errors = errors self.ts.update_statistics() self.junit_xml.add_testsuite(self.ts) ================================================ FILE: tests/performance/tests/api_description/api_description.jmx ================================================ false true false ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(management_port,8444)} OPTIONS true false true false ${__P(port,8443)} OPTIONS true false true false ================================================ FILE: tests/performance/tests/api_description/api_description.yaml ================================================ --- execution: - concurrency: 10 ramp-up: 1s hold-for: 30s scenario: api_description scenarios: api_description: script: api_description.jmx modules: server_local_monitoring: class : metrics_monitoring_inproc.Monitor services: - module: shellexec prepare: - "multi-model-server --start > /dev/null 2>&1" - "sleep 10s" post-process: - "multi-model-server --stop > /dev/null 2>&1" - module: server_local_monitoring ServerLocalClient: - interval: 1s logging : True metrics: - total_processes - sum_all_file_descriptors - sum_all_memory_rss reporting: - module: passfail criteria: # Inbuilt Criteria - success of ManagementAPIDescription<${MGMT_DESC_SUCC}, stop as failed - success of InferenceAPIDescription<${INFR_DESC_SUCC}, stop as failed - avg-rt of ManagementAPIDescription>${MGMT_DESC_RT}, stop as failed - avg-rt of InferenceAPIDescription>${INFR_DESC_RT}, stop as failed # Custom Criteria - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '>' threshold: ${TOTAL_PROCS} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '<' threshold: ${TOTAL_PROCS} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_all_file_descriptors condition: '>' threshold: ${TOTAL_FDS} timeframe: 5s stop : true fail : true # - class: bzt.modules.monitoring.MonitoringCriteria # subject: ServerLocalClient/sum_all_memory_rss # condition: '>' # threshold: ${TOTAL_MEM} # timeframe: 5s # stop : true # fail : true ================================================ FILE: tests/performance/tests/api_description/environments/xlarge.yaml ================================================ --- settings: env: MGMT_DESC_SUCC: 100% INFR_DESC_SUCC: 100% MGMT_DESC_RT : 10ms INFR_DESC_RT : 10ms TOTAL_PROCS : 1 TOTAL_FDS : 73 TOTAL_MEM: 100000000 #100MB ================================================ FILE: tests/performance/tests/batch_and_single_inference/batch_and_single_inference.jmx ================================================ false true false model1 ${__P(model_name1,resnet-152)} = model2 ${__P(model_name2,squeezenet_v1.1)} = ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false ${__P(loops,200)} ${__P(threads,20)} ${__P(rampup,5)} false true ${__P(input_filepath)} data image/jpeg /predictions/${model1} POST true false true true ${__P(input_filepath)} data image/jpeg /predictions/${model2} POST true false true true ================================================ FILE: tests/performance/tests/batch_and_single_inference/batch_and_single_inference.yaml ================================================ --- execution: - concurrency: 10 ramp-up: 5s hold-for: 20s scenario: Inference scenarios: Inference: script: batch_and_single_inference.jmx modules: server_local_monitoring: # metrics_monitoring_inproc and dependencies should be in python path class : metrics_monitoring_inproc.Monitor # monitoring class. services: - module: shellexec prepare: - "curl -s -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg" - "multi-model-server --start > /dev/null 2>&1" - "sleep 20s" - "curl -s -X POST http://localhost:8081/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/examples/resnet-152-batching/resnet-152.mar&batch_size=8&max_batch_delay=50" # uncomment below and comment prev and use downloaded model with model-store #- curl -s -X POST "http://localhost:8081/models?url=resnet-152.mar&batch_size=8&max_batch_delay=60&initial_workers=1" - "curl -s -X PUT http://localhost:8081/models/resnet-152?min_worker=2&synchronous=true" - "curl -s -X POST http://localhost:8081/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar" - "curl -s -X PUT http://localhost:8081/models/squeezenet_v1.1?min_worker=2&synchronous=true" post-process: - "multi-model-server --stop > /dev/null 2>&1" - "rm kitten.jpg" - module: server_local_monitoring # should be added in modules section ServerLocalClient: # keyword from metrics_monitoring_inproc.Monitor - interval: 1s logging : True metrics: - sum_workers_memory_rss - sum_workers_file_descriptors - total_workers - orphans reporting: - module: passfail criteria: - subject: avg-rt # required label: 'Inference1' # optional, default is '' condition: '>' # required threshold: ${INFR1_RT} # required logic: for # optional, logic to aggregate values within timeframe. # Default 'for' means take latest, # 'within' and 'over' means take sum/avg of all values within interval timeframe: 1s # optional, default is none stop: true # optional, default is true. false for nonstop testing until the end fail: true # optional, default is true - subject: avg-rt # required label: 'Inference2' # optional, default is '' condition: '>' # required threshold: ${INFR2_RT} # required logic: for # optional, logic to aggregate values within timeframe. # Default 'for' means take latest, # 'within' and 'over' means take sum/avg of all values within interval timeframe: 1s # optional, default is none stop: true # optional, default is true. false for nonstop testing until the end fail: true # optional, default is true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_workers_memory_rss condition: '>' threshold: ${TOTAL_WORKERS_MEM} timeframe: 1s stop : true fail : true diff_percent : 30 - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/orphans condition: '>' threshold: ${TOTAL_ORPHANS} timeframe: 1s stop : true fail : true diff_percent : 0 - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_workers condition: '>' threshold: ${TOTAL_WORKERS} timeframe: 1s stop: true fail: true diff_percent: 0 - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_workers_file_descriptors condition: '>' threshold: ${TOTAL_WORKERS_FDS} timeframe: 1s stop: true fail: true diff_percent: 30 ================================================ FILE: tests/performance/tests/batch_and_single_inference/environments/xlarge.yaml ================================================ --- settings: env: INFR1_RT : 6s INFR2_RT : 0.08s TOTAL_WORKERS_MEM : 4000000000 #4GB TOTAL_WORKERS : 9 TOTAL_ORPHANS : 0 TOTAL_WORKERS_FDS : 78 ================================================ FILE: tests/performance/tests/batch_inference/batch_inference.jmx ================================================ false true false model ${__P(model_name,resnet-152)} = ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false ${__P(loops,200)} ${__P(threads,20)} ${__P(rampup,5)} false true ${__P(input_filepath)} data image/jpeg /predictions/${model} POST true false true true ================================================ FILE: tests/performance/tests/batch_inference/batch_inference.yaml ================================================ --- execution: - concurrency: 10 ramp-up: 5s hold-for: 20s scenario: Inference scenarios: Inference: script: batch_inference.jmx modules: server_local_monitoring: # metrics_monitoring_inproc and dependencies should be in python path class : metrics_monitoring_inproc.Monitor # monitoring class. services: - module: shellexec prepare: - "curl -s -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg" - "multi-model-server --start > /dev/null 2>&1" - "sleep 20s" - "curl -s -X POST http://localhost:8081/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/examples/resnet-152-batching/resnet-152.mar&batch_size=8&max_batch_delay=50" # uncomment below and comment prev and use downloaded model with model-store #- "curl -s -X POST http://localhost:8081/models?url=resnet-152.mar&batch_size=8&max_batch_delay=60&initial_workers=1" - "curl -s -X PUT http://localhost:8081/models/resnet-152?min_worker=2&synchronous=true" post-process: - "multi-model-server --stop > /dev/null 2>&1" - "rm kitten.jpg" - module: server_local_monitoring # should be added in modules section ServerLocalClient: # keyword from metrics_monitoring_inproc.Monitor - interval: 1s logging : True metrics: - sum_workers_memory_rss - sum_workers_file_descriptors - total_workers - orphans reporting: - module: passfail criteria: - subject: avg-rt # required label: 'Inference' # optional, default is '' condition: '>' # required threshold: ${INFR_RT} # required logic: for # optional, logic to aggregate values within timeframe. # Default 'for' means take latest, # 'within' and 'over' means take sum/avg of all values within interval timeframe: 1s # optional, default is none stop: true # optional, default is true. false for nonstop testing until the end fail: true # optional, default is true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_workers_memory_rss condition: '>' threshold: ${TOTAL_WORKERS_MEM} timeframe: 1s stop : true fail : true diff_percent : 30 - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/orphans condition: '>' threshold: ${TOTAL_ORPHANS} timeframe: 1s stop : true fail : true diff_percent : 0 - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_workers condition: '>' threshold: ${TOTAL_WORKERS} timeframe: 1s stop: true fail: true diff_percent: 0 - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_workers_file_descriptors condition: '>' threshold: ${TOTAL_WORKERS_FDS} timeframe: 1s stop: true fail: true diff_percent: 30 ================================================ FILE: tests/performance/tests/batch_inference/environments/xlarge.yaml ================================================ --- settings: env: INFR_RT : 1.5s TOTAL_WORKERS_MEM : 3000000000 #3GB TOTAL_WORKERS : 4 TOTAL_ORPHANS : 0 TOTAL_WORKERS_FDS : 38 ================================================ FILE: tests/performance/tests/examples_local_criteria/environments/xlarge.yaml ================================================ --- settings: env: FAIL : 100% P90 : 290ms AVG_RT : 1s TOTAL_WORKERS_MEM : 132000000 PERCENT_DIFF_TOTAL_WORKERS_MEM : 5 ================================================ FILE: tests/performance/tests/examples_local_criteria/examples_local_criteria.jmx ================================================ false true false cnn_url https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar = The url from where to fetch noop model from scale_up_workers ${__P(min_workers,1)} = The workers to scale No op model to scale_down_workers 0 Offload the No-Op Model = model squeezenet_v1.1 = ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false 1 1 1 false true ${__P(management_port,8444)} /models?url=${cnn_url} POST true false true false ${__P(management_port,8444)} /models/${model}?min_worker=${scale_up_workers} PUT true false true false continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(input_filepath)} data image/jpeg /predictions/${model} POST true false true true continue false 1 1 1 false true ${__P(management_port,8444)} /models/${model}?min_worker=${scale_down_workers} DELETE true false true false ================================================ FILE: tests/performance/tests/examples_local_criteria/examples_local_criteria.yaml ================================================ --- execution: - concurrency: 1 ramp-up: 5s hold-for: 20s scenario: Inference scenarios: Inference: script: examples_local_criteria.jmx modules: server_local_monitoring: # metrics_monitoring_inproc and dependencies should be in python path class : metrics_monitoring_inproc.Monitor # monitoring class. services: - module: shellexec prepare: - "curl -s -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg" - "multi-model-server --start > /dev/null 2>&1" - "sleep 10s" post-process: - "multi-model-server --stop > /dev/null 2>&1" - "rm kitten.jpg" - module: server_local_monitoring # should be added in modules section ServerLocalClient: # keyword from metrics_monitoring_inproc.Monitor - interval: 1s logging : True metrics: - cpu - disk-space - mem - sum_workers_memory_rss reporting: - module: passfail criteria: - fail >${FAIL}, stop as failed - p90 >${P90} , stop as failed - avg-rt >${AVG_RT} , stop as failed - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_workers_memory_rss condition: '>' threshold: ${TOTAL_WORKERS_MEM} timeframe: 1s stop : true fail : true diff_percent : ${PERCENT_DIFF_TOTAL_WORKERS_MEM} ================================================ FILE: tests/performance/tests/examples_local_monitoring/examples_local_monitoring.jmx ================================================ false true false cnn_url https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar = The url from where to fetch noop model from scale_up_workers ${__P(min_workers,1)} = The workers to scale No op model to scale_down_workers 0 Offload the No-Op Model = model squeezenet_v1.1 = ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false 1 1 1 false true ${__P(management_port,8444)} /models?url=${cnn_url} POST true false true false ${__P(management_port,8444)} /models/${model}?min_worker=${scale_up_workers} PUT true false true false continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(input_filepath)} data image/jpeg /predictions/${model} POST true false true true continue false 1 1 1 false true ${__P(management_port,8444)} /models/${model}?min_worker=${scale_down_workers} DELETE true false true false ================================================ FILE: tests/performance/tests/examples_local_monitoring/examples_local_monitoring.yaml ================================================ --- execution: - concurrency: 1 ramp-up: 5s hold-for: 20s scenario: Inference scenarios: Inference: script: examples_local_monitoring.jmx modules: server_local_monitoring: # metrics_monitoring_inproc and dependencies should be in python path class : metrics_monitoring_inproc.Monitor # monitoring class. services: - module: shellexec prepare: - "curl -s -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg" - "multi-model-server --start > /dev/null 2>&1" - "sleep 10s" post-process: - "multi-model-server --stop > /dev/null 2>&1" - "rm kitten.jpg" - module: server_local_monitoring # should be added in modules section ServerLocalClient: # keyword from metrics_monitoring_inproc.Monitor - interval: 1s metrics: - cpu - disk-space - mem - sum_workers_memory_percent ================================================ FILE: tests/performance/tests/examples_remote_criteria/environments/xlarge.yaml ================================================ --- settings: env: FAIL : 50% P90 : 250ms AVG_RT : 1s TOTAL_WORKERS_FDS : 80 ================================================ FILE: tests/performance/tests/examples_remote_criteria/examples_remote_criteria.jmx ================================================ false true false cnn_url https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar = The url from where to fetch noop model from scale_up_workers ${__P(min_workers,1)} = The workers to scale No op model to scale_down_workers 0 Offload the No-Op Model = model squeezenet_v1.1 = ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false 1 1 1 false true ${__P(management_port,8444)} /models?url=${cnn_url} POST true false true false ${__P(management_port,8444)} /models/${model}?min_worker=${scale_up_workers} PUT true false true false continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(input_filepath)} data image/jpeg /predictions/${model} POST true false true true continue false 1 1 1 false true ${__P(management_port,8444)} /models/${model}?min_worker=${scale_down_workers} DELETE true false true false ================================================ FILE: tests/performance/tests/examples_remote_criteria/examples_remote_criteria.yaml ================================================ execution: - concurrency: 4 ramp-up: 1s hold-for: 20s scenario: Inference scenarios: Inference: script: examples_remote_criteria.jmx services: - module: shellexec prepare: - "curl -s -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg" - "multi-model-server --start > /dev/null 2>&1" - "sleep 10s" post-process: - "multi-model-server --stop > /dev/null 2>&1" - "rm kitten.jpg" - module: monitoring server-agent: - address: localhost:9009 # metric monitoring service address label: mms-inference-server # if you specify label, it will be used in reports instead of ip:port interval: 1s # polling interval logging: True # those logs will be saved to "SAlogs_192.168.0.1_9009.csv" in the artifacts dir metrics: # metrics should be supported by monitoring service - sum_workers_cpu_percent # cpu percent used by all the mms server processes and workers - sum_workers_memory_percent - sum_workers_file_descriptors - total_workers # no of mms workers reporting: - module: passfail criteria: - fail >${FAIL}, stop as failed - p90 >${P90} , stop as failed - avg-rt >${AVG_RT} , stop as failed - class: bzt.modules.monitoring.MonitoringCriteria subject: mms-inference-server/sum_workers_file_descriptors condition: '>' threshold: ${TOTAL_WORKERS_FDS} timeframe: 1s fail: true stop: true diff_percent : 35 ================================================ FILE: tests/performance/tests/examples_remote_monitoring/examples_remote_monitoring.jmx ================================================ false true false cnn_url https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar = The url from where to fetch noop model from scale_up_workers ${__P(min_workers,1)} = The workers to scale No op model to scale_down_workers 0 Offload the No-Op Model = model squeezenet_v1.1 = ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false 1 1 1 false true ${__P(management_port,8444)} /models?url=${cnn_url} POST true false true false ${__P(management_port,8444)} /models/${model}?min_worker=${scale_up_workers} PUT true false true false continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(input_filepath)} data image/jpeg /predictions/${model} POST true false true true continue false 1 1 1 false true ${__P(management_port,8444)} /models/${model}?min_worker=${scale_down_workers} DELETE true false true false ================================================ FILE: tests/performance/tests/examples_remote_monitoring/examples_remote_monitoring.yaml ================================================ execution: - concurrency: 4 ramp-up: 1s hold-for: 20s scenario: Inference scenarios: Inference: script: examples_remote_monitoring.jmx services: - module: shellexec prepare: - "curl -s -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg" - "multi-model-server --start > /dev/null 2>&1" - "sleep 10s" post-process: - "multi-model-server --stop > /dev/null 2>&1" - "rm kitten.jpg" - module: monitoring server-agent: - address: localhost:9009 # metric monitoring service address label: mms-inference-server # if you specify label, it will be used in reports instead of ip:port interval: 1s # polling interval logging: True # those logs will be saved to "SAlogs_192.168.0.1_9009.csv" in the artifacts dir metrics: # metrics should be supported by monitoring service - sum_all_cpu_percent # cpu percent used by all the mms server processes and workers - sum_workers_memory_percent - frontend_file_descriptors - total_workers # no of mms workers ================================================ FILE: tests/performance/tests/examples_starter/examples_starter.jmx ================================================ false true false cnn_url https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar = The url from where to fetch noop model from scale_up_workers ${__P(min_workers,1)} = The workers to scale No op model to scale_down_workers 0 Offload the No-Op Model = model squeezenet_v1.1 = ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false 1 1 1 false true ${__P(management_port,8444)} /models?url=${cnn_url} POST true false true false ${__P(management_port,8444)} /models/${model}?min_worker=${scale_up_workers} PUT true false true false continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(input_filepath)} data image/jpeg /predictions/${model} POST true false true true continue false 1 1 1 false true ${__P(management_port,8444)} /models/${model}?min_worker=${scale_down_workers} DELETE true false true false ================================================ FILE: tests/performance/tests/examples_starter/examples_starter.yaml ================================================ --- execution: - concurrency: 1 ramp-up: 1s hold-for: 40s scenario: Inference scenarios: Inference: script: examples_starter.jmx services: - module: shellexec prepare: - "curl -s -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg" - "multi-model-server --start > /dev/null 2>&1" - "sleep 10s" post-process: - "multi-model-server --stop > /dev/null 2>&1" - "rm kitten.jpg" ================================================ FILE: tests/performance/tests/global_config.yaml ================================================ modules: jmeter: # These are JMeter test case properties. These variables are used in jmx files. # Change the vaues as per your setup properties: hostname : 127.0.0.1 # MMS properties port : 8080 management_port : 8081 protocol : http input_filepath : kitten.jpg # make sure jpg is available at this path # if relative path is provided this will be relative to current working directory # DO-NOT change properties below unless you know what you are doing. # They are needed for performance test suite runner script. reporting: - module: passfail # this is to enable passfail module - module: junit-xml data-source: pass-fail - module: junit-xml data-source: sample-labels - module: final-stats dump-csv : ${BASEDIR}/final_stats.csv settings: env: BASEDIR : '.' ================================================ FILE: tests/performance/tests/health_check/environments/xlarge.yaml ================================================ --- settings: env: HLTH_CHK_SUCC : 100% HLTH_CHK_RT : 14ms TOTAL_PROCS : 1 TOTAL_FDS : 67 TOTAL_MEM : 750000000 #750MB ================================================ FILE: tests/performance/tests/health_check/health_check.jmx ================================================ false true false continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} /ping GET true false true false ================================================ FILE: tests/performance/tests/health_check/health_check.yaml ================================================ --- execution: - concurrency: 10 ramp-up: 1s hold-for: 30s scenario: health_check scenarios: health_check: script: health_check.jmx modules: server_local_monitoring: class : metrics_monitoring_inproc.Monitor services: - module: shellexec prepare: - "multi-model-server --start > /dev/null 2>&1" - "sleep 10s" post-process: - "multi-model-server --stop > /dev/null 2>&1" - module: server_local_monitoring ServerLocalClient: - interval: 1s logging : True metrics: - total_processes - sum_all_file_descriptors - sum_all_memory_rss reporting: - module: passfail criteria: # Inbuilt Criteria - success of HealthCheck<${HLTH_CHK_SUCC}, stop as failed - avg-rt of HealthCheck>${HLTH_CHK_RT}, stop as failed # Custom Criteria - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '>' threshold: ${TOTAL_PROCS} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '<' threshold: ${TOTAL_PROCS} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_all_file_descriptors condition: '>' threshold: ${TOTAL_FDS} timeframe: 5s stop : true fail : true # - class: bzt.modules.monitoring.MonitoringCriteria # subject: ServerLocalClient/sum_all_memory_rss # condition: '>' # threshold: ${TOTAL_MEM} # timeframe: 5s # stop : true # fail : true ================================================ FILE: tests/performance/tests/inference_multiple_models/environments/xlarge.yaml ================================================ --- settings: env: INFR1_SUCC : 100% INFR2_SUCC: 100% INFR1_RT : 290ms INFR2_RT: 450ms TOTAL_PROCS : 5 TOTAL_FDS : 107 TOTAL_MEM : 600000000 #600MB ================================================ FILE: tests/performance/tests/inference_multiple_models/inference_multiple_models.jmx ================================================ false true false model1 squeezenet_v1.1 = Model1 Name model2 resnet-18 Model2 Name = ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(input_filepath)} data image/jpeg /predictions/${model1} POST true false true true continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(input_filepath)} data image/jpeg /predictions/${model2} POST true false true true ================================================ FILE: tests/performance/tests/inference_multiple_models/inference_multiple_models.yaml ================================================ --- execution: - concurrency: 10 ramp-up: 1s hold-for: 30s scenario: inference_multiple_models scenarios: inference_multiple_models: script: inference_multiple_models.jmx modules: server_local_monitoring: class : metrics_monitoring_inproc.Monitor services: - module: shellexec prepare: - "curl -s -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg" - "multi-model-server --start > /dev/null 2>&1" - "sleep 20s" - "curl -s -X POST http://localhost:8081/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar" - "curl -s -X POST http://localhost:8081/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/resnet-18.mar" - "curl -s -X PUT http://localhost:8081/models/squeezenet_v1.1?min_worker=1&synchronous=true" - "curl -s -X PUT http://localhost:8081/models/resnet-18?min_worker=1&synchronous=true" post-process: - "multi-model-server --stop > /dev/null 2>&1" - "rm kitten.jpg" - module: server_local_monitoring ServerLocalClient: - interval: 1s logging : True metrics: - total_processes - sum_all_file_descriptors - sum_all_memory_rss reporting: - module: passfail criteria: # Inbuilt Criteria - success of Inference1<${INFR1_SUCC}, stop as failed - success of Inference2<${INFR2_SUCC}, stop as failed - avg-rt of Inference1>${INFR1_RT}, stop as failed - avg-rt of Inference2>${INFR2_RT}, stop as failed # Custom Criteria - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '>' threshold: ${TOTAL_PROCS} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '<' threshold: ${TOTAL_PROCS} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_all_file_descriptors condition: '>' threshold: ${TOTAL_FDS} timeframe: 5s stop : true fail : true # - class: bzt.modules.monitoring.MonitoringCriteria # subject: ServerLocalClient/sum_all_memory_rss # condition: '>' # threshold: ${TOTAL_MEM} # timeframe: 5s # stop : true # fail : true ================================================ FILE: tests/performance/tests/inference_multiple_worker/environments/xlarge.yaml ================================================ --- settings: env: INFR_SUCC : 100% INFR_RT : 140ms TOTAL_PROCS : 6 TOTAL_FDS : 126 TOTAL_MEM : 750000000 #750MB ================================================ FILE: tests/performance/tests/inference_multiple_worker/inference_multiple_worker.jmx ================================================ false true false model squeezenet_v1.1 = Model Name ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(input_filepath)} data image/jpeg /predictions/${model} POST true false true true ================================================ FILE: tests/performance/tests/inference_multiple_worker/inference_multiple_worker.yaml ================================================ --- execution: - concurrency: 10 ramp-up: 1s hold-for: 1m iterations: 100 scenario: inference_multiple_worker scenarios: inference_multiple_worker: script: inference_multiple_worker.jmx modules: server_local_monitoring: class : metrics_monitoring_inproc.Monitor services: - module: shellexec prepare: - "curl -s -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg" - "multi-model-server --start > /dev/null 2>&1" - "sleep 20s" - "curl -s -X POST http://localhost:8081/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar" - "curl -s -X PUT http://localhost:8081/models/squeezenet_v1.1?min_worker=4&synchronous=true" post-process: - "multi-model-server --stop > /dev/null 2>&1" - "rm kitten.jpg" - module: server_local_monitoring ServerLocalClient: - interval: 1s logging : True metrics: - total_processes - sum_all_file_descriptors - sum_all_memory_rss reporting: - module: passfail criteria: # Inbuilt Criteria - success of Inference<${INFR_SUCC}, stop as failed - avg-rt of Inference>${INFR_RT}, stop as failed # Custom Criteria - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '>' threshold: ${TOTAL_PROCS} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '<' threshold: ${TOTAL_PROCS} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_all_file_descriptors condition: '>' threshold: ${TOTAL_FDS} timeframe: 1s stop : true fail : true # - class: bzt.modules.monitoring.MonitoringCriteria # subject: ServerLocalClient/sum_all_memory_rss # condition: '>' # threshold: ${TOTAL_MEM} # timeframe: 5s # stop : true # fail : true ================================================ FILE: tests/performance/tests/inference_single_worker/environments/xlarge.yaml ================================================ --- settings: env: INFR_SUCC : 100% INFR_RT : 290ms TOTAL_PROCS : 3 TOTAL_FDS : 90 TOTAL_MEM : 290000000 #290MB ================================================ FILE: tests/performance/tests/inference_single_worker/inference_single_worker.jmx ================================================ false true false model squeezenet_v1.1 = Model Name ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(input_filepath)} data image/jpeg /predictions/${model} POST true false true true ================================================ FILE: tests/performance/tests/inference_single_worker/inference_single_worker.yaml ================================================ --- execution: - concurrency: 10 ramp-up: 1s hold-for: 1m iterations: 100 scenario: inference_single_worker scenarios: inference_single_worker: script: inference_single_worker.jmx modules: server_local_monitoring: class : metrics_monitoring_inproc.Monitor services: - module: shellexec prepare: - "curl -s -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg" - "multi-model-server --start > /dev/null 2>&1" - "sleep 20s" - "curl -s -X POST http://localhost:8081/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar" - "curl -s -X PUT http://localhost:8081/models/squeezenet_v1.1?min_worker=1&synchronous=true" post-process: - "multi-model-server --stop > /dev/null 2>&1" - "rm kitten.jpg" - module: server_local_monitoring ServerLocalClient: - interval: 1s logging : True metrics: - total_processes - sum_all_file_descriptors - sum_all_memory_rss reporting: - module: passfail criteria: # Inbuilt Criteria - success of Inference<${INFR_SUCC}, stop as failed - avg-rt of Inference>${INFR_RT}, stop as failed # Custom Criteria - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '>' threshold: ${TOTAL_PROCS} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '<' threshold: ${TOTAL_PROCS} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_all_file_descriptors condition: '>' threshold: ${TOTAL_FDS} timeframe: 1s stop : true fail : true # - class: bzt.modules.monitoring.MonitoringCriteria # subject: ServerLocalClient/sum_all_memory_rss # condition: '>' # threshold: ${TOTAL_MEM} # timeframe: 5s # stop : true # fail : true ================================================ FILE: tests/performance/tests/list_models/environments/xlarge.yaml ================================================ --- settings: env: LST_MODLS_SUCC : 100% LST_MODLS_RT : 14ms TOTAL_PROCS : 3 TOTAL_FDS : 86 TOTAL_MEM : 185000000 #185MB ================================================ FILE: tests/performance/tests/list_models/list_models.jmx ================================================ false true false model squeezenet_v1.1 = Model name ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(management_port,8444)} /models GET true false true false ================================================ FILE: tests/performance/tests/list_models/list_models.yaml ================================================ --- execution: - concurrency: 10 ramp-up: 1s hold-for: 30s scenario: list_models scenarios: list_models: script: list_models.jmx modules: server_local_monitoring: class : metrics_monitoring_inproc.Monitor services: - module: shellexec prepare: - "multi-model-server --start > /dev/null 2>&1" - "sleep 20s" - "curl -s -X POST http://localhost:8081/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar" - "curl -s -X POST http://localhost:8081/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/shufflenet.mar" post-process: - "multi-model-server --stop > /dev/null 2>&1" - module: server_local_monitoring ServerLocalClient: - interval: 1s logging : True metrics: - total_processes - sum_all_file_descriptors - sum_all_memory_rss reporting: - module: passfail criteria: # Inbuilt Criteria - success of ListModels<${LST_MODLS_SUCC}, stop as failed - avg-rt of ListModels>${LST_MODLS_RT}, stop as failed # Custom Criteria - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '>' threshold: ${TOTAL_PROCS} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '<' threshold: ${TOTAL_PROCS} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_all_file_descriptors condition: '>' threshold: ${TOTAL_FDS} timeframe: 1s stop : true fail : true # - class: bzt.modules.monitoring.MonitoringCriteria # subject: ServerLocalClient/sum_all_memory_rss # condition: '>' # threshold: ${TOTAL_MEM} # timeframe: 5s # stop : true # fail : true ================================================ FILE: tests/performance/tests/model_description/environments/xlarge.yaml ================================================ --- settings: env: MODL_DESC_SUCC : 100% MODL_DESC_RT : 14ms TOTAL_PROCS : 3 TOTAL_FDS : 90 TOTAL_MEM : 300000000 #300MB ================================================ FILE: tests/performance/tests/model_description/model_description.jmx ================================================ false true false model squeezenet_v1.1 = Model name ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(management_port,8444)} /models/${model} GET true false true false ================================================ FILE: tests/performance/tests/model_description/model_description.yaml ================================================ --- execution: - concurrency: 10 ramp-up: 1s hold-for: 30s scenario: model_description scenarios: model_description: script: model_description.jmx modules: server_local_monitoring: class : metrics_monitoring_inproc.Monitor services: - module: shellexec prepare: - "multi-model-server --start > /dev/null 2>&1" - "sleep 20s" - "curl -s -X POST http://localhost:8081/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar" - "curl -s -X PUT http://localhost:8081/models/squeezenet_v1.1?min_worker=1&synchronous=true" post-process: - "multi-model-server --stop > /dev/null 2>&1" - module: server_local_monitoring ServerLocalClient: - interval: 1s logging : True metrics: - total_processes - sum_all_file_descriptors - sum_all_memory_rss reporting: - module: passfail criteria: # Inbuilt Criteria - success of ModelDescription<${MODL_DESC_SUCC}, stop as failed - avg-rt of ModelDescription>${MODL_DESC_RT}, stop as failed # Custom Criteria - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '>' threshold: ${TOTAL_PROCS} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '<' threshold: ${TOTAL_PROCS} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_all_file_descriptors condition: '>' threshold: ${TOTAL_FDS} timeframe: 1s stop : true fail : true # - class: bzt.modules.monitoring.MonitoringCriteria # subject: ServerLocalClient/sum_all_memory_rss # condition: '>' # threshold: ${TOTAL_MEM} # timeframe: 5s # stop : true # fail : true ================================================ FILE: tests/performance/tests/multiple_inference_and_scaling/environments/xlarge.yaml ================================================ --- settings: env: INFR1_SUCC : 100% INFR2_SUCC: 100% INFR1_RT : 290ms INFR2_RT: 450ms TOTAL_PROCS : 14 TOTAL_FDS : 300 TOTAL_MEM : 2000000000 #~2GB TOTAL_ORPHANS : 0 FRNTEND_MEM : 1000000000 #~1GB ================================================ FILE: tests/performance/tests/multiple_inference_and_scaling/multiple_inference_and_scaling.jmx ================================================ false true false model1 squeezenet_v1.1 = Model1 Name model2 resnet-18 Model2 Name = scale_up_workers1 4 = scale_down_workers1 1 = scale_up_workers2 4 = scale_down_workers2 1 = ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(input_filepath)} data image/jpeg /predictions/${model1} POST true false true true 1 0 20 1000 100 ${__P(management_port,8444)} /models/${model1}?min_worker=${scale_down_workers1}&synchronous=true PUT true false true false 1 0 20 10000 100 ${__P(management_port,8444)} /models/${model1}?min_worker=${scale_up_workers1}&synchronous=true PUT true false true false 1 0 20 10000 100 continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(input_filepath)} data image/jpeg /predictions/${model2} POST true false true true 1 0 20 1000 100 ${__P(management_port,8444)} /models/${model2}?min_worker=${scale_down_workers2}&synchronous=true PUT true false true false 1 0 20 10000 100 ${__P(management_port,8444)} /models/${model2}?min_worker=${scale_up_workers2}&synchronous=true PUT true false true false 1 0 20 10000 100 ================================================ FILE: tests/performance/tests/multiple_inference_and_scaling/multiple_inference_and_scaling.yaml ================================================ --- execution: - concurrency: 10 ramp-up: 1s hold-for: 300s scenario: inference_multiple_models scenarios: inference_multiple_models: script: multiple_inference_and_scaling.jmx modules: server_local_monitoring: class : metrics_monitoring_inproc.Monitor services: - module: shellexec prepare: - "curl -s -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg" - "multi-model-server --start > /dev/null 2>&1" - "sleep 20s" - "curl -s -X POST http://localhost:8081/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar" - "curl -s -X POST http://localhost:8081/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/resnet-18.mar" - "curl -s -X PUT http://localhost:8081/models/squeezenet_v1.1?min_worker=1&synchronous=true" - "curl -s -X PUT http://localhost:8081/models/resnet-18?min_worker=1&synchronous=true" post-process: - "multi-model-server --stop > /dev/null 2>&1" - "rm kitten.jpg" - module: server_local_monitoring ServerLocalClient: - interval: 1s logging : True metrics: - total_processes - sum_all_file_descriptors - sum_all_memory_rss - frontend_memory_rss - orphans reporting: - module: passfail criteria: # Inbuilt Criteria - success of Inference1<${INFR1_SUCC}, stop as failed - success of Inference2<${INFR2_SUCC}, stop as failed - avg-rt of Inference1>${INFR1_RT}, stop as failed - avg-rt of Inference2>${INFR2_RT}, stop as failed # Custom Criteria - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '>' threshold: ${TOTAL_PROCS} timeframe: 10s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_all_file_descriptors condition: '>' threshold: ${TOTAL_FDS} timeframe: 5s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_all_memory_rss condition: '>' threshold: ${TOTAL_MEM} timeframe: 5s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/orphans condition: '>' threshold: ${TOTAL_ORPHANS} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/frontend_memory_rss condition: '>' threshold: ${FRNTEND_MEM} timeframe: 5s stop : true fail : true ================================================ FILE: tests/performance/tests/register_unregister/environments/xlarge.yaml ================================================ --- settings: env: REG_SUCC : 100% UNREG_SUCC: 100% REG_RT : 15s UNREG_RT: 10ms TOTAL_PROCS : 1 TOTAL_FDS : 66 TOTAL_ORPHANS : 0 FRNTEND_MEM : 75000000 #75MB ================================================ FILE: tests/performance/tests/register_unregister/register_unregister.jmx ================================================ false true false model squeezenet_v1.1 = Model name model_url https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar URL to model store on s3 = ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(management_port,8444)} /models?url=${model_url} POST true false true false ${__P(management_port,8444)} /models/${model} DELETE true false true false ================================================ FILE: tests/performance/tests/register_unregister/register_unregister.yaml ================================================ --- execution: - concurrency: 1 ramp-up: 0s # hold-for: 5h iterations: 5 scenario: register_unregister scenarios: register_unregister: script: register_unregister.jmx modules: server_local_monitoring: class : metrics_monitoring_inproc.Monitor services: - module: shellexec prepare: - "curl -s -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg" - "multi-model-server --start > /dev/null 2>&1" - "sleep 10s" post-process: - "multi-model-server --stop > /dev/null 2>&1" - "rm kitten.jpg" - module: server_local_monitoring ServerLocalClient: - interval: 1s logging : True metrics: - total_processes - sum_all_file_descriptors - frontend_memory_rss - orphans reporting: - module: passfail criteria: # Inbuilt Criteria - success of RegisterModel<${REG_SUCC}, stop as failed - success of UnregisterModel<${UNREG_SUCC}, stop as failed - avg-rt of RegisterModel>${REG_RT}, stop as failed - avg-rt of UnregisterModel>${UNREG_RT}, stop as failed # Custom Criteria - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '>' threshold: ${TOTAL_PROCS} timeframe: 5s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_all_file_descriptors condition: '>' threshold: ${TOTAL_FDS} timeframe: 5s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/orphans condition: '>' threshold: ${TOTAL_ORPHANS} timeframe: 1s stop : true fail : true # - class: bzt.modules.monitoring.MonitoringCriteria # subject: ServerLocalClient/frontend_memory_rss # condition: '>' # threshold: ${FRNTEND_MEM} # timeframe: 5s # stop : true # fail : true ================================================ FILE: tests/performance/tests/register_unregister_multiple/environments/xlarge.yaml ================================================ --- settings: env: REG_SUCC : 100% SCL_UP_SUCC: 100% UNREG_SUCC: 100% REG_RT : 15s SCL_UP_RT: 1.5s UNREG_RT: 18ms TOTAL_PROCS : 2 TOTAL_FDS : 73 FRNTEND_MEM : 120000000 #120MB ================================================ FILE: tests/performance/tests/register_unregister_multiple/register_unregister_multiple.jmx ================================================ false true false model squeezenet_v1.1 = Model name model_url https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar URL to model store on s3 = scale_up_workers 2 Numer of workers to scale up to = ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(management_port,8444)} /models?url=${model_url} POST true false true false ${__P(management_port,8444)} /models/${model}?min_worker=${scale_up_workers}&synchronous=true PUT true false true false ${__P(management_port,8444)} /models/${model} DELETE true false true false ================================================ FILE: tests/performance/tests/register_unregister_multiple/register_unregister_multiple.yaml ================================================ --- execution: - concurrency: 1 ramp-up: 0s iterations: 5 scenario: register_unregister_multiple scenarios: register_unregister_multiple: script: register_unregister_multiple.jmx modules: server_local_monitoring: class : metrics_monitoring_inproc.Monitor services: - module: shellexec prepare: - "curl -s -O https://s3.amazonaws.com/model-server/inputs/kitten.jpg" - "multi-model-server --start > /dev/null 2>&1" - "sleep 20s" - "curl -s -X POST http://localhost:8081/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/resnet-18.mar" post-process: - "multi-model-server --stop > /dev/null 2>&1" - "rm kitten.jpg" - module: server_local_monitoring ServerLocalClient: - interval: 1s logging : True metrics: - total_processes - sum_all_file_descriptors - frontend_memory_rss reporting: - module: passfail criteria: # Inbuilt Criteria - success of RegisterModel<${REG_SUCC}, stop as failed - success of ScaleUp<${SCL_UP_SUCC}, stop as failed - success of UnregisterModel<${UNREG_SUCC}, stop as failed - avg-rt of RegisterModel>${REG_RT}, stop as failed - avg-rt of ScaleUp>${SCL_UP_RT}, stop as failed - avg-rt of UnregisterModel>${UNREG_RT}, stop as failed # Custom Criteria - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '>' threshold: ${TOTAL_PROCS} timeframe: 5s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_all_file_descriptors condition: '>' threshold: ${TOTAL_FDS} timeframe: 5s stop : true fail : true # - class: bzt.modules.monitoring.MonitoringCriteria # subject: ServerLocalClient/frontend_memory_rss # condition: '>' # threshold: ${FRNTEND_MEM} # timeframe: 5s # stop : true # fail : true ================================================ FILE: tests/performance/tests/scale_down_workers/environments/xlarge.yaml ================================================ --- settings: env: SCL_DWN_SUCC : 100% SCL_DWN_RT : 10ms TOTAL_PROCS_B4_SCL_DWN : 6 TOTAL_PROCS_AFTR_SCL_DWN : 4 TOTAL_WRKRS_B4_SCL_DWN : 4 TOTAL_WRKRS_AFTR_SCL_DWN : 2 FRNTEND_FDS : 78 TOTAL_WRKRS_FDS_B4_SCL_DWN: 38 TOTAL_WRKRS_FDS_AFTR_SCL_DWN: 23 FRNTEND_MEM : 290000000 #290MB TOTAL_WRKRS_MEM_B4_SCL_DWN : 450000000 #450MB TOTAL_WRKRS_MEM_AFTR_SCL_DWN : 210000000 #210MB ================================================ FILE: tests/performance/tests/scale_down_workers/scale_down_workers.jmx ================================================ false true false scale_down_workers 2 Number of workers to scale down to = model squeezenet_v1.1 = Model name ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(management_port,8444)} /models/${model}?min_worker=${scale_down_workers}&synchronous=true PUT true false true false ================================================ FILE: tests/performance/tests/scale_down_workers/scale_down_workers.yaml ================================================ --- execution: - concurrency: 10 ramp-up: 1s hold-for: 30s scenario: scaledown scenarios: scaledown: script: scale_down_workers.jmx modules: server_local_monitoring: class : metrics_monitoring_inproc.Monitor services: - module: shellexec prepare: - "multi-model-server --start > /dev/null 2>&1" - "sleep 20s" - "curl -s -X POST http://localhost:8081/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar" - "curl -s -X PUT http://localhost:8081/models/squeezenet_v1.1?min_worker=4&synchronous=true" post-process: - "multi-model-server --stop > /dev/null 2>&1" - module: server_local_monitoring ServerLocalClient: - interval: 1s logging : True metrics: - total_processes - total_workers - frontend_file_descriptors - sum_workers_file_descriptors - frontend_memory_rss - sum_workers_memory_rss reporting: - module: passfail criteria: # Inbuilt Criteria - success of ScaleDown<${SCL_DWN_SUCC}, stop as failed - avg-rt of ScaleDown>${SCL_DWN_RT}, stop as failed # Custom Criteria - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '>' threshold: ${TOTAL_PROCS_B4_SCL_DWN} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '<' threshold: ${TOTAL_PROCS_AFTR_SCL_DWN} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_workers condition: '>' threshold: ${TOTAL_WRKRS_B4_SCL_DWN} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_workers condition: '<' threshold: ${TOTAL_WRKRS_AFTR_SCL_DWN} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/frontend_file_descriptors condition: '>' threshold: ${FRNTEND_FDS} timeframe: 5s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_workers_file_descriptors condition: '>' threshold: ${TOTAL_WRKRS_FDS_B4_SCL_DWN} timeframe: 5s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_workers_file_descriptors condition: '<' threshold: ${TOTAL_WRKRS_FDS_AFTR_SCL_DWN} timeframe: 5s stop : true fail : true # - class: bzt.modules.monitoring.MonitoringCriteria # subject: ServerLocalClient/frontend_memory_rss # condition: '>' # threshold: ${FRNTEND_MEM} # timeframe: 5s # stop : true # fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_workers_memory_rss condition: '>' threshold: ${TOTAL_WRKRS_MEM_B4_SCL_DWN} timeframe: 5s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_workers_memory_rss condition: '<' threshold: ${TOTAL_WRKRS_MEM_AFTR_SCL_DWN} timeframe: 5s stop : true fail : true ================================================ FILE: tests/performance/tests/scale_up_workers/environments/xlarge.yaml ================================================ --- settings: env: SCL_UP_SUCC : 100% SCL_UP_RT : 10ms TOTAL_PROCS_AFTR_SCL_UP : 6 TOTAL_PROCS_B4_SCL_UP : 3 TOTAL_WRKRS_AFTR_SCL_UP : 4 TOTAL_WRKRS_B4_SCL_UP : 1 FRNTEND_FDS : 88 TOTAL_WRKRS_FDS_AFTR_SCL_UP : 38 TOTAL_WRKRS_FDS_B4_SCL_UP : 11 FRNTEND_MEM : 290000000 #290MB TOTAL_WRKRS_MEM_AFTR_SCL_UP : 450000000 #450MB TOTAL_WRKRS_MEM_B4_SCL_UP : 115000000 #115MB ================================================ FILE: tests/performance/tests/scale_up_workers/scale_up_workers.jmx ================================================ false true false scale_up_workers 4 = Number of workers to scale up to model squeezenet_v1.1 = Model name ${__P(hostname,127.0.0.1)} ${__P(port,8443)} ${__P(protocol,https)} 6 continue false ${__P(loops,4)} ${__P(threads,10)} ${__P(rampup,10)} false true ${__P(management_port,8444)} /models/${model}?min_worker=${scale_up_workers}&synchronous=true PUT true false true false ================================================ FILE: tests/performance/tests/scale_up_workers/scale_up_workers.yaml ================================================ --- execution: - concurrency: 10 ramp-up: 1s hold-for: 30s scenario: scaleup scenarios: scaleup: script: scale_up_workers.jmx modules: server_local_monitoring: class : metrics_monitoring_inproc.Monitor services: - module: shellexec prepare: - "multi-model-server --start > /dev/null 2>&1" - "sleep 20s" - "curl -s -X POST http://localhost:8081/models?url=https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar" - "curl -s -X PUT http://localhost:8081/models/squeezenet_v1.1?min_worker=1&synchronous=true" post-process: - "multi-model-server --stop > /dev/null 2>&1" - module: server_local_monitoring ServerLocalClient: - interval: 1s logging : True metrics: - total_processes - total_workers - frontend_file_descriptors - sum_workers_file_descriptors - frontend_memory_rss - sum_workers_memory_rss reporting: - module: passfail criteria: # Inbuilt Criteria - success of ScaleUp<${SCL_UP_SUCC}, stop as failed - avg-rt of ScaleUp>${SCL_UP_RT}, stop as failed # Custom Criteria - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '>' threshold: ${TOTAL_PROCS_AFTR_SCL_UP} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_processes condition: '<' threshold: ${TOTAL_PROCS_B4_SCL_UP} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_workers condition: '>' threshold: ${TOTAL_WRKRS_AFTR_SCL_UP} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/total_workers condition: '<' threshold: ${TOTAL_WRKRS_B4_SCL_UP} timeframe: 1s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/frontend_file_descriptors condition: '>' threshold: ${FRNTEND_FDS} timeframe: 5s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_workers_file_descriptors condition: '>' threshold: ${TOTAL_WRKRS_FDS_AFTR_SCL_UP} timeframe: 5s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_workers_file_descriptors condition: '<' threshold: ${TOTAL_WRKRS_FDS_B4_SCL_UP} timeframe: 5s stop : true fail : true # - class: bzt.modules.monitoring.MonitoringCriteria # subject: ServerLocalClient/frontend_memory_rss # condition: '>' # threshold: ${FRNTEND_MEM} # timeframe: 5s # stop : true # fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_workers_memory_rss condition: '>' threshold: ${TOTAL_WRKRS_MEM_AFTR_SCL_UP} timeframe: 5s stop : true fail : true - class: bzt.modules.monitoring.MonitoringCriteria subject: ServerLocalClient/sum_workers_memory_rss condition: '<' threshold: ${TOTAL_WRKRS_MEM_B4_SCL_UP} timeframe: 5s stop : true fail : true ================================================ FILE: tests/performance/utils/__init__.py ================================================ #!/usr/bin/env python # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Run Tarus test cases and generate the Junit XML report """ # pylint: disable=redefined-builtin from .fs import get_sub_dirs from .timer import Timer from .pyshell import run_process ================================================ FILE: tests/performance/utils/fs.py ================================================ #!/usr/bin/env python # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ File system utilities """ # pylint: disable=redefined-builtin, logging-format-interpolation, dangerous-default-value import logging import sys import os import glob logger = logging.getLogger(__name__) logging.basicConfig(stream=sys.stdout, format="%(message)s", level=logging.INFO) def get_sub_dirs(dir, exclude_list=['comp_data'], include_pattern='*', exclude_pattern=None): """Utility method to get list of folders in a directory""" dir = dir.strip() if not os.path.exists(dir): msg = "The path {} does not exit".format(dir) logger.error("The path {} does not exit".format(dir)) raise Exception(msg) pattern_list = glob.glob(dir + "/" + include_pattern) exclude_pattern_list = glob.glob(dir + "/" + exclude_pattern) if exclude_pattern is not None else [] return list([x for x in os.listdir(dir) if os.path.isdir(dir + "/" + x) and x not in exclude_list and dir + "/" + x in pattern_list and dir + "/" + x not in exclude_pattern_list]) ================================================ FILE: tests/performance/utils/pyshell.py ================================================ #!/usr/bin/env python # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Run shell command utilities """ # pylint: disable=redefined-builtin, logging-format-interpolation import logging import sys import os import subprocess logger = logging.getLogger(__name__) logging.basicConfig(stream=sys.stdout, format="%(message)s", level=logging.INFO) def run_process(cmd, wait=True): """Utility method to run the shell commands""" logger.info("running command : %s", cmd) if wait: os.environ["PYTHONUNBUFFERED"] = "1" process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True) lines = [] while True: line = process.stdout.readline().decode('utf-8').rstrip() if not line: break lines.append(line) logger.info(line) return process.returncode, '\n'.join(lines) else: process = subprocess.Popen(cmd, shell=True) return process.returncode, '' ================================================ FILE: tests/performance/utils/timer.py ================================================ #!/usr/bin/env python # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # http://www.apache.org/licenses/LICENSE-2.0 # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. """ Timer utilities """ # pylint: disable=redefined-builtin import logging import sys import time logger = logging.getLogger(__name__) logging.basicConfig(stream=sys.stdout, format="%(message)s", level=logging.INFO) class Timer(object): """ Helper context manager class to capture time diff """ def __init__(self, description): self.description = description def __enter__(self): self.start = int(time.time()) return self def __exit__(self, type, value, traceback): logger.info("%s: %ss", self.description, self.diff()) def diff(self): return int(time.time()) - self.start